diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..3aa38fcf3dc958f64b12b5e599a111a13324c372
Binary files /dev/null and b/.DS_Store differ
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b694934fbf9b49ee808b6dfc7292c28e2c46a97e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+.venv
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..13566b81b018ad684f3a35fee301741b2734c8f4
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000000000000000000000000000000000..130bb778733e7cea88ccf9fb608fb85cf7071220
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..43a4cf52a2cb1f8e4df8fbda947e0dcc943c572c
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/openpi.iml b/.idea/openpi.iml
new file mode 100644
index 0000000000000000000000000000000000000000..ec63674cd7f4d511fb06cd63eaeba166d6bc0dd8
--- /dev/null
+++ b/.idea/openpi.iml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..812df4003c1bce12fbc63118292196d3245933c3
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000000000000000000000000000000000000..dc83168c7eefe295339a8da8c79b200849ab54a9
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,66 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "customColor": "",
+ "associatedIndex": 0
+}
+
+
+
+
+
+ {
+ "keyToString": {
+ "ModuleVcsDetector.initialDetectionPerformed": "true",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true",
+ "RunOnceActivity.git.unshallow": "true",
+ "git-widget-placeholder": "master",
+ "node.js.detected.package.eslint": "true",
+ "node.js.detected.package.tslint": "true",
+ "node.js.selected.package.eslint": "(autodetect)",
+ "node.js.selected.package.tslint": "(autodetect)",
+ "nodejs_package_manager_path": "npm",
+ "settings.editor.selected.configurable": "dev.sweep.assistant.settings.SweepSettingsConfigurable",
+ "vue.rearranger.settings.migration": "true"
+ }
+}
+
+
+
+
+
+
+
+
+
+
+
+ 1758629365517
+
+
+ 1758629365517
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/handler.py b/handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed59f8e789d12f64e2cda4b77932cdfcb3f6d0d6
--- /dev/null
+++ b/handler.py
@@ -0,0 +1,218 @@
+import base64
+import json
+import os
+import sys
+from io import BytesIO
+from typing import Any, Dict, List
+
+import numpy as np
+from PIL import Image
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "openpi", "src"))
+
+from openpi.policies import policy_config
+from openpi.training import config as train_config
+
+
+class EndpointHandler:
+ def __init__(self, path: str = ""):
+ """
+ Initialize the handler for pi0 model inference using openpi infrastructure.
+
+ Args:
+ path: Path to the model weights directory
+ """
+ # Set model path from environment variable or use provided path
+ model_path = os.environ.get("MODEL_PATH", path)
+ if not model_path:
+ model_path = "weights/pi0"
+
+ # Load the config.json to determine model type
+ config_path = os.path.join(model_path, "config.json")
+ with open(config_path, "r") as f:
+ model_config = json.load(f)
+
+ model_type = model_config.get("type", "pi0")
+
+ # Create training config based on model type
+ # This uses the openpi config system
+ if model_type == "pi0":
+ self.train_config = train_config.get_config("pi0")
+ else:
+ # Default to pi0 if type not recognized
+ self.train_config = train_config.get_config("pi0")
+
+ # Create trained policy using openpi infrastructure
+ # This handles all the model loading, preprocessing, etc.
+ self.policy = policy_config.create_trained_policy(
+ self.train_config,
+ model_path,
+ pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
+ )
+
+ # Default number of inference steps
+ self.default_num_steps = 50
+
+ def _decode_base64_image(self, base64_str: str) -> np.ndarray:
+ """
+ Decode base64 image string to numpy array.
+
+ Args:
+ base64_str: Base64 encoded image string
+
+ Returns:
+ numpy array of shape (H, W, 3) with values in [0, 255]
+ """
+ # Remove data URL prefix if present
+ if base64_str.startswith("data:image"):
+ base64_str = base64_str.split(",", 1)[1]
+
+ # Decode base64
+ image_bytes = base64.b64decode(base64_str)
+
+ # Convert to PIL Image and then to numpy array
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
+ image_array = np.array(image)
+
+ return image_array
+
+ def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]:
+ """
+ Prepare observation dictionary in the format expected by openpi.
+
+ Args:
+ images: Dictionary mapping camera names to base64 encoded images
+ state: List of robot state values
+ prompt: Optional text prompt
+
+ Returns:
+ Observation dictionary in openpi format
+ """
+ # Decode and process images
+ processed_images = {}
+
+ # Map input camera names to expected openpi format
+ # Based on the config, pi0 expects specific camera names
+ camera_mapping = {
+ "camera0": "cam_high", # base camera
+ "camera1": "cam_left_wrist", # left wrist camera
+ "camera2": "cam_right_wrist", # right wrist camera
+ # Alternative mappings
+ "base_camera": "cam_high",
+ "left_wrist": "cam_left_wrist",
+ "right_wrist": "cam_right_wrist",
+ # Direct mappings
+ "cam_high": "cam_high",
+ "cam_left_wrist": "cam_left_wrist",
+ "cam_right_wrist": "cam_right_wrist"
+ }
+
+ for input_name, image_b64 in images.items():
+ # Map to openpi expected name
+ openpi_name = camera_mapping.get(input_name, input_name)
+
+ # Decode image
+ image_array = self._decode_base64_image(image_b64)
+
+ # Resize to expected resolution if needed
+ if image_array.shape[:2] != (224, 224):
+ image_pil = Image.fromarray(image_array)
+ image_resized = image_pil.resize((224, 224))
+ image_array = np.array(image_resized)
+
+ # Convert to format expected by openpi (H, W, C) with uint8
+ processed_images[openpi_name] = image_array.astype(np.uint8)
+
+ # Ensure we have the required cameras, create dummy ones if missing
+ required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
+ for cam_name in required_cameras:
+ if cam_name not in processed_images:
+ # Create a black dummy image
+ processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8)
+
+ # Prepare state
+ state_array = np.array(state, dtype=np.float32)
+
+ # Create observation dict in openpi format
+ observation = {
+ "state": state_array,
+ "images": processed_images,
+ }
+
+ # Add prompt if provided
+ if prompt:
+ observation["prompt"] = prompt
+
+ return observation
+
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
+ """
+ Main inference function called by HuggingFace endpoint.
+
+ Args:
+ data: Input data dictionary containing:
+ - inputs: Dictionary with:
+ - images: Dict mapping camera names to base64 encoded images
+ - state: List of robot state values
+ - prompt: Optional text prompt
+ - num_actions: Optional, number of actions to predict (default: 50)
+ - noise: Optional, noise array for sampling
+
+ Returns:
+ List containing prediction results
+ """
+ try:
+ inputs = data.get("inputs", {})
+
+ # Extract inputs
+ images = inputs.get("images", {})
+ state = inputs.get("state", [])
+ prompt = inputs.get("prompt", "")
+ num_actions = inputs.get("num_actions", self.default_num_steps)
+ noise_input = inputs.get("noise", None)
+
+ # Validate inputs
+ if not images:
+ raise ValueError("No images provided")
+ if not state:
+ raise ValueError("No state provided")
+
+ # Prepare observation using openpi format
+ observation = self._prepare_observation(images, state, prompt)
+
+ # Prepare noise if provided
+ noise = None
+ if noise_input is not None:
+ noise = np.array(noise_input, dtype=np.float32)
+
+ # Run inference using openpi policy
+ # This handles all the preprocessing, model inference, and postprocessing
+ result = self.policy.infer(observation, noise=noise)
+
+ # Extract actions from result
+ actions = result["actions"]
+
+ # Convert to list format for JSON serialization
+ if isinstance(actions, np.ndarray):
+ actions_list = actions.tolist()
+ else:
+ actions_list = actions
+
+ # Return in expected format
+ return [{
+ "actions": actions_list,
+ "num_actions": len(actions_list),
+ "action_horizon": len(actions_list),
+ "action_dim": len(actions_list[0]) if actions_list else 0,
+ "success": True,
+ "metadata": {
+ "model_type": self.train_config.model.model_type.value,
+ "policy_metadata": getattr(self.policy, '_metadata', {})
+ }
+ }]
+
+ except Exception as e:
+ return [{
+ "error": str(e),
+ "success": False
+ }]
diff --git a/openpi/CONTRIBUTING.md b/openpi/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..a8f3e80bab3c6a5fc45ed251173f3eebd829a959
--- /dev/null
+++ b/openpi/CONTRIBUTING.md
@@ -0,0 +1,33 @@
+# Contributing to openpi
+
+We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
+
+## Issues and feature requests
+
+You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
+
+If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
+
+- Your OS type and version and the version of Python you are using
+- Code that allows us to reproduce your bug, including all dependencies
+- Traceback of any exception
+- Any other information that would help us, such as a screenshot
+
+In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
+
+If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
+
+- The motivation for the feature
+- A description of the problem you are trying to solve or your use case
+- Enough information for us to understand the nature of the request
+- Some information for how you intend to use it (this might help us in understanding the motivation!)
+
+We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
+
+## Submitting a pull request
+
+If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
+
+- Make sure that your PR has a clear title and description
+- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
+- Make sure your PR passes all tests
diff --git a/openpi/LICENSE b/openpi/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/openpi/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/openpi/README.md b/openpi/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e4f8ee24a2a6ca3488ce77273354ed1eb0c97aa6
--- /dev/null
+++ b/openpi/README.md
@@ -0,0 +1,323 @@
+# openpi
+
+openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).
+
+Currently, this repo contains three types of models:
+- the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based vision-language-action model (VLA).
+- the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer.
+- the [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05), an upgraded version of π₀ with better open-world generalization trained with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation). Note that, in this repository, we currently only support the flow matching head for both $\pi_{0.5}$ training and inference.
+
+For all models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets.
+
+This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see!
+
+## Updates
+
+- [Sept 2025] We released PyTorch support in openpi.
+- [Sept 2025] We released pi05, an upgraded version of pi0 with better open-world generalization.
+- [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training.
+- [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID.
+
+
+## Requirements
+
+To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training.
+
+| Mode | Memory Required | Example GPU |
+| ------------------ | --------------- | ------------------ |
+| Inference | > 8 GB | RTX 4090 |
+| Fine-Tuning (LoRA) | > 22.5 GB | RTX 4090 |
+| Fine-Tuning (Full) | > 70 GB | A100 (80GB) / H100 |
+
+The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems.
+
+## Installation
+
+When cloning this repo, make sure to update submodules:
+
+```bash
+git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
+
+# Or if you already cloned the repo:
+git submodule update --init --recursive
+```
+
+We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:
+
+```bash
+GIT_LFS_SKIP_SMUDGE=1 uv sync
+GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
+```
+
+NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
+
+**Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details.
+
+
+
+
+## Model Checkpoints
+
+### Base Models
+We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning.
+
+| Model | Use Case | Description | Checkpoint Path |
+| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
+| $\pi_0$ | Fine-Tuning | Base [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_base` |
+| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` |
+| $\pi_{0.5}$ | Fine-Tuning | Base [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05) for fine-tuning | `gs://openpi-assets/checkpoints/pi05_base` |
+
+### Fine-Tuned Models
+We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.
+
+| Model | Use Case | Description | Checkpoint Path |
+| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
+| $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid` |
+| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `gs://openpi-assets/checkpoints/pi0_droid` |
+| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can fold diverse towels 0-shot on ALOHA robot platforms | `gs://openpi-assets/checkpoints/pi0_aloha_towel` |
+| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can unpack food from a tupperware container | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` |
+| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on public [ALOHA](https://dit-policy.github.io/) data: can uncap a pen | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
+| $\pi_{0.5}$-LIBERO | Inference | $\pi_{0.5}$ model fine-tuned for the [LIBERO](https://libero-project.github.io/datasets) benchmark: gets state-of-the-art performance (see [LIBERO README](examples/libero/README.md)) | `gs://openpi-assets/checkpoints/pi05_libero` |
+| $\pi_{0.5}$-DROID | Inference / Fine-Tuning | $\pi_{0.5}$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/) with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation): fast inference and good language-following | `gs://openpi-assets/checkpoints/pi05_droid` |
+
+
+By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
+
+
+
+
+## Running Inference for a Pre-Trained Model
+
+Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model):
+```python
+from openpi.training import config as _config
+from openpi.policies import policy_config
+from openpi.shared import download
+
+config = _config.get_config("pi05_droid")
+checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")
+
+# Create a trained policy.
+policy = policy_config.create_trained_policy(config, checkpoint_dir)
+
+# Run inference on a dummy example.
+example = {
+ "observation/exterior_image_1_left": ...,
+ "observation/wrist_image_left": ...,
+ ...
+ "prompt": "pick up the fork"
+}
+action_chunk = policy.infer(example)["actions"]
+```
+You can also test this out in the [example notebook](examples/inference.ipynb).
+
+We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots.
+
+**Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate.
+
+**Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details.
+
+
+
+
+
+## Fine-Tuning Base Models on Your Own Data
+
+We will fine-tune the $\pi_{0.5}$ model on the [LIBERO dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps:
+1. Convert your data to a LeRobot dataset (which we use for training)
+2. Defining training configs and running training
+3. Spinning up a policy server and running inference
+
+### 1. Convert your data to a LeRobot dataset
+
+We provide a minimal example script for converting LIBERO data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw LIBERO dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with:
+
+```bash
+uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data
+```
+
+**Note:** If you just want to fine-tune on LIBERO, you can skip this step, because our LIBERO fine-tuning configs point to a pre-converted LIBERO dataset. This step is merely an example that you can adapt to your own data.
+
+### 2. Defining training configs and running training
+
+To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for LIBERO below, which you can modify for your own dataset:
+
+- [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the LIBERO environment to the model and vice versa. Will be used for both, training and inference.
+- [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw LIBERO data from LeRobot dataset for training.
+- [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader.
+
+We provide example fine-tuning configs for [π₀](src/openpi/training/config.py), [π₀-FAST](src/openpi/training/config.py), and [π₀.₅](src/openpi/training/config.py) on LIBERO data.
+
+Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config:
+
+```bash
+uv run scripts/compute_norm_stats.py --config-name pi05_libero
+```
+
+Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config):
+
+```bash
+XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_libero --exp-name=my_experiment --overwrite
+```
+
+The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
+
+**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.
+
+### 3. Spinning up a policy server and running inference
+
+Once training is complete, we can run inference by spinning up a policy server and then querying it from a LIBERO evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
+
+```bash
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_libero --policy.dir=checkpoints/pi05_libero/my_experiment/20000
+```
+
+This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run an evaluation script (or robot runtime) that queries the server.
+
+For running the LIBERO eval in particular, we provide (and recommend using) a Dockerized workflow that handles both the policy server and the evaluation script together. See the [LIBERO README](examples/libero/README.md) for more details.
+
+If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
+
+
+
+### More Examples
+
+We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
+- [ALOHA Simulator](examples/aloha_sim)
+- [ALOHA Real](examples/aloha_real)
+- [UR5](examples/ur5)
+
+## PyTorch Support
+
+openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions! The PyTorch implementation has been validated on the LIBERO benchmark (both inference and finetuning). A few features are currently not supported (this may change in the future):
+
+- The π₀-FAST model
+- Mixed precision training
+- FSDP (fully-sharded data parallelism) training
+- LoRA (low-rank adaptation) training
+- EMA (exponential moving average) weights during training
+
+### Setup
+1. Make sure that you have the latest version of all dependencies installed: `uv sync`
+
+2. Double check that you have transformers 4.53.2 installed: `uv pip show transformers`
+
+3. Apply the transformers library patches:
+ ```bash
+ cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
+ ```
+
+This overwrites several files in the transformers library with necessary model changes: 1) supporting AdaRMS, 2) correctly controlling the precision of activations, and 3) allowing the KV cache to be used without being updated.
+
+**WARNING**: With the default uv link mode (hardlink), this will permanently affect the transformers library in your uv cache, meaning the changes will survive reinstallations of transformers and could even propagate to other projects that use transformers. To fully undo this operation, you must run `uv cache clean transformers`.
+
+### Converting JAX Models to PyTorch
+
+To convert a JAX model checkpoint to PyTorch format:
+
+```bash
+uv run examples/convert_jax_model_to_pytorch.py \
+ --checkpoint_dir /path/to/jax/checkpoint \
+ --config_name \
+ --output_path /path/to/converted/pytorch/checkpoint
+```
+
+### Running Inference with PyTorch
+
+The PyTorch implementation uses the same API as the JAX version - you only need to change the checkpoint path to point to the converted PyTorch model:
+
+```python
+from openpi.training import config as _config
+from openpi.policies import policy_config
+from openpi.shared import download
+
+config = _config.get_config("pi05_droid")
+checkpoint_dir = "/path/to/converted/pytorch/checkpoint"
+
+# Create a trained policy (automatically detects PyTorch format)
+policy = policy_config.create_trained_policy(config, checkpoint_dir)
+
+# Run inference (same API as JAX)
+action_chunk = policy.infer(example)["actions"]
+```
+
+### Policy Server with PyTorch
+
+The policy server works identically with PyTorch models - just point to the converted checkpoint directory:
+
+```bash
+uv run scripts/serve_policy.py policy:checkpoint \
+ --policy.config=pi05_droid \
+ --policy.dir=/path/to/converted/pytorch/checkpoint
+```
+
+### Finetuning with PyTorch
+
+To finetune a model in PyTorch:
+
+1. Convert the JAX base model to PyTorch format:
+ ```bash
+ uv run examples/convert_jax_model_to_pytorch.py \
+ --config_name \
+ --checkpoint_dir /path/to/jax/base/model \
+ --output_path /path/to/pytorch/base/model
+ ```
+
+2. Specify the converted PyTorch model path in your config using `pytorch_weight_path`
+
+3. Launch training using one of these modes:
+
+```bash
+# Single GPU training:
+uv run scripts/train_pytorch.py --exp_name --save_interval
+
+# Example:
+uv run scripts/train_pytorch.py debug --exp_name pytorch_test
+uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint
+
+# Multi-GPU training (single node):
+uv run torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name
+
+# Example:
+uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
+uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
+
+# Multi-Node Training:
+uv run torchrun \
+ --nnodes= \
+ --nproc_per_node= \
+ --node_rank= \
+ --master_addr= \
+ --master_port= \
+ scripts/train_pytorch.py --exp_name= --save_interval
+```
+
+### Precision Settings
+
+JAX and PyTorch implementations handle precision as follows:
+
+**JAX:**
+1. Inference: most weights and computations in bfloat16, with a few computations in float32 for stability
+2. Training: defaults to mixed precision: weights and gradients in float32, (most) activations and computations in bfloat16. You can change to full float32 training by setting `dtype` to float32 in the config.
+
+**PyTorch:**
+1. Inference: matches JAX -- most weights and computations in bfloat16, with a few weights converted to float32 for stability
+2. Training: supports either full bfloat16 (default) or full float32. You can change it by setting `pytorch_training_precision` in the config. bfloat16 uses less memory but exhibits higher losses compared to float32. Mixed precision is not yet supported.
+
+With torch.compile, inference speed is comparable between JAX and PyTorch.
+
+## Troubleshooting
+
+We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).
+
+| Issue | Resolution |
+| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). |
+| Training runs out of GPU memory | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` (or higher) before running training to allow JAX to use more GPU memory. You can also use `--fsdp-devices ` where `` is your number of GPUs, to enable [fully-sharded data parallelism](https://engineering.fb.com/2021/07/15/open-source/fsdp/), which reduces memory usage in exchange for slower training (the amount of slowdown depends on your particular setup). If you are still running out of memory, you may way to consider disabling EMA. |
+| Policy server connection errors | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server. |
+| Missing norm stats error when training | Run `scripts/compute_norm_stats.py` with your config name before starting training. |
+| Dataset download fails | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`). |
+| CUDA/GPU errors | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. |
+| Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
+| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
+| Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
diff --git a/openpi/config.json b/openpi/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f59100816e89991f8f1fb643666b83fd4910b15a
--- /dev/null
+++ b/openpi/config.json
@@ -0,0 +1,85 @@
+{
+ "type": "pi0",
+ "n_obs_steps": 1,
+ "input_features": {
+ "observation.state": {
+ "type": "STATE",
+ "shape": [
+ 6
+ ]
+ },
+ "observation.images.camera0": {
+ "type": "VISUAL",
+ "shape": [
+ 3,
+ 480,
+ 640
+ ]
+ },
+ "observation.images.camera1": {
+ "type": "VISUAL",
+ "shape": [
+ 3,
+ 480,
+ 640
+ ]
+ },
+ "observation.images.camera2": {
+ "type": "VISUAL",
+ "shape": [
+ 3,
+ 480,
+ 640
+ ]
+ }
+ },
+ "output_features": {
+ "action": {
+ "type": "ACTION",
+ "shape": [
+ 6
+ ]
+ }
+ },
+ "device": "cpu",
+ "use_amp": false,
+ "push_to_hub": true,
+ "repo_id": null,
+ "private": null,
+ "tags": null,
+ "license": null,
+ "chunk_size": 50,
+ "n_action_steps": 50,
+ "normalization_mapping": {
+ "VISUAL": "IDENTITY",
+ "STATE": "MEAN_STD",
+ "ACTION": "MEAN_STD"
+ },
+ "max_state_dim": 32,
+ "max_action_dim": 32,
+ "resize_imgs_with_padding": [
+ 224,
+ 224
+ ],
+ "empty_cameras": 0,
+ "adapt_to_pi_aloha": false,
+ "use_delta_joint_actions_aloha": false,
+ "tokenizer_max_length": 48,
+ "proj_width": 1024,
+ "num_steps": 10,
+ "use_cache": true,
+ "attention_implementation": "eager",
+ "freeze_vision_encoder": true,
+ "train_expert_only": false,
+ "train_state_proj": true,
+ "optimizer_lr": 2.5e-05,
+ "optimizer_betas": [
+ 0.9,
+ 0.95
+ ],
+ "optimizer_eps": 1e-08,
+ "optimizer_weight_decay": 1e-10,
+ "scheduler_warmup_steps": 1000,
+ "scheduler_decay_steps": 30000,
+ "scheduler_decay_lr": 2.5e-06
+}
\ No newline at end of file
diff --git a/openpi/docs/docker.md b/openpi/docs/docker.md
new file mode 100644
index 0000000000000000000000000000000000000000..6449278019e8c596e444def91f96771f836453cf
--- /dev/null
+++ b/openpi/docs/docker.md
@@ -0,0 +1,25 @@
+### Docker Setup
+
+All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
+
+- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
+- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
+- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
+- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
+- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
+
+
+If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
+
+Build the Docker image and start the container with the following command:
+```bash
+docker compose -f scripts/docker/compose.yml up --build
+```
+
+To build and run the Docker image for a specific example, use the following command:
+```bash
+docker compose -f examples//compose.yml up --build
+```
+where `` is the name of the example you want to run.
+
+During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
\ No newline at end of file
diff --git a/openpi/docs/norm_stats.md b/openpi/docs/norm_stats.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc8f72c9de5b7f92a42a0ee6920770eddc606d78
--- /dev/null
+++ b/openpi/docs/norm_stats.md
@@ -0,0 +1,69 @@
+# Normalization statistics
+
+Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
+
+## Reloading normalization statistics
+
+When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
+
+**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
+
+```python
+TrainConfig(
+ ...
+ data=LeRobotAlohaDataConfig(
+ ...
+ assets=AssetsConfig(
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
+ asset_id="trossen",
+ ),
+ ),
+)
+```
+
+For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
+
+**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
+
+**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
+
+
+## Provided Pre-training Normalization Statistics
+
+Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
+| Robot | Description | Asset ID |
+|-------|-------------|----------|
+| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
+| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
+| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
+| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
+| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
+| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
+| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
+| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
+| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
+
+
+## Pi0 Model Action Space Definitions
+
+Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
+```
+ "dim_0:dim_5": "left arm joint angles",
+ "dim_6": "left arm gripper position",
+ "dim_7:dim_12": "right arm joint angles (for bi-manual only)",
+ "dim_13": "right arm gripper position (for bi-manual only)",
+
+ # For mobile robots:
+ "dim_14:dim_15": "x-y base velocity (for mobile robots only)",
+```
+
+The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
+
+For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
+
+General info for Pi robots:
+- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
+- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
+- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
+
+For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
diff --git a/openpi/docs/remote_inference.md b/openpi/docs/remote_inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..c5e4848402d70b757b7c80f03245124df5f2032f
--- /dev/null
+++ b/openpi/docs/remote_inference.md
@@ -0,0 +1,71 @@
+
+# Running openpi models remotely
+
+We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
+
+## Starting a remote policy server
+
+To start a remote policy server, you can simply run the following command:
+
+```bash
+uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
+```
+
+The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
+
+```bash
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
+```
+
+This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
+
+## Querying the remote policy server from your robot code
+
+We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
+
+First, install the `openpi-client` package in your robot environment:
+
+```bash
+cd $OPENPI_ROOT/packages/openpi-client
+pip install -e .
+```
+
+Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
+
+```python
+from openpi_client import image_tools
+from openpi_client import websocket_client_policy
+
+# Outside of episode loop, initialize the policy client.
+# Point to the host and port of the policy server (localhost and 8000 are the defaults).
+client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
+
+for step in range(num_steps):
+ # Inside the episode loop, construct the observation.
+ # Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
+ # We provide utilities for resizing images + uint8 conversion so you match the training routines.
+ # The typical resize_size for pre-trained pi0 models is 224.
+ # Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
+ observation = {
+ "observation/image": image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(img, 224, 224)
+ ),
+ "observation/wrist_image": image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(wrist_img, 224, 224)
+ ),
+ "observation/state": state,
+ "prompt": task_instruction,
+ }
+
+ # Call the policy server with the current observation.
+ # This returns an action chunk of shape (action_horizon, action_dim).
+ # Note that you typically only need to call the policy every N steps and execute steps
+ # from the predicted action chunk open-loop in the remaining steps.
+ action_chunk = client.infer(observation)["actions"]
+
+ # Execute the actions in the environment.
+ ...
+
+```
+
+Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
diff --git a/openpi/examples/aloha_real/Dockerfile b/openpi/examples/aloha_real/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..0b6c65b828d89cdc5c9ca8a1bff7c77fd6c8b48b
--- /dev/null
+++ b/openpi/examples/aloha_real/Dockerfile
@@ -0,0 +1,70 @@
+# Dockerfile for the Aloha real environment.
+
+# Build the container:
+# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
+
+FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
+SHELL ["/bin/bash", "-c"]
+
+ENV DEBIAN_FRONTEND=noninteractive
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends \
+ cmake \
+ curl \
+ libffi-dev \
+ python3-rosdep \
+ python3-rosinstall \
+ python3-rosinstall-generator \
+ whiptail \
+ git \
+ wget \
+ openssh-client \
+ ros-noetic-cv-bridge \
+ ros-noetic-usb-cam \
+ ros-noetic-realsense2-camera \
+ keyboard-configuration
+
+WORKDIR /root
+RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
+RUN chmod +x xsarm_amd64_install.sh
+RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
+
+COPY ./third_party/aloha /root/interbotix_ws/src/aloha
+RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
+
+# Install python 3.10 because this ROS image comes with 3.8
+RUN mkdir /python && \
+ cd /python && \
+ wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
+ tar -zxvf Python-3.10.14.tgz && \
+ cd Python-3.10.14 && \
+ ls -lhR && \
+ ./configure --enable-optimizations && \
+ make install && \
+ echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
+ echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
+ cd ~ && rm -rf /python && \
+ rm -rf /var/lib/apt/lists/*
+
+COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
+ENV UV_HTTP_TIMEOUT=120
+ENV UV_LINK_MODE=copy
+COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
+COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
+RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
+
+ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
+WORKDIR /app
+
+# Create an entrypoint script to run the setup commands, followed by the command passed in.
+RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
+#!/bin/bash
+source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
+EOF
+RUN chmod +x /usr/local/bin/entrypoint.sh
+
+ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
+CMD ["python3", "/app/examples/aloha_real/main.py"]
diff --git a/openpi/examples/aloha_real/README.md b/openpi/examples/aloha_real/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b4da86d5641d34e2fb65ad5982483a82d649922e
--- /dev/null
+++ b/openpi/examples/aloha_real/README.md
@@ -0,0 +1,126 @@
+# Run Aloha (Real Robot)
+
+This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
+
+## Prerequisites
+
+This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
+
+1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
+1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
+
+## With Docker
+
+```bash
+export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
+docker compose -f examples/aloha_real/compose.yml up --build
+```
+
+## Without Docker
+
+Terminal window 1:
+
+```bash
+# Create virtual environment
+uv venv --python 3.10 examples/aloha_real/.venv
+source examples/aloha_real/.venv/bin/activate
+uv pip sync examples/aloha_real/requirements.txt
+uv pip install -e packages/openpi-client
+
+# Run the robot
+python -m examples.aloha_real.main
+```
+
+Terminal window 2:
+
+```bash
+roslaunch aloha ros_nodes.launch
+```
+
+Terminal window 3:
+
+```bash
+uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
+```
+
+## **ALOHA Checkpoint Guide**
+
+
+The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
+
+While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
+
+
+---
+
+### **Toast Task**
+
+This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
+
+- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
+- **Prompt**: "take the toast out of the toaster"
+- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
+- **Object Distribution**:
+ - Works on both real toast and rubber fake toast
+ - Compatible with standard 2-slice toasters
+ - Works with plates of varying colors
+
+### **Scene Setup Guidelines**
+
+
+- The toaster should be positioned in the top-left quadrant of the workspace.
+- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
+- The plate should be placed roughly in the lower-center of the workspace.
+- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
+
+
+### **Towel Task**
+
+This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
+
+- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
+- **Prompt**: "fold the towel"
+- **Object Distribution**:
+ - Works on towels of varying solid colors
+ - Performance is worse on heavily textured or striped towels
+
+### **Scene Setup Guidelines**
+
+
+- The towel should be flattened and roughly centered on the table.
+- Choose a towel that does not blend in with the table surface.
+
+
+### **Tupperware Task**
+
+This task involves opening a tupperware filled with food and pouring the contents onto a plate.
+
+- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
+- **Prompt**: "open the tupperware and put the food on the plate"
+- **Objects needed**: Tupperware, food (or food-like items), and a plate.
+- **Object Distribution**:
+ - Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
+ - Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
+ - The policy has seen plates of varying solid colors.
+
+### **Scene Setup Guidelines**
+
+
+- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
+- Positioning:
+ - Tupperware should be on the left.
+ - Plate should be on the right or bottom.
+ - The tupperware flap should point toward the plate.
+
+## Training on your own Aloha dataset
+
+1. Convert the dataset to the LeRobot dataset v2.0 format.
+
+ We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
+
+
+2. Define a training config that uses the custom dataset.
+
+ We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
+
+IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
diff --git a/openpi/examples/aloha_real/compose.yml b/openpi/examples/aloha_real/compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..4e1e4ba927a2ef5a7f950083b85ae158b287c456
--- /dev/null
+++ b/openpi/examples/aloha_real/compose.yml
@@ -0,0 +1,66 @@
+# Run with:
+# docker compose -f examples/aloha_real/compose.yml up --build
+services:
+ runtime:
+ image: aloha_real
+ depends_on:
+ - aloha_ros_nodes
+ - ros_master
+ - openpi_server
+ build:
+ context: ../..
+ dockerfile: examples/aloha_real/Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ privileged: true
+ volumes:
+ - $PWD:/app
+ - ../../data:/data
+
+ aloha_ros_nodes:
+ image: aloha_real
+ depends_on:
+ - ros_master
+ build:
+ context: ../..
+ dockerfile: examples/aloha_real/Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ privileged: true
+ volumes:
+ - /dev:/dev
+ command: roslaunch --wait aloha ros_nodes.launch
+
+ ros_master:
+ image: ros:noetic-robot
+ network_mode: host
+ privileged: true
+ command:
+ - roscore
+
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/openpi/examples/aloha_real/constants.py b/openpi/examples/aloha_real/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2ea59f56a73464e348792387356ce695648fbf1
--- /dev/null
+++ b/openpi/examples/aloha_real/constants.py
@@ -0,0 +1,71 @@
+# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
+# ruff: noqa
+
+### Task parameters
+
+### ALOHA fixed constants
+DT = 0.001
+JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
+START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
+
+# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
+MASTER_GRIPPER_POSITION_OPEN = 0.02417
+MASTER_GRIPPER_POSITION_CLOSE = 0.01244
+PUPPET_GRIPPER_POSITION_OPEN = 0.05800
+PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
+
+# Gripper joint limits (qpos[6])
+MASTER_GRIPPER_JOINT_OPEN = 0.3083
+MASTER_GRIPPER_JOINT_CLOSE = -0.6842
+PUPPET_GRIPPER_JOINT_OPEN = 1.4910
+PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
+
+############################ Helper functions ############################
+
+MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
+ MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
+)
+PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
+ PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
+)
+MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
+ lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
+)
+PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
+ lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
+)
+MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
+
+MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
+ MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
+)
+PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
+ PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
+)
+MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
+ lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
+)
+PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
+ lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
+)
+MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
+
+MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
+PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
+
+MASTER_POS2JOINT = (
+ lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ + MASTER_GRIPPER_JOINT_CLOSE
+)
+MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
+ (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+)
+PUPPET_POS2JOINT = (
+ lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ + PUPPET_GRIPPER_JOINT_CLOSE
+)
+PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
+ (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+)
+
+MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
diff --git a/openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py b/openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a8ddcb24bbc5340b9cb179a0b43fccc46646f6
--- /dev/null
+++ b/openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py
@@ -0,0 +1,272 @@
+"""
+Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
+
+Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id /
+"""
+
+import dataclasses
+from pathlib import Path
+import shutil
+from typing import Literal
+
+import h5py
+from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
+import numpy as np
+import torch
+import tqdm
+import tyro
+
+
+@dataclasses.dataclass(frozen=True)
+class DatasetConfig:
+ use_videos: bool = True
+ tolerance_s: float = 0.0001
+ image_writer_processes: int = 10
+ image_writer_threads: int = 5
+ video_backend: str | None = None
+
+
+DEFAULT_DATASET_CONFIG = DatasetConfig()
+
+
+def create_empty_dataset(
+ repo_id: str,
+ robot_type: str,
+ mode: Literal["video", "image"] = "video",
+ *,
+ has_velocity: bool = False,
+ has_effort: bool = False,
+ dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
+) -> LeRobotDataset:
+ motors = [
+ "right_waist",
+ "right_shoulder",
+ "right_elbow",
+ "right_forearm_roll",
+ "right_wrist_angle",
+ "right_wrist_rotate",
+ "right_gripper",
+ "left_waist",
+ "left_shoulder",
+ "left_elbow",
+ "left_forearm_roll",
+ "left_wrist_angle",
+ "left_wrist_rotate",
+ "left_gripper",
+ ]
+ cameras = [
+ "cam_high",
+ "cam_low",
+ "cam_left_wrist",
+ "cam_right_wrist",
+ ]
+
+ features = {
+ "observation.state": {
+ "dtype": "float32",
+ "shape": (len(motors),),
+ "names": [
+ motors,
+ ],
+ },
+ "action": {
+ "dtype": "float32",
+ "shape": (len(motors),),
+ "names": [
+ motors,
+ ],
+ },
+ }
+
+ if has_velocity:
+ features["observation.velocity"] = {
+ "dtype": "float32",
+ "shape": (len(motors),),
+ "names": [
+ motors,
+ ],
+ }
+
+ if has_effort:
+ features["observation.effort"] = {
+ "dtype": "float32",
+ "shape": (len(motors),),
+ "names": [
+ motors,
+ ],
+ }
+
+ for cam in cameras:
+ features[f"observation.images.{cam}"] = {
+ "dtype": mode,
+ "shape": (3, 480, 640),
+ "names": [
+ "channels",
+ "height",
+ "width",
+ ],
+ }
+
+ if Path(LEROBOT_HOME / repo_id).exists():
+ shutil.rmtree(LEROBOT_HOME / repo_id)
+
+ return LeRobotDataset.create(
+ repo_id=repo_id,
+ fps=50,
+ robot_type=robot_type,
+ features=features,
+ use_videos=dataset_config.use_videos,
+ tolerance_s=dataset_config.tolerance_s,
+ image_writer_processes=dataset_config.image_writer_processes,
+ image_writer_threads=dataset_config.image_writer_threads,
+ video_backend=dataset_config.video_backend,
+ )
+
+
+def get_cameras(hdf5_files: list[Path]) -> list[str]:
+ with h5py.File(hdf5_files[0], "r") as ep:
+ # ignore depth channel, not currently handled
+ return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
+
+
+def has_velocity(hdf5_files: list[Path]) -> bool:
+ with h5py.File(hdf5_files[0], "r") as ep:
+ return "/observations/qvel" in ep
+
+
+def has_effort(hdf5_files: list[Path]) -> bool:
+ with h5py.File(hdf5_files[0], "r") as ep:
+ return "/observations/effort" in ep
+
+
+def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
+ imgs_per_cam = {}
+ for camera in cameras:
+ uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
+
+ if uncompressed:
+ # load all images in RAM
+ imgs_array = ep[f"/observations/images/{camera}"][:]
+ else:
+ import cv2
+
+ # load one compressed image after the other in RAM and uncompress
+ imgs_array = []
+ for data in ep[f"/observations/images/{camera}"]:
+ imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
+ imgs_array = np.array(imgs_array)
+
+ imgs_per_cam[camera] = imgs_array
+ return imgs_per_cam
+
+
+def load_raw_episode_data(
+ ep_path: Path,
+) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
+ with h5py.File(ep_path, "r") as ep:
+ state = torch.from_numpy(ep["/observations/qpos"][:])
+ action = torch.from_numpy(ep["/action"][:])
+
+ velocity = None
+ if "/observations/qvel" in ep:
+ velocity = torch.from_numpy(ep["/observations/qvel"][:])
+
+ effort = None
+ if "/observations/effort" in ep:
+ effort = torch.from_numpy(ep["/observations/effort"][:])
+
+ imgs_per_cam = load_raw_images_per_camera(
+ ep,
+ [
+ "cam_high",
+ "cam_low",
+ "cam_left_wrist",
+ "cam_right_wrist",
+ ],
+ )
+
+ return imgs_per_cam, state, action, velocity, effort
+
+
+def populate_dataset(
+ dataset: LeRobotDataset,
+ hdf5_files: list[Path],
+ task: str,
+ episodes: list[int] | None = None,
+) -> LeRobotDataset:
+ if episodes is None:
+ episodes = range(len(hdf5_files))
+
+ for ep_idx in tqdm.tqdm(episodes):
+ ep_path = hdf5_files[ep_idx]
+
+ imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
+ num_frames = state.shape[0]
+
+ for i in range(num_frames):
+ frame = {
+ "observation.state": state[i],
+ "action": action[i],
+ }
+
+ for camera, img_array in imgs_per_cam.items():
+ frame[f"observation.images.{camera}"] = img_array[i]
+
+ if velocity is not None:
+ frame["observation.velocity"] = velocity[i]
+ if effort is not None:
+ frame["observation.effort"] = effort[i]
+
+ dataset.add_frame(frame)
+
+ dataset.save_episode(task=task)
+
+ return dataset
+
+
+def port_aloha(
+ raw_dir: Path,
+ repo_id: str,
+ raw_repo_id: str | None = None,
+ task: str = "DEBUG",
+ *,
+ episodes: list[int] | None = None,
+ push_to_hub: bool = True,
+ is_mobile: bool = False,
+ mode: Literal["video", "image"] = "image",
+ dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
+):
+ if (LEROBOT_HOME / repo_id).exists():
+ shutil.rmtree(LEROBOT_HOME / repo_id)
+
+ if not raw_dir.exists():
+ if raw_repo_id is None:
+ raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
+ download_raw(raw_dir, repo_id=raw_repo_id)
+
+ hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
+
+ dataset = create_empty_dataset(
+ repo_id,
+ robot_type="mobile_aloha" if is_mobile else "aloha",
+ mode=mode,
+ has_effort=has_effort(hdf5_files),
+ has_velocity=has_velocity(hdf5_files),
+ dataset_config=dataset_config,
+ )
+ dataset = populate_dataset(
+ dataset,
+ hdf5_files,
+ task=task,
+ episodes=episodes,
+ )
+ dataset.consolidate()
+
+ if push_to_hub:
+ dataset.push_to_hub()
+
+
+if __name__ == "__main__":
+ tyro.cli(port_aloha)
diff --git a/openpi/examples/aloha_real/env.py b/openpi/examples/aloha_real/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..399092f87e630fa85a9f32249cebce975df07007
--- /dev/null
+++ b/openpi/examples/aloha_real/env.py
@@ -0,0 +1,57 @@
+from typing import List, Optional # noqa: UP035
+
+import einops
+from openpi_client import image_tools
+from openpi_client.runtime import environment as _environment
+from typing_extensions import override
+
+from examples.aloha_real import real_env as _real_env
+
+
+class AlohaRealEnvironment(_environment.Environment):
+ """An environment for an Aloha robot on real hardware."""
+
+ def __init__(
+ self,
+ reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
+ render_height: int = 224,
+ render_width: int = 224,
+ ) -> None:
+ self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
+ self._render_height = render_height
+ self._render_width = render_width
+
+ self._ts = None
+
+ @override
+ def reset(self) -> None:
+ self._ts = self._env.reset()
+
+ @override
+ def is_episode_complete(self) -> bool:
+ return False
+
+ @override
+ def get_observation(self) -> dict:
+ if self._ts is None:
+ raise RuntimeError("Timestep is not set. Call reset() first.")
+
+ obs = self._ts.observation
+ for k in list(obs["images"].keys()):
+ if "_depth" in k:
+ del obs["images"][k]
+
+ for cam_name in obs["images"]:
+ img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
+ )
+ obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
+
+ return {
+ "state": obs["qpos"],
+ "images": obs["images"],
+ }
+
+ @override
+ def apply_action(self, action: dict) -> None:
+ self._ts = self._env.step(action["actions"])
diff --git a/openpi/examples/aloha_real/main.py b/openpi/examples/aloha_real/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..25a0631d8baba100c1fef4a20a962d82b6f1db53
--- /dev/null
+++ b/openpi/examples/aloha_real/main.py
@@ -0,0 +1,51 @@
+import dataclasses
+import logging
+
+from openpi_client import action_chunk_broker
+from openpi_client import websocket_client_policy as _websocket_client_policy
+from openpi_client.runtime import runtime as _runtime
+from openpi_client.runtime.agents import policy_agent as _policy_agent
+import tyro
+
+from examples.aloha_real import env as _env
+
+
+@dataclasses.dataclass
+class Args:
+ host: str = "0.0.0.0"
+ port: int = 8000
+
+ action_horizon: int = 25
+
+ num_episodes: int = 1
+ max_episode_steps: int = 1000
+
+
+def main(args: Args) -> None:
+ ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
+ host=args.host,
+ port=args.port,
+ )
+ logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
+
+ metadata = ws_client_policy.get_server_metadata()
+ runtime = _runtime.Runtime(
+ environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
+ agent=_policy_agent.PolicyAgent(
+ policy=action_chunk_broker.ActionChunkBroker(
+ policy=ws_client_policy,
+ action_horizon=args.action_horizon,
+ )
+ ),
+ subscribers=[],
+ max_hz=50,
+ num_episodes=args.num_episodes,
+ max_episode_steps=args.max_episode_steps,
+ )
+
+ runtime.run()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO, force=True)
+ tyro.cli(main)
diff --git a/openpi/examples/aloha_real/real_env.py b/openpi/examples/aloha_real/real_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..2073d8395b0555d47a58afeec5033f2f7c8fcc99
--- /dev/null
+++ b/openpi/examples/aloha_real/real_env.py
@@ -0,0 +1,176 @@
+# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
+# ruff: noqa
+import collections
+import time
+from typing import Optional, List
+import dm_env
+from interbotix_xs_modules.arm import InterbotixManipulatorXS
+from interbotix_xs_msgs.msg import JointSingleCommand
+import numpy as np
+
+from examples.aloha_real import constants
+from examples.aloha_real import robot_utils
+
+# This is the reset position that is used by the standard Aloha runtime.
+DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
+
+
+class RealEnv:
+ """
+ Environment for real robot bi-manual manipulation
+ Action space: [left_arm_qpos (6), # absolute joint position
+ left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
+
+ Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
+ left_gripper_position (1), # normalized gripper position (0: close, 1: open)
+ right_arm_qpos (6), # absolute joint position
+ right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
+ "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
+ left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
+ right_arm_qvel (6), # absolute joint velocity (rad)
+ right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
+ "images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
+ "cam_low": (480x640x3), # h, w, c, dtype='uint8'
+ "cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
+ "cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
+ """
+
+ def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
+ # reset_position = START_ARM_POSE[:6]
+ self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
+
+ self.puppet_bot_left = InterbotixManipulatorXS(
+ robot_model="vx300s",
+ group_name="arm",
+ gripper_name="gripper",
+ robot_name="puppet_left",
+ init_node=init_node,
+ )
+ self.puppet_bot_right = InterbotixManipulatorXS(
+ robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
+ )
+ if setup_robots:
+ self.setup_robots()
+
+ self.recorder_left = robot_utils.Recorder("left", init_node=False)
+ self.recorder_right = robot_utils.Recorder("right", init_node=False)
+ self.image_recorder = robot_utils.ImageRecorder(init_node=False)
+ self.gripper_command = JointSingleCommand(name="gripper")
+
+ def setup_robots(self):
+ robot_utils.setup_puppet_bot(self.puppet_bot_left)
+ robot_utils.setup_puppet_bot(self.puppet_bot_right)
+
+ def get_qpos(self):
+ left_qpos_raw = self.recorder_left.qpos
+ right_qpos_raw = self.recorder_right.qpos
+ left_arm_qpos = left_qpos_raw[:6]
+ right_arm_qpos = right_qpos_raw[:6]
+ left_gripper_qpos = [
+ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
+ ] # this is position not joint
+ right_gripper_qpos = [
+ constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
+ ] # this is position not joint
+ return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
+
+ def get_qvel(self):
+ left_qvel_raw = self.recorder_left.qvel
+ right_qvel_raw = self.recorder_right.qvel
+ left_arm_qvel = left_qvel_raw[:6]
+ right_arm_qvel = right_qvel_raw[:6]
+ left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
+ right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
+ return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
+
+ def get_effort(self):
+ left_effort_raw = self.recorder_left.effort
+ right_effort_raw = self.recorder_right.effort
+ left_robot_effort = left_effort_raw[:7]
+ right_robot_effort = right_effort_raw[:7]
+ return np.concatenate([left_robot_effort, right_robot_effort])
+
+ def get_images(self):
+ return self.image_recorder.get_images()
+
+ def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
+ left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
+ self.gripper_command.cmd = left_gripper_desired_joint
+ self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
+
+ right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
+ right_gripper_desired_pos_normalized
+ )
+ self.gripper_command.cmd = right_gripper_desired_joint
+ self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
+
+ def _reset_joints(self):
+ robot_utils.move_arms(
+ [self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
+ )
+
+ def _reset_gripper(self):
+ """Set to position mode and do position resets: first close then open. Then change back to PWM mode
+
+ NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
+ was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
+ increase the frequency of motor faults.
+ """
+ robot_utils.move_grippers(
+ [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
+ )
+ robot_utils.move_grippers(
+ [self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
+ )
+
+ def get_observation(self):
+ obs = collections.OrderedDict()
+ obs["qpos"] = self.get_qpos()
+ obs["qvel"] = self.get_qvel()
+ obs["effort"] = self.get_effort()
+ obs["images"] = self.get_images()
+ return obs
+
+ def get_reward(self):
+ return 0
+
+ def reset(self, *, fake=False):
+ if not fake:
+ # Reboot puppet robot gripper motors
+ self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
+ self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
+ self._reset_joints()
+ self._reset_gripper()
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
+ )
+
+ def step(self, action):
+ state_len = int(len(action) / 2)
+ left_action = action[:state_len]
+ right_action = action[state_len:]
+ self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
+ self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
+ self.set_gripper_pose(left_action[-1], right_action[-1])
+ time.sleep(constants.DT)
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
+ )
+
+
+def get_action(master_bot_left, master_bot_right):
+ action = np.zeros(14) # 6 joint + 1 gripper, for two arms
+ # Arm actions
+ action[:6] = master_bot_left.dxl.joint_states.position[:6]
+ action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
+ # Gripper actions
+ action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
+ action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
+
+ return action
+
+
+def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
+ return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
diff --git a/openpi/examples/aloha_real/requirements.in b/openpi/examples/aloha_real/requirements.in
new file mode 100644
index 0000000000000000000000000000000000000000..4a6182a26f157b96af0122c4c9344c31b5bc133e
--- /dev/null
+++ b/openpi/examples/aloha_real/requirements.in
@@ -0,0 +1,18 @@
+Pillow
+dm_control
+einops
+h5py
+matplotlib
+modern_robotics
+msgpack
+numpy>=1.22.4,<2.0.0
+opencv-python
+packaging
+pexpect
+pyquaternion
+pyrealsense2
+pyyaml
+requests
+rospkg
+tyro
+websockets
diff --git a/openpi/examples/aloha_real/requirements.txt b/openpi/examples/aloha_real/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7ab297340cb6dad81c78448e3178973576296985
--- /dev/null
+++ b/openpi/examples/aloha_real/requirements.txt
@@ -0,0 +1,156 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
+absl-py==2.1.0
+ # via
+ # dm-control
+ # dm-env
+ # labmaze
+ # mujoco
+catkin-pkg==1.0.0
+ # via rospkg
+certifi==2024.8.30
+ # via requests
+charset-normalizer==3.4.0
+ # via requests
+contourpy==1.1.1
+ # via matplotlib
+cycler==0.12.1
+ # via matplotlib
+distro==1.9.0
+ # via rospkg
+dm-control==1.0.23
+ # via -r examples/aloha_real/requirements.in
+dm-env==1.6
+ # via dm-control
+dm-tree==0.1.8
+ # via
+ # dm-control
+ # dm-env
+docstring-parser==0.16
+ # via tyro
+docutils==0.20.1
+ # via catkin-pkg
+einops==0.8.0
+ # via -r examples/aloha_real/requirements.in
+etils==1.3.0
+ # via mujoco
+fonttools==4.55.2
+ # via matplotlib
+glfw==2.8.0
+ # via
+ # dm-control
+ # mujoco
+h5py==3.11.0
+ # via -r examples/aloha_real/requirements.in
+idna==3.10
+ # via requests
+importlib-resources==6.4.5
+ # via etils
+kiwisolver==1.4.7
+ # via matplotlib
+labmaze==1.0.6
+ # via dm-control
+lxml==5.3.0
+ # via dm-control
+markdown-it-py==3.0.0
+ # via rich
+matplotlib==3.7.5
+ # via -r examples/aloha_real/requirements.in
+mdurl==0.1.2
+ # via markdown-it-py
+modern-robotics==1.1.1
+ # via -r examples/aloha_real/requirements.in
+msgpack==1.1.0
+ # via -r examples/aloha_real/requirements.in
+mujoco==3.2.3
+ # via dm-control
+numpy==1.24.4
+ # via
+ # -r examples/aloha_real/requirements.in
+ # contourpy
+ # dm-control
+ # dm-env
+ # h5py
+ # labmaze
+ # matplotlib
+ # modern-robotics
+ # mujoco
+ # opencv-python
+ # pyquaternion
+ # scipy
+opencv-python==4.10.0.84
+ # via -r examples/aloha_real/requirements.in
+packaging==24.2
+ # via
+ # -r examples/aloha_real/requirements.in
+ # matplotlib
+pexpect==4.9.0
+ # via -r examples/aloha_real/requirements.in
+pillow==10.4.0
+ # via
+ # -r examples/aloha_real/requirements.in
+ # matplotlib
+protobuf==5.29.1
+ # via dm-control
+ptyprocess==0.7.0
+ # via pexpect
+pygments==2.18.0
+ # via rich
+pyopengl==3.1.7
+ # via
+ # dm-control
+ # mujoco
+pyparsing==3.1.4
+ # via
+ # catkin-pkg
+ # dm-control
+ # matplotlib
+pyquaternion==0.9.9
+ # via -r examples/aloha_real/requirements.in
+pyrealsense2==2.55.1.6486
+ # via -r examples/aloha_real/requirements.in
+python-dateutil==2.9.0.post0
+ # via
+ # catkin-pkg
+ # matplotlib
+pyyaml==6.0.2
+ # via
+ # -r examples/aloha_real/requirements.in
+ # rospkg
+requests==2.32.3
+ # via
+ # -r examples/aloha_real/requirements.in
+ # dm-control
+rich==13.9.4
+ # via tyro
+rospkg==1.5.1
+ # via -r examples/aloha_real/requirements.in
+scipy==1.10.1
+ # via dm-control
+setuptools==75.3.0
+ # via
+ # catkin-pkg
+ # dm-control
+ # labmaze
+shtab==1.7.1
+ # via tyro
+six==1.17.0
+ # via python-dateutil
+tqdm==4.67.1
+ # via dm-control
+typeguard==4.4.0
+ # via tyro
+typing-extensions==4.12.2
+ # via
+ # etils
+ # rich
+ # typeguard
+ # tyro
+tyro==0.9.2
+ # via -r examples/aloha_real/requirements.in
+urllib3==2.2.3
+ # via requests
+websockets==14.1
+ # via -r examples/aloha_real/requirements.in
+zipp==3.20.2
+ # via etils
diff --git a/openpi/examples/aloha_real/robot_utils.py b/openpi/examples/aloha_real/robot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..965a11eaae29c7cbf4bd7233b20aeb6cc40a84df
--- /dev/null
+++ b/openpi/examples/aloha_real/robot_utils.py
@@ -0,0 +1,275 @@
+# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
+# ruff: noqa
+from collections import deque
+import datetime
+import json
+import time
+
+from aloha.msg import RGBGrayscaleImage
+from cv_bridge import CvBridge
+from interbotix_xs_msgs.msg import JointGroupCommand
+from interbotix_xs_msgs.msg import JointSingleCommand
+import numpy as np
+import rospy
+from sensor_msgs.msg import JointState
+
+from examples.aloha_real import constants
+
+
+class ImageRecorder:
+ def __init__(self, init_node=True, is_debug=False):
+ self.is_debug = is_debug
+ self.bridge = CvBridge()
+ self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
+
+ if init_node:
+ rospy.init_node("image_recorder", anonymous=True)
+ for cam_name in self.camera_names:
+ setattr(self, f"{cam_name}_rgb_image", None)
+ setattr(self, f"{cam_name}_depth_image", None)
+ setattr(self, f"{cam_name}_timestamp", 0.0)
+ if cam_name == "cam_high":
+ callback_func = self.image_cb_cam_high
+ elif cam_name == "cam_low":
+ callback_func = self.image_cb_cam_low
+ elif cam_name == "cam_left_wrist":
+ callback_func = self.image_cb_cam_left_wrist
+ elif cam_name == "cam_right_wrist":
+ callback_func = self.image_cb_cam_right_wrist
+ else:
+ raise NotImplementedError
+ rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
+ if self.is_debug:
+ setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
+
+ self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
+ time.sleep(0.5)
+
+ def image_cb(self, cam_name, data):
+ setattr(
+ self,
+ f"{cam_name}_rgb_image",
+ self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
+ )
+ # setattr(
+ # self,
+ # f"{cam_name}_depth_image",
+ # self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
+ # )
+ setattr(
+ self,
+ f"{cam_name}_timestamp",
+ data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
+ )
+ # setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
+ # setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
+ # cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
+ if self.is_debug:
+ getattr(self, f"{cam_name}_timestamps").append(
+ data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
+ )
+
+ def image_cb_cam_high(self, data):
+ cam_name = "cam_high"
+ return self.image_cb(cam_name, data)
+
+ def image_cb_cam_low(self, data):
+ cam_name = "cam_low"
+ return self.image_cb(cam_name, data)
+
+ def image_cb_cam_left_wrist(self, data):
+ cam_name = "cam_left_wrist"
+ return self.image_cb(cam_name, data)
+
+ def image_cb_cam_right_wrist(self, data):
+ cam_name = "cam_right_wrist"
+ return self.image_cb(cam_name, data)
+
+ def get_images(self):
+ image_dict = {}
+ for cam_name in self.camera_names:
+ while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
+ time.sleep(0.00001)
+ rgb_image = getattr(self, f"{cam_name}_rgb_image")
+ depth_image = getattr(self, f"{cam_name}_depth_image")
+ self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
+ image_dict[cam_name] = rgb_image
+ image_dict[f"{cam_name}_depth"] = depth_image
+ return image_dict
+
+ def print_diagnostics(self):
+ def dt_helper(l):
+ l = np.array(l)
+ diff = l[1:] - l[:-1]
+ return np.mean(diff)
+
+ for cam_name in self.camera_names:
+ image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
+ print(f"{cam_name} {image_freq=:.2f}")
+ print()
+
+
+class Recorder:
+ def __init__(self, side, init_node=True, is_debug=False):
+ self.secs = None
+ self.nsecs = None
+ self.qpos = None
+ self.effort = None
+ self.arm_command = None
+ self.gripper_command = None
+ self.is_debug = is_debug
+
+ if init_node:
+ rospy.init_node("recorder", anonymous=True)
+ rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
+ rospy.Subscriber(
+ f"/puppet_{side}/commands/joint_group",
+ JointGroupCommand,
+ self.puppet_arm_commands_cb,
+ )
+ rospy.Subscriber(
+ f"/puppet_{side}/commands/joint_single",
+ JointSingleCommand,
+ self.puppet_gripper_commands_cb,
+ )
+ if self.is_debug:
+ self.joint_timestamps = deque(maxlen=50)
+ self.arm_command_timestamps = deque(maxlen=50)
+ self.gripper_command_timestamps = deque(maxlen=50)
+ time.sleep(0.1)
+
+ def puppet_state_cb(self, data):
+ self.qpos = data.position
+ self.qvel = data.velocity
+ self.effort = data.effort
+ self.data = data
+ if self.is_debug:
+ self.joint_timestamps.append(time.time())
+
+ def puppet_arm_commands_cb(self, data):
+ self.arm_command = data.cmd
+ if self.is_debug:
+ self.arm_command_timestamps.append(time.time())
+
+ def puppet_gripper_commands_cb(self, data):
+ self.gripper_command = data.cmd
+ if self.is_debug:
+ self.gripper_command_timestamps.append(time.time())
+
+ def print_diagnostics(self):
+ def dt_helper(l):
+ l = np.array(l)
+ diff = l[1:] - l[:-1]
+ return np.mean(diff)
+
+ joint_freq = 1 / dt_helper(self.joint_timestamps)
+ arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
+ gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
+
+ print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
+
+
+def get_arm_joint_positions(bot):
+ return bot.arm.core.joint_states.position[:6]
+
+
+def get_arm_gripper_positions(bot):
+ return bot.gripper.core.joint_states.position[6]
+
+
+def move_arms(bot_list, target_pose_list, move_time=1):
+ num_steps = int(move_time / constants.DT)
+ curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
+ traj_list = [
+ np.linspace(curr_pose, target_pose, num_steps)
+ for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
+ ]
+ for t in range(num_steps):
+ for bot_id, bot in enumerate(bot_list):
+ bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
+ time.sleep(constants.DT)
+
+
+def move_grippers(bot_list, target_pose_list, move_time):
+ print(f"Moving grippers to {target_pose_list=}")
+ gripper_command = JointSingleCommand(name="gripper")
+ num_steps = int(move_time / constants.DT)
+ curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
+ traj_list = [
+ np.linspace(curr_pose, target_pose, num_steps)
+ for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
+ ]
+
+ with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
+ for t in range(num_steps):
+ d = {}
+ for bot_id, bot in enumerate(bot_list):
+ gripper_command.cmd = traj_list[bot_id][t]
+ bot.gripper.core.pub_single.publish(gripper_command)
+ d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
+ f.write(json.dumps(d) + "\n")
+ time.sleep(constants.DT)
+
+
+def setup_puppet_bot(bot):
+ bot.dxl.robot_reboot_motors("single", "gripper", True)
+ bot.dxl.robot_set_operating_modes("group", "arm", "position")
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
+ torque_on(bot)
+
+
+def setup_master_bot(bot):
+ bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
+ bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
+ torque_off(bot)
+
+
+def set_standard_pid_gains(bot):
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
+
+
+def set_low_pid_gains(bot):
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
+ bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
+
+
+def torque_off(bot):
+ bot.dxl.robot_torque_enable("group", "arm", False)
+ bot.dxl.robot_torque_enable("single", "gripper", False)
+
+
+def torque_on(bot):
+ bot.dxl.robot_torque_enable("group", "arm", True)
+ bot.dxl.robot_torque_enable("single", "gripper", True)
+
+
+# for DAgger
+def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
+ print("\nSyncing!")
+
+ # activate master arms
+ torque_on(master_bot_left)
+ torque_on(master_bot_right)
+
+ # get puppet arm positions
+ puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
+ puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
+
+ # get puppet gripper positions
+ puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
+ puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
+
+ # move master arms to puppet positions
+ move_arms(
+ [master_bot_left, master_bot_right],
+ [puppet_left_qpos, puppet_right_qpos],
+ move_time=1,
+ )
+
+ # move master grippers to puppet positions
+ move_grippers(
+ [master_bot_left, master_bot_right],
+ [puppet_left_gripper, puppet_right_gripper],
+ move_time=1,
+ )
diff --git a/openpi/examples/aloha_real/video_display.py b/openpi/examples/aloha_real/video_display.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad79ddd30965d842e82f2c6cf3b89fdc43bf844
--- /dev/null
+++ b/openpi/examples/aloha_real/video_display.py
@@ -0,0 +1,36 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from openpi_client.runtime import subscriber as _subscriber
+from typing_extensions import override
+
+
+class VideoDisplay(_subscriber.Subscriber):
+ """Displays video frames."""
+
+ def __init__(self) -> None:
+ self._ax: plt.Axes | None = None
+ self._plt_img: plt.Image | None = None
+
+ @override
+ def on_episode_start(self) -> None:
+ plt.ion()
+ self._ax = plt.subplot()
+ self._plt_img = None
+
+ @override
+ def on_step(self, observation: dict, action: dict) -> None:
+ assert self._ax is not None
+
+ im = observation["image"][0] # [C, H, W]
+ im = np.transpose(im, (1, 2, 0)) # [H, W, C]
+
+ if self._plt_img is None:
+ self._plt_img = self._ax.imshow(im)
+ else:
+ self._plt_img.set_data(im)
+ plt.pause(0.001)
+
+ @override
+ def on_episode_end(self) -> None:
+ plt.ioff()
+ plt.close()
diff --git a/openpi/examples/aloha_sim/Dockerfile b/openpi/examples/aloha_sim/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..1f18790a2abf42377d597162daada4fe459c6bb4
--- /dev/null
+++ b/openpi/examples/aloha_sim/Dockerfile
@@ -0,0 +1,41 @@
+# Dockerfile for the Aloha simulation environment.
+
+# Build the container:
+# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
+
+FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
+COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
+
+RUN apt-get update && \
+ apt-get install -y \
+ libosmesa6-dev \
+ libgl1-mesa-glx \
+ libglew-dev \
+ libglfw3-dev \
+ libgles2-mesa-dev
+ENV MUJOCO_GL=egl
+
+WORKDIR /app
+
+# Copy from the cache instead of linking since it's a mounted volume
+ENV UV_LINK_MODE=copy
+
+# Write the virtual environment outside of the project directory so it doesn't
+# leak out of the container when we mount the application code.
+ENV UV_PROJECT_ENVIRONMENT=/.venv
+
+# Copy the requirements files so we can install dependencies.
+# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
+# This strategy is best for development-style usage.
+COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
+COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
+
+# Install python dependencies.
+RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
+RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
+ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
+
+CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
\ No newline at end of file
diff --git a/openpi/examples/aloha_sim/README.md b/openpi/examples/aloha_sim/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c6d4c5bc80103c0d1fdb5f7387ae2a39836bfc9
--- /dev/null
+++ b/openpi/examples/aloha_sim/README.md
@@ -0,0 +1,36 @@
+# Run Aloha Sim
+
+## With Docker
+
+```bash
+export SERVER_ARGS="--env ALOHA_SIM"
+docker compose -f examples/aloha_sim/compose.yml up --build
+```
+
+## Without Docker
+
+Terminal window 1:
+
+```bash
+# Create virtual environment
+uv venv --python 3.10 examples/aloha_sim/.venv
+source examples/aloha_sim/.venv/bin/activate
+uv pip sync examples/aloha_sim/requirements.txt
+uv pip install -e packages/openpi-client
+
+# Run the simulation
+MUJOCO_GL=egl python examples/aloha_sim/main.py
+```
+
+Note: If you are seeing EGL errors, you may need to install the following dependencies:
+
+```bash
+sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
+```
+
+Terminal window 2:
+
+```bash
+# Run the server
+uv run scripts/serve_policy.py --env ALOHA_SIM
+```
diff --git a/openpi/examples/aloha_sim/compose.yml b/openpi/examples/aloha_sim/compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c56e4dea137e0bbb84d68745047997932080b27d
--- /dev/null
+++ b/openpi/examples/aloha_sim/compose.yml
@@ -0,0 +1,42 @@
+# Run with:
+# docker compose -f examples/aloha_sim/compose.yml up --build
+services:
+ runtime:
+ image: aloha_sim
+ depends_on:
+ - openpi_server
+ build:
+ context: ../..
+ dockerfile: examples/aloha_sim/Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ privileged: true
+ volumes:
+ - $PWD:/app
+ - ../../data:/data
+
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/openpi/examples/aloha_sim/env.py b/openpi/examples/aloha_sim/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2d5b6ac38e4c5c1671dfbc9323dcc541e9cb2d
--- /dev/null
+++ b/openpi/examples/aloha_sim/env.py
@@ -0,0 +1,56 @@
+import gym_aloha # noqa: F401
+import gymnasium
+import numpy as np
+from openpi_client import image_tools
+from openpi_client.runtime import environment as _environment
+from typing_extensions import override
+
+
+class AlohaSimEnvironment(_environment.Environment):
+ """An environment for an Aloha robot in simulation."""
+
+ def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
+ np.random.seed(seed)
+ self._rng = np.random.default_rng(seed)
+
+ self._gym = gymnasium.make(task, obs_type=obs_type)
+
+ self._last_obs = None
+ self._done = True
+ self._episode_reward = 0.0
+
+ @override
+ def reset(self) -> None:
+ gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
+ self._last_obs = self._convert_observation(gym_obs) # type: ignore
+ self._done = False
+ self._episode_reward = 0.0
+
+ @override
+ def is_episode_complete(self) -> bool:
+ return self._done
+
+ @override
+ def get_observation(self) -> dict:
+ if self._last_obs is None:
+ raise RuntimeError("Observation is not set. Call reset() first.")
+
+ return self._last_obs # type: ignore
+
+ @override
+ def apply_action(self, action: dict) -> None:
+ gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
+ self._last_obs = self._convert_observation(gym_obs) # type: ignore
+ self._done = terminated or truncated
+ self._episode_reward = max(self._episode_reward, reward)
+
+ def _convert_observation(self, gym_obs: dict) -> dict:
+ img = gym_obs["pixels"]["top"]
+ img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
+ # Convert axis order from [H, W, C] --> [C, H, W]
+ img = np.transpose(img, (2, 0, 1))
+
+ return {
+ "state": gym_obs["agent_pos"],
+ "images": {"cam_high": img},
+ }
diff --git a/openpi/examples/aloha_sim/main.py b/openpi/examples/aloha_sim/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..d76122ab35e98ecf8e3e9da7080f14c655435115
--- /dev/null
+++ b/openpi/examples/aloha_sim/main.py
@@ -0,0 +1,55 @@
+import dataclasses
+import logging
+import pathlib
+
+import env as _env
+from openpi_client import action_chunk_broker
+from openpi_client import websocket_client_policy as _websocket_client_policy
+from openpi_client.runtime import runtime as _runtime
+from openpi_client.runtime.agents import policy_agent as _policy_agent
+import saver as _saver
+import tyro
+
+
+@dataclasses.dataclass
+class Args:
+ out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
+
+ task: str = "gym_aloha/AlohaTransferCube-v0"
+ seed: int = 0
+
+ action_horizon: int = 10
+
+ host: str = "0.0.0.0"
+ port: int = 8000
+
+ display: bool = False
+
+
+def main(args: Args) -> None:
+ runtime = _runtime.Runtime(
+ environment=_env.AlohaSimEnvironment(
+ task=args.task,
+ seed=args.seed,
+ ),
+ agent=_policy_agent.PolicyAgent(
+ policy=action_chunk_broker.ActionChunkBroker(
+ policy=_websocket_client_policy.WebsocketClientPolicy(
+ host=args.host,
+ port=args.port,
+ ),
+ action_horizon=args.action_horizon,
+ )
+ ),
+ subscribers=[
+ _saver.VideoSaver(args.out_dir),
+ ],
+ max_hz=50,
+ )
+
+ runtime.run()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO, force=True)
+ tyro.cli(main)
diff --git a/openpi/examples/aloha_sim/requirements.in b/openpi/examples/aloha_sim/requirements.in
new file mode 100644
index 0000000000000000000000000000000000000000..d84d356eff374ac084b1e1065ee6d66017897ec6
--- /dev/null
+++ b/openpi/examples/aloha_sim/requirements.in
@@ -0,0 +1,8 @@
+gym-aloha
+imageio
+matplotlib
+msgpack
+numpy>=1.22.4,<2.0.0
+typing-extensions
+tyro
+websockets
\ No newline at end of file
diff --git a/openpi/examples/aloha_sim/requirements.txt b/openpi/examples/aloha_sim/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a1087f1b3b9cd58264eb7f46f0e02eafe3380222
--- /dev/null
+++ b/openpi/examples/aloha_sim/requirements.txt
@@ -0,0 +1,132 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
+absl-py==2.1.0
+ # via
+ # dm-control
+ # dm-env
+ # labmaze
+ # mujoco
+certifi==2024.8.30
+ # via requests
+charset-normalizer==3.4.0
+ # via requests
+cloudpickle==3.1.0
+ # via gymnasium
+contourpy==1.3.1
+ # via matplotlib
+cycler==0.12.1
+ # via matplotlib
+dm-control==1.0.14
+ # via gym-aloha
+dm-env==1.6
+ # via dm-control
+dm-tree==0.1.8
+ # via
+ # dm-control
+ # dm-env
+docstring-parser==0.16
+ # via tyro
+farama-notifications==0.0.4
+ # via gymnasium
+fonttools==4.55.2
+ # via matplotlib
+glfw==2.8.0
+ # via
+ # dm-control
+ # mujoco
+gym-aloha==0.1.1
+ # via -r examples/aloha_sim/requirements.in
+gymnasium==1.0.0
+ # via gym-aloha
+idna==3.10
+ # via requests
+imageio==2.36.1
+ # via
+ # -r examples/aloha_sim/requirements.in
+ # gym-aloha
+imageio-ffmpeg==0.5.1
+ # via imageio
+kiwisolver==1.4.7
+ # via matplotlib
+labmaze==1.0.6
+ # via dm-control
+lxml==5.3.0
+ # via dm-control
+markdown-it-py==3.0.0
+ # via rich
+matplotlib==3.9.3
+ # via -r examples/aloha_sim/requirements.in
+mdurl==0.1.2
+ # via markdown-it-py
+msgpack==1.1.0
+ # via -r examples/aloha_sim/requirements.in
+mujoco==2.3.7
+ # via
+ # dm-control
+ # gym-aloha
+numpy==1.26.4
+ # via
+ # -r examples/aloha_sim/requirements.in
+ # contourpy
+ # dm-control
+ # dm-env
+ # gymnasium
+ # imageio
+ # labmaze
+ # matplotlib
+ # mujoco
+ # scipy
+packaging==24.2
+ # via matplotlib
+pillow==11.0.0
+ # via
+ # imageio
+ # matplotlib
+protobuf==5.29.1
+ # via dm-control
+psutil==6.1.0
+ # via imageio
+pygments==2.18.0
+ # via rich
+pyopengl==3.1.7
+ # via
+ # dm-control
+ # mujoco
+pyparsing==3.2.0
+ # via
+ # dm-control
+ # matplotlib
+python-dateutil==2.9.0.post0
+ # via matplotlib
+requests==2.32.3
+ # via dm-control
+rich==13.9.4
+ # via tyro
+scipy==1.14.1
+ # via dm-control
+setuptools==75.6.0
+ # via
+ # dm-control
+ # imageio-ffmpeg
+ # labmaze
+shtab==1.7.1
+ # via tyro
+six==1.17.0
+ # via python-dateutil
+tqdm==4.67.1
+ # via dm-control
+typeguard==4.4.1
+ # via tyro
+typing-extensions==4.12.2
+ # via
+ # -r examples/aloha_sim/requirements.in
+ # gymnasium
+ # rich
+ # typeguard
+ # tyro
+tyro==0.9.2
+ # via -r examples/aloha_sim/requirements.in
+urllib3==2.2.3
+ # via requests
+websockets==14.1
+ # via -r examples/aloha_sim/requirements.in
diff --git a/openpi/examples/aloha_sim/saver.py b/openpi/examples/aloha_sim/saver.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7f2c5d3365c9bdf32a44ba4cdec9bfe195366f
--- /dev/null
+++ b/openpi/examples/aloha_sim/saver.py
@@ -0,0 +1,40 @@
+import logging
+import pathlib
+
+import imageio
+import numpy as np
+from openpi_client.runtime import subscriber as _subscriber
+from typing_extensions import override
+
+
+class VideoSaver(_subscriber.Subscriber):
+ """Saves episode data."""
+
+ def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
+ out_dir.mkdir(parents=True, exist_ok=True)
+ self._out_dir = out_dir
+ self._images: list[np.ndarray] = []
+ self._subsample = subsample
+
+ @override
+ def on_episode_start(self) -> None:
+ self._images = []
+
+ @override
+ def on_step(self, observation: dict, action: dict) -> None:
+ im = observation["images"]["cam_high"] # [C, H, W]
+ im = np.transpose(im, (1, 2, 0)) # [H, W, C]
+ self._images.append(im)
+
+ @override
+ def on_episode_end(self) -> None:
+ existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
+ next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
+ out_path = self._out_dir / f"out_{next_idx}.mp4"
+
+ logging.info(f"Saving video to {out_path}")
+ imageio.mimwrite(
+ out_path,
+ [np.asarray(x) for x in self._images[:: self._subsample]],
+ fps=50 // max(1, self._subsample),
+ )
diff --git a/openpi/examples/convert_jax_model_to_pytorch.py b/openpi/examples/convert_jax_model_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..632c0b8782c1ecb5cb380130a30a3152b220eafd
--- /dev/null
+++ b/openpi/examples/convert_jax_model_to_pytorch.py
@@ -0,0 +1,587 @@
+#!/usr/bin/env python3
+"""
+Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
+
+This script loads a JAX model checkpoint using orbax and can either:
+1. Print out all the parameter keys in a hierarchical structure for inspection
+2. Convert the JAX model to PyTorch format using our PI0Pytorch model
+
+Usage:
+ # Just inspect keys:
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
+
+ # Convert to PyTorch:
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
+
+Example:
+ # pi0_droid
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
+
+ # pi0_aloha_sim
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
+
+ # pi05_droid
+ python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
+"""
+
+import json
+import os
+import pathlib
+import shutil
+from typing import Literal
+
+from flax.nnx import traversals
+import numpy as np
+import orbax.checkpoint as ocp
+import safetensors
+import torch
+import tyro
+
+import openpi.models.gemma
+import openpi.models.model
+import openpi.models.pi0_config
+import openpi.models_pytorch.pi0_pytorch
+from openpi.training import utils
+import openpi.training.config as _config
+
+
+def slice_paligemma_state_dict(state_dict, config):
+ """Convert PaliGemma JAX parameters to PyTorch format."""
+ suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
+
+ # patch embeddings
+ jax_key = f"img/embedding/kernel{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
+
+ jax_key = f"img/embedding/bias{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # positional embeddings
+ jax_key = f"img/pos_embedding{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
+ state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
+
+ # extract vision layers to be sliced at index 0. There are 27 layers in the base model.
+ encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
+ encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
+ encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
+ encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
+
+ encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
+ encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
+ encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
+ encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
+
+ encoderblock_attention_0_key_kernel = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
+ )
+ encoderblock_attention_0_key_bias = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
+ )
+ encoderblock_attention_0_value_kernel = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
+ )
+ encoderblock_attention_0_value_bias = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
+ )
+ encoderblock_attention_0_query_kernel = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
+ )
+ encoderblock_attention_0_query_bias = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
+ )
+ encoderblock_attention_0_out_kernel = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
+ )
+ encoderblock_attention_0_out_bias = state_dict.pop(
+ f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
+ )
+
+ for i in range(config.vision_config.num_hidden_layers):
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
+ ] = encoderblock_layernorm0_scale[i].transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
+ ] = encoderblock_layernorm0_bias[i]
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
+ ] = encoderblock_layernorm1_scale[i].transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
+ ] = encoderblock_layernorm1_bias[i]
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
+ ] = encoderblock_mlp_dense0_kernel[i].transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
+ ] = encoderblock_mlp_dense0_bias[i]
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
+ ] = encoderblock_mlp_dense1_kernel[i].transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
+ ] = encoderblock_mlp_dense1_bias[i]
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
+ ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
+ ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
+ ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
+ ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
+ ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
+ ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
+ ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
+ ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
+
+ jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
+
+ jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # multimodal projector
+ jax_key = f"img/head/kernel{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
+ state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
+
+ jax_key = f"img/head/bias{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # text decoder (gemma)
+ jax_key = f"llm/embedder/input_embedding{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ # pop the einsum attention + mlp representations
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
+
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
+
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
+
+ for i in range(config.text_config.num_hidden_layers):
+ q_proj_weight_reshaped = (
+ llm_attention_q_einsum[i]
+ .transpose(0, 2, 1)
+ .reshape(
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
+ )
+ )
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
+ q_proj_weight_reshaped
+ )
+
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
+ k_proj_weight_reshaped
+ )
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
+ v_proj_weight_reshaped
+ )
+
+ o_proj_weight_reshaped = (
+ llm_attention_attn_vec_einsum[i]
+ .transpose(2, 0, 1)
+ .reshape(
+ config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
+ )
+ )
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
+ o_proj_weight_reshaped
+ )
+
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
+ gate_proj_weight.transpose()
+ )
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
+ up_proj_weight.transpose()
+ )
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
+ llm_mlp_linear[i].transpose()
+ )
+ state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
+ llm_input_layernorm[i]
+ )
+ state_dict[
+ f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
+ ] = llm_post_attention_layernorm[i]
+
+ jax_key = f"llm/final_norm/scale{suffix}"
+ pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
+ state_dict[pytorch_key] = state_dict.pop(jax_key)
+
+ expert_dict = {}
+ final_state_dict = {}
+
+ # Expert-related keys to extract (including pi05 Dense layer parameters)
+ expert_keys = [
+ f"llm/final_norm_1/scale{suffix}",
+ f"llm/final_norm_1/Dense_0/bias{suffix}",
+ f"llm/final_norm_1/Dense_0/kernel{suffix}",
+ f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
+ f"llm/layers/attn/kv_einsum_1/w{suffix}",
+ f"llm/layers/attn/q_einsum_1/w{suffix}",
+ f"llm/layers/mlp_1/gating_einsum{suffix}",
+ f"llm/layers/mlp_1/linear{suffix}",
+ f"llm/layers/pre_attention_norm_1/scale{suffix}",
+ f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
+ f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
+ f"llm/layers/pre_ffw_norm_1/scale{suffix}",
+ f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
+ f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
+ ]
+
+ for key, value in state_dict.items():
+ if key not in expert_keys:
+ final_state_dict[key] = torch.from_numpy(value)
+ else:
+ expert_dict[key] = value
+
+ return final_state_dict, expert_dict
+
+
+def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
+ """Convert Gemma JAX parameters to PyTorch format."""
+ # Add missing attributes to config if they don't exist
+ if not hasattr(config, "vocab_size"):
+ config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
+ if not hasattr(config, "hidden_size"):
+ config.hidden_size = config.width
+ if not hasattr(config, "num_hidden_layers"):
+ config.num_hidden_layers = config.depth
+ if not hasattr(config, "num_attention_heads"):
+ config.num_attention_heads = config.num_heads
+
+ suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
+
+ llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
+ llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
+ llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
+
+ llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
+ llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
+
+ # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
+ if "pi05" in checkpoint_dir:
+ # Pi05 with adaptive normalization
+ llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
+ llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
+ llm_input_layernorm_kernel = state_dict.pop(
+ f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
+ )
+ llm_post_attention_layernorm_kernel = state_dict.pop(
+ f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
+ )
+ else:
+ # Regular pi0 with standard RMSNorm
+ llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
+ llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
+
+ for i in range(config.num_hidden_layers):
+ q_proj_weight_reshaped = (
+ llm_attention_q_einsum[i]
+ .transpose(0, 2, 1)
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
+ q_proj_weight_reshaped
+ )
+
+ k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
+ k_proj_weight_reshaped
+ )
+ v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
+ v_proj_weight_reshaped
+ )
+
+ o_proj_weight_reshaped = (
+ llm_attention_attn_vec_einsum[i]
+ .reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
+ .transpose(1, 0)
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
+ o_proj_weight_reshaped
+ )
+
+ gate_proj_weight = llm_mlp_gating_einsum[i, 0]
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
+ gate_proj_weight.transpose()
+ )
+ up_proj_weight = llm_mlp_gating_einsum[i, 1]
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
+ up_proj_weight.transpose()
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
+ i
+ ].transpose()
+
+ if "pi05" in checkpoint_dir:
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
+ llm_input_layernorm_bias[i]
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
+ llm_post_attention_layernorm_bias[i]
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
+ llm_input_layernorm_kernel[i].transpose()
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
+ llm_post_attention_layernorm_kernel[i].transpose()
+ )
+ else:
+ # Regular pi0 with standard RMSNorm
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
+ llm_input_layernorm[i]
+ )
+ state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
+ llm_post_attention_layernorm[i]
+ )
+
+ # Handle final norm layer
+ if "pi05" in checkpoint_dir:
+ # Pi05 with adaptive normalization - use Dense layer parameters directly
+ final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
+ final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
+ else:
+ # Regular pi0 with standard RMSNorm
+ state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
+ f"llm/final_norm_{num_expert}/scale{suffix}"
+ )
+
+ # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
+
+ final_state_dict = {}
+ for key, value in state_dict.items():
+ if not isinstance(value, torch.Tensor):
+ final_state_dict[key] = torch.from_numpy(value)
+ else:
+ final_state_dict[key] = value
+
+ return final_state_dict
+
+
+def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
+ """Load and process params by restoring via JAX model loader first.
+ This respects dtype conversions that occur during model restore.
+ """
+ # Use repository restore utility to load a pure dict of params (value suffix removed)
+ params = openpi.models.model.restore_params(
+ f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
+ )
+
+ return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
+
+
+def load_jax_model_and_print_keys(checkpoint_dir: str):
+ """
+ Load JAX model from checkpoint and print all parameter keys.
+
+ Args:
+ checkpoint_dir: Path to the checkpoint directory
+ """
+ checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
+ # Initialize checkpointer
+ checkpointer = ocp.PyTreeCheckpointer()
+ metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
+ print(utils.array_tree_to_info(metadata))
+
+
+def convert_pi0_checkpoint(
+ checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
+):
+ """
+ Convert PI0 JAX checkpoint to PyTorch format.
+
+ Args:
+ checkpoint_dir: Path to the JAX checkpoint
+ precision: Model precision (float32, bfloat16, float16)
+ output_path: Path to save the converted PyTorch model
+ model_config: Model config
+ """
+ print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
+ print(f"Model config: {model_config}")
+
+ # Break down orbax ckpts by restoring via JAX to respect dtype
+ initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
+
+ # Process projection params
+ if model_config.pi05:
+ keys = [
+ "action_in_proj",
+ "action_out_proj",
+ "time_mlp_in",
+ "time_mlp_out",
+ ]
+ else:
+ keys = [
+ "state_proj",
+ "action_in_proj",
+ "action_out_proj",
+ "action_time_mlp_in",
+ "action_time_mlp_out",
+ ]
+
+ projection_params = {}
+ for key in keys:
+ kernel_params = initial_params["projection_params"][key]["kernel"]
+ bias_params = initial_params["projection_params"][key]["bias"]
+ if isinstance(kernel_params, dict):
+ weight = kernel_params["value"]
+ bias = bias_params["value"]
+ else:
+ weight = kernel_params
+ bias = bias_params
+
+ pytorch_weight_key = f"{key}.weight"
+ pytorch_bias_key = f"{key}.bias"
+
+ projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
+ projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
+
+ # Create configs based on checkpoint path
+ # All models use the same PaliGemma config structure
+ class PaliGemmaConfig:
+ def __init__(self):
+ self.vision_config = type(
+ "obj",
+ (object,),
+ {
+ "hidden_size": 1152,
+ "num_hidden_layers": 27,
+ "num_attention_heads": 16,
+ "intermediate_size": 4304,
+ "patch_size": 14,
+ "projection_dim": 2048,
+ },
+ )()
+ self.text_config = type(
+ "obj",
+ (object,),
+ {
+ "hidden_size": 2048,
+ "num_hidden_layers": 18,
+ "num_attention_heads": 8,
+ "head_dim": 256,
+ "intermediate_size": 16384,
+ },
+ )()
+
+ paligemma_config = PaliGemmaConfig()
+ action_expert_config = openpi.models.gemma.get_config("gemma_300m")
+
+ # Process PaliGemma weights
+ paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
+
+ # Process Gemma weights from expert_params
+ gemma_params = slice_gemma_state_dict(
+ expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
+ )
+
+ # Instantiate model
+ pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
+
+ # Combine all parameters (no prefix needed for our model structure)
+ all_params = {**paligemma_params, **gemma_params, **projection_params}
+
+ # Load state dict
+ pi0_model.load_state_dict(all_params, strict=False)
+
+ if precision == "float32":
+ pi0_model = pi0_model.to(torch.float32)
+ elif precision == "bfloat16":
+ pi0_model = pi0_model.to(torch.bfloat16)
+ else:
+ raise ValueError(f"Invalid precision: {precision}")
+
+ # Save the converted model using safetensors
+ os.makedirs(output_path, exist_ok=True)
+
+ # Save model weights as SafeTensors using save_model to handle tied weights
+ safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
+
+ # Copy assets folder if it exists
+ assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
+ if assets_source.exists():
+ assets_dest = pathlib.Path(output_path) / "assets"
+ if assets_dest.exists():
+ shutil.rmtree(assets_dest)
+ shutil.copytree(assets_source, assets_dest)
+
+ # Save config as JSON for reference
+ config_dict = {
+ "action_dim": model_config.action_dim,
+ "action_horizon": model_config.action_horizon,
+ "paligemma_variant": model_config.paligemma_variant,
+ "action_expert_variant": model_config.action_expert_variant,
+ "precision": precision,
+ }
+ with open(os.path.join(output_path, "config.json"), "w") as f:
+ json.dump(config_dict, f, indent=2)
+
+ print("Model conversion completed successfully!")
+ print(f"Model saved to {output_path}")
+
+
+def main(
+ checkpoint_dir: str,
+ config_name: str,
+ output_path: str | None = None,
+ precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
+ *,
+ inspect_only: bool = False,
+):
+ """Load JAX model and optionally convert to PyTorch.
+
+ Args:
+ checkpoint_dir: Path to the JAX checkpoint directory
+ output_path: Path to save converted PyTorch model (required for conversion)
+ precision: Precision for model conversion
+ inspect_only: Only inspect parameter keys, don't convert
+ """
+ model_config = _config.get_config(config_name).model
+ if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
+ raise ValueError(f"Config {config_name} is not a Pi0Config")
+ if inspect_only:
+ load_jax_model_and_print_keys(checkpoint_dir)
+ else:
+ if not output_path:
+ print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
+ return
+ convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/openpi/examples/droid/README.md b/openpi/examples/droid/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6a7edeee7d75eb0d45c87c6c8ef31c46da6f595d
--- /dev/null
+++ b/openpi/examples/droid/README.md
@@ -0,0 +1,84 @@
+# DROID Policies in openpi
+
+We offer instructions for:
+- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
+- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
+- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
+- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
+
+## Running DROID Inference
+
+This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
+
+
+### Step 1: Start a policy server
+
+Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
+
+1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
+2. Start the OpenPI server via the following command:
+
+```bash
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
+```
+
+You can also run the equivalent command below:
+
+```bash
+uv run scripts/serve_policy.py --env=DROID
+```
+
+### Step 2: Run the DROID robot
+
+1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
+2. On the control laptop, activate your DROID conda environment.
+3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
+4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
+5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
+6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
+7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping ` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
+
+```bash
+python3 scripts/main.py --remote_host= --remote_port= --external_camera="left"
+```
+
+The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
+
+## Troubleshooting
+
+| Issue | Solution |
+|-------|----------|
+| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping ` from the DROID laptop. |
+| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
+| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
+| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
+
+
+## Running Other Policies
+
+We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
+
+```
+# Train from pi0-FAST, using FAST tokenizer
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
+
+# Train from pi0, using flow matching
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
+
+# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
+
+# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
+
+# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
+
+# Trained from PaliGemma, using FSQ tokenizer.
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
+
+# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
+uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
+```
+
+You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
diff --git a/openpi/examples/droid/README_train.md b/openpi/examples/droid/README_train.md
new file mode 100644
index 0000000000000000000000000000000000000000..c8ad660656986eec5997d9cb86f846c1a327d11a
--- /dev/null
+++ b/openpi/examples/droid/README_train.md
@@ -0,0 +1,106 @@
+# Training on DROID
+
+Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
+(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
+
+In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
+for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
+
+## Install
+
+We need a few additional dependencies for RLDS data loading. Run:
+```bash
+uv sync --group rlds
+```
+
+## Download DROID dataset
+
+You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
+```
+gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /droid/1.0.1
+```
+
+Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
+
+You will need 1.8TB of disk storage to download the DROID RLDS dataset.
+
+## Run
+
+First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
+
+Then, compute normalization statistics (this will take ~10 minutes):
+```bash
+uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
+```
+
+Run training:
+```bash
+XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
+```
+
+**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
+Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
+Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
+
+
+## Compute Requirements
+
+Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
+If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
+
+We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
+
+
+## Data Filtering
+
+Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
+
+By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path=""` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
+
+**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
+
+## RoboArena
+
+Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
+
+If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
+
+
+# Fine-Tuning on Custom DROID Datasets
+
+Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
+
+Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
+
+
+## Step 1: Converting your custom DROID dataset to LeRobot
+
+We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
+```
+gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04
+```
+
+We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
+```
+gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json
+```
+
+For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
+
+Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
+```
+uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir
+```
+
+## Step 2: Run fine-tuning with your custom dataset
+
+Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
+You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
+
+To launch training:
+```
+uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
+```
+
+Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
+
diff --git a/openpi/examples/droid/compute_droid_nonidle_ranges.py b/openpi/examples/droid/compute_droid_nonidle_ranges.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd73ea52a29bd0148ffd6af996e9ee6470a66084
--- /dev/null
+++ b/openpi/examples/droid/compute_droid_nonidle_ranges.py
@@ -0,0 +1,103 @@
+"""
+Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
+that should be sampled during training (all others are filtered out).
+
+Filtering logic:
+We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
+(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
+this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
+ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
+filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
+
+This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
+yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
+"""
+
+import json
+import os
+from pathlib import Path
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_datasets as tfds
+from tqdm import tqdm
+
+os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
+
+builder = tfds.builder_from_directory(
+ # path to the `droid` directory (not its parent)
+ builder_dir="",
+)
+ds = builder.as_dataset(split="train", shuffle_files=False)
+tf.data.experimental.ignore_errors(ds)
+
+keep_ranges_path = ""
+
+min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
+min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
+filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
+
+keep_ranges_map = {}
+if Path(keep_ranges_path).exists():
+ with Path(keep_ranges_path).open("r") as f:
+ keep_ranges_map = json.load(f)
+ print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
+
+for ep_idx, ep in enumerate(tqdm(ds)):
+ recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
+ file_path = ep["episode_metadata"]["file_path"].numpy().decode()
+
+ key = f"{recording_folderpath}--{file_path}"
+ if key in keep_ranges_map:
+ continue
+
+ joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
+ joint_velocities = np.array(joint_velocities)
+
+ is_idle_array = np.hstack(
+ [np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
+ )
+
+ # Find what steps go from idle to non-idle and vice-versa
+ is_idle_padded = np.concatenate(
+ [[False], is_idle_array, [False]]
+ ) # Start and end with False, so idle at first step is a start of motion
+
+ is_idle_diff = np.diff(is_idle_padded.astype(int))
+ is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
+ is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
+
+ # Find which steps correspond to idle segments of length at least min_idle_len
+ true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
+ is_idle_true_starts = is_idle_true_starts[true_segment_masks]
+ is_idle_true_ends = is_idle_true_ends[true_segment_masks]
+
+ keep_mask = np.ones(len(joint_velocities), dtype=bool)
+ for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
+ keep_mask[start:end] = False
+
+ # Get all non-idle ranges of at least 16
+ # Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
+ keep_padded = np.concatenate([[False], keep_mask, [False]])
+
+ keep_diff = np.diff(keep_padded.astype(int))
+ keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
+ keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
+
+ # Find which steps correspond to non-idle segments of length at least min_non_idle_len
+ true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
+ keep_true_starts = keep_true_starts[true_segment_masks]
+ keep_true_ends = keep_true_ends[true_segment_masks]
+
+ # Add mapping from episode unique ID key to list of non-idle ranges to keep
+ keep_ranges_map[key] = []
+ for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
+ keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
+
+ if ep_idx % 1000 == 0:
+ with Path(keep_ranges_path).open("w") as f:
+ json.dump(keep_ranges_map, f)
+
+print("Done!")
+with Path(keep_ranges_path).open("w") as f:
+ json.dump(keep_ranges_map, f)
diff --git a/openpi/examples/droid/convert_droid_data_to_lerobot.py b/openpi/examples/droid/convert_droid_data_to_lerobot.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6078f5ca3cde13d9cf8c260c0eded0868ea28d5
--- /dev/null
+++ b/openpi/examples/droid/convert_droid_data_to_lerobot.py
@@ -0,0 +1,477 @@
+"""
+Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
+
+Usage:
+uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
+
+If you want to push your dataset to the Hugging Face Hub, you can use the following command:
+uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
+
+The resulting dataset will get saved to the $LEROBOT_HOME directory.
+"""
+
+from collections import defaultdict
+import copy
+import glob
+import json
+from pathlib import Path
+import shutil
+
+import cv2
+import h5py
+from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+import tyro
+
+REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
+
+
+def resize_image(image, size):
+ image = Image.fromarray(image)
+ return np.array(image.resize(size, resample=Image.BICUBIC))
+
+
+def main(data_dir: str, *, push_to_hub: bool = False):
+ # Clean up any existing dataset in the output directory
+ output_path = HF_LEROBOT_HOME / REPO_NAME
+ if output_path.exists():
+ shutil.rmtree(output_path)
+ data_dir = Path(data_dir)
+
+ # Create LeRobot dataset, define features to store
+ # We will follow the DROID data naming conventions here.
+ # LeRobot assumes that dtype of image data is `image`
+ dataset = LeRobotDataset.create(
+ repo_id=REPO_NAME,
+ robot_type="panda",
+ fps=15, # DROID data is typically recorded at 15fps
+ features={
+ # We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
+ "exterior_image_1_left": {
+ "dtype": "image",
+ "shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
+ "names": ["height", "width", "channel"],
+ },
+ "exterior_image_2_left": {
+ "dtype": "image",
+ "shape": (180, 320, 3),
+ "names": ["height", "width", "channel"],
+ },
+ "wrist_image_left": {
+ "dtype": "image",
+ "shape": (180, 320, 3),
+ "names": ["height", "width", "channel"],
+ },
+ "joint_position": {
+ "dtype": "float32",
+ "shape": (7,),
+ "names": ["joint_position"],
+ },
+ "gripper_position": {
+ "dtype": "float32",
+ "shape": (1,),
+ "names": ["gripper_position"],
+ },
+ "actions": {
+ "dtype": "float32",
+ "shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
+ "names": ["actions"],
+ },
+ },
+ image_writer_threads=10,
+ image_writer_processes=5,
+ )
+
+ # Load language annotations
+ # Note: we load the DROID language annotations for this example, but you can manually define them for your own data
+ with (data_dir / "aggregated-annotations-030724.json").open() as f:
+ language_annotations = json.load(f)
+
+ # Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
+ # We assume the following directory structure:
+ # RAW_DROID_PATH/
+ # - <...>/
+ # - recordings/
+ # - MP4/
+ # - .mp4 # single-view video of left stereo pair camera
+ # - trajectory.hdf5
+ # - <...>/
+ episode_paths = list(data_dir.glob("**/trajectory.h5"))
+ print(f"Found {len(episode_paths)} episodes for conversion")
+
+ # We will loop over each dataset_name and write episodes to the LeRobot dataset
+ for episode_path in tqdm(episode_paths, desc="Converting episodes"):
+ # Load raw data
+ recording_folderpath = episode_path.parent / "recordings" / "MP4"
+ trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
+
+ # To load the language instruction, we need to parse out the episode_id from the metadata file
+ # Again, you can modify this step for your own data, to load your own language instructions
+ metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
+ episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
+ language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
+ "language_instruction1"
+ ]
+ print(f"Converting episode with language instruction: {language_instruction}")
+
+ # Write to LeRobot dataset
+ for step in trajectory:
+ camera_type_dict = step["observation"]["camera_type"]
+ wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
+ exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
+ dataset.add_frame(
+ {
+ # Note: need to flip BGR --> RGB for loaded images
+ "exterior_image_1_left": resize_image(
+ step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
+ ),
+ "exterior_image_2_left": resize_image(
+ step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
+ ),
+ "wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
+ "joint_position": np.asarray(
+ step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
+ ),
+ "gripper_position": np.asarray(
+ step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
+ ),
+ # Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
+ "actions": np.concatenate(
+ [step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
+ ),
+ "task": language_instruction,
+ }
+ )
+ dataset.save_episode()
+
+ # Optionally push to the Hugging Face Hub
+ if push_to_hub:
+ dataset.push_to_hub(
+ tags=["libero", "panda", "rlds"],
+ private=False,
+ push_videos=True,
+ license="apache-2.0",
+ )
+
+
+##########################################################################################################
+################ The rest of this file are functions to parse the raw DROID data #########################
+################ You don't need to worry about understanding this part #########################
+################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
+##########################################################################################################
+
+
+camera_type_dict = {
+ "hand_camera_id": 0,
+ "varied_camera_1_id": 1,
+ "varied_camera_2_id": 1,
+}
+
+camera_type_to_string_dict = {
+ 0: "hand_camera",
+ 1: "varied_camera",
+ 2: "fixed_camera",
+}
+
+
+def get_camera_type(cam_id):
+ if cam_id not in camera_type_dict:
+ return None
+ type_int = camera_type_dict[cam_id]
+ return camera_type_to_string_dict[type_int]
+
+
+class MP4Reader:
+ def __init__(self, filepath, serial_number):
+ # Save Parameters #
+ self.serial_number = serial_number
+ self._index = 0
+
+ # Open Video Reader #
+ self._mp4_reader = cv2.VideoCapture(filepath)
+ if not self._mp4_reader.isOpened():
+ raise RuntimeError("Corrupted MP4 File")
+
+ def set_reading_parameters(
+ self,
+ image=True, # noqa: FBT002
+ concatenate_images=False, # noqa: FBT002
+ resolution=(0, 0),
+ resize_func=None,
+ ):
+ # Save Parameters #
+ self.image = image
+ self.concatenate_images = concatenate_images
+ self.resolution = resolution
+ self.resize_func = cv2.resize
+ self.skip_reading = not image
+ if self.skip_reading:
+ return
+
+ def get_frame_resolution(self):
+ width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
+ height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
+ return (width, height)
+
+ def get_frame_count(self):
+ if self.skip_reading:
+ return 0
+ return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
+
+ def set_frame_index(self, index):
+ if self.skip_reading:
+ return
+
+ if index < self._index:
+ self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
+ self._index = index
+
+ while self._index < index:
+ self.read_camera(ignore_data=True)
+
+ def _process_frame(self, frame):
+ frame = copy.deepcopy(frame)
+ if self.resolution == (0, 0):
+ return frame
+ return self.resize_func(frame, self.resolution)
+
+ def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
+ # Skip if Read Unnecesary #
+ if self.skip_reading:
+ return {}
+
+ # Read Camera #
+ success, frame = self._mp4_reader.read()
+
+ self._index += 1
+ if not success:
+ return None
+ if ignore_data:
+ return None
+
+ # Return Data #
+ data_dict = {}
+
+ if self.concatenate_images or "stereo" not in self.serial_number:
+ data_dict["image"] = {self.serial_number: self._process_frame(frame)}
+ else:
+ single_width = frame.shape[1] // 2
+ data_dict["image"] = {
+ self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
+ self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
+ }
+
+ return data_dict
+
+ def disable_camera(self):
+ if hasattr(self, "_mp4_reader"):
+ self._mp4_reader.release()
+
+
+class RecordedMultiCameraWrapper:
+ def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
+ # Save Camera Info #
+ self.camera_kwargs = camera_kwargs
+
+ # Open Camera Readers #
+ mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
+ all_filepaths = mp4_filepaths
+
+ self.camera_dict = {}
+ for f in all_filepaths:
+ serial_number = f.split("/")[-1][:-4]
+ cam_type = get_camera_type(serial_number)
+ camera_kwargs.get(cam_type, {})
+
+ if f.endswith(".mp4"):
+ Reader = MP4Reader # noqa: N806
+ else:
+ raise ValueError
+
+ self.camera_dict[serial_number] = Reader(f, serial_number)
+
+ def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
+ full_obs_dict = defaultdict(dict)
+
+ # Read Cameras In Randomized Order #
+ all_cam_ids = list(self.camera_dict.keys())
+ # random.shuffle(all_cam_ids)
+
+ for cam_id in all_cam_ids:
+ if "stereo" in cam_id:
+ continue
+ try:
+ cam_type = camera_type_dict[cam_id]
+ except KeyError:
+ print(f"{self.camera_dict} -- {camera_type_dict}")
+ raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
+ curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
+ self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
+
+ timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
+ if index is not None:
+ self.camera_dict[cam_id].set_frame_index(index)
+
+ data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
+
+ # Process Returned Data #
+ if data_dict is None:
+ return None
+ for key in data_dict:
+ full_obs_dict[key].update(data_dict[key])
+
+ return full_obs_dict
+
+
+def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
+ length = None
+
+ for key in hdf5_file:
+ if key in keys_to_ignore:
+ continue
+
+ curr_data = hdf5_file[key]
+ if isinstance(curr_data, h5py.Group):
+ curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
+ elif isinstance(curr_data, h5py.Dataset):
+ curr_length = len(curr_data)
+ else:
+ raise ValueError
+
+ if length is None:
+ length = curr_length
+ assert curr_length == length
+
+ return length
+
+
+def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
+ data_dict = {}
+
+ for key in hdf5_file:
+ if key in keys_to_ignore:
+ continue
+
+ curr_data = hdf5_file[key]
+ if isinstance(curr_data, h5py.Group):
+ data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
+ elif isinstance(curr_data, h5py.Dataset):
+ data_dict[key] = curr_data[index]
+ else:
+ raise ValueError
+
+ return data_dict
+
+
+class TrajectoryReader:
+ def __init__(self, filepath, read_images=True): # noqa: FBT002
+ self._hdf5_file = h5py.File(filepath, "r")
+ is_video_folder = "observations/videos" in self._hdf5_file
+ self._read_images = read_images and is_video_folder
+ self._length = get_hdf5_length(self._hdf5_file)
+ self._video_readers = {}
+ self._index = 0
+
+ def length(self):
+ return self._length
+
+ def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
+ # Make Sure We Read Within Range #
+ if index is None:
+ index = self._index
+ else:
+ assert not self._read_images
+ self._index = index
+ assert index < self._length
+
+ # Load Low Dimensional Data #
+ keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
+ timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
+
+ # Increment Read Index #
+ self._index += 1
+
+ # Return Timestep #
+ return timestep
+
+ def close(self):
+ self._hdf5_file.close()
+
+
+def load_trajectory(
+ filepath=None,
+ read_cameras=True, # noqa: FBT002
+ recording_folderpath=None,
+ camera_kwargs={}, # noqa: B006
+ remove_skipped_steps=False, # noqa: FBT002
+ num_samples_per_traj=None,
+ num_samples_per_traj_coeff=1.5,
+):
+ read_recording_folderpath = read_cameras and (recording_folderpath is not None)
+
+ traj_reader = TrajectoryReader(filepath)
+ if read_recording_folderpath:
+ camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
+
+ horizon = traj_reader.length()
+ timestep_list = []
+
+ # Choose Timesteps To Save #
+ if num_samples_per_traj:
+ num_to_save = num_samples_per_traj
+ if remove_skipped_steps:
+ num_to_save = int(num_to_save * num_samples_per_traj_coeff)
+ max_size = min(num_to_save, horizon)
+ indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
+ else:
+ indices_to_save = np.arange(horizon)
+
+ # Iterate Over Trajectory #
+ for i in indices_to_save:
+ # Get HDF5 Data #
+ timestep = traj_reader.read_timestep(index=i)
+
+ # If Applicable, Get Recorded Data #
+ if read_recording_folderpath:
+ timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
+ camera_type_dict = {
+ k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
+ }
+ camera_obs = camera_reader.read_cameras(
+ index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
+ )
+ camera_failed = camera_obs is None
+
+ # Add Data To Timestep If Successful #
+ if camera_failed:
+ break
+ timestep["observation"].update(camera_obs)
+
+ # Filter Steps #
+ step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
+ delete_skipped_step = step_skipped and remove_skipped_steps
+
+ # Save Filtered Timesteps #
+ if delete_skipped_step:
+ del timestep
+ else:
+ timestep_list.append(timestep)
+
+ # Remove Extra Transitions #
+ timestep_list = np.array(timestep_list)
+ if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
+ ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
+ timestep_list = timestep_list[ind_to_keep]
+
+ # Close Readers #
+ traj_reader.close()
+
+ # Return Data #
+ return timestep_list
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/openpi/examples/droid/main.py b/openpi/examples/droid/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..93b76546cedb58e51487431a8ebe83cc87caee5a
--- /dev/null
+++ b/openpi/examples/droid/main.py
@@ -0,0 +1,246 @@
+# ruff: noqa
+
+import contextlib
+import dataclasses
+import datetime
+import faulthandler
+import os
+import signal
+import time
+from moviepy.editor import ImageSequenceClip
+import numpy as np
+from openpi_client import image_tools
+from openpi_client import websocket_client_policy
+import pandas as pd
+from PIL import Image
+from droid.robot_env import RobotEnv
+import tqdm
+import tyro
+
+faulthandler.enable()
+
+# DROID data collection frequency -- we slow down execution to match this frequency
+DROID_CONTROL_FREQUENCY = 15
+
+
+@dataclasses.dataclass
+class Args:
+ # Hardware parameters
+ left_camera_id: str = "" # e.g., "24259877"
+ right_camera_id: str = "" # e.g., "24514023"
+ wrist_camera_id: str = "" # e.g., "13062452"
+
+ # Policy parameters
+ external_camera: str | None = (
+ None # which external camera should be fed to the policy, choose from ["left", "right"]
+ )
+
+ # Rollout parameters
+ max_timesteps: int = 600
+ # How many actions to execute from a predicted action chunk before querying policy server again
+ # 8 is usually a good default (equals 0.5 seconds of action execution).
+ open_loop_horizon: int = 8
+
+ # Remote server parameters
+ remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
+ remote_port: int = (
+ 8000 # point this to the port of the policy server, default server port for openpi servers is 8000
+ )
+
+
+# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
+# waiting for a new action chunk, it will raise an exception and the server connection dies.
+# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
+@contextlib.contextmanager
+def prevent_keyboard_interrupt():
+ """Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
+ interrupted = False
+ original_handler = signal.getsignal(signal.SIGINT)
+
+ def handler(signum, frame):
+ nonlocal interrupted
+ interrupted = True
+
+ signal.signal(signal.SIGINT, handler)
+ try:
+ yield
+ finally:
+ signal.signal(signal.SIGINT, original_handler)
+ if interrupted:
+ raise KeyboardInterrupt
+
+
+def main(args: Args):
+ # Make sure external camera is specified by user -- we only use one external camera for the policy
+ assert (
+ args.external_camera is not None and args.external_camera in ["left", "right"]
+ ), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
+
+ # Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
+ env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
+ print("Created the droid env!")
+
+ # Connect to the policy server
+ policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
+
+ df = pd.DataFrame(columns=["success", "duration", "video_filename"])
+
+ while True:
+ instruction = input("Enter instruction: ")
+
+ # Rollout parameters
+ actions_from_chunk_completed = 0
+ pred_action_chunk = None
+
+ # Prepare to save video of rollout
+ timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
+ video = []
+ bar = tqdm.tqdm(range(args.max_timesteps))
+ print("Running rollout... press Ctrl+C to stop early.")
+ for t_step in bar:
+ start_time = time.time()
+ try:
+ # Get the current observation
+ curr_obs = _extract_observation(
+ args,
+ env.get_observation(),
+ # Save the first observation to disk
+ save_to_disk=t_step == 0,
+ )
+
+ video.append(curr_obs[f"{args.external_camera}_image"])
+
+ # Send websocket request to policy server if it's time to predict a new chunk
+ if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
+ actions_from_chunk_completed = 0
+
+ # We resize images on the robot laptop to minimize the amount of data sent to the policy server
+ # and improve latency.
+ request_data = {
+ "observation/exterior_image_1_left": image_tools.resize_with_pad(
+ curr_obs[f"{args.external_camera}_image"], 224, 224
+ ),
+ "observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
+ "observation/joint_position": curr_obs["joint_position"],
+ "observation/gripper_position": curr_obs["gripper_position"],
+ "prompt": instruction,
+ }
+
+ # Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
+ # Ctrl+C will be handled after the server call is complete
+ with prevent_keyboard_interrupt():
+ # this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
+ pred_action_chunk = policy_client.infer(request_data)["actions"]
+ assert pred_action_chunk.shape == (10, 8)
+
+ # Select current action to execute from chunk
+ action = pred_action_chunk[actions_from_chunk_completed]
+ actions_from_chunk_completed += 1
+
+ # Binarize gripper action
+ if action[-1].item() > 0.5:
+ # action[-1] = 1.0
+ action = np.concatenate([action[:-1], np.ones((1,))])
+ else:
+ # action[-1] = 0.0
+ action = np.concatenate([action[:-1], np.zeros((1,))])
+
+ # clip all dimensions of action to [-1, 1]
+ action = np.clip(action, -1, 1)
+
+ env.step(action)
+
+ # Sleep to match DROID data collection frequency
+ elapsed_time = time.time() - start_time
+ if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
+ time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
+ except KeyboardInterrupt:
+ break
+
+ video = np.stack(video)
+ save_filename = "video_" + timestamp
+ ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
+
+ success: str | float | None = None
+ while not isinstance(success, float):
+ success = input(
+ "Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
+ )
+ if success == "y":
+ success = 1.0
+ elif success == "n":
+ success = 0.0
+
+ success = float(success) / 100
+ if not (0 <= success <= 1):
+ print(f"Success must be a number in [0, 100] but got: {success * 100}")
+
+ df = df.append(
+ {
+ "success": success,
+ "duration": t_step,
+ "video_filename": save_filename,
+ },
+ ignore_index=True,
+ )
+
+ if input("Do one more eval? (enter y or n) ").lower() != "y":
+ break
+ env.reset()
+
+ os.makedirs("results", exist_ok=True)
+ timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
+ csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
+ df.to_csv(csv_filename)
+ print(f"Results saved to {csv_filename}")
+
+
+def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
+ image_observations = obs_dict["image"]
+ left_image, right_image, wrist_image = None, None, None
+ for key in image_observations:
+ # Note the "left" below refers to the left camera in the stereo pair.
+ # The model is only trained on left stereo cams, so we only feed those.
+ if args.left_camera_id in key and "left" in key:
+ left_image = image_observations[key]
+ elif args.right_camera_id in key and "left" in key:
+ right_image = image_observations[key]
+ elif args.wrist_camera_id in key and "left" in key:
+ wrist_image = image_observations[key]
+
+ # Drop the alpha dimension
+ left_image = left_image[..., :3]
+ right_image = right_image[..., :3]
+ wrist_image = wrist_image[..., :3]
+
+ # Convert to RGB
+ left_image = left_image[..., ::-1]
+ right_image = right_image[..., ::-1]
+ wrist_image = wrist_image[..., ::-1]
+
+ # In addition to image observations, also capture the proprioceptive state
+ robot_state = obs_dict["robot_state"]
+ cartesian_position = np.array(robot_state["cartesian_position"])
+ joint_position = np.array(robot_state["joint_positions"])
+ gripper_position = np.array([robot_state["gripper_position"]])
+
+ # Save the images to disk so that they can be viewed live while the robot is running
+ # Create one combined image to make live viewing easy
+ if save_to_disk:
+ combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
+ combined_image = Image.fromarray(combined_image)
+ combined_image.save("robot_camera_views.png")
+
+ return {
+ "left_image": left_image,
+ "right_image": right_image,
+ "wrist_image": wrist_image,
+ "cartesian_position": cartesian_position,
+ "joint_position": joint_position,
+ "gripper_position": gripper_position,
+ }
+
+
+if __name__ == "__main__":
+ args: Args = tyro.cli(Args)
+ main(args)
diff --git a/openpi/examples/inference.ipynb b/openpi/examples/inference.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..2f125880f97dd433134142b2892ba324f592caa8
--- /dev/null
+++ b/openpi/examples/inference.ipynb
@@ -0,0 +1,137 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import dataclasses\n",
+ "\n",
+ "import jax\n",
+ "\n",
+ "from openpi.models import model as _model\n",
+ "from openpi.policies import droid_policy\n",
+ "from openpi.policies import policy_config as _policy_config\n",
+ "from openpi.shared import download\n",
+ "from openpi.training import config as _config\n",
+ "from openpi.training import data_loader as _data_loader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Policy inference\n",
+ "\n",
+ "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = _config.get_config(\"pi0_fast_droid\")\n",
+ "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
+ "\n",
+ "# Create a trained policy.\n",
+ "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
+ "\n",
+ "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
+ "example = droid_policy.make_droid_example()\n",
+ "result = policy.infer(example)\n",
+ "\n",
+ "# Delete the policy to free up memory.\n",
+ "del policy\n",
+ "\n",
+ "print(\"Actions shape:\", result[\"actions\"].shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Working with a live model\n",
+ "\n",
+ "\n",
+ "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config = _config.get_config(\"pi0_aloha_sim\")\n",
+ "\n",
+ "checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
+ "key = jax.random.key(0)\n",
+ "\n",
+ "# Create a model from the checkpoint.\n",
+ "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
+ "\n",
+ "# We can create fake observations and actions to test the model.\n",
+ "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
+ "\n",
+ "# Sample actions from the model.\n",
+ "loss = model.compute_loss(key, obs, act)\n",
+ "print(\"Loss shape:\", loss.shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Reduce the batch size to reduce memory usage.\n",
+ "config = dataclasses.replace(config, batch_size=2)\n",
+ "\n",
+ "# Load a single batch of data. This is the same data that will be used during training.\n",
+ "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
+ "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
+ "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
+ "obs, act = next(iter(loader))\n",
+ "\n",
+ "# Sample actions from the model.\n",
+ "loss = model.compute_loss(key, obs, act)\n",
+ "\n",
+ "# Delete the model to free up memory.\n",
+ "del model\n",
+ "\n",
+ "print(\"Loss shape:\", loss.shape)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/openpi/examples/libero/Dockerfile b/openpi/examples/libero/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..9750796a2ce8854e0effc151a9d736437ca00190
--- /dev/null
+++ b/openpi/examples/libero/Dockerfile
@@ -0,0 +1,59 @@
+# Dockerfile for the LIBERO benchmark.
+
+# Build the container:
+# docker build . -t libero -f examples/libero/Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
+
+FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
+COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
+
+RUN apt-get update && \
+ apt-get install -y \
+ make \
+ g++ \
+ clang \
+ libosmesa6-dev \
+ libgl1-mesa-glx \
+ libglew-dev \
+ libglfw3-dev \
+ libgles2-mesa-dev \
+ libglib2.0-0 \
+ libsm6 \
+ libxrender1 \
+ libxext6
+
+WORKDIR /app
+
+# Copy from the cache instead of linking since it's a mounted volume
+ENV UV_LINK_MODE=copy
+
+# Write the virtual environment outside of the project directory so it doesn't
+# leak out of the container when we mount the application code.
+ENV UV_PROJECT_ENVIRONMENT=/.venv
+
+# Copy the requirements files so we can install dependencies.
+# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
+# This strategy is best for development-style usage.
+COPY ./examples/libero/requirements.txt /tmp/requirements.txt
+COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
+COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
+
+# Install python dependencies.
+RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
+RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
+ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
+
+# Create a default config file to avoid an input prompt from LIBERO's init script.
+# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
+ENV LIBERO_CONFIG_PATH=/tmp/libero
+RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
+benchmark_root: /app/third_party/libero/libero/libero
+bddl_files: /app/third_party/libero/libero/libero/bddl_files
+init_states: /app/third_party/libero/libero/libero/init_files
+datasets: /app/third_party/libero/libero/datasets
+assets: /app/third_party/libero/libero/libero/assets
+EOF
+
+CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]
diff --git a/openpi/examples/libero/README.md b/openpi/examples/libero/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a2d1d935ca830cd9ce0a5d363e7dd8ed0cb324e5
--- /dev/null
+++ b/openpi/examples/libero/README.md
@@ -0,0 +1,71 @@
+# LIBERO Benchmark
+
+This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
+
+Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
+
+This example requires git submodules to be initialized. Don't forget to run:
+
+```bash
+git submodule update --init --recursive
+```
+
+## With Docker (recommended)
+
+```bash
+# Grant access to the X11 server:
+sudo xhost +local:docker
+
+# To run with the default checkpoint and task suite:
+SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
+
+# To run with glx for Mujoco instead (use this if you have egl errors):
+MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
+```
+
+You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
+For example:
+
+```bash
+# To load a custom checkpoint (located in the top-level openpi/ directory):
+export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
+
+# To run the libero_10 task suite:
+export CLIENT_ARGS="--args.task-suite-name libero_10"
+```
+
+## Without Docker (not recommended)
+
+Terminal window 1:
+
+```bash
+# Create virtual environment
+uv venv --python 3.8 examples/libero/.venv
+source examples/libero/.venv/bin/activate
+uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
+uv pip install -e packages/openpi-client
+uv pip install -e third_party/libero
+export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
+
+# Run the simulation
+python examples/libero/main.py
+
+# To run with glx for Mujoco instead (use this if you have egl errors):
+MUJOCO_GL=glx python examples/libero/main.py
+```
+
+Terminal window 2:
+
+```bash
+# Run the server
+uv run scripts/serve_policy.py --env LIBERO
+```
+
+## Results
+
+If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
+checkpoint was trained in openpi with the `pi05_libero` config.
+
+| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
+|-------|---------------|---------------|-------------|-----------|---------|
+| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85
diff --git a/openpi/examples/libero/compose.yml b/openpi/examples/libero/compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3498ce5207628756d01d77753676648b2a6a5c4f
--- /dev/null
+++ b/openpi/examples/libero/compose.yml
@@ -0,0 +1,54 @@
+# Run with:
+# docker compose -f examples/libero/compose.yml up --build
+services:
+ runtime:
+ image: libero
+ depends_on:
+ - openpi_server
+ build:
+ context: ../..
+ dockerfile: examples/libero/Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ privileged: true
+ volumes:
+ - $PWD:/app
+ - ../../data:/data
+ - /tmp/.X11-unix:/tmp/.X11-unix:ro
+ environment:
+ - CLIENT_ARGS
+ - DISPLAY=$DISPLAY
+ - MUJOCO_GL=${MUJOCO_GL:-egl}
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
+
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/openpi/examples/libero/convert_libero_data_to_lerobot.py b/openpi/examples/libero/convert_libero_data_to_lerobot.py
new file mode 100644
index 0000000000000000000000000000000000000000..51db6f138e6a5e290402ad42d940d7e232232bbd
--- /dev/null
+++ b/openpi/examples/libero/convert_libero_data_to_lerobot.py
@@ -0,0 +1,104 @@
+"""
+Minimal example script for converting a dataset to LeRobot format.
+
+We use the Libero dataset (stored in RLDS) for this example, but it can be easily
+modified for any other data you have saved in a custom format.
+
+Usage:
+uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
+
+If you want to push your dataset to the Hugging Face Hub, you can use the following command:
+uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
+
+Note: to run the script, you need to install tensorflow_datasets:
+`uv pip install tensorflow tensorflow_datasets`
+
+You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
+The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
+Running this conversion script will take approximately 30 minutes.
+"""
+
+import shutil
+
+from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
+import tensorflow_datasets as tfds
+import tyro
+
+REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
+RAW_DATASET_NAMES = [
+ "libero_10_no_noops",
+ "libero_goal_no_noops",
+ "libero_object_no_noops",
+ "libero_spatial_no_noops",
+] # For simplicity we will combine multiple Libero datasets into one training dataset
+
+
+def main(data_dir: str, *, push_to_hub: bool = False):
+ # Clean up any existing dataset in the output directory
+ output_path = HF_LEROBOT_HOME / REPO_NAME
+ if output_path.exists():
+ shutil.rmtree(output_path)
+
+ # Create LeRobot dataset, define features to store
+ # OpenPi assumes that proprio is stored in `state` and actions in `action`
+ # LeRobot assumes that dtype of image data is `image`
+ dataset = LeRobotDataset.create(
+ repo_id=REPO_NAME,
+ robot_type="panda",
+ fps=10,
+ features={
+ "image": {
+ "dtype": "image",
+ "shape": (256, 256, 3),
+ "names": ["height", "width", "channel"],
+ },
+ "wrist_image": {
+ "dtype": "image",
+ "shape": (256, 256, 3),
+ "names": ["height", "width", "channel"],
+ },
+ "state": {
+ "dtype": "float32",
+ "shape": (8,),
+ "names": ["state"],
+ },
+ "actions": {
+ "dtype": "float32",
+ "shape": (7,),
+ "names": ["actions"],
+ },
+ },
+ image_writer_threads=10,
+ image_writer_processes=5,
+ )
+
+ # Loop over raw Libero datasets and write episodes to the LeRobot dataset
+ # You can modify this for your own data format
+ for raw_dataset_name in RAW_DATASET_NAMES:
+ raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
+ for episode in raw_dataset:
+ for step in episode["steps"].as_numpy_iterator():
+ dataset.add_frame(
+ {
+ "image": step["observation"]["image"],
+ "wrist_image": step["observation"]["wrist_image"],
+ "state": step["observation"]["state"],
+ "actions": step["action"],
+ "task": step["language_instruction"].decode(),
+ }
+ )
+ dataset.save_episode()
+
+ # Optionally push to the Hugging Face Hub
+ if push_to_hub:
+ dataset.push_to_hub(
+ tags=["libero", "panda", "rlds"],
+ private=False,
+ push_videos=True,
+ license="apache-2.0",
+ )
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/openpi/examples/libero/main.py b/openpi/examples/libero/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc015a61740f2d3174152bebb60176fac52f3f40
--- /dev/null
+++ b/openpi/examples/libero/main.py
@@ -0,0 +1,219 @@
+import collections
+import dataclasses
+import logging
+import math
+import pathlib
+
+import imageio
+from libero.libero import benchmark
+from libero.libero import get_libero_path
+from libero.libero.envs import OffScreenRenderEnv
+import numpy as np
+from openpi_client import image_tools
+from openpi_client import websocket_client_policy as _websocket_client_policy
+import tqdm
+import tyro
+
+LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
+LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
+
+
+@dataclasses.dataclass
+class Args:
+ #################################################################################################################
+ # Model server parameters
+ #################################################################################################################
+ host: str = "0.0.0.0"
+ port: int = 8000
+ resize_size: int = 224
+ replan_steps: int = 5
+
+ #################################################################################################################
+ # LIBERO environment-specific parameters
+ #################################################################################################################
+ task_suite_name: str = (
+ "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
+ )
+ num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
+ num_trials_per_task: int = 50 # Number of rollouts per task
+
+ #################################################################################################################
+ # Utils
+ #################################################################################################################
+ video_out_path: str = "data/libero/videos" # Path to save videos
+
+ seed: int = 7 # Random Seed (for reproducibility)
+
+
+def eval_libero(args: Args) -> None:
+ # Set random seed
+ np.random.seed(args.seed)
+
+ # Initialize LIBERO task suite
+ benchmark_dict = benchmark.get_benchmark_dict()
+ task_suite = benchmark_dict[args.task_suite_name]()
+ num_tasks_in_suite = task_suite.n_tasks
+ logging.info(f"Task suite: {args.task_suite_name}")
+
+ pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
+
+ if args.task_suite_name == "libero_spatial":
+ max_steps = 220 # longest training demo has 193 steps
+ elif args.task_suite_name == "libero_object":
+ max_steps = 280 # longest training demo has 254 steps
+ elif args.task_suite_name == "libero_goal":
+ max_steps = 300 # longest training demo has 270 steps
+ elif args.task_suite_name == "libero_10":
+ max_steps = 520 # longest training demo has 505 steps
+ elif args.task_suite_name == "libero_90":
+ max_steps = 400 # longest training demo has 373 steps
+ else:
+ raise ValueError(f"Unknown task suite: {args.task_suite_name}")
+
+ client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
+
+ # Start evaluation
+ total_episodes, total_successes = 0, 0
+ for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
+ # Get task
+ task = task_suite.get_task(task_id)
+
+ # Get default LIBERO initial states
+ initial_states = task_suite.get_task_init_states(task_id)
+
+ # Initialize LIBERO environment and task description
+ env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
+
+ # Start episodes
+ task_episodes, task_successes = 0, 0
+ for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
+ logging.info(f"\nTask: {task_description}")
+
+ # Reset environment
+ env.reset()
+ action_plan = collections.deque()
+
+ # Set initial states
+ obs = env.set_init_state(initial_states[episode_idx])
+
+ # Setup
+ t = 0
+ replay_images = []
+
+ logging.info(f"Starting episode {task_episodes+1}...")
+ while t < max_steps + args.num_steps_wait:
+ try:
+ # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
+ # and we need to wait for them to fall
+ if t < args.num_steps_wait:
+ obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
+ t += 1
+ continue
+
+ # Get preprocessed image
+ # IMPORTANT: rotate 180 degrees to match train preprocessing
+ img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
+ wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
+ img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
+ )
+ wrist_img = image_tools.convert_to_uint8(
+ image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
+ )
+
+ # Save preprocessed image for replay video
+ replay_images.append(img)
+
+ if not action_plan:
+ # Finished executing previous action chunk -- compute new chunk
+ # Prepare observations dict
+ element = {
+ "observation/image": img,
+ "observation/wrist_image": wrist_img,
+ "observation/state": np.concatenate(
+ (
+ obs["robot0_eef_pos"],
+ _quat2axisangle(obs["robot0_eef_quat"]),
+ obs["robot0_gripper_qpos"],
+ )
+ ),
+ "prompt": str(task_description),
+ }
+
+ # Query model to get action
+ action_chunk = client.infer(element)["actions"]
+ assert (
+ len(action_chunk) >= args.replan_steps
+ ), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
+ action_plan.extend(action_chunk[: args.replan_steps])
+
+ action = action_plan.popleft()
+
+ # Execute action in environment
+ obs, reward, done, info = env.step(action.tolist())
+ if done:
+ task_successes += 1
+ total_successes += 1
+ break
+ t += 1
+
+ except Exception as e:
+ logging.error(f"Caught exception: {e}")
+ break
+
+ task_episodes += 1
+ total_episodes += 1
+
+ # Save a replay video of the episode
+ suffix = "success" if done else "failure"
+ task_segment = task_description.replace(" ", "_")
+ imageio.mimwrite(
+ pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
+ [np.asarray(x) for x in replay_images],
+ fps=10,
+ )
+
+ # Log current results
+ logging.info(f"Success: {done}")
+ logging.info(f"# episodes completed so far: {total_episodes}")
+ logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
+
+ # Log final results
+ logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
+ logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
+
+ logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
+ logging.info(f"Total episodes: {total_episodes}")
+
+
+def _get_libero_env(task, resolution, seed):
+ """Initializes and returns the LIBERO environment, along with the task description."""
+ task_description = task.language
+ task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
+ env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
+ env = OffScreenRenderEnv(**env_args)
+ env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
+ return env, task_description
+
+
+def _quat2axisangle(quat):
+ """
+ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
+ """
+ # clip quaternion
+ if quat[3] > 1.0:
+ quat[3] = 1.0
+ elif quat[3] < -1.0:
+ quat[3] = -1.0
+
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(den, 0.0):
+ # This is (close to) a zero degree rotation, immediately return
+ return np.zeros(3)
+
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ tyro.cli(eval_libero)
diff --git a/openpi/examples/libero/requirements.in b/openpi/examples/libero/requirements.in
new file mode 100644
index 0000000000000000000000000000000000000000..149006564d32ef12d0e7e48af917b9dbfa584730
--- /dev/null
+++ b/openpi/examples/libero/requirements.in
@@ -0,0 +1,11 @@
+imageio[ffmpeg]
+numpy==1.22.4
+tqdm
+tyro
+PyYaml
+opencv-python==4.6.0.66
+torch==1.11.0+cu113
+torchvision==0.12.0+cu113
+torchaudio==0.11.0+cu113
+robosuite==1.4.1
+matplotlib==3.5.3
diff --git a/openpi/examples/libero/requirements.txt b/openpi/examples/libero/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1a52b42887cefefdad0f04da876f48a7c09be3a7
--- /dev/null
+++ b/openpi/examples/libero/requirements.txt
@@ -0,0 +1,136 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
+absl-py==2.1.0
+ # via mujoco
+certifi==2024.12.14
+ # via requests
+charset-normalizer==3.4.0
+ # via requests
+cycler==0.12.1
+ # via matplotlib
+docstring-parser==0.16
+ # via tyro
+etils==1.3.0
+ # via mujoco
+eval-type-backport==0.2.0
+ # via tyro
+evdev==1.7.1
+ # via pynput
+fonttools==4.55.3
+ # via matplotlib
+glfw==1.12.0
+ # via mujoco
+idna==3.10
+ # via requests
+imageio==2.35.1
+ # via -r examples/libero/requirements.in
+imageio-ffmpeg==0.5.1
+ # via imageio
+importlib-metadata==8.5.0
+ # via typeguard
+importlib-resources==6.4.5
+ # via etils
+kiwisolver==1.4.7
+ # via matplotlib
+llvmlite==0.36.0
+ # via numba
+markdown-it-py==3.0.0
+ # via rich
+matplotlib==3.5.3
+ # via -r examples/libero/requirements.in
+mdurl==0.1.2
+ # via markdown-it-py
+mujoco==3.2.3
+ # via robosuite
+numba==0.53.1
+ # via robosuite
+numpy==1.22.4
+ # via
+ # -r examples/libero/requirements.in
+ # imageio
+ # matplotlib
+ # mujoco
+ # numba
+ # opencv-python
+ # robosuite
+ # scipy
+ # torchvision
+opencv-python==4.6.0.66
+ # via
+ # -r examples/libero/requirements.in
+ # robosuite
+packaging==24.2
+ # via matplotlib
+pillow==10.4.0
+ # via
+ # imageio
+ # matplotlib
+ # robosuite
+ # torchvision
+psutil==6.1.0
+ # via imageio
+pygments==2.18.0
+ # via rich
+pynput==1.7.7
+ # via robosuite
+pyopengl==3.1.7
+ # via mujoco
+pyparsing==3.1.4
+ # via matplotlib
+python-dateutil==2.9.0.post0
+ # via matplotlib
+python-xlib==0.33
+ # via pynput
+pyyaml==6.0.2
+ # via -r examples/libero/requirements.in
+requests==2.32.3
+ # via torchvision
+rich==13.9.4
+ # via tyro
+robosuite==1.4.1
+ # via -r examples/libero/requirements.in
+scipy==1.10.1
+ # via robosuite
+setuptools==75.3.0
+ # via
+ # imageio-ffmpeg
+ # numba
+shtab==1.7.1
+ # via tyro
+six==1.17.0
+ # via
+ # pynput
+ # python-dateutil
+ # python-xlib
+termcolor==2.4.0
+ # via robosuite
+torch==1.11.0+cu113
+ # via
+ # -r examples/libero/requirements.in
+ # torchaudio
+ # torchvision
+torchaudio==0.11.0+cu113
+ # via -r examples/libero/requirements.in
+torchvision==0.12.0+cu113
+ # via -r examples/libero/requirements.in
+tqdm==4.67.1
+ # via -r examples/libero/requirements.in
+typeguard==4.4.0
+ # via tyro
+typing-extensions==4.12.2
+ # via
+ # etils
+ # rich
+ # torch
+ # torchvision
+ # typeguard
+ # tyro
+tyro==0.9.2
+ # via -r examples/libero/requirements.in
+urllib3==2.2.3
+ # via requests
+zipp==3.20.2
+ # via
+ # etils
+ # importlib-metadata
+ # importlib-resources
diff --git a/openpi/examples/policy_records.ipynb b/openpi/examples/policy_records.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ee6f268cf733b93fc732d2e7e754bf77b3dceb30
--- /dev/null
+++ b/openpi/examples/policy_records.ipynb
@@ -0,0 +1,134 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pathlib\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "record_path = pathlib.Path(\"../policy_records\")\n",
+ "num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
+ "\n",
+ "records = []\n",
+ "for i in range(num_steps):\n",
+ " record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
+ " records.append(record)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"length of records\", len(records))\n",
+ "print(\"keys in records\", records[0].keys())\n",
+ "\n",
+ "for k in records[0]:\n",
+ " print(f\"{k} shape: {records[0][k].shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from PIL import Image\n",
+ "\n",
+ "\n",
+ "def get_image(step: int, idx: int = 0):\n",
+ " img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
+ " return img[idx].transpose(1, 2, 0)\n",
+ "\n",
+ "\n",
+ "def show_image(step: int, idx_lst: list[int]):\n",
+ " imgs = [get_image(step, idx) for idx in idx_lst]\n",
+ " return Image.fromarray(np.hstack(imgs))\n",
+ "\n",
+ "\n",
+ "for i in range(2):\n",
+ " display(show_image(i, [0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "\n",
+ "def get_axis(name, axis):\n",
+ " return np.array([record[name][axis] for record in records])\n",
+ "\n",
+ "\n",
+ "# qpos is [..., 14] of type float:\n",
+ "# 0-5: left arm joint angles\n",
+ "# 6: left arm gripper\n",
+ "# 7-12: right arm joint angles\n",
+ "# 13: right arm gripper\n",
+ "names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
+ "\n",
+ "\n",
+ "def make_data():\n",
+ " cur_dim = 0\n",
+ " in_data = {}\n",
+ " out_data = {}\n",
+ " for name, dim_size in names:\n",
+ " for i in range(dim_size):\n",
+ " in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
+ " out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
+ " cur_dim += 1\n",
+ " return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
+ "\n",
+ "\n",
+ "in_data, out_data = make_data()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for name in in_data.columns:\n",
+ " data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
+ " data.plot()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/openpi/examples/simple_client/Dockerfile b/openpi/examples/simple_client/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..05991634a6fc01f6f7fa8b94a773a68938277767
--- /dev/null
+++ b/openpi/examples/simple_client/Dockerfile
@@ -0,0 +1,32 @@
+# Dockerfile for the simple client.
+
+# Build the container:
+# docker build . -t simple_client -f examples/simple_client/Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
+
+FROM python:3.7-slim
+COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
+
+WORKDIR /app
+
+# Copy from the cache instead of linking since it's a mounted volume
+ENV UV_LINK_MODE=copy
+
+# Write the virtual environment outside of the project directory so it doesn't
+# leak out of the container when we mount the application code.
+ENV UV_PROJECT_ENVIRONMENT=/.venv
+
+# Copy the requirements files so we can install dependencies.
+# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
+# This strategy is best for development-style usage.
+COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
+COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
+
+# Install python dependencies.
+RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
+RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
+ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
+
+CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
diff --git a/openpi/examples/simple_client/README.md b/openpi/examples/simple_client/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc381c1d7a7d2ebcf60d8136a303ec9c0b67496a
--- /dev/null
+++ b/openpi/examples/simple_client/README.md
@@ -0,0 +1,30 @@
+# Simple Client
+
+A minimal client that sends observations to the server and prints the inference rate.
+
+You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
+
+```bash
+uv run examples/simple_client/main.py --help
+```
+
+## With Docker
+
+```bash
+export SERVER_ARGS="--env ALOHA_SIM"
+docker compose -f examples/simple_client/compose.yml up --build
+```
+
+## Without Docker
+
+Terminal window 1:
+
+```bash
+uv run examples/simple_client/main.py --env DROID
+```
+
+Terminal window 2:
+
+```bash
+uv run scripts/serve_policy.py --env DROID
+```
diff --git a/openpi/examples/simple_client/compose.yml b/openpi/examples/simple_client/compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..977e361f73276502bbf42254db66b159560fefdd
--- /dev/null
+++ b/openpi/examples/simple_client/compose.yml
@@ -0,0 +1,42 @@
+# Run with:
+# docker compose -f examples/simple_client/compose.yml up --build
+services:
+ runtime:
+ image: simple_client
+ depends_on:
+ - openpi_server
+ build:
+ context: ../..
+ dockerfile: examples/simple_client/Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ environment:
+ - SERVER_ARGS
+
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/openpi/examples/simple_client/main.py b/openpi/examples/simple_client/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd7eda1090d7e5f5a738e63642a0792b571fc87c
--- /dev/null
+++ b/openpi/examples/simple_client/main.py
@@ -0,0 +1,187 @@
+import dataclasses
+import enum
+import logging
+import pathlib
+import time
+
+import numpy as np
+from openpi_client import websocket_client_policy as _websocket_client_policy
+import polars as pl
+import rich
+import tqdm
+import tyro
+
+logger = logging.getLogger(__name__)
+
+
+class EnvMode(enum.Enum):
+ """Supported environments."""
+
+ ALOHA = "aloha"
+ ALOHA_SIM = "aloha_sim"
+ DROID = "droid"
+ LIBERO = "libero"
+
+
+@dataclasses.dataclass
+class Args:
+ """Command line arguments."""
+
+ # Host and port to connect to the server.
+ host: str = "0.0.0.0"
+ # Port to connect to the server. If None, the server will use the default port.
+ port: int | None = 8000
+ # API key to use for the server.
+ api_key: str | None = None
+ # Number of steps to run the policy for.
+ num_steps: int = 20
+ # Path to save the timings to a parquet file. (e.g., timing.parquet)
+ timing_file: pathlib.Path | None = None
+ # Environment to run the policy in.
+ env: EnvMode = EnvMode.ALOHA_SIM
+
+
+class TimingRecorder:
+ """Records timing measurements for different keys."""
+
+ def __init__(self) -> None:
+ self._timings: dict[str, list[float]] = {}
+
+ def record(self, key: str, time_ms: float) -> None:
+ """Record a timing measurement for the given key."""
+ if key not in self._timings:
+ self._timings[key] = []
+ self._timings[key].append(time_ms)
+
+ def get_stats(self, key: str) -> dict[str, float]:
+ """Get statistics for the given key."""
+ times = self._timings[key]
+ return {
+ "mean": float(np.mean(times)),
+ "std": float(np.std(times)),
+ "p25": float(np.quantile(times, 0.25)),
+ "p50": float(np.quantile(times, 0.50)),
+ "p75": float(np.quantile(times, 0.75)),
+ "p90": float(np.quantile(times, 0.90)),
+ "p95": float(np.quantile(times, 0.95)),
+ "p99": float(np.quantile(times, 0.99)),
+ }
+
+ def print_all_stats(self) -> None:
+ """Print statistics for all keys in a concise format."""
+
+ table = rich.table.Table(
+ title="[bold blue]Timing Statistics[/bold blue]",
+ show_header=True,
+ header_style="bold white",
+ border_style="blue",
+ title_justify="center",
+ )
+
+ # Add metric column with custom styling
+ table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
+
+ # Add statistical columns with consistent styling
+ stat_columns = [
+ ("Mean", "yellow", "mean"),
+ ("Std", "yellow", "std"),
+ ("P25", "magenta", "p25"),
+ ("P50", "magenta", "p50"),
+ ("P75", "magenta", "p75"),
+ ("P90", "magenta", "p90"),
+ ("P95", "magenta", "p95"),
+ ("P99", "magenta", "p99"),
+ ]
+
+ for name, style, _ in stat_columns:
+ table.add_column(name, justify="right", style=style, no_wrap=True)
+
+ # Add rows for each metric with formatted values
+ for key in sorted(self._timings.keys()):
+ stats = self.get_stats(key)
+ values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
+ table.add_row(key, *values)
+
+ # Print with custom console settings
+ console = rich.console.Console(width=None, highlight=True)
+ console.print(table)
+
+ def write_parquet(self, path: pathlib.Path) -> None:
+ """Save the timings to a parquet file."""
+ logger.info(f"Writing timings to {path}")
+ frame = pl.DataFrame(self._timings)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ frame.write_parquet(path)
+
+
+def main(args: Args) -> None:
+ obs_fn = {
+ EnvMode.ALOHA: _random_observation_aloha,
+ EnvMode.ALOHA_SIM: _random_observation_aloha,
+ EnvMode.DROID: _random_observation_droid,
+ EnvMode.LIBERO: _random_observation_libero,
+ }[args.env]
+
+ policy = _websocket_client_policy.WebsocketClientPolicy(
+ host=args.host,
+ port=args.port,
+ api_key=args.api_key,
+ )
+ logger.info(f"Server metadata: {policy.get_server_metadata()}")
+
+ # Send a few observations to make sure the model is loaded.
+ for _ in range(2):
+ policy.infer(obs_fn())
+
+ timing_recorder = TimingRecorder()
+
+ for _ in tqdm.trange(args.num_steps, desc="Running policy"):
+ inference_start = time.time()
+ action = policy.infer(obs_fn())
+ timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
+ for key, value in action.get("server_timing", {}).items():
+ timing_recorder.record(f"server_{key}", value)
+ for key, value in action.get("policy_timing", {}).items():
+ timing_recorder.record(f"policy_{key}", value)
+
+ timing_recorder.print_all_stats()
+
+ if args.timing_file is not None:
+ timing_recorder.write_parquet(args.timing_file)
+
+
+def _random_observation_aloha() -> dict:
+ return {
+ "state": np.ones((14,)),
+ "images": {
+ "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ },
+ "prompt": "do something",
+ }
+
+
+def _random_observation_droid() -> dict:
+ return {
+ "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "observation/joint_position": np.random.rand(7),
+ "observation/gripper_position": np.random.rand(1),
+ "prompt": "do something",
+ }
+
+
+def _random_observation_libero() -> dict:
+ return {
+ "observation/state": np.random.rand(8),
+ "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "prompt": "do something",
+ }
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ main(tyro.cli(Args))
diff --git a/openpi/examples/simple_client/requirements.in b/openpi/examples/simple_client/requirements.in
new file mode 100644
index 0000000000000000000000000000000000000000..f4c5c0c1b3f3b45ad65067b1b4d61818d74f7c27
--- /dev/null
+++ b/openpi/examples/simple_client/requirements.in
@@ -0,0 +1,5 @@
+numpy>=1.22.4,<2.0.0
+rich
+tqdm
+tyro
+polars
\ No newline at end of file
diff --git a/openpi/examples/simple_client/requirements.txt b/openpi/examples/simple_client/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..86143b53fd0dfafe8ffa30044a172a81bbab0122
--- /dev/null
+++ b/openpi/examples/simple_client/requirements.txt
@@ -0,0 +1,30 @@
+# This file was autogenerated by uv via the following command:
+# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
+docstring-parser==0.16
+ # via tyro
+markdown-it-py==3.0.0
+ # via rich
+mdurl==0.1.2
+ # via markdown-it-py
+numpy==1.26.4
+ # via -r examples/simple_client/requirements.in
+polars==1.30.0
+ # via -r examples/simple_client/requirements.in
+pygments==2.19.1
+ # via rich
+rich==14.0.0
+ # via
+ # -r examples/simple_client/requirements.in
+ # tyro
+shtab==1.7.2
+ # via tyro
+tqdm==4.67.1
+ # via -r examples/simple_client/requirements.in
+typeguard==4.4.2
+ # via tyro
+typing-extensions==4.13.2
+ # via
+ # typeguard
+ # tyro
+tyro==0.9.22
+ # via -r examples/simple_client/requirements.in
diff --git a/openpi/examples/ur5/README.md b/openpi/examples/ur5/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a66e91da8af8dfcd4f339668068666b55578a6a0
--- /dev/null
+++ b/openpi/examples/ur5/README.md
@@ -0,0 +1,142 @@
+# UR5 Example
+
+Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
+
+First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
+
+```python
+
+@dataclasses.dataclass(frozen=True)
+class UR5Inputs(transforms.DataTransformFn):
+
+ model_type: _model.ModelType = _model.ModelType.PI0
+
+ def __call__(self, data: dict) -> dict:
+ # First, concatenate the joints and gripper into the state vector.
+ state = np.concatenate([data["joints"], data["gripper"]])
+
+ # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
+ # stores as float32 (C,H,W), gets skipped for policy inference.
+ base_image = _parse_image(data["base_rgb"])
+ wrist_image = _parse_image(data["wrist_rgb"])
+
+ # Create inputs dict.
+ inputs = {
+ "state": state,
+ "image": {
+ "base_0_rgb": base_image,
+ "left_wrist_0_rgb": wrist_image,
+ # Since there is no right wrist, replace with zeros
+ "right_wrist_0_rgb": np.zeros_like(base_image),
+ },
+ "image_mask": {
+ "base_0_rgb": np.True_,
+ "left_wrist_0_rgb": np.True_,
+ # Since the "slot" for the right wrist is not used, this mask is set
+ # to False
+ "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
+ },
+ }
+
+ if "actions" in data:
+ inputs["actions"] = data["actions"]
+
+ # Pass the prompt (aka language instruction) to the model.
+ if "prompt" in data:
+ inputs["prompt"] = data["prompt"]
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class UR5Outputs(transforms.DataTransformFn):
+
+ def __call__(self, data: dict) -> dict:
+ # Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
+ return {"actions": np.asarray(data["actions"][:, :7])}
+
+```
+
+Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
+
+```python
+
+@dataclasses.dataclass(frozen=True)
+class LeRobotUR5DataConfig(DataConfigFactory):
+
+ @override
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
+ # Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
+ repack_transform = _transforms.Group(
+ inputs=[
+ _transforms.RepackTransform(
+ {
+ "base_rgb": "image",
+ "wrist_rgb": "wrist_image",
+ "joints": "joints",
+ "gripper": "gripper",
+ "prompt": "prompt",
+ }
+ )
+ ]
+ )
+
+ # These transforms are the ones we wrote earlier.
+ data_transforms = _transforms.Group(
+ inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
+ outputs=[UR5Outputs()],
+ )
+
+ # Convert absolute actions to delta actions.
+ # By convention, we do not convert the gripper action (7th dimension).
+ delta_action_mask = _transforms.make_bool_mask(6, -1)
+ data_transforms = data_transforms.push(
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
+ )
+
+ # Model transforms include things like tokenizing the prompt and action targets
+ # You do not need to change anything here for your own dataset.
+ model_transforms = ModelTransformFactory()(model_config)
+
+ # We return all data transforms for training and inference. No need to change anything here.
+ return dataclasses.replace(
+ self.create_base_config(assets_dirs),
+ repack_transforms=repack_transform,
+ data_transforms=data_transforms,
+ model_transforms=model_transforms,
+ )
+
+```
+
+Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
+
+```python
+TrainConfig(
+ name="pi0_ur5",
+ model=pi0.Pi0Config(),
+ data=LeRobotUR5DataConfig(
+ repo_id="your_username/ur5_dataset",
+ # This config lets us reload the UR5 normalization stats from the base model checkpoint.
+ # Reloading normalization stats can help transfer pre-trained models to new environments.
+ # See the [norm_stats.md](../docs/norm_stats.md) file for more details.
+ assets=AssetsConfig(
+ assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
+ asset_id="ur5e",
+ ),
+ base_config=DataConfig(
+ # This flag determines whether we load the prompt (i.e. the task instruction) from the
+ # ``task`` field in the LeRobot dataset. The recommended setting is True.
+ prompt_from_task=True,
+ ),
+ ),
+ # Load the pi0 base model checkpoint.
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
+ num_train_steps=30_000,
+)
+```
+
+
+
+
+
diff --git a/openpi/model.safetensors b/openpi/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..487502d0ecc7730a17cb44073ac007001e4d891f
--- /dev/null
+++ b/openpi/model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:766fa4f2a981da1e46b0be7ecb6055c4780117c6e8aef8a9f36ecdfd0b1c5da8
+size 136
diff --git a/openpi/packages/openpi-client/pyproject.toml b/openpi/packages/openpi-client/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..fba7b66f62be7ff209d46717b1eed1d1b232ca2e
--- /dev/null
+++ b/openpi/packages/openpi-client/pyproject.toml
@@ -0,0 +1,23 @@
+[project]
+name = "openpi-client"
+version = "0.1.0"
+requires-python = ">=3.7"
+dependencies = [
+ "dm-tree>=0.1.8",
+ "msgpack>=1.0.5",
+ "numpy>=1.22.4,<2.0.0",
+ "pillow>=9.0.0",
+ "tree>=0.2.4",
+ "websockets>=11.0",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.uv]
+dev-dependencies = ["pytest>=8.3.4"]
+
+[tool.ruff]
+line-length = 120
+target-version = "py37"
diff --git a/openpi/packages/openpi-client/src/openpi_client/__init__.py b/openpi/packages/openpi-client/src/openpi_client/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc1f76bc69e3f559bee6253b24fc93acee9e1f9
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.1.0"
diff --git a/openpi/packages/openpi-client/src/openpi_client/action_chunk_broker.py b/openpi/packages/openpi-client/src/openpi_client/action_chunk_broker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa9d83d023b7c0c60a1d05531343af01e72d09b
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/action_chunk_broker.py
@@ -0,0 +1,50 @@
+from typing import Dict
+
+import numpy as np
+import tree
+from typing_extensions import override
+
+from openpi_client import base_policy as _base_policy
+
+
+class ActionChunkBroker(_base_policy.BasePolicy):
+ """Wraps a policy to return action chunks one-at-a-time.
+
+ Assumes that the first dimension of all action fields is the chunk size.
+
+ A new inference call to the inner policy is only made when the current
+ list of chunks is exhausted.
+ """
+
+ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
+ self._policy = policy
+ self._action_horizon = action_horizon
+ self._cur_step: int = 0
+
+ self._last_results: Dict[str, np.ndarray] | None = None
+
+ @override
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
+ if self._last_results is None:
+ self._last_results = self._policy.infer(obs)
+ self._cur_step = 0
+
+ def slicer(x):
+ if isinstance(x, np.ndarray):
+ return x[self._cur_step, ...]
+ else:
+ return x
+
+ results = tree.map_structure(slicer, self._last_results)
+ self._cur_step += 1
+
+ if self._cur_step >= self._action_horizon:
+ self._last_results = None
+
+ return results
+
+ @override
+ def reset(self) -> None:
+ self._policy.reset()
+ self._last_results = None
+ self._cur_step = 0
diff --git a/openpi/packages/openpi-client/src/openpi_client/base_policy.py b/openpi/packages/openpi-client/src/openpi_client/base_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f4290651b1b7bab3bd9549b47876838f5b51629
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/base_policy.py
@@ -0,0 +1,12 @@
+import abc
+from typing import Dict
+
+
+class BasePolicy(abc.ABC):
+ @abc.abstractmethod
+ def infer(self, obs: Dict) -> Dict:
+ """Infer actions from observations."""
+
+ def reset(self) -> None:
+ """Reset the policy to its initial state."""
+ pass
diff --git a/openpi/packages/openpi-client/src/openpi_client/image_tools.py b/openpi/packages/openpi-client/src/openpi_client/image_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a971b9d5f6b1495fd6cdea202ffa607d8b34bf0
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/image_tools.py
@@ -0,0 +1,58 @@
+import numpy as np
+from PIL import Image
+
+
+def convert_to_uint8(img: np.ndarray) -> np.ndarray:
+ """Converts an image to uint8 if it is a float image.
+
+ This is important for reducing the size of the image when sending it over the network.
+ """
+ if np.issubdtype(img.dtype, np.floating):
+ img = (255 * img).astype(np.uint8)
+ return img
+
+
+def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
+
+ Args:
+ images: A batch of images in [..., height, width, channel] format.
+ height: The target height of the image.
+ width: The target width of the image.
+ method: The interpolation method to use. Default is bilinear.
+
+ Returns:
+ The resized images in [..., height, width, channel].
+ """
+ # If the images are already the correct size, return them as is.
+ if images.shape[-3:-1] == (height, width):
+ return images
+
+ original_shape = images.shape
+
+ images = images.reshape(-1, *original_shape[-3:])
+ resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
+ return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
+
+
+def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
+ width without distortion by padding with zeros.
+
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
+ """
+ cur_width, cur_height = image.size
+ if cur_width == width and cur_height == height:
+ return image # No need to resize if the image is already the correct size.
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_image = image.resize((resized_width, resized_height), resample=method)
+
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
+ pad_height = max(0, int((height - resized_height) / 2))
+ pad_width = max(0, int((width - resized_width) / 2))
+ zero_image.paste(resized_image, (pad_width, pad_height))
+ assert zero_image.size == (width, height)
+ return zero_image
diff --git a/openpi/packages/openpi-client/src/openpi_client/image_tools_test.py b/openpi/packages/openpi-client/src/openpi_client/image_tools_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d4b4b92030ea869712b312581e26243035aafba
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/image_tools_test.py
@@ -0,0 +1,37 @@
+import numpy as np
+
+import openpi_client.image_tools as image_tools
+
+
+def test_resize_with_pad_shapes():
+ # Test case 1: Resize image with larger dimensions
+ images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
+ height = 20
+ width = 20
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (2, height, width, 3)
+ assert np.all(resized_images == 0)
+
+ # Test case 2: Resize image with smaller dimensions
+ images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
+ height = 15
+ width = 15
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (3, height, width, 3)
+ assert np.all(resized_images == 0)
+
+ # Test case 3: Resize image with the same dimensions
+ images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
+ height = 50
+ width = 50
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (1, height, width, 3)
+ assert np.all(resized_images == 0)
+
+ # Test case 3: Resize image with odd-numbered padding
+ images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
+ height = 60
+ width = 80
+ resized_images = image_tools.resize_with_pad(images, height, width)
+ assert resized_images.shape == (1, height, width, 3)
+ assert np.all(resized_images == 0)
diff --git a/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy.py b/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..007f755edf54565579376b077eec7f7f715e1b96
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy.py
@@ -0,0 +1,57 @@
+"""Adds NumPy array support to msgpack.
+
+msgpack is good for (de)serializing data over a network for multiple reasons:
+- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
+- msgpack is widely used and has good cross-language support
+- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
+ languages like Python and JavaScript
+- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
+ than pickle for serializing large arrays using the below strategy
+
+The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
+that it falls back to pickle for object arrays.
+"""
+
+import functools
+
+import msgpack
+import numpy as np
+
+
+def pack_array(obj):
+ if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
+ raise ValueError(f"Unsupported dtype: {obj.dtype}")
+
+ if isinstance(obj, np.ndarray):
+ return {
+ b"__ndarray__": True,
+ b"data": obj.tobytes(),
+ b"dtype": obj.dtype.str,
+ b"shape": obj.shape,
+ }
+
+ if isinstance(obj, np.generic):
+ return {
+ b"__npgeneric__": True,
+ b"data": obj.item(),
+ b"dtype": obj.dtype.str,
+ }
+
+ return obj
+
+
+def unpack_array(obj):
+ if b"__ndarray__" in obj:
+ return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
+
+ if b"__npgeneric__" in obj:
+ return np.dtype(obj[b"dtype"]).type(obj[b"data"])
+
+ return obj
+
+
+Packer = functools.partial(msgpack.Packer, default=pack_array)
+packb = functools.partial(msgpack.packb, default=pack_array)
+
+Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
+unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
diff --git a/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py b/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c774ba468aa6948e154cb2008bac8f6128a4593
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/msgpack_numpy_test.py
@@ -0,0 +1,45 @@
+import numpy as np
+import pytest
+import tree
+
+from openpi_client import msgpack_numpy
+
+
+def _check(expected, actual):
+ if isinstance(expected, np.ndarray):
+ assert expected.shape == actual.shape
+ assert expected.dtype == actual.dtype
+ assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
+ else:
+ assert expected == actual
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ 1, # int
+ 1.0, # float
+ "hello", # string
+ np.bool_(True), # boolean scalar
+ np.array([1, 2, 3])[0], # int scalar
+ np.str_("asdf"), # string scalar
+ [1, 2, 3], # list
+ {"key": "value"}, # dict
+ {"key": [1, 2, 3]}, # nested dict
+ np.array(1.0), # 0D array
+ np.array([1, 2, 3], dtype=np.int32), # 1D integer array
+ np.array(["asdf", "qwer"]), # string array
+ np.array([True, False]), # boolean array
+ np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
+ np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
+ np.array([np.nan, np.inf, -np.inf]), # special float values
+ {"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
+ [np.array([1, 2]), np.array([3, 4])], # list of arrays
+ np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
+ np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
+ ],
+)
+def test_pack_unpack(data):
+ packed = msgpack_numpy.packb(data)
+ unpacked = msgpack_numpy.unpackb(packed)
+ tree.map_structure(_check, data, unpacked)
diff --git a/openpi/packages/openpi-client/src/openpi_client/runtime/agent.py b/openpi/packages/openpi-client/src/openpi_client/runtime/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2c3ab66ef618ad9ecbff7b81ad9340a4604128c
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/runtime/agent.py
@@ -0,0 +1,17 @@
+import abc
+
+
+class Agent(abc.ABC):
+ """An Agent is the thing with agency, i.e. the entity that makes decisions.
+
+ Agents receive observations about the state of the world, and return actions
+ to take in response.
+ """
+
+ @abc.abstractmethod
+ def get_action(self, observation: dict) -> dict:
+ """Query the agent for the next action."""
+
+ @abc.abstractmethod
+ def reset(self) -> None:
+ """Reset the agent to its initial state."""
diff --git a/openpi/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py b/openpi/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..65227c44dae667d9b2743b6bc1026e791cec35c4
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/runtime/agents/policy_agent.py
@@ -0,0 +1,18 @@
+from typing_extensions import override
+
+from openpi_client import base_policy as _base_policy
+from openpi_client.runtime import agent as _agent
+
+
+class PolicyAgent(_agent.Agent):
+ """An agent that uses a policy to determine actions."""
+
+ def __init__(self, policy: _base_policy.BasePolicy) -> None:
+ self._policy = policy
+
+ @override
+ def get_action(self, observation: dict) -> dict:
+ return self._policy.infer(observation)
+
+ def reset(self) -> None:
+ self._policy.reset()
diff --git a/openpi/packages/openpi-client/src/openpi_client/runtime/environment.py b/openpi/packages/openpi-client/src/openpi_client/runtime/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..664ac4678aaaa3aecf52268a6a09d1d1fc974226
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/runtime/environment.py
@@ -0,0 +1,32 @@
+import abc
+
+
+class Environment(abc.ABC):
+ """An Environment represents the robot and the environment it inhabits.
+
+ The primary contract of environments is that they can be queried for observations
+ about their state, and have actions applied to them to change that state.
+ """
+
+ @abc.abstractmethod
+ def reset(self) -> None:
+ """Reset the environment to its initial state.
+
+ This will be called once before starting each episode.
+ """
+
+ @abc.abstractmethod
+ def is_episode_complete(self) -> bool:
+ """Allow the environment to signal that the episode is complete.
+
+ This will be called after each step. It should return `True` if the episode is
+ complete (either successfully or unsuccessfully), and `False` otherwise.
+ """
+
+ @abc.abstractmethod
+ def get_observation(self) -> dict:
+ """Query the environment for the current state."""
+
+ @abc.abstractmethod
+ def apply_action(self, action: dict) -> None:
+ """Take an action in the environment."""
diff --git a/openpi/packages/openpi-client/src/openpi_client/runtime/runtime.py b/openpi/packages/openpi-client/src/openpi_client/runtime/runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..9552be091a26e163d60cab8071df4716524bf2e8
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/runtime/runtime.py
@@ -0,0 +1,92 @@
+import logging
+import threading
+import time
+
+from openpi_client.runtime import agent as _agent
+from openpi_client.runtime import environment as _environment
+from openpi_client.runtime import subscriber as _subscriber
+
+
+class Runtime:
+ """The core module orchestrating interactions between key components of the system."""
+
+ def __init__(
+ self,
+ environment: _environment.Environment,
+ agent: _agent.Agent,
+ subscribers: list[_subscriber.Subscriber],
+ max_hz: float = 0,
+ num_episodes: int = 1,
+ max_episode_steps: int = 0,
+ ) -> None:
+ self._environment = environment
+ self._agent = agent
+ self._subscribers = subscribers
+ self._max_hz = max_hz
+ self._num_episodes = num_episodes
+ self._max_episode_steps = max_episode_steps
+
+ self._in_episode = False
+ self._episode_steps = 0
+
+ def run(self) -> None:
+ """Runs the runtime loop continuously until stop() is called or the environment is done."""
+ for _ in range(self._num_episodes):
+ self._run_episode()
+
+ # Final reset, this is important for real environments to move the robot to its home position.
+ self._environment.reset()
+
+ def run_in_new_thread(self) -> threading.Thread:
+ """Runs the runtime loop in a new thread."""
+ thread = threading.Thread(target=self.run)
+ thread.start()
+ return thread
+
+ def mark_episode_complete(self) -> None:
+ """Marks the end of an episode."""
+ self._in_episode = False
+
+ def _run_episode(self) -> None:
+ """Runs a single episode."""
+ logging.info("Starting episode...")
+ self._environment.reset()
+ self._agent.reset()
+ for subscriber in self._subscribers:
+ subscriber.on_episode_start()
+
+ self._in_episode = True
+ self._episode_steps = 0
+ step_time = 1 / self._max_hz if self._max_hz > 0 else 0
+ last_step_time = time.time()
+
+ while self._in_episode:
+ self._step()
+ self._episode_steps += 1
+
+ # Sleep to maintain the desired frame rate
+ now = time.time()
+ dt = now - last_step_time
+ if dt < step_time:
+ time.sleep(step_time - dt)
+ last_step_time = time.time()
+ else:
+ last_step_time = now
+
+ logging.info("Episode completed.")
+ for subscriber in self._subscribers:
+ subscriber.on_episode_end()
+
+ def _step(self) -> None:
+ """A single step of the runtime loop."""
+ observation = self._environment.get_observation()
+ action = self._agent.get_action(observation)
+ self._environment.apply_action(action)
+
+ for subscriber in self._subscribers:
+ subscriber.on_step(observation, action)
+
+ if self._environment.is_episode_complete() or (
+ self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
+ ):
+ self.mark_episode_complete()
diff --git a/openpi/packages/openpi-client/src/openpi_client/runtime/subscriber.py b/openpi/packages/openpi-client/src/openpi_client/runtime/subscriber.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c69edaa8e814dfcfe56b78b774578fe37f79428
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/runtime/subscriber.py
@@ -0,0 +1,20 @@
+import abc
+
+
+class Subscriber(abc.ABC):
+ """Subscribes to events in the runtime.
+
+ Subscribers can be used to save data, visualize, etc.
+ """
+
+ @abc.abstractmethod
+ def on_episode_start(self) -> None:
+ """Called when an episode starts."""
+
+ @abc.abstractmethod
+ def on_step(self, observation: dict, action: dict) -> None:
+ """Append a step to the episode."""
+
+ @abc.abstractmethod
+ def on_episode_end(self) -> None:
+ """Called when an episode ends."""
diff --git a/openpi/packages/openpi-client/src/openpi_client/websocket_client_policy.py b/openpi/packages/openpi-client/src/openpi_client/websocket_client_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6244f0f780ceb50c123500456e69340566f060a
--- /dev/null
+++ b/openpi/packages/openpi-client/src/openpi_client/websocket_client_policy.py
@@ -0,0 +1,55 @@
+import logging
+import time
+from typing import Dict, Optional, Tuple
+
+from typing_extensions import override
+import websockets.sync.client
+
+from openpi_client import base_policy as _base_policy
+from openpi_client import msgpack_numpy
+
+
+class WebsocketClientPolicy(_base_policy.BasePolicy):
+ """Implements the Policy interface by communicating with a server over websocket.
+
+ See WebsocketPolicyServer for a corresponding server implementation.
+ """
+
+ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
+ self._uri = f"ws://{host}"
+ if port is not None:
+ self._uri += f":{port}"
+ self._packer = msgpack_numpy.Packer()
+ self._api_key = api_key
+ self._ws, self._server_metadata = self._wait_for_server()
+
+ def get_server_metadata(self) -> Dict:
+ return self._server_metadata
+
+ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
+ logging.info(f"Waiting for server at {self._uri}...")
+ while True:
+ try:
+ headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
+ conn = websockets.sync.client.connect(
+ self._uri, compression=None, max_size=None, additional_headers=headers
+ )
+ metadata = msgpack_numpy.unpackb(conn.recv())
+ return conn, metadata
+ except ConnectionRefusedError:
+ logging.info("Still waiting for server...")
+ time.sleep(5)
+
+ @override
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
+ data = self._packer.pack(obs)
+ self._ws.send(data)
+ response = self._ws.recv()
+ if isinstance(response, str):
+ # we're expecting bytes; if the server sends a string, it's an error.
+ raise RuntimeError(f"Error in inference server:\n{response}")
+ return msgpack_numpy.unpackb(response)
+
+ @override
+ def reset(self) -> None:
+ pass
diff --git a/openpi/policy_postprocessor.json b/openpi/policy_postprocessor.json
new file mode 100644
index 0000000000000000000000000000000000000000..7bdb69e1ef55a1228f95a0fd293d8ec1bf9f4911
--- /dev/null
+++ b/openpi/policy_postprocessor.json
@@ -0,0 +1,31 @@
+{
+ "name": "policy_postprocessor",
+ "steps": [
+ {
+ "registry_name": "unnormalizer_processor",
+ "config": {
+ "eps": 1e-08,
+ "features": {
+ "action": {
+ "type": "ACTION",
+ "shape": [
+ 6
+ ]
+ }
+ },
+ "norm_map": {
+ "VISUAL": "IDENTITY",
+ "STATE": "MEAN_STD",
+ "ACTION": "MEAN_STD"
+ }
+ }
+ },
+ {
+ "registry_name": "device_processor",
+ "config": {
+ "device": "cpu",
+ "float_dtype": null
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/openpi/policy_preprocessor.json b/openpi/policy_preprocessor.json
new file mode 100644
index 0000000000000000000000000000000000000000..da4d59c03ed262a77fda1bcb5b12eabc5f1e8325
--- /dev/null
+++ b/openpi/policy_preprocessor.json
@@ -0,0 +1,86 @@
+{
+ "name": "policy_preprocessor",
+ "steps": [
+ {
+ "registry_name": "rename_observations_processor",
+ "config": {
+ "rename_map": {}
+ }
+ },
+ {
+ "registry_name": "to_batch_processor",
+ "config": {}
+ },
+ {
+ "registry_name": "pi0_new_line_processor",
+ "config": {}
+ },
+ {
+ "registry_name": "tokenizer_processor",
+ "config": {
+ "max_length": 48,
+ "task_key": "task",
+ "padding_side": "right",
+ "padding": "max_length",
+ "truncation": true,
+ "tokenizer_name": "google/paligemma-3b-pt-224"
+ }
+ },
+ {
+ "registry_name": "device_processor",
+ "config": {
+ "device": "cpu",
+ "float_dtype": null
+ }
+ },
+ {
+ "registry_name": "normalizer_processor",
+ "config": {
+ "eps": 1e-08,
+ "features": {
+ "observation.state": {
+ "type": "STATE",
+ "shape": [
+ 6
+ ]
+ },
+ "observation.images.camera0": {
+ "type": "VISUAL",
+ "shape": [
+ 3,
+ 480,
+ 640
+ ]
+ },
+ "observation.images.camera1": {
+ "type": "VISUAL",
+ "shape": [
+ 3,
+ 480,
+ 640
+ ]
+ },
+ "observation.images.camera2": {
+ "type": "VISUAL",
+ "shape": [
+ 3,
+ 480,
+ 640
+ ]
+ },
+ "action": {
+ "type": "ACTION",
+ "shape": [
+ 6
+ ]
+ }
+ },
+ "norm_map": {
+ "VISUAL": "IDENTITY",
+ "STATE": "MEAN_STD",
+ "ACTION": "MEAN_STD"
+ }
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/openpi/pyproject.toml b/openpi/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..c4a06e53284ecc78a05923e374c15e700a20676c
--- /dev/null
+++ b/openpi/pyproject.toml
@@ -0,0 +1,137 @@
+[project]
+name = "openpi"
+version = "0.1.0"
+description = "Physical Intelligence open source repo"
+readme = "README.md"
+requires-python = ">=3.11"
+license = { file = "LICENSE" }
+dependencies = [
+ "augmax>=0.3.4",
+ "dm-tree>=0.1.8",
+ "einops>=0.8.0",
+ "equinox>=0.11.8",
+ "flatbuffers>=24.3.25",
+ "flax==0.10.2",
+ "fsspec[gcs]>=2024.6.0",
+ "gym-aloha>=0.1.1",
+ "imageio>=2.36.1",
+ "jax[cuda12]==0.5.3",
+ "jaxtyping==0.2.36",
+ "lerobot",
+ "ml_collections==1.0.0",
+ "numpy>=1.22.4,<2.0.0",
+ "numpydantic>=1.6.6",
+ "opencv-python>=4.10.0.84",
+ "openpi-client",
+ "orbax-checkpoint==0.11.13",
+ "pillow>=11.0.0",
+ "sentencepiece>=0.2.0",
+ "torch==2.7.1",
+ "tqdm-loggable>=0.2",
+ "typing-extensions>=4.12.2",
+ "tyro>=0.9.5",
+ "wandb>=0.19.1",
+ "filelock>=3.16.1",
+ "beartype==0.19.0",
+ "treescope>=0.1.7",
+ "transformers==4.53.2",
+ "rich>=14.0.0",
+ "polars>=1.30.0",
+]
+
+
+[project.urls]
+Repository = "https://github.com/Physical-Intelligence/openpi"
+
+[dependency-groups]
+dev = [
+ "pytest>=8.3.4",
+ "ruff>=0.8.6",
+ "pre-commit>=4.0.1",
+ "ipykernel>=6.29.5",
+ "ipywidgets>=8.1.5",
+ "matplotlib>=3.10.0",
+ "pynvml>=12.0.0",
+]
+rlds = [
+ "dlimp",
+ "tensorflow-cpu==2.15.0",
+ "tensorflow-datasets==4.9.9",
+]
+
+[tool.uv]
+override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
+
+[tool.uv.sources]
+openpi-client = { workspace = true }
+lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
+dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
+
+[tool.uv.workspace]
+members = ["packages/*"]
+
+[tool.ruff]
+line-length = 120
+target-version = "py311"
+extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
+
+[tool.ruff.lint]
+# https://docs.astral.sh/ruff/rules/
+select = [
+ "B",
+ "C4",
+ "DTZ",
+ "E4",
+ "E7",
+ "E9",
+ "F",
+ "FBT",
+ "FURB",
+ "I",
+ "ICN",
+ "ISC",
+ "LOG",
+ "N",
+ "PD",
+ "PERF",
+ "PIE",
+ "PLC",
+ "PLE",
+ "PLR1",
+ "PLR5",
+ "PLW",
+ "PT",
+ "Q",
+ "RET",
+ "RUF",
+ "SIM",
+ "SLF",
+ "T10",
+ "T20",
+ "UP",
+ "W",
+]
+ignore = [
+ "F722", # Conflicts with array typing.
+ "T201", # We use print statements.
+ "PD008", # Lots of false positives.
+ "ISC001", # Disabling to support ruff format.
+ "LOG015", # Use logger.info.
+]
+unfixable = [
+ "B905", # Fix defaults to strict=False, which is not what we want.
+]
+
+[tool.ruff.lint.isort]
+force-single-line = true
+force-sort-within-sections = true
+single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
+known-third-party = ["wandb"]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.pytest.ini_options]
+markers = ["manual: should be run manually."]
+testpaths = ["src", "scripts", "packages"]
diff --git a/openpi/scripts/__init__.py b/openpi/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openpi/scripts/compute_norm_stats.py b/openpi/scripts/compute_norm_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8aef8722210e21b2de41c315e2ce840f8b35e4e
--- /dev/null
+++ b/openpi/scripts/compute_norm_stats.py
@@ -0,0 +1,117 @@
+"""Compute normalization statistics for a config.
+
+This script is used to compute the normalization statistics for a given config. It
+will compute the mean and standard deviation of the data in the dataset and save it
+to the config assets directory.
+"""
+
+import numpy as np
+import tqdm
+import tyro
+
+import openpi.models.model as _model
+import openpi.shared.normalize as normalize
+import openpi.training.config as _config
+import openpi.training.data_loader as _data_loader
+import openpi.transforms as transforms
+
+
+class RemoveStrings(transforms.DataTransformFn):
+ def __call__(self, x: dict) -> dict:
+ return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
+
+
+def create_torch_dataloader(
+ data_config: _config.DataConfig,
+ action_horizon: int,
+ batch_size: int,
+ model_config: _model.BaseModelConfig,
+ num_workers: int,
+ max_frames: int | None = None,
+) -> tuple[_data_loader.Dataset, int]:
+ if data_config.repo_id is None:
+ raise ValueError("Data config must have a repo_id")
+ dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
+ dataset = _data_loader.TransformedDataset(
+ dataset,
+ [
+ *data_config.repack_transforms.inputs,
+ *data_config.data_transforms.inputs,
+ # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
+ RemoveStrings(),
+ ],
+ )
+ if max_frames is not None and max_frames < len(dataset):
+ num_batches = max_frames // batch_size
+ shuffle = True
+ else:
+ num_batches = len(dataset) // batch_size
+ shuffle = False
+ data_loader = _data_loader.TorchDataLoader(
+ dataset,
+ local_batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=shuffle,
+ num_batches=num_batches,
+ )
+ return data_loader, num_batches
+
+
+def create_rlds_dataloader(
+ data_config: _config.DataConfig,
+ action_horizon: int,
+ batch_size: int,
+ max_frames: int | None = None,
+) -> tuple[_data_loader.Dataset, int]:
+ dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False)
+ dataset = _data_loader.IterableTransformedDataset(
+ dataset,
+ [
+ *data_config.repack_transforms.inputs,
+ *data_config.data_transforms.inputs,
+ # Remove strings since they are not supported by JAX and are not needed to compute norm stats.
+ RemoveStrings(),
+ ],
+ is_batched=True,
+ )
+ if max_frames is not None and max_frames < len(dataset):
+ num_batches = max_frames // batch_size
+ else:
+ # NOTE: this length is currently hard-coded for DROID.
+ num_batches = len(dataset) // batch_size
+ data_loader = _data_loader.RLDSDataLoader(
+ dataset,
+ num_batches=num_batches,
+ )
+ return data_loader, num_batches
+
+
+def main(config_name: str, max_frames: int | None = None):
+ config = _config.get_config(config_name)
+ data_config = config.data.create(config.assets_dirs, config.model)
+
+ if data_config.rlds_data_dir is not None:
+ data_loader, num_batches = create_rlds_dataloader(
+ data_config, config.model.action_horizon, config.batch_size, max_frames
+ )
+ else:
+ data_loader, num_batches = create_torch_dataloader(
+ data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames
+ )
+
+ keys = ["state", "actions"]
+ stats = {key: normalize.RunningStats() for key in keys}
+
+ for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
+ for key in keys:
+ stats[key].update(np.asarray(batch[key]))
+
+ norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
+
+ output_path = config.assets_dirs / data_config.repo_id
+ print(f"Writing stats to: {output_path}")
+ normalize.save(output_path, norm_stats)
+
+
+if __name__ == "__main__":
+ tyro.cli(main)
diff --git a/openpi/scripts/docker/compose.yml b/openpi/scripts/docker/compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..564d276e26bdcfd6397e666f2c73b9b5b353a6ac
--- /dev/null
+++ b/openpi/scripts/docker/compose.yml
@@ -0,0 +1,29 @@
+# Run with:
+# docker compose -f scripts/docker/compose.yml up --build
+services:
+ openpi_server:
+ image: openpi_server
+ build:
+ context: ../..
+ dockerfile: scripts/docker/serve_policy.Dockerfile
+ init: true
+ tty: true
+ network_mode: host
+ # Populate configured openpi data home to /openpi_assets inside the container.
+ # Populate aws credential inside the container.
+ volumes:
+ - $PWD:/app
+ - ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
+ environment:
+ - SERVER_ARGS
+ - OPENPI_DATA_HOME=/openpi_assets
+ - IS_DOCKER=true
+
+ # Comment out this block if not running on a machine with GPUs.
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: 1
+ capabilities: [gpu]
diff --git a/openpi/scripts/docker/install_docker_ubuntu22.sh b/openpi/scripts/docker/install_docker_ubuntu22.sh
new file mode 100644
index 0000000000000000000000000000000000000000..38873b3e379ee40e6f80fe86a88be7dae494e05b
--- /dev/null
+++ b/openpi/scripts/docker/install_docker_ubuntu22.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+# Add Docker's official GPG key:
+sudo apt-get update
+sudo apt-get install -y ca-certificates curl
+sudo install -m 0755 -d /etc/apt/keyrings
+sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
+sudo chmod a+r /etc/apt/keyrings/docker.asc
+
+# Add the repository to Apt sources:
+echo \
+ "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
+ $(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
+ sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
+sudo apt-get update
+
+sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
+
+# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
+# See https://docs.docker.com/engine/install/linux-postinstall/
+username=$(whoami)
+sudo usermod -aG docker $username
+
+# Configure docker to start automatically on system boot.
+sudo systemctl enable docker.service
+sudo systemctl enable containerd.service
+
+# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
+if [ ~/.docker/config.json ]; then
+ sed -i 's/credsStore/credStore/g' ~/.docker/config.json
+fi
+
+echo ""
+echo "********************************************************************"
+echo "**** Restart to allow Docker permission changes to take effect. ****"
+echo "********************************************************************"
+echo ""
diff --git a/openpi/scripts/docker/install_nvidia_container_toolkit.sh b/openpi/scripts/docker/install_nvidia_container_toolkit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a4c67f1d5bcc6655f7ae2084a8866037b819b4f0
--- /dev/null
+++ b/openpi/scripts/docker/install_nvidia_container_toolkit.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
+# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
+
+curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
+ curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
+ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
+ sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
+
+# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
+sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
+sudo apt-get update
+sudo apt-get install -y nvidia-container-toolkit
+
+sudo nvidia-ctk runtime configure --runtime=docker
+sudo systemctl restart docker
diff --git a/openpi/scripts/docker/serve_policy.Dockerfile b/openpi/scripts/docker/serve_policy.Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..bd88a7e63455aa611ef77c2109df20a977afca6e
--- /dev/null
+++ b/openpi/scripts/docker/serve_policy.Dockerfile
@@ -0,0 +1,38 @@
+# Dockerfile for serving a PI policy.
+# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
+
+# Build the container:
+# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
+
+# Run the container:
+# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
+
+FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
+COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
+
+WORKDIR /app
+
+# Needed because LeRobot uses git-lfs.
+RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
+
+# Copy from the cache instead of linking since it's a mounted volume
+ENV UV_LINK_MODE=copy
+
+# Write the virtual environment outside of the project directory so it doesn't
+# leak out of the container when we mount the application code.
+ENV UV_PROJECT_ENVIRONMENT=/.venv
+
+# Install the project's dependencies using the lockfile and settings
+RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
+RUN --mount=type=cache,target=/root/.cache/uv \
+ --mount=type=bind,source=uv.lock,target=uv.lock \
+ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
+ --mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
+ --mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
+ GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
+
+# Copy transformers_replace files while preserving directory structure
+COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
+RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
+
+CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
diff --git a/openpi/scripts/serve_policy.py b/openpi/scripts/serve_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..30f121a60ba6af3d21c287e5c5582da54072ea62
--- /dev/null
+++ b/openpi/scripts/serve_policy.py
@@ -0,0 +1,122 @@
+import dataclasses
+import enum
+import logging
+import socket
+
+import tyro
+
+from openpi.policies import policy as _policy
+from openpi.policies import policy_config as _policy_config
+from openpi.serving import websocket_policy_server
+from openpi.training import config as _config
+
+
+class EnvMode(enum.Enum):
+ """Supported environments."""
+
+ ALOHA = "aloha"
+ ALOHA_SIM = "aloha_sim"
+ DROID = "droid"
+ LIBERO = "libero"
+
+
+@dataclasses.dataclass
+class Checkpoint:
+ """Load a policy from a trained checkpoint."""
+
+ # Training config name (e.g., "pi0_aloha_sim").
+ config: str
+ # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
+ dir: str
+
+
+@dataclasses.dataclass
+class Default:
+ """Use the default policy for the given environment."""
+
+
+@dataclasses.dataclass
+class Args:
+ """Arguments for the serve_policy script."""
+
+ # Environment to serve the policy for. This is only used when serving default policies.
+ env: EnvMode = EnvMode.ALOHA_SIM
+
+ # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
+ # prompt.
+ default_prompt: str | None = None
+
+ # Port to serve the policy on.
+ port: int = 8000
+ # Record the policy's behavior for debugging.
+ record: bool = False
+
+ # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
+ policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
+
+
+# Default checkpoints that should be used for each environment.
+DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
+ EnvMode.ALOHA: Checkpoint(
+ config="pi05_aloha",
+ dir="gs://openpi-assets/checkpoints/pi05_base",
+ ),
+ EnvMode.ALOHA_SIM: Checkpoint(
+ config="pi0_aloha_sim",
+ dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
+ ),
+ EnvMode.DROID: Checkpoint(
+ config="pi05_droid",
+ dir="gs://openpi-assets/checkpoints/pi05_droid",
+ ),
+ EnvMode.LIBERO: Checkpoint(
+ config="pi05_libero",
+ dir="gs://openpi-assets/checkpoints/pi05_libero",
+ ),
+}
+
+
+def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
+ """Create a default policy for the given environment."""
+ if checkpoint := DEFAULT_CHECKPOINT.get(env):
+ return _policy_config.create_trained_policy(
+ _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
+ )
+ raise ValueError(f"Unsupported environment mode: {env}")
+
+
+def create_policy(args: Args) -> _policy.Policy:
+ """Create a policy from the given arguments."""
+ match args.policy:
+ case Checkpoint():
+ return _policy_config.create_trained_policy(
+ _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
+ )
+ case Default():
+ return create_default_policy(args.env, default_prompt=args.default_prompt)
+
+
+def main(args: Args) -> None:
+ policy = create_policy(args)
+ policy_metadata = policy.metadata
+
+ # Record the policy's behavior.
+ if args.record:
+ policy = _policy.PolicyRecorder(policy, "policy_records")
+
+ hostname = socket.gethostname()
+ local_ip = socket.gethostbyname(hostname)
+ logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
+
+ server = websocket_policy_server.WebsocketPolicyServer(
+ policy=policy,
+ host="0.0.0.0",
+ port=args.port,
+ metadata=policy_metadata,
+ )
+ server.serve_forever()
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO, force=True)
+ main(tyro.cli(Args))
diff --git a/openpi/scripts/train.py b/openpi/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d289413abcac2232622a400ba54ac8cd65d4854
--- /dev/null
+++ b/openpi/scripts/train.py
@@ -0,0 +1,280 @@
+import dataclasses
+import functools
+import logging
+import platform
+from typing import Any
+
+import etils.epath as epath
+import flax.nnx as nnx
+from flax.training import common_utils
+import flax.traverse_util as traverse_util
+import jax
+import jax.experimental
+import jax.numpy as jnp
+import numpy as np
+import optax
+import tqdm_loggable.auto as tqdm
+import wandb
+
+import openpi.models.model as _model
+import openpi.shared.array_typing as at
+import openpi.shared.nnx_utils as nnx_utils
+import openpi.training.checkpoints as _checkpoints
+import openpi.training.config as _config
+import openpi.training.data_loader as _data_loader
+import openpi.training.optimizer as _optimizer
+import openpi.training.sharding as sharding
+import openpi.training.utils as training_utils
+import openpi.training.weight_loaders as _weight_loaders
+
+
+def init_logging():
+ """Custom logging format for better readability."""
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
+
+ class CustomFormatter(logging.Formatter):
+ def format(self, record):
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
+ return super().format(record)
+
+ formatter = CustomFormatter(
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
+ datefmt="%H:%M:%S",
+ )
+
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+ logger.handlers[0].setFormatter(formatter)
+
+
+def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
+ if not enabled:
+ wandb.init(mode="disabled")
+ return
+
+ ckpt_dir = config.checkpoint_dir
+ if not ckpt_dir.exists():
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
+ if resuming:
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
+ wandb.init(id=run_id, resume="must", project=config.project_name)
+ else:
+ wandb.init(
+ name=config.exp_name,
+ config=dataclasses.asdict(config),
+ project=config.project_name,
+ )
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
+
+ if log_code:
+ wandb.run.log_code(epath.Path(__file__).parent.parent)
+
+
+def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
+ """Loads and validates the weights. Returns a loaded subset of the weights."""
+ loaded_params = loader.load(params_shape)
+ at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
+
+ # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
+ return traverse_util.unflatten_dict(
+ {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
+ )
+
+
+@at.typecheck
+def init_train_state(
+ config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
+) -> tuple[training_utils.TrainState, Any]:
+ tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
+
+ def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
+ rng, model_rng = jax.random.split(rng)
+ # initialize the model (and its parameters).
+ model = config.model.create(model_rng)
+
+ # Merge the partial params into the model.
+ if partial_params is not None:
+ graphdef, state = nnx.split(model)
+ # This will produce an error if the partial params are not a subset of the state.
+ state.replace_by_pure_dict(partial_params)
+ model = nnx.merge(graphdef, state)
+
+ params = nnx.state(model)
+ # Convert frozen params to bfloat16.
+ params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
+
+ return training_utils.TrainState(
+ step=0,
+ params=params,
+ model_def=nnx.graphdef(model),
+ tx=tx,
+ opt_state=tx.init(params.filter(config.trainable_filter)),
+ ema_decay=config.ema_decay,
+ ema_params=None if config.ema_decay is None else params,
+ )
+
+ train_state_shape = jax.eval_shape(init, init_rng)
+ state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
+
+ if resume:
+ return train_state_shape, state_sharding
+
+ partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
+ replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
+
+ # Initialize the train state and mix in the partial params.
+ train_state = jax.jit(
+ init,
+ donate_argnums=(1,), # donate the partial params buffer.
+ in_shardings=replicated_sharding,
+ out_shardings=state_sharding,
+ )(init_rng, partial_params)
+
+ return train_state, state_sharding
+
+
+@at.typecheck
+def train_step(
+ config: _config.TrainConfig,
+ rng: at.KeyArrayLike,
+ state: training_utils.TrainState,
+ batch: tuple[_model.Observation, _model.Actions],
+) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
+ model = nnx.merge(state.model_def, state.params)
+ model.train()
+
+ @at.typecheck
+ def loss_fn(
+ model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
+ ):
+ chunked_loss = model.compute_loss(rng, observation, actions, train=True)
+ return jnp.mean(chunked_loss)
+
+ train_rng = jax.random.fold_in(rng, state.step)
+ observation, actions = batch
+
+ # Filter out frozen params.
+ diff_state = nnx.DiffState(0, config.trainable_filter)
+ loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
+
+ params = state.params.filter(config.trainable_filter)
+ updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
+ new_params = optax.apply_updates(params, updates)
+
+ # Update the model in place and return the new full state.
+ nnx.update(model, new_params)
+ new_params = nnx.state(model)
+
+ new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
+ if state.ema_decay is not None:
+ new_state = dataclasses.replace(
+ new_state,
+ ema_params=jax.tree.map(
+ lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
+ ),
+ )
+
+ # Filter out params that aren't kernels.
+ kernel_params = nnx.state(
+ model,
+ nnx.All(
+ nnx.Param,
+ nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
+ lambda _, x: x.value.ndim > 1,
+ ),
+ )
+ info = {
+ "loss": loss,
+ "grad_norm": optax.global_norm(grads),
+ "param_norm": optax.global_norm(kernel_params),
+ }
+ return new_state, info
+
+
+def main(config: _config.TrainConfig):
+ init_logging()
+ logging.info(f"Running on: {platform.node()}")
+
+ if config.batch_size % jax.device_count() != 0:
+ raise ValueError(
+ f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
+ )
+
+ jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
+
+ rng = jax.random.key(config.seed)
+ train_rng, init_rng = jax.random.split(rng)
+
+ mesh = sharding.make_mesh(config.fsdp_devices)
+ data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
+ replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
+
+ checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
+ config.checkpoint_dir,
+ keep_period=config.keep_period,
+ overwrite=config.overwrite,
+ resume=config.resume,
+ )
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
+
+ data_loader = _data_loader.create_data_loader(
+ config,
+ sharding=data_sharding,
+ shuffle=True,
+ )
+ data_iter = iter(data_loader)
+ batch = next(data_iter)
+ logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
+
+ # Log images from first batch to sanity check.
+ images_to_log = [
+ wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
+ for i in range(min(5, len(next(iter(batch[0].images.values())))))
+ ]
+ wandb.log({"camera_views": images_to_log}, step=0)
+
+ train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
+ jax.block_until_ready(train_state)
+ logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
+
+ if resuming:
+ train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
+
+ ptrain_step = jax.jit(
+ functools.partial(train_step, config),
+ in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
+ out_shardings=(train_state_sharding, replicated_sharding),
+ donate_argnums=(1,),
+ )
+
+ start_step = int(train_state.step)
+ pbar = tqdm.tqdm(
+ range(start_step, config.num_train_steps),
+ initial=start_step,
+ total=config.num_train_steps,
+ dynamic_ncols=True,
+ )
+
+ infos = []
+ for step in pbar:
+ with sharding.set_mesh(mesh):
+ train_state, info = ptrain_step(train_rng, train_state, batch)
+ infos.append(info)
+ if step % config.log_interval == 0:
+ stacked_infos = common_utils.stack_forest(infos)
+ reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
+ info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
+ pbar.write(f"Step {step}: {info_str}")
+ wandb.log(reduced_info, step=step)
+ infos = []
+ batch = next(data_iter)
+
+ if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
+ _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
+
+ logging.info("Waiting for checkpoint manager to finish")
+ checkpoint_manager.wait_until_finished()
+
+
+if __name__ == "__main__":
+ main(_config.cli())
diff --git a/openpi/scripts/train_pytorch.py b/openpi/scripts/train_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ddd2b5950bd4d208e4831a3abab42dcf2ccee7
--- /dev/null
+++ b/openpi/scripts/train_pytorch.py
@@ -0,0 +1,632 @@
+"""
+PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
+This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
+entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
+pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
+
+Usage
+Single GPU:
+ python scripts/train_pytorch.py --exp_name --save_interval
+ Example:
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
+ python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
+Multi-GPU (single node):
+ torchrun --standalone --nnodes=1 --nproc_per_node= scripts/train_pytorch.py --exp_name
+ Example:
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
+ torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
+Multi-Node Training:
+ torchrun \
+ --nnodes= --nproc_per_node= --node_rank= \
+ --master_addr= --master_port= \
+ scripts/train_pytorch.py --exp_name= --save_interval
+
+"""
+
+import dataclasses
+import gc
+import logging
+import os
+import platform
+import shutil
+import time
+
+import jax
+import numpy as np
+import safetensors.torch
+import torch
+import torch.distributed as dist
+import torch.nn.parallel
+import tqdm
+import wandb
+
+import openpi.models.pi0_config
+import openpi.models_pytorch.pi0_pytorch
+import openpi.shared.normalize as _normalize
+import openpi.training.config as _config
+import openpi.training.data_loader as _data
+
+
+def init_logging():
+ level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
+
+ class CustomFormatter(logging.Formatter):
+ def format(self, record):
+ record.levelname = level_mapping.get(record.levelname, record.levelname)
+ return super().format(record)
+
+ formatter = CustomFormatter(
+ fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
+ datefmt="%H:%M:%S",
+ )
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+ if not logger.handlers:
+ ch = logging.StreamHandler()
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+ else:
+ logger.handlers[0].setFormatter(formatter)
+
+
+def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
+ """Initialize wandb logging."""
+ if not enabled:
+ wandb.init(mode="disabled")
+ return
+
+ ckpt_dir = config.checkpoint_dir
+ if not ckpt_dir.exists():
+ raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
+
+ if resuming:
+ run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
+ wandb.init(id=run_id, resume="must", project=config.project_name)
+ else:
+ wandb.init(
+ name=config.exp_name,
+ config=dataclasses.asdict(config),
+ project=config.project_name,
+ )
+ (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
+
+
+def setup_ddp():
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ use_ddp = world_size > 1
+ if use_ddp and not torch.distributed.is_initialized():
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
+ torch.distributed.init_process_group(backend=backend, init_method="env://")
+
+ # Set up debugging environment variables for DDP issues
+ if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
+
+ local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
+ if torch.cuda.is_available():
+ torch.cuda.set_device(device)
+ return use_ddp, local_rank, device
+
+
+def cleanup_ddp():
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+ torch.distributed.destroy_process_group()
+
+
+def set_seed(seed: int, local_rank: int):
+ torch.manual_seed(seed + local_rank)
+ np.random.seed(seed + local_rank)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed + local_rank)
+
+
+def build_datasets(config: _config.TrainConfig):
+ # Use the unified data loader with PyTorch framework
+ data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
+ return data_loader, data_loader.data_config()
+
+
+def get_model_state_dict(model):
+ """Get state dict from model, handling DDP wrapper."""
+ return (
+ model.module.state_dict()
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model.state_dict()
+ )
+
+
+def get_model_parameters(model):
+ """Get parameters from model, handling DDP wrapper."""
+ return (
+ model.module.parameters()
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel)
+ else model.parameters()
+ )
+
+
+def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
+ """Save a checkpoint with model state, optimizer state, and metadata."""
+ if not is_main:
+ return
+
+ # Only save if it's time to save or if it's the final step
+ if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
+ # Create temporary directory for atomic checkpoint saving
+ final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
+ tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
+
+ # Remove any existing temp directory and create new one
+ if tmp_ckpt_dir.exists():
+ shutil.rmtree(tmp_ckpt_dir)
+ tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save model state using safetensors (handle shared tensors)
+ model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
+ safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
+
+ # Save optimizer state using PyTorch format
+ torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
+
+ # Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
+ metadata = {
+ "global_step": global_step,
+ "config": dataclasses.asdict(config),
+ "timestamp": time.time(),
+ }
+ torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
+
+ # save norm stats
+ norm_stats = data_config.norm_stats
+ if norm_stats is not None and data_config.asset_id is not None:
+ _normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
+
+ # Atomically move temp directory to final location
+ if final_ckpt_dir.exists():
+ shutil.rmtree(final_ckpt_dir)
+ tmp_ckpt_dir.rename(final_ckpt_dir)
+
+ logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
+
+ # Log checkpoint to wandb
+ if config.wandb_enabled:
+ wandb.log({"checkpoint_step": global_step}, step=global_step)
+
+
+def load_checkpoint(model, optimizer, checkpoint_dir, device):
+ """Load the latest checkpoint and return the global step."""
+ checkpoint_steps = [
+ int(d.name)
+ for d in checkpoint_dir.iterdir()
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
+ ]
+
+ if not checkpoint_steps:
+ raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
+
+ latest_step = max(checkpoint_steps)
+ ckpt_dir = checkpoint_dir / f"{latest_step}"
+
+ # Clear memory before loading checkpoints
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, "before_loading_checkpoint")
+
+ try:
+ # Load model state with error handling
+ logging.info("Loading model state...")
+ safetensors_path = ckpt_dir / "model.safetensors"
+
+ if safetensors_path.exists():
+ model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
+ safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
+ logging.info("Loaded model state from safetensors format")
+ else:
+ raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
+
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, "after_loading_model")
+
+ # Load optimizer state with error handling
+ logging.info("Loading optimizer state...")
+ optimizer_path = ckpt_dir / "optimizer.pt"
+
+ if optimizer_path.exists():
+ optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
+ logging.info("Loaded optimizer state from pt format")
+ else:
+ raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
+
+ optimizer.load_state_dict(optimizer_state_dict)
+ del optimizer_state_dict
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, "after_loading_optimizer")
+
+ # Load metadata
+ logging.info("Loading metadata...")
+ metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
+ global_step = metadata.get("global_step", latest_step)
+ del metadata
+ torch.cuda.empty_cache()
+ gc.collect()
+ log_memory_usage(device, latest_step, "after_loading_metadata")
+
+ logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
+ return global_step
+
+ except RuntimeError as e:
+ if "out of memory" in str(e):
+ # Clear memory and provide detailed error message
+ torch.cuda.empty_cache()
+ gc.collect()
+ logging.error(f"Out of memory error while loading checkpoint: {e!s}")
+ log_memory_usage(device, latest_step, "after_oom_error")
+ raise RuntimeError(
+ "Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
+ ) from e
+ raise
+
+
+def get_latest_checkpoint_step(checkpoint_dir):
+ """Get the latest checkpoint step number from a checkpoint directory."""
+ checkpoint_steps = [
+ int(d.name)
+ for d in checkpoint_dir.iterdir()
+ if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
+ ]
+ return max(checkpoint_steps) if checkpoint_steps else None
+
+
+def log_memory_usage(device, step, phase="unknown"):
+ """Log detailed memory usage information."""
+ if not torch.cuda.is_available():
+ return
+
+ memory_allocated = torch.cuda.memory_allocated(device) / 1e9
+ memory_reserved = torch.cuda.memory_reserved(device) / 1e9
+ memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
+ memory_free = memory_free / 1e9
+
+ # Get more detailed memory info
+ memory_stats = torch.cuda.memory_stats(device)
+ max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
+ max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
+
+ # Get DDP info if available
+ ddp_info = ""
+ if dist.is_initialized():
+ ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
+
+ logging.info(
+ f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
+ )
+
+
+def train_loop(config: _config.TrainConfig):
+ use_ddp, local_rank, device = setup_ddp()
+ is_main = (not use_ddp) or (dist.get_rank() == 0)
+ set_seed(config.seed, local_rank)
+
+ # Initialize checkpoint directory and wandb
+ resuming = False
+ if config.resume:
+ # Find checkpoint directory based on experiment name
+ exp_checkpoint_dir = config.checkpoint_dir
+ if exp_checkpoint_dir.exists():
+ # Use validation to find the latest working checkpoint
+ latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
+ if latest_step is not None:
+ resuming = True
+ logging.info(
+ f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
+ )
+ else:
+ raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
+ else:
+ raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
+ elif config.overwrite and config.checkpoint_dir.exists():
+ shutil.rmtree(config.checkpoint_dir)
+ logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
+
+ # Create checkpoint directory with experiment name
+ if not resuming:
+ # For new runs, create experiment-specific checkpoint directory
+ exp_checkpoint_dir = config.checkpoint_dir
+ exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
+ logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
+ else:
+ # For resume, checkpoint_dir is already set to the experiment directory
+ logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
+
+ # Initialize wandb (only on main process)
+ if is_main:
+ init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
+
+ # Build data loader using the unified data loader
+ # Calculate effective batch size per GPU for DDP
+ # For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
+ world_size = torch.distributed.get_world_size() if use_ddp else 1
+ effective_batch_size = config.batch_size // world_size
+ logging.info(
+ f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
+ )
+
+ # Pass the original batch size to data loader - it will handle DDP splitting internally
+ loader, data_config = build_datasets(config)
+
+ # Log sample images to wandb on first batch
+ if is_main and config.wandb_enabled and not resuming:
+ # Create a separate data loader for sample batch to avoid consuming the main loader
+ sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
+ sample_batch = next(iter(sample_data_loader))
+ # Convert observation and actions to torch tensors
+ observation, actions = sample_batch
+ sample_batch = observation.to_dict()
+ sample_batch["actions"] = actions
+
+ # Create sample images for wandb
+ images_to_log = []
+ # Get batch size from the first image tensor
+ batch_size = next(iter(sample_batch["image"].values())).shape[0]
+ for i in range(min(5, batch_size)):
+ # Concatenate all camera views horizontally for this batch item
+ # Convert from NCHW to NHWC format for wandb
+ img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
+ img_concatenated = img_concatenated.cpu().numpy()
+ images_to_log.append(wandb.Image(img_concatenated))
+
+ wandb.log({"camera_views": images_to_log}, step=0)
+
+ # Clear sample batch from memory aggressively
+ del sample_batch, observation, actions, images_to_log, img_concatenated
+ del sample_data_loader # Also delete the sample data loader
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ logging.info("Cleared sample batch and data loader from memory")
+
+ # Build model
+ if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
+ # Convert dataclass to Pi0Config if needed
+ model_cfg = openpi.models.pi0_config.Pi0Config(
+ dtype=config.pytorch_training_precision,
+ action_dim=config.model.action_dim,
+ action_horizon=config.model.action_horizon,
+ max_token_len=config.model.max_token_len,
+ paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
+ action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
+ pi05=getattr(config.model, "pi05", False),
+ )
+ else:
+ model_cfg = config.model
+ # Update dtype to match pytorch_training_precision
+ object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
+
+ model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
+
+ if hasattr(model, "gradient_checkpointing_enable"):
+ enable_gradient_checkpointing = True
+ model.gradient_checkpointing_enable()
+ logging.info("Enabled gradient checkpointing for memory optimization")
+ else:
+ enable_gradient_checkpointing = False
+ logging.info("Gradient checkpointing is not supported for this model")
+
+ # Log initial memory usage after model creation
+ if is_main and torch.cuda.is_available():
+ log_memory_usage(device, 0, "after_model_creation")
+
+ # Enable memory optimizations for large-scale training
+ if world_size >= 8:
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ # Set memory allocation configuration
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
+ logging.info("Enabled memory optimizations for 8+ GPU training")
+
+ if use_ddp:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[device.index] if device.type == "cuda" else None,
+ find_unused_parameters=True, # Disable for memory efficiency
+ gradient_as_bucket_view=True, # Enable for memory efficiency
+ static_graph=world_size >= 8, # Enable for 8+ GPUs
+ )
+
+ # Load weights from weight_loader if specified (for fine-tuning)
+ if config.pytorch_weight_path is not None:
+ logging.info(f"Loading weights from: {config.pytorch_weight_path}")
+
+ model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
+ safetensors.torch.load_model(
+ (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
+ )
+ logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
+
+ # Optimizer + learning rate schedule from config
+ warmup_steps = config.lr_schedule.warmup_steps
+ peak_lr = config.lr_schedule.peak_lr
+ decay_steps = config.lr_schedule.decay_steps
+ end_lr = config.lr_schedule.decay_lr
+
+ # Create optimizer with config parameters
+ optim = torch.optim.AdamW(
+ model.parameters(),
+ lr=peak_lr,
+ betas=(config.optimizer.b1, config.optimizer.b2),
+ eps=config.optimizer.eps,
+ weight_decay=config.optimizer.weight_decay,
+ )
+
+ # Load checkpoint if resuming
+ global_step = 0
+ if resuming:
+ global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
+ logging.info(f"Resumed training from step {global_step}")
+
+ def lr_schedule(step: int):
+ if step < warmup_steps:
+ # Match JAX behavior: start from peak_lr / (warmup_steps + 1)
+ init_lr = peak_lr / (warmup_steps + 1)
+ return init_lr + (peak_lr - init_lr) * step / warmup_steps
+ # cosine decay
+ progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
+ cos = 0.5 * (1 + np.cos(np.pi * progress))
+ return end_lr + (peak_lr - end_lr) * cos
+
+ model.train()
+ start_time = time.time()
+ infos = [] # Collect stats over log interval
+ if is_main:
+ logging.info(
+ f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
+ )
+ logging.info(
+ f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
+ )
+ logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
+ logging.info(
+ f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
+ )
+ logging.info(
+ f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
+ )
+ logging.info("EMA is not supported for PyTorch training")
+ logging.info(f"Training precision: {model_cfg.dtype}")
+
+ # Training loop - iterate until we reach num_train_steps
+ pbar = (
+ tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
+ if is_main
+ else None
+ )
+
+ while global_step < config.num_train_steps:
+ # Set epoch for distributed training
+ if use_ddp and hasattr(loader, "set_epoch"):
+ loader.set_epoch(global_step // len(loader))
+
+ for observation, actions in loader:
+ # Check if we've reached the target number of steps
+ if global_step >= config.num_train_steps:
+ break
+
+ # The unified data loader returns (observation, actions) tuple
+ observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
+ actions = actions.to(torch.float32) # noqa: PLW2901
+ actions = actions.to(device) # noqa: PLW2901
+
+ # Update LR
+ for pg in optim.param_groups:
+ pg["lr"] = lr_schedule(global_step)
+
+ # Forward pass
+ losses = model(observation, actions)
+ # Ensure losses is a tensor and handle different return types
+ if isinstance(losses, list | tuple):
+ losses = torch.stack(losses)
+ elif not isinstance(losses, torch.Tensor):
+ losses = torch.tensor(losses, device=device, dtype=torch.float32)
+
+ loss = losses.mean()
+
+ # Backward pass
+ loss.backward()
+
+ # Log memory usage after backward pass
+ if global_step < 5 and is_main and torch.cuda.is_available():
+ log_memory_usage(device, global_step, "after_backward")
+
+ # Gradient clipping
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
+
+ # Optimizer step
+ optim.step()
+ optim.zero_grad(set_to_none=True)
+
+ # Clear gradients more aggressively
+ for param in model.parameters():
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad = None
+
+ # Collect stats
+ if is_main:
+ infos.append(
+ {
+ "loss": loss.item(),
+ "learning_rate": optim.param_groups[0]["lr"],
+ "grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
+ }
+ )
+
+ if is_main and (global_step % config.log_interval == 0):
+ elapsed = time.time() - start_time
+
+ # Average stats over log interval
+ avg_loss = sum(info["loss"] for info in infos) / len(infos)
+ avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
+
+ avg_grad_norm = None
+ if any("grad_norm" in info for info in infos):
+ vals = [
+ info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
+ ]
+ if len(vals) > 0:
+ avg_grad_norm = sum(vals) / len(vals)
+ logging.info(
+ f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
+ if avg_grad_norm is not None
+ else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
+ )
+
+ # Log to wandb
+ if config.wandb_enabled and len(infos) > 0:
+ log_payload = {
+ "loss": avg_loss,
+ "learning_rate": avg_lr,
+ "step": global_step,
+ "time_per_step": elapsed / config.log_interval,
+ }
+ if avg_grad_norm is not None:
+ log_payload["grad_norm"] = avg_grad_norm
+ wandb.log(log_payload, step=global_step)
+
+ start_time = time.time()
+ infos = [] # Reset stats collection
+
+ global_step += 1
+ # Save checkpoint using the new mechanism
+ save_checkpoint(model, optim, global_step, config, is_main, data_config)
+
+ # Update progress bar
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_postfix(
+ {"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
+ )
+
+ # Close progress bar
+ if pbar is not None:
+ pbar.close()
+
+ # Finish wandb run
+ if is_main and config.wandb_enabled:
+ wandb.finish()
+
+ cleanup_ddp()
+
+
+def main():
+ init_logging()
+ config = _config.cli()
+ train_loop(config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/openpi/scripts/train_test.py b/openpi/scripts/train_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e0bb7bcfdab3ca5855794fc604dac597d96d8c8
--- /dev/null
+++ b/openpi/scripts/train_test.py
@@ -0,0 +1,30 @@
+import dataclasses
+import os
+import pathlib
+
+import pytest
+
+os.environ["JAX_PLATFORMS"] = "cpu"
+
+from openpi.training import config as _config
+
+from . import train
+
+
+@pytest.mark.parametrize("config_name", ["debug"])
+def test_train(tmp_path: pathlib.Path, config_name: str):
+ config = dataclasses.replace(
+ _config._CONFIGS_DICT[config_name], # noqa: SLF001
+ batch_size=2,
+ checkpoint_base_dir=str(tmp_path / "checkpoint"),
+ exp_name="test",
+ overwrite=False,
+ resume=False,
+ num_train_steps=2,
+ log_interval=1,
+ )
+ train.main(config)
+
+ # test resuming
+ config = dataclasses.replace(config, resume=True, num_train_steps=4)
+ train.main(config)
diff --git a/openpi/src/openpi/__init__.py b/openpi/src/openpi/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openpi/src/openpi/__pycache__/__init__.cpython-311.pyc b/openpi/src/openpi/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ea2beaa5435b4895586f0f542884fb212ff81f1
Binary files /dev/null and b/openpi/src/openpi/__pycache__/__init__.cpython-311.pyc differ
diff --git a/openpi/src/openpi/conftest.py b/openpi/src/openpi/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..5002b629de77953e03f24157f6ba4c88fc448468
--- /dev/null
+++ b/openpi/src/openpi/conftest.py
@@ -0,0 +1,17 @@
+import os
+
+import pynvml
+import pytest
+
+
+def set_jax_cpu_backend_if_no_gpu() -> None:
+ try:
+ pynvml.nvmlInit()
+ pynvml.nvmlShutdown()
+ except pynvml.NVMLError:
+ # No GPU found.
+ os.environ["JAX_PLATFORMS"] = "cpu"
+
+
+def pytest_configure(config: pytest.Config) -> None:
+ set_jax_cpu_backend_if_no_gpu()
diff --git a/openpi/src/openpi/models/__init__.py b/openpi/src/openpi/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openpi/src/openpi/models/gemma.py b/openpi/src/openpi/models/gemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1623d19d12517977d55e193361c5ced481efbc8
--- /dev/null
+++ b/openpi/src/openpi/models/gemma.py
@@ -0,0 +1,459 @@
+# Copyright 2024 Big Vision Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Gemma adaptation for Pi, taken from big_vision.
+
+We follow this einsum axis naming convention:
+ B: batch
+ T: query length
+ S: k/v length
+ N: num query heads
+ K: num k/v heads
+ G: num query heads per k/v head
+ H: head dim
+ D: d_model ("features")
+"""
+
+from collections.abc import Sequence
+import dataclasses
+from typing import Literal, TypeAlias
+
+import einops
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+import openpi.models.lora as lora
+import openpi.shared.array_typing as at
+import openpi.training.sharding as sharding
+
+PALIGEMMA_VOCAB_SIZE = 257_152
+
+
+@dataclasses.dataclass
+class Config:
+ width: int
+ depth: int
+ mlp_dim: int
+ num_heads: int
+ num_kv_heads: int
+ head_dim: int
+ lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
+
+
+Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
+
+
+def get_config(variant: Variant) -> Config:
+ """Returns config for specified gemma variant."""
+ if variant == "dummy":
+ return Config(
+ width=64,
+ depth=4,
+ mlp_dim=128,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=16,
+ )
+ if variant == "gemma_300m":
+ # 311M params
+ return Config(
+ width=1024,
+ depth=18,
+ mlp_dim=4096,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ if variant == "gemma_2b":
+ return Config(
+ width=2048,
+ depth=18,
+ mlp_dim=16_384,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ )
+ if variant == "gemma_2b_lora":
+ return Config(
+ width=2048,
+ depth=18,
+ mlp_dim=16_384,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
+ )
+ if variant == "gemma_300m_lora":
+ # 311M params
+ return Config(
+ width=1024,
+ depth=18,
+ mlp_dim=4096,
+ num_heads=8,
+ num_kv_heads=1,
+ head_dim=256,
+ lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
+ )
+ raise ValueError(f"Unknown variant: {variant}")
+
+
+@at.typecheck
+class RMSNorm(nn.Module):
+ @nn.compact
+ def __call__(self, x, cond):
+ dtype = x.dtype # original dtype, could be half-precision
+ var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
+ normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
+ if cond is None:
+ # regular RMSNorm
+ scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
+ normed_inputs = normed_inputs * (
+ 1 + scale
+ ) # scale by learned parameter in float32 (matches Flax implementation)
+ return normed_inputs.astype(dtype), None # return in original dtype
+
+ # adaptive RMSNorm
+ modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
+ scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
+ normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
+ return normed_inputs.astype(dtype), gate
+
+
+@at.typecheck
+class Embedder(nn.Module):
+ """Embedder module."""
+
+ vocab_size: int
+ embed_dim: int
+
+ def setup(self):
+ self.input_embedding_table = self.param(
+ "input_embedding",
+ nn.initializers.normal(),
+ (self.vocab_size, self.embed_dim),
+ )
+
+ def encode(self, x):
+ x = self.input_embedding_table[(x,)]
+ x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
+ return x
+
+ def decode(self, x):
+ return jnp.dot(x, self.input_embedding_table.T)
+
+
+@at.typecheck
+class Attention(nn.Module):
+ """Attention module."""
+
+ configs: Sequence[Config]
+
+ @nn.compact
+ def __call__(self, xs, positions, attn_mask, kv_cache):
+ # all experts must share the same head dim, num heads, and num kv heads for self-attention to work
+ assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
+ assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
+ assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
+
+ dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
+
+ qkvs = []
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
+ if x is None:
+ continue
+ if config.num_kv_heads == config.num_heads:
+ qkv_einsum = lora.Einsum(
+ shape=(3, config.num_heads, config.width, config.head_dim),
+ name=_name("qkv_einsum", i),
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
+ lora_config=config.lora_configs.get("attn"),
+ )
+ qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
+ else:
+ q_einsum = lora.Einsum(
+ shape=(config.num_heads, config.width, config.head_dim),
+ name=_name("q_einsum", i),
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
+ lora_config=config.lora_configs.get("attn"),
+ )
+ q = q_einsum("BTD,NDH->BTNH", x)
+ kv_einsum = lora.Einsum(
+ shape=(2, config.num_kv_heads, config.width, config.head_dim),
+ name=_name("kv_einsum", i),
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
+ lora_config=config.lora_configs.get("attn"),
+ )
+ k, v = kv_einsum("BSD,2KDH->2BSKH", x)
+ qkvs.append((q, k, v))
+
+ q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
+
+ q = _apply_rope(q, positions=positions)
+ q *= self.configs[0].head_dim ** -0.5
+
+ k = _apply_rope(k, positions=positions)
+
+ # should still be half-precision here (if input was half-precision)
+ assert q.dtype == k.dtype == v.dtype == dtype
+
+ if kv_cache is not None:
+ cache_k, cache_v = kv_cache
+ k = jnp.concatenate([cache_k, k], axis=1)
+ v = jnp.concatenate([cache_v, v], axis=1)
+
+ q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
+ logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
+
+ if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
+ raise ValueError(
+ f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
+ )
+
+ # big_neg = jnp.finfo(logits.dtype).min
+ big_neg = -2.3819763e38 # See gemma/modules.py
+ masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
+
+ probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
+
+ encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
+ encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
+
+ out = []
+ start = 0
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
+ if x is not None:
+ end = start + x.shape[1]
+ out_einsum = lora.Einsum(
+ shape=(config.num_heads, config.head_dim, config.width),
+ name=_name("attn_vec_einsum", i),
+ init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
+ lora_config=config.lora_configs.get("attn"),
+ )
+ out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
+ start = end
+ else:
+ out.append(None)
+
+ return out, (k, v)
+
+
+@at.typecheck
+class FeedForward(nn.Module):
+ """Feed forward module."""
+
+ features: int
+ hidden_dim: int
+
+ @nn.compact
+ def __call__(self, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ w_gating = self.param(
+ "gating_einsum",
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
+ (2, self.features, self.hidden_dim),
+ ).astype(dtype)
+ ff_gate = jnp.dot(x, w_gating[0])
+ gate_value = nn.gelu(ff_gate)
+
+ ff1 = jnp.dot(x, w_gating[1])
+ activations = gate_value * ff1
+
+ w_linear = self.param(
+ "linear",
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
+ (self.hidden_dim, self.features),
+ ).astype(dtype)
+ outputs = jnp.dot(activations, w_linear)
+ assert outputs.dtype == dtype
+ return outputs
+
+
+@at.typecheck
+class Block(nn.Module):
+ """Transformer block."""
+
+ configs: tuple[Config, ...]
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[int, ...] = ()
+
+ @nn.compact
+ def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
+ xs = sharding.activation_sharding_constraint(xs)
+ drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
+
+ attn = Attention(configs=self.configs, name="attn")
+
+ pre_attn = []
+ gates = []
+ for i, x in enumerate(xs):
+ if x is not None:
+ x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
+ pre_attn.append(x)
+ gates.append(gate if x is not None else None)
+
+ pre_attn = sharding.activation_sharding_constraint(pre_attn)
+ post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
+ post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
+ post_attn = sharding.activation_sharding_constraint(post_attn)
+ xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
+ xs = sharding.activation_sharding_constraint(xs)
+
+ out = []
+ gates = []
+ for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
+ if x is not None:
+ x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
+ x = lora.FeedForward( # noqa: PLW2901
+ features=config.width,
+ hidden_dim=config.mlp_dim,
+ name=_name("mlp", i),
+ lora_config=config.lora_configs.get("ffn"),
+ )(x)
+ out.append(x)
+ gates.append(gate if x is not None else None)
+
+ out = sharding.activation_sharding_constraint(out)
+ out = jax.tree.map(lambda x: drop(x, deterministic), out)
+ xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
+ xs = sharding.activation_sharding_constraint(xs)
+
+ return xs, kv_cache
+
+
+KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
+
+
+@at.typecheck
+class Module(nn.Module):
+ """Transformer model, supporting a mixture of different weights for different tokens."""
+
+ configs: Sequence[Config] # list of configs, one for each expert
+ embed_dtype: str
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
+ adarms: bool = False
+
+ def setup(self):
+ # all experts must have the same depth
+ assert all(config.depth == self.configs[0].depth for config in self.configs)
+
+ self.embedder = Embedder(
+ vocab_size=PALIGEMMA_VOCAB_SIZE,
+ embed_dim=self.configs[0].width, # embedder for first expert only
+ name="embedder",
+ )
+ block_cls = nn.remat(
+ Block,
+ prevent_cse=False,
+ static_argnums=(5,), # 0=self, 6=deterministic
+ policy=jax.checkpoint_policies.nothing_saveable,
+ )
+ self.layers = nn.scan(
+ block_cls,
+ variable_axes={"params": 0},
+ split_rngs={"params": True, "dropout": True},
+ in_axes=(
+ 0,
+ nn.broadcast,
+ nn.broadcast,
+ nn.broadcast,
+ nn.broadcast,
+ ), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
+ length=self.configs[0].depth,
+ )(
+ configs=self.configs,
+ dropout=self.dropout,
+ dropout_bdims=self.dropout_bdims,
+ )
+ self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
+
+ @at.typecheck
+ def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
+ return self.embedder.encode(tokens).astype(self.embed_dtype)
+
+ @at.typecheck
+ def __call__(
+ self,
+ # list of token arrays, one for each expert, or None if that expert should not be run
+ embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
+ positions: at.Int[at.Array, "b t"],
+ mask: at.Bool[at.Array, "b t s"],
+ adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
+ *,
+ kv_cache: KVCache | None = None,
+ deterministic: bool = True,
+ ) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
+ embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
+ mask = jnp.asarray(mask)[:, None, :, :]
+ if adarms_cond is None:
+ adarms_cond = [None] * len(self.configs)
+
+ embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
+
+ assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
+
+ return [
+ f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
+ ], kv_cache
+
+ def init(self, use_adarms: Sequence[bool]):
+ """Convenience method for initializing all parameters, necessary due to the quirks of linen."""
+ self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
+ self(
+ [jnp.zeros((1, 1, c.width)) for c in self.configs],
+ jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
+ jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
+ adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
+ )
+
+
+def _apply_rope(x, *, positions, max_wavelength=10_000):
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
+ freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
+ timescale = max_wavelength**freq_exponents
+ radians = positions[..., None] / timescale[None, None, :]
+ radians = radians[..., None, :]
+ assert radians.dtype == jnp.float32
+ # radians.shape = [...,L,1,d=D/2]
+ sin, cos = jnp.sin(radians), jnp.cos(radians)
+ x1, x2 = jnp.split(x, 2, axis=-1)
+ res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
+ assert res.dtype == jnp.float32
+ # The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
+ # dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
+ # original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
+ # here.
+ return res.astype(x.dtype)
+
+
+def _name(name, i):
+ # we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
+ # can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
+ # "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
+ # and the action expert.
+ if i == 0:
+ return name
+ return f"{name}_{i}"
+
+
+def _gated_residual(x, y, gate):
+ assert (x is None) == (y is None)
+ if x is None:
+ return None
+ if gate is None:
+ return x + y
+ return x + y * gate
diff --git a/openpi/src/openpi/models/gemma_fast.py b/openpi/src/openpi/models/gemma_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..eee39b4317dd3f42aa4d73e11dceee37e394c39d
--- /dev/null
+++ b/openpi/src/openpi/models/gemma_fast.py
@@ -0,0 +1,437 @@
+# Copyright 2024 Big Vision Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
+Used for FAST autoregressive policies.
+"""
+
+import dataclasses
+from typing import Literal, TypeAlias
+
+import einops
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import ml_collections
+
+import openpi.models.lora as lora
+import openpi.shared.array_typing as at
+
+Variant = Literal["gemma_2b", "gemma_2b_lora"]
+
+
+def get_config(variant):
+ """Returns config for specified gemma variant."""
+ if variant == "gemma_2b":
+ return ml_collections.ConfigDict(
+ {
+ "variant": variant,
+ "width": 2048,
+ "depth": 18,
+ "mlp_dim": 16_384,
+ "num_heads": 8,
+ "num_kv_heads": 1,
+ "head_dim": 256,
+ "norm_eps": 1e-6,
+ "vocab_size": 257_152,
+ "scan": True,
+ "remat_policy": "nothing_saveable",
+ }
+ )
+ if variant == "gemma_2b_lora":
+ return ml_collections.ConfigDict(
+ {
+ "variant": variant,
+ "width": 2048,
+ "depth": 18,
+ "mlp_dim": 16_384,
+ "num_heads": 8,
+ "num_kv_heads": 1,
+ "head_dim": 256,
+ "norm_eps": 1e-6,
+ "vocab_size": 257_152,
+ "scan": True,
+ "remat_policy": "nothing_saveable",
+ "lora_configs": {
+ "attn": lora.LoRAConfig(rank=16, alpha=16.0),
+ "ffn": lora.LoRAConfig(rank=16, alpha=16.0),
+ },
+ }
+ )
+ raise ValueError(f"Unknown variant: {variant}")
+
+
+@at.typecheck
+class Einsum(nn.Module):
+ shape: tuple[int, ...]
+
+ @nn.compact
+ def __call__(self, eqn, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
+ return jnp.einsum(eqn, x, w)
+
+
+@at.typecheck
+class RMSNorm(nn.Module):
+ @nn.compact
+ def __call__(self, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
+ var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
+ normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
+ normed_inputs = normed_inputs * (
+ 1 + scale
+ ) # scale by learned parameter in float32 (matches Flax implementation)
+ return normed_inputs.astype(dtype) # return in original dtype
+
+
+@at.typecheck
+class Embedder(nn.Module):
+ """Embedder module."""
+
+ vocab_size: int
+ embed_dim: int
+
+ def setup(self):
+ self.input_embedding_table = self.param(
+ "input_embedding",
+ nn.initializers.zeros_init(),
+ (self.vocab_size, self.embed_dim),
+ )
+
+ def encode(self, x):
+ x = self.input_embedding_table[(x,)]
+ x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
+ return x
+
+ def decode(self, x):
+ return jnp.dot(x, self.input_embedding_table.T)
+
+
+@at.typecheck
+class Attention(nn.Module):
+ """Attention module."""
+
+ num_heads: int
+ num_kv_heads: int
+ features: int
+ head_dim: int
+
+ cache_dtype: str | None = None
+
+ lora_config: lora.LoRAConfig | None = None
+
+ def setup(self):
+ if self.num_kv_heads == self.num_heads:
+ self.qkv_einsum = lora.Einsum(
+ shape=(3, self.num_heads, self.features, self.head_dim),
+ name="qkv_einsum",
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
+ lora_config=self.lora_config,
+ )
+ else:
+ self.q_einsum = lora.Einsum(
+ shape=(self.num_heads, self.features, self.head_dim),
+ name="q_einsum",
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
+ lora_config=self.lora_config,
+ )
+ self.kv_einsum = lora.Einsum(
+ shape=(2, self.num_kv_heads, self.features, self.head_dim),
+ name="kv_einsum",
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
+ lora_config=self.lora_config,
+ )
+ self.attn_vec_einsum = lora.Einsum(
+ shape=(self.num_heads, self.head_dim, self.features),
+ name="attn_vec_einsum",
+ init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
+ lora_config=self.lora_config,
+ )
+
+ def _init_cache(self, k, v, cache_size):
+ """Initialize KV cache"""
+ prefill_len = k.shape[1]
+ pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
+ cache_dtype = self.cache_dtype or k.dtype
+ k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
+ v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
+ idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
+ return idx, k_cache, v_cache
+
+ def _update_cache(self, k, v, idx, k_cache, v_cache):
+ """Update KV cache with new values"""
+ assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
+ indices = (0, idx[0], 0, 0)
+ cache_dtype = self.cache_dtype or k.dtype
+ k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
+ v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
+ idx_new = idx + 1
+ return idx_new, k_new, v_new
+
+ @nn.compact
+ def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
+ dtype = x.dtype # original dtype, could be half-precision
+ if self.num_kv_heads == self.num_heads:
+ q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
+ else:
+ q = self.q_einsum("BTD,NDH->BTNH", x)
+ k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
+
+ q = _apply_rope(q, positions=positions) # promotes to float32
+ q *= self.head_dim**-0.5
+
+ k = _apply_rope(k, positions=positions) # promotes to float32
+
+ if kv_cache is None:
+ idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
+ else:
+ idx, k_cache, v_cache = kv_cache
+ idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
+
+ k, v = k_cache, v_cache
+ kv_cache = (idx, k_cache, v_cache)
+
+ q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
+ logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
+
+ if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
+ raise ValueError(
+ f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
+ )
+
+ # big_neg = jnp.finfo(logits.dtype).min
+ big_neg = -2.3819763e38 # See gemma/modules.py
+ masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
+
+ probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
+
+ encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
+ encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
+ return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
+
+
+@at.typecheck
+class Block(nn.Module):
+ """Transformer block."""
+
+ num_heads: int
+ num_kv_heads: int
+ embed_dim: int
+ head_dim: int
+ hidden_dim: int
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[int, ...] = ()
+ cache_dtype: str | None = None
+ lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
+
+ def setup(self):
+ self.pre_attention_norm = RMSNorm()
+ self.attn = Attention(
+ num_heads=self.num_heads,
+ num_kv_heads=self.num_kv_heads,
+ features=self.embed_dim,
+ head_dim=self.head_dim,
+ cache_dtype=self.cache_dtype,
+ lora_config=self.lora_configs.get("attn"),
+ )
+ self.pre_ffw_norm = RMSNorm()
+ self.mlp = lora.FeedForward(
+ features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
+ )
+ if self.dropout:
+ self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
+ else:
+ self.drop = lambda x, _: x
+
+ def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
+ x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
+ inputs_normalized = self.pre_attention_norm(x)
+ attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
+ attn_output = self.drop(attn_output, deterministic)
+ attn_output += x
+ residual = attn_output
+ attn_output = self.pre_ffw_norm(attn_output)
+ outputs = self.mlp(attn_output)
+ outputs = self.drop(outputs, deterministic)
+ outputs = residual + outputs
+ return outputs, kv_cache
+
+
+KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
+
+
+@at.typecheck
+class Module(nn.Module):
+ """gemma model."""
+
+ variant: str
+
+ width: int
+ depth: int
+ mlp_dim: int
+ num_heads: int
+ num_kv_heads: int
+ head_dim: int
+ norm_eps: float
+ vocab_size: int
+ embed_dtype: str
+
+ dropout: float = 0.0
+ dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
+ cache_dtype: str | None = None
+
+ scan: bool = False
+ remat_policy: str = "none"
+ lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
+
+ @nn.compact
+ def __call__(
+ self,
+ tokens=None,
+ embedded_prefix=None,
+ embed_only=False, # noqa: FBT002
+ pre_logits=None,
+ positions=None,
+ mask=None,
+ decode=False, # noqa: FBT002
+ kv_cache=None,
+ deterministic=True, # noqa: FBT002
+ return_prelogits=False, # noqa: FBT002
+ ):
+ """Embed only, or complete forward pass.
+
+ Args:
+ tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
+ embedded_prefix: Optional prefix that is already embedded.
+ embed_only: Whether to compute embeddings only.
+ pre_logits: If present computes logits from pre_logits and returns.
+ positions: Optional `[B, T]` allows to specify the absolute position of
+ the tokens.
+ mask: Optional attention mask `[B, T, S]`.
+ decode: Whether to use kv-cache. Caller must pass masks and positions.
+ deterministic: Forwarded to all dropout layers.
+ return_prelogits: Whether to return the pre-logits.
+
+ Returns:
+ If `embed_only=False`, then `(logits, out)` will be returned.
+ If `embed_only=True`, then the embeddings will be returned.
+ If `return_prelogits=True`, then the pre-logits will be returned.
+ """
+ out = {}
+
+ embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
+
+ if pre_logits is not None:
+ x = out["pre_logits"] = pre_logits
+ logits = out["logits"] = embedder.decode(x)
+ return logits, out
+
+ x = []
+ if embedded_prefix is not None:
+ x.append(embedded_prefix)
+ if tokens is not None:
+ x.append(embedder.encode(tokens))
+
+ x = jnp.concatenate(x, axis=-2)
+ x = x.astype(self.embed_dtype)
+ batch_size, seq_len, width = x.shape
+
+ if embed_only:
+ return x
+
+ if decode:
+ assert positions is not None and mask is not None, ( # noqa: PT018
+ "Must explicitly pass positions and mask for decoding."
+ )
+
+ if positions is None:
+ positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
+ assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
+
+ if mask is None:
+ mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
+ if mask.ndim == 3:
+ mask = mask[:, None, :, :]
+ cache_size = max(seq_len, mask.shape[-1])
+ assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
+
+ if self.remat_policy == "none":
+ block_cls = Block
+ else:
+ block_cls = nn.remat(
+ Block,
+ prevent_cse=not self.scan,
+ static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
+ policy=getattr(jax.checkpoint_policies, self.remat_policy),
+ )
+
+ block_kw = {
+ "num_heads": self.num_heads,
+ "head_dim": self.head_dim,
+ "num_kv_heads": self.num_kv_heads,
+ "embed_dim": width,
+ "hidden_dim": self.mlp_dim,
+ "dropout": self.dropout,
+ "dropout_bdims": self.dropout_bdims,
+ "cache_dtype": self.cache_dtype,
+ "lora_configs": self.lora_configs,
+ }
+ layers = self.scope.push("layers")
+ blocks = [
+ nn.scan(
+ block_cls,
+ variable_axes={"params": 0},
+ split_rngs={"params": True, "dropout": True},
+ in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
+ length=self.depth,
+ )(parent=layers, **block_kw)
+ ]
+ for block in blocks:
+ x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
+
+ assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
+ out["encoded"] = x
+
+ x = RMSNorm(name="final_norm")(x)
+ out["pre_logits"] = x
+ if return_prelogits:
+ return x, kv_cache, out
+
+ x = embedder.decode(x)
+ out["logits"] = x
+
+ return x, kv_cache, out
+
+ def init(self):
+ """Convenience method for initializing all parameters, necessary due to the quirks of linen."""
+ self(jnp.zeros((1, 1), dtype=jnp.int32))
+
+
+def _apply_rope(x, *, positions, max_wavelength=10_000):
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
+ freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
+ timescale = max_wavelength**freq_exponents
+ radians = positions[..., None] / timescale[None, None, :]
+ radians = radians[..., None, :]
+ assert radians.dtype == jnp.float32
+ # radians.shape = [...,L,1,d=D/2]
+ sin, cos = jnp.sin(radians), jnp.cos(radians)
+ x1, x2 = jnp.split(x, 2, axis=-1)
+ res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
+ assert res.dtype == jnp.float32
+ return res
diff --git a/openpi/src/openpi/models/lora.py b/openpi/src/openpi/models/lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dfff5b4f22ff42f73298037a77df900c6b11d75
--- /dev/null
+++ b/openpi/src/openpi/models/lora.py
@@ -0,0 +1,148 @@
+import math
+import re
+
+import flax.linen as nn
+import flax.struct as struct
+import jax.numpy as jnp
+
+import openpi.shared.array_typing as at
+
+
+@struct.dataclass
+class LoRAConfig:
+ """Configuration for LoRA."""
+
+ # LoRA rank.
+ rank: int
+ # LoRA scaling factor.
+ alpha: float = 1.0
+ # Initialization function for LoRA parameters.
+ init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
+ # Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
+ rslora: bool = False
+ # Axes in the weight to apply LoRA to. Should typically be the last two axes.
+ axes: tuple[int, int] = (-2, -1)
+ # Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
+ label: str = "L"
+
+ @property
+ def scaling_value(self) -> float:
+ return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
+
+
+class Einsum(nn.Module):
+ """Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
+
+ # Shape of the weight.
+ shape: tuple[int, ...]
+ # Initialization function for the weight.
+ init_fn: nn.initializers.Initializer = nn.initializers.zeros
+ # If not None, apply LoRA to the weight.
+ lora_config: LoRAConfig | None = None
+
+ def setup(self):
+ self.w = self.param("w", self.init_fn, self.shape)
+
+ if config := self.lora_config:
+ # Setup LoRA parameters.
+ shape_a, shape_b = list(self.shape), list(self.shape)
+ shape_a[config.axes[1]] = config.rank
+ shape_b[config.axes[0]] = config.rank
+ self.w_a = self.param("lora_a", config.init_fn, shape_a)
+ self.w_b = self.param("lora_b", config.init_fn, shape_b)
+
+ @nn.compact
+ def __call__(self, eqn: str, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ result = jnp.einsum(eqn, x, self.w.astype(dtype))
+
+ if config := self.lora_config:
+ eqn_a, eqn_b = self._make_lora_eqns(eqn)
+ lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
+ lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
+ result = result + lora * config.scaling_value
+
+ return result
+
+ def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
+ if "L" in eqn:
+ raise ValueError(f"L already in eqn: {eqn}")
+ if not (m := re.match("(.*),(.*)->(.*)", eqn)):
+ raise ValueError(f"Unsupported einsum eqn: {eqn}")
+ lhs, rhs, out = m.groups()
+
+ assert self.lora_config is not None
+ a_label, b_label = (rhs[x] for x in self.lora_config.axes)
+ label = self.lora_config.label
+
+ a_rhs = rhs.replace(b_label, label)
+ a_out = out.replace(b_label, label)
+ eqn_a = f"{lhs},{a_rhs}->{a_out}"
+
+ b_rhs = rhs.replace(a_label, label)
+ eqn_b = f"{a_out},{b_rhs}->{out}"
+
+ return eqn_a, eqn_b
+
+
+class FeedForward(nn.Module):
+ """Feed forward module."""
+
+ features: int
+ hidden_dim: int
+ # If not None, apply LoRA to the weight.
+ lora_config: LoRAConfig | None = None
+
+ def setup(self):
+ self.w_gating = self.param(
+ "gating_einsum",
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
+ (2, self.features, self.hidden_dim),
+ )
+ self.w_linear = self.param(
+ "linear",
+ nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
+ (self.hidden_dim, self.features),
+ )
+ self.w_gating_lora = None
+ self.w_linear_lora = None
+ if self.lora_config:
+ # Setup LoRA parameters.
+ # TODO: follow up with a simplified init_fn api.
+ self.w_gating_lora = (
+ self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
+ self.param(
+ "gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
+ ),
+ )
+ self.w_linear_lora = (
+ self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
+ self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
+ )
+
+ @nn.compact
+ def __call__(self, x):
+ dtype = x.dtype # original dtype, could be half-precision
+ ff_gate = self._dot(
+ x,
+ self.w_gating[0],
+ None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
+ )
+ gate_value = nn.gelu(ff_gate)
+
+ ff1 = self._dot(
+ x,
+ self.w_gating[1],
+ None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
+ )
+ activations = gate_value * ff1
+
+ outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
+ assert outputs.dtype == dtype
+ return outputs
+
+ def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
+ base = jnp.dot(x, w.astype(x.dtype))
+ if lora_weights is None:
+ return base
+ return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
diff --git a/openpi/src/openpi/models/lora_test.py b/openpi/src/openpi/models/lora_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..48b65b6ae282c6bb0e6a410ee71204a4837ffc48
--- /dev/null
+++ b/openpi/src/openpi/models/lora_test.py
@@ -0,0 +1,94 @@
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+import openpi.models.lora as lora
+
+
+def test_lora_einsum_params_shape():
+ shape = (3, 8, 32, 4) # (3KDH)
+ einsum = lora.Einsum(shape)
+ lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
+ lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (8, 64, 32)) # (BSD)
+ eqn = "BSD,3KDH->3BSKH"
+
+ # Ensure that lora parameters are not initialized when LoRA is not used.
+ params = einsum.init(key, eqn, x)
+ assert "lora_a" not in params["params"]
+ assert "lora_b" not in params["params"]
+
+ # Check that default axes work.
+ params_lora0 = lora0.init(key, eqn, x)
+ assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
+ assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
+
+ # Check that user provided axes work.
+ params_lora1 = lora1.init(key, eqn, x)
+ assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
+ assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
+
+
+def test_lora_einsum_same_output():
+ shape = (3, 8, 32, 4) # (3KDH)
+ einsum = lora.Einsum(shape)
+ einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (8, 64, 32)) # (BSD)
+ eqn = "BSD,3KDH->3BSKH"
+
+ params = einsum.init(key, eqn, x)
+ output = einsum.apply(params, eqn, x)
+
+ params_lora = einsum_lora.init(key, eqn, x)
+ output_lora = einsum_lora.apply(params_lora, eqn, x)
+
+ # Results are the same since the LoRA parameters are initialized to zeros.
+ assert jnp.allclose(output, output_lora)
+
+
+def test_lora_ffn_params_shape():
+ ffn = lora.FeedForward(features=8, hidden_dim=32)
+ ffn_lora = lora.FeedForward(
+ features=8,
+ hidden_dim=32,
+ lora_config=lora.LoRAConfig(rank=2),
+ )
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (2, 8))
+
+ params = ffn.init(key, x)
+ assert params["params"]["gating_einsum"].shape == (2, 8, 32)
+ assert params["params"]["linear"].shape == (32, 8)
+
+ params_lora = ffn_lora.init(key, x)
+ assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
+ assert params_lora["params"]["linear"].shape == (32, 8)
+ assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
+ assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
+ assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
+ assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
+
+
+def test_lora_ffn_same_output():
+ ffn = lora.FeedForward(features=8, hidden_dim=32)
+ ffn_lora = lora.FeedForward(
+ features=8,
+ hidden_dim=32,
+ lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
+ )
+
+ key = jax.random.key(0)
+ x = jax.random.normal(key, (2, 8))
+
+ params = ffn.init(key, x)
+ output = ffn.apply(params, x)
+
+ params_lora = ffn_lora.init(key, x)
+ output_lora = ffn_lora.apply(params_lora, x)
+
+ assert jnp.allclose(output, output_lora)
diff --git a/openpi/src/openpi/models/model.py b/openpi/src/openpi/models/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..29618b49453742266fe6e4a5815ceee06d815f3b
--- /dev/null
+++ b/openpi/src/openpi/models/model.py
@@ -0,0 +1,332 @@
+import abc
+from collections.abc import Sequence
+import dataclasses
+import enum
+import logging
+import pathlib
+from typing import Generic, TypeVar
+
+import augmax
+from flax import nnx
+from flax import struct
+from flax import traverse_util
+import jax
+import jax.numpy as jnp
+import numpy as np
+import orbax.checkpoint as ocp
+import safetensors
+import torch
+
+from openpi.models_pytorch import pi0_pytorch
+from openpi.shared import image_tools
+import openpi.shared.array_typing as at
+
+logger = logging.getLogger("openpi")
+
+# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
+ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
+
+
+class ModelType(enum.Enum):
+ """Supported model types."""
+
+ PI0 = "pi0"
+ PI0_FAST = "pi0_fast"
+ PI05 = "pi05"
+
+
+# The model always expects these images
+IMAGE_KEYS = (
+ "base_0_rgb",
+ "left_wrist_0_rgb",
+ "right_wrist_0_rgb",
+)
+
+
+# This may need change if we release a small model.
+IMAGE_RESOLUTION = (224, 224)
+
+
+# Data format
+#
+# Data transforms produce the model input as a nested dictionary which is later converted
+# into `Obesrvation` and `Actions` objects. See below.
+#
+# In the dictory form, this data should look like:
+# {
+# # Observation data.
+# "image": {
+# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
+# ... # Additional camera views
+# },
+# "image_mask": {
+# "base_0_rgb": bool[*b], # True if image is valid
+# ... # Masks for additional views
+# },
+# "state": float32[*b, s], # Low-dimensional robot state
+# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
+# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
+# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
+# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
+#
+# # Actions data.
+# "actions": float32[*b ah ad]
+# }
+# where:
+# *b = batch dimensions
+# h,w = image height/width
+# s = state dimension
+# l = sequence length
+#
+@at.typecheck
+@struct.dataclass
+class Observation(Generic[ArrayT]):
+ """Holds observations, i.e., inputs to the model.
+
+ See `Observation.from_dict` to see the expected dictionary form. This is the format
+ that should be produced by the data transforms.
+ """
+
+ # Images, in [-1, 1] float32.
+ images: dict[str, at.Float[ArrayT, "*b h w c"]]
+ # Image masks, with same keys as images.
+ image_masks: dict[str, at.Bool[ArrayT, "*b"]]
+ # Low-dimensional robot state.
+ state: at.Float[ArrayT, "*b s"]
+
+ # Tokenized prompt.
+ tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
+ # Tokenized prompt mask.
+ tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
+
+ # pi0-fast model specific fields.
+
+ # Token auto-regressive mask (for FAST autoregressive model).
+ token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
+ # Token loss mask (for FAST autoregressive model).
+ token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
+
+ @classmethod
+ def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
+ """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
+ # Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
+ if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
+ raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
+ # If images are uint8, convert them to [-1, 1] float32.
+ for key in data["image"]:
+ if data["image"][key].dtype == np.uint8:
+ data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
+ elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
+ data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
+ return cls(
+ images=data["image"],
+ image_masks=data["image_mask"],
+ state=data["state"],
+ tokenized_prompt=data.get("tokenized_prompt"),
+ tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
+ token_ar_mask=data.get("token_ar_mask"),
+ token_loss_mask=data.get("token_loss_mask"),
+ )
+
+ def to_dict(self) -> at.PyTree[ArrayT]:
+ """Convert the Observation to a nested dict."""
+ result = dataclasses.asdict(self)
+ result["image"] = result.pop("images")
+ result["image_mask"] = result.pop("image_masks")
+ return result
+
+
+# Defines the format of the actions. This field is included as "actions" inside the dictionary
+# produced by the data transforms.
+Actions = at.Float[ArrayT, "*b ah ad"]
+
+
+def preprocess_observation(
+ rng: at.KeyArrayLike | None,
+ observation: Observation,
+ *,
+ train: bool = False,
+ image_keys: Sequence[str] = IMAGE_KEYS,
+ image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
+) -> Observation:
+ """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
+ filling in a default image mask (if necessary).
+ """
+
+ if not set(image_keys).issubset(observation.images):
+ raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
+
+ batch_shape = observation.state.shape[:-1]
+
+ out_images = {}
+ for key in image_keys:
+ image = observation.images[key]
+ if image.shape[1:3] != image_resolution:
+ logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
+ image = image_tools.resize_with_pad(image, *image_resolution)
+
+ if train:
+ # Convert from [-1, 1] to [0, 1] for augmax.
+ image = image / 2.0 + 0.5
+
+ transforms = []
+ if "wrist" not in key:
+ height, width = image.shape[1:3]
+ transforms += [
+ augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
+ augmax.Resize(width, height),
+ augmax.Rotate((-5, 5)),
+ ]
+ transforms += [
+ augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
+ ]
+ sub_rngs = jax.random.split(rng, image.shape[0])
+ image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
+
+ # Back to [-1, 1].
+ image = image * 2.0 - 1.0
+
+ out_images[key] = image
+
+ # obtain mask
+ out_masks = {}
+ for key in out_images:
+ if key not in observation.image_masks:
+ # do not mask by default
+ out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
+ else:
+ out_masks[key] = jnp.asarray(observation.image_masks[key])
+
+ return Observation(
+ images=out_images,
+ image_masks=out_masks,
+ state=observation.state,
+ tokenized_prompt=observation.tokenized_prompt,
+ tokenized_prompt_mask=observation.tokenized_prompt_mask,
+ token_ar_mask=observation.token_ar_mask,
+ token_loss_mask=observation.token_loss_mask,
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class BaseModelConfig(abc.ABC):
+ """Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
+ method to create the corresponding model.
+ """
+
+ # Action space dimension.
+ action_dim: int
+ # Action sequence length.
+ action_horizon: int
+ # Tokenized prompt maximum length.
+ max_token_len: int
+
+ @property
+ @abc.abstractmethod
+ def model_type(self) -> ModelType:
+ """The model type."""
+
+ @abc.abstractmethod
+ def create(self, rng: at.KeyArrayLike) -> "BaseModel":
+ """Create a new model, initializing parameters."""
+
+ def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
+ """Create a model with the given parameters."""
+ model = nnx.eval_shape(self.create, jax.random.key(0))
+ graphdef, state = nnx.split(model)
+ if remove_extra_params:
+ params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
+ at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
+ state.replace_by_pure_dict(params)
+ return nnx.merge(graphdef, state)
+
+ def load_pytorch(self, train_config, weight_path: str):
+ logger.info(f"train_config: {train_config}")
+ model = pi0_pytorch.PI0Pytorch(config=train_config.model)
+ safetensors.torch.load_model(model, weight_path)
+ return model
+
+ @abc.abstractmethod
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
+ """Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
+
+ def fake_obs(self, batch_size: int = 1) -> Observation:
+ observation_spec, _ = self.inputs_spec(batch_size=batch_size)
+ return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
+
+ def fake_act(self, batch_size: int = 1) -> Actions:
+ _, action_spec = self.inputs_spec(batch_size=batch_size)
+ return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
+
+
+@dataclasses.dataclass
+class BaseModel(nnx.Module, abc.ABC):
+ """Base class for all model implementations. Specific models should inherit from this class. They should call
+ super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
+ """
+
+ action_dim: int
+ action_horizon: int
+ max_token_len: int
+
+ @abc.abstractmethod
+ def compute_loss(
+ self,
+ rng: at.KeyArrayLike,
+ observation: Observation,
+ actions: Actions,
+ *,
+ train: bool = False,
+ ) -> at.Float[at.Array, "*b ah"]: ...
+
+ @abc.abstractmethod
+ def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
+
+
+def restore_params(
+ params_path: pathlib.Path | str,
+ *,
+ restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
+ dtype: jnp.dtype | None = None,
+ sharding: jax.sharding.Sharding | None = None,
+) -> at.Params:
+ """Restores unstructured params PyTree from a checkpoint.
+
+ This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
+ well as pre-trained checkpoints released for openpi.
+
+ Args:
+ params_path: The local path to the checkpoint directory.
+ restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
+ dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
+ sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
+
+ Returns:
+ The restored params.
+ """
+ params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
+
+ if restore_type is jax.Array and sharding is None:
+ mesh = jax.sharding.Mesh(jax.devices(), ("x",))
+ sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
+
+ with ocp.PyTreeCheckpointer() as ckptr:
+ metadata = ckptr.metadata(params_path)
+ item = {"params": metadata["params"]}
+
+ params = ckptr.restore(
+ params_path,
+ ocp.args.PyTreeRestore(
+ item=item,
+ restore_args=jax.tree.map(
+ lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
+ ),
+ ),
+ )["params"]
+
+ # If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
+ # added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
+ flat_params = traverse_util.flatten_dict(params)
+ if all(kp[-1] == "value" for kp in flat_params):
+ flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
+ return traverse_util.unflatten_dict(flat_params)
diff --git a/openpi/src/openpi/models/model_test.py b/openpi/src/openpi/models/model_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..495dc18b5f0fbb4405a49d33d676b0ef31e2aeef
--- /dev/null
+++ b/openpi/src/openpi/models/model_test.py
@@ -0,0 +1,94 @@
+from flax import nnx
+import jax
+import pytest
+
+from openpi.models import model as _model
+from openpi.models import pi0_config
+from openpi.models import pi0_fast
+from openpi.shared import download
+from openpi.shared import nnx_utils
+
+
+def test_pi0_model():
+ key = jax.random.key(0)
+ config = pi0_config.Pi0Config()
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size, config.action_horizon)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
+ assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
+
+
+def test_pi0_lora_model():
+ key = jax.random.key(0)
+ config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size, config.action_horizon)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
+ assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
+
+
+def test_pi0_fast_model():
+ key = jax.random.key(0)
+ config = pi0_fast.Pi0FASTConfig()
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size,)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
+ assert actions.shape == (batch_size, 256)
+
+
+def test_pi0_fast_lora_model():
+ key = jax.random.key(0)
+ config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
+ model = config.create(key)
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
+ assert loss.shape == (batch_size,)
+
+ actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
+ assert actions.shape == (batch_size, 256)
+
+ lora_filter = nnx_utils.PathRegex(".*lora.*")
+ model_state = nnx.state(model)
+
+ lora_state_elems = list(model_state.filter(lora_filter))
+ assert len(lora_state_elems) > 0
+
+
+@pytest.mark.manual
+def test_model_restore():
+ key = jax.random.key(0)
+ config = pi0_config.Pi0Config()
+
+ batch_size = 2
+ obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
+
+ model = config.load(
+ _model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
+ )
+
+ loss = model.compute_loss(key, obs, act)
+ assert loss.shape == (batch_size, config.action_horizon)
+
+ actions = model.sample_actions(key, obs, num_steps=10)
+ assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
diff --git a/openpi/src/openpi/models/pi0.py b/openpi/src/openpi/models/pi0.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae7c4590f330f160aa9baa713d5b77fc060120de
--- /dev/null
+++ b/openpi/src/openpi/models/pi0.py
@@ -0,0 +1,279 @@
+import logging
+
+import einops
+import flax.nnx as nnx
+import flax.nnx.bridge as nnx_bridge
+import jax
+import jax.numpy as jnp
+from typing_extensions import override
+
+from openpi.models import model as _model
+from openpi.models import pi0_config
+import openpi.models.gemma as _gemma
+import openpi.models.siglip as _siglip
+from openpi.shared import array_typing as at
+
+logger = logging.getLogger("openpi")
+
+
+def make_attn_mask(input_mask, mask_ar):
+ """Adapted from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
+ it and false where it shares the same attention mask as the previous token.
+ """
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
+ cumsum = jnp.cumsum(mask_ar, axis=1)
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
+ return jnp.logical_and(attn_mask, valid_mask)
+
+
+@at.typecheck
+def posemb_sincos(
+ pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
+) -> at.Float[at.Array, "b {embedding_dim}"]:
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
+ if embedding_dim % 2 != 0:
+ raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
+
+ fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
+ period = min_period * (max_period / min_period) ** fraction
+ sinusoid_input = jnp.einsum(
+ "i,j->ij",
+ pos,
+ 1.0 / period * 2 * jnp.pi,
+ precision=jax.lax.Precision.HIGHEST,
+ )
+ return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
+
+
+class Pi0(_model.BaseModel):
+ def __init__(self, config: pi0_config.Pi0Config, rngs: nnx.Rngs):
+ super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
+ self.pi05 = config.pi05
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
+ action_expert_config = _gemma.get_config(config.action_expert_variant)
+ # TODO: rewrite gemma in NNX. For now, use bridge.
+ llm = nnx_bridge.ToNNX(
+ _gemma.Module(
+ configs=[paligemma_config, action_expert_config],
+ embed_dtype=config.dtype,
+ adarms=config.pi05,
+ )
+ )
+ llm.lazy_init(rngs=rngs, method="init", use_adarms=[False, True] if config.pi05 else [False, False])
+ img = nnx_bridge.ToNNX(
+ _siglip.Module(
+ num_classes=paligemma_config.width,
+ variant="So400m/14",
+ pool_type="none",
+ scan=True,
+ dtype_mm=config.dtype,
+ )
+ )
+ img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
+ self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
+ if config.pi05:
+ self.time_mlp_in = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
+ self.time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
+ else:
+ self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
+ self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
+ self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
+ self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
+
+ # This attribute gets automatically set by model.train() and model.eval().
+ self.deterministic = True
+
+ @at.typecheck
+ def embed_prefix(
+ self, obs: _model.Observation
+ ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
+ input_mask = []
+ ar_mask = []
+ tokens = []
+ # embed images
+ for name in obs.images:
+ image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
+
+ tokens.append(image_tokens)
+ input_mask.append(
+ einops.repeat(
+ obs.image_masks[name],
+ "b -> b s",
+ s=image_tokens.shape[1],
+ )
+ )
+ # image tokens attend to each other
+ ar_mask += [False] * image_tokens.shape[1]
+
+ # add language (aka tokenized inputs)
+ if obs.tokenized_prompt is not None:
+ tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
+ tokens.append(tokenized_inputs)
+ input_mask.append(obs.tokenized_prompt_mask)
+ # full attention between image and language inputs
+ ar_mask += [False] * tokenized_inputs.shape[1]
+ tokens = jnp.concatenate(tokens, axis=1)
+ input_mask = jnp.concatenate(input_mask, axis=1)
+ ar_mask = jnp.array(ar_mask)
+ return tokens, input_mask, ar_mask
+
+ @at.typecheck
+ def embed_suffix(
+ self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
+ ) -> tuple[
+ at.Float[at.Array, "b s emb"],
+ at.Bool[at.Array, "b s"],
+ at.Bool[at.Array, " s"],
+ at.Float[at.Array, "b emb"] | None,
+ ]:
+ input_mask = []
+ ar_mask = []
+ tokens = []
+ if not self.pi05:
+ # add a single state token
+ state_token = self.state_proj(obs.state)[:, None, :]
+ tokens.append(state_token)
+ input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
+ # image/language inputs do not attend to state or actions
+ ar_mask += [True]
+
+ action_tokens = self.action_in_proj(noisy_actions)
+ # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
+ if self.pi05:
+ # time MLP (for adaRMS)
+ time_emb = self.time_mlp_in(time_emb)
+ time_emb = nnx.swish(time_emb)
+ time_emb = self.time_mlp_out(time_emb)
+ time_emb = nnx.swish(time_emb)
+ action_expert_tokens = action_tokens
+ adarms_cond = time_emb
+ else:
+ # mix timestep + action information using an MLP (no adaRMS)
+ time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
+ action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
+ action_time_tokens = self.action_time_mlp_in(action_time_tokens)
+ action_time_tokens = nnx.swish(action_time_tokens)
+ action_time_tokens = self.action_time_mlp_out(action_time_tokens)
+ action_expert_tokens = action_time_tokens
+ adarms_cond = None
+ tokens.append(action_expert_tokens)
+ input_mask.append(jnp.ones(action_expert_tokens.shape[:2], dtype=jnp.bool_))
+ # image/language/state inputs do not attend to action tokens
+ ar_mask += [True] + ([False] * (self.action_horizon - 1))
+ tokens = jnp.concatenate(tokens, axis=1)
+ input_mask = jnp.concatenate(input_mask, axis=1)
+ ar_mask = jnp.array(ar_mask)
+ return tokens, input_mask, ar_mask, adarms_cond
+
+ @override
+ def compute_loss(
+ self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
+ ) -> at.Float[at.Array, "*b ah"]:
+ preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
+ observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
+
+ batch_shape = actions.shape[:-2]
+ noise = jax.random.normal(noise_rng, actions.shape)
+ time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
+ time_expanded = time[..., None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ # one big forward pass of prefix + suffix at once
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
+ suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(observation, x_t, time)
+ input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
+ ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
+ attn_mask = make_attn_mask(input_mask, ar_mask)
+ positions = jnp.cumsum(input_mask, axis=1) - 1
+ (prefix_out, suffix_out), _ = self.PaliGemma.llm(
+ [prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions, adarms_cond=[None, adarms_cond]
+ )
+ v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
+
+ return jnp.mean(jnp.square(v_t - u_t), axis=-1)
+
+ @override
+ def sample_actions(
+ self,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ *,
+ num_steps: int | at.Int[at.Array, ""] = 10,
+ noise: at.Float[at.Array, "b ah ad"] | None = None,
+ ) -> _model.Actions:
+ observation = _model.preprocess_observation(None, observation, train=False)
+ # note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
+ # distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
+ dt = -1.0 / num_steps
+ batch_size = observation.state.shape[0]
+ if noise is None:
+ noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
+
+ # first fill KV cache with a forward pass of the prefix
+ prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
+ positions = jnp.cumsum(prefix_mask, axis=1) - 1
+ _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
+
+ def step(carry):
+ x_t, time = carry
+ suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
+ observation, x_t, jnp.broadcast_to(time, batch_size)
+ )
+ # `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
+ # other
+ suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
+ # `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
+ # prefix tokens
+ prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
+ # `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
+ # generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
+ full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
+ assert full_attn_mask.shape == (
+ batch_size,
+ suffix_tokens.shape[1],
+ prefix_tokens.shape[1] + suffix_tokens.shape[1],
+ )
+ # `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
+ positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
+
+ (prefix_out, suffix_out), _ = self.PaliGemma.llm(
+ [None, suffix_tokens],
+ mask=full_attn_mask,
+ positions=positions,
+ kv_cache=kv_cache,
+ adarms_cond=[None, adarms_cond],
+ )
+ assert prefix_out is None
+ v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
+
+ return x_t + dt * v_t, time + dt
+
+ def cond(carry):
+ x_t, time = carry
+ # robust to floating-point error
+ return time >= -dt / 2
+
+ x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
+ return x_0
diff --git a/openpi/src/openpi/models/pi0_config.py b/openpi/src/openpi/models/pi0_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f6b662ac1450e31721f9860145496720e26b5e
--- /dev/null
+++ b/openpi/src/openpi/models/pi0_config.py
@@ -0,0 +1,108 @@
+import dataclasses
+from typing import TYPE_CHECKING
+
+import flax.nnx as nnx
+import jax
+import jax.numpy as jnp
+from typing_extensions import override
+
+from openpi.models import model as _model
+import openpi.models.gemma as _gemma
+from openpi.shared import array_typing as at
+import openpi.shared.nnx_utils as nnx_utils
+
+if TYPE_CHECKING:
+ from openpi.models.pi0 import Pi0
+
+
+@dataclasses.dataclass(frozen=True)
+class Pi0Config(_model.BaseModelConfig):
+ dtype: str = "bfloat16"
+ paligemma_variant: _gemma.Variant = "gemma_2b"
+ action_expert_variant: _gemma.Variant = "gemma_300m"
+
+ # Set the model specific defaults.
+ action_dim: int = 32
+ action_horizon: int = 50
+ max_token_len: int = None # type: ignore
+ # Pi05 has two differences from Pi0:
+ # - the state input is part of the discrete language tokens rather than a continuous input that is part of the suffix
+ # - the action expert uses adaRMSNorm to inject the flow matching timestep
+ pi05: bool = False
+ # This config option is not used directly by the model, but it is read by the ModelTransformFactory.
+ discrete_state_input: bool = None # type: ignore
+
+ def __post_init__(self):
+ if self.max_token_len is None:
+ object.__setattr__(self, "max_token_len", 200 if self.pi05 else 48)
+ if self.discrete_state_input is None:
+ object.__setattr__(self, "discrete_state_input", self.pi05)
+
+ @property
+ @override
+ def model_type(self) -> _model.ModelType:
+ if self.pi05:
+ return _model.ModelType.PI05
+ return _model.ModelType.PI0
+
+ @override
+ def create(self, rng: at.KeyArrayLike) -> "Pi0":
+ from openpi.models.pi0 import Pi0
+
+ return Pi0(self, rngs=nnx.Rngs(rng))
+
+ @override
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
+ image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
+
+ with at.disable_typechecking():
+ observation_spec = _model.Observation(
+ images={
+ "base_0_rgb": image_spec,
+ "left_wrist_0_rgb": image_spec,
+ "right_wrist_0_rgb": image_spec,
+ },
+ image_masks={
+ "base_0_rgb": image_mask_spec,
+ "left_wrist_0_rgb": image_mask_spec,
+ "right_wrist_0_rgb": image_mask_spec,
+ },
+ state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
+ tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
+ tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
+ )
+ action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
+
+ return observation_spec, action_spec
+
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
+ """Returns the freeze filter based on the model config."""
+ filters = []
+ has_lora = False
+ gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
+ action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
+ if "lora" in self.paligemma_variant:
+ filters.append(
+ gemma_params_filter,
+ )
+ if "lora" not in self.action_expert_variant:
+ # If only freeze gemma params, exclude action expert params.
+ filters.append(
+ nnx.Not(action_expert_params_filter),
+ )
+ has_lora = True
+ elif "lora" in self.action_expert_variant:
+ filters.append(
+ action_expert_params_filter,
+ )
+ has_lora = True
+
+ if has_lora:
+ # If any lora is used, exclude all lora params.
+ filters.append(
+ nnx.Not(nnx_utils.PathRegex(".*lora.*")),
+ )
+ if not filters:
+ return nnx.Nothing
+ return nnx.All(*filters)
diff --git a/openpi/src/openpi/models/pi0_fast.py b/openpi/src/openpi/models/pi0_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6b5bd15e27a58a37e8db3f9f9b9707806fec7a5
--- /dev/null
+++ b/openpi/src/openpi/models/pi0_fast.py
@@ -0,0 +1,313 @@
+import dataclasses
+import logging
+from typing import Any
+
+import einops
+import flax.nnx as nnx
+import flax.nnx.bridge as nnx_bridge
+import jax
+import jax.numpy as jnp
+from typing_extensions import override
+
+from openpi.models import model as _model
+import openpi.models.gemma_fast as _gemma
+import openpi.models.siglip as _siglip
+from openpi.shared import array_typing as at
+import openpi.shared.nnx_utils as nnx_utils
+
+logger = logging.getLogger("openpi")
+
+PALIGEMMA_EOS_TOKEN = 1
+
+
+def make_attn_mask(input_mask, mask_ar):
+ """Adapted from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
+ it and false where it shares the same attention mask as the previous token.
+ """
+ mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
+ cumsum = jnp.cumsum(mask_ar, axis=1)
+ attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
+ valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
+ return jnp.logical_and(attn_mask, valid_mask)
+
+
+@jax.vmap
+def left_to_right_align(x, input_mask, attn_mask):
+ """Converts input from left-align to right-aligned."""
+ # Due to vmap, this is operating in a single example (not batch level).
+ assert x.ndim == 2
+ assert input_mask.ndim == 1
+ assert attn_mask.ndim == 2
+ assert x.shape[0] == input_mask.shape[0]
+ assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
+ seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
+ x = jnp.roll(x, -seqlen, axis=0)
+ input_mask = jnp.roll(input_mask, -seqlen, axis=0)
+ attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
+ return x, input_mask, attn_mask
+
+
+def put_along_last_axis(arr, indices, values):
+ """Like np.put_along_axis(..., axis=-1), since jax is missing it."""
+ assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)
+ onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
+ put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot)
+ put_values = jnp.einsum("...i,...in->...n", values, onehot)
+ return jnp.where(put_mask, put_values, arr)
+
+
+@dataclasses.dataclass(frozen=True)
+class Pi0FASTConfig(_model.BaseModelConfig):
+ dtype: str = "bfloat16"
+ paligemma_variant: _gemma.Variant = "gemma_2b"
+
+ # Set the model specific defaults.
+ action_dim: int = 32
+ action_horizon: int = 32
+ max_token_len: int = 250
+
+ # Tokenizer for the fast model.
+ fast_model_tokenizer: Any | None = None
+ # Keyword arguments for the fast model tokenizer.
+ fast_model_tokenizer_kwargs: dict[str, Any] | None = None
+
+ @property
+ @override
+ def model_type(self) -> _model.ModelType:
+ return _model.ModelType.PI0_FAST
+
+ @override
+ def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
+ return Pi0FAST(self, rngs=nnx.Rngs(rng))
+
+ @override
+ def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
+ image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
+ image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
+
+ with at.disable_typechecking():
+ observation_spec = _model.Observation(
+ images={
+ "base_0_rgb": image_spec,
+ "base_1_rgb": image_spec,
+ "wrist_0_rgb": image_spec,
+ },
+ image_masks={
+ "base_0_rgb": image_mask_spec,
+ "base_1_rgb": image_mask_spec,
+ "wrist_0_rgb": image_mask_spec,
+ },
+ state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
+ tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
+ tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
+ token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
+ token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),
+ )
+ action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
+
+ return observation_spec, action_spec
+
+ def get_freeze_filter(self) -> nnx.filterlib.Filter:
+ """Returns the freeze filter based on the model config."""
+ if "lora" in self.paligemma_variant:
+ return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
+ return nnx.Nothing
+
+
+class Pi0FAST(_model.BaseModel):
+ def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
+ super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
+ # TODO: rewrite gemma in NNX. For now, use bridge.
+ llm = nnx_bridge.ToNNX(
+ _gemma.Module(
+ **paligemma_config,
+ embed_dtype=config.dtype,
+ cache_dtype=config.dtype,
+ )
+ )
+ llm.lazy_init(rngs=rngs, method="init")
+ img = nnx_bridge.ToNNX(
+ _siglip.Module(
+ num_classes=paligemma_config.width,
+ variant="So400m/14",
+ pool_type="none",
+ scan=True,
+ dtype_mm=config.dtype,
+ )
+ )
+ img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
+ self.PaliGemma = nnx.Dict(llm=llm, img=img)
+
+ @at.typecheck
+ def embed_inputs(
+ self, obs: _model.Observation
+ ) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]:
+ input_mask = []
+ ar_mask = []
+ token_embeddings = []
+ # embed images
+ for name in obs.images:
+ image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)
+
+ token_embeddings.append(image_token_embeddings)
+ input_mask.append(
+ einops.repeat(
+ obs.image_masks[name],
+ "b -> b s",
+ s=image_token_embeddings.shape[1],
+ )
+ )
+ # image tokens attend to each other --> AR mask = 0
+ ar_mask.append(0 * input_mask[-1])
+
+ # add tokenized inputs
+ assert obs.tokenized_prompt is not None, "Tokenized prompt is required"
+ assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required"
+ assert obs.token_ar_mask is not None, "Token auto-regressive mask is required"
+ tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)
+ token_embeddings.append(tokenized_inputs_embeddings)
+ input_mask.append(obs.tokenized_prompt_mask)
+ ar_mask.append(obs.token_ar_mask)
+
+ # return embeddings, input mask, and ar mask
+ return (
+ jnp.concatenate(token_embeddings, axis=1),
+ jnp.concatenate(input_mask, axis=1),
+ jnp.concatenate(ar_mask, axis=1),
+ )
+
+ @override
+ def compute_loss(
+ self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
+ ) -> at.Float[at.Array, "*b ah"]:
+ observation = _model.preprocess_observation(
+ rng, observation, train=train, image_keys=list(observation.images.keys())
+ )
+
+ # Compute inputs: one big forward pass of prefix + suffix at once
+ input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
+ attn_mask = make_attn_mask(input_mask, ar_mask)
+
+ # Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
+ targets = jax.nn.one_hot(
+ observation.tokenized_prompt[:, 1:],
+ self.PaliGemma.llm.module.vocab_size,
+ )
+
+ # Each input predicts *next* token, so we don't input the last token.
+ pre_logits, _, _ = self.PaliGemma.llm(
+ embedded_prefix=input_token_embeddings[:, :-1],
+ mask=attn_mask[:, :-1, :-1],
+ return_prelogits=True,
+ )
+
+ # Only decode logits for the target tokens to save memory
+ # (decoding matmul is large because it is a seq_len x vocab_size dense layer).
+ logits, _ = self.PaliGemma.llm(
+ pre_logits=pre_logits[:, -targets.shape[1] :],
+ )
+ logp = jax.nn.log_softmax(logits, axis=-1)
+
+ # Compute CE loss on token targets
+ assert observation.token_loss_mask is not None, "Token loss mask is required"
+ loss_mask = observation.token_loss_mask[:, 1:]
+ token_pplx = jnp.sum(targets * logp, axis=-1)
+ return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
+
+ @override
+ def sample_actions(
+ self,
+ rng: at.KeyArrayLike,
+ observation: _model.Observation,
+ *,
+ max_decoding_steps: int | at.Int[at.Array, ""] = 256,
+ temperature: float = 0.0,
+ ) -> _model.Actions:
+ # TODO: this is a hack to get the image keys.
+ observation = _model.preprocess_observation(
+ None, observation, train=False, image_keys=list(observation.images.keys())
+ )
+
+ # embed inputs
+ prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
+ prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
+
+ # left to right align all input token sequences
+ prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
+ prefix_token_embeddings, prefix_mask, prefix_attn_mask
+ )
+ prefill_size = prefix_token_embeddings.shape[1]
+ prefill_len = jnp.sum(prefix_mask, axis=-1)
+ prefix_start = prefill_size - prefill_len
+
+ # first fill KV cache with a forward pass of the prefix
+ # pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
+ prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
+ prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
+ prefix_logits, kv_cache, _ = self.PaliGemma.llm(
+ embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
+ )
+
+ # prepare decoding -- final logit decodes the first token
+ last_logit = prefix_logits[:, -1:]
+ output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
+
+ def step(carry):
+ rng, last_logit, output_tokens, cache, _, step = carry
+
+ # Sample token from last logit
+ # Split RNG for this step
+ rng, rng_step = jax.random.split(rng)
+ token = jax.lax.cond(
+ temperature > 0.0,
+ lambda _: jax.random.categorical(rng_step, last_logit / temperature, axis=-1),
+ lambda _: jnp.argmax(last_logit, axis=-1),
+ operand=None,
+ )
+ output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
+
+ # Check for early stopping --> stop if all batch elements have EOS token
+ has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
+ all_eos = jnp.all(has_eos)
+
+ # Decode one step
+ token_embedding = self.PaliGemma.llm(token, embed_only=True)
+ positions = prefill_len[:, None] + step + 1
+ mask = jnp.logical_and(
+ jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
+ jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
+ < (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
+ )
+ last_logit, kv_cache, _ = self.PaliGemma.llm(
+ embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
+ )
+
+ return rng, last_logit, output_tokens, kv_cache, all_eos, step + 1
+
+ def cond(carry):
+ _, _, _, _, all_eos, step = carry
+ return (~all_eos) & (step < max_decoding_steps)
+
+ # Use lax.while_loop so we can jit the full decoding loop.
+ _, _, output_tokens, _, _, _ = jax.lax.while_loop(
+ cond, step, (rng, last_logit, output_tokens, kv_cache, False, 0)
+ )
+ return output_tokens
diff --git a/openpi/src/openpi/models/pi0_test.py b/openpi/src/openpi/models/pi0_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5f0f84e7e831db37883ec3181eadf6e175840ef
--- /dev/null
+++ b/openpi/src/openpi/models/pi0_test.py
@@ -0,0 +1,46 @@
+import flax.nnx as nnx
+import jax
+
+import openpi.models.pi0_config as _pi0_config
+
+
+def _get_frozen_state(config: _pi0_config.Pi0Config) -> nnx.State:
+ abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
+
+ freeze_filter = config.get_freeze_filter()
+ return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
+
+
+def test_pi0_full_finetune():
+ config = _pi0_config.Pi0Config()
+ state = _get_frozen_state(config)
+ assert len(state) == 0
+
+
+def test_pi0_gemma_lora():
+ config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
+ state = _get_frozen_state(config)
+ assert len(state) == 9
+ assert all("lora" not in p for p in state)
+ assert all("llm" in p for p in state)
+ assert all("_1" not in p for p in state)
+
+
+def test_pi0_action_expert_lora():
+ config = _pi0_config.Pi0Config(action_expert_variant="gemma_300m_lora")
+ state = _get_frozen_state(config)
+ # excluding embedder, rest of the params should be same as gemma_lora.
+ assert len(state) == 8
+ assert all("lora" not in p for p in state)
+ assert all("llm" in p for p in state)
+ # all frozen params should have _1 in their path since it's the action expert.
+ assert all(any("_1" in p for p in path) for path in state)
+
+
+def test_pi0_all_lora():
+ config = _pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
+ state = _get_frozen_state(config)
+ # sum of gemma_lora and action_expert_lora's frozen params.
+ assert len(state) == 17
+ assert all("lora" not in p for p in state)
+ assert all("llm" in p for p in state)
diff --git a/openpi/src/openpi/models/siglip.py b/openpi/src/openpi/models/siglip.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74c99e75fa45308fcadff115a7fd4ed12c9e96c
--- /dev/null
+++ b/openpi/src/openpi/models/siglip.py
@@ -0,0 +1,373 @@
+# Copyright 2024 Big Vision Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
+
+from collections.abc import Sequence
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+import openpi.training.sharding as sharding
+
+
+def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
+ """Follows the MoCo v3 logic."""
+ y, x = jnp.mgrid[:h, :w]
+
+ assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
+ omega = jnp.arange(width // 4) / (width // 4 - 1)
+ omega = 1.0 / (temperature**omega)
+ y = jnp.einsum("m,d->md", y.flatten(), omega)
+ x = jnp.einsum("m,d->md", x.flatten(), omega)
+ pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
+ return jnp.asarray(pe, dtype)[None, :, :]
+
+
+def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
+ if typ == "learn":
+ return self.param(
+ name,
+ nn.initializers.normal(stddev=1 / np.sqrt(width)),
+ (1, np.prod(seqshape), width),
+ dtype,
+ )
+ if typ == "sincos2d":
+ return posemb_sincos_2d(*seqshape, width, dtype=dtype)
+ raise ValueError(f"Unknown posemb type: {typ}")
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block."""
+
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ dropout: float = 0.0
+ dtype_mm: str = "float32"
+
+ @nn.compact
+ def __call__(self, x, deterministic=True): # noqa: FBT002
+ """Applies Transformer MlpBlock module."""
+ inits = {
+ "kernel_init": nn.initializers.xavier_uniform(),
+ "bias_init": nn.initializers.normal(stddev=1e-6),
+ }
+
+ _, _, d = x.shape # n,l,d
+ x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
+ x = nn.gelu(x)
+ x = nn.Dropout(rate=self.dropout)(x, deterministic)
+ return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
+
+
+class Encoder1DBlock(nn.Module):
+ """Single transformer encoder block (MHSA + MLP)."""
+
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ dropout: float = 0.0
+ dtype_mm: str = "float32"
+
+ @nn.compact
+ def __call__(self, x, deterministic=True): # noqa: FBT002
+ out = {}
+ x = sharding.activation_sharding_constraint(x)
+ y = nn.LayerNorm(dtype=self.dtype_mm)(x)
+ y = out["sa"] = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=deterministic,
+ dtype=self.dtype_mm,
+ )(y, y)
+ y = sharding.activation_sharding_constraint(y)
+ y = nn.Dropout(rate=self.dropout)(y, deterministic)
+ x = out["+sa"] = x + y
+
+ y = nn.LayerNorm(dtype=self.dtype_mm)(x)
+ y = out["mlp"] = MlpBlock(
+ mlp_dim=self.mlp_dim,
+ dropout=self.dropout,
+ dtype_mm=self.dtype_mm,
+ )(y, deterministic)
+ y = sharding.activation_sharding_constraint(y)
+ y = nn.Dropout(rate=self.dropout)(y, deterministic)
+ x = out["+mlp"] = x + y
+ x = sharding.activation_sharding_constraint(x)
+ return x, out
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation."""
+
+ depth: int
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ dropout: float = 0.0
+ scan: bool = False
+ remat_policy: str = "nothing_saveable"
+ dtype_mm: str = "float32"
+
+ @nn.compact
+ def __call__(self, x, deterministic=True): # noqa: FBT002
+ out = {}
+
+ if self.scan:
+ block = nn.remat(
+ Encoder1DBlock,
+ prevent_cse=False,
+ static_argnums=(2,), # 0=self, 2=deterministic
+ policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
+ )
+ x, scan_out = nn.scan(
+ block,
+ variable_axes={"params": 0},
+ split_rngs={"params": True, "dropout": True},
+ in_axes=nn.broadcast,
+ length=self.depth,
+ )(
+ name="encoderblock",
+ dtype_mm=self.dtype_mm,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ )(x, deterministic)
+ for lyr in range(self.depth):
+ out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)
+ else:
+ # Input Encoder
+ for lyr in range(self.depth):
+ block_cur = Encoder1DBlock(
+ name=f"encoderblock_{lyr}",
+ dtype_mm=self.dtype_mm,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ )
+ x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
+ out["pre_ln"] = x # Alias for last block, but without the number in it.
+
+ return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out
+
+
+class MAPHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ dtype_mm: str = "float32"
+
+ @nn.compact
+ def __call__(self, x):
+ n, _, d = x.shape # n,l,d
+ probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)
+ probe = jnp.tile(probe, [n, 1, 1])
+
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ dtype=self.dtype_mm,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(probe, x)
+
+ y = nn.LayerNorm(dtype=self.dtype_mm)(x)
+ x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
+ return x[:, 0]
+
+
+class _Module(nn.Module):
+ """ViT model."""
+
+ num_classes: int | None = None
+ patch_size: Sequence[int] = (16, 16)
+ width: int = 768
+ depth: int = 12
+ mlp_dim: int | None = None # Defaults to 4x input dim
+ num_heads: int = 12
+ posemb: str = "learn" # Can also be "sincos2d"
+ rep_size: int | bool = False
+ dropout: float = 0.0
+ pool_type: str = "gap" # Can also be "map" or "tok"
+ head_zeroinit: bool = True
+ scan: bool = False
+ # or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
+ remat_policy: str = "nothing_saveable"
+ dtype_mm: str = "float32"
+
+ @nn.compact
+ def __call__(self, image, *, train=False):
+ out = {}
+
+ # Kevin edit: do patch extraction and posemb in float32,
+ # because I feel like it's a bit safer.
+ image = jnp.asarray(image, jnp.float32)
+
+ # Patch extraction
+ x = out["stem"] = nn.Conv(
+ self.width,
+ self.patch_size,
+ strides=self.patch_size,
+ padding="VALID",
+ name="embedding",
+ dtype=jnp.float32,
+ )(image)
+
+ n, h, w, c = x.shape
+ x = jnp.reshape(x, [n, h * w, c])
+
+ # Add posemb before adding extra token.
+ x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32)
+
+ if self.pool_type == "tok":
+ cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
+ x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
+
+ n, _, c = x.shape # n,l,d
+ x = nn.Dropout(rate=self.dropout)(x, not train)
+
+ # Kevin edit: now cast back to dtype_mm (potentially half precision)
+ x = x.astype(self.dtype_mm)
+
+ x, out["encoder"] = Encoder(
+ depth=self.depth,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ scan=self.scan,
+ remat_policy=self.remat_policy,
+ dtype_mm=self.dtype_mm,
+ name="Transformer",
+ )(x, deterministic=not train)
+ encoded = out["encoded"] = x
+
+ if self.pool_type == "map":
+ x = out["head_input"] = MAPHead(
+ num_heads=self.num_heads,
+ mlp_dim=self.mlp_dim,
+ dtype=self.dtype_mm,
+ )(x)
+ elif self.pool_type == "gap":
+ x = out["head_input"] = jnp.mean(x, axis=1)
+ elif self.pool_type == "0":
+ x = out["head_input"] = x[:, 0]
+ elif self.pool_type == "tok":
+ x = out["head_input"] = x[:, 0]
+ encoded = encoded[:, 1:]
+ elif self.pool_type == "none":
+ pass
+ else:
+ raise ValueError(f"Unknown pool type: '{self.pool_type}'")
+
+ x_2d = jnp.reshape(encoded, [n, h, w, -1])
+
+ if self.rep_size:
+ rep_size = self.width if self.rep_size is True else self.rep_size
+ hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits")
+ # NOTE: In the past we did not include tanh in pre_logits.
+ # For few-shot, it should not matter much, as it whitens anyways.
+ x_2d = nn.tanh(hid(x_2d))
+ x = nn.tanh(hid(x))
+
+ out["pre_logits_2d"] = x_2d
+ out["pre_logits"] = x
+
+ if self.num_classes:
+ kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
+ head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw)
+ x_2d = out["logits_2d"] = head(x_2d)
+ x = out["logits"] = head(x)
+
+ return x, out
+
+
+def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802
+ """Factory function, because linen really don't like what I'm doing!"""
+ return _Module(num_classes, **{**decode_variant(variant), **kw})
+
+
+def decode_variant(variant):
+ """Converts a string like "B" or "B/32" into a params dict."""
+ if variant is None:
+ return {}
+
+ v, patch = variant, {}
+ if "/" in variant:
+ v, patch = variant.split("/")
+ patch = {"patch_size": (int(patch), int(patch))}
+
+ return {
+ # pylint:disable=line-too-long
+ # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
+ "width": {
+ "mu": 32,
+ "Ti": 192,
+ "S": 384,
+ "M": 512,
+ "B": 768,
+ "L": 1024,
+ "So400m": 1152,
+ "H": 1280,
+ "g": 1408,
+ "g-opt": 1536,
+ "G": 1664,
+ "G-opt": 1536,
+ "e": 1792,
+ }[v],
+ "depth": {
+ "mu": 1,
+ "Ti": 12,
+ "S": 12,
+ "M": 12,
+ "B": 12,
+ "L": 24,
+ "So400m": 27,
+ "H": 32,
+ "g": 40,
+ "g-opt": 40,
+ "G": 48,
+ "G-opt": 48,
+ "e": 56,
+ }[v],
+ "mlp_dim": {
+ "mu": 128,
+ "Ti": 768,
+ "S": 1536,
+ "M": 2048,
+ "B": 3072,
+ "L": 4096,
+ "So400m": 4304,
+ "H": 5120,
+ "g": 6144,
+ "g-opt": 6144,
+ "G": 8192,
+ "G-opt": 8192,
+ "e": 15360,
+ }[v],
+ "num_heads": {
+ "mu": 2,
+ "Ti": 3,
+ "S": 6,
+ "M": 8,
+ "B": 12,
+ "L": 16,
+ "So400m": 16,
+ "H": 16,
+ "g": 16,
+ "g-opt": 16,
+ "G": 16,
+ "G-opt": 16,
+ "e": 16,
+ }[v],
+ # pylint:enable=line-too-long
+ **patch,
+ }
diff --git a/openpi/src/openpi/models/tokenizer.py b/openpi/src/openpi/models/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a4966d6298619e52c7ba53359ccbbb8ba0b8cf6
--- /dev/null
+++ b/openpi/src/openpi/models/tokenizer.py
@@ -0,0 +1,371 @@
+import logging
+import os
+
+import jax
+import numpy as np
+import orbax.checkpoint as ocp
+import sentencepiece
+from transformers import AutoProcessor
+
+import openpi.models.utils.fsq_tokenizer as fsq_tokenizer
+import openpi.shared.download as download
+
+
+class PaligemmaTokenizer:
+ def __init__(self, max_len: int = 48):
+ self._max_len = max_len
+
+ path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
+ with path.open("rb") as f:
+ self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
+
+ def tokenize(self, prompt: str, state: np.ndarray | None = None) -> tuple[np.ndarray, np.ndarray]:
+ cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ")
+ if state is not None:
+ # This is the Pi05 format, where the state is part of the discrete language input.
+ discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+ state_str = " ".join(map(str, discretized_state))
+ full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
+ tokens = self._tokenizer.encode(full_prompt, add_bos=True)
+ else:
+ # This is the Pi0 format, where the state is part of the continuous action expert input.
+ # tokenize "\n" separately as the "start of answer" token
+ tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ mask = [True] * tokens_len + padding
+ tokens = tokens + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
+ "Consider increasing the `max_token_len` in your model config if this happens frequently."
+ )
+ tokens = tokens[: self._max_len]
+ mask = [True] * self._max_len
+
+ return np.asarray(tokens), np.asarray(mask)
+
+
+class FASTTokenizer:
+ def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"):
+ self._max_len = max_len
+
+ # Download base PaliGemma tokenizer
+ path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
+ with path.open("rb") as f:
+ self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
+
+ # Instantiate FAST tokenizer
+ self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
+ self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray, actions: np.ndarray | None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ cleaned_text = prompt.lower().strip().replace("_", " ")
+
+ # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
+ discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+
+ # Convention: prefix includes prompt and string-representation of state, followed by ';'
+ state_str = " ".join(map(str, discretized_state))
+ prefix = f"Task: {cleaned_text}, State: {state_str};\n"
+ prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
+
+ if actions is not None:
+ # Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
+ action_tokens = self._fast_tokenizer(actions[None])[0]
+ action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
+
+ # Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
+ postfix_tokens = (
+ self._paligemma_tokenizer.encode("Action: ")
+ + action_tokens_in_pg.tolist()
+ + self._paligemma_tokenizer.encode("|", add_eos=True)
+ )
+ else:
+ postfix_tokens = []
+
+ # Create output token sequence & masks
+ # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
+ tokens = prefix_tokens + postfix_tokens
+ token_mask = [True] * len(tokens)
+ ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
+ loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
+
+ # Pad tokens to max length
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ tokens = tokens + padding
+ token_mask = token_mask + padding
+ ar_mask = ar_mask + padding
+ loss_mask = loss_mask + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
+ "Consider increasing the `max_token_len` in your model config if this happens frequently."
+ )
+ tokens = tokens[: self._max_len]
+ token_mask = token_mask[: self._max_len]
+ ar_mask = ar_mask[: self._max_len]
+ loss_mask = loss_mask[: self._max_len]
+
+ return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
+
+ def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
+ # Decode predicted output tokens
+ decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
+
+ # Extract actions from FAST model outputs
+ if "Action: " not in decoded_tokens:
+ return np.zeros((action_horizon, action_dim), dtype=np.float32)
+
+ # Extract actions from decoded tokens
+ raw_action_tokens = np.array(
+ self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
+ )
+ action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
+ return self._fast_tokenizer.decode(
+ [action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
+ )[0]
+
+ def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
+ if isinstance(tokens, list):
+ tokens = np.array(tokens)
+ return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
+
+
+###########################################################################
+## The tokenizers below are used for RoboArena baseline implementations. ##
+## They are *not* used for pi0-style models. ##
+###########################################################################
+
+
+class BinningTokenizer:
+ """
+ Standard RT-2 / OpenVLA style binning tokenizer.
+ """
+
+ def __init__(self, max_len: int = 256, n_bins: int = 256):
+ self._max_len = max_len
+ self._n_bins = n_bins
+
+ # Download base PaliGemma tokenizer
+ path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
+ with path.open("rb") as f:
+ self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
+
+ self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray, actions: np.ndarray | None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ """Tokenize a prompt and state into a sequence of tokens.
+
+ Args:
+ prompt: The text prompt to tokenize.
+ state: The state array to discretize and tokenize.
+ actions: Must be None. Action encoding is not currently supported.
+
+ Returns:
+ A tuple of (tokens, token_mask, ar_mask, targets).
+
+ Raises:
+ NotImplementedError: If actions is not None.
+ """
+ cleaned_text = prompt.lower().strip().replace("_", " ")
+
+ # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
+ discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+
+ # Convention: prefix includes prompt and string-representation of state, followed by ';'
+ state_str = " ".join(map(str, discretized_state))
+ prefix = f"Task: {cleaned_text}, State: {state_str};\n"
+ prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
+
+ if actions is not None:
+ raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
+ postfix_tokens = []
+
+ # Create output token sequence & masks
+ # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
+ tokens = prefix_tokens + postfix_tokens
+ token_mask = [True] * len(tokens)
+ ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
+ loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
+
+ # Pad tokens to max length
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ tokens = tokens + padding
+ token_mask = token_mask + padding
+ ar_mask = ar_mask + padding
+ loss_mask = loss_mask + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
+ "Consider increasing the `max_token_len` in your model config if this happens frequently."
+ )
+ tokens = tokens[: self._max_len]
+ token_mask = token_mask[: self._max_len]
+ ar_mask = ar_mask[: self._max_len]
+ loss_mask = loss_mask[: self._max_len]
+
+ return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
+
+ def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
+ # Decode predicted output tokens
+ decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
+
+ # Extract actions from FAST model outputs
+ if "Action: " not in decoded_tokens:
+ return np.zeros((action_horizon, action_dim), dtype=np.float32)
+
+ # Extract actions from decoded tokens
+ raw_action_tokens = np.array(
+ self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
+ )
+ action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
+ if len(action_tokens) < action_horizon * action_dim:
+ return np.zeros([action_horizon, action_dim], dtype=np.float32)
+ action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
+ return action_tokens / self._n_bins * 2 - 1
+
+ def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
+ if isinstance(tokens, list):
+ tokens = np.array(tokens)
+ return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
+
+
+class FSQTokenizer:
+ """
+ FSQ tokenizer from the FAST paper baselines.
+ """
+
+ def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
+ self._max_len = max_len
+
+ assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
+ # Download tokenizer
+ path = download.maybe_download(fsq_tokenizer_path)
+ tok_path = os.path.join(path, os.listdir(path)[0])
+
+ # Split step from path
+ step = int(tok_path.split("/")[-1])
+ base_path = tok_path.rsplit("/", 1)[0]
+
+ mgr = ocp.CheckpointManager(
+ base_path,
+ item_handlers={
+ "params": ocp.StandardCheckpointHandler(),
+ "opt_state": ocp.StandardCheckpointHandler(),
+ "config": ocp.JsonCheckpointHandler(),
+ },
+ options=ocp.CheckpointManagerOptions(max_to_keep=1),
+ )
+
+ try:
+ restored = mgr.restore(
+ step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
+ )
+ config = restored["config"]
+ self._params = restored["params"]
+ self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
+ ) from e
+
+ # Compile tokenize and detokenize functions
+ self._tokenize_fn = jax.jit(
+ lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
+ )
+ self._detokenize_fn = jax.jit(
+ lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
+ )
+
+ # Download base PaliGemma tokenizer
+ path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
+ with path.open("rb") as f:
+ self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
+
+ self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
+
+ def tokenize(
+ self, prompt: str, state: np.ndarray, actions: np.ndarray | None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
+ cleaned_text = prompt.lower().strip().replace("_", " ")
+
+ # Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
+ discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
+
+ # Convention: prefix includes prompt and string-representation of state, followed by ';'
+ state_str = " ".join(map(str, discretized_state))
+ prefix = f"Task: {cleaned_text}, State: {state_str};\n"
+ prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
+
+ if actions is not None:
+ raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
+ postfix_tokens = []
+
+ # Create output token sequence & masks
+ # AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
+ tokens = prefix_tokens + postfix_tokens
+ token_mask = [True] * len(tokens)
+ ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
+ loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
+
+ # Pad tokens to max length
+ tokens_len = len(tokens)
+ if tokens_len < self._max_len:
+ padding = [False] * (self._max_len - tokens_len)
+ tokens = tokens + padding
+ token_mask = token_mask + padding
+ ar_mask = ar_mask + padding
+ loss_mask = loss_mask + padding
+ else:
+ if len(tokens) > self._max_len:
+ logging.warning(
+ f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
+ "Consider increasing the `max_token_len` in your model config if this happens frequently."
+ )
+ tokens = tokens[: self._max_len]
+ token_mask = token_mask[: self._max_len]
+ ar_mask = ar_mask[: self._max_len]
+ loss_mask = loss_mask[: self._max_len]
+
+ return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
+
+ def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
+ # Decode predicted output tokens
+ decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
+
+ # Extract actions from FAST model outputs
+ if "Action: " not in decoded_tokens:
+ return np.zeros((action_horizon, action_dim), dtype=np.float32)
+
+ # Extract actions from decoded tokens
+ raw_action_tokens = np.array(
+ self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
+ )
+ action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
+ try:
+ # Move computation to CPU and compile on-demand
+ device = jax.devices("cpu")[0]
+ with jax.default_device(device):
+ detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
+ return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
+ except Exception as e:
+ logging.warning(f"Error decoding FSQ: {e}")
+ return np.zeros((action_horizon, action_dim))
+
+ def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
+ if isinstance(tokens, list):
+ tokens = np.array(tokens)
+ return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
diff --git a/openpi/src/openpi/models/tokenizer_test.py b/openpi/src/openpi/models/tokenizer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..092c4ae3b2330db9f89344dc4bae150457fa4310
--- /dev/null
+++ b/openpi/src/openpi/models/tokenizer_test.py
@@ -0,0 +1,27 @@
+import numpy as np
+
+from openpi.models import tokenizer as _tokenizer
+
+
+def test_tokenize():
+ tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
+ tokens, masks = tokenizer.tokenize("Hello, world!")
+
+ assert tokens.shape == (10,)
+ assert masks.shape == (10,)
+
+
+def test_fast_tokenizer():
+ prompt = "Hello, world!"
+ state = np.random.rand(5).astype(np.float32)
+ action = np.random.rand(3, 2).astype(np.float32)
+ tokenizer = _tokenizer.FASTTokenizer(max_len=256)
+ tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action)
+
+ assert tokens.shape == (256,)
+ assert token_masks.shape == (256,)
+ assert ar_masks.shape == (256,)
+ assert loss_masks.shape == (256,)
+
+ act = tokenizer.extract_actions(tokens, 3, 2)
+ assert act.shape == (3, 2)
diff --git a/openpi/src/openpi/models/utils/fsq_tokenizer.py b/openpi/src/openpi/models/utils/fsq_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..574d7733de0ebd25898d8f82fdc6e37bd0aa75ab
--- /dev/null
+++ b/openpi/src/openpi/models/utils/fsq_tokenizer.py
@@ -0,0 +1,472 @@
+import math
+from typing import Any, Literal
+
+import chex
+from einops import einops
+from flax import linen as nn
+from flax.linen.module import Module
+from flax.linen.module import compact
+from flax.struct import dataclass
+from flax.typing import Array
+import jax
+import jax.numpy as jnp
+
+
+class FsqCodebook(nn.Module):
+ input_dim: int
+ target_codebook_size: int
+ codebook_type: Literal["fsq", "lfq"]
+
+ _bins_per_dim: tuple[int] | None = None
+
+ @property
+ def bins_per_dim(self) -> tuple[int]:
+ if self._bins_per_dim is not None:
+ return self._bins_per_dim
+
+ if self.codebook_type == "fsq":
+ return self._get_bins_fsq(self.target_codebook_size)
+ elif self.codebook_type == "lfq": # noqa: RET505
+ return self._get_bins_lfq(self.target_codebook_size)
+ elif self.codebook_type == "custom":
+ return self._get_bins_custom(self.target_codebook_size)
+ else:
+ raise ValueError(f"Codebook type {self.codebook_type} not supported.")
+
+ @property
+ def place_values(self) -> jnp.ndarray:
+ place_values = [1]
+ for b in self.bins_per_dim[:-1]:
+ place_values.append(place_values[-1] * b)
+ return jnp.array(place_values)
+
+ @staticmethod
+ def _get_bins_fsq(target_codebook_size: int) -> tuple[int]:
+ """
+ Get bins per dimension based on codebook size, from the original FSQ paper.
+ """
+ if target_codebook_size == 2**8:
+ return (8, 6, 5)
+ elif target_codebook_size == 2**10: # noqa: RET505
+ return (8, 5, 5, 5)
+ elif target_codebook_size == 2**12:
+ return (7, 5, 5, 5, 5)
+ elif target_codebook_size == 2**14:
+ return (8, 8, 8, 6, 5)
+ elif target_codebook_size == 2**16:
+ return (8, 8, 8, 5, 5, 5)
+ else:
+ raise ValueError(f"Codebook size {target_codebook_size} not supported.")
+
+ @staticmethod
+ def _get_bins_custom(target_codebook_size: int) -> tuple[int]:
+ if target_codebook_size == 2**8:
+ return (16, 16)
+ elif target_codebook_size == 2**10: # noqa: RET505
+ return (32, 32)
+ elif target_codebook_size == 2**12:
+ return (64, 64)
+ elif target_codebook_size == 2**14:
+ return (128, 128)
+ elif target_codebook_size == 2**16:
+ return (256, 256)
+ return None
+
+ @staticmethod
+ def _get_bins_lfq(target_codebook_size: int) -> tuple[int]:
+ """
+ Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
+ """
+ assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
+
+ return (2,) * int(math.log2(target_codebook_size))
+
+ def setup(self):
+ self.proj_down = nn.Dense(len(self.bins_per_dim))
+ self.proj_up = nn.Dense(self.input_dim)
+
+ def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
+ tokens, z = self.encode(inputs)
+ output = self.decode(tokens, z_grad=z)
+ return tokens, output
+
+ def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
+ bases = jnp.array(self.bins_per_dim)
+
+ x = self.proj_down(inputs)
+ z = jnp.tanh(x)
+
+ # Quantize
+ digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
+ tokens = self.undigitize(digits)
+
+ return tokens, z
+
+ def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray:
+ bases = jnp.array(self.bins_per_dim)
+ digits = self.digitize(tokens)
+
+ z_q = digits / (bases - 1) * 2 - 1
+
+ if z_grad is not None:
+ chex.assert_equal_shape([z_q, z_grad])
+ z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
+
+ return self.proj_up(z_q)
+
+ def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray:
+ return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
+
+ def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray:
+ return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
+
+ @property
+ def vocab_size(self) -> int:
+ return math.prod(self.bins_per_dim)
+
+
+class ResNetDownBlock(nn.Module):
+ stride: int = 1
+ n_filters: int = 64
+ dropout_rate: float = 0.0
+ group_size: int = 32
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
+ skip = x
+
+ if self.stride > 1 or x.shape[-1] != self.n_filters:
+ skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
+
+ x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
+ x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
+ x = nn.relu(x)
+ x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
+
+ return skip + x
+
+
+class ResNetUpBlock(nn.Module):
+ stride: int = 1
+ n_filters: int = 64
+ dropout_rate: float = 0.0
+ group_size: int = 32
+
+ @nn.compact
+ def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray:
+ skip = x
+
+ if self.stride > 1:
+ skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
+
+ x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
+ x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
+ x = nn.relu(x)
+ x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
+
+ return skip + x
+
+
+@dataclass
+class LfqCodebookOutput:
+ tokens: jnp.ndarray
+ z: jnp.ndarray
+ z_q: jnp.ndarray
+ token_log_probs: jnp.ndarray
+ commit_loss: jnp.ndarray
+
+
+class LookupFreeQuantization(nn.Module):
+ num_dims: int
+ latent_dim: int
+
+ def setup(self):
+ self.codebook = jnp.array([-1, 1])
+ self.activation = nn.tanh
+
+ self.project_down = nn.Dense(self.num_dims)
+ self.project_up = nn.Dense(self.latent_dim)
+
+ def encode(self, z: jnp.ndarray) -> jnp.ndarray:
+ z = self.project_down(z)
+ token_squared_distances = jnp.square(z[..., None] - self.codebook)
+ token_bits = jnp.argmin(token_squared_distances, axis=-1)
+ return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
+
+ def decode(self, tokens: jnp.ndarray) -> jnp.ndarray:
+ token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
+ return self.project_up(self.codebook[token_bits])
+
+ def loss(self, x: jnp.ndarray) -> LfqCodebookOutput:
+ z = self.project_down(x)
+ z = self.activation(z)
+
+ token_squared_distances = jnp.square(z[..., None] - self.codebook)
+ tokens = jnp.argmin(token_squared_distances, axis=-1)
+
+ token_bit_log_probs = -token_squared_distances
+ # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
+ token_bit_expansions = jnp.bitwise_and(
+ jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
+ ).astype(jnp.int32)
+ token_log_probs = (
+ token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
+ + token_bit_log_probs[..., 1] @ token_bit_expansions
+ ) # (batch_size, num_tokens, 2 ** num_dims)
+ token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
+ chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
+
+ z_q = self.codebook[tokens]
+ commit_loss = jnp.square(z - z_q).mean()
+ z_q = jax.lax.stop_gradient(z_q - z) + z
+
+ z_q = self.project_up(z_q)
+ z = self.project_up(z)
+
+ tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
+ return LfqCodebookOutput(
+ tokens=tokens,
+ z=z,
+ z_q=z_q,
+ token_log_probs=jnp.zeros(()),
+ commit_loss=commit_loss,
+ )
+
+
+def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray:
+ return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
+
+
+class GeGLU(Module):
+ """Gated Linear Unit with GELU (GeGLU) activation function.
+ GeGLU is a Flax layer that combines a linear transformation with a GELU
+ activation function in a gating mechanism. It is often used in Transformer models
+ to provide non-linear capabilities while preserving a strong linear component.
+
+ Attributes:
+ features: the number of output features (default: None).
+ """
+
+ output_dim: int = -1
+
+ @compact
+ def __call__(self, inputs: Array) -> Array:
+ """Applies the GeGLU activation to the inputs.
+ Args:
+ inputs: the nd-array to apply the GeGLU activation function to.
+ Returns:
+ The transformed input.
+ """
+ output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim
+
+ x = nn.Dense(output_dim * 2)(inputs)
+ x, gate = x[..., :output_dim], x[..., output_dim:]
+ return x * nn.gelu(gate)
+
+
+class CrossAttentionLayer(nn.Module):
+ dropout_rate: float = 0.0
+ num_heads: int = None
+ causal: bool = False
+ mlp_ratio: float = 4.0
+
+ @nn.compact
+ def __call__(
+ self,
+ x: jnp.ndarray,
+ y: jnp.ndarray,
+ *,
+ mask_self: jnp.ndarray | None = None,
+ mask_cross: jnp.ndarray | None = None,
+ train: bool = True,
+ ) -> jnp.ndarray:
+ d_embed = x.shape[-1]
+ seq_len_q = x.shape[-2]
+ seq_len_k = y.shape[-2]
+
+ if self.causal:
+ # One block size will be 1
+ bs_q = max(seq_len_q // seq_len_k, 1)
+ bs_k = max(seq_len_k // seq_len_q, 1)
+
+ mask_self = nn.make_causal_mask(x[..., 0])
+ mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
+
+ # Self-attention block
+ skip = x
+ x = nn.LayerNorm()(x)
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads or d_embed // 64,
+ dropout_rate=self.dropout_rate,
+ deterministic=not train,
+ )(x, x, x, mask=mask_self)
+ x = skip + x
+
+ # Cross-attention block
+ skip = x
+ x = nn.LayerNorm()(x)
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads or d_embed // 64,
+ dropout_rate=self.dropout_rate,
+ deterministic=not train,
+ )(x, y, y, mask=mask_cross)
+ x = skip + x
+
+ # MLP block
+ skip = x
+ x = nn.LayerNorm()(x)
+ x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
+ x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
+ x = GeGLU()(x)
+ x = nn.Dense(d_embed)(x)
+ return skip + x
+
+
+def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray:
+ seq_len, d_embed = shape
+
+ position = jnp.arange(0, seq_len, 1)
+ div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
+ return jnp.concatenate(
+ [
+ jnp.sin(position[:, jnp.newaxis] * div_term),
+ jnp.cos(position[:, jnp.newaxis] * div_term),
+ ],
+ axis=-1,
+ )
+
+
+class TokenizerEncoderDecoder(nn.Module):
+ num_tokens: int
+ num_cross_tokens: int
+ num_layers: int
+ causal: bool
+
+ mlp_ratio: float = 4.0
+ use_state_conditioning: bool = False
+
+ @nn.compact
+ def __call__(
+ self,
+ y: jnp.ndarray,
+ *,
+ train: bool = True,
+ state_conditioning: jnp.ndarray | None = None,
+ mask: jnp.ndarray | None = None,
+ ) -> jnp.ndarray:
+ x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
+ x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
+
+ if mask is not None:
+ # mask is (batch_dims..., num_cross_tokens)
+ chex.assert_equal_shape([y[..., 0], mask])
+ attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
+ else:
+ attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens))
+
+ if self.use_state_conditioning:
+ assert state_conditioning is not None, "State conditioning is required for this model."
+ state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
+ y = jnp.concatenate([y, state_embed], axis=-2)
+ attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
+
+ y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
+
+ for _ in range(self.num_layers):
+ x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
+ x, y, train=train, mask_self=None, mask_cross=attn_mask
+ )
+
+ return x
+
+
+class FsqAttentionTokenizer(nn.Module):
+ embed_dim: int
+ data_dim: int
+ data_horizon: int
+ num_tokens: int
+ num_layers: int
+ target_codebook_size: int
+ causal: bool = False
+ mlp_ratio: float = 2.0
+
+ bound: float | None = None
+
+ use_state_conditioning: bool = False
+
+ @property
+ def vocab_size(self) -> int:
+ return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001
+
+ def setup(self):
+ self.proj = nn.Dense(self.embed_dim)
+ self.encoder = TokenizerEncoderDecoder(
+ num_tokens=self.num_tokens,
+ num_cross_tokens=self.data_horizon,
+ num_layers=self.num_layers,
+ causal=self.causal,
+ use_state_conditioning=self.use_state_conditioning,
+ mlp_ratio=self.mlp_ratio,
+ )
+ self.codebook = FsqCodebook(
+ input_dim=self.embed_dim,
+ target_codebook_size=self.target_codebook_size,
+ codebook_type="custom",
+ )
+ self.decoder = TokenizerEncoderDecoder(
+ num_tokens=self.data_horizon,
+ num_cross_tokens=self.num_tokens,
+ num_layers=self.num_layers,
+ causal=self.causal,
+ use_state_conditioning=self.use_state_conditioning,
+ mlp_ratio=self.mlp_ratio,
+ )
+
+ self.proj_mean = nn.Dense(self.data_dim)
+ self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
+
+ def tokenize(
+ self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
+ if self.bound is not None:
+ action = jnp.clip(action, -self.bound, self.bound)
+
+ x = self.proj(action)
+ x = self.encoder(x, train=train, state_conditioning=obs)
+
+ return self.codebook.encode(x)
+
+ def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray:
+ x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
+ mean = self.proj_mean(x)
+ return mean * self.out_scale
+
+ def loss(
+ self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True
+ ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
+ # Encode
+ x = self.proj(action)
+ z = self.encoder(x, train=train, state_conditioning=obs)
+
+ # Quantize
+ tokens, z = self.codebook(z)
+
+ # Decode
+ x = self.decoder(z, train=train, state_conditioning=obs)
+ mean = self.proj_mean(x) * self.out_scale
+
+ mse = jnp.mean(jnp.square(action - mean))
+ mae = jnp.mean(jnp.abs(action - mean))
+
+ return mse, {
+ "mse": mse,
+ "mae": mae,
+ }
+
+ def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
+ """
+ Dummy for .init
+ """
+ return self.loss(*args, **kwargs)
diff --git a/openpi/src/openpi/models/vit.py b/openpi/src/openpi/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7901d097b86d36e3564d2b664ed50af8a194080
--- /dev/null
+++ b/openpi/src/openpi/models/vit.py
@@ -0,0 +1,307 @@
+# Copyright 2024 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
+
+from collections.abc import Callable
+from typing import Any
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+from openpi.models import resnet as models_resnet
+
+Array = Any
+PRNGKey = Any
+Shape = tuple[int]
+Dtype = Any
+
+
+class IdentityLayer(nn.Module):
+ """Identity layer, convenient for giving a name to an array."""
+
+ @nn.compact
+ def __call__(self, x):
+ return x
+
+
+class AddPositionEmbs(nn.Module):
+ """Adds learned positional embeddings to the inputs.
+
+ Attributes:
+ posemb_init: positional embedding initializer.
+ """
+
+ posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
+ param_dtype: Dtype = jnp.float32
+
+ @nn.compact
+ def __call__(self, inputs):
+ """Applies the AddPositionEmbs module.
+
+ Args:
+ inputs: Inputs to the layer.
+
+ Returns:
+ Output tensor with shape `(bs, timesteps, in_dim)`.
+ """
+ # inputs.shape is (batch_size, seq_len, emb_dim).
+ assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
+ pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
+ pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
+ return inputs + pe
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block."""
+
+ mlp_dim: int
+ dtype: Dtype = jnp.float32
+ param_dtype: Dtype = jnp.float32
+ out_dim: int | None = None
+ dropout_rate: float = 0.1
+ kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
+ bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
+
+ @nn.compact
+ def __call__(self, inputs, *, deterministic):
+ """Applies Transformer MlpBlock module."""
+ actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
+ x = nn.Dense(
+ features=self.mlp_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ kernel_init=self.kernel_init,
+ bias_init=self.bias_init,
+ )( # pytype: disable=wrong-arg-types
+ inputs
+ )
+ x = nn.gelu(x)
+ x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
+ output = nn.Dense(
+ features=actual_out_dim,
+ dtype=self.dtype,
+ param_dtype=self.param_dtype,
+ kernel_init=self.kernel_init,
+ bias_init=self.bias_init,
+ )( # pytype: disable=wrong-arg-types
+ x
+ )
+ return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
+
+
+class Encoder1DBlock(nn.Module):
+ """Transformer encoder layer.
+
+ Attributes:
+ inputs: input data.
+ mlp_dim: dimension of the mlp on top of attention block.
+ dtype: the dtype of the computation (default: float32).
+ dropout_rate: dropout rate.
+ attention_dropout_rate: dropout for attention heads.
+ deterministic: bool, deterministic or not (to apply dropout).
+ num_heads: Number of heads in nn.MultiHeadDotProductAttention
+ """
+
+ mlp_dim: int
+ num_heads: int
+ dtype: Dtype = jnp.float32
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+
+ @nn.compact
+ def __call__(self, inputs, deterministic):
+ """Applies Encoder1DBlock module.
+
+ Args:
+ inputs: Inputs to the layer.
+ deterministic: Dropout will not be applied when set to true.
+
+ Returns:
+ output after transformer encoder block.
+ """
+
+ # Attention block.
+ assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
+ x = nn.LayerNorm(dtype=self.dtype)(inputs)
+ x = nn.MultiHeadDotProductAttention(
+ dtype=self.dtype,
+ kernel_init=nn.initializers.xavier_uniform(),
+ broadcast_dropout=False,
+ deterministic=deterministic,
+ dropout_rate=self.attention_dropout_rate,
+ num_heads=self.num_heads,
+ # why isn't this true by default???
+ force_fp32_for_softmax=True,
+ )(x, x)
+ x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
+ x = x + inputs
+
+ # MLP block.
+ y = nn.LayerNorm(dtype=self.dtype)(x)
+ y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
+ y, deterministic=deterministic
+ )
+
+ return x + y, None
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation.
+
+ Attributes:
+ num_layers: number of layers
+ mlp_dim: dimension of the mlp on top of attention block
+ num_heads: Number of heads in nn.MultiHeadDotProductAttention
+ dropout_rate: dropout rate.
+ attention_dropout_rate: dropout rate in self attention.
+ """
+
+ dtype: jax.typing.DTypeLike
+ num_layers: int
+ mlp_dim: int
+ num_heads: int
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ add_position_embedding: bool = True
+
+ @nn.compact
+ def __call__(self, x, *, train):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ x: Inputs to the layer.
+ train: Set to `True` when training.
+
+ Returns:
+ output of a transformer encoder.
+ """
+ assert x.ndim == 3 # (batch, len, emb)
+
+ if self.add_position_embedding:
+ x = AddPositionEmbs(
+ posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
+ name="posembed_input",
+ )(x)
+ x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
+
+ x = x.astype(self.dtype)
+ # Input Encoder
+ block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
+ x, _ = nn.scan(
+ block,
+ variable_axes={"params": 0},
+ split_rngs={"params": True, "dropout": True},
+ in_axes=nn.broadcast,
+ length=self.num_layers,
+ )(
+ name="encoderblock",
+ mlp_dim=self.mlp_dim,
+ dropout_rate=self.dropout_rate,
+ attention_dropout_rate=self.attention_dropout_rate,
+ dtype=self.dtype,
+ num_heads=self.num_heads,
+ )(x, not train)
+ return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
+
+
+class VisionTransformer(nn.Module):
+ """VisionTransformer."""
+
+ dtype: jax.typing.DTypeLike
+ num_classes: int
+ patches: Any
+ transformer: Any
+ hidden_size: int
+ resnet: Any | None = None
+ representation_size: int | None = None
+ classifier: str = "token"
+ head_bias_init: float = 0.0
+ encoder: type[nn.Module] = Encoder
+ model_name: str | None = None
+
+ @nn.compact
+ def __call__(self, inputs, *, train):
+ x = inputs
+ # (Possibly partial) ResNet root.
+ if self.resnet is not None:
+ width = int(64 * self.resnet.width_factor)
+
+ # Root block.
+ x = models_resnet.StdConv(
+ features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
+ )(x)
+ x = nn.GroupNorm(name="gn_root")(x)
+ x = nn.relu(x)
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
+
+ # ResNet stages.
+ if self.resnet.num_layers:
+ x = models_resnet.ResNetStage(
+ block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
+ )(x)
+ for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
+ x = models_resnet.ResNetStage(
+ block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
+ )(x)
+
+ n, h, w, c = x.shape
+
+ # We can merge s2d+emb into a single conv; it's the same.
+ x = nn.Conv(
+ features=self.hidden_size,
+ kernel_size=self.patches.size,
+ strides=self.patches.size,
+ padding="VALID",
+ name="embedding",
+ )(x)
+
+ # Here, x is a grid of embeddings.
+
+ # (Possibly partial) Transformer.
+ if self.transformer is not None:
+ n, h, w, c = x.shape
+ x = jnp.reshape(x, [n, h * w, c])
+
+ # If we want to add a class token, add it here.
+ if self.classifier in ["token", "token_unpooled"]:
+ cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
+ cls = jnp.tile(cls, [n, 1, 1])
+ x = jnp.concatenate([cls, x], axis=1)
+
+ x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
+
+ if self.classifier == "token":
+ x = x[:, 0]
+ elif self.classifier == "gap":
+ x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
+ elif self.classifier in ["unpooled", "token_unpooled"]:
+ pass
+ else:
+ raise ValueError(f"Invalid classifier={self.classifier}")
+
+ if self.representation_size is not None:
+ x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
+ x = nn.tanh(x)
+ else:
+ x = IdentityLayer(name="pre_logits")(x)
+
+ if self.num_classes:
+ x = nn.Dense(
+ features=self.num_classes,
+ name="head",
+ kernel_init=nn.initializers.zeros,
+ bias_init=nn.initializers.constant(self.head_bias_init),
+ )(x)
+ return x
diff --git a/openpi/src/openpi/models_pytorch/gemma_pytorch.py b/openpi/src/openpi/models_pytorch/gemma_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..203b36be8ae4c525422c93aa50926c03ded7deb5
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/gemma_pytorch.py
@@ -0,0 +1,281 @@
+from typing import Literal
+
+import pytest
+import torch
+from torch import nn
+from transformers import GemmaForCausalLM
+from transformers import PaliGemmaForConditionalGeneration
+from transformers.models.auto import CONFIG_MAPPING
+from transformers.models.gemma import modeling_gemma
+
+
+class PaliGemmaWithExpertModel(nn.Module):
+ def __init__(
+ self,
+ vlm_config,
+ action_expert_config,
+ use_adarms=None,
+ precision: Literal["bfloat16", "float32"] = "bfloat16",
+ ):
+ if use_adarms is None:
+ use_adarms = [False, False]
+ super().__init__()
+
+ vlm_config_hf = CONFIG_MAPPING["paligemma"]()
+ vlm_config_hf._vocab_size = 257152 # noqa: SLF001
+ vlm_config_hf.image_token_index = 257152
+ vlm_config_hf.text_config.hidden_size = vlm_config.width
+ vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
+ vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
+ vlm_config_hf.text_config.head_dim = vlm_config.head_dim
+ vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
+ vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
+ vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
+ vlm_config_hf.text_config.torch_dtype = "float32"
+ vlm_config_hf.text_config.vocab_size = 257152
+ vlm_config_hf.text_config.use_adarms = use_adarms[0]
+ vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
+ vlm_config_hf.vision_config.intermediate_size = 4304
+ vlm_config_hf.vision_config.projection_dim = 2048
+ vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
+ vlm_config_hf.vision_config.torch_dtype = "float32"
+
+ action_expert_config_hf = CONFIG_MAPPING["gemma"](
+ head_dim=action_expert_config.head_dim,
+ hidden_size=action_expert_config.width,
+ intermediate_size=action_expert_config.mlp_dim,
+ num_attention_heads=action_expert_config.num_heads,
+ num_hidden_layers=action_expert_config.depth,
+ num_key_value_heads=action_expert_config.num_kv_heads,
+ vocab_size=257152,
+ hidden_activation="gelu_pytorch_tanh",
+ torch_dtype="float32",
+ use_adarms=use_adarms[1],
+ adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
+ )
+
+ self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
+ self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
+ self.gemma_expert.model.embed_tokens = None
+
+ self.to_bfloat16_for_selected_params(precision)
+
+ def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
+ if precision == "bfloat16":
+ self.to(dtype=torch.bfloat16)
+ elif precision == "float32":
+ self.to(dtype=torch.float32)
+ return
+ else:
+ raise ValueError(f"Invalid precision: {precision}")
+
+ params_to_keep_float32 = [
+ "vision_tower.vision_model.embeddings.patch_embedding.weight",
+ "vision_tower.vision_model.embeddings.patch_embedding.bias",
+ "vision_tower.vision_model.embeddings.position_embedding.weight",
+ "input_layernorm",
+ "post_attention_layernorm",
+ "model.norm",
+ ]
+
+ for name, param in self.named_parameters():
+ if any(selector in name for selector in params_to_keep_float32):
+ param.data = param.data.to(dtype=torch.float32)
+
+ def embed_image(self, image: torch.Tensor):
+ return self.paligemma.model.get_image_features(image)
+
+ def embed_language_tokens(self, tokens: torch.Tensor):
+ return self.paligemma.language_model.embed_tokens(tokens)
+
+ def forward(
+ self,
+ attention_mask: torch.Tensor | None = None,
+ position_ids: torch.LongTensor | None = None,
+ past_key_values: list[torch.FloatTensor] | pytest.Cache | None = None,
+ inputs_embeds: list[torch.FloatTensor] | None = None,
+ use_cache: bool | None = None,
+ adarms_cond: list[torch.Tensor] | None = None,
+ ):
+ if adarms_cond is None:
+ adarms_cond = [None, None]
+ if inputs_embeds[1] is None:
+ prefix_output = self.paligemma.language_model.forward(
+ inputs_embeds=inputs_embeds[0],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
+ )
+ prefix_past_key_values = prefix_output.past_key_values
+ prefix_output = prefix_output.last_hidden_state
+ suffix_output = None
+ elif inputs_embeds[0] is None:
+ suffix_output = self.gemma_expert.model.forward(
+ inputs_embeds=inputs_embeds[1],
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
+ )
+ suffix_output = suffix_output.last_hidden_state
+ prefix_output = None
+ prefix_past_key_values = None
+ else:
+ models = [self.paligemma.language_model, self.gemma_expert.model]
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
+
+ # Check if gradient checkpointing is enabled for any of the models
+ use_gradient_checkpointing = (
+ hasattr(self.gemma_expert.model, "gradient_checkpointing")
+ and self.gemma_expert.model.gradient_checkpointing
+ and self.training
+ ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
+
+ # Force enable gradient checkpointing if we're in training mode and the model supports it
+ if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
+ if not self.gemma_expert.model.gradient_checkpointing:
+ print("Forcing gradient checkpointing to be enabled for Gemma expert model")
+ self.gemma_expert.model.gradient_checkpointing = True
+ use_gradient_checkpointing = True
+
+ # Debug gradient checkpointing status
+ if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
+ print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
+ print(f"Model training mode: {self.training}")
+ print(
+ f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
+ )
+ if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
+ print(
+ f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
+ )
+ self._debug_gc_printed = True
+
+ # Define the complete layer computation function for gradient checkpointing
+ def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
+ models = [self.paligemma.language_model, self.gemma_expert.model]
+
+ query_states = []
+ key_states = []
+ value_states = []
+ gates = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
+ gates.append(gate)
+
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
+ query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ query_states.append(query_state)
+ key_states.append(key_state)
+ value_states.append(value_state)
+
+ # Concatenate and process attention
+ query_states = torch.cat(query_states, dim=2)
+ key_states = torch.cat(key_states, dim=2)
+ value_states = torch.cat(value_states, dim=2)
+
+ dummy_tensor = torch.zeros(
+ query_states.shape[0],
+ query_states.shape[2],
+ query_states.shape[-1],
+ device=query_states.device,
+ dtype=query_states.dtype,
+ )
+ cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
+ query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, unsqueeze_dim=1
+ )
+
+ batch_size = query_states.shape[0]
+ scaling = self.paligemma.language_model.layers[layer_idx].self_attn.scaling
+
+ # Attention computation
+ att_output, _ = modeling_gemma.eager_attention_forward(
+ self.paligemma.language_model.layers[layer_idx].self_attn,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ scaling,
+ )
+ # Get head_dim from the current layer, not from the model
+ head_dim = self.paligemma.language_model.layers[layer_idx].self_attn.head_dim
+ att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
+
+ # Process layer outputs
+ outputs_embeds = []
+ start_pos = 0
+ for i, hidden_states in enumerate(inputs_embeds):
+ layer = models[i].layers[layer_idx]
+ end_pos = start_pos + hidden_states.shape[1]
+
+ if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
+ out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
+
+ # first residual
+ out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
+ after_first_residual = out_emb.clone()
+ out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
+ # Convert to bfloat16 if the next layer (mlp) uses bfloat16
+ if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
+ out_emb = out_emb.to(dtype=torch.bfloat16)
+
+ out_emb = layer.mlp(out_emb)
+ # second residual
+ out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
+ outputs_embeds.append(out_emb)
+ start_pos = end_pos
+
+ return outputs_embeds
+
+ # Process all layers with gradient checkpointing if enabled
+ for layer_idx in range(num_layers):
+ if use_gradient_checkpointing:
+ inputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_layer_complete,
+ layer_idx,
+ inputs_embeds,
+ attention_mask,
+ position_ids,
+ adarms_cond,
+ use_reentrant=False,
+ preserve_rng_state=False,
+ )
+ else:
+ inputs_embeds = compute_layer_complete(
+ layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
+ )
+
+ # Old code removed - now using compute_layer_complete function above
+
+ # final norm
+ # Define final norm computation function for gradient checkpointing
+ def compute_final_norms(inputs_embeds, adarms_cond):
+ outputs_embeds = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
+ outputs_embeds.append(out_emb)
+ return outputs_embeds
+
+ # Apply gradient checkpointing to final norm if enabled
+ if use_gradient_checkpointing:
+ outputs_embeds = torch.utils.checkpoint.checkpoint(
+ compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False
+ )
+ else:
+ outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
+
+ prefix_output = outputs_embeds[0]
+ suffix_output = outputs_embeds[1]
+ prefix_past_key_values = None
+
+ return [prefix_output, suffix_output], prefix_past_key_values
diff --git a/openpi/src/openpi/models_pytorch/pi0_pytorch.py b/openpi/src/openpi/models_pytorch/pi0_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..25f0580ba818d55aadeab444a5293cac558f6804
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/pi0_pytorch.py
@@ -0,0 +1,461 @@
+import logging
+import math
+
+import torch
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F # noqa: N812
+
+import openpi.models.gemma as _gemma
+from openpi.models_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
+import openpi.models_pytorch.preprocessing_pytorch as _preprocessing
+
+
+def get_safe_dtype(target_dtype, device_type):
+ """Get a safe dtype for the given device type."""
+ if device_type == "cpu":
+ # CPU doesn't support bfloat16, use float32 instead
+ if target_dtype == torch.bfloat16:
+ return torch.float32
+ if target_dtype == torch.float64:
+ return torch.float64
+ return target_dtype
+
+
+def create_sinusoidal_pos_embedding(
+ time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
+) -> Tensor:
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
+ if dimension % 2 != 0:
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
+
+ if time.ndim != 1:
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
+
+ dtype = get_safe_dtype(torch.float64, device.type)
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
+ period = min_period * (max_period / min_period) ** fraction
+
+ # Compute the outer product
+ scaling_factor = 1.0 / period * 2 * math.pi
+ sin_input = scaling_factor[None, :] * time[:, None]
+ return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
+
+
+def sample_beta(alpha, beta, bsize, device):
+ alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
+ beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
+ dist = torch.distributions.Beta(alpha_t, beta_t)
+ return dist.sample((bsize,))
+
+
+def make_att_2d_masks(pad_masks, att_masks):
+ """Copied from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
+ it and 0 where it shares the same attention mask as the previous token.
+ """
+ if att_masks.ndim != 2:
+ raise ValueError(att_masks.ndim)
+ if pad_masks.ndim != 2:
+ raise ValueError(pad_masks.ndim)
+
+ cumsum = torch.cumsum(att_masks, dim=1)
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
+ return att_2d_masks & pad_2d_masks
+
+
+class PI0Pytorch(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pi05 = config.pi05
+
+ paligemma_config = _gemma.get_config(config.paligemma_variant)
+ action_expert_config = _gemma.get_config(config.action_expert_variant)
+
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
+ paligemma_config,
+ action_expert_config,
+ use_adarms=[False, True] if self.pi05 else [False, False],
+ precision=config.dtype,
+ )
+
+ self.action_in_proj = nn.Linear(32, action_expert_config.width)
+ self.action_out_proj = nn.Linear(action_expert_config.width, 32)
+
+ if self.pi05:
+ self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
+ self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
+ else:
+ self.state_proj = nn.Linear(32, action_expert_config.width)
+ self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
+ self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
+
+ torch.set_float32_matmul_precision("high")
+ self.sample_actions = torch.compile(self.sample_actions, mode="max-autotune")
+
+ # Initialize gradient checkpointing flag
+ self.gradient_checkpointing_enabled = False
+
+ msg = "transformers_replace is not installed correctly. Please install it with `uv pip install transformers==4.53.2` and `cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/`."
+ try:
+ from transformers.models.siglip import check
+
+ if not check.check_whether_transformers_replace_is_installed_correctly():
+ raise ValueError(msg)
+ except ImportError:
+ raise ValueError(msg) from None
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.gradient_checkpointing_enabled = True
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
+
+ logging.info("Enabled gradient checkpointing for PI0Pytorch model")
+
+ def gradient_checkpointing_disable(self):
+ """Disable gradient checkpointing."""
+ self.gradient_checkpointing_enabled = False
+ self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
+ self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
+ self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
+
+ logging.info("Disabled gradient checkpointing for PI0Pytorch model")
+
+ def is_gradient_checkpointing_enabled(self):
+ """Check if gradient checkpointing is enabled."""
+ return self.gradient_checkpointing_enabled
+
+ def _apply_checkpoint(self, func, *args, **kwargs):
+ """Helper method to apply gradient checkpointing if enabled."""
+ if self.gradient_checkpointing_enabled and self.training:
+ return torch.utils.checkpoint.checkpoint(
+ func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
+ )
+ return func(*args, **kwargs)
+
+ def _prepare_attention_masks_4d(self, att_2d_masks):
+ """Helper method to prepare 4D attention masks for transformer."""
+ att_2d_masks_4d = att_2d_masks[:, None, :, :]
+ return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
+
+ def _preprocess_observation(self, observation, *, train=True):
+ """Helper method to preprocess observation."""
+ observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
+ return (
+ list(observation.images.values()),
+ list(observation.image_masks.values()),
+ observation.tokenized_prompt,
+ observation.tokenized_prompt_mask,
+ observation.state,
+ )
+
+ def sample_noise(self, shape, device):
+ return torch.normal(
+ mean=0.0,
+ std=1.0,
+ size=shape,
+ dtype=torch.float32,
+ device=device,
+ )
+
+ def sample_time(self, bsize, device):
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
+ time = time_beta * 0.999 + 0.001
+ return time.to(dtype=torch.float32, device=device)
+
+ def embed_prefix(
+ self, images, img_masks, lang_tokens, lang_masks
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
+ for PaliGemma transformer processing.
+ """
+ embs = []
+ pad_masks = []
+ att_masks = []
+
+ # Process images
+ for img, img_mask in zip(images, img_masks, strict=True):
+
+ def image_embed_func(img):
+ return self.paligemma_with_expert.embed_image(img)
+
+ img_emb = self._apply_checkpoint(image_embed_func, img)
+
+ bsize, num_img_embs = img_emb.shape[:2]
+
+ embs.append(img_emb)
+ pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
+
+ # Create attention masks so that image tokens attend to each other
+ att_masks += [0] * num_img_embs
+
+ # Process language tokens
+ def lang_embed_func(lang_tokens):
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
+ lang_emb_dim = lang_emb.shape[-1]
+ return lang_emb * math.sqrt(lang_emb_dim)
+
+ lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
+
+ embs.append(lang_emb)
+ pad_masks.append(lang_masks)
+
+ # full attention between image and language inputs
+ num_lang_embs = lang_emb.shape[1]
+ att_masks += [0] * num_lang_embs
+
+ embs = torch.cat(embs, dim=1)
+ pad_masks = torch.cat(pad_masks, dim=1)
+ att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
+
+ # Get batch size from the first dimension of the concatenated tensors
+ bsize = pad_masks.shape[0]
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
+
+ return embs, pad_masks, att_masks
+
+ def embed_suffix(self, state, noisy_actions, timestep):
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
+ embs = []
+ pad_masks = []
+ att_masks = []
+
+ if not self.pi05:
+ if self.state_proj.weight.dtype == torch.float32:
+ state = state.to(torch.float32)
+
+ # Embed state
+ def state_proj_func(state):
+ return self.state_proj(state)
+
+ state_emb = self._apply_checkpoint(state_proj_func, state)
+
+ embs.append(state_emb[:, None, :])
+ bsize = state_emb.shape[0]
+ device = state_emb.device
+
+ state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
+ pad_masks.append(state_mask)
+
+ # Set attention masks so that image and language inputs do not attend to state or actions
+ att_masks += [1]
+
+ # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ time_emb = create_sinusoidal_pos_embedding(
+ timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0, device=timestep.device
+ )
+ time_emb = time_emb.type(dtype=timestep.dtype)
+
+ # Fuse timestep + action information using an MLP
+ def action_proj_func(noisy_actions):
+ return self.action_in_proj(noisy_actions)
+
+ action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
+
+ if not self.pi05:
+ time_emb = time_emb[:, None, :].expand_as(action_emb)
+ action_time_emb = torch.cat([action_emb, time_emb], dim=2)
+
+ # Apply MLP layers
+ def mlp_func(action_time_emb):
+ x = self.action_time_mlp_in(action_time_emb)
+ x = F.silu(x) # swish == silu
+ return self.action_time_mlp_out(x)
+
+ action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
+ adarms_cond = None
+ else:
+ # time MLP (for adaRMS)
+ def time_mlp_func(time_emb):
+ x = self.time_mlp_in(time_emb)
+ x = F.silu(x) # swish == silu
+ x = self.time_mlp_out(x)
+ return F.silu(x)
+
+ time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
+ action_time_emb = action_emb
+ adarms_cond = time_emb
+
+ # Add to input tokens
+ embs.append(action_time_emb)
+
+ bsize, action_time_dim = action_time_emb.shape[:2]
+ action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
+ pad_masks.append(action_time_mask)
+
+ # Set attention masks so that image, language and state inputs do not attend to action tokens
+ att_masks += [1] + ([0] * (self.config.action_horizon - 1))
+
+ embs = torch.cat(embs, dim=1)
+ pad_masks = torch.cat(pad_masks, dim=1)
+ att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
+ att_masks = att_masks[None, :].expand(bsize, len(att_masks))
+
+ return embs, pad_masks, att_masks, adarms_cond
+
+ def forward(self, observation, actions, noise=None, time=None) -> Tensor:
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
+ images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=True)
+
+ if noise is None:
+ noise = self.sample_noise(actions.shape, actions.device)
+
+ if time is None:
+ time = self.sample_time(actions.shape[0], actions.device)
+
+ time_expanded = time[:, None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
+ if (
+ self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
+ == torch.bfloat16
+ ):
+ suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
+ prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
+
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
+
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
+
+ # Prepare attention masks
+ att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
+
+ # Apply gradient checkpointing if enabled
+ def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
+ attention_mask=att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+ return suffix_out
+
+ suffix_out = self._apply_checkpoint(
+ forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
+ )
+
+ suffix_out = suffix_out[:, -self.config.action_horizon :]
+ suffix_out = suffix_out.to(dtype=torch.float32)
+
+ # Apply gradient checkpointing to final action projection if enabled
+ def action_out_proj_func(suffix_out):
+ return self.action_out_proj(suffix_out)
+
+ v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
+
+ return F.mse_loss(u_t, v_t, reduction="none")
+
+ @torch.no_grad()
+ def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
+ bsize = observation.state.shape[0]
+ if noise is None:
+ actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
+ noise = self.sample_noise(actions_shape, device)
+
+ images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, lang_tokens, lang_masks)
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
+
+ # Compute image and language key value cache
+ prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
+ self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
+
+ _, past_key_values = self.paligemma_with_expert.forward(
+ attention_mask=prefix_att_2d_masks_4d,
+ position_ids=prefix_position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, None],
+ use_cache=True,
+ )
+
+ dt = -1.0 / num_steps
+ dt = torch.tensor(dt, dtype=torch.float32, device=device)
+
+ x_t = noise
+ time = torch.tensor(1.0, dtype=torch.float32, device=device)
+ while time >= -dt / 2:
+ expanded_time = time.expand(bsize)
+ v_t = self.denoise_step(
+ state,
+ prefix_pad_masks,
+ past_key_values,
+ x_t,
+ expanded_time,
+ )
+
+ # Euler step - use new tensor assignment instead of in-place operation
+ x_t = x_t + dt * v_t
+ time += dt
+ return x_t
+
+ def denoise_step(
+ self,
+ state,
+ prefix_pad_masks,
+ past_key_values,
+ x_t,
+ timestep,
+ ):
+ """Apply one denoising step of the noise `x_t` at a given timestep."""
+ suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
+
+ suffix_len = suffix_pad_masks.shape[1]
+ batch_size = prefix_pad_masks.shape[0]
+ prefix_len = prefix_pad_masks.shape[1]
+
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
+
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
+
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
+
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
+
+ # Prepare attention masks
+ full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
+ self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
+
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
+ attention_mask=full_att_2d_masks_4d,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=[None, suffix_embs],
+ use_cache=False,
+ adarms_cond=[None, adarms_cond],
+ )
+
+ suffix_out = outputs_embeds[1]
+ suffix_out = suffix_out[:, -self.config.action_horizon :]
+ suffix_out = suffix_out.to(dtype=torch.float32)
+ return self.action_out_proj(suffix_out)
diff --git a/openpi/src/openpi/models_pytorch/preprocessing_pytorch.py b/openpi/src/openpi/models_pytorch/preprocessing_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..33c94a59b18a7e45732a02200f0daef5b0f93018
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/preprocessing_pytorch.py
@@ -0,0 +1,173 @@
+from collections.abc import Sequence
+import logging
+
+import torch
+
+from openpi.shared import image_tools
+
+logger = logging.getLogger("openpi")
+
+# Constants moved from model.py
+IMAGE_KEYS = (
+ "base_0_rgb",
+ "left_wrist_0_rgb",
+ "right_wrist_0_rgb",
+)
+
+IMAGE_RESOLUTION = (224, 224)
+
+
+def preprocess_observation_pytorch(
+ observation,
+ *,
+ train: bool = False,
+ image_keys: Sequence[str] = IMAGE_KEYS,
+ image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
+):
+ """Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
+
+ This function avoids complex type annotations that can cause torch.compile issues.
+ """
+ if not set(image_keys).issubset(observation.images):
+ raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
+
+ batch_shape = observation.state.shape[:-1]
+
+ out_images = {}
+ for key in image_keys:
+ image = observation.images[key]
+
+ # TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
+ # Handle both [B, C, H, W] and [B, H, W, C] formats
+ is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
+
+ if is_channels_first:
+ # Convert [B, C, H, W] to [B, H, W, C] for processing
+ image = image.permute(0, 2, 3, 1)
+
+ if image.shape[1:3] != image_resolution:
+ logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
+ image = image_tools.resize_with_pad_torch(image, *image_resolution)
+
+ if train:
+ # Convert from [-1, 1] to [0, 1] for PyTorch augmentations
+ image = image / 2.0 + 0.5
+
+ # Apply PyTorch-based augmentations
+ if "wrist" not in key:
+ # Geometric augmentations for non-wrist cameras
+ height, width = image.shape[1:3]
+
+ # Random crop and resize
+ crop_height = int(height * 0.95)
+ crop_width = int(width * 0.95)
+
+ # Random crop
+ max_h = height - crop_height
+ max_w = width - crop_width
+ if max_h > 0 and max_w > 0:
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
+ start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
+ image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
+
+ # Resize back to original size
+ image = torch.nn.functional.interpolate(
+ image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
+ size=(height, width),
+ mode="bilinear",
+ align_corners=False,
+ ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
+
+ # Random rotation (small angles)
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
+ if torch.abs(angle) > 0.1: # Only rotate if angle is significant
+ # Convert to radians
+ angle_rad = angle * torch.pi / 180.0
+
+ # Create rotation matrix
+ cos_a = torch.cos(angle_rad)
+ sin_a = torch.sin(angle_rad)
+
+ # Apply rotation using grid_sample
+ grid_x = torch.linspace(-1, 1, width, device=image.device)
+ grid_y = torch.linspace(-1, 1, height, device=image.device)
+
+ # Create meshgrid
+ grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
+
+ # Expand to batch dimension
+ grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
+ grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
+
+ # Apply rotation transformation
+ grid_x_rot = grid_x * cos_a - grid_y * sin_a
+ grid_y_rot = grid_x * sin_a + grid_y * cos_a
+
+ # Stack and reshape for grid_sample
+ grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
+
+ image = torch.nn.functional.grid_sample(
+ image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
+ grid,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
+
+ # Color augmentations for all cameras
+ # Random brightness
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ brightness_factor = 0.7 + torch.rand(1, device=image.device) * 0.6 # Random factor between 0.7 and 1.3
+ image = image * brightness_factor
+
+ # Random contrast
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ contrast_factor = 0.6 + torch.rand(1, device=image.device) * 0.8 # Random factor between 0.6 and 1.4
+ mean = image.mean(dim=[1, 2, 3], keepdim=True)
+ image = (image - mean) * contrast_factor + mean
+
+ # Random saturation (convert to HSV, modify S, convert back)
+ # For simplicity, we'll just apply a random scaling to the color channels
+ # Use tensor operations instead of .item() for torch.compile compatibility
+ saturation_factor = 0.5 + torch.rand(1, device=image.device) * 1.0 # Random factor between 0.5 and 1.5
+ gray = image.mean(dim=-1, keepdim=True)
+ image = gray + (image - gray) * saturation_factor
+
+ # Clamp values to [0, 1]
+ image = torch.clamp(image, 0, 1)
+
+ # Back to [-1, 1]
+ image = image * 2.0 - 1.0
+
+ # Convert back to [B, C, H, W] format if it was originally channels-first
+ if is_channels_first:
+ image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
+
+ out_images[key] = image
+
+ # obtain mask
+ out_masks = {}
+ for key in out_images:
+ if key not in observation.image_masks:
+ # do not mask by default
+ out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
+ else:
+ out_masks[key] = observation.image_masks[key]
+
+ # Create a simple object with the required attributes instead of using the complex Observation class
+ class SimpleProcessedObservation:
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ return SimpleProcessedObservation(
+ images=out_images,
+ image_masks=out_masks,
+ state=observation.state,
+ tokenized_prompt=observation.tokenized_prompt,
+ tokenized_prompt_mask=observation.tokenized_prompt_mask,
+ token_ar_mask=observation.token_ar_mask,
+ token_loss_mask=observation.token_loss_mask,
+ )
diff --git a/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py b/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..a42abe1341d4cf30574cc5f239369cb5e55bfdbb
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/configuration_gemma.py
@@ -0,0 +1,173 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+from ...configuration_utils import PretrainedConfig
+
+
+class GemmaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Gemma-7B.
+ e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 256000):
+ Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`GemmaModel`]
+ hidden_size (`int`, *optional*, defaults to 3072):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 24576):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 28):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 16):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details, check out [this
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
+ `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 256):
+ The attention head dimension.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
+ The legacy activation function. It is overwritten by the `hidden_activation`.
+ hidden_activation (`str` or `function`, *optional*):
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
+ if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ eos_token_id (`int`, *optional*, defaults to 1):
+ End of stream token id.
+ bos_token_id (`int`, *optional*, defaults to 2):
+ Beginning of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ use_adarms (`bool`, *optional*, defaults to `False`):
+ Whether to use ADARMS.
+ adarms_cond_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the ADARMS condition.
+ ```python
+ >>> from transformers import GemmaModel, GemmaConfig
+ >>> # Initializing a Gemma gemma-7b style configuration
+ >>> configuration = GemmaConfig()
+ >>> # Initializing a model from the gemma-7b style configuration
+ >>> model = GemmaModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "gemma"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+ base_model_pp_plan = {
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
+ "norm": (["hidden_states"], ["hidden_states"]),
+ }
+
+ def __init__(
+ self,
+ vocab_size=256000,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ head_dim=256,
+ hidden_act="gelu_pytorch_tanh",
+ hidden_activation=None,
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ bos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ use_adarms: bool = False,
+ adarms_cond_dim: Optional[int] = None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.head_dim = head_dim
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.hidden_activation = hidden_activation
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self.use_adarms = use_adarms
+ self.adarms_cond_dim = adarms_cond_dim
+
+ # Set default for adarms_cond_dim if use_adarms is True
+ if self.use_adarms and self.adarms_cond_dim is None:
+ self.adarms_cond_dim = self.hidden_size
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+__all__ = ["GemmaConfig"]
\ No newline at end of file
diff --git a/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py b/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0529778bba56e0a45279aa44e29ef9a379026d9
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/transformers_replace/models/gemma/modeling_gemma.py
@@ -0,0 +1,862 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_gemma.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...masking_utils import create_causal_mask
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
+from .configuration_gemma import GemmaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class GemmaRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6, cond_dim: Optional[int] = None):
+ super().__init__()
+ self.eps = eps
+ self.dim = dim
+ self.cond_dim = cond_dim
+
+ # Dense layer for adaptive normalization (if cond_dim is provided)
+ if cond_dim is not None:
+ #self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16)
+ self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
+ # Initialize with zeros (matches source implementation)
+ nn.init.zeros_(self.dense.weight)
+ else:
+ self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16))
+ self.dense = None
+
+ def _norm(self, x):
+ # Compute variance in float32 (like the source implementation)
+ var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
+ # Compute normalization in float32
+ normed_inputs = x * torch.rsqrt(var + self.eps)
+ return normed_inputs
+
+ def forward(self, x, cond=None):
+ dtype = x.dtype # original dtype, could be half-precision
+ normed_inputs = self._norm(x)
+
+ if cond is None or self.dense is None:
+ # regular RMSNorm
+ # scale by learned parameter in float32 (matches source implementation)
+ normed_inputs = normed_inputs * (1.0 + self.weight.float())
+ return normed_inputs.to(dtype), None # return in original dtype with None gate
+
+ # adaptive RMSNorm (if cond is provided and dense layer exists)
+ if cond.shape[-1] != self.cond_dim:
+ raise ValueError(f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}")
+
+ #self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32)
+ modulation = self.dense(cond)
+ # Reshape modulation to broadcast properly: [batch, 1, features] for [batch, seq, features]
+ if len(x.shape) == 3: # [batch, seq, features]
+ modulation = modulation.unsqueeze(1)
+
+ scale, shift, gate = torch.chunk(modulation, 3, dim=-1)
+
+ # Apply adaptive normalization: use model weight dtype to ensure compatibility
+ # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16)
+ # scale = scale.to(model_dtype)
+ # shift = shift.to(model_dtype)
+ # gate = gate.to(model_dtype)
+ # normed_inputs = normed_inputs.to(model_dtype) # Convert normed_inputs to model dtype
+
+ normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to(torch.float32)
+
+ return normed_inputs.to(dtype), gate.to(dtype)
+
+ def extra_repr(self):
+ repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}"
+ if self.dense is not None:
+ repr_str += f", adaptive=True, cond_dim={self.cond_dim}"
+ return repr_str
+
+
+class GemmaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class GemmaRotaryEmbedding(nn.Module):
+ def __init__(self, config: GemmaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def _gated_residual(x, y, gate):
+ """
+ Applies gated residual connection with optional gate parameter.
+
+ Args:
+ x: Input tensor (residual)
+ y: Output tensor to be added
+ gate: Optional gate tensor to modulate the addition
+
+ Returns:
+ x + y if gate is None, otherwise x + y * gate
+ """
+ if x is None and y is None:
+ return None
+ if x is None or y is None:
+ return x if x is not None else y
+ if gate is None:
+ return x + y
+ return x + y * gate
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GemmaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ use_cache: bool = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ # Use cache if provided
+ if past_key_value is not None:
+ if use_cache:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ else:
+ key_states = torch.cat([past_key_value[self.layer_idx][0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[self.layer_idx][1], value_states], dim=2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GemmaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = GemmaMLP(config)
+ cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ adarms_cond: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = _gated_residual(residual, hidden_states, gate)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = _gated_residual(residual, hidden_states, gate)
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class GemmaPreTrainedModel(PreTrainedModel):
+ config_class = GemmaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GemmaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_3 = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GemmaRMSNorm):
+ if hasattr(module, 'weight'):
+ module.weight.data.fill_(1.0)
+
+
+@auto_docstring
+class GemmaModel(GemmaPreTrainedModel):
+ def __init__(self, config: GemmaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ cond_dim = getattr(config, 'adarms_cond_dim', None) if getattr(config, 'use_adarms', False) else None
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
+ self.rotary_emb = GemmaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ adarms_cond: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ """
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = create_causal_mask(
+ config=self.config,
+ input_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+ # Convert to bfloat16 if the first layer uses bfloat16
+ if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # normalized
+ # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
+ # See https://github.com/huggingface/transformers/pull/29402
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
+ #hidden_states = hidden_states * normalizer
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ adarms_cond=adarms_cond,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states, _ = self.norm(hidden_states, adarms_cond)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring
+class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = GemmaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ adarms_cond: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
+
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ adarms_cond=adarms_cond,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The Gemma Model transformer with a sequence classification head on top (linear layer).
+
+ [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class GemmaForSequenceClassification(GemmaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = GemmaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ adarms_cond: Optional[torch.Tensor] = None,
+ ) -> SequenceClassifierOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+ """
+
+ transformer_outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ adarms_cond=adarms_cond,
+ )
+ hidden_states = transformer_outputs.last_hidden_state
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring
+class GemmaForTokenClassification(GemmaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = GemmaModel(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ adarms_cond: Optional[torch.Tensor] = None,
+ ) -> TokenClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
+ Condition for ADARMS.
+ """
+
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ adarms_cond=adarms_cond,
+ )
+ sequence_output = outputs.last_hidden_state
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.config)
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "GemmaModel",
+ "GemmaForCausalLM",
+ "GemmaForSequenceClassification",
+ "GemmaForTokenClassification",
+ "GemmaPreTrainedModel",
+]
diff --git a/openpi/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py b/openpi/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf0e948909044f1f97ce053cce4ff0fee0378e8
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/transformers_replace/models/paligemma/modeling_paligemma.py
@@ -0,0 +1,622 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PaliGemmamodel."""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...cache_utils import Cache, HybridCache, StaticCache
+from ...generation import GenerationMixin
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import LossKwargs, ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
+from ..auto import AutoModel
+from .configuration_paligemma import PaliGemmaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for Paligemma outputs, with hidden states and attentions.
+ """
+)
+class PaligemmaModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for PaliGemma causal language model (or autoregressive) outputs.
+ """
+)
+class PaliGemmaCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class PaliGemmaMultiModalProjector(nn.Module):
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__()
+ self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
+
+ def forward(self, image_features):
+ hidden_states = self.linear(image_features)
+
+ return hidden_states
+
+
+@auto_docstring
+class PaliGemmaPreTrainedModel(PreTrainedModel):
+ config_class = PaliGemmaConfig
+ base_model_prefix = ""
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PaliGemmaMultiModalProjector"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ # important: this ported version of PaliGemmaisn't meant for training from scratch - only
+ # inference and fine-tuning
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+@auto_docstring(
+ custom_intro="""
+ The Base Paligemma model which consists of a vision backbone and a language model withou language modeling head.,
+ """
+)
+class PaliGemmaModel(PaliGemmaPreTrainedModel):
+ _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
+ # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
+ accepts_loss_kwargs = False
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
+ self.vocab_size = config.text_config.vocab_size
+
+ language_model = AutoModel.from_config(config=config.text_config)
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.get_input_embeddings with Llava->PaliGemma
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaModel.set_input_embeddings with Llava->PaliGemma
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def set_decoder(self, decoder):
+ self.language_model = decoder
+
+ def get_decoder(self):
+ return self.language_model
+
+ def _update_causal_mask(
+ self,
+ attention_mask,
+ token_type_ids=None,
+ past_key_values=None,
+ cache_position=None,
+ input_tensor=None,
+ is_training: Optional[bool] = None,
+ ):
+ if self.config.text_config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+ is_training = is_training if is_training is not None else self.training
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ min_dtype = torch.finfo(self.dtype).min
+ if input_tensor is None:
+ input_tensor = attention_mask
+
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ elif isinstance(past_key_values, HybridCache):
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else cache_position[0] + sequence_length + 1
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ return attention_mask
+
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
+ )
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
+ if sequence_length != 1:
+ if is_training:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ else:
+ causal_mask[:, :sequence_length] = 0.0
+
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+
+ # First unmask prefix tokens during training
+ if is_training:
+ if token_type_ids is None:
+ raise ValueError("Token type ids must be provided during training")
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
+ )
+
+ # Then apply padding mask (will mask pad tokens)
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ image_outputs = self.vision_tower(pixel_values)
+ selected_image_feature = image_outputs.last_hidden_state
+ image_features = self.multi_modal_projector(selected_image_feature)
+ return image_features
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[tuple, PaligemmaModelOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
+ special_image_mask = input_ids == self.config.image_token_id
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
+ "tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
+ )
+ outputs = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return PaligemmaModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring(
+ custom_intro="""
+ The Base Paligemma model which consists of a vision backbone and a language model without language modeling head.,
+ """
+)
+class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {
+ "^language_model.model": "model.language_model",
+ "^vision_tower": "model.vision_tower",
+ "^multi_modal_projector": "model.multi_modal_projector",
+ "^language_model.lm_head": "lm_head",
+ }
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.model = PaliGemmaModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def get_image_features(self, pixel_values):
+ return self.model.get_image_features(pixel_values)
+
+ # Make modules available throught conditional class for BC
+ @property
+ def language_model(self):
+ return self.model.language_model
+
+ @property
+ def vision_tower(self):
+ return self.model.vision_tower
+
+ @property
+ def multi_modal_projector(self):
+ return self.model.multi_modal_projector
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> Union[tuple, PaliGemmaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ labels=labels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return PaliGemmaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # position_ids in Paligemma are 1-indexed
+ if model_inputs.get("position_ids") is not None:
+ model_inputs["position_ids"] += 1
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ is_training = token_type_ids is not None and labels is not None
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
+ causal_mask = self.model._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
+ )
+ model_inputs["attention_mask"] = causal_mask
+
+ return model_inputs
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = ["PaliGemmaForConditionalGeneration", "PaliGemmaPreTrainedModel", "PaliGemmaModel"]
diff --git a/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py b/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bb3c96aa57aea42b20452d35eb5ecd9f50d4d59
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/check.py
@@ -0,0 +1,4 @@
+import transformers
+
+def check_whether_transformers_replace_is_installed_correctly():
+ return transformers.__version__ == "4.53.2"
\ No newline at end of file
diff --git a/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py b/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ea8acdb31ab5f29de86f4452942137ecb788a1b
--- /dev/null
+++ b/openpi/src/openpi/models_pytorch/transformers_replace/models/siglip/modeling_siglip.py
@@ -0,0 +1,1237 @@
+# coding=utf-8
+# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Siglip model."""
+
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+from ...activations import ACT2FN
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
+from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def _trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+
+
+def trunc_normal_tf_(
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
+) -> torch.Tensor:
+ """Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \\leq \text{mean} \\leq b`.
+
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
+ and the result is subsequently scaled and shifted by the mean and std args.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ """
+ with torch.no_grad():
+ _trunc_normal_(tensor, 0, 1.0, a, b)
+ tensor.mul_(std).add_(mean)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ with torch.no_grad():
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ with torch.no_grad():
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def default_flax_embed_init(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+ """
+)
+# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
+class SiglipVisionModelOutput(ModelOutput):
+ r"""
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+ """
+)
+# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
+class SiglipTextModelOutput(ModelOutput):
+ r"""
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ """
+
+ text_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+@auto_docstring
+# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
+class SiglipOutput(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
+ Contrastive loss for image-text similarity.
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
+ similarity scores.
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
+ similarity scores.
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
+ text_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`SiglipTextModel`].
+ vision_model_output (`BaseModelOutputWithPooling`):
+ The output of the [`SiglipVisionModel`].
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits_per_image: Optional[torch.FloatTensor] = None
+ logits_per_text: Optional[torch.FloatTensor] = None
+ text_embeds: Optional[torch.FloatTensor] = None
+ image_embeds: Optional[torch.FloatTensor] = None
+ text_model_output: BaseModelOutputWithPooling = None
+ vision_model_output: BaseModelOutputWithPooling = None
+
+ def to_tuple(self) -> tuple[Any]:
+ return tuple(
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
+ for k in self.keys()
+ )
+
+
+class SiglipVisionEmbeddings(nn.Module):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embedding.weight.shape[0]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embedding(self.position_ids)
+
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
+ _, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
+class SiglipTextEmbeddings(nn.Module):
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ embed_dim = config.hidden_size
+
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
+ max_position_embedding = self.position_embedding.weight.shape[0]
+
+ if seq_length > max_position_embedding:
+ raise ValueError(
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
+ )
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.token_embedding(input_ids)
+
+ position_embeddings = self.position_embedding(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class SiglipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
+class SiglipMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SiglipEncoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.self_attn = SiglipAttention(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+@auto_docstring
+class SiglipPreTrainedModel(PreTrainedModel):
+ config_class = SiglipConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+
+ _no_split_modules = [
+ "SiglipTextEmbeddings",
+ "SiglipEncoderLayer",
+ "SiglipVisionEmbeddings",
+ "SiglipEncoderLayer",
+ "SiglipMultiheadAttentionPoolingHead",
+ ]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, SiglipVisionEmbeddings):
+ width = (
+ self.config.vision_config.hidden_size
+ if isinstance(self.config, SiglipConfig)
+ else self.config.hidden_size
+ )
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
+ elif isinstance(module, nn.Embedding):
+ default_flax_embed_init(module.weight)
+ elif isinstance(module, SiglipAttention):
+ nn.init.xavier_uniform_(module.q_proj.weight)
+ nn.init.xavier_uniform_(module.k_proj.weight)
+ nn.init.xavier_uniform_(module.v_proj.weight)
+ nn.init.xavier_uniform_(module.out_proj.weight)
+ nn.init.zeros_(module.q_proj.bias)
+ nn.init.zeros_(module.k_proj.bias)
+ nn.init.zeros_(module.v_proj.bias)
+ nn.init.zeros_(module.out_proj.bias)
+ elif isinstance(module, SiglipMLP):
+ nn.init.xavier_uniform_(module.fc1.weight)
+ nn.init.xavier_uniform_(module.fc2.weight)
+ nn.init.normal_(module.fc1.bias, std=1e-6)
+ nn.init.normal_(module.fc2.bias, std=1e-6)
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
+ nn.init.xavier_uniform_(module.probe.data)
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
+ nn.init.zeros_(module.attention.in_proj_bias.data)
+ elif isinstance(module, SiglipModel):
+ logit_scale_init = torch.log(torch.tensor(1.0))
+ module.logit_scale.data.fill_(logit_scale_init)
+ module.logit_bias.data.zero_()
+ elif isinstance(module, SiglipForImageClassification):
+ nn.init.normal_(
+ module.classifier.weight,
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
+ )
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
+class SiglipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SiglipEncoderLayer`].
+
+ Args:
+ config: SiglipConfig
+ """
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ # Ignore copy
+ @can_return_tuple
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutput:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions,
+ )
+
+
+class SiglipTextTransformer(nn.Module):
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+ self.embeddings = SiglipTextEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+ self.head = nn.Linear(embed_dim, config.projection_size)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutputWithPooling:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if input_ids is None:
+ raise ValueError("You have to specify input_ids")
+
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
+ # expand attention_mask
+ if attention_mask is not None and not self._use_flash_attention_2:
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
+ pooled_output = last_hidden_state[:, -1, :]
+ pooled_output = self.head(pooled_output)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The text model from SigLIP without any head or projection on top.
+ """
+)
+class SiglipTextModel(SiglipPreTrainedModel):
+ config_class = SiglipTextConfig
+
+ def __init__(self, config: SiglipTextConfig):
+ super().__init__(config)
+ self.text_model = SiglipTextTransformer(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, value):
+ self.text_model.embeddings.token_embedding = value
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, SiglipTextModel
+
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+
+class SiglipVisionTransformer(nn.Module):
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SiglipVisionEmbeddings(config)
+ self.encoder = SiglipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
+ if self.use_head:
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = False,
+ ) -> BaseModelOutputWithPooling:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+ # Convert to bfloat16 if the encoder uses bfloat16
+ if len(self.encoder.layers) > 0 and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.bfloat16)
+
+ encoder_outputs: BaseModelOutput = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = encoder_outputs.last_hidden_state
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooler_output = self.head(last_hidden_state) if self.use_head else None
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooler_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SiglipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SiglipMLP(config)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+@auto_docstring(
+ custom_intro="""
+ The vision model from SigLIP without any head or projection on top.
+ """
+)
+class SiglipVisionModel(SiglipPreTrainedModel):
+ config_class = SiglipVisionConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SiglipVisionConfig):
+ super().__init__(config)
+
+ self.vision_model = SiglipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> BaseModelOutputWithPooling:
+ r"""
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, SiglipVisionModel
+
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled features
+ ```"""
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+
+@auto_docstring
+class SiglipModel(SiglipPreTrainedModel):
+ config_class = SiglipConfig
+
+ def __init__(self, config: SiglipConfig):
+ super().__init__(config)
+
+ if not isinstance(config.text_config, SiglipTextConfig):
+ raise TypeError(
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
+ f" {type(config.text_config)}."
+ )
+
+ if not isinstance(config.vision_config, SiglipVisionConfig):
+ raise TypeError(
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
+ f" {type(config.vision_config)}."
+ )
+
+ text_config = config.text_config
+ vision_config = config.vision_config
+
+ # First, initialize the text and vision models with proper attention implementation
+ text_model = SiglipTextModel._from_config(text_config)
+ vision_model = SiglipVisionModel._from_config(vision_config)
+
+ # Second, get the text and vision submodules (for backward compatibility)
+ self.text_model = text_model.text_model
+ self.vision_model = vision_model.vision_model
+
+ self.logit_scale = nn.Parameter(torch.randn(1))
+ self.logit_bias = nn.Parameter(torch.randn(1))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def get_text_features(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
+ >>> with torch.no_grad():
+ ... text_features = model.get_text_features(**inputs)
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ pooled_output = text_outputs.pooler_output
+
+ return pooled_output
+
+ @auto_docstring
+ def get_image_features(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.FloatTensor:
+ r"""
+ Returns:
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... image_features = model.get_image_features(**inputs)
+ ```"""
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ pooled_output = vision_outputs.pooler_output
+
+ return pooled_output
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ return_loss: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> SiglipOutput:
+ r"""
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+
+ Examples:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModel
+ >>> import torch
+
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+ >>> # important: we pass `padding=max_length` since the model was trained with this
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> logits_per_image = outputs.logits_per_image
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
+ 31.9% that image 0 is 'a photo of 2 cats'
+ ```"""
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ text_outputs: BaseModelOutputWithPooling = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ image_embeds = vision_outputs.pooler_output
+ text_embeds = text_outputs.pooler_output
+
+ # normalized features
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
+
+ logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
+ logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
+
+ logits_per_image = logits_per_text.t()
+
+ loss = None
+ if return_loss:
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
+ nll = -torch.sum(loglik, dim=-1)
+ loss = nll.mean()
+
+ return SiglipOutput(
+ loss=loss,
+ logits_per_image=logits_per_image,
+ logits_per_text=logits_per_text,
+ text_embeds=text_embeds,
+ image_embeds=image_embeds,
+ text_model_output=text_outputs,
+ vision_model_output=vision_outputs,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
+ the patch tokens) e.g. for ImageNet.
+ """
+)
+class SiglipForImageClassification(SiglipPreTrainedModel):
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: SiglipConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+
+ # Create the vision model with proper attention
+ # and take only vision_model submodule (for backward compatibility)
+ vision_model = SiglipVisionModel._from_config(config.vision_config)
+ self.vision_model = vision_model.vision_model
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> ImageClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, SiglipForImageClassification
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> # note: we are loading a `SiglipModel` from the hub here,
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
+ >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ >>> # model predicts one of the two classes
+ >>> predicted_class_idx = logits.argmax(-1).item()
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
+ Predicted class: LABEL_1
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ outputs: BaseModelOutputWithPooling = self.vision_model(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ # average pool the patch tokens
+ sequence_output = torch.mean(sequence_output, dim=1)
+ # apply classifier
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "SiglipModel",
+ "SiglipPreTrainedModel",
+ "SiglipTextModel",
+ "SiglipVisionModel",
+ "SiglipForImageClassification",
+]
\ No newline at end of file
diff --git a/openpi/src/openpi/policies/__pycache__/policy_config.cpython-311.pyc b/openpi/src/openpi/policies/__pycache__/policy_config.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe415f376218c0a9a0ceab56759780682b1af2e2
Binary files /dev/null and b/openpi/src/openpi/policies/__pycache__/policy_config.cpython-311.pyc differ
diff --git a/openpi/src/openpi/policies/aloha_policy.py b/openpi/src/openpi/policies/aloha_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16be3334584692ba505553220ab20ff7b2719c5
--- /dev/null
+++ b/openpi/src/openpi/policies/aloha_policy.py
@@ -0,0 +1,202 @@
+import dataclasses
+from typing import ClassVar
+
+import einops
+import numpy as np
+
+from openpi import transforms
+
+
+def make_aloha_example() -> dict:
+ """Creates a random input example for the Aloha policy."""
+ return {
+ "state": np.ones((14,)),
+ "images": {
+ "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
+ },
+ "prompt": "do something",
+ }
+
+
+@dataclasses.dataclass(frozen=True)
+class AlohaInputs(transforms.DataTransformFn):
+ """Inputs for the Aloha policy.
+
+ Expected inputs:
+ - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
+ - state: [14]
+ - actions: [action_horizon, 14]
+ """
+
+ # If true, this will convert the joint and gripper values from the standard Aloha space to
+ # the space used by the pi internal runtime which was used to train the base model.
+ adapt_to_pi: bool = True
+
+ # The expected cameras names. All input cameras must be in this set. Missing cameras will be
+ # replaced with black images and the corresponding `image_mask` will be set to False.
+ EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
+
+ def __call__(self, data: dict) -> dict:
+ data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
+
+ in_images = data["images"]
+ if set(in_images) - set(self.EXPECTED_CAMERAS):
+ raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
+
+ # Assume that base image always exists.
+ base_image = in_images["cam_high"]
+
+ images = {
+ "base_0_rgb": base_image,
+ }
+ image_masks = {
+ "base_0_rgb": np.True_,
+ }
+
+ # Add the extra images.
+ extra_image_names = {
+ "left_wrist_0_rgb": "cam_left_wrist",
+ "right_wrist_0_rgb": "cam_right_wrist",
+ }
+ for dest, source in extra_image_names.items():
+ if source in in_images:
+ images[dest] = in_images[source]
+ image_masks[dest] = np.True_
+ else:
+ images[dest] = np.zeros_like(base_image)
+ image_masks[dest] = np.False_
+
+ inputs = {
+ "image": images,
+ "image_mask": image_masks,
+ "state": data["state"],
+ }
+
+ # Actions are only available during training.
+ if "actions" in data:
+ actions = np.asarray(data["actions"])
+ actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
+ inputs["actions"] = actions
+
+ if "prompt" in data:
+ inputs["prompt"] = data["prompt"]
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class AlohaOutputs(transforms.DataTransformFn):
+ """Outputs for the Aloha policy."""
+
+ # If true, this will convert the joint and gripper values from the standard Aloha space to
+ # the space used by the pi internal runtime which was used to train the base model.
+ adapt_to_pi: bool = True
+
+ def __call__(self, data: dict) -> dict:
+ # Only return the first 14 dims.
+ actions = np.asarray(data["actions"][:, :14])
+ return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
+
+
+def _joint_flip_mask() -> np.ndarray:
+ """Used to convert between aloha and pi joint angles."""
+ return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
+
+
+def _normalize(x, min_val, max_val):
+ return (x - min_val) / (max_val - min_val)
+
+
+def _unnormalize(x, min_val, max_val):
+ return x * (max_val - min_val) + min_val
+
+
+def _gripper_to_angular(value):
+ # Aloha transforms the gripper positions into a linear space. The following code
+ # reverses this transformation to be consistent with pi0 which is pretrained in
+ # angular space.
+ #
+ # These values are coming from the Aloha code:
+ # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
+ value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
+
+ # This is the inverse of the angular to linear transformation inside the Interbotix code.
+ def linear_to_radian(linear_position, arm_length, horn_radius):
+ value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
+ return np.arcsin(np.clip(value, -1.0, 1.0))
+
+ # The constants are taken from the Interbotix code.
+ value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
+
+ # pi0 gripper data is normalized (0, 1) between encoder counts (2405, 3110).
+ # There are 4096 total encoder counts and aloha uses a zero of 2048.
+ # Converting this to radians means that the normalized inputs are between (0.5476, 1.6296)
+ return _normalize(value, min_val=0.5476, max_val=1.6296)
+
+
+def _gripper_from_angular(value):
+ # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
+ # Note that the units are still angular but the range is different.
+
+ # We do not scale the output since the trossen model predictions are already in radians.
+ # See the comment in _gripper_to_angular for a derivation of the constant
+ value = value + 0.5476
+
+ # These values are coming from the Aloha code:
+ # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
+ return _normalize(value, min_val=-0.6213, max_val=1.4910)
+
+
+def _gripper_from_angular_inv(value):
+ # Directly inverts the gripper_from_angular function.
+ value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
+ return value - 0.5476
+
+
+def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
+ # state is [left_arm_joint_angles, left_arm_gripper, right_arm_joint_angles, right_arm_gripper]
+ # dim sizes: [6, 1, 6, 1]
+ state = np.asarray(data["state"])
+ state = _decode_state(state, adapt_to_pi=adapt_to_pi)
+
+ def convert_image(img):
+ img = np.asarray(img)
+ # Convert to uint8 if using float images.
+ if np.issubdtype(img.dtype, np.floating):
+ img = (255 * img).astype(np.uint8)
+ # Convert from [channel, height, width] to [height, width, channel].
+ return einops.rearrange(img, "c h w -> h w c")
+
+ images = data["images"]
+ images_dict = {name: convert_image(img) for name, img in images.items()}
+
+ data["images"] = images_dict
+ data["state"] = state
+ return data
+
+
+def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
+ if adapt_to_pi:
+ # Flip the joints.
+ state = _joint_flip_mask() * state
+ # Reverse the gripper transformation that is being applied by the Aloha runtime.
+ state[[6, 13]] = _gripper_to_angular(state[[6, 13]])
+ return state
+
+
+def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
+ if adapt_to_pi:
+ # Flip the joints.
+ actions = _joint_flip_mask() * actions
+ actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])
+ return actions
+
+
+def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
+ if adapt_to_pi:
+ actions = _joint_flip_mask() * actions
+ actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])
+ return actions
diff --git a/openpi/src/openpi/policies/droid_policy.py b/openpi/src/openpi/policies/droid_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..666b7f76fecb2decacc60df09c577c086df76f1d
--- /dev/null
+++ b/openpi/src/openpi/policies/droid_policy.py
@@ -0,0 +1,81 @@
+import dataclasses
+
+import einops
+import numpy as np
+
+from openpi import transforms
+from openpi.models import model as _model
+
+
+def make_droid_example() -> dict:
+ """Creates a random input example for the Droid policy."""
+ return {
+ "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "observation/joint_position": np.random.rand(7),
+ "observation/gripper_position": np.random.rand(1),
+ "prompt": "do something",
+ }
+
+
+def _parse_image(image) -> np.ndarray:
+ image = np.asarray(image)
+ if np.issubdtype(image.dtype, np.floating):
+ image = (255 * image).astype(np.uint8)
+ if image.shape[0] == 3:
+ image = einops.rearrange(image, "c h w -> h w c")
+ return image
+
+
+@dataclasses.dataclass(frozen=True)
+class DroidInputs(transforms.DataTransformFn):
+ # Determines which model will be used.
+ model_type: _model.ModelType
+
+ def __call__(self, data: dict) -> dict:
+ gripper_pos = np.asarray(data["observation/gripper_position"])
+ if gripper_pos.ndim == 0:
+ # Ensure gripper position is a 1D array, not a scalar, so we can concatenate with joint positions
+ gripper_pos = gripper_pos[np.newaxis]
+ state = np.concatenate([data["observation/joint_position"], gripper_pos])
+
+ # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
+ # stores as float32 (C,H,W), gets skipped for policy inference
+ base_image = _parse_image(data["observation/exterior_image_1_left"])
+ wrist_image = _parse_image(data["observation/wrist_image_left"])
+
+ match self.model_type:
+ case _model.ModelType.PI0 | _model.ModelType.PI05:
+ names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
+ images = (base_image, wrist_image, np.zeros_like(base_image))
+ image_masks = (np.True_, np.True_, np.False_)
+ case _model.ModelType.PI0_FAST:
+ names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb")
+ # We don't mask out padding images for FAST models.
+ images = (base_image, np.zeros_like(base_image), wrist_image)
+ image_masks = (np.True_, np.True_, np.True_)
+ case _:
+ raise ValueError(f"Unsupported model type: {self.model_type}")
+
+ inputs = {
+ "state": state,
+ "image": dict(zip(names, images, strict=True)),
+ "image_mask": dict(zip(names, image_masks, strict=True)),
+ }
+
+ if "actions" in data:
+ inputs["actions"] = np.asarray(data["actions"])
+
+ if "prompt" in data:
+ if isinstance(data["prompt"], bytes):
+ data["prompt"] = data["prompt"].decode("utf-8")
+ inputs["prompt"] = data["prompt"]
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class DroidOutputs(transforms.DataTransformFn):
+ def __call__(self, data: dict) -> dict:
+ # Only return the first 8 dims.
+ return {"actions": np.asarray(data["actions"][:, :8])}
diff --git a/openpi/src/openpi/policies/libero_policy.py b/openpi/src/openpi/policies/libero_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..10611f61bed1d7d3a9ca49e6ac455829fda9b61d
--- /dev/null
+++ b/openpi/src/openpi/policies/libero_policy.py
@@ -0,0 +1,100 @@
+import dataclasses
+
+import einops
+import numpy as np
+
+from openpi import transforms
+from openpi.models import model as _model
+
+
+def make_libero_example() -> dict:
+ """Creates a random input example for the Libero policy."""
+ return {
+ "observation/state": np.random.rand(8),
+ "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
+ "prompt": "do something",
+ }
+
+
+def _parse_image(image) -> np.ndarray:
+ image = np.asarray(image)
+ if np.issubdtype(image.dtype, np.floating):
+ image = (255 * image).astype(np.uint8)
+ if image.shape[0] == 3:
+ image = einops.rearrange(image, "c h w -> h w c")
+ return image
+
+
+@dataclasses.dataclass(frozen=True)
+class LiberoInputs(transforms.DataTransformFn):
+ """
+ This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
+
+ For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
+ the correct elements of your dataset into the model.
+ """
+
+ # Determines which model will be used.
+ # Do not change this for your own dataset.
+ model_type: _model.ModelType
+
+ def __call__(self, data: dict) -> dict:
+ # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
+ # stores as float32 (C,H,W), gets skipped for policy inference.
+ # Keep this for your own dataset, but if your dataset stores the images
+ # in a different key than "observation/image" or "observation/wrist_image",
+ # you should change it below.
+ # Pi0 models support three image inputs at the moment: one third-person view,
+ # and two wrist views (left and right). If your dataset does not have a particular type
+ # of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
+ # right wrist image below.
+ base_image = _parse_image(data["observation/image"])
+ wrist_image = _parse_image(data["observation/wrist_image"])
+
+ # Create inputs dict. Do not change the keys in the dict below.
+ inputs = {
+ "state": data["observation/state"],
+ "image": {
+ "base_0_rgb": base_image,
+ "left_wrist_0_rgb": wrist_image,
+ # Pad any non-existent images with zero-arrays of the appropriate shape.
+ "right_wrist_0_rgb": np.zeros_like(base_image),
+ },
+ "image_mask": {
+ "base_0_rgb": np.True_,
+ "left_wrist_0_rgb": np.True_,
+ # We only mask padding images for pi0 model, not pi0-FAST. Do not change this for your own dataset.
+ "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
+ },
+ }
+
+ # Pad actions to the model action dimension. Keep this for your own dataset.
+ # Actions are only available during training.
+ if "actions" in data:
+ inputs["actions"] = data["actions"]
+
+ # Pass the prompt (aka language instruction) to the model.
+ # Keep this for your own dataset (but modify the key if the instruction is not
+ # stored in "prompt"; the output dict always needs to have the key "prompt").
+ if "prompt" in data:
+ inputs["prompt"] = data["prompt"]
+
+ return inputs
+
+
+@dataclasses.dataclass(frozen=True)
+class LiberoOutputs(transforms.DataTransformFn):
+ """
+ This class is used to convert outputs from the model back the the dataset specific format. It is
+ used for inference only.
+
+ For your own dataset, you can copy this class and modify the action dimension based on the comments below.
+ """
+
+ def __call__(self, data: dict) -> dict:
+ # Only return the first N actions -- since we padded actions above to fit the model action
+ # dimension, we need to now parse out the correct number of actions in the return dict.
+ # For Libero, we only return the first 7 actions (since the rest is padding).
+ # For your own dataset, replace `7` with the action dimension of your dataset.
+ return {"actions": np.asarray(data["actions"][:, :7])}
diff --git a/openpi/src/openpi/policies/policy.py b/openpi/src/openpi/policies/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9b708bdcaffa086be4b22b84a29b8cc00c710e4
--- /dev/null
+++ b/openpi/src/openpi/policies/policy.py
@@ -0,0 +1,135 @@
+from collections.abc import Sequence
+import logging
+import pathlib
+import time
+from typing import Any, TypeAlias
+
+import flax
+import flax.traverse_util
+import jax
+import jax.numpy as jnp
+import numpy as np
+from openpi_client import base_policy as _base_policy
+import torch
+from typing_extensions import override
+
+from openpi import transforms as _transforms
+from openpi.models import model as _model
+from openpi.shared import array_typing as at
+from openpi.shared import nnx_utils
+
+BasePolicy: TypeAlias = _base_policy.BasePolicy
+
+
+class Policy(BasePolicy):
+ def __init__(
+ self,
+ model: _model.BaseModel,
+ *,
+ rng: at.KeyArrayLike | None = None,
+ transforms: Sequence[_transforms.DataTransformFn] = (),
+ output_transforms: Sequence[_transforms.DataTransformFn] = (),
+ sample_kwargs: dict[str, Any] | None = None,
+ metadata: dict[str, Any] | None = None,
+ pytorch_device: str = "cpu",
+ is_pytorch: bool = False,
+ ):
+ """Initialize the Policy.
+
+ Args:
+ model: The model to use for action sampling.
+ rng: Random number generator key for JAX models. Ignored for PyTorch models.
+ transforms: Input data transformations to apply before inference.
+ output_transforms: Output data transformations to apply after inference.
+ sample_kwargs: Additional keyword arguments to pass to model.sample_actions.
+ metadata: Additional metadata to store with the policy.
+ pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0").
+ Only relevant when is_pytorch=True.
+ is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model.
+ """
+ self._model = model
+ self._input_transform = _transforms.compose(transforms)
+ self._output_transform = _transforms.compose(output_transforms)
+ self._sample_kwargs = sample_kwargs or {}
+ self._metadata = metadata or {}
+ self._is_pytorch_model = is_pytorch
+ self._pytorch_device = pytorch_device
+
+ if self._is_pytorch_model:
+ self._model = self._model.to(pytorch_device)
+ self._model.eval()
+ self._sample_actions = model.sample_actions
+ else:
+ # JAX model setup
+ self._sample_actions = nnx_utils.module_jit(model.sample_actions)
+ self._rng = rng or jax.random.key(0)
+
+ @override
+ def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: ignore[misc]
+ # Make a copy since transformations may modify the inputs in place.
+ inputs = jax.tree.map(lambda x: x, obs)
+ inputs = self._input_transform(inputs)
+ if not self._is_pytorch_model:
+ # Make a batch and convert to jax.Array.
+ inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
+ self._rng, sample_rng_or_pytorch_device = jax.random.split(self._rng)
+ else:
+ # Convert inputs to PyTorch tensors and move to correct device
+ inputs = jax.tree.map(lambda x: torch.from_numpy(np.array(x)).to(self._pytorch_device)[None, ...], inputs)
+ sample_rng_or_pytorch_device = self._pytorch_device
+
+ # Prepare kwargs for sample_actions
+ sample_kwargs = dict(self._sample_kwargs)
+ if noise is not None:
+ noise = torch.from_numpy(noise).to(self._pytorch_device) if self._is_pytorch_model else jnp.asarray(noise)
+
+ if noise.ndim == 2: # If noise is (action_horizon, action_dim), add batch dimension
+ noise = noise[None, ...] # Make it (1, action_horizon, action_dim)
+ sample_kwargs["noise"] = noise
+
+ observation = _model.Observation.from_dict(inputs)
+ start_time = time.monotonic()
+ outputs = {
+ "state": inputs["state"],
+ "actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs),
+ }
+ model_time = time.monotonic() - start_time
+ if self._is_pytorch_model:
+ outputs = jax.tree.map(lambda x: np.asarray(x[0, ...].detach().cpu()), outputs)
+ else:
+ outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs)
+
+ outputs = self._output_transform(outputs)
+ outputs["policy_timing"] = {
+ "infer_ms": model_time * 1000,
+ }
+ return outputs
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ return self._metadata
+
+
+class PolicyRecorder(_base_policy.BasePolicy):
+ """Records the policy's behavior to disk."""
+
+ def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
+ self._policy = policy
+
+ logging.info(f"Dumping policy records to: {record_dir}")
+ self._record_dir = pathlib.Path(record_dir)
+ self._record_dir.mkdir(parents=True, exist_ok=True)
+ self._record_step = 0
+
+ @override
+ def infer(self, obs: dict) -> dict: # type: ignore[misc]
+ results = self._policy.infer(obs)
+
+ data = {"inputs": obs, "outputs": results}
+ data = flax.traverse_util.flatten_dict(data, sep="/")
+
+ output_path = self._record_dir / f"step_{self._record_step}"
+ self._record_step += 1
+
+ np.save(output_path, np.asarray(data))
+ return results
diff --git a/openpi/src/openpi/policies/policy_config.py b/openpi/src/openpi/policies/policy_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6570df05ed068f297d48326990a1cd10ee68ac5a
--- /dev/null
+++ b/openpi/src/openpi/policies/policy_config.py
@@ -0,0 +1,94 @@
+import logging
+import os
+import pathlib
+from typing import Any
+
+import jax.numpy as jnp
+
+import openpi.models.model as _model
+import openpi.policies.policy as _policy
+import openpi.shared.download as download
+from openpi.training import checkpoints as _checkpoints
+from openpi.training import config as _config
+import openpi.transforms as transforms
+
+
+def create_trained_policy(
+ train_config: _config.TrainConfig,
+ checkpoint_dir: pathlib.Path | str,
+ *,
+ repack_transforms: transforms.Group | None = None,
+ sample_kwargs: dict[str, Any] | None = None,
+ default_prompt: str | None = None,
+ norm_stats: dict[str, transforms.NormStats] | None = None,
+ pytorch_device: str | None = None,
+) -> _policy.Policy:
+ """Create a policy from a trained checkpoint.
+
+ Args:
+ train_config: The training config to use to create the model.
+ checkpoint_dir: The directory to load the model from.
+ repack_transforms: Optional transforms that will be applied before any other transforms.
+ sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
+ kwargs will be used.
+ default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
+ data if it doesn't already exist.
+ norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
+ from the checkpoint directory.
+ pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0").
+ If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu".
+
+ Note:
+ The function automatically detects whether the model is PyTorch-based by checking for the
+ presence of "model.safensors" in the checkpoint directory.
+ """
+ repack_transforms = repack_transforms or transforms.Group()
+ checkpoint_dir = download.maybe_download(str(checkpoint_dir))
+
+ # Check if this is a PyTorch model by looking for model.safetensors
+ weight_path = os.path.join(checkpoint_dir, "model.safetensors")
+ is_pytorch = os.path.exists(weight_path)
+
+ logging.info("Loading model...")
+ if is_pytorch:
+ model = train_config.model.load_pytorch(train_config, weight_path)
+ model.paligemma_with_expert.to_bfloat16_for_selected_params("bfloat16")
+ else:
+ model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
+ data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
+ if norm_stats is None:
+ # We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
+ # that the policy is using the same normalization stats as the original training process.
+ if data_config.asset_id is None:
+ raise ValueError("Asset id is required to load norm stats.")
+ norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
+
+ # Determine the device to use for PyTorch models
+ if is_pytorch and pytorch_device is None:
+ try:
+ import torch
+
+ pytorch_device = "cuda" if torch.cuda.is_available() else "cpu"
+ except ImportError:
+ pytorch_device = "cpu"
+
+ return _policy.Policy(
+ model,
+ transforms=[
+ *repack_transforms.inputs,
+ transforms.InjectDefaultPrompt(default_prompt),
+ *data_config.data_transforms.inputs,
+ transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
+ *data_config.model_transforms.inputs,
+ ],
+ output_transforms=[
+ *data_config.model_transforms.outputs,
+ transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
+ *data_config.data_transforms.outputs,
+ *repack_transforms.outputs,
+ ],
+ sample_kwargs=sample_kwargs,
+ metadata=train_config.policy_metadata,
+ is_pytorch=is_pytorch,
+ pytorch_device=pytorch_device if is_pytorch else None,
+ )
diff --git a/openpi/src/openpi/policies/policy_test.py b/openpi/src/openpi/policies/policy_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5808e5274a9d19384d4277ccbde113926da65a50
--- /dev/null
+++ b/openpi/src/openpi/policies/policy_test.py
@@ -0,0 +1,34 @@
+from openpi_client import action_chunk_broker
+import pytest
+
+from openpi.policies import aloha_policy
+from openpi.policies import policy_config as _policy_config
+from openpi.training import config as _config
+
+
+@pytest.mark.manual
+def test_infer():
+ config = _config.get_config("pi0_aloha_sim")
+ policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
+
+ example = aloha_policy.make_aloha_example()
+ result = policy.infer(example)
+
+ assert result["actions"].shape == (config.model.action_horizon, 14)
+
+
+@pytest.mark.manual
+def test_broker():
+ config = _config.get_config("pi0_aloha_sim")
+ policy = _policy_config.create_trained_policy(config, "gs://openpi-assets/checkpoints/pi0_aloha_sim")
+
+ broker = action_chunk_broker.ActionChunkBroker(
+ policy,
+ # Only execute the first half of the chunk.
+ action_horizon=config.model.action_horizon // 2,
+ )
+
+ example = aloha_policy.make_aloha_example()
+ for _ in range(config.model.action_horizon):
+ outputs = broker.infer(example)
+ assert outputs["actions"].shape == (14,)
diff --git a/openpi/src/openpi/py.typed b/openpi/src/openpi/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openpi/src/openpi/serving/websocket_policy_server.py b/openpi/src/openpi/serving/websocket_policy_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdefa98b879323aca449550bf1985ae877eddd83
--- /dev/null
+++ b/openpi/src/openpi/serving/websocket_policy_server.py
@@ -0,0 +1,90 @@
+import asyncio
+import http
+import logging
+import time
+import traceback
+
+from openpi_client import base_policy as _base_policy
+from openpi_client import msgpack_numpy
+import websockets.asyncio.server as _server
+import websockets.frames
+
+logger = logging.getLogger(__name__)
+
+
+class WebsocketPolicyServer:
+ """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
+
+ Currently only implements the `load` and `infer` methods.
+ """
+
+ def __init__(
+ self,
+ policy: _base_policy.BasePolicy,
+ host: str = "0.0.0.0",
+ port: int | None = None,
+ metadata: dict | None = None,
+ ) -> None:
+ self._policy = policy
+ self._host = host
+ self._port = port
+ self._metadata = metadata or {}
+ logging.getLogger("websockets.server").setLevel(logging.INFO)
+
+ def serve_forever(self) -> None:
+ asyncio.run(self.run())
+
+ async def run(self):
+ async with _server.serve(
+ self._handler,
+ self._host,
+ self._port,
+ compression=None,
+ max_size=None,
+ process_request=_health_check,
+ ) as server:
+ await server.serve_forever()
+
+ async def _handler(self, websocket: _server.ServerConnection):
+ logger.info(f"Connection from {websocket.remote_address} opened")
+ packer = msgpack_numpy.Packer()
+
+ await websocket.send(packer.pack(self._metadata))
+
+ prev_total_time = None
+ while True:
+ try:
+ start_time = time.monotonic()
+ obs = msgpack_numpy.unpackb(await websocket.recv())
+
+ infer_time = time.monotonic()
+ action = self._policy.infer(obs)
+ infer_time = time.monotonic() - infer_time
+
+ action["server_timing"] = {
+ "infer_ms": infer_time * 1000,
+ }
+ if prev_total_time is not None:
+ # We can only record the last total time since we also want to include the send time.
+ action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
+
+ await websocket.send(packer.pack(action))
+ prev_total_time = time.monotonic() - start_time
+
+ except websockets.ConnectionClosed:
+ logger.info(f"Connection from {websocket.remote_address} closed")
+ break
+ except Exception:
+ await websocket.send(traceback.format_exc())
+ await websocket.close(
+ code=websockets.frames.CloseCode.INTERNAL_ERROR,
+ reason="Internal server error. Traceback included in previous frame.",
+ )
+ raise
+
+
+def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
+ if request.path == "/healthz":
+ return connection.respond(http.HTTPStatus.OK, "OK\n")
+ # Continue with the normal request handling.
+ return None
diff --git a/openpi/src/openpi/shared/__init__.py b/openpi/src/openpi/shared/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/openpi/src/openpi/shared/array_typing.py b/openpi/src/openpi/shared/array_typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..569eafef1401d4a34b0289c0efb2125f71d21dfb
--- /dev/null
+++ b/openpi/src/openpi/shared/array_typing.py
@@ -0,0 +1,89 @@
+import contextlib
+import functools as ft
+import inspect
+from typing import TypeAlias, TypeVar, cast
+
+import beartype
+import jax
+import jax._src.tree_util as private_tree_util
+import jax.core
+from jaxtyping import ArrayLike
+from jaxtyping import Bool # noqa: F401
+from jaxtyping import DTypeLike # noqa: F401
+from jaxtyping import Float
+from jaxtyping import Int # noqa: F401
+from jaxtyping import Key # noqa: F401
+from jaxtyping import Num # noqa: F401
+from jaxtyping import PyTree
+from jaxtyping import Real # noqa: F401
+from jaxtyping import UInt8 # noqa: F401
+from jaxtyping import config
+from jaxtyping import jaxtyped
+import jaxtyping._decorator
+import torch
+
+# patch jaxtyping to handle https://github.com/patrick-kidger/jaxtyping/issues/277.
+# the problem is that custom PyTree nodes are sometimes initialized with arbitrary types (e.g., `jax.ShapeDtypeStruct`,
+# `jax.Sharding`, or even