s3y commited on
Commit
2cff6b0
·
verified ·
1 Parent(s): 5c55bd1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. openpi/.dockerignore +3 -0
  2. openpi/.gitattributes +36 -0
  3. openpi/.github/CODEOWNERS +16 -0
  4. openpi/.github/workflows/pre-commit.yml +17 -0
  5. openpi/.github/workflows/test.yml +31 -0
  6. openpi/.gitignore +168 -0
  7. openpi/.gitmodules +6 -0
  8. openpi/.idea/.gitignore +8 -0
  9. openpi/.idea/workspace.xml +12 -0
  10. openpi/.pre-commit-config.yaml +16 -0
  11. openpi/.python-version +1 -0
  12. openpi/.vscode/settings.json +11 -0
  13. openpi/CONTRIBUTING.md +33 -0
  14. openpi/LICENSE +201 -0
  15. openpi/README.md +323 -0
  16. openpi/config.json +85 -0
  17. openpi/docs/docker.md +25 -0
  18. openpi/docs/norm_stats.md +69 -0
  19. openpi/docs/remote_inference.md +71 -0
  20. openpi/examples/aloha_real/Dockerfile +70 -0
  21. openpi/examples/aloha_real/README.md +126 -0
  22. openpi/examples/aloha_real/compose.yml +66 -0
  23. openpi/examples/aloha_real/constants.py +71 -0
  24. openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py +272 -0
  25. openpi/examples/aloha_real/env.py +57 -0
  26. openpi/examples/aloha_real/main.py +51 -0
  27. openpi/examples/aloha_real/real_env.py +176 -0
  28. openpi/examples/aloha_real/requirements.in +18 -0
  29. openpi/examples/aloha_real/requirements.txt +156 -0
  30. openpi/examples/aloha_real/robot_utils.py +275 -0
  31. openpi/examples/aloha_real/video_display.py +36 -0
  32. openpi/examples/aloha_sim/Dockerfile +41 -0
  33. openpi/examples/aloha_sim/README.md +36 -0
  34. openpi/examples/aloha_sim/compose.yml +42 -0
  35. openpi/examples/aloha_sim/env.py +56 -0
  36. openpi/examples/aloha_sim/main.py +55 -0
  37. openpi/examples/aloha_sim/requirements.in +8 -0
  38. openpi/examples/aloha_sim/requirements.txt +132 -0
  39. openpi/examples/aloha_sim/saver.py +40 -0
  40. openpi/examples/convert_jax_model_to_pytorch.py +587 -0
  41. openpi/examples/droid/README.md +84 -0
  42. openpi/examples/droid/README_train.md +106 -0
  43. openpi/examples/droid/compute_droid_nonidle_ranges.py +103 -0
  44. openpi/examples/droid/convert_droid_data_to_lerobot.py +477 -0
  45. openpi/examples/droid/main.py +246 -0
  46. openpi/examples/inference.ipynb +137 -0
  47. openpi/examples/libero/Dockerfile +59 -0
  48. openpi/examples/libero/README.md +71 -0
  49. openpi/examples/libero/compose.yml +54 -0
  50. openpi/examples/libero/convert_libero_data_to_lerobot.py +104 -0
openpi/.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ checkpoints
3
+ data
openpi/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
openpi/.github/CODEOWNERS ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The CODEOWNERS file defines individuals or teams that are automatically requested for
2
+ # review when someone opens a pull request that modifies certain code. When a draft pull
3
+ # request is marked as ready for review, code owners are automatically notified.
4
+ #
5
+ # See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
6
+ #
7
+ # This is a comment.
8
+ # Each line is a file pattern followed by one or more owners.
9
+
10
+ # Global owners.
11
+ * @jimmyt857 @Michael-Equi @uzhilinsky
12
+
13
+ src/openpi/models/ @kvablack @uzhilinsky
14
+ src/openpi/training/ @kvablack @uzhilinsky
15
+
16
+ scripts/ @jimmyt857 @kvablack @uzhilinsky
openpi/.github/workflows/pre-commit.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pre-commit
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ pull_request:
7
+ branches:
8
+ - "*"
9
+ jobs:
10
+ pre-commit:
11
+ runs-on: ubuntu-latest
12
+ env:
13
+ GIT_LFS_SKIP_SMUDGE: true
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - uses: actions/setup-python@v3
17
+ - uses: pre-commit/action@v3.0.1
openpi/.github/workflows/test.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Test
2
+ on:
3
+ pull_request:
4
+ branches:
5
+ - "*"
6
+
7
+ jobs:
8
+ run_tests:
9
+ name: Run Tests
10
+ runs-on: openpi-verylarge
11
+ env:
12
+ GIT_LFS_SKIP_SMUDGE: true
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+
16
+ - name: Install FFmpeg dependencies
17
+ run: |
18
+ sudo apt-get update
19
+ sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev
20
+
21
+ - name: Install uv
22
+ uses: astral-sh/setup-uv@v5
23
+
24
+ - name: Set up Python
25
+ run: uv python install
26
+
27
+ - name: Install the project
28
+ run: uv sync --all-extras --dev
29
+
30
+ - name: Run tests
31
+ run: uv run pytest --strict-markers -m "not manual"
openpi/.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data directories.
2
+ assets/
3
+ checkpoints/
4
+ data/
5
+ wandb/
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ # .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
openpi/.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "third_party/aloha"]
2
+ path = third_party/aloha
3
+ url = https://github.com/Physical-Intelligence/aloha.git
4
+ [submodule "third_party/libero"]
5
+ path = third_party/libero
6
+ url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
openpi/.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
openpi/.idea/workspace.xml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectViewState">
4
+ <option name="hideEmptyMiddlePackages" value="true" />
5
+ <option name="showLibraryContents" value="true" />
6
+ </component>
7
+ <component name="PropertiesComponent">{
8
+ &quot;keyToString&quot;: {
9
+ &quot;settings.editor.selected.configurable&quot;: &quot;dev.sweep.assistant.settings.SweepSettingsConfigurable&quot;
10
+ }
11
+ }</component>
12
+ </project>
openpi/.pre-commit-config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: third_party/
2
+
3
+ repos:
4
+ - repo: https://github.com/astral-sh/uv-pre-commit
5
+ # uv version.
6
+ rev: 0.5.14
7
+ hooks:
8
+ - id: uv-lock
9
+ - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ # Ruff version.
11
+ rev: v0.8.6
12
+ hooks:
13
+ # Run the linter.
14
+ - id: ruff
15
+ args: [--fix]
16
+ - id: ruff-format
openpi/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
openpi/.vscode/settings.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "charliermarsh.ruff",
4
+ "editor.formatOnSave": true,
5
+ },
6
+ "python.testing.pytestArgs": [
7
+ "src"
8
+ ],
9
+ "python.testing.unittestEnabled": false,
10
+ "python.testing.pytestEnabled": true
11
+ }
openpi/CONTRIBUTING.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to openpi
2
+
3
+ We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
4
+
5
+ ## Issues and feature requests
6
+
7
+ You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
8
+
9
+ If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
10
+
11
+ - Your OS type and version and the version of Python you are using
12
+ - Code that allows us to reproduce your bug, including all dependencies
13
+ - Traceback of any exception
14
+ - Any other information that would help us, such as a screenshot
15
+
16
+ In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
17
+
18
+ If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
19
+
20
+ - The motivation for the feature
21
+ - A description of the problem you are trying to solve or your use case
22
+ - Enough information for us to understand the nature of the request
23
+ - Some information for how you intend to use it (this might help us in understanding the motivation!)
24
+
25
+ We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
26
+
27
+ ## Submitting a pull request
28
+
29
+ If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
30
+
31
+ - Make sure that your PR has a clear title and description
32
+ - Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
33
+ - Make sure your PR passes all tests
openpi/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
openpi/README.md ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openpi
2
+
3
+ openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).
4
+
5
+ Currently, this repo contains three types of models:
6
+ - the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based vision-language-action model (VLA).
7
+ - the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer.
8
+ - the [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05), an upgraded version of π₀ with better open-world generalization trained with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation). Note that, in this repository, we currently only support the flow matching head for both $\pi_{0.5}$ training and inference.
9
+
10
+ For all models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets.
11
+
12
+ This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see!
13
+
14
+ ## Updates
15
+
16
+ - [Sept 2025] We released PyTorch support in openpi.
17
+ - [Sept 2025] We released pi05, an upgraded version of pi0 with better open-world generalization.
18
+ - [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training.
19
+ - [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID.
20
+
21
+
22
+ ## Requirements
23
+
24
+ To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training.
25
+
26
+ | Mode | Memory Required | Example GPU |
27
+ | ------------------ | --------------- | ------------------ |
28
+ | Inference | > 8 GB | RTX 4090 |
29
+ | Fine-Tuning (LoRA) | > 22.5 GB | RTX 4090 |
30
+ | Fine-Tuning (Full) | > 70 GB | A100 (80GB) / H100 |
31
+
32
+ The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems.
33
+
34
+ ## Installation
35
+
36
+ When cloning this repo, make sure to update submodules:
37
+
38
+ ```bash
39
+ git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
40
+
41
+ # Or if you already cloned the repo:
42
+ git submodule update --init --recursive
43
+ ```
44
+
45
+ We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:
46
+
47
+ ```bash
48
+ GIT_LFS_SKIP_SMUDGE=1 uv sync
49
+ GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
50
+ ```
51
+
52
+ NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
53
+
54
+ **Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details.
55
+
56
+
57
+
58
+
59
+ ## Model Checkpoints
60
+
61
+ ### Base Models
62
+ We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning.
63
+
64
+ | Model | Use Case | Description | Checkpoint Path |
65
+ | ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
66
+ | $\pi_0$ | Fine-Tuning | Base [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_base` |
67
+ | $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` |
68
+ | $\pi_{0.5}$ | Fine-Tuning | Base [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05) for fine-tuning | `gs://openpi-assets/checkpoints/pi05_base` |
69
+
70
+ ### Fine-Tuned Models
71
+ We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.
72
+
73
+ | Model | Use Case | Description | Checkpoint Path |
74
+ | ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
75
+ | $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid` |
76
+ | $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `gs://openpi-assets/checkpoints/pi0_droid` |
77
+ | $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can fold diverse towels 0-shot on ALOHA robot platforms | `gs://openpi-assets/checkpoints/pi0_aloha_towel` |
78
+ | $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can unpack food from a tupperware container | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` |
79
+ | $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on public [ALOHA](https://dit-policy.github.io/) data: can uncap a pen | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
80
+ | $\pi_{0.5}$-LIBERO | Inference | $\pi_{0.5}$ model fine-tuned for the [LIBERO](https://libero-project.github.io/datasets) benchmark: gets state-of-the-art performance (see [LIBERO README](examples/libero/README.md)) | `gs://openpi-assets/checkpoints/pi05_libero` |
81
+ | $\pi_{0.5}$-DROID | Inference / Fine-Tuning | $\pi_{0.5}$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/) with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation): fast inference and good language-following | `gs://openpi-assets/checkpoints/pi05_droid` |
82
+
83
+
84
+ By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
85
+
86
+
87
+
88
+
89
+ ## Running Inference for a Pre-Trained Model
90
+
91
+ Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model):
92
+ ```python
93
+ from openpi.training import config as _config
94
+ from openpi.policies import policy_config
95
+ from openpi.shared import download
96
+
97
+ config = _config.get_config("pi05_droid")
98
+ checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")
99
+
100
+ # Create a trained policy.
101
+ policy = policy_config.create_trained_policy(config, checkpoint_dir)
102
+
103
+ # Run inference on a dummy example.
104
+ example = {
105
+ "observation/exterior_image_1_left": ...,
106
+ "observation/wrist_image_left": ...,
107
+ ...
108
+ "prompt": "pick up the fork"
109
+ }
110
+ action_chunk = policy.infer(example)["actions"]
111
+ ```
112
+ You can also test this out in the [example notebook](examples/inference.ipynb).
113
+
114
+ We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots.
115
+
116
+ **Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate.
117
+
118
+ **Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details.
119
+
120
+
121
+
122
+
123
+
124
+ ## Fine-Tuning Base Models on Your Own Data
125
+
126
+ We will fine-tune the $\pi_{0.5}$ model on the [LIBERO dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps:
127
+ 1. Convert your data to a LeRobot dataset (which we use for training)
128
+ 2. Defining training configs and running training
129
+ 3. Spinning up a policy server and running inference
130
+
131
+ ### 1. Convert your data to a LeRobot dataset
132
+
133
+ We provide a minimal example script for converting LIBERO data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw LIBERO dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with:
134
+
135
+ ```bash
136
+ uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data
137
+ ```
138
+
139
+ **Note:** If you just want to fine-tune on LIBERO, you can skip this step, because our LIBERO fine-tuning configs point to a pre-converted LIBERO dataset. This step is merely an example that you can adapt to your own data.
140
+
141
+ ### 2. Defining training configs and running training
142
+
143
+ To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for LIBERO below, which you can modify for your own dataset:
144
+
145
+ - [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the LIBERO environment to the model and vice versa. Will be used for both, training and inference.
146
+ - [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw LIBERO data from LeRobot dataset for training.
147
+ - [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader.
148
+
149
+ We provide example fine-tuning configs for [π₀](src/openpi/training/config.py), [π₀-FAST](src/openpi/training/config.py), and [π₀.₅](src/openpi/training/config.py) on LIBERO data.
150
+
151
+ Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config:
152
+
153
+ ```bash
154
+ uv run scripts/compute_norm_stats.py --config-name pi05_libero
155
+ ```
156
+
157
+ Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config):
158
+
159
+ ```bash
160
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_libero --exp-name=my_experiment --overwrite
161
+ ```
162
+
163
+ The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
164
+
165
+ **Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.
166
+
167
+ ### 3. Spinning up a policy server and running inference
168
+
169
+ Once training is complete, we can run inference by spinning up a policy server and then querying it from a LIBERO evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
170
+
171
+ ```bash
172
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_libero --policy.dir=checkpoints/pi05_libero/my_experiment/20000
173
+ ```
174
+
175
+ This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run an evaluation script (or robot runtime) that queries the server.
176
+
177
+ For running the LIBERO eval in particular, we provide (and recommend using) a Dockerized workflow that handles both the policy server and the evaluation script together. See the [LIBERO README](examples/libero/README.md) for more details.
178
+
179
+ If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
180
+
181
+
182
+
183
+ ### More Examples
184
+
185
+ We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
186
+ - [ALOHA Simulator](examples/aloha_sim)
187
+ - [ALOHA Real](examples/aloha_real)
188
+ - [UR5](examples/ur5)
189
+
190
+ ## PyTorch Support
191
+
192
+ openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions! The PyTorch implementation has been validated on the LIBERO benchmark (both inference and finetuning). A few features are currently not supported (this may change in the future):
193
+
194
+ - The π₀-FAST model
195
+ - Mixed precision training
196
+ - FSDP (fully-sharded data parallelism) training
197
+ - LoRA (low-rank adaptation) training
198
+ - EMA (exponential moving average) weights during training
199
+
200
+ ### Setup
201
+ 1. Make sure that you have the latest version of all dependencies installed: `uv sync`
202
+
203
+ 2. Double check that you have transformers 4.53.2 installed: `uv pip show transformers`
204
+
205
+ 3. Apply the transformers library patches:
206
+ ```bash
207
+ cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
208
+ ```
209
+
210
+ This overwrites several files in the transformers library with necessary model changes: 1) supporting AdaRMS, 2) correctly controlling the precision of activations, and 3) allowing the KV cache to be used without being updated.
211
+
212
+ **WARNING**: With the default uv link mode (hardlink), this will permanently affect the transformers library in your uv cache, meaning the changes will survive reinstallations of transformers and could even propagate to other projects that use transformers. To fully undo this operation, you must run `uv cache clean transformers`.
213
+
214
+ ### Converting JAX Models to PyTorch
215
+
216
+ To convert a JAX model checkpoint to PyTorch format:
217
+
218
+ ```bash
219
+ uv run examples/convert_jax_model_to_pytorch.py \
220
+ --checkpoint_dir /path/to/jax/checkpoint \
221
+ --config_name <config name> \
222
+ --output_path /path/to/converted/pytorch/checkpoint
223
+ ```
224
+
225
+ ### Running Inference with PyTorch
226
+
227
+ The PyTorch implementation uses the same API as the JAX version - you only need to change the checkpoint path to point to the converted PyTorch model:
228
+
229
+ ```python
230
+ from openpi.training import config as _config
231
+ from openpi.policies import policy_config
232
+ from openpi.shared import download
233
+
234
+ config = _config.get_config("pi05_droid")
235
+ checkpoint_dir = "/path/to/converted/pytorch/checkpoint"
236
+
237
+ # Create a trained policy (automatically detects PyTorch format)
238
+ policy = policy_config.create_trained_policy(config, checkpoint_dir)
239
+
240
+ # Run inference (same API as JAX)
241
+ action_chunk = policy.infer(example)["actions"]
242
+ ```
243
+
244
+ ### Policy Server with PyTorch
245
+
246
+ The policy server works identically with PyTorch models - just point to the converted checkpoint directory:
247
+
248
+ ```bash
249
+ uv run scripts/serve_policy.py policy:checkpoint \
250
+ --policy.config=pi05_droid \
251
+ --policy.dir=/path/to/converted/pytorch/checkpoint
252
+ ```
253
+
254
+ ### Finetuning with PyTorch
255
+
256
+ To finetune a model in PyTorch:
257
+
258
+ 1. Convert the JAX base model to PyTorch format:
259
+ ```bash
260
+ uv run examples/convert_jax_model_to_pytorch.py \
261
+ --config_name <config name> \
262
+ --checkpoint_dir /path/to/jax/base/model \
263
+ --output_path /path/to/pytorch/base/model
264
+ ```
265
+
266
+ 2. Specify the converted PyTorch model path in your config using `pytorch_weight_path`
267
+
268
+ 3. Launch training using one of these modes:
269
+
270
+ ```bash
271
+ # Single GPU training:
272
+ uv run scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
273
+
274
+ # Example:
275
+ uv run scripts/train_pytorch.py debug --exp_name pytorch_test
276
+ uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint
277
+
278
+ # Multi-GPU training (single node):
279
+ uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
280
+
281
+ # Example:
282
+ uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
283
+ uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
284
+
285
+ # Multi-Node Training:
286
+ uv run torchrun \
287
+ --nnodes=<num_nodes> \
288
+ --nproc_per_node=<gpus_per_node> \
289
+ --node_rank=<rank_of_node> \
290
+ --master_addr=<master_ip> \
291
+ --master_port=<port> \
292
+ scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
293
+ ```
294
+
295
+ ### Precision Settings
296
+
297
+ JAX and PyTorch implementations handle precision as follows:
298
+
299
+ **JAX:**
300
+ 1. Inference: most weights and computations in bfloat16, with a few computations in float32 for stability
301
+ 2. Training: defaults to mixed precision: weights and gradients in float32, (most) activations and computations in bfloat16. You can change to full float32 training by setting `dtype` to float32 in the config.
302
+
303
+ **PyTorch:**
304
+ 1. Inference: matches JAX -- most weights and computations in bfloat16, with a few weights converted to float32 for stability
305
+ 2. Training: supports either full bfloat16 (default) or full float32. You can change it by setting `pytorch_training_precision` in the config. bfloat16 uses less memory but exhibits higher losses compared to float32. Mixed precision is not yet supported.
306
+
307
+ With torch.compile, inference speed is comparable between JAX and PyTorch.
308
+
309
+ ## Troubleshooting
310
+
311
+ We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).
312
+
313
+ | Issue | Resolution |
314
+ | ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
315
+ | `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). |
316
+ | Training runs out of GPU memory | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` (or higher) before running training to allow JAX to use more GPU memory. You can also use `--fsdp-devices <n>` where `<n>` is your number of GPUs, to enable [fully-sharded data parallelism](https://engineering.fb.com/2021/07/15/open-source/fsdp/), which reduces memory usage in exchange for slower training (the amount of slowdown depends on your particular setup). If you are still running out of memory, you may way to consider disabling EMA. |
317
+ | Policy server connection errors | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server. |
318
+ | Missing norm stats error when training | Run `scripts/compute_norm_stats.py` with your config name before starting training. |
319
+ | Dataset download fails | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`). |
320
+ | CUDA/GPU errors | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. |
321
+ | Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
322
+ | Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
323
+ | Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
openpi/config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "pi0",
3
+ "n_obs_steps": 1,
4
+ "input_features": {
5
+ "observation.state": {
6
+ "type": "STATE",
7
+ "shape": [
8
+ 6
9
+ ]
10
+ },
11
+ "observation.images.camera0": {
12
+ "type": "VISUAL",
13
+ "shape": [
14
+ 3,
15
+ 480,
16
+ 640
17
+ ]
18
+ },
19
+ "observation.images.camera1": {
20
+ "type": "VISUAL",
21
+ "shape": [
22
+ 3,
23
+ 480,
24
+ 640
25
+ ]
26
+ },
27
+ "observation.images.camera2": {
28
+ "type": "VISUAL",
29
+ "shape": [
30
+ 3,
31
+ 480,
32
+ 640
33
+ ]
34
+ }
35
+ },
36
+ "output_features": {
37
+ "action": {
38
+ "type": "ACTION",
39
+ "shape": [
40
+ 6
41
+ ]
42
+ }
43
+ },
44
+ "device": "cpu",
45
+ "use_amp": false,
46
+ "push_to_hub": true,
47
+ "repo_id": null,
48
+ "private": null,
49
+ "tags": null,
50
+ "license": null,
51
+ "chunk_size": 50,
52
+ "n_action_steps": 50,
53
+ "normalization_mapping": {
54
+ "VISUAL": "IDENTITY",
55
+ "STATE": "MEAN_STD",
56
+ "ACTION": "MEAN_STD"
57
+ },
58
+ "max_state_dim": 32,
59
+ "max_action_dim": 32,
60
+ "resize_imgs_with_padding": [
61
+ 224,
62
+ 224
63
+ ],
64
+ "empty_cameras": 0,
65
+ "adapt_to_pi_aloha": false,
66
+ "use_delta_joint_actions_aloha": false,
67
+ "tokenizer_max_length": 48,
68
+ "proj_width": 1024,
69
+ "num_steps": 10,
70
+ "use_cache": true,
71
+ "attention_implementation": "eager",
72
+ "freeze_vision_encoder": true,
73
+ "train_expert_only": false,
74
+ "train_state_proj": true,
75
+ "optimizer_lr": 2.5e-05,
76
+ "optimizer_betas": [
77
+ 0.9,
78
+ 0.95
79
+ ],
80
+ "optimizer_eps": 1e-08,
81
+ "optimizer_weight_decay": 1e-10,
82
+ "scheduler_warmup_steps": 1000,
83
+ "scheduler_decay_steps": 30000,
84
+ "scheduler_decay_lr": 2.5e-06
85
+ }
openpi/docs/docker.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Docker Setup
2
+
3
+ All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
4
+
5
+ - Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
6
+ - Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
7
+ - To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
8
+ - The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
9
+ - Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
10
+
11
+
12
+ If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
13
+
14
+ Build the Docker image and start the container with the following command:
15
+ ```bash
16
+ docker compose -f scripts/docker/compose.yml up --build
17
+ ```
18
+
19
+ To build and run the Docker image for a specific example, use the following command:
20
+ ```bash
21
+ docker compose -f examples/<example_name>/compose.yml up --build
22
+ ```
23
+ where `<example_name>` is the name of the example you want to run.
24
+
25
+ During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
openpi/docs/norm_stats.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Normalization statistics
2
+
3
+ Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
4
+
5
+ ## Reloading normalization statistics
6
+
7
+ When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
8
+
9
+ **If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
10
+
11
+ ```python
12
+ TrainConfig(
13
+ ...
14
+ data=LeRobotAlohaDataConfig(
15
+ ...
16
+ assets=AssetsConfig(
17
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
18
+ asset_id="trossen",
19
+ ),
20
+ ),
21
+ )
22
+ ```
23
+
24
+ For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
25
+
26
+ **Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
27
+
28
+ **Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
29
+
30
+
31
+ ## Provided Pre-training Normalization Statistics
32
+
33
+ Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
34
+ | Robot | Description | Asset ID |
35
+ |-------|-------------|----------|
36
+ | ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
37
+ | Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
38
+ | Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
39
+ | Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
40
+ | UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
41
+ | UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
42
+ | ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
43
+ | ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
44
+ | Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
45
+
46
+
47
+ ## Pi0 Model Action Space Definitions
48
+
49
+ Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
50
+ ```
51
+ "dim_0:dim_5": "left arm joint angles",
52
+ "dim_6": "left arm gripper position",
53
+ "dim_7:dim_12": "right arm joint angles (for bi-manual only)",
54
+ "dim_13": "right arm gripper position (for bi-manual only)",
55
+
56
+ # For mobile robots:
57
+ "dim_14:dim_15": "x-y base velocity (for mobile robots only)",
58
+ ```
59
+
60
+ The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
61
+
62
+ For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
63
+
64
+ General info for Pi robots:
65
+ - Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
66
+ - Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
67
+ - Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
68
+
69
+ For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
openpi/docs/remote_inference.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Running openpi models remotely
3
+
4
+ We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
5
+
6
+ ## Starting a remote policy server
7
+
8
+ To start a remote policy server, you can simply run the following command:
9
+
10
+ ```bash
11
+ uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
12
+ ```
13
+
14
+ The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
15
+
16
+ ```bash
17
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
18
+ ```
19
+
20
+ This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
21
+
22
+ ## Querying the remote policy server from your robot code
23
+
24
+ We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
25
+
26
+ First, install the `openpi-client` package in your robot environment:
27
+
28
+ ```bash
29
+ cd $OPENPI_ROOT/packages/openpi-client
30
+ pip install -e .
31
+ ```
32
+
33
+ Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
34
+
35
+ ```python
36
+ from openpi_client import image_tools
37
+ from openpi_client import websocket_client_policy
38
+
39
+ # Outside of episode loop, initialize the policy client.
40
+ # Point to the host and port of the policy server (localhost and 8000 are the defaults).
41
+ client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
42
+
43
+ for step in range(num_steps):
44
+ # Inside the episode loop, construct the observation.
45
+ # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
46
+ # We provide utilities for resizing images + uint8 conversion so you match the training routines.
47
+ # The typical resize_size for pre-trained pi0 models is 224.
48
+ # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
49
+ observation = {
50
+ "observation/image": image_tools.convert_to_uint8(
51
+ image_tools.resize_with_pad(img, 224, 224)
52
+ ),
53
+ "observation/wrist_image": image_tools.convert_to_uint8(
54
+ image_tools.resize_with_pad(wrist_img, 224, 224)
55
+ ),
56
+ "observation/state": state,
57
+ "prompt": task_instruction,
58
+ }
59
+
60
+ # Call the policy server with the current observation.
61
+ # This returns an action chunk of shape (action_horizon, action_dim).
62
+ # Note that you typically only need to call the policy every N steps and execute steps
63
+ # from the predicted action chunk open-loop in the remaining steps.
64
+ action_chunk = client.infer(observation)["actions"]
65
+
66
+ # Execute the actions in the environment.
67
+ ...
68
+
69
+ ```
70
+
71
+ Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
openpi/examples/aloha_real/Dockerfile ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the Aloha real environment.
2
+
3
+ # Build the container:
4
+ # docker build . -t aloha_real -f examples/aloha_real/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
8
+
9
+ FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
10
+ SHELL ["/bin/bash", "-c"]
11
+
12
+ ENV DEBIAN_FRONTEND=noninteractive
13
+ RUN apt-get update && \
14
+ apt-get install -y --no-install-recommends \
15
+ cmake \
16
+ curl \
17
+ libffi-dev \
18
+ python3-rosdep \
19
+ python3-rosinstall \
20
+ python3-rosinstall-generator \
21
+ whiptail \
22
+ git \
23
+ wget \
24
+ openssh-client \
25
+ ros-noetic-cv-bridge \
26
+ ros-noetic-usb-cam \
27
+ ros-noetic-realsense2-camera \
28
+ keyboard-configuration
29
+
30
+ WORKDIR /root
31
+ RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
32
+ RUN chmod +x xsarm_amd64_install.sh
33
+ RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
34
+
35
+ COPY ./third_party/aloha /root/interbotix_ws/src/aloha
36
+ RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
37
+
38
+ # Install python 3.10 because this ROS image comes with 3.8
39
+ RUN mkdir /python && \
40
+ cd /python && \
41
+ wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
42
+ tar -zxvf Python-3.10.14.tgz && \
43
+ cd Python-3.10.14 && \
44
+ ls -lhR && \
45
+ ./configure --enable-optimizations && \
46
+ make install && \
47
+ echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
48
+ echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
49
+ cd ~ && rm -rf /python && \
50
+ rm -rf /var/lib/apt/lists/*
51
+
52
+ COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
53
+ ENV UV_HTTP_TIMEOUT=120
54
+ ENV UV_LINK_MODE=copy
55
+ COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
56
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
57
+ RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
58
+
59
+ ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
60
+ WORKDIR /app
61
+
62
+ # Create an entrypoint script to run the setup commands, followed by the command passed in.
63
+ RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
64
+ #!/bin/bash
65
+ source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
66
+ EOF
67
+ RUN chmod +x /usr/local/bin/entrypoint.sh
68
+
69
+ ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
70
+ CMD ["python3", "/app/examples/aloha_real/main.py"]
openpi/examples/aloha_real/README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run Aloha (Real Robot)
2
+
3
+ This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
4
+
5
+ ## Prerequisites
6
+
7
+ This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
8
+
9
+ 1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
10
+ 1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
11
+
12
+ ## With Docker
13
+
14
+ ```bash
15
+ export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
16
+ docker compose -f examples/aloha_real/compose.yml up --build
17
+ ```
18
+
19
+ ## Without Docker
20
+
21
+ Terminal window 1:
22
+
23
+ ```bash
24
+ # Create virtual environment
25
+ uv venv --python 3.10 examples/aloha_real/.venv
26
+ source examples/aloha_real/.venv/bin/activate
27
+ uv pip sync examples/aloha_real/requirements.txt
28
+ uv pip install -e packages/openpi-client
29
+
30
+ # Run the robot
31
+ python -m examples.aloha_real.main
32
+ ```
33
+
34
+ Terminal window 2:
35
+
36
+ ```bash
37
+ roslaunch aloha ros_nodes.launch
38
+ ```
39
+
40
+ Terminal window 3:
41
+
42
+ ```bash
43
+ uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
44
+ ```
45
+
46
+ ## **ALOHA Checkpoint Guide**
47
+
48
+
49
+ The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
50
+
51
+ While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
52
+
53
+
54
+ ---
55
+
56
+ ### **Toast Task**
57
+
58
+ This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
59
+
60
+ - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
61
+ - **Prompt**: "take the toast out of the toaster"
62
+ - **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
63
+ - **Object Distribution**:
64
+ - Works on both real toast and rubber fake toast
65
+ - Compatible with standard 2-slice toasters
66
+ - Works with plates of varying colors
67
+
68
+ ### **Scene Setup Guidelines**
69
+ <img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
70
+
71
+ - The toaster should be positioned in the top-left quadrant of the workspace.
72
+ - Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
73
+ - The plate should be placed roughly in the lower-center of the workspace.
74
+ - Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
75
+
76
+
77
+ ### **Towel Task**
78
+
79
+ This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
80
+
81
+ - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
82
+ - **Prompt**: "fold the towel"
83
+ - **Object Distribution**:
84
+ - Works on towels of varying solid colors
85
+ - Performance is worse on heavily textured or striped towels
86
+
87
+ ### **Scene Setup Guidelines**
88
+ <img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
89
+
90
+ - The towel should be flattened and roughly centered on the table.
91
+ - Choose a towel that does not blend in with the table surface.
92
+
93
+
94
+ ### **Tupperware Task**
95
+
96
+ This task involves opening a tupperware filled with food and pouring the contents onto a plate.
97
+
98
+ - **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
99
+ - **Prompt**: "open the tupperware and put the food on the plate"
100
+ - **Objects needed**: Tupperware, food (or food-like items), and a plate.
101
+ - **Object Distribution**:
102
+ - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
103
+ - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
104
+ - The policy has seen plates of varying solid colors.
105
+
106
+ ### **Scene Setup Guidelines**
107
+ <img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
108
+
109
+ - Best performance observed when both the tupperware and plate are roughly centered in the workspace.
110
+ - Positioning:
111
+ - Tupperware should be on the left.
112
+ - Plate should be on the right or bottom.
113
+ - The tupperware flap should point toward the plate.
114
+
115
+ ## Training on your own Aloha dataset
116
+
117
+ 1. Convert the dataset to the LeRobot dataset v2.0 format.
118
+
119
+ We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
120
+
121
+
122
+ 2. Define a training config that uses the custom dataset.
123
+
124
+ We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
125
+
126
+ IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
openpi/examples/aloha_real/compose.yml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/aloha_real/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: aloha_real
6
+ depends_on:
7
+ - aloha_ros_nodes
8
+ - ros_master
9
+ - openpi_server
10
+ build:
11
+ context: ../..
12
+ dockerfile: examples/aloha_real/Dockerfile
13
+ init: true
14
+ tty: true
15
+ network_mode: host
16
+ privileged: true
17
+ volumes:
18
+ - $PWD:/app
19
+ - ../../data:/data
20
+
21
+ aloha_ros_nodes:
22
+ image: aloha_real
23
+ depends_on:
24
+ - ros_master
25
+ build:
26
+ context: ../..
27
+ dockerfile: examples/aloha_real/Dockerfile
28
+ init: true
29
+ tty: true
30
+ network_mode: host
31
+ privileged: true
32
+ volumes:
33
+ - /dev:/dev
34
+ command: roslaunch --wait aloha ros_nodes.launch
35
+
36
+ ros_master:
37
+ image: ros:noetic-robot
38
+ network_mode: host
39
+ privileged: true
40
+ command:
41
+ - roscore
42
+
43
+ openpi_server:
44
+ image: openpi_server
45
+ build:
46
+ context: ../..
47
+ dockerfile: scripts/docker/serve_policy.Dockerfile
48
+ init: true
49
+ tty: true
50
+ network_mode: host
51
+ volumes:
52
+ - $PWD:/app
53
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
54
+ environment:
55
+ - SERVER_ARGS
56
+ - OPENPI_DATA_HOME=/openpi_assets
57
+ - IS_DOCKER=true
58
+
59
+ # Comment out this block if not running on a machine with GPUs.
60
+ deploy:
61
+ resources:
62
+ reservations:
63
+ devices:
64
+ - driver: nvidia
65
+ count: 1
66
+ capabilities: [gpu]
openpi/examples/aloha_real/constants.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2
+ # ruff: noqa
3
+
4
+ ### Task parameters
5
+
6
+ ### ALOHA fixed constants
7
+ DT = 0.001
8
+ JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
9
+ START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
10
+
11
+ # Left finger position limits (qpos[7]), right_finger = -1 * left_finger
12
+ MASTER_GRIPPER_POSITION_OPEN = 0.02417
13
+ MASTER_GRIPPER_POSITION_CLOSE = 0.01244
14
+ PUPPET_GRIPPER_POSITION_OPEN = 0.05800
15
+ PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
16
+
17
+ # Gripper joint limits (qpos[6])
18
+ MASTER_GRIPPER_JOINT_OPEN = 0.3083
19
+ MASTER_GRIPPER_JOINT_CLOSE = -0.6842
20
+ PUPPET_GRIPPER_JOINT_OPEN = 1.4910
21
+ PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
22
+
23
+ ############################ Helper functions ############################
24
+
25
+ MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
26
+ MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
27
+ )
28
+ PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
29
+ PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
30
+ )
31
+ MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
32
+ lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
33
+ )
34
+ PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
35
+ lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
36
+ )
37
+ MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
38
+
39
+ MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
40
+ MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
41
+ )
42
+ PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
43
+ PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
44
+ )
45
+ MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
46
+ lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
47
+ )
48
+ PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
49
+ lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
50
+ )
51
+ MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
52
+
53
+ MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
54
+ PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
55
+
56
+ MASTER_POS2JOINT = (
57
+ lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
58
+ + MASTER_GRIPPER_JOINT_CLOSE
59
+ )
60
+ MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
61
+ (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
62
+ )
63
+ PUPPET_POS2JOINT = (
64
+ lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
65
+ + PUPPET_GRIPPER_JOINT_CLOSE
66
+ )
67
+ PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
68
+ (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
69
+ )
70
+
71
+ MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
3
+
4
+ Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
5
+ """
6
+
7
+ import dataclasses
8
+ from pathlib import Path
9
+ import shutil
10
+ from typing import Literal
11
+
12
+ import h5py
13
+ from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
14
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
15
+ from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
16
+ import numpy as np
17
+ import torch
18
+ import tqdm
19
+ import tyro
20
+
21
+
22
+ @dataclasses.dataclass(frozen=True)
23
+ class DatasetConfig:
24
+ use_videos: bool = True
25
+ tolerance_s: float = 0.0001
26
+ image_writer_processes: int = 10
27
+ image_writer_threads: int = 5
28
+ video_backend: str | None = None
29
+
30
+
31
+ DEFAULT_DATASET_CONFIG = DatasetConfig()
32
+
33
+
34
+ def create_empty_dataset(
35
+ repo_id: str,
36
+ robot_type: str,
37
+ mode: Literal["video", "image"] = "video",
38
+ *,
39
+ has_velocity: bool = False,
40
+ has_effort: bool = False,
41
+ dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
42
+ ) -> LeRobotDataset:
43
+ motors = [
44
+ "right_waist",
45
+ "right_shoulder",
46
+ "right_elbow",
47
+ "right_forearm_roll",
48
+ "right_wrist_angle",
49
+ "right_wrist_rotate",
50
+ "right_gripper",
51
+ "left_waist",
52
+ "left_shoulder",
53
+ "left_elbow",
54
+ "left_forearm_roll",
55
+ "left_wrist_angle",
56
+ "left_wrist_rotate",
57
+ "left_gripper",
58
+ ]
59
+ cameras = [
60
+ "cam_high",
61
+ "cam_low",
62
+ "cam_left_wrist",
63
+ "cam_right_wrist",
64
+ ]
65
+
66
+ features = {
67
+ "observation.state": {
68
+ "dtype": "float32",
69
+ "shape": (len(motors),),
70
+ "names": [
71
+ motors,
72
+ ],
73
+ },
74
+ "action": {
75
+ "dtype": "float32",
76
+ "shape": (len(motors),),
77
+ "names": [
78
+ motors,
79
+ ],
80
+ },
81
+ }
82
+
83
+ if has_velocity:
84
+ features["observation.velocity"] = {
85
+ "dtype": "float32",
86
+ "shape": (len(motors),),
87
+ "names": [
88
+ motors,
89
+ ],
90
+ }
91
+
92
+ if has_effort:
93
+ features["observation.effort"] = {
94
+ "dtype": "float32",
95
+ "shape": (len(motors),),
96
+ "names": [
97
+ motors,
98
+ ],
99
+ }
100
+
101
+ for cam in cameras:
102
+ features[f"observation.images.{cam}"] = {
103
+ "dtype": mode,
104
+ "shape": (3, 480, 640),
105
+ "names": [
106
+ "channels",
107
+ "height",
108
+ "width",
109
+ ],
110
+ }
111
+
112
+ if Path(LEROBOT_HOME / repo_id).exists():
113
+ shutil.rmtree(LEROBOT_HOME / repo_id)
114
+
115
+ return LeRobotDataset.create(
116
+ repo_id=repo_id,
117
+ fps=50,
118
+ robot_type=robot_type,
119
+ features=features,
120
+ use_videos=dataset_config.use_videos,
121
+ tolerance_s=dataset_config.tolerance_s,
122
+ image_writer_processes=dataset_config.image_writer_processes,
123
+ image_writer_threads=dataset_config.image_writer_threads,
124
+ video_backend=dataset_config.video_backend,
125
+ )
126
+
127
+
128
+ def get_cameras(hdf5_files: list[Path]) -> list[str]:
129
+ with h5py.File(hdf5_files[0], "r") as ep:
130
+ # ignore depth channel, not currently handled
131
+ return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
132
+
133
+
134
+ def has_velocity(hdf5_files: list[Path]) -> bool:
135
+ with h5py.File(hdf5_files[0], "r") as ep:
136
+ return "/observations/qvel" in ep
137
+
138
+
139
+ def has_effort(hdf5_files: list[Path]) -> bool:
140
+ with h5py.File(hdf5_files[0], "r") as ep:
141
+ return "/observations/effort" in ep
142
+
143
+
144
+ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
145
+ imgs_per_cam = {}
146
+ for camera in cameras:
147
+ uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
148
+
149
+ if uncompressed:
150
+ # load all images in RAM
151
+ imgs_array = ep[f"/observations/images/{camera}"][:]
152
+ else:
153
+ import cv2
154
+
155
+ # load one compressed image after the other in RAM and uncompress
156
+ imgs_array = []
157
+ for data in ep[f"/observations/images/{camera}"]:
158
+ imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
159
+ imgs_array = np.array(imgs_array)
160
+
161
+ imgs_per_cam[camera] = imgs_array
162
+ return imgs_per_cam
163
+
164
+
165
+ def load_raw_episode_data(
166
+ ep_path: Path,
167
+ ) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
168
+ with h5py.File(ep_path, "r") as ep:
169
+ state = torch.from_numpy(ep["/observations/qpos"][:])
170
+ action = torch.from_numpy(ep["/action"][:])
171
+
172
+ velocity = None
173
+ if "/observations/qvel" in ep:
174
+ velocity = torch.from_numpy(ep["/observations/qvel"][:])
175
+
176
+ effort = None
177
+ if "/observations/effort" in ep:
178
+ effort = torch.from_numpy(ep["/observations/effort"][:])
179
+
180
+ imgs_per_cam = load_raw_images_per_camera(
181
+ ep,
182
+ [
183
+ "cam_high",
184
+ "cam_low",
185
+ "cam_left_wrist",
186
+ "cam_right_wrist",
187
+ ],
188
+ )
189
+
190
+ return imgs_per_cam, state, action, velocity, effort
191
+
192
+
193
+ def populate_dataset(
194
+ dataset: LeRobotDataset,
195
+ hdf5_files: list[Path],
196
+ task: str,
197
+ episodes: list[int] | None = None,
198
+ ) -> LeRobotDataset:
199
+ if episodes is None:
200
+ episodes = range(len(hdf5_files))
201
+
202
+ for ep_idx in tqdm.tqdm(episodes):
203
+ ep_path = hdf5_files[ep_idx]
204
+
205
+ imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
206
+ num_frames = state.shape[0]
207
+
208
+ for i in range(num_frames):
209
+ frame = {
210
+ "observation.state": state[i],
211
+ "action": action[i],
212
+ }
213
+
214
+ for camera, img_array in imgs_per_cam.items():
215
+ frame[f"observation.images.{camera}"] = img_array[i]
216
+
217
+ if velocity is not None:
218
+ frame["observation.velocity"] = velocity[i]
219
+ if effort is not None:
220
+ frame["observation.effort"] = effort[i]
221
+
222
+ dataset.add_frame(frame)
223
+
224
+ dataset.save_episode(task=task)
225
+
226
+ return dataset
227
+
228
+
229
+ def port_aloha(
230
+ raw_dir: Path,
231
+ repo_id: str,
232
+ raw_repo_id: str | None = None,
233
+ task: str = "DEBUG",
234
+ *,
235
+ episodes: list[int] | None = None,
236
+ push_to_hub: bool = True,
237
+ is_mobile: bool = False,
238
+ mode: Literal["video", "image"] = "image",
239
+ dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
240
+ ):
241
+ if (LEROBOT_HOME / repo_id).exists():
242
+ shutil.rmtree(LEROBOT_HOME / repo_id)
243
+
244
+ if not raw_dir.exists():
245
+ if raw_repo_id is None:
246
+ raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
247
+ download_raw(raw_dir, repo_id=raw_repo_id)
248
+
249
+ hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
250
+
251
+ dataset = create_empty_dataset(
252
+ repo_id,
253
+ robot_type="mobile_aloha" if is_mobile else "aloha",
254
+ mode=mode,
255
+ has_effort=has_effort(hdf5_files),
256
+ has_velocity=has_velocity(hdf5_files),
257
+ dataset_config=dataset_config,
258
+ )
259
+ dataset = populate_dataset(
260
+ dataset,
261
+ hdf5_files,
262
+ task=task,
263
+ episodes=episodes,
264
+ )
265
+ dataset.consolidate()
266
+
267
+ if push_to_hub:
268
+ dataset.push_to_hub()
269
+
270
+
271
+ if __name__ == "__main__":
272
+ tyro.cli(port_aloha)
openpi/examples/aloha_real/env.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional # noqa: UP035
2
+
3
+ import einops
4
+ from openpi_client import image_tools
5
+ from openpi_client.runtime import environment as _environment
6
+ from typing_extensions import override
7
+
8
+ from examples.aloha_real import real_env as _real_env
9
+
10
+
11
+ class AlohaRealEnvironment(_environment.Environment):
12
+ """An environment for an Aloha robot on real hardware."""
13
+
14
+ def __init__(
15
+ self,
16
+ reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
17
+ render_height: int = 224,
18
+ render_width: int = 224,
19
+ ) -> None:
20
+ self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
21
+ self._render_height = render_height
22
+ self._render_width = render_width
23
+
24
+ self._ts = None
25
+
26
+ @override
27
+ def reset(self) -> None:
28
+ self._ts = self._env.reset()
29
+
30
+ @override
31
+ def is_episode_complete(self) -> bool:
32
+ return False
33
+
34
+ @override
35
+ def get_observation(self) -> dict:
36
+ if self._ts is None:
37
+ raise RuntimeError("Timestep is not set. Call reset() first.")
38
+
39
+ obs = self._ts.observation
40
+ for k in list(obs["images"].keys()):
41
+ if "_depth" in k:
42
+ del obs["images"][k]
43
+
44
+ for cam_name in obs["images"]:
45
+ img = image_tools.convert_to_uint8(
46
+ image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
47
+ )
48
+ obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
49
+
50
+ return {
51
+ "state": obs["qpos"],
52
+ "images": obs["images"],
53
+ }
54
+
55
+ @override
56
+ def apply_action(self, action: dict) -> None:
57
+ self._ts = self._env.step(action["actions"])
openpi/examples/aloha_real/main.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+
4
+ from openpi_client import action_chunk_broker
5
+ from openpi_client import websocket_client_policy as _websocket_client_policy
6
+ from openpi_client.runtime import runtime as _runtime
7
+ from openpi_client.runtime.agents import policy_agent as _policy_agent
8
+ import tyro
9
+
10
+ from examples.aloha_real import env as _env
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Args:
15
+ host: str = "0.0.0.0"
16
+ port: int = 8000
17
+
18
+ action_horizon: int = 25
19
+
20
+ num_episodes: int = 1
21
+ max_episode_steps: int = 1000
22
+
23
+
24
+ def main(args: Args) -> None:
25
+ ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
26
+ host=args.host,
27
+ port=args.port,
28
+ )
29
+ logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
30
+
31
+ metadata = ws_client_policy.get_server_metadata()
32
+ runtime = _runtime.Runtime(
33
+ environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
34
+ agent=_policy_agent.PolicyAgent(
35
+ policy=action_chunk_broker.ActionChunkBroker(
36
+ policy=ws_client_policy,
37
+ action_horizon=args.action_horizon,
38
+ )
39
+ ),
40
+ subscribers=[],
41
+ max_hz=50,
42
+ num_episodes=args.num_episodes,
43
+ max_episode_steps=args.max_episode_steps,
44
+ )
45
+
46
+ runtime.run()
47
+
48
+
49
+ if __name__ == "__main__":
50
+ logging.basicConfig(level=logging.INFO, force=True)
51
+ tyro.cli(main)
openpi/examples/aloha_real/real_env.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2
+ # ruff: noqa
3
+ import collections
4
+ import time
5
+ from typing import Optional, List
6
+ import dm_env
7
+ from interbotix_xs_modules.arm import InterbotixManipulatorXS
8
+ from interbotix_xs_msgs.msg import JointSingleCommand
9
+ import numpy as np
10
+
11
+ from examples.aloha_real import constants
12
+ from examples.aloha_real import robot_utils
13
+
14
+ # This is the reset position that is used by the standard Aloha runtime.
15
+ DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
16
+
17
+
18
+ class RealEnv:
19
+ """
20
+ Environment for real robot bi-manual manipulation
21
+ Action space: [left_arm_qpos (6), # absolute joint position
22
+ left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
23
+ right_arm_qpos (6), # absolute joint position
24
+ right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
25
+
26
+ Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
27
+ left_gripper_position (1), # normalized gripper position (0: close, 1: open)
28
+ right_arm_qpos (6), # absolute joint position
29
+ right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
30
+ "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
31
+ left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
32
+ right_arm_qvel (6), # absolute joint velocity (rad)
33
+ right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
34
+ "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
35
+ "cam_low": (480x640x3), # h, w, c, dtype='uint8'
36
+ "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
37
+ "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
38
+ """
39
+
40
+ def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
41
+ # reset_position = START_ARM_POSE[:6]
42
+ self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
43
+
44
+ self.puppet_bot_left = InterbotixManipulatorXS(
45
+ robot_model="vx300s",
46
+ group_name="arm",
47
+ gripper_name="gripper",
48
+ robot_name="puppet_left",
49
+ init_node=init_node,
50
+ )
51
+ self.puppet_bot_right = InterbotixManipulatorXS(
52
+ robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
53
+ )
54
+ if setup_robots:
55
+ self.setup_robots()
56
+
57
+ self.recorder_left = robot_utils.Recorder("left", init_node=False)
58
+ self.recorder_right = robot_utils.Recorder("right", init_node=False)
59
+ self.image_recorder = robot_utils.ImageRecorder(init_node=False)
60
+ self.gripper_command = JointSingleCommand(name="gripper")
61
+
62
+ def setup_robots(self):
63
+ robot_utils.setup_puppet_bot(self.puppet_bot_left)
64
+ robot_utils.setup_puppet_bot(self.puppet_bot_right)
65
+
66
+ def get_qpos(self):
67
+ left_qpos_raw = self.recorder_left.qpos
68
+ right_qpos_raw = self.recorder_right.qpos
69
+ left_arm_qpos = left_qpos_raw[:6]
70
+ right_arm_qpos = right_qpos_raw[:6]
71
+ left_gripper_qpos = [
72
+ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
73
+ ] # this is position not joint
74
+ right_gripper_qpos = [
75
+ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
76
+ ] # this is position not joint
77
+ return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
78
+
79
+ def get_qvel(self):
80
+ left_qvel_raw = self.recorder_left.qvel
81
+ right_qvel_raw = self.recorder_right.qvel
82
+ left_arm_qvel = left_qvel_raw[:6]
83
+ right_arm_qvel = right_qvel_raw[:6]
84
+ left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
85
+ right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
86
+ return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
87
+
88
+ def get_effort(self):
89
+ left_effort_raw = self.recorder_left.effort
90
+ right_effort_raw = self.recorder_right.effort
91
+ left_robot_effort = left_effort_raw[:7]
92
+ right_robot_effort = right_effort_raw[:7]
93
+ return np.concatenate([left_robot_effort, right_robot_effort])
94
+
95
+ def get_images(self):
96
+ return self.image_recorder.get_images()
97
+
98
+ def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
99
+ left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
100
+ self.gripper_command.cmd = left_gripper_desired_joint
101
+ self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
102
+
103
+ right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
104
+ right_gripper_desired_pos_normalized
105
+ )
106
+ self.gripper_command.cmd = right_gripper_desired_joint
107
+ self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
108
+
109
+ def _reset_joints(self):
110
+ robot_utils.move_arms(
111
+ [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
112
+ )
113
+
114
+ def _reset_gripper(self):
115
+ """Set to position mode and do position resets: first close then open. Then change back to PWM mode
116
+
117
+ NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
118
+ was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
119
+ increase the frequency of motor faults.
120
+ """
121
+ robot_utils.move_grippers(
122
+ [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
123
+ )
124
+ robot_utils.move_grippers(
125
+ [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
126
+ )
127
+
128
+ def get_observation(self):
129
+ obs = collections.OrderedDict()
130
+ obs["qpos"] = self.get_qpos()
131
+ obs["qvel"] = self.get_qvel()
132
+ obs["effort"] = self.get_effort()
133
+ obs["images"] = self.get_images()
134
+ return obs
135
+
136
+ def get_reward(self):
137
+ return 0
138
+
139
+ def reset(self, *, fake=False):
140
+ if not fake:
141
+ # Reboot puppet robot gripper motors
142
+ self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
143
+ self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
144
+ self._reset_joints()
145
+ self._reset_gripper()
146
+ return dm_env.TimeStep(
147
+ step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
148
+ )
149
+
150
+ def step(self, action):
151
+ state_len = int(len(action) / 2)
152
+ left_action = action[:state_len]
153
+ right_action = action[state_len:]
154
+ self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
155
+ self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
156
+ self.set_gripper_pose(left_action[-1], right_action[-1])
157
+ time.sleep(constants.DT)
158
+ return dm_env.TimeStep(
159
+ step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
160
+ )
161
+
162
+
163
+ def get_action(master_bot_left, master_bot_right):
164
+ action = np.zeros(14) # 6 joint + 1 gripper, for two arms
165
+ # Arm actions
166
+ action[:6] = master_bot_left.dxl.joint_states.position[:6]
167
+ action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
168
+ # Gripper actions
169
+ action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
170
+ action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
171
+
172
+ return action
173
+
174
+
175
+ def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
176
+ return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
openpi/examples/aloha_real/requirements.in ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow
2
+ dm_control
3
+ einops
4
+ h5py
5
+ matplotlib
6
+ modern_robotics
7
+ msgpack
8
+ numpy>=1.22.4,<2.0.0
9
+ opencv-python
10
+ packaging
11
+ pexpect
12
+ pyquaternion
13
+ pyrealsense2
14
+ pyyaml
15
+ requests
16
+ rospkg
17
+ tyro
18
+ websockets
openpi/examples/aloha_real/requirements.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
3
+ absl-py==2.1.0
4
+ # via
5
+ # dm-control
6
+ # dm-env
7
+ # labmaze
8
+ # mujoco
9
+ catkin-pkg==1.0.0
10
+ # via rospkg
11
+ certifi==2024.8.30
12
+ # via requests
13
+ charset-normalizer==3.4.0
14
+ # via requests
15
+ contourpy==1.1.1
16
+ # via matplotlib
17
+ cycler==0.12.1
18
+ # via matplotlib
19
+ distro==1.9.0
20
+ # via rospkg
21
+ dm-control==1.0.23
22
+ # via -r examples/aloha_real/requirements.in
23
+ dm-env==1.6
24
+ # via dm-control
25
+ dm-tree==0.1.8
26
+ # via
27
+ # dm-control
28
+ # dm-env
29
+ docstring-parser==0.16
30
+ # via tyro
31
+ docutils==0.20.1
32
+ # via catkin-pkg
33
+ einops==0.8.0
34
+ # via -r examples/aloha_real/requirements.in
35
+ etils==1.3.0
36
+ # via mujoco
37
+ fonttools==4.55.2
38
+ # via matplotlib
39
+ glfw==2.8.0
40
+ # via
41
+ # dm-control
42
+ # mujoco
43
+ h5py==3.11.0
44
+ # via -r examples/aloha_real/requirements.in
45
+ idna==3.10
46
+ # via requests
47
+ importlib-resources==6.4.5
48
+ # via etils
49
+ kiwisolver==1.4.7
50
+ # via matplotlib
51
+ labmaze==1.0.6
52
+ # via dm-control
53
+ lxml==5.3.0
54
+ # via dm-control
55
+ markdown-it-py==3.0.0
56
+ # via rich
57
+ matplotlib==3.7.5
58
+ # via -r examples/aloha_real/requirements.in
59
+ mdurl==0.1.2
60
+ # via markdown-it-py
61
+ modern-robotics==1.1.1
62
+ # via -r examples/aloha_real/requirements.in
63
+ msgpack==1.1.0
64
+ # via -r examples/aloha_real/requirements.in
65
+ mujoco==3.2.3
66
+ # via dm-control
67
+ numpy==1.24.4
68
+ # via
69
+ # -r examples/aloha_real/requirements.in
70
+ # contourpy
71
+ # dm-control
72
+ # dm-env
73
+ # h5py
74
+ # labmaze
75
+ # matplotlib
76
+ # modern-robotics
77
+ # mujoco
78
+ # opencv-python
79
+ # pyquaternion
80
+ # scipy
81
+ opencv-python==4.10.0.84
82
+ # via -r examples/aloha_real/requirements.in
83
+ packaging==24.2
84
+ # via
85
+ # -r examples/aloha_real/requirements.in
86
+ # matplotlib
87
+ pexpect==4.9.0
88
+ # via -r examples/aloha_real/requirements.in
89
+ pillow==10.4.0
90
+ # via
91
+ # -r examples/aloha_real/requirements.in
92
+ # matplotlib
93
+ protobuf==5.29.1
94
+ # via dm-control
95
+ ptyprocess==0.7.0
96
+ # via pexpect
97
+ pygments==2.18.0
98
+ # via rich
99
+ pyopengl==3.1.7
100
+ # via
101
+ # dm-control
102
+ # mujoco
103
+ pyparsing==3.1.4
104
+ # via
105
+ # catkin-pkg
106
+ # dm-control
107
+ # matplotlib
108
+ pyquaternion==0.9.9
109
+ # via -r examples/aloha_real/requirements.in
110
+ pyrealsense2==2.55.1.6486
111
+ # via -r examples/aloha_real/requirements.in
112
+ python-dateutil==2.9.0.post0
113
+ # via
114
+ # catkin-pkg
115
+ # matplotlib
116
+ pyyaml==6.0.2
117
+ # via
118
+ # -r examples/aloha_real/requirements.in
119
+ # rospkg
120
+ requests==2.32.3
121
+ # via
122
+ # -r examples/aloha_real/requirements.in
123
+ # dm-control
124
+ rich==13.9.4
125
+ # via tyro
126
+ rospkg==1.5.1
127
+ # via -r examples/aloha_real/requirements.in
128
+ scipy==1.10.1
129
+ # via dm-control
130
+ setuptools==75.3.0
131
+ # via
132
+ # catkin-pkg
133
+ # dm-control
134
+ # labmaze
135
+ shtab==1.7.1
136
+ # via tyro
137
+ six==1.17.0
138
+ # via python-dateutil
139
+ tqdm==4.67.1
140
+ # via dm-control
141
+ typeguard==4.4.0
142
+ # via tyro
143
+ typing-extensions==4.12.2
144
+ # via
145
+ # etils
146
+ # rich
147
+ # typeguard
148
+ # tyro
149
+ tyro==0.9.2
150
+ # via -r examples/aloha_real/requirements.in
151
+ urllib3==2.2.3
152
+ # via requests
153
+ websockets==14.1
154
+ # via -r examples/aloha_real/requirements.in
155
+ zipp==3.20.2
156
+ # via etils
openpi/examples/aloha_real/robot_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
2
+ # ruff: noqa
3
+ from collections import deque
4
+ import datetime
5
+ import json
6
+ import time
7
+
8
+ from aloha.msg import RGBGrayscaleImage
9
+ from cv_bridge import CvBridge
10
+ from interbotix_xs_msgs.msg import JointGroupCommand
11
+ from interbotix_xs_msgs.msg import JointSingleCommand
12
+ import numpy as np
13
+ import rospy
14
+ from sensor_msgs.msg import JointState
15
+
16
+ from examples.aloha_real import constants
17
+
18
+
19
+ class ImageRecorder:
20
+ def __init__(self, init_node=True, is_debug=False):
21
+ self.is_debug = is_debug
22
+ self.bridge = CvBridge()
23
+ self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
24
+
25
+ if init_node:
26
+ rospy.init_node("image_recorder", anonymous=True)
27
+ for cam_name in self.camera_names:
28
+ setattr(self, f"{cam_name}_rgb_image", None)
29
+ setattr(self, f"{cam_name}_depth_image", None)
30
+ setattr(self, f"{cam_name}_timestamp", 0.0)
31
+ if cam_name == "cam_high":
32
+ callback_func = self.image_cb_cam_high
33
+ elif cam_name == "cam_low":
34
+ callback_func = self.image_cb_cam_low
35
+ elif cam_name == "cam_left_wrist":
36
+ callback_func = self.image_cb_cam_left_wrist
37
+ elif cam_name == "cam_right_wrist":
38
+ callback_func = self.image_cb_cam_right_wrist
39
+ else:
40
+ raise NotImplementedError
41
+ rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
42
+ if self.is_debug:
43
+ setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
44
+
45
+ self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
46
+ time.sleep(0.5)
47
+
48
+ def image_cb(self, cam_name, data):
49
+ setattr(
50
+ self,
51
+ f"{cam_name}_rgb_image",
52
+ self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
53
+ )
54
+ # setattr(
55
+ # self,
56
+ # f"{cam_name}_depth_image",
57
+ # self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
58
+ # )
59
+ setattr(
60
+ self,
61
+ f"{cam_name}_timestamp",
62
+ data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
63
+ )
64
+ # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
65
+ # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
66
+ # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
67
+ if self.is_debug:
68
+ getattr(self, f"{cam_name}_timestamps").append(
69
+ data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
70
+ )
71
+
72
+ def image_cb_cam_high(self, data):
73
+ cam_name = "cam_high"
74
+ return self.image_cb(cam_name, data)
75
+
76
+ def image_cb_cam_low(self, data):
77
+ cam_name = "cam_low"
78
+ return self.image_cb(cam_name, data)
79
+
80
+ def image_cb_cam_left_wrist(self, data):
81
+ cam_name = "cam_left_wrist"
82
+ return self.image_cb(cam_name, data)
83
+
84
+ def image_cb_cam_right_wrist(self, data):
85
+ cam_name = "cam_right_wrist"
86
+ return self.image_cb(cam_name, data)
87
+
88
+ def get_images(self):
89
+ image_dict = {}
90
+ for cam_name in self.camera_names:
91
+ while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
92
+ time.sleep(0.00001)
93
+ rgb_image = getattr(self, f"{cam_name}_rgb_image")
94
+ depth_image = getattr(self, f"{cam_name}_depth_image")
95
+ self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
96
+ image_dict[cam_name] = rgb_image
97
+ image_dict[f"{cam_name}_depth"] = depth_image
98
+ return image_dict
99
+
100
+ def print_diagnostics(self):
101
+ def dt_helper(l):
102
+ l = np.array(l)
103
+ diff = l[1:] - l[:-1]
104
+ return np.mean(diff)
105
+
106
+ for cam_name in self.camera_names:
107
+ image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
108
+ print(f"{cam_name} {image_freq=:.2f}")
109
+ print()
110
+
111
+
112
+ class Recorder:
113
+ def __init__(self, side, init_node=True, is_debug=False):
114
+ self.secs = None
115
+ self.nsecs = None
116
+ self.qpos = None
117
+ self.effort = None
118
+ self.arm_command = None
119
+ self.gripper_command = None
120
+ self.is_debug = is_debug
121
+
122
+ if init_node:
123
+ rospy.init_node("recorder", anonymous=True)
124
+ rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
125
+ rospy.Subscriber(
126
+ f"/puppet_{side}/commands/joint_group",
127
+ JointGroupCommand,
128
+ self.puppet_arm_commands_cb,
129
+ )
130
+ rospy.Subscriber(
131
+ f"/puppet_{side}/commands/joint_single",
132
+ JointSingleCommand,
133
+ self.puppet_gripper_commands_cb,
134
+ )
135
+ if self.is_debug:
136
+ self.joint_timestamps = deque(maxlen=50)
137
+ self.arm_command_timestamps = deque(maxlen=50)
138
+ self.gripper_command_timestamps = deque(maxlen=50)
139
+ time.sleep(0.1)
140
+
141
+ def puppet_state_cb(self, data):
142
+ self.qpos = data.position
143
+ self.qvel = data.velocity
144
+ self.effort = data.effort
145
+ self.data = data
146
+ if self.is_debug:
147
+ self.joint_timestamps.append(time.time())
148
+
149
+ def puppet_arm_commands_cb(self, data):
150
+ self.arm_command = data.cmd
151
+ if self.is_debug:
152
+ self.arm_command_timestamps.append(time.time())
153
+
154
+ def puppet_gripper_commands_cb(self, data):
155
+ self.gripper_command = data.cmd
156
+ if self.is_debug:
157
+ self.gripper_command_timestamps.append(time.time())
158
+
159
+ def print_diagnostics(self):
160
+ def dt_helper(l):
161
+ l = np.array(l)
162
+ diff = l[1:] - l[:-1]
163
+ return np.mean(diff)
164
+
165
+ joint_freq = 1 / dt_helper(self.joint_timestamps)
166
+ arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
167
+ gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
168
+
169
+ print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
170
+
171
+
172
+ def get_arm_joint_positions(bot):
173
+ return bot.arm.core.joint_states.position[:6]
174
+
175
+
176
+ def get_arm_gripper_positions(bot):
177
+ return bot.gripper.core.joint_states.position[6]
178
+
179
+
180
+ def move_arms(bot_list, target_pose_list, move_time=1):
181
+ num_steps = int(move_time / constants.DT)
182
+ curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
183
+ traj_list = [
184
+ np.linspace(curr_pose, target_pose, num_steps)
185
+ for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
186
+ ]
187
+ for t in range(num_steps):
188
+ for bot_id, bot in enumerate(bot_list):
189
+ bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
190
+ time.sleep(constants.DT)
191
+
192
+
193
+ def move_grippers(bot_list, target_pose_list, move_time):
194
+ print(f"Moving grippers to {target_pose_list=}")
195
+ gripper_command = JointSingleCommand(name="gripper")
196
+ num_steps = int(move_time / constants.DT)
197
+ curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
198
+ traj_list = [
199
+ np.linspace(curr_pose, target_pose, num_steps)
200
+ for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
201
+ ]
202
+
203
+ with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
204
+ for t in range(num_steps):
205
+ d = {}
206
+ for bot_id, bot in enumerate(bot_list):
207
+ gripper_command.cmd = traj_list[bot_id][t]
208
+ bot.gripper.core.pub_single.publish(gripper_command)
209
+ d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
210
+ f.write(json.dumps(d) + "\n")
211
+ time.sleep(constants.DT)
212
+
213
+
214
+ def setup_puppet_bot(bot):
215
+ bot.dxl.robot_reboot_motors("single", "gripper", True)
216
+ bot.dxl.robot_set_operating_modes("group", "arm", "position")
217
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
218
+ torque_on(bot)
219
+
220
+
221
+ def setup_master_bot(bot):
222
+ bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
223
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
224
+ torque_off(bot)
225
+
226
+
227
+ def set_standard_pid_gains(bot):
228
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
229
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
230
+
231
+
232
+ def set_low_pid_gains(bot):
233
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
234
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
235
+
236
+
237
+ def torque_off(bot):
238
+ bot.dxl.robot_torque_enable("group", "arm", False)
239
+ bot.dxl.robot_torque_enable("single", "gripper", False)
240
+
241
+
242
+ def torque_on(bot):
243
+ bot.dxl.robot_torque_enable("group", "arm", True)
244
+ bot.dxl.robot_torque_enable("single", "gripper", True)
245
+
246
+
247
+ # for DAgger
248
+ def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
249
+ print("\nSyncing!")
250
+
251
+ # activate master arms
252
+ torque_on(master_bot_left)
253
+ torque_on(master_bot_right)
254
+
255
+ # get puppet arm positions
256
+ puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
257
+ puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
258
+
259
+ # get puppet gripper positions
260
+ puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
261
+ puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
262
+
263
+ # move master arms to puppet positions
264
+ move_arms(
265
+ [master_bot_left, master_bot_right],
266
+ [puppet_left_qpos, puppet_right_qpos],
267
+ move_time=1,
268
+ )
269
+
270
+ # move master grippers to puppet positions
271
+ move_grippers(
272
+ [master_bot_left, master_bot_right],
273
+ [puppet_left_gripper, puppet_right_gripper],
274
+ move_time=1,
275
+ )
openpi/examples/aloha_real/video_display.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from openpi_client.runtime import subscriber as _subscriber
4
+ from typing_extensions import override
5
+
6
+
7
+ class VideoDisplay(_subscriber.Subscriber):
8
+ """Displays video frames."""
9
+
10
+ def __init__(self) -> None:
11
+ self._ax: plt.Axes | None = None
12
+ self._plt_img: plt.Image | None = None
13
+
14
+ @override
15
+ def on_episode_start(self) -> None:
16
+ plt.ion()
17
+ self._ax = plt.subplot()
18
+ self._plt_img = None
19
+
20
+ @override
21
+ def on_step(self, observation: dict, action: dict) -> None:
22
+ assert self._ax is not None
23
+
24
+ im = observation["image"][0] # [C, H, W]
25
+ im = np.transpose(im, (1, 2, 0)) # [H, W, C]
26
+
27
+ if self._plt_img is None:
28
+ self._plt_img = self._ax.imshow(im)
29
+ else:
30
+ self._plt_img.set_data(im)
31
+ plt.pause(0.001)
32
+
33
+ @override
34
+ def on_episode_end(self) -> None:
35
+ plt.ioff()
36
+ plt.close()
openpi/examples/aloha_sim/Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the Aloha simulation environment.
2
+
3
+ # Build the container:
4
+ # docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
8
+
9
+ FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
10
+ COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11
+
12
+ RUN apt-get update && \
13
+ apt-get install -y \
14
+ libosmesa6-dev \
15
+ libgl1-mesa-glx \
16
+ libglew-dev \
17
+ libglfw3-dev \
18
+ libgles2-mesa-dev
19
+ ENV MUJOCO_GL=egl
20
+
21
+ WORKDIR /app
22
+
23
+ # Copy from the cache instead of linking since it's a mounted volume
24
+ ENV UV_LINK_MODE=copy
25
+
26
+ # Write the virtual environment outside of the project directory so it doesn't
27
+ # leak out of the container when we mount the application code.
28
+ ENV UV_PROJECT_ENVIRONMENT=/.venv
29
+
30
+ # Copy the requirements files so we can install dependencies.
31
+ # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
32
+ # This strategy is best for development-style usage.
33
+ COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
34
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
35
+
36
+ # Install python dependencies.
37
+ RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
38
+ RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
39
+ ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
40
+
41
+ CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
openpi/examples/aloha_sim/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run Aloha Sim
2
+
3
+ ## With Docker
4
+
5
+ ```bash
6
+ export SERVER_ARGS="--env ALOHA_SIM"
7
+ docker compose -f examples/aloha_sim/compose.yml up --build
8
+ ```
9
+
10
+ ## Without Docker
11
+
12
+ Terminal window 1:
13
+
14
+ ```bash
15
+ # Create virtual environment
16
+ uv venv --python 3.10 examples/aloha_sim/.venv
17
+ source examples/aloha_sim/.venv/bin/activate
18
+ uv pip sync examples/aloha_sim/requirements.txt
19
+ uv pip install -e packages/openpi-client
20
+
21
+ # Run the simulation
22
+ MUJOCO_GL=egl python examples/aloha_sim/main.py
23
+ ```
24
+
25
+ Note: If you are seeing EGL errors, you may need to install the following dependencies:
26
+
27
+ ```bash
28
+ sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
29
+ ```
30
+
31
+ Terminal window 2:
32
+
33
+ ```bash
34
+ # Run the server
35
+ uv run scripts/serve_policy.py --env ALOHA_SIM
36
+ ```
openpi/examples/aloha_sim/compose.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/aloha_sim/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: aloha_sim
6
+ depends_on:
7
+ - openpi_server
8
+ build:
9
+ context: ../..
10
+ dockerfile: examples/aloha_sim/Dockerfile
11
+ init: true
12
+ tty: true
13
+ network_mode: host
14
+ privileged: true
15
+ volumes:
16
+ - $PWD:/app
17
+ - ../../data:/data
18
+
19
+ openpi_server:
20
+ image: openpi_server
21
+ build:
22
+ context: ../..
23
+ dockerfile: scripts/docker/serve_policy.Dockerfile
24
+ init: true
25
+ tty: true
26
+ network_mode: host
27
+ volumes:
28
+ - $PWD:/app
29
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
30
+ environment:
31
+ - SERVER_ARGS
32
+ - OPENPI_DATA_HOME=/openpi_assets
33
+ - IS_DOCKER=true
34
+
35
+ # Comment out this block if not running on a machine with GPUs.
36
+ deploy:
37
+ resources:
38
+ reservations:
39
+ devices:
40
+ - driver: nvidia
41
+ count: 1
42
+ capabilities: [gpu]
openpi/examples/aloha_sim/env.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym_aloha # noqa: F401
2
+ import gymnasium
3
+ import numpy as np
4
+ from openpi_client import image_tools
5
+ from openpi_client.runtime import environment as _environment
6
+ from typing_extensions import override
7
+
8
+
9
+ class AlohaSimEnvironment(_environment.Environment):
10
+ """An environment for an Aloha robot in simulation."""
11
+
12
+ def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
13
+ np.random.seed(seed)
14
+ self._rng = np.random.default_rng(seed)
15
+
16
+ self._gym = gymnasium.make(task, obs_type=obs_type)
17
+
18
+ self._last_obs = None
19
+ self._done = True
20
+ self._episode_reward = 0.0
21
+
22
+ @override
23
+ def reset(self) -> None:
24
+ gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
25
+ self._last_obs = self._convert_observation(gym_obs) # type: ignore
26
+ self._done = False
27
+ self._episode_reward = 0.0
28
+
29
+ @override
30
+ def is_episode_complete(self) -> bool:
31
+ return self._done
32
+
33
+ @override
34
+ def get_observation(self) -> dict:
35
+ if self._last_obs is None:
36
+ raise RuntimeError("Observation is not set. Call reset() first.")
37
+
38
+ return self._last_obs # type: ignore
39
+
40
+ @override
41
+ def apply_action(self, action: dict) -> None:
42
+ gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
43
+ self._last_obs = self._convert_observation(gym_obs) # type: ignore
44
+ self._done = terminated or truncated
45
+ self._episode_reward = max(self._episode_reward, reward)
46
+
47
+ def _convert_observation(self, gym_obs: dict) -> dict:
48
+ img = gym_obs["pixels"]["top"]
49
+ img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
50
+ # Convert axis order from [H, W, C] --> [C, H, W]
51
+ img = np.transpose(img, (2, 0, 1))
52
+
53
+ return {
54
+ "state": gym_obs["agent_pos"],
55
+ "images": {"cam_high": img},
56
+ }
openpi/examples/aloha_sim/main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import pathlib
4
+
5
+ import env as _env
6
+ from openpi_client import action_chunk_broker
7
+ from openpi_client import websocket_client_policy as _websocket_client_policy
8
+ from openpi_client.runtime import runtime as _runtime
9
+ from openpi_client.runtime.agents import policy_agent as _policy_agent
10
+ import saver as _saver
11
+ import tyro
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class Args:
16
+ out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
17
+
18
+ task: str = "gym_aloha/AlohaTransferCube-v0"
19
+ seed: int = 0
20
+
21
+ action_horizon: int = 10
22
+
23
+ host: str = "0.0.0.0"
24
+ port: int = 8000
25
+
26
+ display: bool = False
27
+
28
+
29
+ def main(args: Args) -> None:
30
+ runtime = _runtime.Runtime(
31
+ environment=_env.AlohaSimEnvironment(
32
+ task=args.task,
33
+ seed=args.seed,
34
+ ),
35
+ agent=_policy_agent.PolicyAgent(
36
+ policy=action_chunk_broker.ActionChunkBroker(
37
+ policy=_websocket_client_policy.WebsocketClientPolicy(
38
+ host=args.host,
39
+ port=args.port,
40
+ ),
41
+ action_horizon=args.action_horizon,
42
+ )
43
+ ),
44
+ subscribers=[
45
+ _saver.VideoSaver(args.out_dir),
46
+ ],
47
+ max_hz=50,
48
+ )
49
+
50
+ runtime.run()
51
+
52
+
53
+ if __name__ == "__main__":
54
+ logging.basicConfig(level=logging.INFO, force=True)
55
+ tyro.cli(main)
openpi/examples/aloha_sim/requirements.in ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gym-aloha
2
+ imageio
3
+ matplotlib
4
+ msgpack
5
+ numpy>=1.22.4,<2.0.0
6
+ typing-extensions
7
+ tyro
8
+ websockets
openpi/examples/aloha_sim/requirements.txt ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
3
+ absl-py==2.1.0
4
+ # via
5
+ # dm-control
6
+ # dm-env
7
+ # labmaze
8
+ # mujoco
9
+ certifi==2024.8.30
10
+ # via requests
11
+ charset-normalizer==3.4.0
12
+ # via requests
13
+ cloudpickle==3.1.0
14
+ # via gymnasium
15
+ contourpy==1.3.1
16
+ # via matplotlib
17
+ cycler==0.12.1
18
+ # via matplotlib
19
+ dm-control==1.0.14
20
+ # via gym-aloha
21
+ dm-env==1.6
22
+ # via dm-control
23
+ dm-tree==0.1.8
24
+ # via
25
+ # dm-control
26
+ # dm-env
27
+ docstring-parser==0.16
28
+ # via tyro
29
+ farama-notifications==0.0.4
30
+ # via gymnasium
31
+ fonttools==4.55.2
32
+ # via matplotlib
33
+ glfw==2.8.0
34
+ # via
35
+ # dm-control
36
+ # mujoco
37
+ gym-aloha==0.1.1
38
+ # via -r examples/aloha_sim/requirements.in
39
+ gymnasium==1.0.0
40
+ # via gym-aloha
41
+ idna==3.10
42
+ # via requests
43
+ imageio==2.36.1
44
+ # via
45
+ # -r examples/aloha_sim/requirements.in
46
+ # gym-aloha
47
+ imageio-ffmpeg==0.5.1
48
+ # via imageio
49
+ kiwisolver==1.4.7
50
+ # via matplotlib
51
+ labmaze==1.0.6
52
+ # via dm-control
53
+ lxml==5.3.0
54
+ # via dm-control
55
+ markdown-it-py==3.0.0
56
+ # via rich
57
+ matplotlib==3.9.3
58
+ # via -r examples/aloha_sim/requirements.in
59
+ mdurl==0.1.2
60
+ # via markdown-it-py
61
+ msgpack==1.1.0
62
+ # via -r examples/aloha_sim/requirements.in
63
+ mujoco==2.3.7
64
+ # via
65
+ # dm-control
66
+ # gym-aloha
67
+ numpy==1.26.4
68
+ # via
69
+ # -r examples/aloha_sim/requirements.in
70
+ # contourpy
71
+ # dm-control
72
+ # dm-env
73
+ # gymnasium
74
+ # imageio
75
+ # labmaze
76
+ # matplotlib
77
+ # mujoco
78
+ # scipy
79
+ packaging==24.2
80
+ # via matplotlib
81
+ pillow==11.0.0
82
+ # via
83
+ # imageio
84
+ # matplotlib
85
+ protobuf==5.29.1
86
+ # via dm-control
87
+ psutil==6.1.0
88
+ # via imageio
89
+ pygments==2.18.0
90
+ # via rich
91
+ pyopengl==3.1.7
92
+ # via
93
+ # dm-control
94
+ # mujoco
95
+ pyparsing==3.2.0
96
+ # via
97
+ # dm-control
98
+ # matplotlib
99
+ python-dateutil==2.9.0.post0
100
+ # via matplotlib
101
+ requests==2.32.3
102
+ # via dm-control
103
+ rich==13.9.4
104
+ # via tyro
105
+ scipy==1.14.1
106
+ # via dm-control
107
+ setuptools==75.6.0
108
+ # via
109
+ # dm-control
110
+ # imageio-ffmpeg
111
+ # labmaze
112
+ shtab==1.7.1
113
+ # via tyro
114
+ six==1.17.0
115
+ # via python-dateutil
116
+ tqdm==4.67.1
117
+ # via dm-control
118
+ typeguard==4.4.1
119
+ # via tyro
120
+ typing-extensions==4.12.2
121
+ # via
122
+ # -r examples/aloha_sim/requirements.in
123
+ # gymnasium
124
+ # rich
125
+ # typeguard
126
+ # tyro
127
+ tyro==0.9.2
128
+ # via -r examples/aloha_sim/requirements.in
129
+ urllib3==2.2.3
130
+ # via requests
131
+ websockets==14.1
132
+ # via -r examples/aloha_sim/requirements.in
openpi/examples/aloha_sim/saver.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pathlib
3
+
4
+ import imageio
5
+ import numpy as np
6
+ from openpi_client.runtime import subscriber as _subscriber
7
+ from typing_extensions import override
8
+
9
+
10
+ class VideoSaver(_subscriber.Subscriber):
11
+ """Saves episode data."""
12
+
13
+ def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
14
+ out_dir.mkdir(parents=True, exist_ok=True)
15
+ self._out_dir = out_dir
16
+ self._images: list[np.ndarray] = []
17
+ self._subsample = subsample
18
+
19
+ @override
20
+ def on_episode_start(self) -> None:
21
+ self._images = []
22
+
23
+ @override
24
+ def on_step(self, observation: dict, action: dict) -> None:
25
+ im = observation["images"]["cam_high"] # [C, H, W]
26
+ im = np.transpose(im, (1, 2, 0)) # [H, W, C]
27
+ self._images.append(im)
28
+
29
+ @override
30
+ def on_episode_end(self) -> None:
31
+ existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
32
+ next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
33
+ out_path = self._out_dir / f"out_{next_idx}.mp4"
34
+
35
+ logging.info(f"Saving video to {out_path}")
36
+ imageio.mimwrite(
37
+ out_path,
38
+ [np.asarray(x) for x in self._images[:: self._subsample]],
39
+ fps=50 // max(1, self._subsample),
40
+ )
openpi/examples/convert_jax_model_to_pytorch.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
4
+
5
+ This script loads a JAX model checkpoint using orbax and can either:
6
+ 1. Print out all the parameter keys in a hierarchical structure for inspection
7
+ 2. Convert the JAX model to PyTorch format using our PI0Pytorch model
8
+
9
+ Usage:
10
+ # Just inspect keys:
11
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
12
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
13
+
14
+ # Convert to PyTorch:
15
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
16
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
17
+
18
+ Example:
19
+ # pi0_droid
20
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
21
+
22
+ # pi0_aloha_sim
23
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
24
+
25
+ # pi05_droid
26
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
27
+ """
28
+
29
+ import json
30
+ import os
31
+ import pathlib
32
+ import shutil
33
+ from typing import Literal
34
+
35
+ from flax.nnx import traversals
36
+ import numpy as np
37
+ import orbax.checkpoint as ocp
38
+ import safetensors
39
+ import torch
40
+ import tyro
41
+
42
+ import openpi.models.gemma
43
+ import openpi.models.model
44
+ import openpi.models.pi0_config
45
+ import openpi.models_pytorch.pi0_pytorch
46
+ from openpi.training import utils
47
+ import openpi.training.config as _config
48
+
49
+
50
+ def slice_paligemma_state_dict(state_dict, config):
51
+ """Convert PaliGemma JAX parameters to PyTorch format."""
52
+ suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
53
+
54
+ # patch embeddings
55
+ jax_key = f"img/embedding/kernel{suffix}"
56
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
57
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
58
+
59
+ jax_key = f"img/embedding/bias{suffix}"
60
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
61
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
62
+
63
+ # positional embeddings
64
+ jax_key = f"img/pos_embedding{suffix}"
65
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
66
+ state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
67
+
68
+ # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
69
+ encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
70
+ encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
71
+ encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
72
+ encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
73
+
74
+ encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
75
+ encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
76
+ encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
77
+ encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
78
+
79
+ encoderblock_attention_0_key_kernel = state_dict.pop(
80
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
81
+ )
82
+ encoderblock_attention_0_key_bias = state_dict.pop(
83
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
84
+ )
85
+ encoderblock_attention_0_value_kernel = state_dict.pop(
86
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
87
+ )
88
+ encoderblock_attention_0_value_bias = state_dict.pop(
89
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
90
+ )
91
+ encoderblock_attention_0_query_kernel = state_dict.pop(
92
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
93
+ )
94
+ encoderblock_attention_0_query_bias = state_dict.pop(
95
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
96
+ )
97
+ encoderblock_attention_0_out_kernel = state_dict.pop(
98
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
99
+ )
100
+ encoderblock_attention_0_out_bias = state_dict.pop(
101
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
102
+ )
103
+
104
+ for i in range(config.vision_config.num_hidden_layers):
105
+ state_dict[
106
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
107
+ ] = encoderblock_layernorm0_scale[i].transpose()
108
+ state_dict[
109
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
110
+ ] = encoderblock_layernorm0_bias[i]
111
+ state_dict[
112
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
113
+ ] = encoderblock_layernorm1_scale[i].transpose()
114
+ state_dict[
115
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
116
+ ] = encoderblock_layernorm1_bias[i]
117
+ state_dict[
118
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
119
+ ] = encoderblock_mlp_dense0_kernel[i].transpose()
120
+ state_dict[
121
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
122
+ ] = encoderblock_mlp_dense0_bias[i]
123
+ state_dict[
124
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
125
+ ] = encoderblock_mlp_dense1_kernel[i].transpose()
126
+ state_dict[
127
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
128
+ ] = encoderblock_mlp_dense1_bias[i]
129
+ state_dict[
130
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
131
+ ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
132
+ state_dict[
133
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
134
+ ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
135
+ state_dict[
136
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
137
+ ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
138
+ state_dict[
139
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
140
+ ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
141
+ state_dict[
142
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
143
+ ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
144
+ state_dict[
145
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
146
+ ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
147
+ state_dict[
148
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
149
+ ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
150
+ state_dict[
151
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
152
+ ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
153
+
154
+ jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
155
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
156
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
157
+
158
+ jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
159
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
160
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
161
+
162
+ # multimodal projector
163
+ jax_key = f"img/head/kernel{suffix}"
164
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
165
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
166
+
167
+ jax_key = f"img/head/bias{suffix}"
168
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
169
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
170
+
171
+ # text decoder (gemma)
172
+ jax_key = f"llm/embedder/input_embedding{suffix}"
173
+ pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
174
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
175
+
176
+ # pop the einsum attention + mlp representations
177
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
178
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
179
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
180
+
181
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
182
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
183
+
184
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
185
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
186
+
187
+ for i in range(config.text_config.num_hidden_layers):
188
+ q_proj_weight_reshaped = (
189
+ llm_attention_q_einsum[i]
190
+ .transpose(0, 2, 1)
191
+ .reshape(
192
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
193
+ )
194
+ )
195
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
196
+ q_proj_weight_reshaped
197
+ )
198
+
199
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
200
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
201
+ k_proj_weight_reshaped
202
+ )
203
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
204
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
205
+ v_proj_weight_reshaped
206
+ )
207
+
208
+ o_proj_weight_reshaped = (
209
+ llm_attention_attn_vec_einsum[i]
210
+ .transpose(2, 0, 1)
211
+ .reshape(
212
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
213
+ )
214
+ )
215
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
216
+ o_proj_weight_reshaped
217
+ )
218
+
219
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
220
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
221
+ gate_proj_weight.transpose()
222
+ )
223
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
224
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
225
+ up_proj_weight.transpose()
226
+ )
227
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
228
+ llm_mlp_linear[i].transpose()
229
+ )
230
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
231
+ llm_input_layernorm[i]
232
+ )
233
+ state_dict[
234
+ f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
235
+ ] = llm_post_attention_layernorm[i]
236
+
237
+ jax_key = f"llm/final_norm/scale{suffix}"
238
+ pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
239
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
240
+
241
+ expert_dict = {}
242
+ final_state_dict = {}
243
+
244
+ # Expert-related keys to extract (including pi05 Dense layer parameters)
245
+ expert_keys = [
246
+ f"llm/final_norm_1/scale{suffix}",
247
+ f"llm/final_norm_1/Dense_0/bias{suffix}",
248
+ f"llm/final_norm_1/Dense_0/kernel{suffix}",
249
+ f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
250
+ f"llm/layers/attn/kv_einsum_1/w{suffix}",
251
+ f"llm/layers/attn/q_einsum_1/w{suffix}",
252
+ f"llm/layers/mlp_1/gating_einsum{suffix}",
253
+ f"llm/layers/mlp_1/linear{suffix}",
254
+ f"llm/layers/pre_attention_norm_1/scale{suffix}",
255
+ f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
256
+ f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
257
+ f"llm/layers/pre_ffw_norm_1/scale{suffix}",
258
+ f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
259
+ f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
260
+ ]
261
+
262
+ for key, value in state_dict.items():
263
+ if key not in expert_keys:
264
+ final_state_dict[key] = torch.from_numpy(value)
265
+ else:
266
+ expert_dict[key] = value
267
+
268
+ return final_state_dict, expert_dict
269
+
270
+
271
+ def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
272
+ """Convert Gemma JAX parameters to PyTorch format."""
273
+ # Add missing attributes to config if they don't exist
274
+ if not hasattr(config, "vocab_size"):
275
+ config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
276
+ if not hasattr(config, "hidden_size"):
277
+ config.hidden_size = config.width
278
+ if not hasattr(config, "num_hidden_layers"):
279
+ config.num_hidden_layers = config.depth
280
+ if not hasattr(config, "num_attention_heads"):
281
+ config.num_attention_heads = config.num_heads
282
+
283
+ suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
284
+
285
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
286
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
287
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
288
+
289
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
290
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
291
+
292
+ # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
293
+ if "pi05" in checkpoint_dir:
294
+ # Pi05 with adaptive normalization
295
+ llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
296
+ llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
297
+ llm_input_layernorm_kernel = state_dict.pop(
298
+ f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
299
+ )
300
+ llm_post_attention_layernorm_kernel = state_dict.pop(
301
+ f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
302
+ )
303
+ else:
304
+ # Regular pi0 with standard RMSNorm
305
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
306
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
307
+
308
+ for i in range(config.num_hidden_layers):
309
+ q_proj_weight_reshaped = (
310
+ llm_attention_q_einsum[i]
311
+ .transpose(0, 2, 1)
312
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
313
+ )
314
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
315
+ q_proj_weight_reshaped
316
+ )
317
+
318
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
319
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
320
+ k_proj_weight_reshaped
321
+ )
322
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
323
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
324
+ v_proj_weight_reshaped
325
+ )
326
+
327
+ o_proj_weight_reshaped = (
328
+ llm_attention_attn_vec_einsum[i]
329
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
330
+ .transpose(1, 0)
331
+ )
332
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
333
+ o_proj_weight_reshaped
334
+ )
335
+
336
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
337
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
338
+ gate_proj_weight.transpose()
339
+ )
340
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
341
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
342
+ up_proj_weight.transpose()
343
+ )
344
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
345
+ i
346
+ ].transpose()
347
+
348
+ if "pi05" in checkpoint_dir:
349
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
350
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
351
+ llm_input_layernorm_bias[i]
352
+ )
353
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
354
+ llm_post_attention_layernorm_bias[i]
355
+ )
356
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
357
+ llm_input_layernorm_kernel[i].transpose()
358
+ )
359
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
360
+ llm_post_attention_layernorm_kernel[i].transpose()
361
+ )
362
+ else:
363
+ # Regular pi0 with standard RMSNorm
364
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
365
+ llm_input_layernorm[i]
366
+ )
367
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
368
+ llm_post_attention_layernorm[i]
369
+ )
370
+
371
+ # Handle final norm layer
372
+ if "pi05" in checkpoint_dir:
373
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
374
+ final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
375
+ final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
376
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
377
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
378
+ else:
379
+ # Regular pi0 with standard RMSNorm
380
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
381
+ f"llm/final_norm_{num_expert}/scale{suffix}"
382
+ )
383
+
384
+ # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
385
+
386
+ final_state_dict = {}
387
+ for key, value in state_dict.items():
388
+ if not isinstance(value, torch.Tensor):
389
+ final_state_dict[key] = torch.from_numpy(value)
390
+ else:
391
+ final_state_dict[key] = value
392
+
393
+ return final_state_dict
394
+
395
+
396
+ def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
397
+ """Load and process params by restoring via JAX model loader first.
398
+ This respects dtype conversions that occur during model restore.
399
+ """
400
+ # Use repository restore utility to load a pure dict of params (value suffix removed)
401
+ params = openpi.models.model.restore_params(
402
+ f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
403
+ )
404
+
405
+ return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
406
+
407
+
408
+ def load_jax_model_and_print_keys(checkpoint_dir: str):
409
+ """
410
+ Load JAX model from checkpoint and print all parameter keys.
411
+
412
+ Args:
413
+ checkpoint_dir: Path to the checkpoint directory
414
+ """
415
+ checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
416
+ # Initialize checkpointer
417
+ checkpointer = ocp.PyTreeCheckpointer()
418
+ metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
419
+ print(utils.array_tree_to_info(metadata))
420
+
421
+
422
+ def convert_pi0_checkpoint(
423
+ checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
424
+ ):
425
+ """
426
+ Convert PI0 JAX checkpoint to PyTorch format.
427
+
428
+ Args:
429
+ checkpoint_dir: Path to the JAX checkpoint
430
+ precision: Model precision (float32, bfloat16, float16)
431
+ output_path: Path to save the converted PyTorch model
432
+ model_config: Model config
433
+ """
434
+ print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
435
+ print(f"Model config: {model_config}")
436
+
437
+ # Break down orbax ckpts by restoring via JAX to respect dtype
438
+ initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
439
+
440
+ # Process projection params
441
+ if model_config.pi05:
442
+ keys = [
443
+ "action_in_proj",
444
+ "action_out_proj",
445
+ "time_mlp_in",
446
+ "time_mlp_out",
447
+ ]
448
+ else:
449
+ keys = [
450
+ "state_proj",
451
+ "action_in_proj",
452
+ "action_out_proj",
453
+ "action_time_mlp_in",
454
+ "action_time_mlp_out",
455
+ ]
456
+
457
+ projection_params = {}
458
+ for key in keys:
459
+ kernel_params = initial_params["projection_params"][key]["kernel"]
460
+ bias_params = initial_params["projection_params"][key]["bias"]
461
+ if isinstance(kernel_params, dict):
462
+ weight = kernel_params["value"]
463
+ bias = bias_params["value"]
464
+ else:
465
+ weight = kernel_params
466
+ bias = bias_params
467
+
468
+ pytorch_weight_key = f"{key}.weight"
469
+ pytorch_bias_key = f"{key}.bias"
470
+
471
+ projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
472
+ projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
473
+
474
+ # Create configs based on checkpoint path
475
+ # All models use the same PaliGemma config structure
476
+ class PaliGemmaConfig:
477
+ def __init__(self):
478
+ self.vision_config = type(
479
+ "obj",
480
+ (object,),
481
+ {
482
+ "hidden_size": 1152,
483
+ "num_hidden_layers": 27,
484
+ "num_attention_heads": 16,
485
+ "intermediate_size": 4304,
486
+ "patch_size": 14,
487
+ "projection_dim": 2048,
488
+ },
489
+ )()
490
+ self.text_config = type(
491
+ "obj",
492
+ (object,),
493
+ {
494
+ "hidden_size": 2048,
495
+ "num_hidden_layers": 18,
496
+ "num_attention_heads": 8,
497
+ "head_dim": 256,
498
+ "intermediate_size": 16384,
499
+ },
500
+ )()
501
+
502
+ paligemma_config = PaliGemmaConfig()
503
+ action_expert_config = openpi.models.gemma.get_config("gemma_300m")
504
+
505
+ # Process PaliGemma weights
506
+ paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
507
+
508
+ # Process Gemma weights from expert_params
509
+ gemma_params = slice_gemma_state_dict(
510
+ expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
511
+ )
512
+
513
+ # Instantiate model
514
+ pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
515
+
516
+ # Combine all parameters (no prefix needed for our model structure)
517
+ all_params = {**paligemma_params, **gemma_params, **projection_params}
518
+
519
+ # Load state dict
520
+ pi0_model.load_state_dict(all_params, strict=False)
521
+
522
+ if precision == "float32":
523
+ pi0_model = pi0_model.to(torch.float32)
524
+ elif precision == "bfloat16":
525
+ pi0_model = pi0_model.to(torch.bfloat16)
526
+ else:
527
+ raise ValueError(f"Invalid precision: {precision}")
528
+
529
+ # Save the converted model using safetensors
530
+ os.makedirs(output_path, exist_ok=True)
531
+
532
+ # Save model weights as SafeTensors using save_model to handle tied weights
533
+ safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
534
+
535
+ # Copy assets folder if it exists
536
+ assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
537
+ if assets_source.exists():
538
+ assets_dest = pathlib.Path(output_path) / "assets"
539
+ if assets_dest.exists():
540
+ shutil.rmtree(assets_dest)
541
+ shutil.copytree(assets_source, assets_dest)
542
+
543
+ # Save config as JSON for reference
544
+ config_dict = {
545
+ "action_dim": model_config.action_dim,
546
+ "action_horizon": model_config.action_horizon,
547
+ "paligemma_variant": model_config.paligemma_variant,
548
+ "action_expert_variant": model_config.action_expert_variant,
549
+ "precision": precision,
550
+ }
551
+ with open(os.path.join(output_path, "config.json"), "w") as f:
552
+ json.dump(config_dict, f, indent=2)
553
+
554
+ print("Model conversion completed successfully!")
555
+ print(f"Model saved to {output_path}")
556
+
557
+
558
+ def main(
559
+ checkpoint_dir: str,
560
+ config_name: str,
561
+ output_path: str | None = None,
562
+ precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
563
+ *,
564
+ inspect_only: bool = False,
565
+ ):
566
+ """Load JAX model and optionally convert to PyTorch.
567
+
568
+ Args:
569
+ checkpoint_dir: Path to the JAX checkpoint directory
570
+ output_path: Path to save converted PyTorch model (required for conversion)
571
+ precision: Precision for model conversion
572
+ inspect_only: Only inspect parameter keys, don't convert
573
+ """
574
+ model_config = _config.get_config(config_name).model
575
+ if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
576
+ raise ValueError(f"Config {config_name} is not a Pi0Config")
577
+ if inspect_only:
578
+ load_jax_model_and_print_keys(checkpoint_dir)
579
+ else:
580
+ if not output_path:
581
+ print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
582
+ return
583
+ convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
584
+
585
+
586
+ if __name__ == "__main__":
587
+ tyro.cli(main)
openpi/examples/droid/README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DROID Policies in openpi
2
+
3
+ We offer instructions for:
4
+ - [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
5
+ - [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
6
+ - [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
7
+ - [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
8
+
9
+ ## Running DROID Inference
10
+
11
+ This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
12
+
13
+
14
+ ### Step 1: Start a policy server
15
+
16
+ Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
17
+
18
+ 1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
19
+ 2. Start the OpenPI server via the following command:
20
+
21
+ ```bash
22
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
23
+ ```
24
+
25
+ You can also run the equivalent command below:
26
+
27
+ ```bash
28
+ uv run scripts/serve_policy.py --env=DROID
29
+ ```
30
+
31
+ ### Step 2: Run the DROID robot
32
+
33
+ 1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
34
+ 2. On the control laptop, activate your DROID conda environment.
35
+ 3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
36
+ 4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
37
+ 5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
38
+ 6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
39
+ 7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
40
+
41
+ ```bash
42
+ python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
43
+ ```
44
+
45
+ The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
46
+
47
+ ## Troubleshooting
48
+
49
+ | Issue | Solution |
50
+ |-------|----------|
51
+ | Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
52
+ | Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
53
+ | Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
54
+ | Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
55
+
56
+
57
+ ## Running Other Policies
58
+
59
+ We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
60
+
61
+ ```
62
+ # Train from pi0-FAST, using FAST tokenizer
63
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
64
+
65
+ # Train from pi0, using flow matching
66
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
67
+
68
+ # Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
69
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
70
+
71
+ # Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
72
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
73
+
74
+ # Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
75
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
76
+
77
+ # Trained from PaliGemma, using FSQ tokenizer.
78
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
79
+
80
+ # pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
81
+ uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
82
+ ```
83
+
84
+ You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
openpi/examples/droid/README_train.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training on DROID
2
+
3
+ Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
4
+ (small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
5
+
6
+ In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
7
+ for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
8
+
9
+ ## Install
10
+
11
+ We need a few additional dependencies for RLDS data loading. Run:
12
+ ```bash
13
+ uv sync --group rlds
14
+ ```
15
+
16
+ ## Download DROID dataset
17
+
18
+ You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
19
+ ```
20
+ gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
21
+ ```
22
+
23
+ Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
24
+
25
+ You will need 1.8TB of disk storage to download the DROID RLDS dataset.
26
+
27
+ ## Run
28
+
29
+ First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
30
+
31
+ Then, compute normalization statistics (this will take ~10 minutes):
32
+ ```bash
33
+ uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
34
+ ```
35
+
36
+ Run training:
37
+ ```bash
38
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
39
+ ```
40
+
41
+ **Note**: The original pi0.5-DROID model was trained with joint velocity actions.
42
+ Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
43
+ Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
44
+
45
+
46
+ ## Compute Requirements
47
+
48
+ Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
49
+ If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
50
+
51
+ We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
52
+
53
+
54
+ ## Data Filtering
55
+
56
+ Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
57
+
58
+ By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
59
+
60
+ **Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
61
+
62
+ ## RoboArena
63
+
64
+ Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
65
+
66
+ If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
67
+
68
+
69
+ # Fine-Tuning on Custom DROID Datasets
70
+
71
+ Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
72
+
73
+ Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
74
+
75
+
76
+ ## Step 1: Converting your custom DROID dataset to LeRobot
77
+
78
+ We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
79
+ ```
80
+ gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
81
+ ```
82
+
83
+ We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
84
+ ```
85
+ gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
86
+ ```
87
+
88
+ For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
89
+
90
+ Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
91
+ ```
92
+ uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
93
+ ```
94
+
95
+ ## Step 2: Run fine-tuning with your custom dataset
96
+
97
+ Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
98
+ You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
99
+
100
+ To launch training:
101
+ ```
102
+ uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
103
+ ```
104
+
105
+ Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
106
+
openpi/examples/droid/compute_droid_nonidle_ranges.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
3
+ that should be sampled during training (all others are filtered out).
4
+
5
+ Filtering logic:
6
+ We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
7
+ (default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
8
+ this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
9
+ ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
10
+ filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
11
+
12
+ This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
13
+ yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
14
+ """
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import tensorflow as tf
22
+ import tensorflow_datasets as tfds
23
+ from tqdm import tqdm
24
+
25
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
26
+
27
+ builder = tfds.builder_from_directory(
28
+ # path to the `droid` directory (not its parent)
29
+ builder_dir="<path_to_droid_dataset_tfds_files>",
30
+ )
31
+ ds = builder.as_dataset(split="train", shuffle_files=False)
32
+ tf.data.experimental.ignore_errors(ds)
33
+
34
+ keep_ranges_path = "<path_to_where_to_save_the_json>"
35
+
36
+ min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
37
+ min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
38
+ filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
39
+
40
+ keep_ranges_map = {}
41
+ if Path(keep_ranges_path).exists():
42
+ with Path(keep_ranges_path).open("r") as f:
43
+ keep_ranges_map = json.load(f)
44
+ print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
45
+
46
+ for ep_idx, ep in enumerate(tqdm(ds)):
47
+ recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
48
+ file_path = ep["episode_metadata"]["file_path"].numpy().decode()
49
+
50
+ key = f"{recording_folderpath}--{file_path}"
51
+ if key in keep_ranges_map:
52
+ continue
53
+
54
+ joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
55
+ joint_velocities = np.array(joint_velocities)
56
+
57
+ is_idle_array = np.hstack(
58
+ [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
59
+ )
60
+
61
+ # Find what steps go from idle to non-idle and vice-versa
62
+ is_idle_padded = np.concatenate(
63
+ [[False], is_idle_array, [False]]
64
+ ) # Start and end with False, so idle at first step is a start of motion
65
+
66
+ is_idle_diff = np.diff(is_idle_padded.astype(int))
67
+ is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
68
+ is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
69
+
70
+ # Find which steps correspond to idle segments of length at least min_idle_len
71
+ true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
72
+ is_idle_true_starts = is_idle_true_starts[true_segment_masks]
73
+ is_idle_true_ends = is_idle_true_ends[true_segment_masks]
74
+
75
+ keep_mask = np.ones(len(joint_velocities), dtype=bool)
76
+ for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
77
+ keep_mask[start:end] = False
78
+
79
+ # Get all non-idle ranges of at least 16
80
+ # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
81
+ keep_padded = np.concatenate([[False], keep_mask, [False]])
82
+
83
+ keep_diff = np.diff(keep_padded.astype(int))
84
+ keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
85
+ keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
86
+
87
+ # Find which steps correspond to non-idle segments of length at least min_non_idle_len
88
+ true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
89
+ keep_true_starts = keep_true_starts[true_segment_masks]
90
+ keep_true_ends = keep_true_ends[true_segment_masks]
91
+
92
+ # Add mapping from episode unique ID key to list of non-idle ranges to keep
93
+ keep_ranges_map[key] = []
94
+ for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
95
+ keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
96
+
97
+ if ep_idx % 1000 == 0:
98
+ with Path(keep_ranges_path).open("w") as f:
99
+ json.dump(keep_ranges_map, f)
100
+
101
+ print("Done!")
102
+ with Path(keep_ranges_path).open("w") as f:
103
+ json.dump(keep_ranges_map, f)
openpi/examples/droid/convert_droid_data_to_lerobot.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
3
+
4
+ Usage:
5
+ uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
6
+
7
+ If you want to push your dataset to the Hugging Face Hub, you can use the following command:
8
+ uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
9
+
10
+ The resulting dataset will get saved to the $LEROBOT_HOME directory.
11
+ """
12
+
13
+ from collections import defaultdict
14
+ import copy
15
+ import glob
16
+ import json
17
+ from pathlib import Path
18
+ import shutil
19
+
20
+ import cv2
21
+ import h5py
22
+ from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
23
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
24
+ import numpy as np
25
+ from PIL import Image
26
+ from tqdm import tqdm
27
+ import tyro
28
+
29
+ REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
30
+
31
+
32
+ def resize_image(image, size):
33
+ image = Image.fromarray(image)
34
+ return np.array(image.resize(size, resample=Image.BICUBIC))
35
+
36
+
37
+ def main(data_dir: str, *, push_to_hub: bool = False):
38
+ # Clean up any existing dataset in the output directory
39
+ output_path = HF_LEROBOT_HOME / REPO_NAME
40
+ if output_path.exists():
41
+ shutil.rmtree(output_path)
42
+ data_dir = Path(data_dir)
43
+
44
+ # Create LeRobot dataset, define features to store
45
+ # We will follow the DROID data naming conventions here.
46
+ # LeRobot assumes that dtype of image data is `image`
47
+ dataset = LeRobotDataset.create(
48
+ repo_id=REPO_NAME,
49
+ robot_type="panda",
50
+ fps=15, # DROID data is typically recorded at 15fps
51
+ features={
52
+ # We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
53
+ "exterior_image_1_left": {
54
+ "dtype": "image",
55
+ "shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
56
+ "names": ["height", "width", "channel"],
57
+ },
58
+ "exterior_image_2_left": {
59
+ "dtype": "image",
60
+ "shape": (180, 320, 3),
61
+ "names": ["height", "width", "channel"],
62
+ },
63
+ "wrist_image_left": {
64
+ "dtype": "image",
65
+ "shape": (180, 320, 3),
66
+ "names": ["height", "width", "channel"],
67
+ },
68
+ "joint_position": {
69
+ "dtype": "float32",
70
+ "shape": (7,),
71
+ "names": ["joint_position"],
72
+ },
73
+ "gripper_position": {
74
+ "dtype": "float32",
75
+ "shape": (1,),
76
+ "names": ["gripper_position"],
77
+ },
78
+ "actions": {
79
+ "dtype": "float32",
80
+ "shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
81
+ "names": ["actions"],
82
+ },
83
+ },
84
+ image_writer_threads=10,
85
+ image_writer_processes=5,
86
+ )
87
+
88
+ # Load language annotations
89
+ # Note: we load the DROID language annotations for this example, but you can manually define them for your own data
90
+ with (data_dir / "aggregated-annotations-030724.json").open() as f:
91
+ language_annotations = json.load(f)
92
+
93
+ # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
94
+ # We assume the following directory structure:
95
+ # RAW_DROID_PATH/
96
+ # - <...>/
97
+ # - recordings/
98
+ # - MP4/
99
+ # - <camera_id>.mp4 # single-view video of left stereo pair camera
100
+ # - trajectory.hdf5
101
+ # - <...>/
102
+ episode_paths = list(data_dir.glob("**/trajectory.h5"))
103
+ print(f"Found {len(episode_paths)} episodes for conversion")
104
+
105
+ # We will loop over each dataset_name and write episodes to the LeRobot dataset
106
+ for episode_path in tqdm(episode_paths, desc="Converting episodes"):
107
+ # Load raw data
108
+ recording_folderpath = episode_path.parent / "recordings" / "MP4"
109
+ trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
110
+
111
+ # To load the language instruction, we need to parse out the episode_id from the metadata file
112
+ # Again, you can modify this step for your own data, to load your own language instructions
113
+ metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
114
+ episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
115
+ language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
116
+ "language_instruction1"
117
+ ]
118
+ print(f"Converting episode with language instruction: {language_instruction}")
119
+
120
+ # Write to LeRobot dataset
121
+ for step in trajectory:
122
+ camera_type_dict = step["observation"]["camera_type"]
123
+ wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
124
+ exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
125
+ dataset.add_frame(
126
+ {
127
+ # Note: need to flip BGR --> RGB for loaded images
128
+ "exterior_image_1_left": resize_image(
129
+ step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
130
+ ),
131
+ "exterior_image_2_left": resize_image(
132
+ step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
133
+ ),
134
+ "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
135
+ "joint_position": np.asarray(
136
+ step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
137
+ ),
138
+ "gripper_position": np.asarray(
139
+ step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
140
+ ),
141
+ # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
142
+ "actions": np.concatenate(
143
+ [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
144
+ ),
145
+ "task": language_instruction,
146
+ }
147
+ )
148
+ dataset.save_episode()
149
+
150
+ # Optionally push to the Hugging Face Hub
151
+ if push_to_hub:
152
+ dataset.push_to_hub(
153
+ tags=["libero", "panda", "rlds"],
154
+ private=False,
155
+ push_videos=True,
156
+ license="apache-2.0",
157
+ )
158
+
159
+
160
+ ##########################################################################################################
161
+ ################ The rest of this file are functions to parse the raw DROID data #########################
162
+ ################ You don't need to worry about understanding this part #########################
163
+ ################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
164
+ ##########################################################################################################
165
+
166
+
167
+ camera_type_dict = {
168
+ "hand_camera_id": 0,
169
+ "varied_camera_1_id": 1,
170
+ "varied_camera_2_id": 1,
171
+ }
172
+
173
+ camera_type_to_string_dict = {
174
+ 0: "hand_camera",
175
+ 1: "varied_camera",
176
+ 2: "fixed_camera",
177
+ }
178
+
179
+
180
+ def get_camera_type(cam_id):
181
+ if cam_id not in camera_type_dict:
182
+ return None
183
+ type_int = camera_type_dict[cam_id]
184
+ return camera_type_to_string_dict[type_int]
185
+
186
+
187
+ class MP4Reader:
188
+ def __init__(self, filepath, serial_number):
189
+ # Save Parameters #
190
+ self.serial_number = serial_number
191
+ self._index = 0
192
+
193
+ # Open Video Reader #
194
+ self._mp4_reader = cv2.VideoCapture(filepath)
195
+ if not self._mp4_reader.isOpened():
196
+ raise RuntimeError("Corrupted MP4 File")
197
+
198
+ def set_reading_parameters(
199
+ self,
200
+ image=True, # noqa: FBT002
201
+ concatenate_images=False, # noqa: FBT002
202
+ resolution=(0, 0),
203
+ resize_func=None,
204
+ ):
205
+ # Save Parameters #
206
+ self.image = image
207
+ self.concatenate_images = concatenate_images
208
+ self.resolution = resolution
209
+ self.resize_func = cv2.resize
210
+ self.skip_reading = not image
211
+ if self.skip_reading:
212
+ return
213
+
214
+ def get_frame_resolution(self):
215
+ width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
216
+ height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
217
+ return (width, height)
218
+
219
+ def get_frame_count(self):
220
+ if self.skip_reading:
221
+ return 0
222
+ return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
223
+
224
+ def set_frame_index(self, index):
225
+ if self.skip_reading:
226
+ return
227
+
228
+ if index < self._index:
229
+ self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
230
+ self._index = index
231
+
232
+ while self._index < index:
233
+ self.read_camera(ignore_data=True)
234
+
235
+ def _process_frame(self, frame):
236
+ frame = copy.deepcopy(frame)
237
+ if self.resolution == (0, 0):
238
+ return frame
239
+ return self.resize_func(frame, self.resolution)
240
+
241
+ def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
242
+ # Skip if Read Unnecesary #
243
+ if self.skip_reading:
244
+ return {}
245
+
246
+ # Read Camera #
247
+ success, frame = self._mp4_reader.read()
248
+
249
+ self._index += 1
250
+ if not success:
251
+ return None
252
+ if ignore_data:
253
+ return None
254
+
255
+ # Return Data #
256
+ data_dict = {}
257
+
258
+ if self.concatenate_images or "stereo" not in self.serial_number:
259
+ data_dict["image"] = {self.serial_number: self._process_frame(frame)}
260
+ else:
261
+ single_width = frame.shape[1] // 2
262
+ data_dict["image"] = {
263
+ self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
264
+ self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
265
+ }
266
+
267
+ return data_dict
268
+
269
+ def disable_camera(self):
270
+ if hasattr(self, "_mp4_reader"):
271
+ self._mp4_reader.release()
272
+
273
+
274
+ class RecordedMultiCameraWrapper:
275
+ def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
276
+ # Save Camera Info #
277
+ self.camera_kwargs = camera_kwargs
278
+
279
+ # Open Camera Readers #
280
+ mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
281
+ all_filepaths = mp4_filepaths
282
+
283
+ self.camera_dict = {}
284
+ for f in all_filepaths:
285
+ serial_number = f.split("/")[-1][:-4]
286
+ cam_type = get_camera_type(serial_number)
287
+ camera_kwargs.get(cam_type, {})
288
+
289
+ if f.endswith(".mp4"):
290
+ Reader = MP4Reader # noqa: N806
291
+ else:
292
+ raise ValueError
293
+
294
+ self.camera_dict[serial_number] = Reader(f, serial_number)
295
+
296
+ def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
297
+ full_obs_dict = defaultdict(dict)
298
+
299
+ # Read Cameras In Randomized Order #
300
+ all_cam_ids = list(self.camera_dict.keys())
301
+ # random.shuffle(all_cam_ids)
302
+
303
+ for cam_id in all_cam_ids:
304
+ if "stereo" in cam_id:
305
+ continue
306
+ try:
307
+ cam_type = camera_type_dict[cam_id]
308
+ except KeyError:
309
+ print(f"{self.camera_dict} -- {camera_type_dict}")
310
+ raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
311
+ curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
312
+ self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
313
+
314
+ timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
315
+ if index is not None:
316
+ self.camera_dict[cam_id].set_frame_index(index)
317
+
318
+ data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
319
+
320
+ # Process Returned Data #
321
+ if data_dict is None:
322
+ return None
323
+ for key in data_dict:
324
+ full_obs_dict[key].update(data_dict[key])
325
+
326
+ return full_obs_dict
327
+
328
+
329
+ def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
330
+ length = None
331
+
332
+ for key in hdf5_file:
333
+ if key in keys_to_ignore:
334
+ continue
335
+
336
+ curr_data = hdf5_file[key]
337
+ if isinstance(curr_data, h5py.Group):
338
+ curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
339
+ elif isinstance(curr_data, h5py.Dataset):
340
+ curr_length = len(curr_data)
341
+ else:
342
+ raise ValueError
343
+
344
+ if length is None:
345
+ length = curr_length
346
+ assert curr_length == length
347
+
348
+ return length
349
+
350
+
351
+ def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
352
+ data_dict = {}
353
+
354
+ for key in hdf5_file:
355
+ if key in keys_to_ignore:
356
+ continue
357
+
358
+ curr_data = hdf5_file[key]
359
+ if isinstance(curr_data, h5py.Group):
360
+ data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
361
+ elif isinstance(curr_data, h5py.Dataset):
362
+ data_dict[key] = curr_data[index]
363
+ else:
364
+ raise ValueError
365
+
366
+ return data_dict
367
+
368
+
369
+ class TrajectoryReader:
370
+ def __init__(self, filepath, read_images=True): # noqa: FBT002
371
+ self._hdf5_file = h5py.File(filepath, "r")
372
+ is_video_folder = "observations/videos" in self._hdf5_file
373
+ self._read_images = read_images and is_video_folder
374
+ self._length = get_hdf5_length(self._hdf5_file)
375
+ self._video_readers = {}
376
+ self._index = 0
377
+
378
+ def length(self):
379
+ return self._length
380
+
381
+ def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
382
+ # Make Sure We Read Within Range #
383
+ if index is None:
384
+ index = self._index
385
+ else:
386
+ assert not self._read_images
387
+ self._index = index
388
+ assert index < self._length
389
+
390
+ # Load Low Dimensional Data #
391
+ keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
392
+ timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
393
+
394
+ # Increment Read Index #
395
+ self._index += 1
396
+
397
+ # Return Timestep #
398
+ return timestep
399
+
400
+ def close(self):
401
+ self._hdf5_file.close()
402
+
403
+
404
+ def load_trajectory(
405
+ filepath=None,
406
+ read_cameras=True, # noqa: FBT002
407
+ recording_folderpath=None,
408
+ camera_kwargs={}, # noqa: B006
409
+ remove_skipped_steps=False, # noqa: FBT002
410
+ num_samples_per_traj=None,
411
+ num_samples_per_traj_coeff=1.5,
412
+ ):
413
+ read_recording_folderpath = read_cameras and (recording_folderpath is not None)
414
+
415
+ traj_reader = TrajectoryReader(filepath)
416
+ if read_recording_folderpath:
417
+ camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
418
+
419
+ horizon = traj_reader.length()
420
+ timestep_list = []
421
+
422
+ # Choose Timesteps To Save #
423
+ if num_samples_per_traj:
424
+ num_to_save = num_samples_per_traj
425
+ if remove_skipped_steps:
426
+ num_to_save = int(num_to_save * num_samples_per_traj_coeff)
427
+ max_size = min(num_to_save, horizon)
428
+ indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
429
+ else:
430
+ indices_to_save = np.arange(horizon)
431
+
432
+ # Iterate Over Trajectory #
433
+ for i in indices_to_save:
434
+ # Get HDF5 Data #
435
+ timestep = traj_reader.read_timestep(index=i)
436
+
437
+ # If Applicable, Get Recorded Data #
438
+ if read_recording_folderpath:
439
+ timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
440
+ camera_type_dict = {
441
+ k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
442
+ }
443
+ camera_obs = camera_reader.read_cameras(
444
+ index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
445
+ )
446
+ camera_failed = camera_obs is None
447
+
448
+ # Add Data To Timestep If Successful #
449
+ if camera_failed:
450
+ break
451
+ timestep["observation"].update(camera_obs)
452
+
453
+ # Filter Steps #
454
+ step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
455
+ delete_skipped_step = step_skipped and remove_skipped_steps
456
+
457
+ # Save Filtered Timesteps #
458
+ if delete_skipped_step:
459
+ del timestep
460
+ else:
461
+ timestep_list.append(timestep)
462
+
463
+ # Remove Extra Transitions #
464
+ timestep_list = np.array(timestep_list)
465
+ if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
466
+ ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
467
+ timestep_list = timestep_list[ind_to_keep]
468
+
469
+ # Close Readers #
470
+ traj_reader.close()
471
+
472
+ # Return Data #
473
+ return timestep_list
474
+
475
+
476
+ if __name__ == "__main__":
477
+ tyro.cli(main)
openpi/examples/droid/main.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa
2
+
3
+ import contextlib
4
+ import dataclasses
5
+ import datetime
6
+ import faulthandler
7
+ import os
8
+ import signal
9
+ import time
10
+ from moviepy.editor import ImageSequenceClip
11
+ import numpy as np
12
+ from openpi_client import image_tools
13
+ from openpi_client import websocket_client_policy
14
+ import pandas as pd
15
+ from PIL import Image
16
+ from droid.robot_env import RobotEnv
17
+ import tqdm
18
+ import tyro
19
+
20
+ faulthandler.enable()
21
+
22
+ # DROID data collection frequency -- we slow down execution to match this frequency
23
+ DROID_CONTROL_FREQUENCY = 15
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Args:
28
+ # Hardware parameters
29
+ left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
30
+ right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
31
+ wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
32
+
33
+ # Policy parameters
34
+ external_camera: str | None = (
35
+ None # which external camera should be fed to the policy, choose from ["left", "right"]
36
+ )
37
+
38
+ # Rollout parameters
39
+ max_timesteps: int = 600
40
+ # How many actions to execute from a predicted action chunk before querying policy server again
41
+ # 8 is usually a good default (equals 0.5 seconds of action execution).
42
+ open_loop_horizon: int = 8
43
+
44
+ # Remote server parameters
45
+ remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
46
+ remote_port: int = (
47
+ 8000 # point this to the port of the policy server, default server port for openpi servers is 8000
48
+ )
49
+
50
+
51
+ # We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
52
+ # waiting for a new action chunk, it will raise an exception and the server connection dies.
53
+ # This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
54
+ @contextlib.contextmanager
55
+ def prevent_keyboard_interrupt():
56
+ """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
57
+ interrupted = False
58
+ original_handler = signal.getsignal(signal.SIGINT)
59
+
60
+ def handler(signum, frame):
61
+ nonlocal interrupted
62
+ interrupted = True
63
+
64
+ signal.signal(signal.SIGINT, handler)
65
+ try:
66
+ yield
67
+ finally:
68
+ signal.signal(signal.SIGINT, original_handler)
69
+ if interrupted:
70
+ raise KeyboardInterrupt
71
+
72
+
73
+ def main(args: Args):
74
+ # Make sure external camera is specified by user -- we only use one external camera for the policy
75
+ assert (
76
+ args.external_camera is not None and args.external_camera in ["left", "right"]
77
+ ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
78
+
79
+ # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
80
+ env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
81
+ print("Created the droid env!")
82
+
83
+ # Connect to the policy server
84
+ policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
85
+
86
+ df = pd.DataFrame(columns=["success", "duration", "video_filename"])
87
+
88
+ while True:
89
+ instruction = input("Enter instruction: ")
90
+
91
+ # Rollout parameters
92
+ actions_from_chunk_completed = 0
93
+ pred_action_chunk = None
94
+
95
+ # Prepare to save video of rollout
96
+ timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
97
+ video = []
98
+ bar = tqdm.tqdm(range(args.max_timesteps))
99
+ print("Running rollout... press Ctrl+C to stop early.")
100
+ for t_step in bar:
101
+ start_time = time.time()
102
+ try:
103
+ # Get the current observation
104
+ curr_obs = _extract_observation(
105
+ args,
106
+ env.get_observation(),
107
+ # Save the first observation to disk
108
+ save_to_disk=t_step == 0,
109
+ )
110
+
111
+ video.append(curr_obs[f"{args.external_camera}_image"])
112
+
113
+ # Send websocket request to policy server if it's time to predict a new chunk
114
+ if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
115
+ actions_from_chunk_completed = 0
116
+
117
+ # We resize images on the robot laptop to minimize the amount of data sent to the policy server
118
+ # and improve latency.
119
+ request_data = {
120
+ "observation/exterior_image_1_left": image_tools.resize_with_pad(
121
+ curr_obs[f"{args.external_camera}_image"], 224, 224
122
+ ),
123
+ "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
124
+ "observation/joint_position": curr_obs["joint_position"],
125
+ "observation/gripper_position": curr_obs["gripper_position"],
126
+ "prompt": instruction,
127
+ }
128
+
129
+ # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
130
+ # Ctrl+C will be handled after the server call is complete
131
+ with prevent_keyboard_interrupt():
132
+ # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
133
+ pred_action_chunk = policy_client.infer(request_data)["actions"]
134
+ assert pred_action_chunk.shape == (10, 8)
135
+
136
+ # Select current action to execute from chunk
137
+ action = pred_action_chunk[actions_from_chunk_completed]
138
+ actions_from_chunk_completed += 1
139
+
140
+ # Binarize gripper action
141
+ if action[-1].item() > 0.5:
142
+ # action[-1] = 1.0
143
+ action = np.concatenate([action[:-1], np.ones((1,))])
144
+ else:
145
+ # action[-1] = 0.0
146
+ action = np.concatenate([action[:-1], np.zeros((1,))])
147
+
148
+ # clip all dimensions of action to [-1, 1]
149
+ action = np.clip(action, -1, 1)
150
+
151
+ env.step(action)
152
+
153
+ # Sleep to match DROID data collection frequency
154
+ elapsed_time = time.time() - start_time
155
+ if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
156
+ time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
157
+ except KeyboardInterrupt:
158
+ break
159
+
160
+ video = np.stack(video)
161
+ save_filename = "video_" + timestamp
162
+ ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
163
+
164
+ success: str | float | None = None
165
+ while not isinstance(success, float):
166
+ success = input(
167
+ "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
168
+ )
169
+ if success == "y":
170
+ success = 1.0
171
+ elif success == "n":
172
+ success = 0.0
173
+
174
+ success = float(success) / 100
175
+ if not (0 <= success <= 1):
176
+ print(f"Success must be a number in [0, 100] but got: {success * 100}")
177
+
178
+ df = df.append(
179
+ {
180
+ "success": success,
181
+ "duration": t_step,
182
+ "video_filename": save_filename,
183
+ },
184
+ ignore_index=True,
185
+ )
186
+
187
+ if input("Do one more eval? (enter y or n) ").lower() != "y":
188
+ break
189
+ env.reset()
190
+
191
+ os.makedirs("results", exist_ok=True)
192
+ timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
193
+ csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
194
+ df.to_csv(csv_filename)
195
+ print(f"Results saved to {csv_filename}")
196
+
197
+
198
+ def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
199
+ image_observations = obs_dict["image"]
200
+ left_image, right_image, wrist_image = None, None, None
201
+ for key in image_observations:
202
+ # Note the "left" below refers to the left camera in the stereo pair.
203
+ # The model is only trained on left stereo cams, so we only feed those.
204
+ if args.left_camera_id in key and "left" in key:
205
+ left_image = image_observations[key]
206
+ elif args.right_camera_id in key and "left" in key:
207
+ right_image = image_observations[key]
208
+ elif args.wrist_camera_id in key and "left" in key:
209
+ wrist_image = image_observations[key]
210
+
211
+ # Drop the alpha dimension
212
+ left_image = left_image[..., :3]
213
+ right_image = right_image[..., :3]
214
+ wrist_image = wrist_image[..., :3]
215
+
216
+ # Convert to RGB
217
+ left_image = left_image[..., ::-1]
218
+ right_image = right_image[..., ::-1]
219
+ wrist_image = wrist_image[..., ::-1]
220
+
221
+ # In addition to image observations, also capture the proprioceptive state
222
+ robot_state = obs_dict["robot_state"]
223
+ cartesian_position = np.array(robot_state["cartesian_position"])
224
+ joint_position = np.array(robot_state["joint_positions"])
225
+ gripper_position = np.array([robot_state["gripper_position"]])
226
+
227
+ # Save the images to disk so that they can be viewed live while the robot is running
228
+ # Create one combined image to make live viewing easy
229
+ if save_to_disk:
230
+ combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
231
+ combined_image = Image.fromarray(combined_image)
232
+ combined_image.save("robot_camera_views.png")
233
+
234
+ return {
235
+ "left_image": left_image,
236
+ "right_image": right_image,
237
+ "wrist_image": wrist_image,
238
+ "cartesian_position": cartesian_position,
239
+ "joint_position": joint_position,
240
+ "gripper_position": gripper_position,
241
+ }
242
+
243
+
244
+ if __name__ == "__main__":
245
+ args: Args = tyro.cli(Args)
246
+ main(args)
openpi/examples/inference.ipynb ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import dataclasses\n",
10
+ "\n",
11
+ "import jax\n",
12
+ "\n",
13
+ "from openpi.models import model as _model\n",
14
+ "from openpi.policies import droid_policy\n",
15
+ "from openpi.policies import policy_config as _policy_config\n",
16
+ "from openpi.shared import download\n",
17
+ "from openpi.training import config as _config\n",
18
+ "from openpi.training import data_loader as _data_loader"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "# Policy inference\n",
26
+ "\n",
27
+ "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "config = _config.get_config(\"pi0_fast_droid\")\n",
37
+ "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
38
+ "\n",
39
+ "# Create a trained policy.\n",
40
+ "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
41
+ "\n",
42
+ "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
43
+ "example = droid_policy.make_droid_example()\n",
44
+ "result = policy.infer(example)\n",
45
+ "\n",
46
+ "# Delete the policy to free up memory.\n",
47
+ "del policy\n",
48
+ "\n",
49
+ "print(\"Actions shape:\", result[\"actions\"].shape)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {},
55
+ "source": [
56
+ "# Working with a live model\n",
57
+ "\n",
58
+ "\n",
59
+ "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "config = _config.get_config(\"pi0_aloha_sim\")\n",
69
+ "\n",
70
+ "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
71
+ "key = jax.random.key(0)\n",
72
+ "\n",
73
+ "# Create a model from the checkpoint.\n",
74
+ "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
75
+ "\n",
76
+ "# We can create fake observations and actions to test the model.\n",
77
+ "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
78
+ "\n",
79
+ "# Sample actions from the model.\n",
80
+ "loss = model.compute_loss(key, obs, act)\n",
81
+ "print(\"Loss shape:\", loss.shape)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# Reduce the batch size to reduce memory usage.\n",
98
+ "config = dataclasses.replace(config, batch_size=2)\n",
99
+ "\n",
100
+ "# Load a single batch of data. This is the same data that will be used during training.\n",
101
+ "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
102
+ "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
103
+ "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
104
+ "obs, act = next(iter(loader))\n",
105
+ "\n",
106
+ "# Sample actions from the model.\n",
107
+ "loss = model.compute_loss(key, obs, act)\n",
108
+ "\n",
109
+ "# Delete the model to free up memory.\n",
110
+ "del model\n",
111
+ "\n",
112
+ "print(\"Loss shape:\", loss.shape)"
113
+ ]
114
+ }
115
+ ],
116
+ "metadata": {
117
+ "kernelspec": {
118
+ "display_name": ".venv",
119
+ "language": "python",
120
+ "name": "python3"
121
+ },
122
+ "language_info": {
123
+ "codemirror_mode": {
124
+ "name": "ipython",
125
+ "version": 3
126
+ },
127
+ "file_extension": ".py",
128
+ "mimetype": "text/x-python",
129
+ "name": "python",
130
+ "nbconvert_exporter": "python",
131
+ "pygments_lexer": "ipython3",
132
+ "version": "3.11.9"
133
+ }
134
+ },
135
+ "nbformat": 4,
136
+ "nbformat_minor": 2
137
+ }
openpi/examples/libero/Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for the LIBERO benchmark.
2
+
3
+ # Build the container:
4
+ # docker build . -t libero -f examples/libero/Dockerfile
5
+
6
+ # Run the container:
7
+ # docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
8
+
9
+ FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
10
+ COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
11
+
12
+ RUN apt-get update && \
13
+ apt-get install -y \
14
+ make \
15
+ g++ \
16
+ clang \
17
+ libosmesa6-dev \
18
+ libgl1-mesa-glx \
19
+ libglew-dev \
20
+ libglfw3-dev \
21
+ libgles2-mesa-dev \
22
+ libglib2.0-0 \
23
+ libsm6 \
24
+ libxrender1 \
25
+ libxext6
26
+
27
+ WORKDIR /app
28
+
29
+ # Copy from the cache instead of linking since it's a mounted volume
30
+ ENV UV_LINK_MODE=copy
31
+
32
+ # Write the virtual environment outside of the project directory so it doesn't
33
+ # leak out of the container when we mount the application code.
34
+ ENV UV_PROJECT_ENVIRONMENT=/.venv
35
+
36
+ # Copy the requirements files so we can install dependencies.
37
+ # The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
38
+ # This strategy is best for development-style usage.
39
+ COPY ./examples/libero/requirements.txt /tmp/requirements.txt
40
+ COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
41
+ COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
42
+
43
+ # Install python dependencies.
44
+ RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
45
+ RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
46
+ ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
47
+
48
+ # Create a default config file to avoid an input prompt from LIBERO's init script.
49
+ # https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
50
+ ENV LIBERO_CONFIG_PATH=/tmp/libero
51
+ RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
52
+ benchmark_root: /app/third_party/libero/libero/libero
53
+ bddl_files: /app/third_party/libero/libero/libero/bddl_files
54
+ init_states: /app/third_party/libero/libero/libero/init_files
55
+ datasets: /app/third_party/libero/libero/datasets
56
+ assets: /app/third_party/libero/libero/libero/assets
57
+ EOF
58
+
59
+ CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]
openpi/examples/libero/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LIBERO Benchmark
2
+
3
+ This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
4
+
5
+ Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
6
+
7
+ This example requires git submodules to be initialized. Don't forget to run:
8
+
9
+ ```bash
10
+ git submodule update --init --recursive
11
+ ```
12
+
13
+ ## With Docker (recommended)
14
+
15
+ ```bash
16
+ # Grant access to the X11 server:
17
+ sudo xhost +local:docker
18
+
19
+ # To run with the default checkpoint and task suite:
20
+ SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
21
+
22
+ # To run with glx for Mujoco instead (use this if you have egl errors):
23
+ MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
24
+ ```
25
+
26
+ You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
27
+ For example:
28
+
29
+ ```bash
30
+ # To load a custom checkpoint (located in the top-level openpi/ directory):
31
+ export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
32
+
33
+ # To run the libero_10 task suite:
34
+ export CLIENT_ARGS="--args.task-suite-name libero_10"
35
+ ```
36
+
37
+ ## Without Docker (not recommended)
38
+
39
+ Terminal window 1:
40
+
41
+ ```bash
42
+ # Create virtual environment
43
+ uv venv --python 3.8 examples/libero/.venv
44
+ source examples/libero/.venv/bin/activate
45
+ uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
46
+ uv pip install -e packages/openpi-client
47
+ uv pip install -e third_party/libero
48
+ export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
49
+
50
+ # Run the simulation
51
+ python examples/libero/main.py
52
+
53
+ # To run with glx for Mujoco instead (use this if you have egl errors):
54
+ MUJOCO_GL=glx python examples/libero/main.py
55
+ ```
56
+
57
+ Terminal window 2:
58
+
59
+ ```bash
60
+ # Run the server
61
+ uv run scripts/serve_policy.py --env LIBERO
62
+ ```
63
+
64
+ ## Results
65
+
66
+ If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
67
+ checkpoint was trained in openpi with the `pi05_libero` config.
68
+
69
+ | Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
70
+ |-------|---------------|---------------|-------------|-----------|---------|
71
+ | π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85
openpi/examples/libero/compose.yml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run with:
2
+ # docker compose -f examples/libero/compose.yml up --build
3
+ services:
4
+ runtime:
5
+ image: libero
6
+ depends_on:
7
+ - openpi_server
8
+ build:
9
+ context: ../..
10
+ dockerfile: examples/libero/Dockerfile
11
+ init: true
12
+ tty: true
13
+ network_mode: host
14
+ privileged: true
15
+ volumes:
16
+ - $PWD:/app
17
+ - ../../data:/data
18
+ - /tmp/.X11-unix:/tmp/.X11-unix:ro
19
+ environment:
20
+ - CLIENT_ARGS
21
+ - DISPLAY=$DISPLAY
22
+ - MUJOCO_GL=${MUJOCO_GL:-egl}
23
+ deploy:
24
+ resources:
25
+ reservations:
26
+ devices:
27
+ - driver: nvidia
28
+ count: 1
29
+ capabilities: [gpu]
30
+
31
+ openpi_server:
32
+ image: openpi_server
33
+ build:
34
+ context: ../..
35
+ dockerfile: scripts/docker/serve_policy.Dockerfile
36
+ init: true
37
+ tty: true
38
+ network_mode: host
39
+ volumes:
40
+ - $PWD:/app
41
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
42
+ environment:
43
+ - SERVER_ARGS
44
+ - OPENPI_DATA_HOME=/openpi_assets
45
+ - IS_DOCKER=true
46
+
47
+ # Comment out this block if not running on a machine with GPUs.
48
+ deploy:
49
+ resources:
50
+ reservations:
51
+ devices:
52
+ - driver: nvidia
53
+ count: 1
54
+ capabilities: [gpu]
openpi/examples/libero/convert_libero_data_to_lerobot.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal example script for converting a dataset to LeRobot format.
3
+
4
+ We use the Libero dataset (stored in RLDS) for this example, but it can be easily
5
+ modified for any other data you have saved in a custom format.
6
+
7
+ Usage:
8
+ uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
9
+
10
+ If you want to push your dataset to the Hugging Face Hub, you can use the following command:
11
+ uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
12
+
13
+ Note: to run the script, you need to install tensorflow_datasets:
14
+ `uv pip install tensorflow tensorflow_datasets`
15
+
16
+ You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
17
+ The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
18
+ Running this conversion script will take approximately 30 minutes.
19
+ """
20
+
21
+ import shutil
22
+
23
+ from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
24
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
25
+ import tensorflow_datasets as tfds
26
+ import tyro
27
+
28
+ REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
29
+ RAW_DATASET_NAMES = [
30
+ "libero_10_no_noops",
31
+ "libero_goal_no_noops",
32
+ "libero_object_no_noops",
33
+ "libero_spatial_no_noops",
34
+ ] # For simplicity we will combine multiple Libero datasets into one training dataset
35
+
36
+
37
+ def main(data_dir: str, *, push_to_hub: bool = False):
38
+ # Clean up any existing dataset in the output directory
39
+ output_path = HF_LEROBOT_HOME / REPO_NAME
40
+ if output_path.exists():
41
+ shutil.rmtree(output_path)
42
+
43
+ # Create LeRobot dataset, define features to store
44
+ # OpenPi assumes that proprio is stored in `state` and actions in `action`
45
+ # LeRobot assumes that dtype of image data is `image`
46
+ dataset = LeRobotDataset.create(
47
+ repo_id=REPO_NAME,
48
+ robot_type="panda",
49
+ fps=10,
50
+ features={
51
+ "image": {
52
+ "dtype": "image",
53
+ "shape": (256, 256, 3),
54
+ "names": ["height", "width", "channel"],
55
+ },
56
+ "wrist_image": {
57
+ "dtype": "image",
58
+ "shape": (256, 256, 3),
59
+ "names": ["height", "width", "channel"],
60
+ },
61
+ "state": {
62
+ "dtype": "float32",
63
+ "shape": (8,),
64
+ "names": ["state"],
65
+ },
66
+ "actions": {
67
+ "dtype": "float32",
68
+ "shape": (7,),
69
+ "names": ["actions"],
70
+ },
71
+ },
72
+ image_writer_threads=10,
73
+ image_writer_processes=5,
74
+ )
75
+
76
+ # Loop over raw Libero datasets and write episodes to the LeRobot dataset
77
+ # You can modify this for your own data format
78
+ for raw_dataset_name in RAW_DATASET_NAMES:
79
+ raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
80
+ for episode in raw_dataset:
81
+ for step in episode["steps"].as_numpy_iterator():
82
+ dataset.add_frame(
83
+ {
84
+ "image": step["observation"]["image"],
85
+ "wrist_image": step["observation"]["wrist_image"],
86
+ "state": step["observation"]["state"],
87
+ "actions": step["action"],
88
+ "task": step["language_instruction"].decode(),
89
+ }
90
+ )
91
+ dataset.save_episode()
92
+
93
+ # Optionally push to the Hugging Face Hub
94
+ if push_to_hub:
95
+ dataset.push_to_hub(
96
+ tags=["libero", "panda", "rlds"],
97
+ private=False,
98
+ push_videos=True,
99
+ license="apache-2.0",
100
+ )
101
+
102
+
103
+ if __name__ == "__main__":
104
+ tyro.cli(main)