bazaar-research commited on
Commit
fb11af9
·
verified ·
1 Parent(s): 5f59008

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +46 -0
  2. .gitignore +222 -0
  3. .gitmodules +6 -0
  4. .vscode/launch.json +88 -0
  5. LEGAL.md +7 -0
  6. LICENSE +202 -0
  7. Makefile +21 -0
  8. README.md +330 -0
  9. assets/LingBot-VLA.pdf +3 -0
  10. assets/PaliGemmaPI.png +3 -0
  11. assets/QwenPI.png +3 -0
  12. assets/QwenPI_PaliGemmaPI.png +3 -0
  13. assets/Teaser.png +3 -0
  14. assets/exp-gm-100.png +3 -0
  15. assets/exp-robotwin.png +3 -0
  16. assets/norm_stats/libero.json +280 -0
  17. assets/norm_stats/robotwin_50.json +229 -0
  18. assets/norm_stats/robotwin_5_customized.json +201 -0
  19. assets/norm_stats/robotwin_all_new.json +229 -0
  20. assets/scale_ps.png +3 -0
  21. assets/scale_sr.png +3 -0
  22. configs/norm/robotwin_5.yaml +12 -0
  23. configs/vla/robotwin_load20000h.yaml +42 -0
  24. configs/vla/robotwin_load20000h_depth.yaml +68 -0
  25. deploy/__init__.py +0 -0
  26. deploy/image_tools.py +58 -0
  27. deploy/lingbot_robotwin_policy.py +506 -0
  28. deploy/lingbot_robotwin_policy_rep.py +491 -0
  29. deploy/msgpack_numpy.py +57 -0
  30. deploy/websocket_client_policy.py +88 -0
  31. deploy/websocket_policy_server.py +89 -0
  32. docker/Dockerfile +34 -0
  33. docs/Makefile +20 -0
  34. docs/README.md +19 -0
  35. docs/conf.py +66 -0
  36. docs/config/config.md +96 -0
  37. docs/examples/qwen2vl.rst +2 -0
  38. docs/examples/qwen3_moe.md +125 -0
  39. docs/index.rst +2 -0
  40. docs/requirements-docs.txt +9 -0
  41. docs/start/start.rst +2 -0
  42. experiment/libero/README.md +18 -0
  43. experiment/libero/libero/libero_utils.py +112 -0
  44. experiment/libero/libero/req.txt +6 -0
  45. experiment/libero/libero/run_libero_eval.py +300 -0
  46. experiment/libero/robot_utils.py +84 -0
  47. experiment/robotwin/README.md +85 -0
  48. lingbotvla/__init__.py +16 -0
  49. lingbotvla/checkpoint/__init__.py +25 -0
  50. lingbotvla/checkpoint/checkpointer.py +340 -0
.gitattributes CHANGED
@@ -33,3 +33,49 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/LingBot-VLA.pdf filter=lfs diff=lfs merge=lfs -text
37
+ assets/PaliGemmaPI.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/QwenPI.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/QwenPI_PaliGemmaPI.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/Teaser.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/exp-gm-100.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/exp-robotwin.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/scale_ps.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/scale_sr.png filter=lfs diff=lfs merge=lfs -text
45
+ lingbotvla/models/vla/vision_models/MoGe/assets/normal_comaprison.jpg filter=lfs diff=lfs merge=lfs -text
46
+ lingbotvla/models/vla/vision_models/MoGe/assets/overview_simplified.png filter=lfs diff=lfs merge=lfs -text
47
+ lingbotvla/models/vla/vision_models/MoGe/assets/panorama_pipeline.png filter=lfs diff=lfs merge=lfs -text
48
+ lingbotvla/models/vla/vision_models/MoGe/example_images/01_HouseIndoor.jpg filter=lfs diff=lfs merge=lfs -text
49
+ lingbotvla/models/vla/vision_models/MoGe/example_images/02_Office.jpg filter=lfs diff=lfs merge=lfs -text
50
+ lingbotvla/models/vla/vision_models/MoGe/example_images/03_Traffic.jpg filter=lfs diff=lfs merge=lfs -text
51
+ lingbotvla/models/vla/vision_models/MoGe/example_images/05_Mountain.jpg filter=lfs diff=lfs merge=lfs -text
52
+ lingbotvla/models/vla/vision_models/MoGe/example_images/06_MaitreyaBuddha.png filter=lfs diff=lfs merge=lfs -text
53
+ lingbotvla/models/vla/vision_models/MoGe/example_images/07_Breads.jpg filter=lfs diff=lfs merge=lfs -text
54
+ lingbotvla/models/vla/vision_models/MoGe/example_images/08_CatGirl.png filter=lfs diff=lfs merge=lfs -text
55
+ lingbotvla/models/vla/vision_models/MoGe/example_images/09_Restaurant.jpg filter=lfs diff=lfs merge=lfs -text
56
+ lingbotvla/models/vla/vision_models/MoGe/example_images/10_MedievalVillage.jpg filter=lfs diff=lfs merge=lfs -text
57
+ lingbotvla/models/vla/vision_models/MoGe/example_images/panorama/Braunschweig_Panoram.jpg filter=lfs diff=lfs merge=lfs -text
58
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/attention/fig-attention-vis.png filter=lfs diff=lfs merge=lfs -text
59
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/dataset/diversity_figure.png filter=lfs diff=lfs merge=lfs -text
60
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-divided.jpg filter=lfs diff=lfs merge=lfs -text
61
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-full.jpg filter=lfs diff=lfs merge=lfs -text
62
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_grasp/fig-grasp-demo.png filter=lfs diff=lfs merge=lfs -text
63
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-dynamic-tracking.png filter=lfs diff=lfs merge=lfs -text
64
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-scene-tracking-crop.png filter=lfs diff=lfs merge=lfs -text
65
+ lingbotvla/models/vla/vision_models/lingbot-depth/assets/teaser/teaser-crop.png filter=lfs diff=lfs merge=lfs -text
66
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth.png filter=lfs diff=lfs merge=lfs -text
67
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/rgb.png filter=lfs diff=lfs merge=lfs -text
68
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/raw_depth.png filter=lfs diff=lfs merge=lfs -text
69
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/rgb.jpg filter=lfs diff=lfs merge=lfs -text
70
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/raw_depth.png filter=lfs diff=lfs merge=lfs -text
71
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/rgb.png filter=lfs diff=lfs merge=lfs -text
72
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/raw_depth.png filter=lfs diff=lfs merge=lfs -text
73
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/rgb.jpg filter=lfs diff=lfs merge=lfs -text
74
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/raw_depth.png filter=lfs diff=lfs merge=lfs -text
75
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/rgb.png filter=lfs diff=lfs merge=lfs -text
76
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/raw_depth.png filter=lfs diff=lfs merge=lfs -text
77
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/rgb.png filter=lfs diff=lfs merge=lfs -text
78
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/raw_depth.png filter=lfs diff=lfs merge=lfs -text
79
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/raw_depth.png filter=lfs diff=lfs merge=lfs -text
80
+ lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/rgb.jpg filter=lfs diff=lfs merge=lfs -text
81
+ lingbotvla/models/vla/vision_models/lingbot-depth/tech-report.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # log
210
+ *log.txt
211
+ ossutil_output/
212
+ .sumi/
213
+ env.sh
214
+ pids_qwenpi.txt
215
+ run.sh
216
+ start_multi_eval.sh
217
+ trash/
218
+ eval.sh
219
+
220
+ # xwc
221
+ output/
222
+ wandb/
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "lingbotvla/models/vla/vision_models/lingbot-depth"]
2
+ path = lingbotvla/models/vla/vision_models/lingbot-depth
3
+ url = https://github.com/Robbyant/lingbot-depth
4
+ [submodule "lingbotvla/models/vla/vision_models/MoGe"]
5
+ path = lingbotvla/models/vla/vision_models/MoGe
6
+ url = https://github.com/microsoft/MoGe.git
.vscode/launch.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "deploy lingbotvla (模块方式)",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "module": "deploy.lingbot_robotwin_policy",
12
+ "console": "integratedTerminal",
13
+ "cwd": "${workspaceFolder}",
14
+ "justMyCode": false,
15
+ "args": [
16
+ "--model_path",
17
+ "output/ori_4/checkpoints/global_step_12850/hf_ckpt",
18
+ "--use_length",
19
+ "50",
20
+ "--chunk_ret",
21
+ "true",
22
+ "--debug_infer_once"
23
+ ],
24
+ "env": {
25
+ "CUDA_VISIBLE_DEVICES": "0",
26
+ "QWEN25_PATH": "/group/ossdphi_algo_scratch_11/weicxu/huggingface_cache/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3"
27
+ }
28
+ },
29
+ {
30
+ "name": "example_call_robotwin_server",
31
+ "type": "debugpy",
32
+ "request": "launch",
33
+ "module": "deploy.example_call_robotwin_server",
34
+ "console": "integratedTerminal",
35
+ "cwd": "${workspaceFolder}",
36
+ "justMyCode": false,
37
+ "args": [
38
+ "--host",
39
+ "127.0.0.1",
40
+ "--port",
41
+ "8006"
42
+ ],
43
+ "env": {
44
+ "CUDA_VISIBLE_DEVICES": "0"
45
+ }
46
+ },
47
+ {
48
+ "name": "train lingbotvla",
49
+ "type": "debugpy",
50
+ "request": "launch",
51
+ "program": "${file}",
52
+ "console": "integratedTerminal",
53
+ "justMyCode": false,
54
+ "args": [
55
+ "configs/vla/robotwin_load20000h.yaml",
56
+ "--model.model_path",
57
+ "robbyant/lingbot-vla-4b",
58
+ "--data.train_path",
59
+ "mixed_robotwin_5tasks_repo_0.1.0",
60
+ "--train.output_dir",
61
+ "output/",
62
+ "--model.tokenizer_path",
63
+ "Qwen/Qwen2.5-VL-3B-Instruct",
64
+ "--train.micro_batch_size",
65
+ "1",
66
+ "--train.global_batch_size",
67
+ "1",
68
+ "--train.enable_full_shard",
69
+ "true",
70
+ "--train.use_compile",
71
+ "false",
72
+ "--train.enable_fp32",
73
+ "false",
74
+ "--train.freeze_vision_encoder",
75
+ "true",
76
+ ],
77
+ "env": {
78
+ "CUDA_VISIBLE_DEVICES": "2",
79
+ "LOCAL_RANK": "0",
80
+ "RANK": "0",
81
+ "WORLD_SIZE": "1",
82
+ "MASTER_ADDR": "localhost",
83
+ "MASTER_PORT": "29500",
84
+ "PYDEVD_USE_SYS_MONITORING": "0"
85
+ }
86
+ }
87
+ ]
88
+ }
LEGAL.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Legal Disclaimer
2
+
3
+ Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
4
+
5
+ 法律免责声明
6
+
7
+ 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [2026] [Robbyant Team]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: build commit quality style test
2
+
3
+ check_dirs := tasks tests lingbot docs setup.py
4
+
5
+ build:
6
+ python3 setup.py sdist bdist_wheel
7
+
8
+ commit:
9
+ pre-commit install
10
+ pre-commit run --all-files
11
+
12
+ quality:
13
+ ruff check $(check_dirs)
14
+ ruff format --check $(check_dirs)
15
+
16
+ style:
17
+ ruff check $(check_dirs) --fix
18
+ ruff format $(check_dirs)
19
+
20
+ test:
21
+ pytest tests/
README.md ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">LingBot-VLA: A Pragmatic VLA Foundation Model</h1>
2
+
3
+ <p align="center">
4
+ <a href="assets/LingBot-VLA.pdf"><img src="https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv"></a>
5
+ <a href="https://technology.robbyant.com/lingbot-vla"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
6
+ <a href="https://huggingface.co/collections/robbyant/lingbot-vla"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=yellow"></a>
7
+ <a href="https://modelscope.cn/collections/Robbyant/LingBot-VLA"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%96%20Model&message=ModelScope&color=purple"></a>
8
+ <a href="https://huggingface.co/datasets/robbyant/gm100"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20GM-100&message=HuggingFace&color=yellow"></a>
9
+ <a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache--2.0-green"></a>
10
+ </p>
11
+
12
+
13
+ <p align="center">
14
+ <img src="assets/Teaser.png" width="100%">
15
+ </p>
16
+
17
+ ## 🥳 We are excited to introduce **LingBot-VLA**, a pragmatic Vision-Language-Action foundation model.
18
+
19
+ **LingBot-VLA** has focused on **Pragmatic**:
20
+ - **Large-scale Pre-training Data**: 20,000 hours of real-world
21
+ data from 9 popular dual-arm robot configurations.
22
+ <p align="center">
23
+ <img src="assets/scale_sr.png" width="45%" style="margin: 0 10px;">
24
+ <img src="assets/scale_ps.png" width="45%" style="margin: 0 10px;">
25
+ </p>
26
+
27
+ - **Strong Performance**: Achieve clear superiority over competitors on simulation and real-world benchmarks.
28
+ - **Training Efficiency**: Represent a 1.5 ∼ 2.8× (depending on the relied VLM base model) speedup over existing VLA-oriented codebases.
29
+
30
+ ## 🚀 News
31
+ - **[2026-01-27]** LingBot-VLA Technical Report is available on Arxiv.
32
+ - **[2026-01-27]** Weights and code released!
33
+
34
+
35
+ ---
36
+
37
+
38
+ ## 🛠️ Installation
39
+ Requirements
40
+ - Python 3.12.3
41
+ - Pytorch 2.8.0
42
+ - CUDA 12.8
43
+
44
+ ```bash
45
+ # Install Lerobot
46
+ pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128
47
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/huggingface/lerobot.git
48
+ cd lerobot
49
+ git checkout 0cf864870cf29f4738d3ade893e6fd13fbd7cdb5
50
+ pip install -e .
51
+ # Install flash attention
52
+ pip install /path/to/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
53
+
54
+ # Clone the repository
55
+ git clone https://github.com/robbyant/lingbot-vla.git
56
+ cd lingbot-vla/
57
+ git submodule update --remote --recursive
58
+ pip install -e .
59
+ pip install -r requirements.txt
60
+ # Install LingBot-Depth dependency
61
+ cd ./lingbotvla/models/vla/vision_models/lingbot-depth/
62
+ pip install -e . --no-deps
63
+ cd ../MoGe
64
+ pip install -e .
65
+ ```
66
+
67
+ ---
68
+
69
+ ## 📦 Model Download
70
+ We release LingBot-VLA pre-trained weights in two configurations: depth-free version and a depth-distillated version.
71
+ - **Pretrained Checkpoints for Post-Training with and without depth**
72
+
73
+ | Model Name | Huggingface | ModelScope | Description |
74
+ | :--- | :---: | :---: | :---: |
75
+ | LingBot-VLA-4B &nbsp; | [🤗 lingbot-vla-4b](https://huggingface.co/robbyant/lingbot-vla-4b) | [🤖 lingbot-vla-4b](https://modelscope.cn/models/Robbyant/lingbot-vla-4b) | LingBot-VLA *w/o* Depth|
76
+ | LingBot-VLA-4B-Depth | [🤗 lingbot-vla-4b-depth](https://huggingface.co/robbyant/lingbot-vla-4b-depth) | [🤖 lingbot-vla-4b-depth](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-depth) | LingBot-VLA *w/* Depth |
77
+
78
+
79
+
80
+
81
+ To train LingBot with our codebase, weights from [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct), [MoGe-2-vitb-normal](https://huggingface.co/Ruicheng/moge-2-vitb-normal), and [LingBot-Depth](https://huggingface.co/robbyant/lingbot-depth-pretrain-vitl-14) also need to be prepared.
82
+ - **Run Command**:
83
+ ```bash
84
+ python3 scripts/download_hf_model.py --repo_id robbyant/lingbot-vla-4b --local_dir lingbot-vla-4b
85
+ ```
86
+ ---
87
+
88
+ ## 💻 Post-Training Example
89
+
90
+ - **Data Preparation**:
91
+ Please follow [RoboTwin2.0 Preparation](experiment/robotwin/README.md)
92
+
93
+ - **Training Configuration**:
94
+ We provide the mixed post-training configuration in five RoboTwin 2.0 tasks ("open_microwave" "click_bell" "stack_blocks_three" "place_shoe" "put_object_cabinet").
95
+ <details>
96
+ <summary><b>Click to expand full YAML configuration</b></summary>
97
+
98
+ ```yaml
99
+ model:
100
+ model_path: "path/to/lingbot_vla_checkpoint" # Path to pre-trained VLA foundation model (w/o or w depth)
101
+ tokenizer_path: "path/to/Qwen2.5-VL-3B-Instruct"
102
+ post_training: true # Enable post-training/fine-tuning mode
103
+ adanorm_time: true
104
+ old_adanorm: true
105
+
106
+ data:
107
+ datasets_type: vla
108
+ data_name: robotwin_5_new
109
+ train_path: "path/to/lerobot_merged_data" # merged data from 5 robotwin2.0 tasks
110
+ num_workers: 8
111
+ norm_type: bounds_99_woclip
112
+ norm_stats_file: assets/norm_stats/robotwin_50.json # file of normalization statistics
113
+
114
+ train:
115
+ output_dir: "path/to/output"
116
+ loss_type: L1_fm # we apply L1 flow-matching loss in robotwin2.0 finetuning
117
+ data_parallel_mode: fsdp2 # Use Fully Sharded Data Parallel (PyTorch FSDP2)
118
+ enable_full_shard: false # Don't apply reshare after forward in FSDP2
119
+ module_fsdp_enable: true
120
+ use_compile: true # Acceleration via torch.compile
121
+ use_wandb: false
122
+ rmpad: false
123
+ rmpad_with_pos_ids: false
124
+ ulysses_parallel_size: 1
125
+ freeze_vision_encoder: false # ViT need to be optimized
126
+ tokenizer_max_length: 24 # token numbers of task prompt
127
+ action_dim: 14 # Target robot action space dimension
128
+ max_action_dim: 75 # action dim in LingBot-VLA
129
+ max_state_dim: 75 # state dim in LingBot-VLA
130
+ lr: 1.0e-4
131
+ lr_decay_style: constant
132
+ num_train_epochs: 69 # finetuning 20k step
133
+ micro_batch_size: 32
134
+ global_batch_size: 256
135
+ max_steps: 220000
136
+ ckpt_manager: dcp
137
+ save_steps: 220000
138
+ save_epochs: 69
139
+ enable_fp32: true
140
+ enable_resume: true # resume training automatically
141
+ # ===========================================================================
142
+ # Depth Injection Parameters
143
+ # (Required only for LingBot-VLA with Depth. Ignore if not using depth)
144
+ # ===========================================================================
145
+ align_params:
146
+ mode: 'query' # Query-based distillation
147
+ num_task_tokens: 8 # Number of learnable task-specific tokens
148
+ use_image_tokens: True
149
+ use_task_tokens: False
150
+ use_text_tokens: False
151
+ use_contrastive: True
152
+ contrastive_loss_weight: 0.3
153
+ depth_loss_weight: 0.002
154
+ llm: # VLM Projection Settings
155
+ dim_out: 2048
156
+ image_token_size: 8
157
+ image_input_size: 224
158
+ depth:
159
+ model_type: MoRGBD
160
+ moge_path: /"path/to/moGe-2-vitb-normal"
161
+ morgbd_path: "path/to/LingBot-Depth"
162
+ num_layers: 1
163
+ num_heads: 4
164
+ dim_head: 32
165
+ ff_mult: 1
166
+ num_backbone_tokens: 256
167
+ token_size: 16
168
+ dim_out: 1024
169
+ input_size: 224
170
+ visual_steps: 10000
171
+ visual_dir: "path/to/output/images" # visualization path of depth distillation
172
+ ```
173
+ </details>
174
+
175
+ - **Run Command**:
176
+ ```bash
177
+ # without detph
178
+ bash train.sh tasks/vla/train_lingbotvla.py ./configs/vla/robotwin_load20000h.yaml --model.model_path /path/to/LingBot-VLA --data.train_path path/to/mixed_robotwin_5tasks --train.output_dir /path/to/lingbot_robotwin5tasks/ --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct --train.micro_batch_size ${your_batch_size} --train.global_batch_size ${your_batch_size * your_gpu_num}
179
+
180
+ # with depth
181
+ bash train.sh tasks/vla/train_lingbotvla.py ./configs/vla/robotwin_load20000h_depth.yaml --model.model_path /path/to/LingBot-VLA-Depth --data.train_path /path/to/mixed_robotwin_5tasks --train.output_dir /path/to/lingbot_depth_robotwin5tasks --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct --model.moge_path /path/to/moge2-vitb-normal.pt --model.morgbd_path /path/to/LingBot-Depth-Pretrained --train.micro_batch_size ${your_batch_size} --train.global_batch_size ${your_batch_size * your_gpu_num}
182
+ ```
183
+
184
+ - **Evaluation**
185
+ ```bash
186
+ # robotwin2.0
187
+ export QWEN25_PATH=path_to_Qwen2.5-VL-3B-Instruct
188
+ python -m deploy.lingbot_robotwin_policy \
189
+ --model_path path_to_your_model \
190
+ --use_length 50 \
191
+ --port port
192
+ ```
193
+
194
+ - **Customized Post-training**:
195
+ To construct post-training in specified downstream tasks, we have provided an example and please refer to [Custom](lingbotvla/data/vla_data/README.md) for details.
196
+ ---
197
+
198
+ ## 🏗️ Efficiency
199
+ <p align="center">
200
+ <img src="assets/QwenPI_PaliGemmaPI.png" width="85%">
201
+ </p>
202
+ We evaluate the training efficiency of our codebase against established baselines for both <b>Qwen2.5-VL-3B-π</b> and <b>PaliGemma-3B-pt-224-π</b> models. The results demonstrate that our codebase
203
+ achieved the fastest training speeds in both model settings. The above figures detail the training throughput across configurations of 8, 16, 32, 128, and 256 GPUs, alongside the theoretical linear scaling limit.
204
+
205
+ > **📢 Note on Throughput Metrics:**
206
+ > All throughput values (e.g., 261 samples/sec) represent the **total aggregate throughput across all GPUs**, not per-GPU performance.
207
+ > <br><sup>(Updated: Previously mislabeled as per-GPU in earlier versions. We apologize for the confusion.)</sup>
208
+
209
+ ---
210
+
211
+ ## 📊 Performance
212
+
213
+ Our LingBot-VLA achieves state-of-the-art results on real-world and simulation benchmarks:
214
+ - **GM-100 across 3 robot platforms**
215
+
216
+ <table>
217
+ <thead>
218
+ <tr>
219
+ <th rowspan="2">Platform</th>
220
+ <th colspan="2">WALL-OSS</th>
221
+ <th colspan="2">GR00T N1.6</th>
222
+ <th colspan="2">π<sub>0.5</sub></th>
223
+ <th colspan="2">Ours w/o depth</th>
224
+ <th colspan="2">Ours w/ depth</th>
225
+ </tr>
226
+ <tr>
227
+ <th>SR</th><th>PS</th>
228
+ <th>SR</th><th>PS</th>
229
+ <th>SR</th><th>PS</th>
230
+ <th>SR</th><th>PS</th>
231
+ <th>SR</th><th>PS</th>
232
+ </tr>
233
+ </thead>
234
+ <tbody>
235
+ <tr>
236
+ <td>Agibot G1</td>
237
+ <td>2.99%</td><td>8.75%</td><td>5.23%</td><td>12.63%</td><td>7.77%</td><td>21.98%</td><td><b>12.82%</b></td><td>30.04%</td><td>11.98%</td><td><b>30.47%</b></td>
238
+ </tr>
239
+ <tr>
240
+ <td>AgileX</td>
241
+ <td>2.26%</td><td>8.16%</td><td>3.26%</td><td>10.52%</td><td>17.20%</td><td>34.82%</td><td>15.50%</td><td>36.31%</td><td><b>18.93%</b></td><td><b>40.36%</b></td>
242
+ </tr>
243
+ <tr>
244
+ <td>Galaxea R1Pro</td>
245
+ <td>6.89%</td><td>14.13%</td><td>14.29%</td><td>24.83%</td><td>14.10%</td><td>26.14%</td><td>18.89%</td><td>34.71%</td><td><b>20.98%</b></td><td><b>35.40%</b></td>
246
+ </tr>
247
+ <tr>
248
+ <td><b>Average</b></td>
249
+ <td>4.05%</td><td>10.35%</td><td>7.59%</td><td>15.99%</td><td>13.02%</td><td>27.65%</td><td>15.74%</td><td>33.69%</td><td><b>17.30%</b></td><td><b>35.41%</b></td>
250
+ </tr>
251
+ </tbody>
252
+ </table>
253
+
254
+
255
+ - **RoboTwin 2.0 (Clean and Randomized)**
256
+
257
+ <table>
258
+ <thead>
259
+ <tr>
260
+ <th rowspan="2" ><b>Simulation Tasks</b></th>
261
+ <th colspan="2"><b>&pi;<sub>0.5</sub></b></th>
262
+ <th colspan="2"><b>Ours w/o depth</b></th>
263
+ <th colspan="2"><b>Ours w/ depth</b></th>
264
+ </tr>
265
+ <tr>
266
+ <th><b>Clean</b></th>
267
+ <th><b>Rand.</b></th>
268
+ <th><b>Clean</b></th>
269
+ <th><b>Rand.</b></th>
270
+ <th><b>Clean</b></th>
271
+ <th><b>Rand.</b></th>
272
+ </tr>
273
+ </thead>
274
+ <tbody>
275
+ <tr style="border-top: 1px solid #ccc;"> <!-- \midrule -->
276
+ <td><b>Average SR</b></td>
277
+ <td>82.74%</td>
278
+ <td>76.76%</td>
279
+ <td>86.50%</td>
280
+ <td>85.34%</td>
281
+ <td>88.56%</td>
282
+ <td>86.68%</td>
283
+ </tr>
284
+ <!-- 您可以在此处继续添加其他任务行 -->
285
+ </tbody>
286
+ </table>
287
+
288
+
289
+ 📢 We have released our checkpoints of LingBot-VLA-Posttrain-Robotwin:
290
+ | Model Name | Huggingface | ModelScope | Description |
291
+ | :--- | :---: | :---: | :---: |
292
+ | LingBot-VLA-4B-Posttrain-Robotwin &nbsp; | [🤗 lingbot-vla-4b-posttrain-robotwin](https://huggingface.co/robbyant/lingbot-vla-4b-posttrain-robotwin) | [🤖 lingbot-vla-4b-posttrain-robotwin](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-posttrain-robotwin) | LingBot-VLA-Posttrain-Robotwin *w/o* Depth|
293
+ | LingBot-VLA-4B-Depth-Posttrain-Robotwin | [🤗 lingbot-vla-4b-depth-posttrain-robotwin](https://huggingface.co/robbyant/lingbot-vla-4b-depth-posttrain-robotwin) | [🤖 lingbot-vla-4b-depth-posttrain-robotwin](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-depth-posttrain-robotwin) | LingBot-VLA-Posttrain-Robotwin *w/* Depth |
294
+
295
+ We also provided [evaluation code](deploy/lingbot_robotwin_policy_rep.py) for the community to reproduce the performance of LingBot-VLA on Robotwin 2.0:
296
+ ```bash
297
+ export QWEN25_PATH=path_to_Qwen2.5-VL-3B-Instruct
298
+ python -m deploy.lingbot_robotwin_policy_rep \
299
+ --model_path Path_to_LingBot-VLA-Posttrain-Robotwin \
300
+ --use_length 50 \
301
+ --port port
302
+ ```
303
+
304
+ <p align="center">
305
+ <img src="assets/exp-gm-100.png" width="45%" style="margin: 0 10px;">
306
+ <img src="assets/exp-robotwin.png" width="45%" style="margin: 0 10px;">
307
+ </p>
308
+
309
+ ---
310
+
311
+ ## 📝 Citation
312
+
313
+ If you find our work useful in your research, feel free to give us a cite.
314
+
315
+ ```bibtex
316
+ @article{wu2026pragmatic,
317
+ title={A Pragmatic VLA Foundation Model},
318
+ author={Wei Wu and Fan Lu and Yunnan Wang and Shuai Yang and Shi Liu and Fangjing Wang and Shuailei Ma and He Sun and Yong Wang and Zhenqi Qiu and Houlong Xiong and Ziyu Wang and Shuai Zhou and Yiyu Ren and Kejia Zhang and Hui Yu and Jingmei Zhao and Qian Zhu and Ran Cheng and Yong-Lu Li and Yongtao Huang and Xing Zhu and Yujun Shen and Kecheng Zheng},
319
+ journal={arXiv preprint arXiv:2601.18692v1},
320
+ year={2026}
321
+ }
322
+ ```
323
+
324
+ ---
325
+
326
+ ## 📄 License Agreement
327
+ This project is licensed under the [Apache-2.0 License](LICENSE).
328
+
329
+ ## 😊 Acknowledgement
330
+ We would like to express our sincere gratitude to the developers of [VeOmni](https://arxiv.org/abs/2508.02317) and [LeRobot](https://github.com/huggingface/lerobot#). This project benefits significantly from their outstanding work and contributions to the open-source community.
assets/LingBot-VLA.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b0a361d6084d74afc0bc9fcdee5051375b701a8e41013460107a46902bd0426
3
+ size 10000817
assets/PaliGemmaPI.png ADDED

Git LFS Details

  • SHA256: e691d3ffcabb56307bd58397b04b575e03186b6e6f98aa86cd0a00f6327659b8
  • Pointer size: 131 Bytes
  • Size of remote file: 458 kB
assets/QwenPI.png ADDED

Git LFS Details

  • SHA256: f327696f64edd947a3f4b6ce4d81d88420bc8ca756fc80b4db937228d571f150
  • Pointer size: 131 Bytes
  • Size of remote file: 442 kB
assets/QwenPI_PaliGemmaPI.png ADDED

Git LFS Details

  • SHA256: 4ce326329047abdf297f713ae303693db983de4849f3ad5f32a92c3ca310658d
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
assets/Teaser.png ADDED

Git LFS Details

  • SHA256: 7081c4c6c8586c21ade32fbfe7547f0841b201c46302ab495c9537cfc982ab54
  • Pointer size: 132 Bytes
  • Size of remote file: 9.14 MB
assets/exp-gm-100.png ADDED

Git LFS Details

  • SHA256: 9afddc707eb74534e0c1e3903eed0ee6a2ea24df883f7eb1b2fc8d0c5862068d
  • Pointer size: 131 Bytes
  • Size of remote file: 516 kB
assets/exp-robotwin.png ADDED

Git LFS Details

  • SHA256: 1d61317bee06123a946302d358ff14f11cc01640cfb820f31630cbf612373ecc
  • Pointer size: 131 Bytes
  • Size of remote file: 396 kB
assets/norm_stats/libero.json ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "state": {
4
+ "mean": [
5
+ -0.04617275670170784,
6
+ 0.034034404903650284,
7
+ 0.7647115588188171,
8
+ 2.971421480178833,
9
+ -0.2198116034269333,
10
+ -0.1260652393102646,
11
+ 0.02694438025355339,
12
+ -0.0272101741284132,
13
+ 0.0,
14
+ 0.0,
15
+ 0.0,
16
+ 0.0,
17
+ 0.0,
18
+ 0.0,
19
+ 0.0,
20
+ 0.0,
21
+ 0.0,
22
+ 0.0,
23
+ 0.0,
24
+ 0.0,
25
+ 0.0,
26
+ 0.0,
27
+ 0.0,
28
+ 0.0,
29
+ 0.0,
30
+ 0.0,
31
+ 0.0,
32
+ 0.0,
33
+ 0.0,
34
+ 0.0,
35
+ 0.0,
36
+ 0.0
37
+ ],
38
+ "std": [
39
+ 0.1049584373831749,
40
+ 0.15187117457389832,
41
+ 0.3785041272640228,
42
+ 0.3451951742172241,
43
+ 0.910057544708252,
44
+ 0.3253032863140106,
45
+ 0.014151589013636112,
46
+ 0.014038060791790485,
47
+ 0.0,
48
+ 0.0,
49
+ 0.0,
50
+ 0.0,
51
+ 0.0,
52
+ 0.0,
53
+ 0.0,
54
+ 0.0,
55
+ 0.0,
56
+ 0.0,
57
+ 0.0,
58
+ 0.0,
59
+ 0.0,
60
+ 0.0,
61
+ 0.0,
62
+ 0.0,
63
+ 0.0,
64
+ 0.0,
65
+ 0.0,
66
+ 0.0,
67
+ 0.0,
68
+ 0.0,
69
+ 0.0,
70
+ 0.0
71
+ ],
72
+ "q01": [
73
+ -0.4003246918797493,
74
+ -0.268838057410717,
75
+ 0.03963126605004072,
76
+ 1.5141939243793487,
77
+ -2.7199491125106814,
78
+ -1.0708919448852539,
79
+ 0.0017206525699933989,
80
+ -0.04004273633235134,
81
+ 0.0,
82
+ 0.0,
83
+ 0.0,
84
+ 0.0,
85
+ 0.0,
86
+ 0.0,
87
+ 0.0,
88
+ 0.0,
89
+ 0.0,
90
+ 0.0,
91
+ 0.0,
92
+ 0.0,
93
+ 0.0,
94
+ 0.0,
95
+ 0.0,
96
+ 0.0,
97
+ 0.0,
98
+ 0.0,
99
+ 0.0,
100
+ 0.0,
101
+ 0.0,
102
+ 0.0,
103
+ 0.0,
104
+ 0.0
105
+ ],
106
+ "q99": [
107
+ 0.1335429027736188,
108
+ 0.3378903574764729,
109
+ 1.2657122139371932,
110
+ 3.2784227243721484,
111
+ 2.4147262454509733,
112
+ 0.5962245464324951,
113
+ 0.04029089962062426,
114
+ -0.001789628425752747,
115
+ 0.0,
116
+ 0.0,
117
+ 0.0,
118
+ 0.0,
119
+ 0.0,
120
+ 0.0,
121
+ 0.0,
122
+ 0.0,
123
+ 0.0,
124
+ 0.0,
125
+ 0.0,
126
+ 0.0,
127
+ 0.0,
128
+ 0.0,
129
+ 0.0,
130
+ 0.0,
131
+ 0.0,
132
+ 0.0,
133
+ 0.0,
134
+ 0.0,
135
+ 0.0,
136
+ 0.0,
137
+ 0.0,
138
+ 0.0
139
+ ]
140
+ },
141
+ "actions": {
142
+ "mean": [
143
+ 0.06667574495077133,
144
+ 0.06483978033065796,
145
+ -0.80384361743927,
146
+ -2.970874071121216,
147
+ 0.22662578523159027,
148
+ 0.11959122866392136,
149
+ -0.036161474883556366,
150
+ 0.0,
151
+ 0.0,
152
+ 0.0,
153
+ 0.0,
154
+ 0.0,
155
+ 0.0,
156
+ 0.0,
157
+ 0.0,
158
+ 0.0,
159
+ 0.0,
160
+ 0.0,
161
+ 0.0,
162
+ 0.0,
163
+ 0.0,
164
+ 0.0,
165
+ 0.0,
166
+ 0.0,
167
+ 0.0,
168
+ 0.0,
169
+ 0.0,
170
+ 0.0,
171
+ 0.0,
172
+ 0.0,
173
+ 0.0,
174
+ 0.0
175
+ ],
176
+ "std": [
177
+ 0.32812511920928955,
178
+ 0.4197826683521271,
179
+ 0.6153613924980164,
180
+ 0.35168182849884033,
181
+ 0.9132273197174072,
182
+ 0.3432939946651459,
183
+ 0.9993459582328796,
184
+ 0.0,
185
+ 0.0,
186
+ 0.0,
187
+ 0.0,
188
+ 0.0,
189
+ 0.0,
190
+ 0.0,
191
+ 0.0,
192
+ 0.0,
193
+ 0.0,
194
+ 0.0,
195
+ 0.0,
196
+ 0.0,
197
+ 0.0,
198
+ 0.0,
199
+ 0.0,
200
+ 0.0,
201
+ 0.0,
202
+ 0.0,
203
+ 0.0,
204
+ 0.0,
205
+ 0.0,
206
+ 0.0,
207
+ 0.0,
208
+ 0.0
209
+ ],
210
+ "q01": [
211
+ -0.7088336983919143,
212
+ -0.8786727856397629,
213
+ -2.097322083187103,
214
+ -3.3041505486488343,
215
+ -2.4138620029449465,
216
+ -0.6111064100980759,
217
+ -1.0,
218
+ 0.0,
219
+ 0.0,
220
+ 0.0,
221
+ 0.0,
222
+ 0.0,
223
+ 0.0,
224
+ 0.0,
225
+ 0.0,
226
+ 0.0,
227
+ 0.0,
228
+ 0.0,
229
+ 0.0,
230
+ 0.0,
231
+ 0.0,
232
+ 0.0,
233
+ 0.0,
234
+ 0.0,
235
+ 0.0,
236
+ 0.0,
237
+ 0.0,
238
+ 0.0,
239
+ 0.0,
240
+ 0.0,
241
+ 0.0,
242
+ 0.0
243
+ ],
244
+ "q99": [
245
+ 1.0219826289415357,
246
+ 1.0526966882944104,
247
+ 0.7265835452556608,
248
+ -1.491220802116394,
249
+ 2.7264903316497806,
250
+ 1.1191907620668413,
251
+ 0.9996,
252
+ 0.0,
253
+ 0.0,
254
+ 0.0,
255
+ 0.0,
256
+ 0.0,
257
+ 0.0,
258
+ 0.0,
259
+ 0.0,
260
+ 0.0,
261
+ 0.0,
262
+ 0.0,
263
+ 0.0,
264
+ 0.0,
265
+ 0.0,
266
+ 0.0,
267
+ 0.0,
268
+ 0.0,
269
+ 0.0,
270
+ 0.0,
271
+ 0.0,
272
+ 0.0,
273
+ 0.0,
274
+ 0.0,
275
+ 0.0,
276
+ 0.0
277
+ ]
278
+ }
279
+ }
280
+ }
assets/norm_stats/robotwin_50.json ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "action.arm.position": {
4
+ "mean": [
5
+ -0.22649447619915009,
6
+ 1.0910465717315674,
7
+ 0.8046976923942566,
8
+ -0.3529793620109558,
9
+ 0.056382808834314346,
10
+ -0.04518803581595421,
11
+ 0.23444592952728271,
12
+ 1.1117788553237915,
13
+ 0.8302268385887146,
14
+ -0.3584558367729187,
15
+ -0.010058438405394554,
16
+ 0.010835078544914722
17
+ ],
18
+ "std": [
19
+ 0.36951732635498047,
20
+ 0.9946224689483643,
21
+ 0.7907869219779968,
22
+ 0.663685142993927,
23
+ 0.24930860102176666,
24
+ 0.5646992921829224,
25
+ 0.32377511262893677,
26
+ 1.0205038785934448,
27
+ 0.8121177554130554,
28
+ 0.7205489277839661,
29
+ 0.25676125288009644,
30
+ 0.6210611462593079
31
+ ],
32
+ "q01": [
33
+ -0.9676963651657111,
34
+ -0.0003164021181873977,
35
+ -0.0008187678098678652,
36
+ -1.5952941972732544,
37
+ -0.4444093635320664,
38
+ -2.2108209049224854,
39
+ -0.13648582720756508,
40
+ -0.0025135905981064077,
41
+ -0.0016476722434163094,
42
+ -1.7023667912483216,
43
+ -1.0292453282356262,
44
+ -1.6702169750213622
45
+ ],
46
+ "q99": [
47
+ 0.17045696868896432,
48
+ 2.5792064671580563,
49
+ 2.4791862522006034,
50
+ 1.263499072647095,
51
+ 1.2283580561399456,
52
+ 1.4622943069458012,
53
+ 1.096450059175491,
54
+ 2.605947977209091,
55
+ 2.5039097490906714,
56
+ 1.3104696589708325,
57
+ 1.074876550579071,
58
+ 2.104229341125489
59
+ ],
60
+ "q02": [
61
+ -0.9234203773498537,
62
+ -0.0003164021181873977,
63
+ -0.0008187678098678652,
64
+ -1.509812859249115,
65
+ -0.32799621334075924,
66
+ -1.656348336791992,
67
+ -0.05942733430862468,
68
+ -0.0025135905981064077,
69
+ -0.0016476722434163094,
70
+ -1.6187864029407502,
71
+ -0.8712951603889465,
72
+ -1.5470734649658198
73
+ ],
74
+ "q98": [
75
+ 0.11836757125854458,
76
+ 2.4944407171577216,
77
+ 2.3239549394726753,
78
+ 1.0776700769424439,
79
+ 1.0128444806575776,
80
+ 1.2158620544433596,
81
+ 0.945415413093567,
82
+ 2.5296102081775667,
83
+ 2.3580759009346366,
84
+ 1.2048114322423933,
85
+ 0.6983346325874327,
86
+ 1.7523907409667974
87
+ ]
88
+ },
89
+ "action.effector.position": {
90
+ "mean": [
91
+ 0.6722026467323303,
92
+ 0.6737783551216125
93
+ ],
94
+ "std": [
95
+ 0.45274168252944946,
96
+ 0.45141810178756714
97
+ ],
98
+ "q01": [
99
+ -1e-10,
100
+ -1e-10
101
+ ],
102
+ "q99": [
103
+ 0.99980000009996,
104
+ 0.99980000009996
105
+ ],
106
+ "q02": [
107
+ -1e-10,
108
+ -1e-10
109
+ ],
110
+ "q98": [
111
+ 0.99980000009996,
112
+ 0.99980000009996
113
+ ]
114
+ },
115
+ "observation.state.arm.position": {
116
+ "mean": [
117
+ -0.22545991837978363,
118
+ 1.0864390134811401,
119
+ 0.8012449741363525,
120
+ -0.3515830338001251,
121
+ 0.05604754388332367,
122
+ -0.0445503294467926,
123
+ 0.23296862840652466,
124
+ 1.1059207916259766,
125
+ 0.8258985280990601,
126
+ -0.3568105697631836,
127
+ -0.00992637686431408,
128
+ 0.010328034870326519
129
+ ],
130
+ "std": [
131
+ 0.3688313364982605,
132
+ 0.9950565099716187,
133
+ 0.7906551957130432,
134
+ 0.6622100472450256,
135
+ 0.24865445494651794,
136
+ 0.5626452565193176,
137
+ 0.32314980030059814,
138
+ 1.0208053588867188,
139
+ 0.8119285702705383,
140
+ 0.718558132648468,
141
+ 0.25572913885116577,
142
+ 0.6181830763816833
143
+ ],
144
+ "q01": [
145
+ -0.9676963651657111,
146
+ -0.0003164021181873977,
147
+ -0.0008187678098678652,
148
+ -1.5938075653076171,
149
+ -0.44261839199066166,
150
+ -2.198074409103393,
151
+ -0.13494465734958627,
152
+ -0.0025135905981064077,
153
+ -0.0016476722434163094,
154
+ -1.7015782970190048,
155
+ -1.0292453282356262,
156
+ -1.6682623161315915
157
+ ],
158
+ "q99": [
159
+ 0.17045696868896432,
160
+ 2.5792064671580563,
161
+ 2.4782622562915084,
162
+ 1.2545792808532719,
163
+ 1.2247761130571364,
164
+ 1.458045475006104,
165
+ 1.0856618701696394,
166
+ 2.6036578441381453,
167
+ 2.502444082275033,
168
+ 1.3057386935949324,
169
+ 1.0699406078338622,
170
+ 2.0983653644561766
171
+ ],
172
+ "q02": [
173
+ -0.9234203773498537,
174
+ -0.0003164021181873977,
175
+ -0.0008187678098678652,
176
+ -1.5083262272834776,
177
+ -0.32799621334075924,
178
+ -1.6499750888824458,
179
+ -0.05942733430862468,
180
+ -0.0025135905981064077,
181
+ -0.0016476722434163094,
182
+ -1.6172094144821167,
183
+ -0.8684746216773986,
184
+ -1.5470734649658198
185
+ ],
186
+ "q98": [
187
+ 0.11836757125854458,
188
+ 2.4944407171577216,
189
+ 2.320258955836296,
190
+ 1.0754401289939883,
191
+ 1.0116504996299742,
192
+ 1.2137376384735115,
193
+ 0.945415413093567,
194
+ 2.528846830487251,
195
+ 2.3551445673033595,
196
+ 1.2016574553251265,
197
+ 0.6969243632316591,
198
+ 1.746526764297485
199
+ ]
200
+ },
201
+ "observation.state.effector.position": {
202
+ "mean": [
203
+ 0.6734354496002197,
204
+ 0.6749846339225769
205
+ ],
206
+ "std": [
207
+ 0.4522727429866791,
208
+ 0.45095184445381165
209
+ ],
210
+ "q01": [
211
+ -1e-10,
212
+ -1e-10
213
+ ],
214
+ "q99": [
215
+ 0.99980000009996,
216
+ 0.99980000009996
217
+ ],
218
+ "q02": [
219
+ -1e-10,
220
+ -1e-10
221
+ ],
222
+ "q98": [
223
+ 0.99980000009996,
224
+ 0.99980000009996
225
+ ]
226
+ }
227
+ },
228
+ "count": 532992
229
+ }
assets/norm_stats/robotwin_5_customized.json ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "action": {
4
+ "mean": [
5
+ -0.32207754254341125,
6
+ 1.406205654144287,
7
+ 1.1087545156478882,
8
+ -0.6245313882827759,
9
+ -0.027720848098397255,
10
+ -0.035565875470638275,
11
+ 0.4717631936073303,
12
+ 0.25276312232017517,
13
+ 0.8104884624481201,
14
+ 0.5522242188453674,
15
+ -0.1358797252178192,
16
+ 0.13210205733776093,
17
+ -0.13196010887622833,
18
+ 0.7805091738700867
19
+ ],
20
+ "std": [
21
+ 0.2855374813079834,
22
+ 0.9229381084442139,
23
+ 0.8118345737457275,
24
+ 0.49564430117607117,
25
+ 0.16244904696941376,
26
+ 0.5517618656158447,
27
+ 0.4883338212966919,
28
+ 0.40702372789382935,
29
+ 1.036325216293335,
30
+ 0.7480976581573486,
31
+ 0.7034134268760681,
32
+ 0.3450477123260498,
33
+ 0.7341580390930176,
34
+ 0.4033139646053314
35
+ ],
36
+ "q01": [
37
+ -0.8213654638230801,
38
+ -5.257390398583084e-7,
39
+ -0.00002296771708643064,
40
+ -1.6557389229632915,
41
+ -0.6564541918039322,
42
+ -1.1997157670021057,
43
+ 0.0,
44
+ -0.0013322193384173175,
45
+ 0.0,
46
+ -0.0000281171942333458,
47
+ -1.4858032744407654,
48
+ -0.013652276556193832,
49
+ -1.5582030366897581,
50
+ 0.0
51
+ ],
52
+ "q99": [
53
+ 0.01988644998967637,
54
+ 2.618066892673189,
55
+ 2.8887816588023267,
56
+ -0.00009503023102874764,
57
+ 0.39941834962368006,
58
+ 1.3274614672660827,
59
+ 0.9998,
60
+ 1.2499000839233396,
61
+ 2.403721238327026,
62
+ 2.223998639903084,
63
+ 1.3482957191944123,
64
+ 1.2036741195514797,
65
+ 2.3008846492767336,
66
+ 0.9998
67
+ ],
68
+ "q02": [
69
+ -0.8116190195694566,
70
+ -5.257390398583084e-7,
71
+ -0.00002296771708643064,
72
+ -1.5653808554142714,
73
+ -0.5909986785650253,
74
+ -0.9318809885978698,
75
+ 0.0,
76
+ -0.0013322193384173175,
77
+ 0.0,
78
+ -0.0000281171942333458,
79
+ -1.400590261220932,
80
+ -0.005905654035508634,
81
+ -1.5582030366897581,
82
+ 0.0
83
+ ],
84
+ "q98": [
85
+ 0.01988644998967637,
86
+ 2.509362170317786,
87
+ 2.6153081541584893,
88
+ -0.00009503023102874764,
89
+ 0.34549802929162987,
90
+ 1.2313367155075077,
91
+ 0.9998,
92
+ 1.2416952819347378,
93
+ 2.374588215923309,
94
+ 2.1395174845976728,
95
+ 1.328065291595459,
96
+ 1.1956508319407702,
97
+ 2.172924092388153,
98
+ 0.9998
99
+ ]
100
+ },
101
+ "observation.state": {
102
+ "mean": [
103
+ -0.320831835269928,
104
+ 1.401549220085144,
105
+ 1.1045918464660645,
106
+ -0.6217827796936035,
107
+ -0.0279570072889328,
108
+ -0.03499468415975571,
109
+ 0.4726906716823578,
110
+ 0.2512069344520569,
111
+ 0.8065828680992126,
112
+ 0.5495453476905823,
113
+ -0.13533149659633636,
114
+ 0.13129419088363647,
115
+ -0.1315813809633255,
116
+ 0.7816013693809509
117
+ ],
118
+ "std": [
119
+ 0.28554511070251465,
120
+ 0.924691379070282,
121
+ 0.8124904036521912,
122
+ 0.49545007944107056,
123
+ 0.16213101148605347,
124
+ 0.5504377484321594,
125
+ 0.4883865714073181,
126
+ 0.40611740946769714,
127
+ 1.035233497619629,
128
+ 0.7470027208328247,
129
+ 0.7013660073280334,
130
+ 0.3439686894416809,
131
+ 0.7313857674598694,
132
+ 0.4025507867336273
133
+ ],
134
+ "q01": [
135
+ -0.8213654638230801,
136
+ -5.257390398583084e-7,
137
+ -0.00002296771708643064,
138
+ -1.6557389229632915,
139
+ -0.6564541918039322,
140
+ -1.1997157670021057,
141
+ 0.0,
142
+ -0.0013322193384173175,
143
+ 0.0,
144
+ -0.0000281171942333458,
145
+ -1.483351101398468,
146
+ -0.013652276556193832,
147
+ -1.5582030366897581,
148
+ 0.0
149
+ ],
150
+ "q99": [
151
+ 0.01988644998967637,
152
+ 2.6186390227908487,
153
+ 2.889423615385998,
154
+ -0.00009503023102874764,
155
+ 0.39780878782272344,
156
+ 1.3274614672660827,
157
+ 0.9998,
158
+ 1.2499000839233396,
159
+ 2.404215018367767,
160
+ 2.2201366442319794,
161
+ 1.347682675933838,
162
+ 1.2036741195514797,
163
+ 2.3008846492767336,
164
+ 0.9998
165
+ ],
166
+ "q02": [
167
+ -0.8116190195694566,
168
+ -5.257390398583084e-7,
169
+ -0.00002296771708643064,
170
+ -1.5653808554142714,
171
+ -0.5909986785650253,
172
+ -0.9318809885978698,
173
+ 0.0,
174
+ -0.0013322193384173175,
175
+ 0.0,
176
+ -0.0000281171942333458,
177
+ -1.3981380881786347,
178
+ -0.005905654035508634,
179
+ -1.5582030366897581,
180
+ 0.0
181
+ ],
182
+ "q98": [
183
+ 0.01988644998967637,
184
+ 2.509362170317786,
185
+ 2.61595011074216,
186
+ -0.00009503023102874764,
187
+ 0.3452297689914703,
188
+ 1.2313367155075077,
189
+ 0.9998,
190
+ 1.2416952819347378,
191
+ 2.374588215923309,
192
+ 2.1380692362210083,
193
+ 1.328065291595459,
194
+ 1.1956508319407702,
195
+ 2.1450514958381657,
196
+ 0.9998
197
+ ]
198
+ }
199
+ },
200
+ "count": 74240
201
+ }
assets/norm_stats/robotwin_all_new.json ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "norm_stats": {
3
+ "action.arm.position": {
4
+ "mean": [
5
+ -0.2260681688785553,
6
+ 1.090435266494751,
7
+ 0.8042582273483276,
8
+ -0.3527189791202545,
9
+ 0.056556474417448044,
10
+ -0.04530515521764755,
11
+ 0.2346765249967575,
12
+ 1.112542748451233,
13
+ 0.8304542303085327,
14
+ -0.357768177986145,
15
+ -0.01014612801373005,
16
+ 0.010991317220032215
17
+ ],
18
+ "std": [
19
+ 0.3691432774066925,
20
+ 0.994762122631073,
21
+ 0.7908730506896973,
22
+ 0.6637247800827026,
23
+ 0.24963052570819855,
24
+ 0.5638052821159363,
25
+ 0.32393988966941833,
26
+ 1.0204970836639404,
27
+ 0.8119731545448303,
28
+ 0.7209287285804749,
29
+ 0.25776439905166626,
30
+ 0.6208906769752502
31
+ ],
32
+ "q01": [
33
+ -0.9676963651657111,
34
+ -0.0003164021181873977,
35
+ -0.0026667596280574857,
36
+ -1.596037513256073,
37
+ -0.4467973255872727,
38
+ -2.20232324104309,
39
+ -0.13648582720756508,
40
+ -0.0017502129077910933,
41
+ -0.0023805056512355804,
42
+ -1.703943779706955,
43
+ -1.0264247895240783,
44
+ -1.6682623161315915
45
+ ],
46
+ "q99": [
47
+ 0.17045696868896432,
48
+ 2.5760957974332737,
49
+ 2.4727182808369395,
50
+ 1.259782492733002,
51
+ 1.2253731035709379,
52
+ 1.4495478111267097,
53
+ 1.0841207003116606,
54
+ 2.6036578441381453,
55
+ 2.4987799152359367,
56
+ 1.3104696589708325,
57
+ 1.0692354731559752,
58
+ 2.104229341125489
59
+ ],
60
+ "q02": [
61
+ -0.9260248472213748,
62
+ -0.0003164021181873977,
63
+ -0.0026667596280574857,
64
+ -1.5090695432662964,
65
+ -0.3291901943683624,
66
+ -1.6520995048522948,
67
+ -0.05942733430862468,
68
+ -0.0017502129077910933,
69
+ -0.0023805056512355804,
70
+ -1.6187864029407502,
71
+ -0.8741156991004944,
72
+ -1.5490281238555905
73
+ ],
74
+ "q98": [
75
+ 0.1157631013870235,
76
+ 2.4936630497265257,
77
+ 2.3193349599272013,
78
+ 1.0769267609596254,
79
+ 1.0140384616851805,
80
+ 1.2073643905639653,
81
+ 0.9469565829515458,
82
+ 2.528083452796936,
83
+ 2.3551445673033595,
84
+ 1.2071769149303435,
85
+ 0.6969243632316591,
86
+ 1.7504360820770266
87
+ ]
88
+ },
89
+ "action.effector.position": {
90
+ "mean": [
91
+ 0.6723259687423706,
92
+ 0.6735112071037292
93
+ ],
94
+ "std": [
95
+ 0.4526418447494507,
96
+ 0.4514695405960083
97
+ ],
98
+ "q01": [
99
+ 0.0,
100
+ 0.0
101
+ ],
102
+ "q99": [
103
+ 0.9998,
104
+ 0.9998
105
+ ],
106
+ "q02": [
107
+ 0.0,
108
+ 0.0
109
+ ],
110
+ "q98": [
111
+ 0.9998,
112
+ 0.9998
113
+ ]
114
+ },
115
+ "observation.state.arm.position": {
116
+ "mean": [
117
+ -0.22502799332141876,
118
+ 1.0857956409454346,
119
+ 0.8007810711860657,
120
+ -0.3513113558292389,
121
+ 0.05622035637497902,
122
+ -0.044659487903118134,
123
+ 0.23319771885871887,
124
+ 1.106688141822815,
125
+ 0.82613205909729,
126
+ -0.3561287522315979,
127
+ -0.010010534897446632,
128
+ 0.010481182485818863
129
+ ],
130
+ "std": [
131
+ 0.3684558570384979,
132
+ 0.9951919317245483,
133
+ 0.7907320857048035,
134
+ 0.6622379422187805,
135
+ 0.24897389113903046,
136
+ 0.5617504119873047,
137
+ 0.32331398129463196,
138
+ 1.0208075046539307,
139
+ 0.8117841482162476,
140
+ 0.718940019607544,
141
+ 0.25672635436058044,
142
+ 0.6180205345153809
143
+ ],
144
+ "q01": [
145
+ -0.9676963651657111,
146
+ -0.0003164021181873977,
147
+ -0.0026667596280574857,
148
+ -1.5938075653076171,
149
+ -0.4462003350734711,
150
+ -2.195949993133545,
151
+ -0.13648582720756508,
152
+ -0.0017502129077910933,
153
+ -0.0023805056512355804,
154
+ -1.703943779706955,
155
+ -1.0257196548461915,
156
+ -1.6663076572418207
157
+ ],
158
+ "q99": [
159
+ 0.16785249881744324,
160
+ 2.5760957974332737,
161
+ 2.47087028901875,
162
+ 1.2516060169219974,
163
+ 1.22238815100193,
164
+ 1.4495478111267097,
165
+ 1.073332511305809,
166
+ 2.602131088757515,
167
+ 2.494382914789021,
168
+ 1.3104696589708325,
169
+ 1.0657097997665406,
170
+ 2.102274682235718
171
+ ],
172
+ "q02": [
173
+ -0.9234203773498537,
174
+ -0.0003164021181873977,
175
+ -0.0026667596280574857,
176
+ -1.5060962793350219,
177
+ -0.3291901943683624,
178
+ -1.6436018409728996,
179
+ -0.05788616445064587,
180
+ -0.0017502129077910933,
181
+ -0.0023805056512355804,
182
+ -1.6164209202528,
183
+ -0.8698848910331727,
184
+ -1.5490281238555905
185
+ ],
186
+ "q98": [
187
+ 0.1157631013870235,
188
+ 2.4928853822953303,
189
+ 2.3174869681090113,
190
+ 1.0754401289939883,
191
+ 1.0122474901437757,
192
+ 1.2031155586242681,
193
+ 0.945415413093567,
194
+ 2.527320075106621,
195
+ 2.3522132336720825,
196
+ 1.202445949554443,
197
+ 0.694808959197998,
198
+ 1.7484814231872559
199
+ ]
200
+ },
201
+ "observation.state.effector.position": {
202
+ "mean": [
203
+ 0.6735715866088867,
204
+ 0.6747165322303772
205
+ ],
206
+ "std": [
207
+ 0.4521658420562744,
208
+ 0.4510030150413513
209
+ ],
210
+ "q01": [
211
+ 0.0,
212
+ 0.0
213
+ ],
214
+ "q99": [
215
+ 0.9998,
216
+ 0.9998
217
+ ],
218
+ "q02": [
219
+ 0.0,
220
+ 0.0
221
+ ],
222
+ "q98": [
223
+ 0.9998,
224
+ 0.9998
225
+ ]
226
+ }
227
+ },
228
+ "count": 535680
229
+ }
assets/scale_ps.png ADDED

Git LFS Details

  • SHA256: b23143996c78b30f658b9a81e0d46c96c2231d9dd2646775b0c057773a1fce14
  • Pointer size: 131 Bytes
  • Size of remote file: 481 kB
assets/scale_sr.png ADDED

Git LFS Details

  • SHA256: 3becc2bb6d5355f672dc110a4578277c3eac1cf53f3cba726e5e6277b8d9c413
  • Pointer size: 131 Bytes
  • Size of remote file: 466 kB
configs/norm/robotwin_5.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_path: /path/to/LingBot-VLA-Depth
3
+ tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
4
+
5
+ data:
6
+ datasets_type: vla
7
+ train_path: /path/to/mixed_robotwin_5tasks
8
+ norm_path: assets/norm_stats/robotwin_5_custom.json
9
+
10
+ train:
11
+ global_batch_size: 512
12
+ output_dir: output/norm
configs/vla/robotwin_load20000h.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_path: /path/to/LingBot-VLA
3
+ tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
4
+ post_training: true
5
+ adanorm_time: true
6
+ old_adanorm: true
7
+
8
+ data:
9
+ datasets_type: vla
10
+ data_name: robotwin_5_new
11
+ train_path: /path/to/mixed_robotwin_5tasks
12
+ num_workers: 8
13
+ norm_type: bounds_99_woclip
14
+ norm_stats_file: assets/norm_stats/robotwin_50.json
15
+
16
+ train:
17
+ output_dir: /path/to/lingbot_robotwin5tasks/
18
+ loss_type: L1_fm
19
+ data_parallel_mode: fsdp2
20
+ enable_full_shard: false
21
+ module_fsdp_enable: true
22
+ use_compile: true
23
+ use_wandb: false
24
+ rmpad: false
25
+ rmpad_with_pos_ids: false
26
+ ulysses_parallel_size: 1
27
+ freeze_vision_encoder: false
28
+ tokenizer_max_length: 24
29
+ action_dim: 14
30
+ max_action_dim: 75
31
+ max_state_dim: 75
32
+ lr: 1.0e-4
33
+ lr_decay_style: constant
34
+ num_train_epochs: 69
35
+ micro_batch_size: 32
36
+ global_batch_size: 256
37
+ max_steps: 220000
38
+ ckpt_manager: dcp
39
+ save_steps: 220000
40
+ save_epochs: 69
41
+ enable_fp32: true
42
+ enable_resume: true
configs/vla/robotwin_load20000h_depth.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_path: /path/to/LingBot-VLA-Depth
3
+ tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
4
+ post_training: true
5
+ adanorm_time: true
6
+ old_adanorm: true
7
+ moge_path: /path/to/moge2-vitb-normal
8
+ morgbd_path: /path/to/LingBot-Depth-Pretrained
9
+
10
+ data:
11
+ datasets_type: vla
12
+ data_name: robotwin_5_new
13
+ train_path: /path/to/mixed_robotwin_5tasks
14
+ num_workers: 8
15
+ norm_type: bounds_99_woclip
16
+ norm_stats_file: assets/norm_stats/robotwin_50.json
17
+
18
+ train:
19
+ output_dir: /path/to/lingbot_depth_robotwin5tasks/
20
+ loss_type: L1_fm
21
+ data_parallel_mode: fsdp2
22
+ enable_full_shard: false
23
+ module_fsdp_enable: true
24
+ use_compile: true
25
+ use_wandb: false
26
+ rmpad: false
27
+ rmpad_with_pos_ids: false
28
+ ulysses_parallel_size: 1
29
+ freeze_vision_encoder: false
30
+ tokenizer_max_length: 24
31
+ action_dim: 14
32
+ max_action_dim: 75
33
+ max_state_dim: 75
34
+ lr: 1.0e-4
35
+ lr_decay_style: constant
36
+ num_train_epochs: 69
37
+ micro_batch_size: 32
38
+ global_batch_size: 256
39
+ max_steps: 220000
40
+ ckpt_manager: dcp
41
+ save_steps: 220000
42
+ save_epochs: 69
43
+ enable_fp32: true
44
+ enable_resume: true
45
+ align_params:
46
+ mode: 'query'
47
+ num_task_tokens: 8
48
+ use_image_tokens: True
49
+ use_task_tokens: False
50
+ use_text_tokens: False
51
+ use_contrastive: True
52
+ contrastive_loss_weight: 0.3
53
+ depth_loss_weight: 0.004
54
+ llm:
55
+ dim_out: 2048
56
+ image_token_size: 8
57
+ image_input_size: 224
58
+ depth:
59
+ model_type: MoRGBD
60
+ num_layers: 1
61
+ num_heads: 4
62
+ dim_head: 32
63
+ ff_mult: 1
64
+ num_backbone_tokens: 256
65
+ token_size: 16
66
+ dim_out: 1024
67
+ input_size: 224
68
+ visual_steps: 10000
deploy/__init__.py ADDED
File without changes
deploy/image_tools.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def convert_to_uint8(img: np.ndarray) -> np.ndarray:
6
+ """Converts an image to uint8 if it is a float image.
7
+
8
+ This is important for reducing the size of the image when sending it over the network.
9
+ """
10
+ if np.issubdtype(img.dtype, np.floating):
11
+ img = (255 * img).astype(np.uint8)
12
+ return img
13
+
14
+
15
+ def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
16
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
17
+
18
+ Args:
19
+ images: A batch of images in [..., height, width, channel] format.
20
+ height: The target height of the image.
21
+ width: The target width of the image.
22
+ method: The interpolation method to use. Default is bilinear.
23
+
24
+ Returns:
25
+ The resized images in [..., height, width, channel].
26
+ """
27
+ # If the images are already the correct size, return them as is.
28
+ if images.shape[-3:-1] == (height, width):
29
+ return images
30
+
31
+ original_shape = images.shape
32
+
33
+ images = images.reshape(-1, *original_shape[-3:])
34
+ resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
35
+ return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
36
+
37
+
38
+ def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
39
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
40
+ width without distortion by padding with zeros.
41
+
42
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
43
+ """
44
+ cur_width, cur_height = image.size
45
+ if cur_width == width and cur_height == height:
46
+ return image # No need to resize if the image is already the correct size.
47
+
48
+ ratio = max(cur_width / width, cur_height / height)
49
+ resized_height = int(cur_height / ratio)
50
+ resized_width = int(cur_width / ratio)
51
+ resized_image = image.resize((resized_width, resized_height), resample=method)
52
+
53
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
54
+ pad_height = max(0, int((height - resized_height) / 2))
55
+ pad_width = max(0, int((width - resized_width) / 2))
56
+ zero_image.paste(resized_image, (pad_width, pad_height))
57
+ assert zero_image.size == (width, height)
58
+ return zero_image
deploy/lingbot_robotwin_policy.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import random
5
+ import numpy as np
6
+ from collections import deque
7
+ import torchvision
8
+ import yaml
9
+ from types import SimpleNamespace
10
+ from packaging.version import Version
11
+ from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
12
+ from glob import glob
13
+ from tqdm import tqdm
14
+ from safetensors import safe_open
15
+ from safetensors.torch import load_file
16
+ from pathlib import Path
17
+ from PIL import Image
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor, nn
21
+
22
+
23
+ import transformers
24
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
25
+ from transformers import (
26
+ AutoConfig,
27
+ PretrainedConfig,
28
+ PreTrainedModel,
29
+ AutoProcessor,
30
+ )
31
+
32
+ from lerobot.configs.policies import PreTrainedConfig
33
+ from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
34
+ from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
35
+ from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
36
+ from lingbotvla.models import build_processor
37
+
38
+
39
+ def set_seed_everywhere(seed: int):
40
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+ np.random.seed(seed)
44
+ random.seed(seed)
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.backends.cudnn.benchmark = False
47
+ os.environ["PYTHONHASHSEED"] = str(seed)
48
+
49
+ set_seed_everywhere(42)
50
+
51
+ BASE_MODEL_PATH = {
52
+ 'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
53
+ 'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
54
+ }
55
+
56
+ def load_model_weights(policy, path_to_pi_model, strict=True):
57
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
58
+ merged_weights = {}
59
+
60
+ for file_path in tqdm(all_safetensors):
61
+ with safe_open(file_path, framework="pt", device="cpu") as f:
62
+ for key in f.keys():
63
+ merged_weights[key] = f.get_tensor(key)
64
+ policy.load_state_dict(merged_weights, strict=strict)
65
+
66
+
67
+ def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
68
+ crop_scale = 0.9
69
+ side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
70
+ out_size = (224, 224)
71
+
72
+ # Convert input to PIL Image
73
+ if isinstance(image, np.ndarray):
74
+ arr = image
75
+ if arr.dtype.kind == "f":
76
+ # If floats likely in [0,1], map to [0,255]
77
+ if arr.max() <= 1.0 and arr.min() >= 0.0:
78
+ arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
79
+ else:
80
+ arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
81
+ elif arr.dtype == np.uint16:
82
+ # Map 16-bit to 8-bit
83
+ arr = (arr / 257).astype(np.uint8)
84
+ elif arr.dtype != np.uint8:
85
+ arr = arr.astype(np.uint8)
86
+ pil = Image.fromarray(arr)
87
+ elif isinstance(image, Image.Image):
88
+ pil = image
89
+ else:
90
+ raise TypeError("image must be a numpy array or PIL.Image.Image")
91
+
92
+ # Force RGB for consistent output
93
+ pil = pil.convert("RGB")
94
+ W, H = pil.size
95
+
96
+ # Compute centered crop box (integer pixels)
97
+ crop_w = max(1, int(round(W * side_scale)))
98
+ crop_h = max(1, int(round(H * side_scale)))
99
+ left = (W - crop_w) // 2
100
+ top = (H - crop_h) // 2
101
+ right = left + crop_w
102
+ bottom = top + crop_h
103
+
104
+ cropped = pil.crop((left, top, right, bottom))
105
+ resized = cropped.resize(out_size, resample=Image.BILINEAR)
106
+ return resized
107
+
108
+ def resize_with_pad(img, width, height, pad_value=-1):
109
+ # assume no-op when width height fits already
110
+ if img.ndim != 4:
111
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
112
+
113
+ # channel last to channel first if necessary
114
+ if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
115
+ img = img.permute(0, 3, 1, 2)
116
+
117
+ cur_height, cur_width = img.shape[2:]
118
+
119
+ ratio = max(cur_width / width, cur_height / height)
120
+ resized_height = int(cur_height / ratio)
121
+ resized_width = int(cur_width / ratio)
122
+ resized_img = F.interpolate(
123
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
124
+ )
125
+
126
+ pad_height = max(0, int(height - resized_height))
127
+ pad_width = max(0, int(width - resized_width))
128
+
129
+ # pad on left and top of image
130
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
131
+ return padded_img
132
+
133
+ class PolicyPreprocessMixin:
134
+
135
+ @torch.no_grad
136
+ def select_action(
137
+ self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
138
+ ):
139
+ self.eval()
140
+ device = 'cuda'
141
+ if use_bf16:
142
+ dtype = torch.bfloat16
143
+ else:
144
+ dtype = torch.float32
145
+ s1 = time.time()
146
+
147
+ if len(observation['images'].shape) == 4:
148
+ observation['images'] = observation['images'].unsqueeze(0)
149
+ observation['img_masks'] = observation['img_masks'].unsqueeze(0)
150
+
151
+ if 'expert_imgs' in observation:
152
+ actions = self.model.sample_actions(
153
+ observation['images'].to(dtype=dtype, device=device),
154
+ observation['img_masks'].to(device=device),
155
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
156
+ observation['lang_masks'].unsqueeze(0).to(device=device),
157
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
158
+ observation['expert_imgs'].to(dtype=dtype, device=device),
159
+ vlm_causal = vlm_causal
160
+ )
161
+ else:
162
+ actions = self.model.sample_actions(
163
+ observation['images'].to(dtype=dtype, device=device),
164
+ observation['img_masks'].to(device=device),
165
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
166
+ observation['lang_masks'].unsqueeze(0).to(device=device),
167
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
168
+ vlm_causal = vlm_causal
169
+ )
170
+ delta_time = time.time() - s1
171
+ print(f'sample_actions cost {delta_time} s')
172
+ observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
173
+ if use_bf16:
174
+ observation['state'] = observation['state'].to(dtype=torch.float32)
175
+ data = self.normalizer.unnormalize(observation)
176
+ return data
177
+
178
+ class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
179
+ pass # Only combine necessary functions
180
+
181
+ class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
182
+ pass # Only combine necessary functions
183
+
184
+
185
+ def merge_qwen_config(policy_config, qwen_config):
186
+ if hasattr(qwen_config, 'to_dict'):
187
+ config_dict = qwen_config.to_dict()
188
+ else:
189
+ config_dict = qwen_config
190
+
191
+ text_keys = {
192
+ "hidden_size",
193
+ "intermediate_size",
194
+ "num_hidden_layers",
195
+ "num_attention_heads",
196
+ "num_key_value_heads",
197
+ "rms_norm_eps",
198
+ "rope_theta",
199
+ "vocab_size",
200
+ "max_position_embeddings",
201
+ "hidden_act",
202
+ "tie_word_embeddings",
203
+ "tokenizer_path",
204
+ }
205
+
206
+ for key in text_keys:
207
+ if key in config_dict:
208
+ setattr(policy_config, key, config_dict[key])
209
+ print(f"✅ Merged LLM: {key} = {config_dict[key]}")
210
+
211
+ if "vision_config" in config_dict:
212
+ policy_config.vision_config = qwen_config.vision_config
213
+ else:
214
+ print("⚠️ Warning: 'vision_config' not found in qwen_config!")
215
+
216
+ return policy_config
217
+
218
+
219
+ class QwenPiServer:
220
+ '''
221
+ policy wrapper to support action ensemble or chunk execution
222
+ '''
223
+ def __init__(
224
+ self,
225
+ path_to_pi_model="",
226
+ adaptive_ensemble_alpha=0.1,
227
+ action_ensemble_horizon=8,
228
+ use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
229
+ chunk_ret=False,
230
+ use_bf16=True,
231
+ use_fp32=False,
232
+ ) -> None:
233
+ assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
234
+ self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
235
+ self.use_length = use_length
236
+ self.chunk_ret = chunk_ret
237
+
238
+ self.task_description = None
239
+
240
+ self.vla = self.load_vla(path_to_pi_model)
241
+ self.vla = self.vla.cuda().eval()
242
+ if use_bf16:
243
+ self.vla = self.vla.to(torch.bfloat16)
244
+ elif use_fp32:
245
+ self.vla.model.float()
246
+ self.global_step = 0
247
+ self.last_action_chunk = None
248
+ self.use_bf16 = use_bf16
249
+ self.use_fp32 = use_fp32
250
+
251
+ def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
252
+ # load model
253
+
254
+ print(f"loading model from: {path_to_pi_model}")
255
+ config = PreTrainedConfig.from_pretrained(path_to_pi_model)
256
+
257
+ # load training config
258
+ training_config_path = Path(path_to_pi_model).parent.parent.parent/'lingbotvla_cli.yaml'
259
+ with open(training_config_path, 'r') as f:
260
+ training_config = yaml.safe_load(f)
261
+ f.close()
262
+
263
+ # update model config according to training config
264
+ training_model_config = training_config['model']
265
+ training_model_config.update(training_config['train'])
266
+ for k, v in training_model_config.items():
267
+ v = getattr(config, k, training_model_config[k])
268
+ setattr(config, k, v)
269
+
270
+ # Set attention_implementation to 'eager' to speed up evaluation.
271
+ config.attention_implementation = 'eager'
272
+
273
+ # set base model according to training config
274
+ training_base_model = training_config['model']['tokenizer_path']
275
+ if 'paligemma' in training_base_model:
276
+ model_name = 'pi0'
277
+ config.vocab_size = 257152 # set vocab size for paligamma
278
+ elif 'qwen2' in training_base_model.lower():
279
+ model_name = 'lingbotvla'
280
+ else:
281
+ raise ValueError(f"Unsupported base model of {path_to_pi_model}")
282
+ base_model_path = BASE_MODEL_PATH[model_name]
283
+ config.tokenizer_path = base_model_path
284
+ self.model_name = model_name
285
+
286
+ qwen_config = AutoConfig.from_pretrained(base_model_path)
287
+ config = merge_qwen_config(config, qwen_config)
288
+
289
+ if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
290
+ config.vocab_size = training_config['model']['vocab_size']
291
+ # load processors
292
+ self.processor = build_processor(base_model_path)
293
+ self.language_tokenizer = self.processor.tokenizer
294
+ self.image_processor = self.processor.image_processor
295
+ data_config = SimpleNamespace(**training_config['data'])
296
+
297
+ print('Initializing model ... ')
298
+
299
+ if 'paligemma' in training_base_model:
300
+ policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
301
+ else:
302
+ policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path)
303
+
304
+ load_model_weights(policy, path_to_pi_model, strict=True)
305
+
306
+ policy.feature_transform = None
307
+ self.data_config = data_config
308
+ self.config = config
309
+ self.joint_max_dim = training_config['train']['max_action_dim']
310
+ self.action_dim = training_config['train']['action_dim']
311
+ self.chunk_size = training_config['train']['chunk_size']
312
+ policy.action_dim = self.action_dim
313
+ policy.chunk_size = self.chunk_size
314
+ self.norm_stats_file = data_config.norm_stats_file
315
+ if 'align_params' in training_config['train']:
316
+ self.use_depth_align = True
317
+ else: self.use_depth_align = False
318
+ with open(self.norm_stats_file) as f:
319
+ self.norm_stats = json.load(f)
320
+ policy.normalizer = Normalizer(
321
+ norm_stats=self.norm_stats['norm_stats'],
322
+ from_file=True,
323
+ data_type='robotwin',
324
+ norm_type={
325
+ "observation.images.cam_high": "identity",
326
+ "observation.images.cam_left_wrist": "identity",
327
+ "observation.images.cam_right_wrist": "identity",
328
+ "observation.state": self.data_config.norm_type,
329
+ "action": self.data_config.norm_type,
330
+ },
331
+ )
332
+
333
+ print('Model initialized ... ')
334
+
335
+ return policy
336
+
337
+ def reset(self, robo_name, path_to_pi_model = None) -> None:
338
+
339
+ if path_to_pi_model is not None:
340
+ self.vla = self.load_vla(path_to_pi_model)
341
+ self.vla = self.vla.cuda().eval()
342
+ if self.use_bf16:
343
+ self.vla = self.vla.to(torch.bfloat16)
344
+ elif self.use_fp32:
345
+ self.vla.model.float()
346
+
347
+ self.global_step = 0
348
+ self.last_action_chunk = None
349
+
350
+ if getattr(self.data_config, 'norm_type', None) is None:
351
+ self.data_config.norm_type = 'meanstd'
352
+ if getattr(self.config, 'vlm_causal', None) is None:
353
+ self.config.vlm_causal = False
354
+ if getattr(self.config, 'qwenvl_bos', None) is None:
355
+ self.config.qwenvl_bos = False
356
+
357
+ # if update ckpt path
358
+ if path_to_pi_model is not None:
359
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
360
+ merged_weights = {}
361
+
362
+ for file_path in tqdm(all_safetensors):
363
+ with safe_open(file_path, framework="pt", device="cpu") as f:
364
+ for key in f.keys():
365
+ merged_weights[key] = f.get_tensor(key)
366
+
367
+ self.vla.load_state_dict(merged_weights, strict=True)
368
+
369
+ def resize_image(self, observation):
370
+ for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
371
+ assert image_feature in observation
372
+ assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
373
+ image = observation[image_feature]
374
+ img_pil = Image.fromarray(image)
375
+ image_size = getattr(self.data_config, 'img_size', 224)
376
+ img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
377
+
378
+ # img_resized shape: C*H*W
379
+ img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
380
+ observation[image_feature] = img_resized / 255.
381
+
382
+ def infer(self, observation, center_crop=True):
383
+ """Generates an action with the VLA policy."""
384
+
385
+ # (If trained with image augmentations) Center crop image and then resize back up to original size.
386
+ # IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
387
+ # the original height and width by sqrt(0.9) -- not 0.9!
388
+ if 'reset' in observation and observation['reset']:
389
+ self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
390
+ return dict(action = None)
391
+
392
+ self.resize_image(observation)
393
+ for k, v in observation.items():
394
+ if isinstance(v, np.ndarray):
395
+ observation[k] = torch.from_numpy(v)
396
+
397
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
398
+ joint_max_dim = getattr(self, 'joint_max_dim')
399
+ action_dim = getattr(self, 'action_dim')
400
+ chunk_size = getattr(self, 'chunk_size')
401
+ normalized_observation = self.vla.normalizer.normalize(observation)
402
+ base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
403
+ left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
404
+ torch.uint8
405
+ )
406
+ right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
407
+ torch.uint8
408
+ )
409
+ obs_dict = {
410
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
411
+ "state": normalized_observation["observation.state"].to(torch.float32),
412
+ "prompt": [observation["task"]],
413
+ }
414
+ state = prepare_state(self.config, obs_dict)
415
+ lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
416
+ images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
417
+ observation = {
418
+ 'images': images,
419
+ 'img_masks': img_masks,
420
+ 'state': state,
421
+ 'lang_tokens': lang_tokens,
422
+ 'lang_masks': lang_masks,
423
+ }
424
+
425
+ if self.use_bf16:
426
+ observation['state'] = observation['state'].to(torch.bfloat16)
427
+
428
+ org_actions = ['action']
429
+ assert len(org_actions)==1, "Only support single action feature"
430
+ if self.chunk_ret:
431
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
432
+ action = action[:self.use_length, :self.action_dim]
433
+ else:
434
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
435
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
436
+ self.last_action_chunk = action.float().cpu().numpy()
437
+
438
+ if self.use_length > 0:
439
+ action = self.last_action_chunk[self.global_step % self.use_length]
440
+ action = action[:, :self.action_dim]
441
+ print(f"on server step: {self.global_step}")
442
+ self.global_step+=1
443
+
444
+ return dict(action = action)
445
+
446
+
447
+ import argparse
448
+ from .websocket_policy_server import WebsocketPolicyServer
449
+
450
+ def main():
451
+ parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
452
+
453
+ parser.add_argument(
454
+ "--model_path",
455
+ type=str,
456
+ )
457
+
458
+ parser.add_argument(
459
+ "--use_length",
460
+ type=int,
461
+ default=50,
462
+ help="used length of action chunk"
463
+ )
464
+
465
+ parser.add_argument(
466
+ "--chunk_ret",
467
+ type=bool,
468
+ default=True,
469
+ help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
470
+ )
471
+
472
+ parser.add_argument(
473
+ "--port",
474
+ type=int,
475
+ default=8006,
476
+ help="port of WebSocket"
477
+ )
478
+
479
+ parser.add_argument(
480
+ "--debug_infer_once",
481
+ action="store_true",
482
+ help="Run one infer with dummy observation then exit (for debugging infer() without WebSocket client)",
483
+ )
484
+
485
+ args = parser.parse_args()
486
+
487
+ model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret=args.chunk_ret)
488
+ if args.debug_infer_once:
489
+ # 调试用:不启动 WebSocket,只跑一次 infer,可在 infer / select_action 里下断点
490
+ dummy_obs = {
491
+ "observation.images.cam_high": np.zeros((224, 224, 3), dtype=np.uint8),
492
+ "observation.images.cam_left_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
493
+ "observation.images.cam_right_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
494
+ "observation.state": np.zeros(model.action_dim, dtype=np.float32),
495
+ "task": "dummy task for debug",
496
+ "reset": False,
497
+ }
498
+ out = model.infer(dummy_obs)
499
+ print("debug_infer_once result keys:", out.keys())
500
+ return
501
+ model_server = WebsocketPolicyServer(model, port=args.port)
502
+ model_server.serve_forever()
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
deploy/lingbot_robotwin_policy_rep.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import random
5
+ import numpy as np
6
+ from collections import deque
7
+ import torchvision
8
+ import yaml
9
+ from types import SimpleNamespace
10
+ from packaging.version import Version
11
+ from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
12
+ from glob import glob
13
+ from tqdm import tqdm
14
+ from safetensors import safe_open
15
+ from safetensors.torch import load_file
16
+ from pathlib import Path
17
+ from PIL import Image
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor, nn
21
+
22
+
23
+ import transformers
24
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
25
+ from transformers import (
26
+ AutoConfig,
27
+ PretrainedConfig,
28
+ PreTrainedModel,
29
+ AutoProcessor,
30
+ )
31
+
32
+ from lerobot.configs.policies import PreTrainedConfig
33
+ from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
34
+ from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
35
+ from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
36
+ from lingbotvla.models import build_processor
37
+
38
+
39
+ def set_seed_everywhere(seed: int):
40
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+ np.random.seed(seed)
44
+ random.seed(seed)
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.backends.cudnn.benchmark = False
47
+ os.environ["PYTHONHASHSEED"] = str(seed)
48
+
49
+ set_seed_everywhere(42)
50
+
51
+ BASE_MODEL_PATH = {
52
+ 'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
53
+ 'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
54
+ }
55
+
56
+ def load_model_weights(policy, path_to_pi_model, strict=True):
57
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
58
+ merged_weights = {}
59
+
60
+ for file_path in tqdm(all_safetensors):
61
+ with safe_open(file_path, framework="pt", device="cpu") as f:
62
+ for key in f.keys():
63
+ merged_weights[key] = f.get_tensor(key)
64
+ policy.load_state_dict(merged_weights, strict=strict)
65
+
66
+
67
+ def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
68
+ crop_scale = 0.9
69
+ side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
70
+ out_size = (224, 224)
71
+
72
+ # Convert input to PIL Image
73
+ if isinstance(image, np.ndarray):
74
+ arr = image
75
+ if arr.dtype.kind == "f":
76
+ # If floats likely in [0,1], map to [0,255]
77
+ if arr.max() <= 1.0 and arr.min() >= 0.0:
78
+ arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
79
+ else:
80
+ arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
81
+ elif arr.dtype == np.uint16:
82
+ # Map 16-bit to 8-bit
83
+ arr = (arr / 257).astype(np.uint8)
84
+ elif arr.dtype != np.uint8:
85
+ arr = arr.astype(np.uint8)
86
+ pil = Image.fromarray(arr)
87
+ elif isinstance(image, Image.Image):
88
+ pil = image
89
+ else:
90
+ raise TypeError("image must be a numpy array or PIL.Image.Image")
91
+
92
+ # Force RGB for consistent output
93
+ pil = pil.convert("RGB")
94
+ W, H = pil.size
95
+
96
+ # Compute centered crop box (integer pixels)
97
+ crop_w = max(1, int(round(W * side_scale)))
98
+ crop_h = max(1, int(round(H * side_scale)))
99
+ left = (W - crop_w) // 2
100
+ top = (H - crop_h) // 2
101
+ right = left + crop_w
102
+ bottom = top + crop_h
103
+
104
+ cropped = pil.crop((left, top, right, bottom))
105
+ resized = cropped.resize(out_size, resample=Image.BILINEAR)
106
+ return resized
107
+
108
+ def resize_with_pad(img, width, height, pad_value=-1):
109
+ # assume no-op when width height fits already
110
+ if img.ndim != 4:
111
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
112
+
113
+ # channel last to channel first if necessary
114
+ if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
115
+ img = img.permute(0, 3, 1, 2)
116
+
117
+ cur_height, cur_width = img.shape[2:]
118
+
119
+ ratio = max(cur_width / width, cur_height / height)
120
+ resized_height = int(cur_height / ratio)
121
+ resized_width = int(cur_width / ratio)
122
+ resized_img = F.interpolate(
123
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
124
+ )
125
+
126
+ pad_height = max(0, int(height - resized_height))
127
+ pad_width = max(0, int(width - resized_width))
128
+
129
+ # pad on left and top of image
130
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
131
+ return padded_img
132
+
133
+ class PolicyPreprocessMixin:
134
+
135
+ @torch.no_grad
136
+ def select_action(
137
+ self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
138
+ ):
139
+ self.eval()
140
+ device = 'cuda'
141
+ if use_bf16:
142
+ dtype = torch.bfloat16
143
+ else:
144
+ dtype = torch.float32
145
+ s1 = time.time()
146
+
147
+ if len(observation['images'].shape) == 4:
148
+ observation['images'] = observation['images'].unsqueeze(0)
149
+ observation['img_masks'] = observation['img_masks'].unsqueeze(0)
150
+ state_indices = list(range(12)) + list(range(73, 75)) + list(range(12, 14)) + list(range(14, 73))
151
+ observation['state'] = observation['state'][state_indices]
152
+ if 'expert_imgs' in observation:
153
+ actions = self.model.sample_actions(
154
+ observation['images'].to(dtype=dtype, device=device),
155
+ observation['img_masks'].to(device=device),
156
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
157
+ observation['lang_masks'].unsqueeze(0).to(device=device),
158
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
159
+ observation['expert_imgs'].to(dtype=dtype, device=device),
160
+ vlm_causal = vlm_causal
161
+ )
162
+ else:
163
+ actions = self.model.sample_actions(
164
+ observation['images'].to(dtype=dtype, device=device),
165
+ observation['img_masks'].to(device=device),
166
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
167
+ observation['lang_masks'].unsqueeze(0).to(device=device),
168
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
169
+ vlm_causal = vlm_causal
170
+ )
171
+ action_indices = list(range(6)) + [14] + list(range(6, 12)) + [15]
172
+ actions = actions[:, :, action_indices]
173
+ delta_time = time.time() - s1
174
+ print(f'sample_actions cost {delta_time} s')
175
+ observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
176
+ if use_bf16:
177
+ observation['state'] = observation['state'].to(dtype=torch.float32)
178
+ data = self.normalizer.unnormalize(observation)
179
+ return data
180
+
181
+ class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
182
+ pass # Only combine necessary functions
183
+
184
+ class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
185
+ pass # Only combine necessary functions
186
+
187
+
188
+ def merge_qwen_config(policy_config, qwen_config):
189
+ if hasattr(qwen_config, 'to_dict'):
190
+ config_dict = qwen_config.to_dict()
191
+ else:
192
+ config_dict = qwen_config
193
+
194
+ text_keys = {
195
+ "hidden_size",
196
+ "intermediate_size",
197
+ "num_hidden_layers",
198
+ "num_attention_heads",
199
+ "num_key_value_heads",
200
+ "rms_norm_eps",
201
+ "rope_theta",
202
+ "vocab_size",
203
+ "max_position_embeddings",
204
+ "hidden_act",
205
+ "tie_word_embeddings",
206
+ "tokenizer_path",
207
+ }
208
+
209
+ for key in text_keys:
210
+ if key in config_dict:
211
+ setattr(policy_config, key, config_dict[key])
212
+ print(f"✅ Merged LLM: {key} = {config_dict[key]}")
213
+
214
+ if "vision_config" in config_dict:
215
+ policy_config.vision_config = qwen_config.vision_config
216
+ else:
217
+ print("⚠️ Warning: 'vision_config' not found in qwen_config!")
218
+
219
+ return policy_config
220
+
221
+
222
+ class QwenPiServer:
223
+ '''
224
+ policy wrapper to support action ensemble or chunk execution
225
+ '''
226
+ def __init__(
227
+ self,
228
+ path_to_pi_model="",
229
+ adaptive_ensemble_alpha=0.1,
230
+ action_ensemble_horizon=8,
231
+ use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
232
+ chunk_ret=False,
233
+ use_bf16=True,
234
+ use_fp32=False,
235
+ ) -> None:
236
+ assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
237
+ self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
238
+ self.use_length = use_length
239
+ self.chunk_ret = chunk_ret
240
+
241
+ self.task_description = None
242
+
243
+ self.vla = self.load_vla(path_to_pi_model)
244
+ self.vla = self.vla.cuda().eval()
245
+ if use_bf16:
246
+ self.vla = self.vla.to(torch.bfloat16)
247
+ elif use_fp32:
248
+ self.vla.model.float()
249
+ self.global_step = 0
250
+ self.last_action_chunk = None
251
+ self.use_bf16 = use_bf16
252
+ self.use_fp32 = use_fp32
253
+
254
+ def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
255
+ # load model
256
+ print(f"loading model from: {path_to_pi_model}")
257
+ config = PreTrainedConfig.from_pretrained(path_to_pi_model)
258
+
259
+ # load training config
260
+ training_config_path = Path(path_to_pi_model)/'lingbotvla_cli.yaml'
261
+ with open(training_config_path, 'r') as f:
262
+ training_config = yaml.safe_load(f)
263
+ f.close()
264
+
265
+ # update model config according to training config
266
+ training_model_config = training_config['model']
267
+ training_model_config.update(training_config['train'])
268
+ for k, v in training_model_config.items():
269
+ v = getattr(config, k, training_model_config[k])
270
+ setattr(config, k, v)
271
+
272
+ # Set attention_implementation to 'eager' to speed up evaluation.
273
+ config.attention_implementation = 'eager'
274
+
275
+ # set base model according to training config
276
+ training_base_model = os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/')
277
+ if 'paligemma' in training_base_model:
278
+ model_name = 'pi0'
279
+ config.vocab_size = 257152 # set vocab size for paligamma
280
+ elif 'qwen2' in training_base_model.lower():
281
+ model_name = 'lingbotvla'
282
+ else:
283
+ raise ValueError(f"Unsupported base model of {path_to_pi_model}")
284
+ base_model_path = BASE_MODEL_PATH[model_name]
285
+ config.tokenizer_path = base_model_path
286
+ self.model_name = model_name
287
+
288
+ qwen_config = AutoConfig.from_pretrained(base_model_path)
289
+ config = merge_qwen_config(config, qwen_config)
290
+
291
+ if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
292
+ config.vocab_size = training_config['model']['vocab_size']
293
+ # load processors
294
+ self.processor = build_processor(base_model_path)
295
+ self.language_tokenizer = self.processor.tokenizer
296
+ self.image_processor = self.processor.image_processor
297
+ data_config = SimpleNamespace(**training_config['data'])
298
+
299
+ print('Initializing model ... ')
300
+
301
+ if 'paligemma' in training_base_model:
302
+ policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
303
+ else:
304
+ policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path, eval=True)
305
+
306
+ load_model_weights(policy, path_to_pi_model, strict=True)
307
+
308
+ policy.feature_transform = None
309
+ self.data_config = data_config
310
+ self.config = config
311
+ self.joint_max_dim = training_config['train']['max_action_dim']
312
+ self.action_dim = training_config['train']['action_dim']
313
+ self.chunk_size = training_config['train']['chunk_size']
314
+ policy.action_dim = self.action_dim
315
+ policy.chunk_size = self.chunk_size
316
+ self.norm_stats_file = 'assets/norm_stats/robotwin_all_new.json'
317
+ if 'align_params' in training_config['train']:
318
+ self.use_depth_align = True
319
+ else: self.use_depth_align = False
320
+ with open(self.norm_stats_file) as f:
321
+ self.norm_stats = json.load(f)
322
+ policy.normalizer = Normalizer(
323
+ norm_stats=self.norm_stats['norm_stats'],
324
+ from_file=True,
325
+ data_type='robotwin_rep',
326
+ norm_type={
327
+ "observation.images.cam_high": "identity",
328
+ "observation.images.cam_left_wrist": "identity",
329
+ "observation.images.cam_right_wrist": "identity",
330
+ "observation.state": self.data_config.norm_type,
331
+ "action": self.data_config.norm_type,
332
+ },
333
+ )
334
+
335
+ print('Model initialized ... ')
336
+
337
+ return policy
338
+
339
+ def reset(self, robo_name, path_to_pi_model = None) -> None:
340
+
341
+ if path_to_pi_model is not None:
342
+ self.vla = self.load_vla(path_to_pi_model)
343
+ self.vla = self.vla.cuda().eval()
344
+ if self.use_bf16:
345
+ self.vla = self.vla.to(torch.bfloat16)
346
+ elif self.use_fp32:
347
+ self.vla.model.float()
348
+
349
+ self.global_step = 0
350
+ self.last_action_chunk = None
351
+
352
+ if getattr(self.data_config, 'norm_type', None) is None:
353
+ self.data_config.norm_type = 'meanstd'
354
+ if getattr(self.config, 'vlm_causal', None) is None:
355
+ self.config.vlm_causal = False
356
+ if getattr(self.config, 'qwenvl_bos', None) is None:
357
+ self.config.qwenvl_bos = False
358
+
359
+ # if update ckpt path
360
+ if path_to_pi_model is not None:
361
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
362
+ merged_weights = {}
363
+
364
+ for file_path in tqdm(all_safetensors):
365
+ with safe_open(file_path, framework="pt", device="cpu") as f:
366
+ for key in f.keys():
367
+ merged_weights[key] = f.get_tensor(key)
368
+
369
+ self.vla.load_state_dict(merged_weights, strict=True)
370
+
371
+ def resize_image(self, observation):
372
+ for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
373
+ assert image_feature in observation
374
+ assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
375
+ image = observation[image_feature]
376
+ img_pil = Image.fromarray(image)
377
+ image_size = getattr(self.data_config, 'img_size', 224)
378
+ img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
379
+
380
+ # img_resized shape: C*H*W
381
+ img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
382
+ observation[image_feature] = img_resized / 255.
383
+
384
+ def infer(self, observation, center_crop=True):
385
+ """Generates an action with the VLA policy."""
386
+
387
+ # (If trained with image augmentations) Center crop image and then resize back up to original size.
388
+ # IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
389
+ # the original height and width by sqrt(0.9) -- not 0.9!
390
+ if 'reset' in observation and observation['reset']:
391
+ self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
392
+ return dict(action = None)
393
+
394
+ self.resize_image(observation)
395
+ for k, v in observation.items():
396
+ if isinstance(v, np.ndarray):
397
+ observation[k] = torch.from_numpy(v)
398
+
399
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
400
+ joint_max_dim = getattr(self, 'joint_max_dim')
401
+ action_dim = getattr(self, 'action_dim')
402
+ chunk_size = getattr(self, 'chunk_size')
403
+ indices = list(range(6)) + list(range(7, 13)) + [6] + [13]
404
+ observation["observation.state"] = observation["observation.state"][indices]
405
+ normalized_observation = self.vla.normalizer.normalize(observation)
406
+ base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
407
+ left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
408
+ torch.uint8
409
+ )
410
+ right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
411
+ torch.uint8
412
+ )
413
+ obs_dict = {
414
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
415
+ "state": normalized_observation["observation.state"].to(torch.float32),
416
+ "prompt": [observation["task"]],
417
+ }
418
+ state = prepare_state(self.config, obs_dict)
419
+ lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
420
+ images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
421
+ observation = {
422
+ 'images': images,
423
+ 'img_masks': img_masks,
424
+ 'state': state,
425
+ 'lang_tokens': lang_tokens,
426
+ 'lang_masks': lang_masks,
427
+ }
428
+
429
+ if self.use_bf16:
430
+ observation['state'] = observation['state'].to(torch.bfloat16)
431
+
432
+ org_actions = ['action']
433
+ assert len(org_actions)==1, "Only support single action feature"
434
+ if self.chunk_ret:
435
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
436
+ action = action[:self.use_length, :self.action_dim]
437
+ else:
438
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
439
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
440
+ self.last_action_chunk = action.float().cpu().numpy()
441
+
442
+ if self.use_length > 0:
443
+ action = self.last_action_chunk[self.global_step % self.use_length]
444
+ action = action[:, :self.action_dim]
445
+ print(f"on server step: {self.global_step}")
446
+ self.global_step+=1
447
+
448
+ return dict(action = action)
449
+
450
+
451
+ import argparse
452
+ from .websocket_policy_server import WebsocketPolicyServer
453
+
454
+ def main():
455
+ parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
456
+
457
+ parser.add_argument(
458
+ "--model_path",
459
+ type=str,
460
+ )
461
+
462
+ parser.add_argument(
463
+ "--use_length",
464
+ type=int,
465
+ default=50,
466
+ help="used length of action chunk"
467
+ )
468
+
469
+ parser.add_argument(
470
+ "--chunk_ret",
471
+ type=bool,
472
+ default=True,
473
+ help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
474
+ )
475
+
476
+ parser.add_argument(
477
+ "--port",
478
+ type=int,
479
+ default=8006,
480
+ help="port of WebSocket"
481
+ )
482
+
483
+ args = parser.parse_args()
484
+
485
+ model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret = args.chunk_ret)
486
+ model_server = WebsocketPolicyServer(model, port=args.port)
487
+ model_server.serve_forever()
488
+
489
+
490
+ if __name__ == "__main__":
491
+ main()
deploy/msgpack_numpy.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adds NumPy array support to msgpack.
2
+
3
+ msgpack is good for (de)serializing data over a network for multiple reasons:
4
+ - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
5
+ - msgpack is widely used and has good cross-language support
6
+ - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
7
+ languages like Python and JavaScript
8
+ - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
9
+ than pickle for serializing large arrays using the below strategy
10
+
11
+ The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
12
+ that it falls back to pickle for object arrays.
13
+ """
14
+
15
+ import functools
16
+
17
+ import msgpack
18
+ import numpy as np
19
+
20
+
21
+ def pack_array(obj):
22
+ if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
23
+ raise ValueError(f"Unsupported dtype: {obj.dtype}")
24
+
25
+ if isinstance(obj, np.ndarray):
26
+ return {
27
+ b"__ndarray__": True,
28
+ b"data": obj.tobytes(),
29
+ b"dtype": obj.dtype.str,
30
+ b"shape": obj.shape,
31
+ }
32
+
33
+ if isinstance(obj, np.generic):
34
+ return {
35
+ b"__npgeneric__": True,
36
+ b"data": obj.item(),
37
+ b"dtype": obj.dtype.str,
38
+ }
39
+
40
+ return obj
41
+
42
+
43
+ def unpack_array(obj):
44
+ if b"__ndarray__" in obj:
45
+ return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
46
+
47
+ if b"__npgeneric__" in obj:
48
+ return np.dtype(obj[b"dtype"]).type(obj[b"data"])
49
+
50
+ return obj
51
+
52
+
53
+ Packer = functools.partial(msgpack.Packer, default=pack_array)
54
+ packb = functools.partial(msgpack.packb, default=pack_array)
55
+
56
+ Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
57
+ unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
deploy/websocket_client_policy.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ from typing_extensions import override
6
+ import websockets.sync.client
7
+ from .msgpack_numpy import Packer, unpackb
8
+
9
+
10
+ class WebsocketClientPolicy:
11
+ """Implements the Policy interface by communicating with a server over websocket.
12
+
13
+ See WebsocketPolicyServer for a corresponding server implementation.
14
+ """
15
+
16
+ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
17
+ self._uri = f"ws://{host}"
18
+ if port is not None:
19
+ self._uri += f":{port}"
20
+ self._packer = Packer()
21
+ self._api_key = api_key
22
+ self._ws, self._server_metadata = self._wait_for_server()
23
+
24
+ def get_server_metadata(self) -> Dict:
25
+ return self._server_metadata
26
+
27
+ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
28
+ logging.info(f"Waiting for server at {self._uri}...")
29
+ while True:
30
+ try:
31
+ headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
32
+ conn = websockets.sync.client.connect(
33
+ self._uri, compression=None, max_size=None, additional_headers=headers
34
+ )
35
+ metadata = unpackb(conn.recv())
36
+ return conn, metadata
37
+ except ConnectionRefusedError:
38
+ logging.info("Still waiting for server...")
39
+ time.sleep(5)
40
+
41
+ @override
42
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
43
+ data = self._packer.pack(obs)
44
+ self._ws.send(data)
45
+ response = self._ws.recv()
46
+ if isinstance(response, str):
47
+ # we're expecting bytes; if the server sends a string, it's an error.
48
+ raise RuntimeError(f"Error in inference server:\n{response}")
49
+ return unpackb(response)
50
+
51
+ @override
52
+ def reset(self, robo_name: str) -> None:
53
+ self.infer(dict(reset=True, robo_name=robo_name))
54
+
55
+ if __name__ == "__main__":
56
+ policy_on_device = WebsocketClientPolicy(port=8000)
57
+ import torch
58
+ import numpy as np
59
+ from PIL import Image
60
+ from .image_tools import convert_to_uint8
61
+ device = torch.device("cuda")
62
+
63
+ base_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
64
+ left_wrist_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
65
+ state = np.random.rand(1,8).astype(np.float32)
66
+ prompt = ["do something"]
67
+
68
+ # observation = {
69
+ # "image": {
70
+ # "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None],
71
+ # "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None],
72
+ # },
73
+ # "state": torch.from_numpy(state).to(device)[None],
74
+ # "prompt": prompt,
75
+ # }
76
+
77
+ observation = {
78
+ "image": {
79
+ "base_0_rgb": convert_to_uint8(base_0_rgb),
80
+ "left_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
81
+ "right_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
82
+ },
83
+ "state": state,
84
+ "prompt": prompt,
85
+ }
86
+
87
+ policy_on_device.infer(observation)
88
+ from IPython import embed;embed()
deploy/websocket_policy_server.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import http
3
+ import logging
4
+ import time
5
+ import traceback
6
+
7
+ from .msgpack_numpy import Packer, unpackb
8
+ import websockets.asyncio.server as _server
9
+ import websockets.frames
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class WebsocketPolicyServer:
15
+ """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
16
+
17
+ Currently only implements the `load` and `infer` methods.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ policy,
23
+ host: str = "0.0.0.0",
24
+ port: int | None = None,
25
+ metadata: dict | None = None,
26
+ ) -> None:
27
+ self._policy = policy
28
+ self._host = host
29
+ self._port = port
30
+ self._metadata = metadata or {}
31
+ logging.getLogger("websockets.server").setLevel(logging.INFO)
32
+
33
+ def serve_forever(self) -> None:
34
+ asyncio.run(self.run())
35
+
36
+ async def run(self):
37
+ async with _server.serve(
38
+ self._handler,
39
+ self._host,
40
+ self._port,
41
+ compression=None,
42
+ max_size=None,
43
+ process_request=_health_check,
44
+ ) as server:
45
+ await server.serve_forever()
46
+
47
+ async def _handler(self, websocket: _server.ServerConnection):
48
+ logger.info(f"Connection from {websocket.remote_address} opened")
49
+ packer = Packer()
50
+
51
+ await websocket.send(packer.pack(self._metadata))
52
+
53
+ prev_total_time = None
54
+ while True:
55
+ try:
56
+ start_time = time.monotonic()
57
+ obs = unpackb(await websocket.recv())
58
+
59
+ infer_time = time.monotonic()
60
+ action = self._policy.infer(obs)
61
+ infer_time = time.monotonic() - infer_time
62
+
63
+ action["server_timing"] = {
64
+ "infer_ms": infer_time * 1000,
65
+ }
66
+ if prev_total_time is not None:
67
+ # We can only record the last total time since we also want to include the send time.
68
+ action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
69
+
70
+ await websocket.send(packer.pack(action))
71
+ prev_total_time = time.monotonic() - start_time
72
+
73
+ except websockets.ConnectionClosed:
74
+ logger.info(f"Connection from {websocket.remote_address} closed")
75
+ break
76
+ except Exception:
77
+ await websocket.send(traceback.format_exc())
78
+ await websocket.close(
79
+ code=websockets.frames.CloseCode.INTERNAL_ERROR,
80
+ reason="Internal server error. Traceback included in previous frame.",
81
+ )
82
+ raise
83
+
84
+
85
+ def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
86
+ if request.path == "/healthz":
87
+ return connection.respond(http.HTTPStatus.OK, "OK\n")
88
+ # Continue with the normal request handling.
89
+ return None
docker/Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
2
+ # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
3
+ FROM nvcr.io/nvidia/pytorch:24.08-py3
4
+
5
+ # Define environments
6
+ ENV MAX_JOBS=32
7
+ ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
8
+ ENV DEBIAN_FRONTEND=noninteractive
9
+ ENV NODE_OPTIONS=""
10
+
11
+
12
+ # Install systemctl and tini
13
+ RUN apt-get update && \
14
+ apt-get install -y -o Dpkg::Options::="--force-confdef" systemd tini && \
15
+ apt-get clean || { echo "Installation failed"; exit 1; }
16
+
17
+ RUN apt-get install -y tzdata \
18
+ && ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
19
+ && dpkg-reconfigure -f noninteractive tzdata
20
+
21
+ # Change pip source
22
+ RUN python -m pip install --upgrade pip
23
+
24
+ # Install torch-2.5.1 + vllm-0.7.3
25
+ RUN pip install --no-cache-dir vllm==0.7.3 torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \
26
+ transformers>=4.49.0 accelerate datasets peft hf-transfer diffusers \
27
+ codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb ninja liger-kernel \
28
+ pytest yapf py-spy pyext pre-commit ruff packaging
29
+
30
+ # Install flux
31
+ RUN pip install --no-cache-dir byte-flux
32
+
33
+ # Install flash-attn and triton
34
+ RUN pip install --no-cache-dir flash-attn triton>=3.1.0
docs/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal makefile for Sphinx documentation
2
+ #
3
+
4
+ # You can set these variables from the command line.
5
+ SPHINXOPTS =
6
+ SPHINXBUILD = sphinx-build
7
+ SPHINXPROJ = LingBotVLA
8
+ SOURCEDIR = .
9
+ BUILDDIR = _build
10
+
11
+ # Put it first so that "make" without argument is like "make help".
12
+ help:
13
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14
+
15
+ .PHONY: help Makefile
16
+
17
+ # Catch-all target: route all unknown targets to Sphinx using the new
18
+ # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19
+ %: Makefile
20
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
docs/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LingBotVLA documents
2
+
3
+ ## Build the docs
4
+
5
+ ```bash
6
+ # Install dependencies.
7
+ pip install -r requirements-docs.txt
8
+
9
+ # Build the docs.
10
+ make clean
11
+ make html
12
+ ```
13
+
14
+ ## Open the docs with your browser
15
+
16
+ ```bash
17
+ python -m http.server -d _build/html/
18
+ ```
19
+ Launch your browser and open localhost:8000.
docs/conf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # This file only contains a selection of the most common options. For a full
4
+ # list see the documentation:
5
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
6
+
7
+ # -- Path setup --------------------------------------------------------------
8
+
9
+ # If extensions (or modules to document with autodoc) are in another directory,
10
+ # add these directories to sys.path here. If the directory is relative to the
11
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
12
+ #
13
+ # import os
14
+ # import sys
15
+ # sys.path.insert(0, os.path.abspath('.'))
16
+
17
+
18
+ # -- Project information -----------------------------------------------------
19
+
20
+ project = "LingBotVLA"
21
+ # pylint: disable=W0622
22
+ copyright = "2026 Robbyant Team, based on VeOmni by ByteDance Seed Foundation MLSys Team"
23
+
24
+ # -- General configuration ---------------------------------------------------
25
+ # The master toctree document.
26
+ master_doc = "index"
27
+
28
+ # Add any Sphinx extension module names here, as strings. They can be
29
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
30
+ # ones.
31
+ extensions = [
32
+ "recommonmark",
33
+ "sphinx.ext.autosectionlabel",
34
+ ]
35
+
36
+ # The suffix(es) of source filenames.
37
+ # You can specify multiple suffix as a list of string:
38
+ source_suffix = [".rst", "rest", ".md"]
39
+
40
+ # Add any paths that contain templates here, relative to this directory.
41
+ templates_path = ["_templates"]
42
+
43
+ # The language for content autogenerated by Sphinx. Refer to documentation
44
+ # for a list of supported languages.
45
+ #
46
+ # This is also used if you do content translation via gettext catalogs.
47
+ # Usually you set "language" from the command line for these cases.
48
+ language = "en"
49
+
50
+ # List of patterns, relative to source directory, that match files and
51
+ # directories to ignore when looking for source files.
52
+ # This pattern also affects html_static_path and html_extra_path.
53
+ exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
54
+
55
+
56
+ # -- Options for HTML output -------------------------------------------------
57
+
58
+ # The theme to use for HTML and HTML Help pages. See the documentation for
59
+ # a list of builtin themes.
60
+ #
61
+ html_theme = "sphinx_rtd_theme"
62
+
63
+ # Add any paths that contain custom static files (such as style sheets) here,
64
+ # relative to this directory. They are copied after the builtin static files,
65
+ # so a file named "default.css" will overwrite the builtin "default.css".
66
+ html_static_path = ["_static"]
docs/config/config.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Config arguments Explanation
2
+ ### Model configuration arguments
3
+ | Name | Type | Description | Default Value |
4
+ | --- | --- | --- | --- |
5
+ | model.config_path | str | Path to the model huggingface configuration, like `config.json` | model.model_path |
6
+ | model.model_path | str | Path to the model parameter file. If empty, random initialization will be performed | None |
7
+ | model.tokenizer_path | str | Path to the tokenizer | model.model_path |
8
+ | model.encoders | dict | Configuration file for multi-modal encoders | {} |
9
+ | model.decoders | dict | Configuration file for multi-modal decoders | {} |
10
+ | model.input_encoder | str: {"encoder", "decoder"} | Use the encoder of the encoder or decoder to encode the input image | encoder |
11
+ | model.output_encoder | str: {"encoder", "decoder"} | Use the encoder of the encoder or decoder to encode the output image | decoder |
12
+ | model.encode_target | bool | Used to encode the training data for the diffusion model | False |
13
+
14
+ ### Data configuration arguments
15
+
16
+ | Name | Type | Description | Default Value |
17
+ | --- | --- | --- | --- |
18
+ | data.train_path | str | Path of training dataset | Required |
19
+ | data.train_size | int | Total number of tokens in the training set | 10,000,000 |
20
+ | data.data_type | str: {"plaintext", "conversation"} | Dataset type. | conversation |
21
+ | data.dataloader_type | str: {"native"} | Use the pytorch dataloader or | native |
22
+ | data.datasets_type | str: {"mapping", "iterable"} | Dataset type. `IterativeDataset` or `MappingDataset`, or your custom datsets | mapping |
23
+ | data.text_keys | str: {"content_split", "messages"} | The key corresponding to the text samples in the data dictionary. Generally, it is "content_split" for pretraining and "messages" for SFT. | content_split |
24
+ | data.image_keys | str | The key corresponding to the image samples in the data dictionary. Generally, it is "images". | images |
25
+ | data.chat_template | str | Name of the chat template. | default |
26
+ | data.max_seq_len | int | Maximum training length. | 2048 |
27
+ | data.num_workers | int | Number of multi-process loaders for the dataloader. | 4 |
28
+ | data.drop_last | bool | Whether to discard the remaining data at the end. | True |
29
+ | data.pin_memory | bool | Whether to pin the data in the CPU memory. | True |
30
+ | data.prefetch_factor | int | Number of samples preprocessed by the dataloader. | 2 |
31
+
32
+ #### Training configuration arguments
33
+ | Name | Type | Description | Default Value |
34
+ | --- | --- | --- | --- |
35
+ | train.output_dir | str | Path to save the model. | Required |
36
+ | train.lr | float | Maximum learning rate. | 5e - 5 |
37
+ | train.lr_min | float | Minimum learning rate. | 1e - 7 |
38
+ | train.weight_decay | float | Weight decay coefficient. | 0 |
39
+ | train.optimizer | str: {"adamw", "anyprecision_adamw"} | Name of the optimizer. | adamw |
40
+ | train.max_grad_norm | float | Gradient clipping norm. | 1.0 |
41
+ | train.micro_batch_size | int | Number of samples processed simultaneously on each GPU. | 1 |
42
+ | train.global_batch_size | int | Global batch size, which must be a multiple of the number of GPUs. | train.micro_batch_size * n_gpus |
43
+ | train.num_train_epochs | int | Number of training epochs. | 1 |
44
+ | train.rmpad | bool | Whether to use rmpad training based on cu_seqlens. | False |
45
+ | train.rmpad_with_pos_ids | bool | Whether to use rmpad training based on position_ids. | False |
46
+ | train.dyn_bsz_margin | int | Number of pad tokens in the dynamic batch. | 0 |
47
+ | train.dyn_bsz_runtime | str: {"main", "worker"} | Running process of the dynamic batch. | worker |
48
+ | train.bsz_warmup_ratio | float | Proportion of batch size warmup in the total number of steps. | 0 |
49
+ | train.lr_warmup_ratio | float | Proportion of learning rate warmup in the total number of steps. | 0 |
50
+ | train.lr_decay_style | str: {"constant", "linear", "cosine"} | Name of the learning rate scheduler. | cosine |
51
+ | train.lr_decay_ratio | float | Proportion of learning rate decay in the total number of steps | 1.0 |
52
+ | train.use_doptim | bool | Whether to use the distributed optimizer during Vescale training(no use for torch fsdp) | False |
53
+ | train.enable_mixed_precision | bool | Whether to enable mixed precision training (higher memory usage but more stable) | True |
54
+ | train.enable_gradient_checkpointing | bool | Whether to enable gradient checkpointing to reduce memory usage. | True |
55
+ | train.enable_reentrant | bool | Whether to enable reentrant in gradient checkpointing. | True |
56
+ | train.enable_full_shard | bool | Whether to use full sharding FSDP (equivalent to ZeRO3). | True |
57
+ | train.enable_fsdp_offload | bool | Whether to enable FSDP CPU offloading (only supported for FSDP1). | False |
58
+ | train.enable_activation_offload | bool | Whether to enable activation value CPU offloading. | False |
59
+ | train.activation_gpu_limit | float | Size of the activation values retained on the GPU (in GB). | 0.0 |
60
+ | train.enable_manual_eager | bool | Whether to use manual eager during Vescale training. | False |
61
+ | train.init_device: meta | str | "cpu", "cuda", "meta", init device for model initialization. use "meta" or cpu for large model(>30B) | cuda |
62
+ | train.enable_full_determinism | bool | Whether to enable deterministic mode (for bitwise alignment). | False |
63
+ | train.empty_cache_steps | int | Number of steps between two cache clearings. -1 means not enabled. | 500 |
64
+ | train.data_parallel_mode | str: {"ddp", "fsdp1", "fsdp2"} | Data parallel algorithm. | ddp |
65
+ | train.tensor_parallel_size | int | Tensor parallel size (currently only supported for vescale training). | 1 |
66
+ | train.pipeline_parallel_size | int | Pipeline parallel size (currently not supported). | 1 |
67
+ | train.ulysses_parallel_size | int | Ulysses sequence parallel size (currently only supported for P6dense and Qwen2VL). | 1 |
68
+ | train.context_parallel_size | int | Ring sequence parallel size (currently not supported) | 1 |
69
+ | train.expert_parallel_size | int | Expert parallel size (currently only supported DeepseekMOE) | 1 |
70
+ | train.load_checkpoint_path | str | Path to the omnistore checkpoint for resuming training. | None |
71
+ | train.save_steps | int | Number of steps between two checkpoint saves. 0 means invalid. | 0 |
72
+ | train.save_epochs | int | Number of epochs between two checkpoint saves. 0 means invalid. | 1 |
73
+ | train.save_hf_weights | bool | Whether to save the model weights in the huggingface format. It is recommended to set it to False for models > 30B to prevent NCCL timeout. You can convert it after training. | True |
74
+ | train.seed | int | Random seed. | 42 |
75
+ | train.use_wandb | bool | Whether to enable byted wandb experiment logging. | True |
76
+ | train.wandb_project | str | Name of the wandb experiment project. | LingBotVLA |
77
+ | train.wandb_name | str | Name of the wandb experiment. | None |
78
+ | train.enable_profiling | bool | Whether to use torch profiling. | False |
79
+ | train.profile_start_step | int | Starting step of profiling. | 1 |
80
+ | train.profile_end_step | int | Ending step of profiling. | 2 |
81
+ | train.profile_trace_dir | str | Path to save the profiling results. | ./trace |
82
+ | train.profile_record_shapes | bool | Whether to record the shapes of the input tensors. | True |
83
+ | train.profile_profile_memory | bool | Whether to record the memory usage. | True |
84
+ | train.profile_with_stack | bool | Whether to record the stack information. | True |
85
+ | train.max_steps | int | Number of steps per training epoch (only used for debugging). | None |
86
+
87
+ ### Inference configuration arguments
88
+ | Name | Type | Description | Default Value |
89
+ | --- | --- | --- | --- |
90
+ | infer.model_path | str | Path to the model parameter file. | Required |
91
+ | infer.tokenizer_path | str | Path to the tokenizer. | model.model_path |
92
+ | infer.seed | int | Random seed. | 42 |
93
+ | infer.do_sample | bool | Whether to enable sampling. | True |
94
+ | infer.temperature | float | Sampling temperature. | 1.0 |
95
+ | infer.top_p | float | Sampling Top P value. | 1.0 |
96
+ | infer.max_tokens | int | Maximum number of tokens generated each time. | 1024 |
docs/examples/qwen2vl.rst ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Qwen2VL example
2
+ =========================
docs/examples/qwen3_moe.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Qwen3 MoE training guide
2
+
3
+ 1. Download qwen3 moe model
4
+
5
+ ```shell
6
+ python3 scripts/download_hf_model.py \
7
+ --repo_id Qwen/Qwen3-30B-A3B \
8
+ --local_dir .
9
+ ```
10
+
11
+ 2. Merge qwen3 moe model experts to support GroupGemm optimize
12
+ ``` shell
13
+ python3 scripts/moe_ckpt_merge/moe_merge.py --raw_hf_path Qwen3-30B-A3B --merge_hf_path Qwen3-30B-A3B-merge
14
+ ```
15
+
16
+ Most of the MoE models in Transformers referenced the open-source implementation of Mixtral MoE. In this implementation, MoE experts are divided into multiple blocks instead of being combined into a single `nn.Parameters`. Additionally, there are cpu-block operators like `torch.where()` and for loop, which are not very friendly for integrating MoE fusion operators.
17
+
18
+ Origin [Qwen3MoeMLP](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L200C1-L213C25) code
19
+ ```python
20
+ class Qwen3MoeMLP(nn.Module):
21
+ def __init__(self, config, intermediate_size=None):
22
+ super().__init__()
23
+ self.config = config
24
+ self.hidden_size = config.hidden_size
25
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
26
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
27
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
28
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
29
+ self.act_fn = ACT2FN[config.hidden_act]
30
+
31
+ def forward(self, x):
32
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
33
+ return down_proj
34
+
35
+ class Qwen3MoeSparseMoeBlock(nn.Module):
36
+ def __init__(self, config):
37
+
38
+ ...
39
+
40
+ self.experts = nn.ModuleList(
41
+ [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
42
+ )
43
+
44
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
45
+
46
+ ...
47
+
48
+ final_hidden_states = torch.zeros(
49
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
50
+ )
51
+
52
+ for expert_idx in expert_hitted:
53
+ expert_layer = self.experts[expert_idx]
54
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
55
+
56
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
57
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
58
+
59
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
60
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
61
+ return final_hidden_states, router_logits
62
+
63
+ ```
64
+
65
+ - Combine Qwen3MoeMLP to Qwen3MoeExperts, then use fused moe operator
66
+
67
+ ```python
68
+ class Qwen3MoeExperts(nn.Module):
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.num_experts = config.num_experts
72
+ self.hidden_dim = config.hidden_size
73
+ self.intermediate_size = config.moe_intermediate_size
74
+ self.gate_proj = torch.nn.Parameter(
75
+ torch.empty(self.num_experts, self.intermediate_size, self.hidden_dim),
76
+ requires_grad=True,
77
+ )
78
+ self.up_proj = torch.nn.Parameter(
79
+ torch.empty(self.num_experts, self.intermediate_size, self.hidden_dim),
80
+ requires_grad=True,
81
+ )
82
+ self.down_proj = torch.nn.Parameter(
83
+ torch.empty(self.num_experts, self.hidden_dim, self.intermediate_size),
84
+ requires_grad=True,
85
+ )
86
+ self.act_fn = ACT2FN[config.hidden_act]
87
+
88
+ def forward(self, hidden_states, expert_idx=None, cumsum=None):
89
+ gate_proj_out = torch.matmul(hidden_states, self.gate_proj[expert_idx].transpose(0, 1))
90
+ up_proj_out = torch.matmul(hidden_states, self.up_proj[expert_idx].transpose(0, 1))
91
+
92
+ out = self.act_fn(gate_proj_out) * up_proj_out
93
+ out = torch.matmul(out, self.down_proj[expert_idx].transpose(0, 1))
94
+ return out
95
+
96
+
97
+ class Qwen3MoeSparseFusedMoeBlock(nn.Module):
98
+ def __init__(self, config):
99
+
100
+ ...
101
+
102
+ self.experts = Qwen3MoeExperts(config)
103
+
104
+ def forward(self, hidden_states, expert_idx=None, routing_weights=None, selected_experts=None) -> torch.Tensor:
105
+
106
+ ...
107
+
108
+ out = fused_moe_forward(
109
+ module=self,
110
+ num_experts=self.num_experts,
111
+ routing_weights=routing_weights,
112
+ selected_experts=selected_experts,
113
+ hidden_states=hidden_states,
114
+ fc1_1_weight=self.gate_proj,
115
+ fc1_2_weight=self.up_proj,
116
+ fc2_weight=self.down_proj,
117
+ )
118
+ return out
119
+
120
+ ```
121
+
122
+ 3. Train qwen3 moe model
123
+ ```
124
+ bash train.sh tasks/train_torch.py configs/pretrain/qwen3-moe.yaml
125
+ ```
docs/index.rst ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Welcome to LingBotVLA
2
+ =========================
docs/requirements-docs.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # markdown suport
2
+ recommonmark
3
+ # markdown table suport
4
+ sphinx-markdown-tables
5
+
6
+ # theme default rtd
7
+
8
+ # crate-docs-theme
9
+ sphinx-rtd-theme
docs/start/start.rst ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Getting Started
2
+ =========================
experiment/libero/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install official LIBERO
2
+
3
+ ```bash
4
+ git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git libero # (here)
5
+ cd libero
6
+ pip install -e .
7
+
8
+ cd experiment/libero/libero
9
+ pip install -r req.txt
10
+ ```
11
+
12
+ If can not import xxx from libero.libero please add the libero (here) path to the PYTHONPATH variable.
13
+
14
+ The results will be save to /project_root/Libero
15
+
16
+ - release_ensemble/ stores the log files (This directory can be changed by --local_log_dir variable)
17
+ - rollouts stores the videos
18
+
experiment/libero/libero/libero_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for evaluating policies in LIBERO simulation environments."""
2
+
3
+ import math
4
+ import os
5
+
6
+ import imageio
7
+ import numpy as np
8
+ import tensorflow as tf
9
+ from libero.libero import get_libero_path
10
+ from libero.libero.envs import OffScreenRenderEnv
11
+
12
+ from experiment.libero.robot_utils import (
13
+ DATE,
14
+ DATE_TIME,
15
+ )
16
+
17
+
18
+ def get_libero_env(task, model_family, resolution=256):
19
+ """Initializes and returns the LIBERO environment, along with the task description."""
20
+ task_description = task.language
21
+ task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
22
+ env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
23
+ env = OffScreenRenderEnv(**env_args)
24
+ env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
25
+ return env, task_description
26
+
27
+
28
+ def get_libero_dummy_action(model_family: str):
29
+ """Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
30
+ return [0, 0, 0, 0, 0, 0, -1]
31
+
32
+
33
+ def resize_image(img, resize_size):
34
+ """
35
+ Takes numpy array corresponding to a single image and returns resized image as numpy array.
36
+
37
+ NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow
38
+ the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.
39
+ """
40
+ assert isinstance(resize_size, tuple)
41
+ # Resize to image size expected by model
42
+ with tf.device('/CPU:0'):
43
+ img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder
44
+ img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back
45
+ img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
46
+ img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
47
+ img = img.numpy()
48
+ return img
49
+
50
+
51
+ def get_libero_image(obs, resize_size):
52
+ """Extracts image from observations and preprocesses it."""
53
+ assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
54
+ if isinstance(resize_size, int):
55
+ resize_size = (resize_size, resize_size)
56
+ img = obs["agentview_image"]
57
+ img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
58
+ img = resize_image(img, resize_size)
59
+ return img
60
+
61
+
62
+ def get_libero_wrist_image(obs, resize_size):
63
+ """Extracts wrist camera image from observations and preprocesses it."""
64
+ assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
65
+ if isinstance(resize_size, int):
66
+ resize_size = (resize_size, resize_size)
67
+ img = obs["robot0_eye_in_hand_image"]
68
+ img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
69
+ img = resize_image(img, resize_size)
70
+ return img
71
+
72
+ def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, ckpt_index=None, task_suite_name=None, task_id=None):
73
+ """Saves an MP4 replay of an episode."""
74
+ rollout_dir = f"./Libero/rollouts/{ckpt_index}/{task_suite_name}-task{task_id}-{DATE_TIME}-{ckpt_index}"
75
+ os.makedirs(rollout_dir, exist_ok=True)
76
+ processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
77
+ mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4"
78
+ video_writer = imageio.get_writer(mp4_path, fps=30)
79
+ for img in rollout_images:
80
+ video_writer.append_data(img)
81
+ video_writer.close()
82
+ print(f"Saved rollout MP4 at path {mp4_path}")
83
+ if log_file is not None:
84
+ log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
85
+ return mp4_path
86
+
87
+
88
+ def quat2axisangle(quat):
89
+ """
90
+ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
91
+
92
+ Converts quaternion to axis-angle format.
93
+ Returns a unit vector direction scaled by its angle in radians.
94
+
95
+ Args:
96
+ quat (np.array): (x,y,z,w) vec4 float angles
97
+
98
+ Returns:
99
+ np.array: (ax,ay,az) axis-angle exponential coordinates
100
+ """
101
+ # clip quaternion
102
+ if quat[3] > 1.0:
103
+ quat[3] = 1.0
104
+ elif quat[3] < -1.0:
105
+ quat[3] = -1.0
106
+
107
+ den = np.sqrt(1.0 - quat[3] * quat[3])
108
+ if math.isclose(den, 0.0):
109
+ # This is (close to) a zero degree rotation, immediately return
110
+ return np.zeros(3)
111
+
112
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
experiment/libero/libero/req.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ imageio[ffmpeg]
2
+ robosuite==1.4.1
3
+ bddl
4
+ easydict
5
+ cloudpickle
6
+ gym
experiment/libero/libero/run_libero_eval.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ run_libero_eval.py
3
+
4
+ Runs a model in a LIBERO simulation environment.
5
+
6
+ Usage:
7
+ # OpenVLA:
8
+ # IMPORTANT: Set `center_crop=True` if model is fine-tuned with augmentations
9
+ python Libero/robot/libero/run_libero_eval.py \
10
+ --model_family openvla \
11
+ --pretrained_checkpoint <CHECKPOINT_PATH> \
12
+ --task_suite_name [ libero_spatial | libero_object | libero_goal | libero_10 | libero_90 ] \
13
+ --center_crop [ True | False ] \
14
+ --run_id_note <OPTIONAL TAG TO INSERT INTO RUN ID FOR LOGGING> \
15
+ --use_wandb [ True | False ] \
16
+ --wandb_project <PROJECT> \
17
+ --wandb_entity <ENTITY>
18
+ """
19
+
20
+ import tensorflow as tf
21
+ import os, json, re, io, base64, threading
22
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
23
+ for g in tf.config.list_physical_devices('GPU'):
24
+ tf.config.experimental.set_memory_growth(g, True)
25
+
26
+ import os
27
+ import sys
28
+ parent_dir = os.path.dirname(os.getcwd())
29
+ sys.path.insert(0, parent_dir)
30
+ sys.path.insert(0, os.getcwd())
31
+
32
+ from dataclasses import dataclass
33
+ from pathlib import Path
34
+ from typing import Optional, Union
35
+ import torch
36
+
37
+ import draccus
38
+ import numpy as np
39
+ import tqdm
40
+ from libero.libero import benchmark
41
+
42
+ import wandb
43
+
44
+ # Append current directory so that interpreter can find Libero.robot
45
+ from experiment.libero.libero.libero_utils import (
46
+ get_libero_dummy_action,
47
+ get_libero_env,
48
+ get_libero_image,
49
+ get_libero_wrist_image,
50
+ quat2axisangle,
51
+ save_rollout_video,
52
+ )
53
+
54
+ from experiment.libero.robot_utils import (
55
+ DATE_TIME,
56
+ get_action,
57
+ get_image_resize_size,
58
+ get_model,
59
+ invert_gripper_action,
60
+ normalize_gripper_action,
61
+ set_seed_everywhere,
62
+ )
63
+
64
+
65
+ @dataclass
66
+ class GenerateConfig:
67
+ # fmt: off
68
+
69
+ #################################################################################################################
70
+ # Model-specific parameters
71
+ #################################################################################################################
72
+ model_family: str = "instruct_vla" # Model family
73
+ pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
74
+ unnorm_key: Optional[str] = None
75
+ # image_size: list[int] = [224, 224]
76
+ action_dim: int = 7
77
+ model_port: int = 8012
78
+
79
+ #################################################################################################################
80
+ # LIBERO environment-specific parameters
81
+ #################################################################################################################
82
+ task_suite_name: str = "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
83
+ task_id: Optional[int] = None
84
+ num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
85
+ num_trials_per_task: int = 50 # Number of rollouts per task
86
+
87
+ #################################################################################################################
88
+ # Utils
89
+ #################################################################################################################
90
+ run_id_note: Optional[str] = None # Extra note to add in run ID for logging
91
+ local_log_dir: str = "./Libero/logs" # Local directory for eval logs
92
+
93
+ use_wandb: bool = False # Whether to also log results in Weights & Biases
94
+ wandb_project: str = "YOUR_WANDB_PROJECT" # Name of W&B project to log to (use default!)
95
+ wandb_entity: str = "YOUR_WANDB_ENTITY" # Name of entity to log under
96
+
97
+ seed: int = 42 # Random Seed (for reproducibility)
98
+ use_length: int = 8
99
+ # fmt: on
100
+
101
+
102
+ @draccus.wrap()
103
+ def eval_libero(cfg: GenerateConfig) -> None:
104
+
105
+ ckpt_index = cfg.pretrained_checkpoint.split('/checkpoints/')[0].split('/')[-1]
106
+ # Set random seed
107
+ set_seed_everywhere(cfg.seed)
108
+
109
+ # [OpenVLA] Check that the model contains the action un-normalization key
110
+ if cfg.model_family == "openvla":
111
+ # [OpenVLA] Set action un-normalization key
112
+ cfg.unnorm_key = cfg.task_suite_name
113
+ model, server = get_model(cfg)
114
+ server = None
115
+ # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
116
+ # with the suffix "_no_noops" in the dataset name)
117
+ if cfg.unnorm_key not in model.norm_stats and f"{cfg.unnorm_key}_no_noops" in model.norm_stats:
118
+ cfg.unnorm_key = f"{cfg.unnorm_key}_no_noops"
119
+ assert cfg.unnorm_key in model.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!"
120
+
121
+ elif cfg.model_family == "instruct_vla":
122
+ # [OpenVLA] Set action un-normalization key
123
+ cfg.unnorm_key = f"{cfg.task_suite_name}_no_noops"
124
+ model, server = get_model(cfg)
125
+
126
+ # Initialize local logging
127
+ run_id = f"EVAL-{cfg.task_suite_name}-task{cfg.task_id}-{cfg.model_family}-{DATE_TIME}-{ckpt_index}"
128
+ if cfg.run_id_note is not None:
129
+ run_id += f"--{cfg.run_id_note}"
130
+ cfg.local_log_dir = os.path.join(cfg.local_log_dir, ckpt_index)
131
+ os.makedirs(cfg.local_log_dir, exist_ok=True)
132
+ local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
133
+ log_file = open(local_log_filepath, "w")
134
+ print(f"Logging to local log file: {local_log_filepath}")
135
+
136
+ # Initialize Weights & Biases logging as well
137
+ if cfg.use_wandb:
138
+ wandb.init(
139
+ entity=cfg.wandb_entity,
140
+ project=cfg.wandb_project,
141
+ name=run_id,
142
+ )
143
+
144
+ # Initialize LIBERO task suite
145
+ benchmark_dict = benchmark.get_benchmark_dict()
146
+ task_suite = benchmark_dict[cfg.task_suite_name]()
147
+ num_tasks_in_suite = task_suite.n_tasks
148
+ print(f"Task suite: {cfg.task_suite_name}")
149
+ log_file.write(f"Task suite: {cfg.task_suite_name}\n")
150
+
151
+ # Get expected image dimensions
152
+ resize_size = get_image_resize_size(cfg)
153
+
154
+ # Start evaluation
155
+ total_episodes, total_successes = 0, 0
156
+ for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
157
+ # Get task
158
+ if cfg.task_id is not None:
159
+ if cfg.task_suite_name == 'libero_10':
160
+ if task_id != cfg.task_id:
161
+ continue
162
+ task = task_suite.get_task(task_id)
163
+
164
+ # Get default LIBERO initial states
165
+ initial_states = task_suite.get_task_init_states(task_id)
166
+
167
+ # Initialize LIBERO environment and task description
168
+ env, task_description = get_libero_env(task, cfg.model_family, resolution=256)
169
+
170
+ # Start episodes
171
+ task_episodes, task_successes = 0, 0
172
+ for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
173
+ print(f"\nTask: {task_description}")
174
+ log_file.write(f"\nTask: {task_description}\n")
175
+
176
+ # Reset environment
177
+ env.reset()
178
+ server.reset(robo_name='libero')
179
+ # Set initial states
180
+ obs = env.set_init_state(initial_states[episode_idx])
181
+
182
+ # Setup
183
+ t = 0
184
+ replay_images = []
185
+ if cfg.task_suite_name == "libero_spatial":
186
+ max_steps = 220 # longest training demo has 193 steps
187
+ elif cfg.task_suite_name == "libero_object":
188
+ max_steps = 280 # longest training demo has 254 steps
189
+ elif cfg.task_suite_name == "libero_goal":
190
+ max_steps = 300 # longest training demo has 270 steps
191
+ elif cfg.task_suite_name == "libero_10":
192
+ max_steps = 520 # longest training demo has 505 steps
193
+ elif cfg.task_suite_name == "libero_90":
194
+ max_steps = 400 # longest training demo has 373 steps
195
+
196
+ print(f"Starting episode {task_episodes+1}...")
197
+ log_file.write(f"Starting episode {task_episodes+1}...\n")
198
+ while t < max_steps + cfg.num_steps_wait:
199
+ # try:
200
+ # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
201
+ # and we need to wait for them to fall
202
+ if t < cfg.num_steps_wait:
203
+ obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family))
204
+ t += 1
205
+ continue
206
+
207
+ # Get preprocessed image
208
+ img = get_libero_image(obs, resize_size)
209
+ wrist_img = get_libero_wrist_image(obs, resize_size)
210
+
211
+ # Save preprocessed image for replay video
212
+ replay_images.append(img)
213
+
214
+ # Prepare observations dict
215
+ # Note: OpenVLA does not take proprio state as input
216
+
217
+ state = np.concatenate(
218
+ (obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"]))
219
+
220
+ observation = {
221
+ "image": img,
222
+ "wrist_image": wrist_img,
223
+ "state": state,
224
+ "task": task_description,
225
+ }
226
+
227
+ # Query model to get action
228
+ action = get_action(
229
+ server, observation
230
+ ).copy()
231
+
232
+ # Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter
233
+ # action = normalize_gripper_action(action, binarize=True)
234
+ action[..., -1] = np.sign(action[..., -1]) # binarize
235
+
236
+ # [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets
237
+ # (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action
238
+ # action = invert_gripper_action(action) # skip since we use raw action
239
+
240
+ print('==>action is',action)
241
+ # Execute action in environment
242
+ obs, reward, done, info = env.step(action.tolist())
243
+ if done:
244
+ task_successes += 1
245
+ total_successes += 1
246
+ break
247
+ t += 1
248
+
249
+ # except Exception as e:
250
+ # print(f"Caught exception: {e}")
251
+ # log_file.write(f"Caught exception: {e}\n")
252
+ # break
253
+
254
+ task_episodes += 1
255
+ total_episodes += 1
256
+
257
+ # Save a replay video of the episode
258
+ save_rollout_video(
259
+ replay_images, total_episodes, success=done, task_description=task_description, log_file=log_file, ckpt_index=ckpt_index, task_suite_name=cfg.task_suite_name, task_id=task_id
260
+ )
261
+
262
+ # Log current results
263
+ print(f"Success: {done}")
264
+ print(f"# episodes completed so far: {total_episodes}")
265
+ print(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
266
+ log_file.write(f"Success: {done}\n")
267
+ log_file.write(f"# episodes completed so far: {total_episodes}\n")
268
+ log_file.write(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)\n")
269
+ log_file.flush()
270
+
271
+ # Log final results
272
+ print(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
273
+ print(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
274
+ log_file.write(f"Current task success rate: {float(task_successes) / float(task_episodes)}\n")
275
+ log_file.write(f"Current total success rate: {float(total_successes) / float(total_episodes)}\n")
276
+ log_file.flush()
277
+ if cfg.use_wandb:
278
+ wandb.log(
279
+ {
280
+ f"success_rate/{task_description}": float(task_successes) / float(task_episodes),
281
+ f"num_episodes/{task_description}": task_episodes,
282
+ }
283
+ )
284
+
285
+ # Save local log file
286
+ log_file.close()
287
+
288
+ # Push total metrics and local log file to wandb
289
+ if cfg.use_wandb:
290
+ wandb.log(
291
+ {
292
+ "success_rate/total": float(total_successes) / float(total_episodes),
293
+ "num_episodes/total": total_episodes,
294
+ }
295
+ )
296
+ wandb.save(local_log_filepath)
297
+
298
+
299
+ if __name__ == "__main__":
300
+ eval_libero()
experiment/libero/robot_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for evaluating robot policies in various environments."""
2
+
3
+ import os
4
+ import random
5
+ import time
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ # Initialize important constants and pretty-printing mode in NumPy.
11
+ ACTION_DIM = 7
12
+ DATE = time.strftime("%Y_%m_%d")
13
+ DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
14
+ np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
15
+
16
+
17
+
18
+ def set_seed_everywhere(seed: int):
19
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ np.random.seed(seed)
23
+ random.seed(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+ os.environ["PYTHONHASHSEED"] = str(seed)
27
+
28
+
29
+ def get_model(cfg, wrap_diffusion_policy_for_droid=False):
30
+ """Load model for evaluation."""
31
+ from deploy.websocket_client_policy import WebsocketClientPolicy
32
+ cronus_server = WebsocketClientPolicy(port=cfg.model_port)
33
+ return None, cronus_server
34
+
35
+
36
+ def get_image_resize_size(cfg):
37
+ """
38
+ Gets image resize size for a model class.
39
+ If `resize_size` is an int, then the resized image will be a square.
40
+ Else, the image will be a rectangle.
41
+ """
42
+ if cfg.model_family == "openvla" or "instruct_vla" in cfg.model_family:
43
+ resize_size = 224
44
+ else:
45
+ raise ValueError("Unexpected `model_family` found in config.")
46
+ return resize_size
47
+
48
+
49
+ def get_action(server, obs):
50
+ """Queries the model to get an action."""
51
+
52
+ action = server.infer(obs)['action']
53
+ return action
54
+
55
+
56
+ def normalize_gripper_action(action, binarize=True):
57
+ """
58
+ Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1].
59
+ Necessary for some environments (not Bridge) because the dataset wrapper standardizes gripper actions to [0,1].
60
+ Note that unlike the other action dimensions, the gripper action is not normalized to [-1,+1] by default by
61
+ the dataset wrapper.
62
+
63
+ Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1
64
+ """
65
+ # Just normalize the last action to [-1,+1].
66
+ orig_low, orig_high = 0.0, 1.0
67
+ action = np.array(action, copy=True)
68
+ action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1
69
+
70
+ if binarize:
71
+ # Binarize to -1 or +1.
72
+ action[..., -1] = np.sign(action[..., -1])
73
+
74
+ return action
75
+
76
+
77
+ def invert_gripper_action(action):
78
+ """
79
+ Flips the sign of the gripper action (last dimension of action vector).
80
+ This is necessary for some environments where -1 = open, +1 = close, since
81
+ the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.
82
+ """
83
+ action[..., -1] = action[..., -1] * -1.0
84
+ return action
experiment/robotwin/README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generate Lerobot Dataset from RoboTwin Data
2
+
3
+ This guide explains how to process raw data from **RoboTwin** and convert it into the **LerobotDataset** format following the official RoboTwin instructions.
4
+
5
+ ## 1. Clone the Official RoboTwin Repository
6
+ ```bash
7
+ git clone git@github.com:RoboTwin-Platform/RoboTwin.git
8
+ ```
9
+
10
+ ## 2. Create Required Directories
11
+ Navigate to the `policy/pi0` directory inside the cloned RoboTwin repository and create the folders:
12
+
13
+ ```bash
14
+ cd ./policy/pi0
15
+ mkdir processed_data training_data
16
+ ```
17
+
18
+ ## 3. Convert RoboTwin Raw Data to HDF5
19
+
20
+ Use the provided script [process_data_pi0.sh](https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/process_data_pi0.sh):
21
+
22
+ ```bash
23
+ bash process_data_pi0.sh ${task_name} ${task_config} ${expert_data_num}
24
+ ```
25
+
26
+ **Example (clean demo):**
27
+ ```bash
28
+ bash process_data_pi0.sh beat_block_hammer demo_clean 50
29
+ ```
30
+
31
+ **Example (randomized demo):**
32
+ ```bash
33
+ bash process_data_pi0.sh beat_block_hammer demo_randomized 50
34
+ ```
35
+
36
+ If successful, the output folder:
37
+ ```
38
+ processed_data/${task_name}-${task_config}-${expert_data_num}/
39
+ ```
40
+
41
+ ## 4. Prepare Training Data
42
+
43
+ Copy the required processed datasets into `training_data/${model_name}`:
44
+
45
+ ```bash
46
+ cp -r processed_data/${task_name}-${task_config}-${expert_data_num} \
47
+ training_data/${model_name}/
48
+ ```
49
+
50
+ ## 5. Ensure Sufficient Disk Space
51
+
52
+ The generated **LerobotDataset** will be stored under:
53
+
54
+ ```
55
+ $XDG_CACHE_HOME/huggingface/lerobot/${repo_id}
56
+ ```
57
+
58
+ By default, `XDG_CACHE_HOME` points to `~/.cache`, which must have sufficient free space.
59
+ If space is low, change the cache location:
60
+
61
+ ```bash
62
+ export XDG_CACHE_HOME=/path/to/your/cache
63
+ ```
64
+
65
+ ## 6. Generate LerobotDataset Format
66
+
67
+ Run [process_data_pi0.sh](https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/generate.sh) to convert the HDF5 datasets to Lerobot.
68
+
69
+ Parameters:
70
+ - **hdf5_path**: Path to the HDF5 training data (e.g., `./training_data/${model_name}/`)
71
+ - **repo_id**: Name for the dataset (e.g., `my_repo`)
72
+
73
+ ```bash
74
+ bash generate.sh ${hdf5_path} ${repo_id}
75
+ ```
76
+
77
+ **Example:**
78
+ ```bash
79
+ bash generate.sh ./training_data/demo_clean/ demo_clean_repo
80
+ ```
81
+
82
+ Output:
83
+ ```
84
+ ${XDG_CACHE_HOME}/huggingface/lerobot/${repo_id}
85
+ ```
lingbotvla/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ __version__ = "0.0.1"
lingbotvla/checkpoint/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .checkpointer import build_checkpointer
17
+ from .format_utils import bytecheckpoint_ckpt_to_state_dict, ckpt_to_state_dict, dcp_to_torch_state_dict
18
+
19
+
20
+ __all__ = [
21
+ "ckpt_to_state_dict",
22
+ "dcp_to_torch_state_dict",
23
+ "bytecheckpoint_ckpt_to_state_dict",
24
+ "build_checkpointer",
25
+ ]
lingbotvla/checkpoint/checkpointer.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ from abc import ABC, abstractmethod
18
+ from typing import Any, Dict
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
23
+ from ..utils.import_utils import is_torch_version_greater_than
24
+ from ..utils.logging import get_logger
25
+ from pathlib import Path
26
+
27
+ if is_torch_version_greater_than("2.4"):
28
+ import torch.distributed.checkpoint as dcp
29
+ from torch.distributed.checkpoint import (
30
+ FileSystemReader,
31
+ FileSystemWriter,
32
+ )
33
+ from torch.distributed.checkpoint.state_dict import (
34
+ get_model_state_dict,
35
+ get_optimizer_state_dict,
36
+ set_model_state_dict,
37
+ set_optimizer_state_dict,
38
+ )
39
+ from torch.distributed.checkpoint.stateful import Stateful
40
+ else:
41
+ Stateful = ABC
42
+
43
+ logger = get_logger(__name__)
44
+
45
+ _EXTRA_STATE_FORMAT = "extra_state_rank_{}.pt"
46
+ _MODEL_DIR = "model"
47
+ _EMA_DIR = "ema"
48
+ _OPTIMIZER_DIR = "optimizer"
49
+ _EXTRA_STATE_DIR = "extra_state"
50
+
51
+
52
+ class ModelState(Stateful):
53
+ """
54
+ A wrapper around a model to make it stateful.
55
+ Args:
56
+ model (Model): model to wrap.
57
+ """
58
+
59
+ def __init__(self, model):
60
+ self.model = model
61
+
62
+ def state_dict(self):
63
+ model_state_dict = get_model_state_dict(model=self.model)
64
+ return {"model": model_state_dict}
65
+
66
+ def load_state_dict(self, state_dict):
67
+ set_model_state_dict(model=self.model, model_state_dict=state_dict["model"])
68
+
69
+
70
+ class OptimizerState(Stateful):
71
+ """
72
+ A wrapper around an optimizer to make it stateful.
73
+
74
+ Args:
75
+ model (Model): model to wrap.
76
+ optimizer (Optimizer): optimizer to wrap.
77
+ """
78
+
79
+ def __init__(self, model, optimizer):
80
+ self.model = model
81
+ self.optimizer = optimizer
82
+
83
+ def state_dict(self):
84
+ optimizer_state_dict = get_optimizer_state_dict(model=self.model, optimizers=self.optimizer)
85
+ return {"optim": optimizer_state_dict}
86
+
87
+ def load_state_dict(self, state_dict):
88
+ set_optimizer_state_dict(model=self.model, optimizers=self.optimizer, optim_state_dict=state_dict["optim"])
89
+
90
+
91
+ def build_checkpointer(
92
+ dist_backend: str = "fsdp1",
93
+ ckpt_manager: str = "bytecheckpoint",
94
+ ):
95
+ """
96
+ create a checkpointer manager with given mode.
97
+ Args:
98
+ dist_backend (str, optional): checkpoint mode. Defaults to "fsdp1".
99
+ fsdp1: FSDP1 checkpoint from bytecheckpoint
100
+ fsdp2-vescale: FSDP2 checkpoint from bytecheckpoint
101
+ fsdp2: FSDP2 checkpoint from bytecheckpoint
102
+ ddp: DDP checkpoint from bytecheckpoint
103
+ dcp: DCP checkpoint from torch.distributed.checkpoint
104
+ ckpt_manager (str, optional): checkpoint manager. Defaults to "bytecheckpoint".
105
+ bytecheckpoint: bytecheckpoint checkpoint manager
106
+ dcp: torch dcp checkpoint manager
107
+ Raises:
108
+ ValueError: if ckpt_manager is not supported
109
+
110
+ Returns:
111
+ Checkpointer: checkpointer with given mode.
112
+ """
113
+
114
+ if ckpt_manager == "bytecheckpoint":
115
+ if dist_backend == "ddp":
116
+ from bytecheckpoint import DDPCheckpointer as Checkpointer
117
+ elif dist_backend == "fsdp1":
118
+ from bytecheckpoint import FSDPCheckpointer as Checkpointer
119
+ elif dist_backend == "fsdp2-vescale":
120
+ from bytecheckpoint import VeScaleCheckpointer as Checkpointer
121
+ elif dist_backend == "fsdp2":
122
+ from bytecheckpoint import FSDP2Checkpointer as Checkpointer
123
+ elif ckpt_manager == "dcp":
124
+ if not is_torch_version_greater_than("2.4"):
125
+ raise ValueError("DCP checkpoint manager requires torch version >= 2.4")
126
+ if dist_backend not in ["ddp", "fsdp1", "fsdp2"]:
127
+ raise ValueError(
128
+ f"Unsupported distributed backend: {dist_backend} for DCP checkpoint manager, supported modes are: ddp, fsdp1, fsdp2"
129
+ )
130
+ Checkpointer = DistributedCheckpointer
131
+ else:
132
+ raise ValueError(
133
+ f"Unknown checkpoint manager: {ckpt_manager}, supported modes are: bytecheckpoint, dcp, native"
134
+ )
135
+
136
+ return Checkpointer
137
+
138
+
139
+ class CheckpointerBase(ABC):
140
+ """Base class for checkpointer"""
141
+
142
+ @abstractmethod
143
+ def save(
144
+ cls,
145
+ path: str,
146
+ state: Dict[str, Any],
147
+ ):
148
+ return
149
+
150
+ @abstractmethod
151
+ def load(
152
+ cls,
153
+ path: str,
154
+ state: Dict[str, Any],
155
+ ):
156
+ return
157
+
158
+
159
+ class DistributedCheckpointer(CheckpointerBase):
160
+ """
161
+ Distributed checkpointer for torch.distributed.checkpoint
162
+ """
163
+
164
+ @classmethod
165
+ def save(
166
+ cls,
167
+ path: str,
168
+ state: Dict[str, Any],
169
+ global_steps: int = None,
170
+ save_async=False,
171
+ ) -> None:
172
+ """
173
+ save training state to distributed checkpoint
174
+
175
+ args:
176
+ path: path to save checkpoint
177
+ state: state to save
178
+ global_steps: global steps
179
+ save_async: whether to save asynchronously
180
+ return:
181
+ None
182
+ """
183
+
184
+ checkpoint_dir = f"{path}/global_step_{global_steps}" if global_steps else path
185
+ os.makedirs(checkpoint_dir, exist_ok=True)
186
+
187
+ if "model" not in state:
188
+ raise ValueError("Model must be provided to save a distributed checkpoint.")
189
+
190
+ if save_async:
191
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
192
+ dcp.async_save(
193
+ state_dict={"state": ModelState(state["model"])},
194
+ storage_writer=FileSystemWriter(
195
+ model_dir,
196
+ thread_count=16,
197
+ single_file_per_rank=True,
198
+ sync_files=False,
199
+ ),
200
+ )
201
+ if "ema" in state and state["ema"] is not None:
202
+ ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
203
+ dcp.async_save(
204
+ state_dict={"state": ModelState(state["ema"])},
205
+ storage_writer=FileSystemWriter(
206
+ ema_dir,
207
+ thread_count=16,
208
+ single_file_per_rank=True,
209
+ sync_files=False,
210
+ ),
211
+ )
212
+ if "optimizer" in state:
213
+ optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
214
+ dcp.async_save(
215
+ state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])},
216
+ storage_writer=FileSystemWriter(
217
+ optimizer_dir,
218
+ thread_count=16,
219
+ single_file_per_rank=True,
220
+ sync_files=False,
221
+ ),
222
+ )
223
+ else:
224
+ def safe_create_writer(output_dir):
225
+ tmp_path = Path(output_dir) / ".metadata.tmp"
226
+ if tmp_path.exists():
227
+ print(f"Warning: removing existing tmp file: {tmp_path}")
228
+ tmp_path.unlink() # remove .metadata.tmp
229
+ return FileSystemWriter(
230
+ output_dir,
231
+ thread_count=16,
232
+ single_file_per_rank=True,
233
+ sync_files=False,
234
+ )
235
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
236
+ storage_writer = safe_create_writer(model_dir)
237
+ dcp.save(
238
+ state_dict={"state": ModelState(state["model"])},
239
+ storage_writer=storage_writer,
240
+ )
241
+ if "ema" in state and state["ema"] is not None:
242
+ ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
243
+ storage_writer = safe_create_writer(ema_dir)
244
+ dcp.save(
245
+ state_dict={"state": ModelState(state["ema"])},
246
+ storage_writer=storage_writer,
247
+ )
248
+ if "optimizer" in state:
249
+ optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
250
+ dcp.save(
251
+ state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])},
252
+ storage_writer=FileSystemWriter(
253
+ optimizer_dir,
254
+ thread_count=16,
255
+ single_file_per_rank=True,
256
+ sync_files=False,
257
+ ),
258
+ )
259
+ # dist.barrier()
260
+
261
+ if "extra_state" in state:
262
+ extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
263
+ os.makedirs(extra_state_dir, exist_ok=True)
264
+ extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
265
+ torch.save(
266
+ state["extra_state"],
267
+ extra_state_path,
268
+ )
269
+
270
+ logger.info_rank0(f"Saved checkpoint to {checkpoint_dir}")
271
+
272
+ @classmethod
273
+ def load(
274
+ cls,
275
+ path: str,
276
+ state: Dict[str, Any],
277
+ process_group=None,
278
+ ) -> Dict[str, Any]:
279
+ """
280
+ load training state from distributed checkpoint
281
+ args:
282
+ path: path to load checkpoint
283
+ state: state to load, "model" are required, "optimizer" and "extra_state" are optional
284
+
285
+ return:
286
+ state: state loaded
287
+ """
288
+ checkpoint_dir = path
289
+
290
+ if state is None:
291
+ raise ValueError("State dict must be provided to load a distributed checkpoint.")
292
+
293
+ if "model" not in state:
294
+ raise ValueError("Model must be provided to load a distributed checkpoint.")
295
+
296
+ if "ema" in state and state["ema"] is not None:
297
+ ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
298
+ dcp.load(
299
+ state_dict={"state": ModelState(state["ema"])},
300
+ storage_reader=FileSystemReader(ema_dir),
301
+ process_group=process_group,
302
+ )
303
+
304
+ if "optimizer" in state:
305
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
306
+ dcp.load(
307
+ state_dict={"state": ModelState(state["model"])},
308
+ storage_reader=FileSystemReader(model_dir),
309
+ process_group=process_group,
310
+ )
311
+
312
+ optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
313
+ try:
314
+ dcp.load(
315
+ state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, # 1043
316
+ storage_reader=FileSystemReader(optimizer_dir), # 1027
317
+ planner = DefaultLoadPlanner(allow_partial_load=True),
318
+ process_group=process_group,
319
+ )
320
+ except:
321
+ logger.info_rank0(f"Skip loading Optimizer from {checkpoint_dir}")
322
+ else:
323
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
324
+ dcp.load(
325
+ state_dict={"state": ModelState(state["model"])},
326
+ storage_reader=FileSystemReader(model_dir),
327
+ process_group=process_group,
328
+ )
329
+
330
+ if "extra_state" in state:
331
+ extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
332
+ os.makedirs(extra_state_dir, exist_ok=True)
333
+ extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
334
+ state["extra_state"] = torch.load(
335
+ extra_state_path,
336
+ )
337
+
338
+ logger.info_rank0(f"Loaded checkpoint from {checkpoint_dir}")
339
+
340
+ return state