Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .idea/.gitignore +8 -0
- .idea/workspace.xml +12 -0
- handler.py +215 -0
- openpi/.DS_Store +0 -0
- openpi/.dockerignore +3 -0
- openpi/.gitattributes +36 -0
- openpi/.github/CODEOWNERS +16 -0
- openpi/.github/workflows/pre-commit.yml +17 -0
- openpi/.github/workflows/test.yml +31 -0
- openpi/.gitignore +168 -0
- openpi/.gitmodules +6 -0
- openpi/.idea/.gitignore +8 -0
- openpi/.idea/workspace.xml +12 -0
- openpi/.pre-commit-config.yaml +16 -0
- openpi/.python-version +1 -0
- openpi/.vscode/settings.json +11 -0
- openpi/CONTRIBUTING.md +33 -0
- openpi/LICENSE +201 -0
- openpi/README.md +323 -0
- openpi/config.json +85 -0
- openpi/docs/docker.md +25 -0
- openpi/docs/norm_stats.md +69 -0
- openpi/docs/remote_inference.md +71 -0
- openpi/examples/aloha_real/Dockerfile +70 -0
- openpi/examples/aloha_real/README.md +126 -0
- openpi/examples/aloha_real/compose.yml +66 -0
- openpi/examples/aloha_real/constants.py +71 -0
- openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py +272 -0
- openpi/examples/aloha_real/env.py +57 -0
- openpi/examples/aloha_real/main.py +51 -0
- openpi/examples/aloha_real/real_env.py +176 -0
- openpi/examples/aloha_real/requirements.in +18 -0
- openpi/examples/aloha_real/requirements.txt +156 -0
- openpi/examples/aloha_real/robot_utils.py +275 -0
- openpi/examples/aloha_real/video_display.py +36 -0
- openpi/examples/aloha_sim/Dockerfile +41 -0
- openpi/examples/aloha_sim/README.md +36 -0
- openpi/examples/aloha_sim/compose.yml +42 -0
- openpi/examples/aloha_sim/env.py +56 -0
- openpi/examples/aloha_sim/main.py +55 -0
- openpi/examples/aloha_sim/requirements.in +8 -0
- openpi/examples/aloha_sim/requirements.txt +132 -0
- openpi/examples/aloha_sim/saver.py +40 -0
- openpi/examples/convert_jax_model_to_pytorch.py +587 -0
- openpi/examples/droid/README.md +84 -0
- openpi/examples/droid/README_train.md +106 -0
- openpi/examples/droid/compute_droid_nonidle_ranges.py +103 -0
- openpi/examples/droid/convert_droid_data_to_lerobot.py +477 -0
- openpi/examples/droid/main.py +246 -0
.DS_Store
ADDED
|
Binary file (10.2 kB). View file
|
|
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectViewState">
|
| 4 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 5 |
+
<option name="showLibraryContents" value="true" />
|
| 6 |
+
</component>
|
| 7 |
+
<component name="PropertiesComponent">{
|
| 8 |
+
"keyToString": {
|
| 9 |
+
"settings.editor.selected.configurable": "dev.sweep.assistant.settings.SweepSettingsConfigurable"
|
| 10 |
+
}
|
| 11 |
+
}</component>
|
| 12 |
+
</project>
|
handler.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from openpi.policies import policy_config
|
| 11 |
+
from openpi.training import config as train_config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EndpointHandler:
|
| 15 |
+
def __init__(self, path: str = ""):
|
| 16 |
+
"""
|
| 17 |
+
Initialize the handler for pi0 model inference using openpi infrastructure.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
path: Path to the model weights directory
|
| 21 |
+
"""
|
| 22 |
+
# Set model path from environment variable or use provided path
|
| 23 |
+
model_path = os.environ.get("MODEL_PATH", path)
|
| 24 |
+
if not model_path:
|
| 25 |
+
model_path = "weights/pi0"
|
| 26 |
+
|
| 27 |
+
# Load the config.json to determine model type
|
| 28 |
+
config_path = os.path.join(model_path, "config.json")
|
| 29 |
+
with open(config_path, "r") as f:
|
| 30 |
+
model_config = json.load(f)
|
| 31 |
+
|
| 32 |
+
model_type = model_config.get("type", "pi0")
|
| 33 |
+
|
| 34 |
+
# Create training config based on model type
|
| 35 |
+
# This uses the openpi config system
|
| 36 |
+
if model_type == "pi0":
|
| 37 |
+
self.train_config = train_config.get_config("pi0")
|
| 38 |
+
else:
|
| 39 |
+
# Default to pi0 if type not recognized
|
| 40 |
+
self.train_config = train_config.get_config("pi0")
|
| 41 |
+
|
| 42 |
+
# Create trained policy using openpi infrastructure
|
| 43 |
+
# This handles all the model loading, preprocessing, etc.
|
| 44 |
+
self.policy = policy_config.create_trained_policy(
|
| 45 |
+
self.train_config,
|
| 46 |
+
model_path,
|
| 47 |
+
pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Default number of inference steps
|
| 51 |
+
self.default_num_steps = 50
|
| 52 |
+
|
| 53 |
+
def _decode_base64_image(self, base64_str: str) -> np.ndarray:
|
| 54 |
+
"""
|
| 55 |
+
Decode base64 image string to numpy array.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
base64_str: Base64 encoded image string
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
numpy array of shape (H, W, 3) with values in [0, 255]
|
| 62 |
+
"""
|
| 63 |
+
# Remove data URL prefix if present
|
| 64 |
+
if base64_str.startswith("data:image"):
|
| 65 |
+
base64_str = base64_str.split(",", 1)[1]
|
| 66 |
+
|
| 67 |
+
# Decode base64
|
| 68 |
+
image_bytes = base64.b64decode(base64_str)
|
| 69 |
+
|
| 70 |
+
# Convert to PIL Image and then to numpy array
|
| 71 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 72 |
+
image_array = np.array(image)
|
| 73 |
+
|
| 74 |
+
return image_array
|
| 75 |
+
|
| 76 |
+
def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]:
|
| 77 |
+
"""
|
| 78 |
+
Prepare observation dictionary in the format expected by openpi.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
images: Dictionary mapping camera names to base64 encoded images
|
| 82 |
+
state: List of robot state values
|
| 83 |
+
prompt: Optional text prompt
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Observation dictionary in openpi format
|
| 87 |
+
"""
|
| 88 |
+
# Decode and process images
|
| 89 |
+
processed_images = {}
|
| 90 |
+
|
| 91 |
+
# Map input camera names to expected openpi format
|
| 92 |
+
# Based on the config, pi0 expects specific camera names
|
| 93 |
+
camera_mapping = {
|
| 94 |
+
"camera0": "cam_high", # base camera
|
| 95 |
+
"camera1": "cam_left_wrist", # left wrist camera
|
| 96 |
+
"camera2": "cam_right_wrist", # right wrist camera
|
| 97 |
+
# Alternative mappings
|
| 98 |
+
"base_camera": "cam_high",
|
| 99 |
+
"left_wrist": "cam_left_wrist",
|
| 100 |
+
"right_wrist": "cam_right_wrist",
|
| 101 |
+
# Direct mappings
|
| 102 |
+
"cam_high": "cam_high",
|
| 103 |
+
"cam_left_wrist": "cam_left_wrist",
|
| 104 |
+
"cam_right_wrist": "cam_right_wrist"
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
for input_name, image_b64 in images.items():
|
| 108 |
+
# Map to openpi expected name
|
| 109 |
+
openpi_name = camera_mapping.get(input_name, input_name)
|
| 110 |
+
|
| 111 |
+
# Decode image
|
| 112 |
+
image_array = self._decode_base64_image(image_b64)
|
| 113 |
+
|
| 114 |
+
# Resize to expected resolution if needed
|
| 115 |
+
if image_array.shape[:2] != (224, 224):
|
| 116 |
+
image_pil = Image.fromarray(image_array)
|
| 117 |
+
image_resized = image_pil.resize((224, 224))
|
| 118 |
+
image_array = np.array(image_resized)
|
| 119 |
+
|
| 120 |
+
# Convert to format expected by openpi (H, W, C) with uint8
|
| 121 |
+
processed_images[openpi_name] = image_array.astype(np.uint8)
|
| 122 |
+
|
| 123 |
+
# Ensure we have the required cameras, create dummy ones if missing
|
| 124 |
+
required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
|
| 125 |
+
for cam_name in required_cameras:
|
| 126 |
+
if cam_name not in processed_images:
|
| 127 |
+
# Create a black dummy image
|
| 128 |
+
processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8)
|
| 129 |
+
|
| 130 |
+
# Prepare state
|
| 131 |
+
state_array = np.array(state, dtype=np.float32)
|
| 132 |
+
|
| 133 |
+
# Create observation dict in openpi format
|
| 134 |
+
observation = {
|
| 135 |
+
"state": state_array,
|
| 136 |
+
"images": processed_images,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
# Add prompt if provided
|
| 140 |
+
if prompt:
|
| 141 |
+
observation["prompt"] = prompt
|
| 142 |
+
|
| 143 |
+
return observation
|
| 144 |
+
|
| 145 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 146 |
+
"""
|
| 147 |
+
Main inference function called by HuggingFace endpoint.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
data: Input data dictionary containing:
|
| 151 |
+
- inputs: Dictionary with:
|
| 152 |
+
- images: Dict mapping camera names to base64 encoded images
|
| 153 |
+
- state: List of robot state values
|
| 154 |
+
- prompt: Optional text prompt
|
| 155 |
+
- num_actions: Optional, number of actions to predict (default: 50)
|
| 156 |
+
- noise: Optional, noise array for sampling
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
List containing prediction results
|
| 160 |
+
"""
|
| 161 |
+
try:
|
| 162 |
+
inputs = data.get("inputs", {})
|
| 163 |
+
|
| 164 |
+
# Extract inputs
|
| 165 |
+
images = inputs.get("images", {})
|
| 166 |
+
state = inputs.get("state", [])
|
| 167 |
+
prompt = inputs.get("prompt", "")
|
| 168 |
+
num_actions = inputs.get("num_actions", self.default_num_steps)
|
| 169 |
+
noise_input = inputs.get("noise", None)
|
| 170 |
+
|
| 171 |
+
# Validate inputs
|
| 172 |
+
if not images:
|
| 173 |
+
raise ValueError("No images provided")
|
| 174 |
+
if not state:
|
| 175 |
+
raise ValueError("No state provided")
|
| 176 |
+
|
| 177 |
+
# Prepare observation using openpi format
|
| 178 |
+
observation = self._prepare_observation(images, state, prompt)
|
| 179 |
+
|
| 180 |
+
# Prepare noise if provided
|
| 181 |
+
noise = None
|
| 182 |
+
if noise_input is not None:
|
| 183 |
+
noise = np.array(noise_input, dtype=np.float32)
|
| 184 |
+
|
| 185 |
+
# Run inference using openpi policy
|
| 186 |
+
# This handles all the preprocessing, model inference, and postprocessing
|
| 187 |
+
result = self.policy.infer(observation, noise=noise)
|
| 188 |
+
|
| 189 |
+
# Extract actions from result
|
| 190 |
+
actions = result["actions"]
|
| 191 |
+
|
| 192 |
+
# Convert to list format for JSON serialization
|
| 193 |
+
if isinstance(actions, np.ndarray):
|
| 194 |
+
actions_list = actions.tolist()
|
| 195 |
+
else:
|
| 196 |
+
actions_list = actions
|
| 197 |
+
|
| 198 |
+
# Return in expected format
|
| 199 |
+
return [{
|
| 200 |
+
"actions": actions_list,
|
| 201 |
+
"num_actions": len(actions_list),
|
| 202 |
+
"action_horizon": len(actions_list),
|
| 203 |
+
"action_dim": len(actions_list[0]) if actions_list else 0,
|
| 204 |
+
"success": True,
|
| 205 |
+
"metadata": {
|
| 206 |
+
"model_type": self.train_config.model.model_type.value,
|
| 207 |
+
"policy_metadata": getattr(self.policy, '_metadata', {})
|
| 208 |
+
}
|
| 209 |
+
}]
|
| 210 |
+
|
| 211 |
+
except Exception as e:
|
| 212 |
+
return [{
|
| 213 |
+
"error": str(e),
|
| 214 |
+
"success": False
|
| 215 |
+
}]
|
openpi/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
openpi/.dockerignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
checkpoints
|
| 3 |
+
data
|
openpi/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
openpi/.github/CODEOWNERS
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# The CODEOWNERS file defines individuals or teams that are automatically requested for
|
| 2 |
+
# review when someone opens a pull request that modifies certain code. When a draft pull
|
| 3 |
+
# request is marked as ready for review, code owners are automatically notified.
|
| 4 |
+
#
|
| 5 |
+
# See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
| 6 |
+
#
|
| 7 |
+
# This is a comment.
|
| 8 |
+
# Each line is a file pattern followed by one or more owners.
|
| 9 |
+
|
| 10 |
+
# Global owners.
|
| 11 |
+
* @jimmyt857 @Michael-Equi @uzhilinsky
|
| 12 |
+
|
| 13 |
+
src/openpi/models/ @kvablack @uzhilinsky
|
| 14 |
+
src/openpi/training/ @kvablack @uzhilinsky
|
| 15 |
+
|
| 16 |
+
scripts/ @jimmyt857 @kvablack @uzhilinsky
|
openpi/.github/workflows/pre-commit.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: pre-commit
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches:
|
| 5 |
+
- main
|
| 6 |
+
pull_request:
|
| 7 |
+
branches:
|
| 8 |
+
- "*"
|
| 9 |
+
jobs:
|
| 10 |
+
pre-commit:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
env:
|
| 13 |
+
GIT_LFS_SKIP_SMUDGE: true
|
| 14 |
+
steps:
|
| 15 |
+
- uses: actions/checkout@v4
|
| 16 |
+
- uses: actions/setup-python@v3
|
| 17 |
+
- uses: pre-commit/action@v3.0.1
|
openpi/.github/workflows/test.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Test
|
| 2 |
+
on:
|
| 3 |
+
pull_request:
|
| 4 |
+
branches:
|
| 5 |
+
- "*"
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
run_tests:
|
| 9 |
+
name: Run Tests
|
| 10 |
+
runs-on: openpi-verylarge
|
| 11 |
+
env:
|
| 12 |
+
GIT_LFS_SKIP_SMUDGE: true
|
| 13 |
+
steps:
|
| 14 |
+
- uses: actions/checkout@v4
|
| 15 |
+
|
| 16 |
+
- name: Install FFmpeg dependencies
|
| 17 |
+
run: |
|
| 18 |
+
sudo apt-get update
|
| 19 |
+
sudo apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev
|
| 20 |
+
|
| 21 |
+
- name: Install uv
|
| 22 |
+
uses: astral-sh/setup-uv@v5
|
| 23 |
+
|
| 24 |
+
- name: Set up Python
|
| 25 |
+
run: uv python install
|
| 26 |
+
|
| 27 |
+
- name: Install the project
|
| 28 |
+
run: uv sync --all-extras --dev
|
| 29 |
+
|
| 30 |
+
- name: Run tests
|
| 31 |
+
run: uv run pytest --strict-markers -m "not manual"
|
openpi/.gitignore
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data directories.
|
| 2 |
+
assets/
|
| 3 |
+
checkpoints/
|
| 4 |
+
data/
|
| 5 |
+
wandb/
|
| 6 |
+
|
| 7 |
+
# Byte-compiled / optimized / DLL files
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.py[cod]
|
| 10 |
+
*$py.class
|
| 11 |
+
|
| 12 |
+
# C extensions
|
| 13 |
+
*.so
|
| 14 |
+
|
| 15 |
+
# Distribution / packaging
|
| 16 |
+
.Python
|
| 17 |
+
build/
|
| 18 |
+
develop-eggs/
|
| 19 |
+
dist/
|
| 20 |
+
downloads/
|
| 21 |
+
eggs/
|
| 22 |
+
.eggs/
|
| 23 |
+
lib/
|
| 24 |
+
lib64/
|
| 25 |
+
parts/
|
| 26 |
+
sdist/
|
| 27 |
+
var/
|
| 28 |
+
wheels/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py,cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
cover/
|
| 59 |
+
|
| 60 |
+
# Translations
|
| 61 |
+
*.mo
|
| 62 |
+
*.pot
|
| 63 |
+
|
| 64 |
+
# Django stuff:
|
| 65 |
+
*.log
|
| 66 |
+
local_settings.py
|
| 67 |
+
db.sqlite3
|
| 68 |
+
db.sqlite3-journal
|
| 69 |
+
|
| 70 |
+
# Flask stuff:
|
| 71 |
+
instance/
|
| 72 |
+
.webassets-cache
|
| 73 |
+
|
| 74 |
+
# Scrapy stuff:
|
| 75 |
+
.scrapy
|
| 76 |
+
|
| 77 |
+
# Sphinx documentation
|
| 78 |
+
docs/_build/
|
| 79 |
+
|
| 80 |
+
# PyBuilder
|
| 81 |
+
.pybuilder/
|
| 82 |
+
target/
|
| 83 |
+
|
| 84 |
+
# Jupyter Notebook
|
| 85 |
+
.ipynb_checkpoints
|
| 86 |
+
|
| 87 |
+
# IPython
|
| 88 |
+
profile_default/
|
| 89 |
+
ipython_config.py
|
| 90 |
+
|
| 91 |
+
# pyenv
|
| 92 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 94 |
+
# .python-version
|
| 95 |
+
|
| 96 |
+
# pipenv
|
| 97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 100 |
+
# install all needed dependencies.
|
| 101 |
+
#Pipfile.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
openpi/.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "third_party/aloha"]
|
| 2 |
+
path = third_party/aloha
|
| 3 |
+
url = https://github.com/Physical-Intelligence/aloha.git
|
| 4 |
+
[submodule "third_party/libero"]
|
| 5 |
+
path = third_party/libero
|
| 6 |
+
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
|
openpi/.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
openpi/.idea/workspace.xml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectViewState">
|
| 4 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 5 |
+
<option name="showLibraryContents" value="true" />
|
| 6 |
+
</component>
|
| 7 |
+
<component name="PropertiesComponent">{
|
| 8 |
+
"keyToString": {
|
| 9 |
+
"settings.editor.selected.configurable": "dev.sweep.assistant.settings.SweepSettingsConfigurable"
|
| 10 |
+
}
|
| 11 |
+
}</component>
|
| 12 |
+
</project>
|
openpi/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exclude: third_party/
|
| 2 |
+
|
| 3 |
+
repos:
|
| 4 |
+
- repo: https://github.com/astral-sh/uv-pre-commit
|
| 5 |
+
# uv version.
|
| 6 |
+
rev: 0.5.14
|
| 7 |
+
hooks:
|
| 8 |
+
- id: uv-lock
|
| 9 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 10 |
+
# Ruff version.
|
| 11 |
+
rev: v0.8.6
|
| 12 |
+
hooks:
|
| 13 |
+
# Run the linter.
|
| 14 |
+
- id: ruff
|
| 15 |
+
args: [--fix]
|
| 16 |
+
- id: ruff-format
|
openpi/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
openpi/.vscode/settings.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"[python]": {
|
| 3 |
+
"editor.defaultFormatter": "charliermarsh.ruff",
|
| 4 |
+
"editor.formatOnSave": true,
|
| 5 |
+
},
|
| 6 |
+
"python.testing.pytestArgs": [
|
| 7 |
+
"src"
|
| 8 |
+
],
|
| 9 |
+
"python.testing.unittestEnabled": false,
|
| 10 |
+
"python.testing.pytestEnabled": true
|
| 11 |
+
}
|
openpi/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to openpi
|
| 2 |
+
|
| 3 |
+
We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
|
| 4 |
+
|
| 5 |
+
## Issues and feature requests
|
| 6 |
+
|
| 7 |
+
You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
|
| 8 |
+
|
| 9 |
+
If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
|
| 10 |
+
|
| 11 |
+
- Your OS type and version and the version of Python you are using
|
| 12 |
+
- Code that allows us to reproduce your bug, including all dependencies
|
| 13 |
+
- Traceback of any exception
|
| 14 |
+
- Any other information that would help us, such as a screenshot
|
| 15 |
+
|
| 16 |
+
In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
|
| 17 |
+
|
| 18 |
+
If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
|
| 19 |
+
|
| 20 |
+
- The motivation for the feature
|
| 21 |
+
- A description of the problem you are trying to solve or your use case
|
| 22 |
+
- Enough information for us to understand the nature of the request
|
| 23 |
+
- Some information for how you intend to use it (this might help us in understanding the motivation!)
|
| 24 |
+
|
| 25 |
+
We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
|
| 26 |
+
|
| 27 |
+
## Submitting a pull request
|
| 28 |
+
|
| 29 |
+
If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
|
| 30 |
+
|
| 31 |
+
- Make sure that your PR has a clear title and description
|
| 32 |
+
- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
|
| 33 |
+
- Make sure your PR passes all tests
|
openpi/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
openpi/README.md
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# openpi
|
| 2 |
+
|
| 3 |
+
openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).
|
| 4 |
+
|
| 5 |
+
Currently, this repo contains three types of models:
|
| 6 |
+
- the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based vision-language-action model (VLA).
|
| 7 |
+
- the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer.
|
| 8 |
+
- the [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05), an upgraded version of π₀ with better open-world generalization trained with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation). Note that, in this repository, we currently only support the flow matching head for both $\pi_{0.5}$ training and inference.
|
| 9 |
+
|
| 10 |
+
For all models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets.
|
| 11 |
+
|
| 12 |
+
This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see!
|
| 13 |
+
|
| 14 |
+
## Updates
|
| 15 |
+
|
| 16 |
+
- [Sept 2025] We released PyTorch support in openpi.
|
| 17 |
+
- [Sept 2025] We released pi05, an upgraded version of pi0 with better open-world generalization.
|
| 18 |
+
- [Sept 2025]: We have added an [improved idle filter](examples/droid/README_train.md#data-filtering) for DROID training.
|
| 19 |
+
- [Jun 2025]: We have added [instructions](examples/droid/README_train.md) for using `openpi` to train VLAs on the full [DROID dataset](https://droid-dataset.github.io/). This is an approximate open-source implementation of the training pipeline used to train pi0-FAST-DROID.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Requirements
|
| 23 |
+
|
| 24 |
+
To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training.
|
| 25 |
+
|
| 26 |
+
| Mode | Memory Required | Example GPU |
|
| 27 |
+
| ------------------ | --------------- | ------------------ |
|
| 28 |
+
| Inference | > 8 GB | RTX 4090 |
|
| 29 |
+
| Fine-Tuning (LoRA) | > 22.5 GB | RTX 4090 |
|
| 30 |
+
| Fine-Tuning (Full) | > 70 GB | A100 (80GB) / H100 |
|
| 31 |
+
|
| 32 |
+
The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems.
|
| 33 |
+
|
| 34 |
+
## Installation
|
| 35 |
+
|
| 36 |
+
When cloning this repo, make sure to update submodules:
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
|
| 40 |
+
|
| 41 |
+
# Or if you already cloned the repo:
|
| 42 |
+
git submodule update --init --recursive
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
GIT_LFS_SKIP_SMUDGE=1 uv sync
|
| 49 |
+
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
|
| 53 |
+
|
| 54 |
+
**Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details.
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
## Model Checkpoints
|
| 60 |
+
|
| 61 |
+
### Base Models
|
| 62 |
+
We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning.
|
| 63 |
+
|
| 64 |
+
| Model | Use Case | Description | Checkpoint Path |
|
| 65 |
+
| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
|
| 66 |
+
| $\pi_0$ | Fine-Tuning | Base [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_base` |
|
| 67 |
+
| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `gs://openpi-assets/checkpoints/pi0_fast_base` |
|
| 68 |
+
| $\pi_{0.5}$ | Fine-Tuning | Base [π₀.₅ model](https://www.physicalintelligence.company/blog/pi05) for fine-tuning | `gs://openpi-assets/checkpoints/pi05_base` |
|
| 69 |
+
|
| 70 |
+
### Fine-Tuned Models
|
| 71 |
+
We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.
|
| 72 |
+
|
| 73 |
+
| Model | Use Case | Description | Checkpoint Path |
|
| 74 |
+
| ------------------------ | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
|
| 75 |
+
| $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `gs://openpi-assets/checkpoints/pi0_fast_droid` |
|
| 76 |
+
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/): faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `gs://openpi-assets/checkpoints/pi0_droid` |
|
| 77 |
+
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can fold diverse towels 0-shot on ALOHA robot platforms | `gs://openpi-assets/checkpoints/pi0_aloha_towel` |
|
| 78 |
+
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal [ALOHA](https://tonyzhaozh.github.io/aloha/) data: can unpack food from a tupperware container | `gs://openpi-assets/checkpoints/pi0_aloha_tupperware` |
|
| 79 |
+
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on public [ALOHA](https://dit-policy.github.io/) data: can uncap a pen | `gs://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
| 80 |
+
| $\pi_{0.5}$-LIBERO | Inference | $\pi_{0.5}$ model fine-tuned for the [LIBERO](https://libero-project.github.io/datasets) benchmark: gets state-of-the-art performance (see [LIBERO README](examples/libero/README.md)) | `gs://openpi-assets/checkpoints/pi05_libero` |
|
| 81 |
+
| $\pi_{0.5}$-DROID | Inference / Fine-Tuning | $\pi_{0.5}$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/) with [knowledge insulation](https://www.physicalintelligence.company/research/knowledge_insulation): fast inference and good language-following | `gs://openpi-assets/checkpoints/pi05_droid` |
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
By default, checkpoints are automatically downloaded from `gs://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
## Running Inference for a Pre-Trained Model
|
| 90 |
+
|
| 91 |
+
Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model):
|
| 92 |
+
```python
|
| 93 |
+
from openpi.training import config as _config
|
| 94 |
+
from openpi.policies import policy_config
|
| 95 |
+
from openpi.shared import download
|
| 96 |
+
|
| 97 |
+
config = _config.get_config("pi05_droid")
|
| 98 |
+
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")
|
| 99 |
+
|
| 100 |
+
# Create a trained policy.
|
| 101 |
+
policy = policy_config.create_trained_policy(config, checkpoint_dir)
|
| 102 |
+
|
| 103 |
+
# Run inference on a dummy example.
|
| 104 |
+
example = {
|
| 105 |
+
"observation/exterior_image_1_left": ...,
|
| 106 |
+
"observation/wrist_image_left": ...,
|
| 107 |
+
...
|
| 108 |
+
"prompt": "pick up the fork"
|
| 109 |
+
}
|
| 110 |
+
action_chunk = policy.infer(example)["actions"]
|
| 111 |
+
```
|
| 112 |
+
You can also test this out in the [example notebook](examples/inference.ipynb).
|
| 113 |
+
|
| 114 |
+
We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots.
|
| 115 |
+
|
| 116 |
+
**Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate.
|
| 117 |
+
|
| 118 |
+
**Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details.
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
## Fine-Tuning Base Models on Your Own Data
|
| 125 |
+
|
| 126 |
+
We will fine-tune the $\pi_{0.5}$ model on the [LIBERO dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps:
|
| 127 |
+
1. Convert your data to a LeRobot dataset (which we use for training)
|
| 128 |
+
2. Defining training configs and running training
|
| 129 |
+
3. Spinning up a policy server and running inference
|
| 130 |
+
|
| 131 |
+
### 1. Convert your data to a LeRobot dataset
|
| 132 |
+
|
| 133 |
+
We provide a minimal example script for converting LIBERO data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw LIBERO dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with:
|
| 134 |
+
|
| 135 |
+
```bash
|
| 136 |
+
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
**Note:** If you just want to fine-tune on LIBERO, you can skip this step, because our LIBERO fine-tuning configs point to a pre-converted LIBERO dataset. This step is merely an example that you can adapt to your own data.
|
| 140 |
+
|
| 141 |
+
### 2. Defining training configs and running training
|
| 142 |
+
|
| 143 |
+
To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for LIBERO below, which you can modify for your own dataset:
|
| 144 |
+
|
| 145 |
+
- [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the LIBERO environment to the model and vice versa. Will be used for both, training and inference.
|
| 146 |
+
- [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw LIBERO data from LeRobot dataset for training.
|
| 147 |
+
- [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader.
|
| 148 |
+
|
| 149 |
+
We provide example fine-tuning configs for [π₀](src/openpi/training/config.py), [π₀-FAST](src/openpi/training/config.py), and [π₀.₅](src/openpi/training/config.py) on LIBERO data.
|
| 150 |
+
|
| 151 |
+
Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config:
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
uv run scripts/compute_norm_stats.py --config-name pi05_libero
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config):
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_libero --exp-name=my_experiment --overwrite
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
|
| 164 |
+
|
| 165 |
+
**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.
|
| 166 |
+
|
| 167 |
+
### 3. Spinning up a policy server and running inference
|
| 168 |
+
|
| 169 |
+
Once training is complete, we can run inference by spinning up a policy server and then querying it from a LIBERO evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_libero --policy.dir=checkpoints/pi05_libero/my_experiment/20000
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run an evaluation script (or robot runtime) that queries the server.
|
| 176 |
+
|
| 177 |
+
For running the LIBERO eval in particular, we provide (and recommend using) a Dockerized workflow that handles both the policy server and the evaluation script together. See the [LIBERO README](examples/libero/README.md) for more details.
|
| 178 |
+
|
| 179 |
+
If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
### More Examples
|
| 184 |
+
|
| 185 |
+
We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
|
| 186 |
+
- [ALOHA Simulator](examples/aloha_sim)
|
| 187 |
+
- [ALOHA Real](examples/aloha_real)
|
| 188 |
+
- [UR5](examples/ur5)
|
| 189 |
+
|
| 190 |
+
## PyTorch Support
|
| 191 |
+
|
| 192 |
+
openpi now provides PyTorch implementations of π₀ and π₀.₅ models alongside the original JAX versions! The PyTorch implementation has been validated on the LIBERO benchmark (both inference and finetuning). A few features are currently not supported (this may change in the future):
|
| 193 |
+
|
| 194 |
+
- The π₀-FAST model
|
| 195 |
+
- Mixed precision training
|
| 196 |
+
- FSDP (fully-sharded data parallelism) training
|
| 197 |
+
- LoRA (low-rank adaptation) training
|
| 198 |
+
- EMA (exponential moving average) weights during training
|
| 199 |
+
|
| 200 |
+
### Setup
|
| 201 |
+
1. Make sure that you have the latest version of all dependencies installed: `uv sync`
|
| 202 |
+
|
| 203 |
+
2. Double check that you have transformers 4.53.2 installed: `uv pip show transformers`
|
| 204 |
+
|
| 205 |
+
3. Apply the transformers library patches:
|
| 206 |
+
```bash
|
| 207 |
+
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
This overwrites several files in the transformers library with necessary model changes: 1) supporting AdaRMS, 2) correctly controlling the precision of activations, and 3) allowing the KV cache to be used without being updated.
|
| 211 |
+
|
| 212 |
+
**WARNING**: With the default uv link mode (hardlink), this will permanently affect the transformers library in your uv cache, meaning the changes will survive reinstallations of transformers and could even propagate to other projects that use transformers. To fully undo this operation, you must run `uv cache clean transformers`.
|
| 213 |
+
|
| 214 |
+
### Converting JAX Models to PyTorch
|
| 215 |
+
|
| 216 |
+
To convert a JAX model checkpoint to PyTorch format:
|
| 217 |
+
|
| 218 |
+
```bash
|
| 219 |
+
uv run examples/convert_jax_model_to_pytorch.py \
|
| 220 |
+
--checkpoint_dir /path/to/jax/checkpoint \
|
| 221 |
+
--config_name <config name> \
|
| 222 |
+
--output_path /path/to/converted/pytorch/checkpoint
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### Running Inference with PyTorch
|
| 226 |
+
|
| 227 |
+
The PyTorch implementation uses the same API as the JAX version - you only need to change the checkpoint path to point to the converted PyTorch model:
|
| 228 |
+
|
| 229 |
+
```python
|
| 230 |
+
from openpi.training import config as _config
|
| 231 |
+
from openpi.policies import policy_config
|
| 232 |
+
from openpi.shared import download
|
| 233 |
+
|
| 234 |
+
config = _config.get_config("pi05_droid")
|
| 235 |
+
checkpoint_dir = "/path/to/converted/pytorch/checkpoint"
|
| 236 |
+
|
| 237 |
+
# Create a trained policy (automatically detects PyTorch format)
|
| 238 |
+
policy = policy_config.create_trained_policy(config, checkpoint_dir)
|
| 239 |
+
|
| 240 |
+
# Run inference (same API as JAX)
|
| 241 |
+
action_chunk = policy.infer(example)["actions"]
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
### Policy Server with PyTorch
|
| 245 |
+
|
| 246 |
+
The policy server works identically with PyTorch models - just point to the converted checkpoint directory:
|
| 247 |
+
|
| 248 |
+
```bash
|
| 249 |
+
uv run scripts/serve_policy.py policy:checkpoint \
|
| 250 |
+
--policy.config=pi05_droid \
|
| 251 |
+
--policy.dir=/path/to/converted/pytorch/checkpoint
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
### Finetuning with PyTorch
|
| 255 |
+
|
| 256 |
+
To finetune a model in PyTorch:
|
| 257 |
+
|
| 258 |
+
1. Convert the JAX base model to PyTorch format:
|
| 259 |
+
```bash
|
| 260 |
+
uv run examples/convert_jax_model_to_pytorch.py \
|
| 261 |
+
--config_name <config name> \
|
| 262 |
+
--checkpoint_dir /path/to/jax/base/model \
|
| 263 |
+
--output_path /path/to/pytorch/base/model
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
2. Specify the converted PyTorch model path in your config using `pytorch_weight_path`
|
| 267 |
+
|
| 268 |
+
3. Launch training using one of these modes:
|
| 269 |
+
|
| 270 |
+
```bash
|
| 271 |
+
# Single GPU training:
|
| 272 |
+
uv run scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
| 273 |
+
|
| 274 |
+
# Example:
|
| 275 |
+
uv run scripts/train_pytorch.py debug --exp_name pytorch_test
|
| 276 |
+
uv run scripts/train_pytorch.py debug --exp_name pytorch_test --resume # Resume from latest checkpoint
|
| 277 |
+
|
| 278 |
+
# Multi-GPU training (single node):
|
| 279 |
+
uv run torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
| 280 |
+
|
| 281 |
+
# Example:
|
| 282 |
+
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
| 283 |
+
uv run torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
| 284 |
+
|
| 285 |
+
# Multi-Node Training:
|
| 286 |
+
uv run torchrun \
|
| 287 |
+
--nnodes=<num_nodes> \
|
| 288 |
+
--nproc_per_node=<gpus_per_node> \
|
| 289 |
+
--node_rank=<rank_of_node> \
|
| 290 |
+
--master_addr=<master_ip> \
|
| 291 |
+
--master_port=<port> \
|
| 292 |
+
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
### Precision Settings
|
| 296 |
+
|
| 297 |
+
JAX and PyTorch implementations handle precision as follows:
|
| 298 |
+
|
| 299 |
+
**JAX:**
|
| 300 |
+
1. Inference: most weights and computations in bfloat16, with a few computations in float32 for stability
|
| 301 |
+
2. Training: defaults to mixed precision: weights and gradients in float32, (most) activations and computations in bfloat16. You can change to full float32 training by setting `dtype` to float32 in the config.
|
| 302 |
+
|
| 303 |
+
**PyTorch:**
|
| 304 |
+
1. Inference: matches JAX -- most weights and computations in bfloat16, with a few weights converted to float32 for stability
|
| 305 |
+
2. Training: supports either full bfloat16 (default) or full float32. You can change it by setting `pytorch_training_precision` in the config. bfloat16 uses less memory but exhibits higher losses compared to float32. Mixed precision is not yet supported.
|
| 306 |
+
|
| 307 |
+
With torch.compile, inference speed is comparable between JAX and PyTorch.
|
| 308 |
+
|
| 309 |
+
## Troubleshooting
|
| 310 |
+
|
| 311 |
+
We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).
|
| 312 |
+
|
| 313 |
+
| Issue | Resolution |
|
| 314 |
+
| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| 315 |
+
| `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). |
|
| 316 |
+
| Training runs out of GPU memory | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` (or higher) before running training to allow JAX to use more GPU memory. You can also use `--fsdp-devices <n>` where `<n>` is your number of GPUs, to enable [fully-sharded data parallelism](https://engineering.fb.com/2021/07/15/open-source/fsdp/), which reduces memory usage in exchange for slower training (the amount of slowdown depends on your particular setup). If you are still running out of memory, you may way to consider disabling EMA. |
|
| 317 |
+
| Policy server connection errors | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server. |
|
| 318 |
+
| Missing norm stats error when training | Run `scripts/compute_norm_stats.py` with your config name before starting training. |
|
| 319 |
+
| Dataset download fails | Check your internet connection. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`). |
|
| 320 |
+
| CUDA/GPU errors | Verify NVIDIA drivers are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. You do NOT need CUDA libraries installed at a system level --- they will be installed via uv. You may even want to try *uninstalling* system CUDA libraries if you run into CUDA issues, since system libraries can sometimes cause conflicts. |
|
| 321 |
+
| Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
|
| 322 |
+
| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
|
| 323 |
+
| Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
|
openpi/config.json
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"type": "pi0",
|
| 3 |
+
"n_obs_steps": 1,
|
| 4 |
+
"input_features": {
|
| 5 |
+
"observation.state": {
|
| 6 |
+
"type": "STATE",
|
| 7 |
+
"shape": [
|
| 8 |
+
6
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
"observation.images.camera0": {
|
| 12 |
+
"type": "VISUAL",
|
| 13 |
+
"shape": [
|
| 14 |
+
3,
|
| 15 |
+
480,
|
| 16 |
+
640
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
"observation.images.camera1": {
|
| 20 |
+
"type": "VISUAL",
|
| 21 |
+
"shape": [
|
| 22 |
+
3,
|
| 23 |
+
480,
|
| 24 |
+
640
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
"observation.images.camera2": {
|
| 28 |
+
"type": "VISUAL",
|
| 29 |
+
"shape": [
|
| 30 |
+
3,
|
| 31 |
+
480,
|
| 32 |
+
640
|
| 33 |
+
]
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"output_features": {
|
| 37 |
+
"action": {
|
| 38 |
+
"type": "ACTION",
|
| 39 |
+
"shape": [
|
| 40 |
+
6
|
| 41 |
+
]
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"device": "cpu",
|
| 45 |
+
"use_amp": false,
|
| 46 |
+
"push_to_hub": true,
|
| 47 |
+
"repo_id": null,
|
| 48 |
+
"private": null,
|
| 49 |
+
"tags": null,
|
| 50 |
+
"license": null,
|
| 51 |
+
"chunk_size": 50,
|
| 52 |
+
"n_action_steps": 50,
|
| 53 |
+
"normalization_mapping": {
|
| 54 |
+
"VISUAL": "IDENTITY",
|
| 55 |
+
"STATE": "MEAN_STD",
|
| 56 |
+
"ACTION": "MEAN_STD"
|
| 57 |
+
},
|
| 58 |
+
"max_state_dim": 32,
|
| 59 |
+
"max_action_dim": 32,
|
| 60 |
+
"resize_imgs_with_padding": [
|
| 61 |
+
224,
|
| 62 |
+
224
|
| 63 |
+
],
|
| 64 |
+
"empty_cameras": 0,
|
| 65 |
+
"adapt_to_pi_aloha": false,
|
| 66 |
+
"use_delta_joint_actions_aloha": false,
|
| 67 |
+
"tokenizer_max_length": 48,
|
| 68 |
+
"proj_width": 1024,
|
| 69 |
+
"num_steps": 10,
|
| 70 |
+
"use_cache": true,
|
| 71 |
+
"attention_implementation": "eager",
|
| 72 |
+
"freeze_vision_encoder": true,
|
| 73 |
+
"train_expert_only": false,
|
| 74 |
+
"train_state_proj": true,
|
| 75 |
+
"optimizer_lr": 2.5e-05,
|
| 76 |
+
"optimizer_betas": [
|
| 77 |
+
0.9,
|
| 78 |
+
0.95
|
| 79 |
+
],
|
| 80 |
+
"optimizer_eps": 1e-08,
|
| 81 |
+
"optimizer_weight_decay": 1e-10,
|
| 82 |
+
"scheduler_warmup_steps": 1000,
|
| 83 |
+
"scheduler_decay_steps": 30000,
|
| 84 |
+
"scheduler_decay_lr": 2.5e-06
|
| 85 |
+
}
|
openpi/docs/docker.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Docker Setup
|
| 2 |
+
|
| 3 |
+
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
| 4 |
+
|
| 5 |
+
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
|
| 6 |
+
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
|
| 7 |
+
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
| 8 |
+
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
|
| 9 |
+
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
| 13 |
+
|
| 14 |
+
Build the Docker image and start the container with the following command:
|
| 15 |
+
```bash
|
| 16 |
+
docker compose -f scripts/docker/compose.yml up --build
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
To build and run the Docker image for a specific example, use the following command:
|
| 20 |
+
```bash
|
| 21 |
+
docker compose -f examples/<example_name>/compose.yml up --build
|
| 22 |
+
```
|
| 23 |
+
where `<example_name>` is the name of the example you want to run.
|
| 24 |
+
|
| 25 |
+
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
openpi/docs/norm_stats.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Normalization statistics
|
| 2 |
+
|
| 3 |
+
Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
|
| 4 |
+
|
| 5 |
+
## Reloading normalization statistics
|
| 6 |
+
|
| 7 |
+
When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
|
| 8 |
+
|
| 9 |
+
**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
TrainConfig(
|
| 13 |
+
...
|
| 14 |
+
data=LeRobotAlohaDataConfig(
|
| 15 |
+
...
|
| 16 |
+
assets=AssetsConfig(
|
| 17 |
+
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
| 18 |
+
asset_id="trossen",
|
| 19 |
+
),
|
| 20 |
+
),
|
| 21 |
+
)
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
| 25 |
+
|
| 26 |
+
**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
|
| 27 |
+
|
| 28 |
+
**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
## Provided Pre-training Normalization Statistics
|
| 32 |
+
|
| 33 |
+
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
|
| 34 |
+
| Robot | Description | Asset ID |
|
| 35 |
+
|-------|-------------|----------|
|
| 36 |
+
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
|
| 37 |
+
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
|
| 38 |
+
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
|
| 39 |
+
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
|
| 40 |
+
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
|
| 41 |
+
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
|
| 42 |
+
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
|
| 43 |
+
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
|
| 44 |
+
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
## Pi0 Model Action Space Definitions
|
| 48 |
+
|
| 49 |
+
Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
|
| 50 |
+
```
|
| 51 |
+
"dim_0:dim_5": "left arm joint angles",
|
| 52 |
+
"dim_6": "left arm gripper position",
|
| 53 |
+
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
|
| 54 |
+
"dim_13": "right arm gripper position (for bi-manual only)",
|
| 55 |
+
|
| 56 |
+
# For mobile robots:
|
| 57 |
+
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
|
| 61 |
+
|
| 62 |
+
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
|
| 63 |
+
|
| 64 |
+
General info for Pi robots:
|
| 65 |
+
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
|
| 66 |
+
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
|
| 67 |
+
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
|
| 68 |
+
|
| 69 |
+
For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
|
openpi/docs/remote_inference.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Running openpi models remotely
|
| 3 |
+
|
| 4 |
+
We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
|
| 5 |
+
|
| 6 |
+
## Starting a remote policy server
|
| 7 |
+
|
| 8 |
+
To start a remote policy server, you can simply run the following command:
|
| 9 |
+
|
| 10 |
+
```bash
|
| 11 |
+
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
|
| 21 |
+
|
| 22 |
+
## Querying the remote policy server from your robot code
|
| 23 |
+
|
| 24 |
+
We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
|
| 25 |
+
|
| 26 |
+
First, install the `openpi-client` package in your robot environment:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
cd $OPENPI_ROOT/packages/openpi-client
|
| 30 |
+
pip install -e .
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from openpi_client import image_tools
|
| 37 |
+
from openpi_client import websocket_client_policy
|
| 38 |
+
|
| 39 |
+
# Outside of episode loop, initialize the policy client.
|
| 40 |
+
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
|
| 41 |
+
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
|
| 42 |
+
|
| 43 |
+
for step in range(num_steps):
|
| 44 |
+
# Inside the episode loop, construct the observation.
|
| 45 |
+
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
|
| 46 |
+
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
|
| 47 |
+
# The typical resize_size for pre-trained pi0 models is 224.
|
| 48 |
+
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
|
| 49 |
+
observation = {
|
| 50 |
+
"observation/image": image_tools.convert_to_uint8(
|
| 51 |
+
image_tools.resize_with_pad(img, 224, 224)
|
| 52 |
+
),
|
| 53 |
+
"observation/wrist_image": image_tools.convert_to_uint8(
|
| 54 |
+
image_tools.resize_with_pad(wrist_img, 224, 224)
|
| 55 |
+
),
|
| 56 |
+
"observation/state": state,
|
| 57 |
+
"prompt": task_instruction,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# Call the policy server with the current observation.
|
| 61 |
+
# This returns an action chunk of shape (action_horizon, action_dim).
|
| 62 |
+
# Note that you typically only need to call the policy every N steps and execute steps
|
| 63 |
+
# from the predicted action chunk open-loop in the remaining steps.
|
| 64 |
+
action_chunk = client.infer(observation)["actions"]
|
| 65 |
+
|
| 66 |
+
# Execute the actions in the environment.
|
| 67 |
+
...
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
openpi/examples/aloha_real/Dockerfile
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for the Aloha real environment.
|
| 2 |
+
|
| 3 |
+
# Build the container:
|
| 4 |
+
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
| 5 |
+
|
| 6 |
+
# Run the container:
|
| 7 |
+
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
| 8 |
+
|
| 9 |
+
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
|
| 10 |
+
SHELL ["/bin/bash", "-c"]
|
| 11 |
+
|
| 12 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 13 |
+
RUN apt-get update && \
|
| 14 |
+
apt-get install -y --no-install-recommends \
|
| 15 |
+
cmake \
|
| 16 |
+
curl \
|
| 17 |
+
libffi-dev \
|
| 18 |
+
python3-rosdep \
|
| 19 |
+
python3-rosinstall \
|
| 20 |
+
python3-rosinstall-generator \
|
| 21 |
+
whiptail \
|
| 22 |
+
git \
|
| 23 |
+
wget \
|
| 24 |
+
openssh-client \
|
| 25 |
+
ros-noetic-cv-bridge \
|
| 26 |
+
ros-noetic-usb-cam \
|
| 27 |
+
ros-noetic-realsense2-camera \
|
| 28 |
+
keyboard-configuration
|
| 29 |
+
|
| 30 |
+
WORKDIR /root
|
| 31 |
+
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
| 32 |
+
RUN chmod +x xsarm_amd64_install.sh
|
| 33 |
+
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
| 34 |
+
|
| 35 |
+
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
| 36 |
+
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
| 37 |
+
|
| 38 |
+
# Install python 3.10 because this ROS image comes with 3.8
|
| 39 |
+
RUN mkdir /python && \
|
| 40 |
+
cd /python && \
|
| 41 |
+
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
| 42 |
+
tar -zxvf Python-3.10.14.tgz && \
|
| 43 |
+
cd Python-3.10.14 && \
|
| 44 |
+
ls -lhR && \
|
| 45 |
+
./configure --enable-optimizations && \
|
| 46 |
+
make install && \
|
| 47 |
+
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
| 48 |
+
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
| 49 |
+
cd ~ && rm -rf /python && \
|
| 50 |
+
rm -rf /var/lib/apt/lists/*
|
| 51 |
+
|
| 52 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
| 53 |
+
ENV UV_HTTP_TIMEOUT=120
|
| 54 |
+
ENV UV_LINK_MODE=copy
|
| 55 |
+
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
| 56 |
+
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
| 57 |
+
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
| 58 |
+
|
| 59 |
+
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
| 63 |
+
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
| 64 |
+
#!/bin/bash
|
| 65 |
+
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
| 66 |
+
EOF
|
| 67 |
+
RUN chmod +x /usr/local/bin/entrypoint.sh
|
| 68 |
+
|
| 69 |
+
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
| 70 |
+
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
openpi/examples/aloha_real/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run Aloha (Real Robot)
|
| 2 |
+
|
| 3 |
+
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
| 8 |
+
|
| 9 |
+
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
| 10 |
+
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
| 11 |
+
|
| 12 |
+
## With Docker
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
| 16 |
+
docker compose -f examples/aloha_real/compose.yml up --build
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Without Docker
|
| 20 |
+
|
| 21 |
+
Terminal window 1:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
# Create virtual environment
|
| 25 |
+
uv venv --python 3.10 examples/aloha_real/.venv
|
| 26 |
+
source examples/aloha_real/.venv/bin/activate
|
| 27 |
+
uv pip sync examples/aloha_real/requirements.txt
|
| 28 |
+
uv pip install -e packages/openpi-client
|
| 29 |
+
|
| 30 |
+
# Run the robot
|
| 31 |
+
python -m examples.aloha_real.main
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Terminal window 2:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
roslaunch aloha ros_nodes.launch
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Terminal window 3:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## **ALOHA Checkpoint Guide**
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
| 50 |
+
|
| 51 |
+
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
### **Toast Task**
|
| 57 |
+
|
| 58 |
+
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
| 59 |
+
|
| 60 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
|
| 61 |
+
- **Prompt**: "take the toast out of the toaster"
|
| 62 |
+
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
| 63 |
+
- **Object Distribution**:
|
| 64 |
+
- Works on both real toast and rubber fake toast
|
| 65 |
+
- Compatible with standard 2-slice toasters
|
| 66 |
+
- Works with plates of varying colors
|
| 67 |
+
|
| 68 |
+
### **Scene Setup Guidelines**
|
| 69 |
+
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
| 70 |
+
|
| 71 |
+
- The toaster should be positioned in the top-left quadrant of the workspace.
|
| 72 |
+
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
| 73 |
+
- The plate should be placed roughly in the lower-center of the workspace.
|
| 74 |
+
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
### **Towel Task**
|
| 78 |
+
|
| 79 |
+
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
| 80 |
+
|
| 81 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
|
| 82 |
+
- **Prompt**: "fold the towel"
|
| 83 |
+
- **Object Distribution**:
|
| 84 |
+
- Works on towels of varying solid colors
|
| 85 |
+
- Performance is worse on heavily textured or striped towels
|
| 86 |
+
|
| 87 |
+
### **Scene Setup Guidelines**
|
| 88 |
+
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
| 89 |
+
|
| 90 |
+
- The towel should be flattened and roughly centered on the table.
|
| 91 |
+
- Choose a towel that does not blend in with the table surface.
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
### **Tupperware Task**
|
| 95 |
+
|
| 96 |
+
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
| 97 |
+
|
| 98 |
+
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
| 99 |
+
- **Prompt**: "open the tupperware and put the food on the plate"
|
| 100 |
+
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
| 101 |
+
- **Object Distribution**:
|
| 102 |
+
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
| 103 |
+
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
| 104 |
+
- The policy has seen plates of varying solid colors.
|
| 105 |
+
|
| 106 |
+
### **Scene Setup Guidelines**
|
| 107 |
+
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
| 108 |
+
|
| 109 |
+
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
| 110 |
+
- Positioning:
|
| 111 |
+
- Tupperware should be on the left.
|
| 112 |
+
- Plate should be on the right or bottom.
|
| 113 |
+
- The tupperware flap should point toward the plate.
|
| 114 |
+
|
| 115 |
+
## Training on your own Aloha dataset
|
| 116 |
+
|
| 117 |
+
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
| 118 |
+
|
| 119 |
+
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
2. Define a training config that uses the custom dataset.
|
| 123 |
+
|
| 124 |
+
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
| 125 |
+
|
| 126 |
+
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
openpi/examples/aloha_real/compose.yml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f examples/aloha_real/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
runtime:
|
| 5 |
+
image: aloha_real
|
| 6 |
+
depends_on:
|
| 7 |
+
- aloha_ros_nodes
|
| 8 |
+
- ros_master
|
| 9 |
+
- openpi_server
|
| 10 |
+
build:
|
| 11 |
+
context: ../..
|
| 12 |
+
dockerfile: examples/aloha_real/Dockerfile
|
| 13 |
+
init: true
|
| 14 |
+
tty: true
|
| 15 |
+
network_mode: host
|
| 16 |
+
privileged: true
|
| 17 |
+
volumes:
|
| 18 |
+
- $PWD:/app
|
| 19 |
+
- ../../data:/data
|
| 20 |
+
|
| 21 |
+
aloha_ros_nodes:
|
| 22 |
+
image: aloha_real
|
| 23 |
+
depends_on:
|
| 24 |
+
- ros_master
|
| 25 |
+
build:
|
| 26 |
+
context: ../..
|
| 27 |
+
dockerfile: examples/aloha_real/Dockerfile
|
| 28 |
+
init: true
|
| 29 |
+
tty: true
|
| 30 |
+
network_mode: host
|
| 31 |
+
privileged: true
|
| 32 |
+
volumes:
|
| 33 |
+
- /dev:/dev
|
| 34 |
+
command: roslaunch --wait aloha ros_nodes.launch
|
| 35 |
+
|
| 36 |
+
ros_master:
|
| 37 |
+
image: ros:noetic-robot
|
| 38 |
+
network_mode: host
|
| 39 |
+
privileged: true
|
| 40 |
+
command:
|
| 41 |
+
- roscore
|
| 42 |
+
|
| 43 |
+
openpi_server:
|
| 44 |
+
image: openpi_server
|
| 45 |
+
build:
|
| 46 |
+
context: ../..
|
| 47 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 48 |
+
init: true
|
| 49 |
+
tty: true
|
| 50 |
+
network_mode: host
|
| 51 |
+
volumes:
|
| 52 |
+
- $PWD:/app
|
| 53 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 54 |
+
environment:
|
| 55 |
+
- SERVER_ARGS
|
| 56 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 57 |
+
- IS_DOCKER=true
|
| 58 |
+
|
| 59 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 60 |
+
deploy:
|
| 61 |
+
resources:
|
| 62 |
+
reservations:
|
| 63 |
+
devices:
|
| 64 |
+
- driver: nvidia
|
| 65 |
+
count: 1
|
| 66 |
+
capabilities: [gpu]
|
openpi/examples/aloha_real/constants.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
| 2 |
+
# ruff: noqa
|
| 3 |
+
|
| 4 |
+
### Task parameters
|
| 5 |
+
|
| 6 |
+
### ALOHA fixed constants
|
| 7 |
+
DT = 0.001
|
| 8 |
+
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
| 9 |
+
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
| 10 |
+
|
| 11 |
+
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
| 12 |
+
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
| 13 |
+
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
| 14 |
+
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
| 15 |
+
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
| 16 |
+
|
| 17 |
+
# Gripper joint limits (qpos[6])
|
| 18 |
+
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
| 19 |
+
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
| 20 |
+
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
| 21 |
+
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
| 22 |
+
|
| 23 |
+
############################ Helper functions ############################
|
| 24 |
+
|
| 25 |
+
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
| 26 |
+
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
| 27 |
+
)
|
| 28 |
+
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
| 29 |
+
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
| 30 |
+
)
|
| 31 |
+
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
| 32 |
+
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
| 33 |
+
)
|
| 34 |
+
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
| 35 |
+
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
| 36 |
+
)
|
| 37 |
+
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
| 38 |
+
|
| 39 |
+
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
| 40 |
+
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
| 41 |
+
)
|
| 42 |
+
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
| 43 |
+
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
| 44 |
+
)
|
| 45 |
+
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
| 46 |
+
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
| 47 |
+
)
|
| 48 |
+
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
| 49 |
+
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
| 50 |
+
)
|
| 51 |
+
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
| 52 |
+
|
| 53 |
+
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
| 54 |
+
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
| 55 |
+
|
| 56 |
+
MASTER_POS2JOINT = (
|
| 57 |
+
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
| 58 |
+
+ MASTER_GRIPPER_JOINT_CLOSE
|
| 59 |
+
)
|
| 60 |
+
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
| 61 |
+
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
| 62 |
+
)
|
| 63 |
+
PUPPET_POS2JOINT = (
|
| 64 |
+
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
| 65 |
+
+ PUPPET_GRIPPER_JOINT_CLOSE
|
| 66 |
+
)
|
| 67 |
+
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
| 68 |
+
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
openpi/examples/aloha_real/convert_aloha_data_to_lerobot.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
| 3 |
+
|
| 4 |
+
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import dataclasses
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import shutil
|
| 10 |
+
from typing import Literal
|
| 11 |
+
|
| 12 |
+
import h5py
|
| 13 |
+
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
| 14 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 15 |
+
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import tqdm
|
| 19 |
+
import tyro
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclasses.dataclass(frozen=True)
|
| 23 |
+
class DatasetConfig:
|
| 24 |
+
use_videos: bool = True
|
| 25 |
+
tolerance_s: float = 0.0001
|
| 26 |
+
image_writer_processes: int = 10
|
| 27 |
+
image_writer_threads: int = 5
|
| 28 |
+
video_backend: str | None = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_empty_dataset(
|
| 35 |
+
repo_id: str,
|
| 36 |
+
robot_type: str,
|
| 37 |
+
mode: Literal["video", "image"] = "video",
|
| 38 |
+
*,
|
| 39 |
+
has_velocity: bool = False,
|
| 40 |
+
has_effort: bool = False,
|
| 41 |
+
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
| 42 |
+
) -> LeRobotDataset:
|
| 43 |
+
motors = [
|
| 44 |
+
"right_waist",
|
| 45 |
+
"right_shoulder",
|
| 46 |
+
"right_elbow",
|
| 47 |
+
"right_forearm_roll",
|
| 48 |
+
"right_wrist_angle",
|
| 49 |
+
"right_wrist_rotate",
|
| 50 |
+
"right_gripper",
|
| 51 |
+
"left_waist",
|
| 52 |
+
"left_shoulder",
|
| 53 |
+
"left_elbow",
|
| 54 |
+
"left_forearm_roll",
|
| 55 |
+
"left_wrist_angle",
|
| 56 |
+
"left_wrist_rotate",
|
| 57 |
+
"left_gripper",
|
| 58 |
+
]
|
| 59 |
+
cameras = [
|
| 60 |
+
"cam_high",
|
| 61 |
+
"cam_low",
|
| 62 |
+
"cam_left_wrist",
|
| 63 |
+
"cam_right_wrist",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
features = {
|
| 67 |
+
"observation.state": {
|
| 68 |
+
"dtype": "float32",
|
| 69 |
+
"shape": (len(motors),),
|
| 70 |
+
"names": [
|
| 71 |
+
motors,
|
| 72 |
+
],
|
| 73 |
+
},
|
| 74 |
+
"action": {
|
| 75 |
+
"dtype": "float32",
|
| 76 |
+
"shape": (len(motors),),
|
| 77 |
+
"names": [
|
| 78 |
+
motors,
|
| 79 |
+
],
|
| 80 |
+
},
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if has_velocity:
|
| 84 |
+
features["observation.velocity"] = {
|
| 85 |
+
"dtype": "float32",
|
| 86 |
+
"shape": (len(motors),),
|
| 87 |
+
"names": [
|
| 88 |
+
motors,
|
| 89 |
+
],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if has_effort:
|
| 93 |
+
features["observation.effort"] = {
|
| 94 |
+
"dtype": "float32",
|
| 95 |
+
"shape": (len(motors),),
|
| 96 |
+
"names": [
|
| 97 |
+
motors,
|
| 98 |
+
],
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
for cam in cameras:
|
| 102 |
+
features[f"observation.images.{cam}"] = {
|
| 103 |
+
"dtype": mode,
|
| 104 |
+
"shape": (3, 480, 640),
|
| 105 |
+
"names": [
|
| 106 |
+
"channels",
|
| 107 |
+
"height",
|
| 108 |
+
"width",
|
| 109 |
+
],
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
if Path(LEROBOT_HOME / repo_id).exists():
|
| 113 |
+
shutil.rmtree(LEROBOT_HOME / repo_id)
|
| 114 |
+
|
| 115 |
+
return LeRobotDataset.create(
|
| 116 |
+
repo_id=repo_id,
|
| 117 |
+
fps=50,
|
| 118 |
+
robot_type=robot_type,
|
| 119 |
+
features=features,
|
| 120 |
+
use_videos=dataset_config.use_videos,
|
| 121 |
+
tolerance_s=dataset_config.tolerance_s,
|
| 122 |
+
image_writer_processes=dataset_config.image_writer_processes,
|
| 123 |
+
image_writer_threads=dataset_config.image_writer_threads,
|
| 124 |
+
video_backend=dataset_config.video_backend,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
| 129 |
+
with h5py.File(hdf5_files[0], "r") as ep:
|
| 130 |
+
# ignore depth channel, not currently handled
|
| 131 |
+
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def has_velocity(hdf5_files: list[Path]) -> bool:
|
| 135 |
+
with h5py.File(hdf5_files[0], "r") as ep:
|
| 136 |
+
return "/observations/qvel" in ep
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def has_effort(hdf5_files: list[Path]) -> bool:
|
| 140 |
+
with h5py.File(hdf5_files[0], "r") as ep:
|
| 141 |
+
return "/observations/effort" in ep
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
| 145 |
+
imgs_per_cam = {}
|
| 146 |
+
for camera in cameras:
|
| 147 |
+
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
| 148 |
+
|
| 149 |
+
if uncompressed:
|
| 150 |
+
# load all images in RAM
|
| 151 |
+
imgs_array = ep[f"/observations/images/{camera}"][:]
|
| 152 |
+
else:
|
| 153 |
+
import cv2
|
| 154 |
+
|
| 155 |
+
# load one compressed image after the other in RAM and uncompress
|
| 156 |
+
imgs_array = []
|
| 157 |
+
for data in ep[f"/observations/images/{camera}"]:
|
| 158 |
+
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
| 159 |
+
imgs_array = np.array(imgs_array)
|
| 160 |
+
|
| 161 |
+
imgs_per_cam[camera] = imgs_array
|
| 162 |
+
return imgs_per_cam
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_raw_episode_data(
|
| 166 |
+
ep_path: Path,
|
| 167 |
+
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
| 168 |
+
with h5py.File(ep_path, "r") as ep:
|
| 169 |
+
state = torch.from_numpy(ep["/observations/qpos"][:])
|
| 170 |
+
action = torch.from_numpy(ep["/action"][:])
|
| 171 |
+
|
| 172 |
+
velocity = None
|
| 173 |
+
if "/observations/qvel" in ep:
|
| 174 |
+
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
| 175 |
+
|
| 176 |
+
effort = None
|
| 177 |
+
if "/observations/effort" in ep:
|
| 178 |
+
effort = torch.from_numpy(ep["/observations/effort"][:])
|
| 179 |
+
|
| 180 |
+
imgs_per_cam = load_raw_images_per_camera(
|
| 181 |
+
ep,
|
| 182 |
+
[
|
| 183 |
+
"cam_high",
|
| 184 |
+
"cam_low",
|
| 185 |
+
"cam_left_wrist",
|
| 186 |
+
"cam_right_wrist",
|
| 187 |
+
],
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return imgs_per_cam, state, action, velocity, effort
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def populate_dataset(
|
| 194 |
+
dataset: LeRobotDataset,
|
| 195 |
+
hdf5_files: list[Path],
|
| 196 |
+
task: str,
|
| 197 |
+
episodes: list[int] | None = None,
|
| 198 |
+
) -> LeRobotDataset:
|
| 199 |
+
if episodes is None:
|
| 200 |
+
episodes = range(len(hdf5_files))
|
| 201 |
+
|
| 202 |
+
for ep_idx in tqdm.tqdm(episodes):
|
| 203 |
+
ep_path = hdf5_files[ep_idx]
|
| 204 |
+
|
| 205 |
+
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
| 206 |
+
num_frames = state.shape[0]
|
| 207 |
+
|
| 208 |
+
for i in range(num_frames):
|
| 209 |
+
frame = {
|
| 210 |
+
"observation.state": state[i],
|
| 211 |
+
"action": action[i],
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
for camera, img_array in imgs_per_cam.items():
|
| 215 |
+
frame[f"observation.images.{camera}"] = img_array[i]
|
| 216 |
+
|
| 217 |
+
if velocity is not None:
|
| 218 |
+
frame["observation.velocity"] = velocity[i]
|
| 219 |
+
if effort is not None:
|
| 220 |
+
frame["observation.effort"] = effort[i]
|
| 221 |
+
|
| 222 |
+
dataset.add_frame(frame)
|
| 223 |
+
|
| 224 |
+
dataset.save_episode(task=task)
|
| 225 |
+
|
| 226 |
+
return dataset
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def port_aloha(
|
| 230 |
+
raw_dir: Path,
|
| 231 |
+
repo_id: str,
|
| 232 |
+
raw_repo_id: str | None = None,
|
| 233 |
+
task: str = "DEBUG",
|
| 234 |
+
*,
|
| 235 |
+
episodes: list[int] | None = None,
|
| 236 |
+
push_to_hub: bool = True,
|
| 237 |
+
is_mobile: bool = False,
|
| 238 |
+
mode: Literal["video", "image"] = "image",
|
| 239 |
+
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
| 240 |
+
):
|
| 241 |
+
if (LEROBOT_HOME / repo_id).exists():
|
| 242 |
+
shutil.rmtree(LEROBOT_HOME / repo_id)
|
| 243 |
+
|
| 244 |
+
if not raw_dir.exists():
|
| 245 |
+
if raw_repo_id is None:
|
| 246 |
+
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
| 247 |
+
download_raw(raw_dir, repo_id=raw_repo_id)
|
| 248 |
+
|
| 249 |
+
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
| 250 |
+
|
| 251 |
+
dataset = create_empty_dataset(
|
| 252 |
+
repo_id,
|
| 253 |
+
robot_type="mobile_aloha" if is_mobile else "aloha",
|
| 254 |
+
mode=mode,
|
| 255 |
+
has_effort=has_effort(hdf5_files),
|
| 256 |
+
has_velocity=has_velocity(hdf5_files),
|
| 257 |
+
dataset_config=dataset_config,
|
| 258 |
+
)
|
| 259 |
+
dataset = populate_dataset(
|
| 260 |
+
dataset,
|
| 261 |
+
hdf5_files,
|
| 262 |
+
task=task,
|
| 263 |
+
episodes=episodes,
|
| 264 |
+
)
|
| 265 |
+
dataset.consolidate()
|
| 266 |
+
|
| 267 |
+
if push_to_hub:
|
| 268 |
+
dataset.push_to_hub()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
tyro.cli(port_aloha)
|
openpi/examples/aloha_real/env.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional # noqa: UP035
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
from openpi_client import image_tools
|
| 5 |
+
from openpi_client.runtime import environment as _environment
|
| 6 |
+
from typing_extensions import override
|
| 7 |
+
|
| 8 |
+
from examples.aloha_real import real_env as _real_env
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AlohaRealEnvironment(_environment.Environment):
|
| 12 |
+
"""An environment for an Aloha robot on real hardware."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
| 17 |
+
render_height: int = 224,
|
| 18 |
+
render_width: int = 224,
|
| 19 |
+
) -> None:
|
| 20 |
+
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
| 21 |
+
self._render_height = render_height
|
| 22 |
+
self._render_width = render_width
|
| 23 |
+
|
| 24 |
+
self._ts = None
|
| 25 |
+
|
| 26 |
+
@override
|
| 27 |
+
def reset(self) -> None:
|
| 28 |
+
self._ts = self._env.reset()
|
| 29 |
+
|
| 30 |
+
@override
|
| 31 |
+
def is_episode_complete(self) -> bool:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
@override
|
| 35 |
+
def get_observation(self) -> dict:
|
| 36 |
+
if self._ts is None:
|
| 37 |
+
raise RuntimeError("Timestep is not set. Call reset() first.")
|
| 38 |
+
|
| 39 |
+
obs = self._ts.observation
|
| 40 |
+
for k in list(obs["images"].keys()):
|
| 41 |
+
if "_depth" in k:
|
| 42 |
+
del obs["images"][k]
|
| 43 |
+
|
| 44 |
+
for cam_name in obs["images"]:
|
| 45 |
+
img = image_tools.convert_to_uint8(
|
| 46 |
+
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
| 47 |
+
)
|
| 48 |
+
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"state": obs["qpos"],
|
| 52 |
+
"images": obs["images"],
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
@override
|
| 56 |
+
def apply_action(self, action: dict) -> None:
|
| 57 |
+
self._ts = self._env.step(action["actions"])
|
openpi/examples/aloha_real/main.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from openpi_client import action_chunk_broker
|
| 5 |
+
from openpi_client import websocket_client_policy as _websocket_client_policy
|
| 6 |
+
from openpi_client.runtime import runtime as _runtime
|
| 7 |
+
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
| 8 |
+
import tyro
|
| 9 |
+
|
| 10 |
+
from examples.aloha_real import env as _env
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclasses.dataclass
|
| 14 |
+
class Args:
|
| 15 |
+
host: str = "0.0.0.0"
|
| 16 |
+
port: int = 8000
|
| 17 |
+
|
| 18 |
+
action_horizon: int = 25
|
| 19 |
+
|
| 20 |
+
num_episodes: int = 1
|
| 21 |
+
max_episode_steps: int = 1000
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main(args: Args) -> None:
|
| 25 |
+
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
| 26 |
+
host=args.host,
|
| 27 |
+
port=args.port,
|
| 28 |
+
)
|
| 29 |
+
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
| 30 |
+
|
| 31 |
+
metadata = ws_client_policy.get_server_metadata()
|
| 32 |
+
runtime = _runtime.Runtime(
|
| 33 |
+
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
| 34 |
+
agent=_policy_agent.PolicyAgent(
|
| 35 |
+
policy=action_chunk_broker.ActionChunkBroker(
|
| 36 |
+
policy=ws_client_policy,
|
| 37 |
+
action_horizon=args.action_horizon,
|
| 38 |
+
)
|
| 39 |
+
),
|
| 40 |
+
subscribers=[],
|
| 41 |
+
max_hz=50,
|
| 42 |
+
num_episodes=args.num_episodes,
|
| 43 |
+
max_episode_steps=args.max_episode_steps,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
runtime.run()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 51 |
+
tyro.cli(main)
|
openpi/examples/aloha_real/real_env.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
| 2 |
+
# ruff: noqa
|
| 3 |
+
import collections
|
| 4 |
+
import time
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import dm_env
|
| 7 |
+
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
| 8 |
+
from interbotix_xs_msgs.msg import JointSingleCommand
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from examples.aloha_real import constants
|
| 12 |
+
from examples.aloha_real import robot_utils
|
| 13 |
+
|
| 14 |
+
# This is the reset position that is used by the standard Aloha runtime.
|
| 15 |
+
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RealEnv:
|
| 19 |
+
"""
|
| 20 |
+
Environment for real robot bi-manual manipulation
|
| 21 |
+
Action space: [left_arm_qpos (6), # absolute joint position
|
| 22 |
+
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
| 23 |
+
right_arm_qpos (6), # absolute joint position
|
| 24 |
+
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
| 25 |
+
|
| 26 |
+
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
| 27 |
+
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
| 28 |
+
right_arm_qpos (6), # absolute joint position
|
| 29 |
+
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
| 30 |
+
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
| 31 |
+
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
| 32 |
+
right_arm_qvel (6), # absolute joint velocity (rad)
|
| 33 |
+
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
| 34 |
+
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
| 35 |
+
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
| 36 |
+
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
| 37 |
+
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
| 41 |
+
# reset_position = START_ARM_POSE[:6]
|
| 42 |
+
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
| 43 |
+
|
| 44 |
+
self.puppet_bot_left = InterbotixManipulatorXS(
|
| 45 |
+
robot_model="vx300s",
|
| 46 |
+
group_name="arm",
|
| 47 |
+
gripper_name="gripper",
|
| 48 |
+
robot_name="puppet_left",
|
| 49 |
+
init_node=init_node,
|
| 50 |
+
)
|
| 51 |
+
self.puppet_bot_right = InterbotixManipulatorXS(
|
| 52 |
+
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
| 53 |
+
)
|
| 54 |
+
if setup_robots:
|
| 55 |
+
self.setup_robots()
|
| 56 |
+
|
| 57 |
+
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
| 58 |
+
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
| 59 |
+
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
| 60 |
+
self.gripper_command = JointSingleCommand(name="gripper")
|
| 61 |
+
|
| 62 |
+
def setup_robots(self):
|
| 63 |
+
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
| 64 |
+
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
| 65 |
+
|
| 66 |
+
def get_qpos(self):
|
| 67 |
+
left_qpos_raw = self.recorder_left.qpos
|
| 68 |
+
right_qpos_raw = self.recorder_right.qpos
|
| 69 |
+
left_arm_qpos = left_qpos_raw[:6]
|
| 70 |
+
right_arm_qpos = right_qpos_raw[:6]
|
| 71 |
+
left_gripper_qpos = [
|
| 72 |
+
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
| 73 |
+
] # this is position not joint
|
| 74 |
+
right_gripper_qpos = [
|
| 75 |
+
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
| 76 |
+
] # this is position not joint
|
| 77 |
+
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
| 78 |
+
|
| 79 |
+
def get_qvel(self):
|
| 80 |
+
left_qvel_raw = self.recorder_left.qvel
|
| 81 |
+
right_qvel_raw = self.recorder_right.qvel
|
| 82 |
+
left_arm_qvel = left_qvel_raw[:6]
|
| 83 |
+
right_arm_qvel = right_qvel_raw[:6]
|
| 84 |
+
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
| 85 |
+
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
| 86 |
+
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
| 87 |
+
|
| 88 |
+
def get_effort(self):
|
| 89 |
+
left_effort_raw = self.recorder_left.effort
|
| 90 |
+
right_effort_raw = self.recorder_right.effort
|
| 91 |
+
left_robot_effort = left_effort_raw[:7]
|
| 92 |
+
right_robot_effort = right_effort_raw[:7]
|
| 93 |
+
return np.concatenate([left_robot_effort, right_robot_effort])
|
| 94 |
+
|
| 95 |
+
def get_images(self):
|
| 96 |
+
return self.image_recorder.get_images()
|
| 97 |
+
|
| 98 |
+
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
| 99 |
+
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
| 100 |
+
self.gripper_command.cmd = left_gripper_desired_joint
|
| 101 |
+
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
| 102 |
+
|
| 103 |
+
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
| 104 |
+
right_gripper_desired_pos_normalized
|
| 105 |
+
)
|
| 106 |
+
self.gripper_command.cmd = right_gripper_desired_joint
|
| 107 |
+
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
| 108 |
+
|
| 109 |
+
def _reset_joints(self):
|
| 110 |
+
robot_utils.move_arms(
|
| 111 |
+
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def _reset_gripper(self):
|
| 115 |
+
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
|
| 116 |
+
|
| 117 |
+
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
|
| 118 |
+
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
|
| 119 |
+
increase the frequency of motor faults.
|
| 120 |
+
"""
|
| 121 |
+
robot_utils.move_grippers(
|
| 122 |
+
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
| 123 |
+
)
|
| 124 |
+
robot_utils.move_grippers(
|
| 125 |
+
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def get_observation(self):
|
| 129 |
+
obs = collections.OrderedDict()
|
| 130 |
+
obs["qpos"] = self.get_qpos()
|
| 131 |
+
obs["qvel"] = self.get_qvel()
|
| 132 |
+
obs["effort"] = self.get_effort()
|
| 133 |
+
obs["images"] = self.get_images()
|
| 134 |
+
return obs
|
| 135 |
+
|
| 136 |
+
def get_reward(self):
|
| 137 |
+
return 0
|
| 138 |
+
|
| 139 |
+
def reset(self, *, fake=False):
|
| 140 |
+
if not fake:
|
| 141 |
+
# Reboot puppet robot gripper motors
|
| 142 |
+
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
| 143 |
+
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
| 144 |
+
self._reset_joints()
|
| 145 |
+
self._reset_gripper()
|
| 146 |
+
return dm_env.TimeStep(
|
| 147 |
+
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def step(self, action):
|
| 151 |
+
state_len = int(len(action) / 2)
|
| 152 |
+
left_action = action[:state_len]
|
| 153 |
+
right_action = action[state_len:]
|
| 154 |
+
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
| 155 |
+
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
| 156 |
+
self.set_gripper_pose(left_action[-1], right_action[-1])
|
| 157 |
+
time.sleep(constants.DT)
|
| 158 |
+
return dm_env.TimeStep(
|
| 159 |
+
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_action(master_bot_left, master_bot_right):
|
| 164 |
+
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
| 165 |
+
# Arm actions
|
| 166 |
+
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
| 167 |
+
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
| 168 |
+
# Gripper actions
|
| 169 |
+
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
| 170 |
+
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
| 171 |
+
|
| 172 |
+
return action
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
| 176 |
+
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
openpi/examples/aloha_real/requirements.in
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Pillow
|
| 2 |
+
dm_control
|
| 3 |
+
einops
|
| 4 |
+
h5py
|
| 5 |
+
matplotlib
|
| 6 |
+
modern_robotics
|
| 7 |
+
msgpack
|
| 8 |
+
numpy>=1.22.4,<2.0.0
|
| 9 |
+
opencv-python
|
| 10 |
+
packaging
|
| 11 |
+
pexpect
|
| 12 |
+
pyquaternion
|
| 13 |
+
pyrealsense2
|
| 14 |
+
pyyaml
|
| 15 |
+
requests
|
| 16 |
+
rospkg
|
| 17 |
+
tyro
|
| 18 |
+
websockets
|
openpi/examples/aloha_real/requirements.txt
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
| 3 |
+
absl-py==2.1.0
|
| 4 |
+
# via
|
| 5 |
+
# dm-control
|
| 6 |
+
# dm-env
|
| 7 |
+
# labmaze
|
| 8 |
+
# mujoco
|
| 9 |
+
catkin-pkg==1.0.0
|
| 10 |
+
# via rospkg
|
| 11 |
+
certifi==2024.8.30
|
| 12 |
+
# via requests
|
| 13 |
+
charset-normalizer==3.4.0
|
| 14 |
+
# via requests
|
| 15 |
+
contourpy==1.1.1
|
| 16 |
+
# via matplotlib
|
| 17 |
+
cycler==0.12.1
|
| 18 |
+
# via matplotlib
|
| 19 |
+
distro==1.9.0
|
| 20 |
+
# via rospkg
|
| 21 |
+
dm-control==1.0.23
|
| 22 |
+
# via -r examples/aloha_real/requirements.in
|
| 23 |
+
dm-env==1.6
|
| 24 |
+
# via dm-control
|
| 25 |
+
dm-tree==0.1.8
|
| 26 |
+
# via
|
| 27 |
+
# dm-control
|
| 28 |
+
# dm-env
|
| 29 |
+
docstring-parser==0.16
|
| 30 |
+
# via tyro
|
| 31 |
+
docutils==0.20.1
|
| 32 |
+
# via catkin-pkg
|
| 33 |
+
einops==0.8.0
|
| 34 |
+
# via -r examples/aloha_real/requirements.in
|
| 35 |
+
etils==1.3.0
|
| 36 |
+
# via mujoco
|
| 37 |
+
fonttools==4.55.2
|
| 38 |
+
# via matplotlib
|
| 39 |
+
glfw==2.8.0
|
| 40 |
+
# via
|
| 41 |
+
# dm-control
|
| 42 |
+
# mujoco
|
| 43 |
+
h5py==3.11.0
|
| 44 |
+
# via -r examples/aloha_real/requirements.in
|
| 45 |
+
idna==3.10
|
| 46 |
+
# via requests
|
| 47 |
+
importlib-resources==6.4.5
|
| 48 |
+
# via etils
|
| 49 |
+
kiwisolver==1.4.7
|
| 50 |
+
# via matplotlib
|
| 51 |
+
labmaze==1.0.6
|
| 52 |
+
# via dm-control
|
| 53 |
+
lxml==5.3.0
|
| 54 |
+
# via dm-control
|
| 55 |
+
markdown-it-py==3.0.0
|
| 56 |
+
# via rich
|
| 57 |
+
matplotlib==3.7.5
|
| 58 |
+
# via -r examples/aloha_real/requirements.in
|
| 59 |
+
mdurl==0.1.2
|
| 60 |
+
# via markdown-it-py
|
| 61 |
+
modern-robotics==1.1.1
|
| 62 |
+
# via -r examples/aloha_real/requirements.in
|
| 63 |
+
msgpack==1.1.0
|
| 64 |
+
# via -r examples/aloha_real/requirements.in
|
| 65 |
+
mujoco==3.2.3
|
| 66 |
+
# via dm-control
|
| 67 |
+
numpy==1.24.4
|
| 68 |
+
# via
|
| 69 |
+
# -r examples/aloha_real/requirements.in
|
| 70 |
+
# contourpy
|
| 71 |
+
# dm-control
|
| 72 |
+
# dm-env
|
| 73 |
+
# h5py
|
| 74 |
+
# labmaze
|
| 75 |
+
# matplotlib
|
| 76 |
+
# modern-robotics
|
| 77 |
+
# mujoco
|
| 78 |
+
# opencv-python
|
| 79 |
+
# pyquaternion
|
| 80 |
+
# scipy
|
| 81 |
+
opencv-python==4.10.0.84
|
| 82 |
+
# via -r examples/aloha_real/requirements.in
|
| 83 |
+
packaging==24.2
|
| 84 |
+
# via
|
| 85 |
+
# -r examples/aloha_real/requirements.in
|
| 86 |
+
# matplotlib
|
| 87 |
+
pexpect==4.9.0
|
| 88 |
+
# via -r examples/aloha_real/requirements.in
|
| 89 |
+
pillow==10.4.0
|
| 90 |
+
# via
|
| 91 |
+
# -r examples/aloha_real/requirements.in
|
| 92 |
+
# matplotlib
|
| 93 |
+
protobuf==5.29.1
|
| 94 |
+
# via dm-control
|
| 95 |
+
ptyprocess==0.7.0
|
| 96 |
+
# via pexpect
|
| 97 |
+
pygments==2.18.0
|
| 98 |
+
# via rich
|
| 99 |
+
pyopengl==3.1.7
|
| 100 |
+
# via
|
| 101 |
+
# dm-control
|
| 102 |
+
# mujoco
|
| 103 |
+
pyparsing==3.1.4
|
| 104 |
+
# via
|
| 105 |
+
# catkin-pkg
|
| 106 |
+
# dm-control
|
| 107 |
+
# matplotlib
|
| 108 |
+
pyquaternion==0.9.9
|
| 109 |
+
# via -r examples/aloha_real/requirements.in
|
| 110 |
+
pyrealsense2==2.55.1.6486
|
| 111 |
+
# via -r examples/aloha_real/requirements.in
|
| 112 |
+
python-dateutil==2.9.0.post0
|
| 113 |
+
# via
|
| 114 |
+
# catkin-pkg
|
| 115 |
+
# matplotlib
|
| 116 |
+
pyyaml==6.0.2
|
| 117 |
+
# via
|
| 118 |
+
# -r examples/aloha_real/requirements.in
|
| 119 |
+
# rospkg
|
| 120 |
+
requests==2.32.3
|
| 121 |
+
# via
|
| 122 |
+
# -r examples/aloha_real/requirements.in
|
| 123 |
+
# dm-control
|
| 124 |
+
rich==13.9.4
|
| 125 |
+
# via tyro
|
| 126 |
+
rospkg==1.5.1
|
| 127 |
+
# via -r examples/aloha_real/requirements.in
|
| 128 |
+
scipy==1.10.1
|
| 129 |
+
# via dm-control
|
| 130 |
+
setuptools==75.3.0
|
| 131 |
+
# via
|
| 132 |
+
# catkin-pkg
|
| 133 |
+
# dm-control
|
| 134 |
+
# labmaze
|
| 135 |
+
shtab==1.7.1
|
| 136 |
+
# via tyro
|
| 137 |
+
six==1.17.0
|
| 138 |
+
# via python-dateutil
|
| 139 |
+
tqdm==4.67.1
|
| 140 |
+
# via dm-control
|
| 141 |
+
typeguard==4.4.0
|
| 142 |
+
# via tyro
|
| 143 |
+
typing-extensions==4.12.2
|
| 144 |
+
# via
|
| 145 |
+
# etils
|
| 146 |
+
# rich
|
| 147 |
+
# typeguard
|
| 148 |
+
# tyro
|
| 149 |
+
tyro==0.9.2
|
| 150 |
+
# via -r examples/aloha_real/requirements.in
|
| 151 |
+
urllib3==2.2.3
|
| 152 |
+
# via requests
|
| 153 |
+
websockets==14.1
|
| 154 |
+
# via -r examples/aloha_real/requirements.in
|
| 155 |
+
zipp==3.20.2
|
| 156 |
+
# via etils
|
openpi/examples/aloha_real/robot_utils.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
| 2 |
+
# ruff: noqa
|
| 3 |
+
from collections import deque
|
| 4 |
+
import datetime
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
from aloha.msg import RGBGrayscaleImage
|
| 9 |
+
from cv_bridge import CvBridge
|
| 10 |
+
from interbotix_xs_msgs.msg import JointGroupCommand
|
| 11 |
+
from interbotix_xs_msgs.msg import JointSingleCommand
|
| 12 |
+
import numpy as np
|
| 13 |
+
import rospy
|
| 14 |
+
from sensor_msgs.msg import JointState
|
| 15 |
+
|
| 16 |
+
from examples.aloha_real import constants
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ImageRecorder:
|
| 20 |
+
def __init__(self, init_node=True, is_debug=False):
|
| 21 |
+
self.is_debug = is_debug
|
| 22 |
+
self.bridge = CvBridge()
|
| 23 |
+
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
| 24 |
+
|
| 25 |
+
if init_node:
|
| 26 |
+
rospy.init_node("image_recorder", anonymous=True)
|
| 27 |
+
for cam_name in self.camera_names:
|
| 28 |
+
setattr(self, f"{cam_name}_rgb_image", None)
|
| 29 |
+
setattr(self, f"{cam_name}_depth_image", None)
|
| 30 |
+
setattr(self, f"{cam_name}_timestamp", 0.0)
|
| 31 |
+
if cam_name == "cam_high":
|
| 32 |
+
callback_func = self.image_cb_cam_high
|
| 33 |
+
elif cam_name == "cam_low":
|
| 34 |
+
callback_func = self.image_cb_cam_low
|
| 35 |
+
elif cam_name == "cam_left_wrist":
|
| 36 |
+
callback_func = self.image_cb_cam_left_wrist
|
| 37 |
+
elif cam_name == "cam_right_wrist":
|
| 38 |
+
callback_func = self.image_cb_cam_right_wrist
|
| 39 |
+
else:
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
| 42 |
+
if self.is_debug:
|
| 43 |
+
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
| 44 |
+
|
| 45 |
+
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
| 46 |
+
time.sleep(0.5)
|
| 47 |
+
|
| 48 |
+
def image_cb(self, cam_name, data):
|
| 49 |
+
setattr(
|
| 50 |
+
self,
|
| 51 |
+
f"{cam_name}_rgb_image",
|
| 52 |
+
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
| 53 |
+
)
|
| 54 |
+
# setattr(
|
| 55 |
+
# self,
|
| 56 |
+
# f"{cam_name}_depth_image",
|
| 57 |
+
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
| 58 |
+
# )
|
| 59 |
+
setattr(
|
| 60 |
+
self,
|
| 61 |
+
f"{cam_name}_timestamp",
|
| 62 |
+
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
| 63 |
+
)
|
| 64 |
+
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
| 65 |
+
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
| 66 |
+
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
| 67 |
+
if self.is_debug:
|
| 68 |
+
getattr(self, f"{cam_name}_timestamps").append(
|
| 69 |
+
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def image_cb_cam_high(self, data):
|
| 73 |
+
cam_name = "cam_high"
|
| 74 |
+
return self.image_cb(cam_name, data)
|
| 75 |
+
|
| 76 |
+
def image_cb_cam_low(self, data):
|
| 77 |
+
cam_name = "cam_low"
|
| 78 |
+
return self.image_cb(cam_name, data)
|
| 79 |
+
|
| 80 |
+
def image_cb_cam_left_wrist(self, data):
|
| 81 |
+
cam_name = "cam_left_wrist"
|
| 82 |
+
return self.image_cb(cam_name, data)
|
| 83 |
+
|
| 84 |
+
def image_cb_cam_right_wrist(self, data):
|
| 85 |
+
cam_name = "cam_right_wrist"
|
| 86 |
+
return self.image_cb(cam_name, data)
|
| 87 |
+
|
| 88 |
+
def get_images(self):
|
| 89 |
+
image_dict = {}
|
| 90 |
+
for cam_name in self.camera_names:
|
| 91 |
+
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
| 92 |
+
time.sleep(0.00001)
|
| 93 |
+
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
| 94 |
+
depth_image = getattr(self, f"{cam_name}_depth_image")
|
| 95 |
+
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
| 96 |
+
image_dict[cam_name] = rgb_image
|
| 97 |
+
image_dict[f"{cam_name}_depth"] = depth_image
|
| 98 |
+
return image_dict
|
| 99 |
+
|
| 100 |
+
def print_diagnostics(self):
|
| 101 |
+
def dt_helper(l):
|
| 102 |
+
l = np.array(l)
|
| 103 |
+
diff = l[1:] - l[:-1]
|
| 104 |
+
return np.mean(diff)
|
| 105 |
+
|
| 106 |
+
for cam_name in self.camera_names:
|
| 107 |
+
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
| 108 |
+
print(f"{cam_name} {image_freq=:.2f}")
|
| 109 |
+
print()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Recorder:
|
| 113 |
+
def __init__(self, side, init_node=True, is_debug=False):
|
| 114 |
+
self.secs = None
|
| 115 |
+
self.nsecs = None
|
| 116 |
+
self.qpos = None
|
| 117 |
+
self.effort = None
|
| 118 |
+
self.arm_command = None
|
| 119 |
+
self.gripper_command = None
|
| 120 |
+
self.is_debug = is_debug
|
| 121 |
+
|
| 122 |
+
if init_node:
|
| 123 |
+
rospy.init_node("recorder", anonymous=True)
|
| 124 |
+
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
| 125 |
+
rospy.Subscriber(
|
| 126 |
+
f"/puppet_{side}/commands/joint_group",
|
| 127 |
+
JointGroupCommand,
|
| 128 |
+
self.puppet_arm_commands_cb,
|
| 129 |
+
)
|
| 130 |
+
rospy.Subscriber(
|
| 131 |
+
f"/puppet_{side}/commands/joint_single",
|
| 132 |
+
JointSingleCommand,
|
| 133 |
+
self.puppet_gripper_commands_cb,
|
| 134 |
+
)
|
| 135 |
+
if self.is_debug:
|
| 136 |
+
self.joint_timestamps = deque(maxlen=50)
|
| 137 |
+
self.arm_command_timestamps = deque(maxlen=50)
|
| 138 |
+
self.gripper_command_timestamps = deque(maxlen=50)
|
| 139 |
+
time.sleep(0.1)
|
| 140 |
+
|
| 141 |
+
def puppet_state_cb(self, data):
|
| 142 |
+
self.qpos = data.position
|
| 143 |
+
self.qvel = data.velocity
|
| 144 |
+
self.effort = data.effort
|
| 145 |
+
self.data = data
|
| 146 |
+
if self.is_debug:
|
| 147 |
+
self.joint_timestamps.append(time.time())
|
| 148 |
+
|
| 149 |
+
def puppet_arm_commands_cb(self, data):
|
| 150 |
+
self.arm_command = data.cmd
|
| 151 |
+
if self.is_debug:
|
| 152 |
+
self.arm_command_timestamps.append(time.time())
|
| 153 |
+
|
| 154 |
+
def puppet_gripper_commands_cb(self, data):
|
| 155 |
+
self.gripper_command = data.cmd
|
| 156 |
+
if self.is_debug:
|
| 157 |
+
self.gripper_command_timestamps.append(time.time())
|
| 158 |
+
|
| 159 |
+
def print_diagnostics(self):
|
| 160 |
+
def dt_helper(l):
|
| 161 |
+
l = np.array(l)
|
| 162 |
+
diff = l[1:] - l[:-1]
|
| 163 |
+
return np.mean(diff)
|
| 164 |
+
|
| 165 |
+
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
| 166 |
+
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
| 167 |
+
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
| 168 |
+
|
| 169 |
+
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_arm_joint_positions(bot):
|
| 173 |
+
return bot.arm.core.joint_states.position[:6]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_arm_gripper_positions(bot):
|
| 177 |
+
return bot.gripper.core.joint_states.position[6]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def move_arms(bot_list, target_pose_list, move_time=1):
|
| 181 |
+
num_steps = int(move_time / constants.DT)
|
| 182 |
+
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
| 183 |
+
traj_list = [
|
| 184 |
+
np.linspace(curr_pose, target_pose, num_steps)
|
| 185 |
+
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
| 186 |
+
]
|
| 187 |
+
for t in range(num_steps):
|
| 188 |
+
for bot_id, bot in enumerate(bot_list):
|
| 189 |
+
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
| 190 |
+
time.sleep(constants.DT)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def move_grippers(bot_list, target_pose_list, move_time):
|
| 194 |
+
print(f"Moving grippers to {target_pose_list=}")
|
| 195 |
+
gripper_command = JointSingleCommand(name="gripper")
|
| 196 |
+
num_steps = int(move_time / constants.DT)
|
| 197 |
+
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
| 198 |
+
traj_list = [
|
| 199 |
+
np.linspace(curr_pose, target_pose, num_steps)
|
| 200 |
+
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
| 204 |
+
for t in range(num_steps):
|
| 205 |
+
d = {}
|
| 206 |
+
for bot_id, bot in enumerate(bot_list):
|
| 207 |
+
gripper_command.cmd = traj_list[bot_id][t]
|
| 208 |
+
bot.gripper.core.pub_single.publish(gripper_command)
|
| 209 |
+
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
| 210 |
+
f.write(json.dumps(d) + "\n")
|
| 211 |
+
time.sleep(constants.DT)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def setup_puppet_bot(bot):
|
| 215 |
+
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
| 216 |
+
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
| 217 |
+
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
| 218 |
+
torque_on(bot)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def setup_master_bot(bot):
|
| 222 |
+
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
| 223 |
+
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
| 224 |
+
torque_off(bot)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def set_standard_pid_gains(bot):
|
| 228 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
| 229 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def set_low_pid_gains(bot):
|
| 233 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
| 234 |
+
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def torque_off(bot):
|
| 238 |
+
bot.dxl.robot_torque_enable("group", "arm", False)
|
| 239 |
+
bot.dxl.robot_torque_enable("single", "gripper", False)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def torque_on(bot):
|
| 243 |
+
bot.dxl.robot_torque_enable("group", "arm", True)
|
| 244 |
+
bot.dxl.robot_torque_enable("single", "gripper", True)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# for DAgger
|
| 248 |
+
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
| 249 |
+
print("\nSyncing!")
|
| 250 |
+
|
| 251 |
+
# activate master arms
|
| 252 |
+
torque_on(master_bot_left)
|
| 253 |
+
torque_on(master_bot_right)
|
| 254 |
+
|
| 255 |
+
# get puppet arm positions
|
| 256 |
+
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
| 257 |
+
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
| 258 |
+
|
| 259 |
+
# get puppet gripper positions
|
| 260 |
+
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
| 261 |
+
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
| 262 |
+
|
| 263 |
+
# move master arms to puppet positions
|
| 264 |
+
move_arms(
|
| 265 |
+
[master_bot_left, master_bot_right],
|
| 266 |
+
[puppet_left_qpos, puppet_right_qpos],
|
| 267 |
+
move_time=1,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# move master grippers to puppet positions
|
| 271 |
+
move_grippers(
|
| 272 |
+
[master_bot_left, master_bot_right],
|
| 273 |
+
[puppet_left_gripper, puppet_right_gripper],
|
| 274 |
+
move_time=1,
|
| 275 |
+
)
|
openpi/examples/aloha_real/video_display.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
from openpi_client.runtime import subscriber as _subscriber
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class VideoDisplay(_subscriber.Subscriber):
|
| 8 |
+
"""Displays video frames."""
|
| 9 |
+
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
self._ax: plt.Axes | None = None
|
| 12 |
+
self._plt_img: plt.Image | None = None
|
| 13 |
+
|
| 14 |
+
@override
|
| 15 |
+
def on_episode_start(self) -> None:
|
| 16 |
+
plt.ion()
|
| 17 |
+
self._ax = plt.subplot()
|
| 18 |
+
self._plt_img = None
|
| 19 |
+
|
| 20 |
+
@override
|
| 21 |
+
def on_step(self, observation: dict, action: dict) -> None:
|
| 22 |
+
assert self._ax is not None
|
| 23 |
+
|
| 24 |
+
im = observation["image"][0] # [C, H, W]
|
| 25 |
+
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
| 26 |
+
|
| 27 |
+
if self._plt_img is None:
|
| 28 |
+
self._plt_img = self._ax.imshow(im)
|
| 29 |
+
else:
|
| 30 |
+
self._plt_img.set_data(im)
|
| 31 |
+
plt.pause(0.001)
|
| 32 |
+
|
| 33 |
+
@override
|
| 34 |
+
def on_episode_end(self) -> None:
|
| 35 |
+
plt.ioff()
|
| 36 |
+
plt.close()
|
openpi/examples/aloha_sim/Dockerfile
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for the Aloha simulation environment.
|
| 2 |
+
|
| 3 |
+
# Build the container:
|
| 4 |
+
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
| 5 |
+
|
| 6 |
+
# Run the container:
|
| 7 |
+
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
| 8 |
+
|
| 9 |
+
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
| 10 |
+
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && \
|
| 13 |
+
apt-get install -y \
|
| 14 |
+
libosmesa6-dev \
|
| 15 |
+
libgl1-mesa-glx \
|
| 16 |
+
libglew-dev \
|
| 17 |
+
libglfw3-dev \
|
| 18 |
+
libgles2-mesa-dev
|
| 19 |
+
ENV MUJOCO_GL=egl
|
| 20 |
+
|
| 21 |
+
WORKDIR /app
|
| 22 |
+
|
| 23 |
+
# Copy from the cache instead of linking since it's a mounted volume
|
| 24 |
+
ENV UV_LINK_MODE=copy
|
| 25 |
+
|
| 26 |
+
# Write the virtual environment outside of the project directory so it doesn't
|
| 27 |
+
# leak out of the container when we mount the application code.
|
| 28 |
+
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
| 29 |
+
|
| 30 |
+
# Copy the requirements files so we can install dependencies.
|
| 31 |
+
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
| 32 |
+
# This strategy is best for development-style usage.
|
| 33 |
+
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
| 34 |
+
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
| 35 |
+
|
| 36 |
+
# Install python dependencies.
|
| 37 |
+
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
| 38 |
+
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
| 39 |
+
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
| 40 |
+
|
| 41 |
+
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
openpi/examples/aloha_sim/README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run Aloha Sim
|
| 2 |
+
|
| 3 |
+
## With Docker
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
export SERVER_ARGS="--env ALOHA_SIM"
|
| 7 |
+
docker compose -f examples/aloha_sim/compose.yml up --build
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
## Without Docker
|
| 11 |
+
|
| 12 |
+
Terminal window 1:
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
# Create virtual environment
|
| 16 |
+
uv venv --python 3.10 examples/aloha_sim/.venv
|
| 17 |
+
source examples/aloha_sim/.venv/bin/activate
|
| 18 |
+
uv pip sync examples/aloha_sim/requirements.txt
|
| 19 |
+
uv pip install -e packages/openpi-client
|
| 20 |
+
|
| 21 |
+
# Run the simulation
|
| 22 |
+
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Terminal window 2:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# Run the server
|
| 35 |
+
uv run scripts/serve_policy.py --env ALOHA_SIM
|
| 36 |
+
```
|
openpi/examples/aloha_sim/compose.yml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run with:
|
| 2 |
+
# docker compose -f examples/aloha_sim/compose.yml up --build
|
| 3 |
+
services:
|
| 4 |
+
runtime:
|
| 5 |
+
image: aloha_sim
|
| 6 |
+
depends_on:
|
| 7 |
+
- openpi_server
|
| 8 |
+
build:
|
| 9 |
+
context: ../..
|
| 10 |
+
dockerfile: examples/aloha_sim/Dockerfile
|
| 11 |
+
init: true
|
| 12 |
+
tty: true
|
| 13 |
+
network_mode: host
|
| 14 |
+
privileged: true
|
| 15 |
+
volumes:
|
| 16 |
+
- $PWD:/app
|
| 17 |
+
- ../../data:/data
|
| 18 |
+
|
| 19 |
+
openpi_server:
|
| 20 |
+
image: openpi_server
|
| 21 |
+
build:
|
| 22 |
+
context: ../..
|
| 23 |
+
dockerfile: scripts/docker/serve_policy.Dockerfile
|
| 24 |
+
init: true
|
| 25 |
+
tty: true
|
| 26 |
+
network_mode: host
|
| 27 |
+
volumes:
|
| 28 |
+
- $PWD:/app
|
| 29 |
+
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
| 30 |
+
environment:
|
| 31 |
+
- SERVER_ARGS
|
| 32 |
+
- OPENPI_DATA_HOME=/openpi_assets
|
| 33 |
+
- IS_DOCKER=true
|
| 34 |
+
|
| 35 |
+
# Comment out this block if not running on a machine with GPUs.
|
| 36 |
+
deploy:
|
| 37 |
+
resources:
|
| 38 |
+
reservations:
|
| 39 |
+
devices:
|
| 40 |
+
- driver: nvidia
|
| 41 |
+
count: 1
|
| 42 |
+
capabilities: [gpu]
|
openpi/examples/aloha_sim/env.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gym_aloha # noqa: F401
|
| 2 |
+
import gymnasium
|
| 3 |
+
import numpy as np
|
| 4 |
+
from openpi_client import image_tools
|
| 5 |
+
from openpi_client.runtime import environment as _environment
|
| 6 |
+
from typing_extensions import override
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AlohaSimEnvironment(_environment.Environment):
|
| 10 |
+
"""An environment for an Aloha robot in simulation."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
| 13 |
+
np.random.seed(seed)
|
| 14 |
+
self._rng = np.random.default_rng(seed)
|
| 15 |
+
|
| 16 |
+
self._gym = gymnasium.make(task, obs_type=obs_type)
|
| 17 |
+
|
| 18 |
+
self._last_obs = None
|
| 19 |
+
self._done = True
|
| 20 |
+
self._episode_reward = 0.0
|
| 21 |
+
|
| 22 |
+
@override
|
| 23 |
+
def reset(self) -> None:
|
| 24 |
+
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
| 25 |
+
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
| 26 |
+
self._done = False
|
| 27 |
+
self._episode_reward = 0.0
|
| 28 |
+
|
| 29 |
+
@override
|
| 30 |
+
def is_episode_complete(self) -> bool:
|
| 31 |
+
return self._done
|
| 32 |
+
|
| 33 |
+
@override
|
| 34 |
+
def get_observation(self) -> dict:
|
| 35 |
+
if self._last_obs is None:
|
| 36 |
+
raise RuntimeError("Observation is not set. Call reset() first.")
|
| 37 |
+
|
| 38 |
+
return self._last_obs # type: ignore
|
| 39 |
+
|
| 40 |
+
@override
|
| 41 |
+
def apply_action(self, action: dict) -> None:
|
| 42 |
+
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
|
| 43 |
+
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
| 44 |
+
self._done = terminated or truncated
|
| 45 |
+
self._episode_reward = max(self._episode_reward, reward)
|
| 46 |
+
|
| 47 |
+
def _convert_observation(self, gym_obs: dict) -> dict:
|
| 48 |
+
img = gym_obs["pixels"]["top"]
|
| 49 |
+
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
|
| 50 |
+
# Convert axis order from [H, W, C] --> [C, H, W]
|
| 51 |
+
img = np.transpose(img, (2, 0, 1))
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
"state": gym_obs["agent_pos"],
|
| 55 |
+
"images": {"cam_high": img},
|
| 56 |
+
}
|
openpi/examples/aloha_sim/main.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import pathlib
|
| 4 |
+
|
| 5 |
+
import env as _env
|
| 6 |
+
from openpi_client import action_chunk_broker
|
| 7 |
+
from openpi_client import websocket_client_policy as _websocket_client_policy
|
| 8 |
+
from openpi_client.runtime import runtime as _runtime
|
| 9 |
+
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
| 10 |
+
import saver as _saver
|
| 11 |
+
import tyro
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclasses.dataclass
|
| 15 |
+
class Args:
|
| 16 |
+
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
|
| 17 |
+
|
| 18 |
+
task: str = "gym_aloha/AlohaTransferCube-v0"
|
| 19 |
+
seed: int = 0
|
| 20 |
+
|
| 21 |
+
action_horizon: int = 10
|
| 22 |
+
|
| 23 |
+
host: str = "0.0.0.0"
|
| 24 |
+
port: int = 8000
|
| 25 |
+
|
| 26 |
+
display: bool = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main(args: Args) -> None:
|
| 30 |
+
runtime = _runtime.Runtime(
|
| 31 |
+
environment=_env.AlohaSimEnvironment(
|
| 32 |
+
task=args.task,
|
| 33 |
+
seed=args.seed,
|
| 34 |
+
),
|
| 35 |
+
agent=_policy_agent.PolicyAgent(
|
| 36 |
+
policy=action_chunk_broker.ActionChunkBroker(
|
| 37 |
+
policy=_websocket_client_policy.WebsocketClientPolicy(
|
| 38 |
+
host=args.host,
|
| 39 |
+
port=args.port,
|
| 40 |
+
),
|
| 41 |
+
action_horizon=args.action_horizon,
|
| 42 |
+
)
|
| 43 |
+
),
|
| 44 |
+
subscribers=[
|
| 45 |
+
_saver.VideoSaver(args.out_dir),
|
| 46 |
+
],
|
| 47 |
+
max_hz=50,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
runtime.run()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
logging.basicConfig(level=logging.INFO, force=True)
|
| 55 |
+
tyro.cli(main)
|
openpi/examples/aloha_sim/requirements.in
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gym-aloha
|
| 2 |
+
imageio
|
| 3 |
+
matplotlib
|
| 4 |
+
msgpack
|
| 5 |
+
numpy>=1.22.4,<2.0.0
|
| 6 |
+
typing-extensions
|
| 7 |
+
tyro
|
| 8 |
+
websockets
|
openpi/examples/aloha_sim/requirements.txt
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
| 3 |
+
absl-py==2.1.0
|
| 4 |
+
# via
|
| 5 |
+
# dm-control
|
| 6 |
+
# dm-env
|
| 7 |
+
# labmaze
|
| 8 |
+
# mujoco
|
| 9 |
+
certifi==2024.8.30
|
| 10 |
+
# via requests
|
| 11 |
+
charset-normalizer==3.4.0
|
| 12 |
+
# via requests
|
| 13 |
+
cloudpickle==3.1.0
|
| 14 |
+
# via gymnasium
|
| 15 |
+
contourpy==1.3.1
|
| 16 |
+
# via matplotlib
|
| 17 |
+
cycler==0.12.1
|
| 18 |
+
# via matplotlib
|
| 19 |
+
dm-control==1.0.14
|
| 20 |
+
# via gym-aloha
|
| 21 |
+
dm-env==1.6
|
| 22 |
+
# via dm-control
|
| 23 |
+
dm-tree==0.1.8
|
| 24 |
+
# via
|
| 25 |
+
# dm-control
|
| 26 |
+
# dm-env
|
| 27 |
+
docstring-parser==0.16
|
| 28 |
+
# via tyro
|
| 29 |
+
farama-notifications==0.0.4
|
| 30 |
+
# via gymnasium
|
| 31 |
+
fonttools==4.55.2
|
| 32 |
+
# via matplotlib
|
| 33 |
+
glfw==2.8.0
|
| 34 |
+
# via
|
| 35 |
+
# dm-control
|
| 36 |
+
# mujoco
|
| 37 |
+
gym-aloha==0.1.1
|
| 38 |
+
# via -r examples/aloha_sim/requirements.in
|
| 39 |
+
gymnasium==1.0.0
|
| 40 |
+
# via gym-aloha
|
| 41 |
+
idna==3.10
|
| 42 |
+
# via requests
|
| 43 |
+
imageio==2.36.1
|
| 44 |
+
# via
|
| 45 |
+
# -r examples/aloha_sim/requirements.in
|
| 46 |
+
# gym-aloha
|
| 47 |
+
imageio-ffmpeg==0.5.1
|
| 48 |
+
# via imageio
|
| 49 |
+
kiwisolver==1.4.7
|
| 50 |
+
# via matplotlib
|
| 51 |
+
labmaze==1.0.6
|
| 52 |
+
# via dm-control
|
| 53 |
+
lxml==5.3.0
|
| 54 |
+
# via dm-control
|
| 55 |
+
markdown-it-py==3.0.0
|
| 56 |
+
# via rich
|
| 57 |
+
matplotlib==3.9.3
|
| 58 |
+
# via -r examples/aloha_sim/requirements.in
|
| 59 |
+
mdurl==0.1.2
|
| 60 |
+
# via markdown-it-py
|
| 61 |
+
msgpack==1.1.0
|
| 62 |
+
# via -r examples/aloha_sim/requirements.in
|
| 63 |
+
mujoco==2.3.7
|
| 64 |
+
# via
|
| 65 |
+
# dm-control
|
| 66 |
+
# gym-aloha
|
| 67 |
+
numpy==1.26.4
|
| 68 |
+
# via
|
| 69 |
+
# -r examples/aloha_sim/requirements.in
|
| 70 |
+
# contourpy
|
| 71 |
+
# dm-control
|
| 72 |
+
# dm-env
|
| 73 |
+
# gymnasium
|
| 74 |
+
# imageio
|
| 75 |
+
# labmaze
|
| 76 |
+
# matplotlib
|
| 77 |
+
# mujoco
|
| 78 |
+
# scipy
|
| 79 |
+
packaging==24.2
|
| 80 |
+
# via matplotlib
|
| 81 |
+
pillow==11.0.0
|
| 82 |
+
# via
|
| 83 |
+
# imageio
|
| 84 |
+
# matplotlib
|
| 85 |
+
protobuf==5.29.1
|
| 86 |
+
# via dm-control
|
| 87 |
+
psutil==6.1.0
|
| 88 |
+
# via imageio
|
| 89 |
+
pygments==2.18.0
|
| 90 |
+
# via rich
|
| 91 |
+
pyopengl==3.1.7
|
| 92 |
+
# via
|
| 93 |
+
# dm-control
|
| 94 |
+
# mujoco
|
| 95 |
+
pyparsing==3.2.0
|
| 96 |
+
# via
|
| 97 |
+
# dm-control
|
| 98 |
+
# matplotlib
|
| 99 |
+
python-dateutil==2.9.0.post0
|
| 100 |
+
# via matplotlib
|
| 101 |
+
requests==2.32.3
|
| 102 |
+
# via dm-control
|
| 103 |
+
rich==13.9.4
|
| 104 |
+
# via tyro
|
| 105 |
+
scipy==1.14.1
|
| 106 |
+
# via dm-control
|
| 107 |
+
setuptools==75.6.0
|
| 108 |
+
# via
|
| 109 |
+
# dm-control
|
| 110 |
+
# imageio-ffmpeg
|
| 111 |
+
# labmaze
|
| 112 |
+
shtab==1.7.1
|
| 113 |
+
# via tyro
|
| 114 |
+
six==1.17.0
|
| 115 |
+
# via python-dateutil
|
| 116 |
+
tqdm==4.67.1
|
| 117 |
+
# via dm-control
|
| 118 |
+
typeguard==4.4.1
|
| 119 |
+
# via tyro
|
| 120 |
+
typing-extensions==4.12.2
|
| 121 |
+
# via
|
| 122 |
+
# -r examples/aloha_sim/requirements.in
|
| 123 |
+
# gymnasium
|
| 124 |
+
# rich
|
| 125 |
+
# typeguard
|
| 126 |
+
# tyro
|
| 127 |
+
tyro==0.9.2
|
| 128 |
+
# via -r examples/aloha_sim/requirements.in
|
| 129 |
+
urllib3==2.2.3
|
| 130 |
+
# via requests
|
| 131 |
+
websockets==14.1
|
| 132 |
+
# via -r examples/aloha_sim/requirements.in
|
openpi/examples/aloha_sim/saver.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import pathlib
|
| 3 |
+
|
| 4 |
+
import imageio
|
| 5 |
+
import numpy as np
|
| 6 |
+
from openpi_client.runtime import subscriber as _subscriber
|
| 7 |
+
from typing_extensions import override
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VideoSaver(_subscriber.Subscriber):
|
| 11 |
+
"""Saves episode data."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
|
| 14 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 15 |
+
self._out_dir = out_dir
|
| 16 |
+
self._images: list[np.ndarray] = []
|
| 17 |
+
self._subsample = subsample
|
| 18 |
+
|
| 19 |
+
@override
|
| 20 |
+
def on_episode_start(self) -> None:
|
| 21 |
+
self._images = []
|
| 22 |
+
|
| 23 |
+
@override
|
| 24 |
+
def on_step(self, observation: dict, action: dict) -> None:
|
| 25 |
+
im = observation["images"]["cam_high"] # [C, H, W]
|
| 26 |
+
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
| 27 |
+
self._images.append(im)
|
| 28 |
+
|
| 29 |
+
@override
|
| 30 |
+
def on_episode_end(self) -> None:
|
| 31 |
+
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
|
| 32 |
+
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
|
| 33 |
+
out_path = self._out_dir / f"out_{next_idx}.mp4"
|
| 34 |
+
|
| 35 |
+
logging.info(f"Saving video to {out_path}")
|
| 36 |
+
imageio.mimwrite(
|
| 37 |
+
out_path,
|
| 38 |
+
[np.asarray(x) for x in self._images[:: self._subsample]],
|
| 39 |
+
fps=50 // max(1, self._subsample),
|
| 40 |
+
)
|
openpi/examples/convert_jax_model_to_pytorch.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
|
| 4 |
+
|
| 5 |
+
This script loads a JAX model checkpoint using orbax and can either:
|
| 6 |
+
1. Print out all the parameter keys in a hierarchical structure for inspection
|
| 7 |
+
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# Just inspect keys:
|
| 11 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 12 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
| 13 |
+
|
| 14 |
+
# Convert to PyTorch:
|
| 15 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 16 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
# pi0_droid
|
| 20 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
|
| 21 |
+
|
| 22 |
+
# pi0_aloha_sim
|
| 23 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
| 24 |
+
|
| 25 |
+
# pi05_droid
|
| 26 |
+
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import json
|
| 30 |
+
import os
|
| 31 |
+
import pathlib
|
| 32 |
+
import shutil
|
| 33 |
+
from typing import Literal
|
| 34 |
+
|
| 35 |
+
from flax.nnx import traversals
|
| 36 |
+
import numpy as np
|
| 37 |
+
import orbax.checkpoint as ocp
|
| 38 |
+
import safetensors
|
| 39 |
+
import torch
|
| 40 |
+
import tyro
|
| 41 |
+
|
| 42 |
+
import openpi.models.gemma
|
| 43 |
+
import openpi.models.model
|
| 44 |
+
import openpi.models.pi0_config
|
| 45 |
+
import openpi.models_pytorch.pi0_pytorch
|
| 46 |
+
from openpi.training import utils
|
| 47 |
+
import openpi.training.config as _config
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def slice_paligemma_state_dict(state_dict, config):
|
| 51 |
+
"""Convert PaliGemma JAX parameters to PyTorch format."""
|
| 52 |
+
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
| 53 |
+
|
| 54 |
+
# patch embeddings
|
| 55 |
+
jax_key = f"img/embedding/kernel{suffix}"
|
| 56 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
|
| 57 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
|
| 58 |
+
|
| 59 |
+
jax_key = f"img/embedding/bias{suffix}"
|
| 60 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
|
| 61 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 62 |
+
|
| 63 |
+
# positional embeddings
|
| 64 |
+
jax_key = f"img/pos_embedding{suffix}"
|
| 65 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
|
| 66 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
|
| 67 |
+
|
| 68 |
+
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
| 69 |
+
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
| 70 |
+
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
| 71 |
+
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
| 72 |
+
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
| 73 |
+
|
| 74 |
+
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
| 75 |
+
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
| 76 |
+
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
| 77 |
+
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
| 78 |
+
|
| 79 |
+
encoderblock_attention_0_key_kernel = state_dict.pop(
|
| 80 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
|
| 81 |
+
)
|
| 82 |
+
encoderblock_attention_0_key_bias = state_dict.pop(
|
| 83 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
|
| 84 |
+
)
|
| 85 |
+
encoderblock_attention_0_value_kernel = state_dict.pop(
|
| 86 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
|
| 87 |
+
)
|
| 88 |
+
encoderblock_attention_0_value_bias = state_dict.pop(
|
| 89 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
|
| 90 |
+
)
|
| 91 |
+
encoderblock_attention_0_query_kernel = state_dict.pop(
|
| 92 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
|
| 93 |
+
)
|
| 94 |
+
encoderblock_attention_0_query_bias = state_dict.pop(
|
| 95 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
|
| 96 |
+
)
|
| 97 |
+
encoderblock_attention_0_out_kernel = state_dict.pop(
|
| 98 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
|
| 99 |
+
)
|
| 100 |
+
encoderblock_attention_0_out_bias = state_dict.pop(
|
| 101 |
+
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
for i in range(config.vision_config.num_hidden_layers):
|
| 105 |
+
state_dict[
|
| 106 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
|
| 107 |
+
] = encoderblock_layernorm0_scale[i].transpose()
|
| 108 |
+
state_dict[
|
| 109 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
|
| 110 |
+
] = encoderblock_layernorm0_bias[i]
|
| 111 |
+
state_dict[
|
| 112 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
|
| 113 |
+
] = encoderblock_layernorm1_scale[i].transpose()
|
| 114 |
+
state_dict[
|
| 115 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
|
| 116 |
+
] = encoderblock_layernorm1_bias[i]
|
| 117 |
+
state_dict[
|
| 118 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
|
| 119 |
+
] = encoderblock_mlp_dense0_kernel[i].transpose()
|
| 120 |
+
state_dict[
|
| 121 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
|
| 122 |
+
] = encoderblock_mlp_dense0_bias[i]
|
| 123 |
+
state_dict[
|
| 124 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
|
| 125 |
+
] = encoderblock_mlp_dense1_kernel[i].transpose()
|
| 126 |
+
state_dict[
|
| 127 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
|
| 128 |
+
] = encoderblock_mlp_dense1_bias[i]
|
| 129 |
+
state_dict[
|
| 130 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
|
| 131 |
+
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 132 |
+
state_dict[
|
| 133 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
|
| 134 |
+
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 135 |
+
state_dict[
|
| 136 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
|
| 137 |
+
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 138 |
+
state_dict[
|
| 139 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
|
| 140 |
+
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 141 |
+
state_dict[
|
| 142 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
|
| 143 |
+
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 144 |
+
state_dict[
|
| 145 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
|
| 146 |
+
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 147 |
+
state_dict[
|
| 148 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
|
| 149 |
+
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
| 150 |
+
state_dict[
|
| 151 |
+
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
|
| 152 |
+
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
| 153 |
+
|
| 154 |
+
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
|
| 155 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
|
| 156 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 157 |
+
|
| 158 |
+
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
|
| 159 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
|
| 160 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 161 |
+
|
| 162 |
+
# multimodal projector
|
| 163 |
+
jax_key = f"img/head/kernel{suffix}"
|
| 164 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
|
| 165 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
| 166 |
+
|
| 167 |
+
jax_key = f"img/head/bias{suffix}"
|
| 168 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
|
| 169 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 170 |
+
|
| 171 |
+
# text decoder (gemma)
|
| 172 |
+
jax_key = f"llm/embedder/input_embedding{suffix}"
|
| 173 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
| 174 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 175 |
+
|
| 176 |
+
# pop the einsum attention + mlp representations
|
| 177 |
+
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
| 178 |
+
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
| 179 |
+
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
| 180 |
+
|
| 181 |
+
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
| 182 |
+
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
| 183 |
+
|
| 184 |
+
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
| 185 |
+
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
| 186 |
+
|
| 187 |
+
for i in range(config.text_config.num_hidden_layers):
|
| 188 |
+
q_proj_weight_reshaped = (
|
| 189 |
+
llm_attention_q_einsum[i]
|
| 190 |
+
.transpose(0, 2, 1)
|
| 191 |
+
.reshape(
|
| 192 |
+
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
|
| 196 |
+
q_proj_weight_reshaped
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 200 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
|
| 201 |
+
k_proj_weight_reshaped
|
| 202 |
+
)
|
| 203 |
+
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 204 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
|
| 205 |
+
v_proj_weight_reshaped
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
o_proj_weight_reshaped = (
|
| 209 |
+
llm_attention_attn_vec_einsum[i]
|
| 210 |
+
.transpose(2, 0, 1)
|
| 211 |
+
.reshape(
|
| 212 |
+
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
|
| 216 |
+
o_proj_weight_reshaped
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 220 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
|
| 221 |
+
gate_proj_weight.transpose()
|
| 222 |
+
)
|
| 223 |
+
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 224 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
|
| 225 |
+
up_proj_weight.transpose()
|
| 226 |
+
)
|
| 227 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
|
| 228 |
+
llm_mlp_linear[i].transpose()
|
| 229 |
+
)
|
| 230 |
+
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
|
| 231 |
+
llm_input_layernorm[i]
|
| 232 |
+
)
|
| 233 |
+
state_dict[
|
| 234 |
+
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
|
| 235 |
+
] = llm_post_attention_layernorm[i]
|
| 236 |
+
|
| 237 |
+
jax_key = f"llm/final_norm/scale{suffix}"
|
| 238 |
+
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
|
| 239 |
+
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
| 240 |
+
|
| 241 |
+
expert_dict = {}
|
| 242 |
+
final_state_dict = {}
|
| 243 |
+
|
| 244 |
+
# Expert-related keys to extract (including pi05 Dense layer parameters)
|
| 245 |
+
expert_keys = [
|
| 246 |
+
f"llm/final_norm_1/scale{suffix}",
|
| 247 |
+
f"llm/final_norm_1/Dense_0/bias{suffix}",
|
| 248 |
+
f"llm/final_norm_1/Dense_0/kernel{suffix}",
|
| 249 |
+
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
| 250 |
+
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
| 251 |
+
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
| 252 |
+
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
| 253 |
+
f"llm/layers/mlp_1/linear{suffix}",
|
| 254 |
+
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
| 255 |
+
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
|
| 256 |
+
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
|
| 257 |
+
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
| 258 |
+
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
|
| 259 |
+
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
for key, value in state_dict.items():
|
| 263 |
+
if key not in expert_keys:
|
| 264 |
+
final_state_dict[key] = torch.from_numpy(value)
|
| 265 |
+
else:
|
| 266 |
+
expert_dict[key] = value
|
| 267 |
+
|
| 268 |
+
return final_state_dict, expert_dict
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
|
| 272 |
+
"""Convert Gemma JAX parameters to PyTorch format."""
|
| 273 |
+
# Add missing attributes to config if they don't exist
|
| 274 |
+
if not hasattr(config, "vocab_size"):
|
| 275 |
+
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
|
| 276 |
+
if not hasattr(config, "hidden_size"):
|
| 277 |
+
config.hidden_size = config.width
|
| 278 |
+
if not hasattr(config, "num_hidden_layers"):
|
| 279 |
+
config.num_hidden_layers = config.depth
|
| 280 |
+
if not hasattr(config, "num_attention_heads"):
|
| 281 |
+
config.num_attention_heads = config.num_heads
|
| 282 |
+
|
| 283 |
+
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
| 284 |
+
|
| 285 |
+
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
| 286 |
+
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
| 287 |
+
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
| 288 |
+
|
| 289 |
+
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
| 290 |
+
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
| 291 |
+
|
| 292 |
+
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
|
| 293 |
+
if "pi05" in checkpoint_dir:
|
| 294 |
+
# Pi05 with adaptive normalization
|
| 295 |
+
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 296 |
+
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 297 |
+
llm_input_layernorm_kernel = state_dict.pop(
|
| 298 |
+
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
|
| 299 |
+
)
|
| 300 |
+
llm_post_attention_layernorm_kernel = state_dict.pop(
|
| 301 |
+
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
# Regular pi0 with standard RMSNorm
|
| 305 |
+
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
| 306 |
+
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
| 307 |
+
|
| 308 |
+
for i in range(config.num_hidden_layers):
|
| 309 |
+
q_proj_weight_reshaped = (
|
| 310 |
+
llm_attention_q_einsum[i]
|
| 311 |
+
.transpose(0, 2, 1)
|
| 312 |
+
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
| 313 |
+
)
|
| 314 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
|
| 315 |
+
q_proj_weight_reshaped
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
| 319 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
|
| 320 |
+
k_proj_weight_reshaped
|
| 321 |
+
)
|
| 322 |
+
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
| 323 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
|
| 324 |
+
v_proj_weight_reshaped
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
o_proj_weight_reshaped = (
|
| 328 |
+
llm_attention_attn_vec_einsum[i]
|
| 329 |
+
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
| 330 |
+
.transpose(1, 0)
|
| 331 |
+
)
|
| 332 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
|
| 333 |
+
o_proj_weight_reshaped
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
| 337 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
|
| 338 |
+
gate_proj_weight.transpose()
|
| 339 |
+
)
|
| 340 |
+
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
| 341 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
|
| 342 |
+
up_proj_weight.transpose()
|
| 343 |
+
)
|
| 344 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
|
| 345 |
+
i
|
| 346 |
+
].transpose()
|
| 347 |
+
|
| 348 |
+
if "pi05" in checkpoint_dir:
|
| 349 |
+
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
| 350 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
|
| 351 |
+
llm_input_layernorm_bias[i]
|
| 352 |
+
)
|
| 353 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
|
| 354 |
+
llm_post_attention_layernorm_bias[i]
|
| 355 |
+
)
|
| 356 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
|
| 357 |
+
llm_input_layernorm_kernel[i].transpose()
|
| 358 |
+
)
|
| 359 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
|
| 360 |
+
llm_post_attention_layernorm_kernel[i].transpose()
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
# Regular pi0 with standard RMSNorm
|
| 364 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
|
| 365 |
+
llm_input_layernorm[i]
|
| 366 |
+
)
|
| 367 |
+
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
|
| 368 |
+
llm_post_attention_layernorm[i]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# Handle final norm layer
|
| 372 |
+
if "pi05" in checkpoint_dir:
|
| 373 |
+
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
| 374 |
+
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
|
| 375 |
+
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
|
| 376 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
|
| 377 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
|
| 378 |
+
else:
|
| 379 |
+
# Regular pi0 with standard RMSNorm
|
| 380 |
+
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
|
| 381 |
+
f"llm/final_norm_{num_expert}/scale{suffix}"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
|
| 385 |
+
|
| 386 |
+
final_state_dict = {}
|
| 387 |
+
for key, value in state_dict.items():
|
| 388 |
+
if not isinstance(value, torch.Tensor):
|
| 389 |
+
final_state_dict[key] = torch.from_numpy(value)
|
| 390 |
+
else:
|
| 391 |
+
final_state_dict[key] = value
|
| 392 |
+
|
| 393 |
+
return final_state_dict
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
|
| 397 |
+
"""Load and process params by restoring via JAX model loader first.
|
| 398 |
+
This respects dtype conversions that occur during model restore.
|
| 399 |
+
"""
|
| 400 |
+
# Use repository restore utility to load a pure dict of params (value suffix removed)
|
| 401 |
+
params = openpi.models.model.restore_params(
|
| 402 |
+
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def load_jax_model_and_print_keys(checkpoint_dir: str):
|
| 409 |
+
"""
|
| 410 |
+
Load JAX model from checkpoint and print all parameter keys.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
checkpoint_dir: Path to the checkpoint directory
|
| 414 |
+
"""
|
| 415 |
+
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
|
| 416 |
+
# Initialize checkpointer
|
| 417 |
+
checkpointer = ocp.PyTreeCheckpointer()
|
| 418 |
+
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
|
| 419 |
+
print(utils.array_tree_to_info(metadata))
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def convert_pi0_checkpoint(
|
| 423 |
+
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Convert PI0 JAX checkpoint to PyTorch format.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
checkpoint_dir: Path to the JAX checkpoint
|
| 430 |
+
precision: Model precision (float32, bfloat16, float16)
|
| 431 |
+
output_path: Path to save the converted PyTorch model
|
| 432 |
+
model_config: Model config
|
| 433 |
+
"""
|
| 434 |
+
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
|
| 435 |
+
print(f"Model config: {model_config}")
|
| 436 |
+
|
| 437 |
+
# Break down orbax ckpts by restoring via JAX to respect dtype
|
| 438 |
+
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
|
| 439 |
+
|
| 440 |
+
# Process projection params
|
| 441 |
+
if model_config.pi05:
|
| 442 |
+
keys = [
|
| 443 |
+
"action_in_proj",
|
| 444 |
+
"action_out_proj",
|
| 445 |
+
"time_mlp_in",
|
| 446 |
+
"time_mlp_out",
|
| 447 |
+
]
|
| 448 |
+
else:
|
| 449 |
+
keys = [
|
| 450 |
+
"state_proj",
|
| 451 |
+
"action_in_proj",
|
| 452 |
+
"action_out_proj",
|
| 453 |
+
"action_time_mlp_in",
|
| 454 |
+
"action_time_mlp_out",
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
projection_params = {}
|
| 458 |
+
for key in keys:
|
| 459 |
+
kernel_params = initial_params["projection_params"][key]["kernel"]
|
| 460 |
+
bias_params = initial_params["projection_params"][key]["bias"]
|
| 461 |
+
if isinstance(kernel_params, dict):
|
| 462 |
+
weight = kernel_params["value"]
|
| 463 |
+
bias = bias_params["value"]
|
| 464 |
+
else:
|
| 465 |
+
weight = kernel_params
|
| 466 |
+
bias = bias_params
|
| 467 |
+
|
| 468 |
+
pytorch_weight_key = f"{key}.weight"
|
| 469 |
+
pytorch_bias_key = f"{key}.bias"
|
| 470 |
+
|
| 471 |
+
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
|
| 472 |
+
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
|
| 473 |
+
|
| 474 |
+
# Create configs based on checkpoint path
|
| 475 |
+
# All models use the same PaliGemma config structure
|
| 476 |
+
class PaliGemmaConfig:
|
| 477 |
+
def __init__(self):
|
| 478 |
+
self.vision_config = type(
|
| 479 |
+
"obj",
|
| 480 |
+
(object,),
|
| 481 |
+
{
|
| 482 |
+
"hidden_size": 1152,
|
| 483 |
+
"num_hidden_layers": 27,
|
| 484 |
+
"num_attention_heads": 16,
|
| 485 |
+
"intermediate_size": 4304,
|
| 486 |
+
"patch_size": 14,
|
| 487 |
+
"projection_dim": 2048,
|
| 488 |
+
},
|
| 489 |
+
)()
|
| 490 |
+
self.text_config = type(
|
| 491 |
+
"obj",
|
| 492 |
+
(object,),
|
| 493 |
+
{
|
| 494 |
+
"hidden_size": 2048,
|
| 495 |
+
"num_hidden_layers": 18,
|
| 496 |
+
"num_attention_heads": 8,
|
| 497 |
+
"head_dim": 256,
|
| 498 |
+
"intermediate_size": 16384,
|
| 499 |
+
},
|
| 500 |
+
)()
|
| 501 |
+
|
| 502 |
+
paligemma_config = PaliGemmaConfig()
|
| 503 |
+
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
|
| 504 |
+
|
| 505 |
+
# Process PaliGemma weights
|
| 506 |
+
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
|
| 507 |
+
|
| 508 |
+
# Process Gemma weights from expert_params
|
| 509 |
+
gemma_params = slice_gemma_state_dict(
|
| 510 |
+
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Instantiate model
|
| 514 |
+
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
|
| 515 |
+
|
| 516 |
+
# Combine all parameters (no prefix needed for our model structure)
|
| 517 |
+
all_params = {**paligemma_params, **gemma_params, **projection_params}
|
| 518 |
+
|
| 519 |
+
# Load state dict
|
| 520 |
+
pi0_model.load_state_dict(all_params, strict=False)
|
| 521 |
+
|
| 522 |
+
if precision == "float32":
|
| 523 |
+
pi0_model = pi0_model.to(torch.float32)
|
| 524 |
+
elif precision == "bfloat16":
|
| 525 |
+
pi0_model = pi0_model.to(torch.bfloat16)
|
| 526 |
+
else:
|
| 527 |
+
raise ValueError(f"Invalid precision: {precision}")
|
| 528 |
+
|
| 529 |
+
# Save the converted model using safetensors
|
| 530 |
+
os.makedirs(output_path, exist_ok=True)
|
| 531 |
+
|
| 532 |
+
# Save model weights as SafeTensors using save_model to handle tied weights
|
| 533 |
+
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
|
| 534 |
+
|
| 535 |
+
# Copy assets folder if it exists
|
| 536 |
+
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
|
| 537 |
+
if assets_source.exists():
|
| 538 |
+
assets_dest = pathlib.Path(output_path) / "assets"
|
| 539 |
+
if assets_dest.exists():
|
| 540 |
+
shutil.rmtree(assets_dest)
|
| 541 |
+
shutil.copytree(assets_source, assets_dest)
|
| 542 |
+
|
| 543 |
+
# Save config as JSON for reference
|
| 544 |
+
config_dict = {
|
| 545 |
+
"action_dim": model_config.action_dim,
|
| 546 |
+
"action_horizon": model_config.action_horizon,
|
| 547 |
+
"paligemma_variant": model_config.paligemma_variant,
|
| 548 |
+
"action_expert_variant": model_config.action_expert_variant,
|
| 549 |
+
"precision": precision,
|
| 550 |
+
}
|
| 551 |
+
with open(os.path.join(output_path, "config.json"), "w") as f:
|
| 552 |
+
json.dump(config_dict, f, indent=2)
|
| 553 |
+
|
| 554 |
+
print("Model conversion completed successfully!")
|
| 555 |
+
print(f"Model saved to {output_path}")
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def main(
|
| 559 |
+
checkpoint_dir: str,
|
| 560 |
+
config_name: str,
|
| 561 |
+
output_path: str | None = None,
|
| 562 |
+
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
|
| 563 |
+
*,
|
| 564 |
+
inspect_only: bool = False,
|
| 565 |
+
):
|
| 566 |
+
"""Load JAX model and optionally convert to PyTorch.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
checkpoint_dir: Path to the JAX checkpoint directory
|
| 570 |
+
output_path: Path to save converted PyTorch model (required for conversion)
|
| 571 |
+
precision: Precision for model conversion
|
| 572 |
+
inspect_only: Only inspect parameter keys, don't convert
|
| 573 |
+
"""
|
| 574 |
+
model_config = _config.get_config(config_name).model
|
| 575 |
+
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
|
| 576 |
+
raise ValueError(f"Config {config_name} is not a Pi0Config")
|
| 577 |
+
if inspect_only:
|
| 578 |
+
load_jax_model_and_print_keys(checkpoint_dir)
|
| 579 |
+
else:
|
| 580 |
+
if not output_path:
|
| 581 |
+
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
|
| 582 |
+
return
|
| 583 |
+
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
if __name__ == "__main__":
|
| 587 |
+
tyro.cli(main)
|
openpi/examples/droid/README.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DROID Policies in openpi
|
| 2 |
+
|
| 3 |
+
We offer instructions for:
|
| 4 |
+
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
|
| 5 |
+
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
|
| 6 |
+
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
|
| 7 |
+
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
|
| 8 |
+
|
| 9 |
+
## Running DROID Inference
|
| 10 |
+
|
| 11 |
+
This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
### Step 1: Start a policy server
|
| 15 |
+
|
| 16 |
+
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
| 17 |
+
|
| 18 |
+
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
| 19 |
+
2. Start the OpenPI server via the following command:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
You can also run the equivalent command below:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
uv run scripts/serve_policy.py --env=DROID
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Step 2: Run the DROID robot
|
| 32 |
+
|
| 33 |
+
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
| 34 |
+
2. On the control laptop, activate your DROID conda environment.
|
| 35 |
+
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
| 36 |
+
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
| 37 |
+
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
| 38 |
+
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
| 39 |
+
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
| 46 |
+
|
| 47 |
+
## Troubleshooting
|
| 48 |
+
|
| 49 |
+
| Issue | Solution |
|
| 50 |
+
|-------|----------|
|
| 51 |
+
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
| 52 |
+
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
| 53 |
+
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
| 54 |
+
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
## Running Other Policies
|
| 58 |
+
|
| 59 |
+
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
# Train from pi0-FAST, using FAST tokenizer
|
| 63 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
| 64 |
+
|
| 65 |
+
# Train from pi0, using flow matching
|
| 66 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
|
| 67 |
+
|
| 68 |
+
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
| 69 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
|
| 70 |
+
|
| 71 |
+
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
| 72 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
|
| 73 |
+
|
| 74 |
+
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
| 75 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
|
| 76 |
+
|
| 77 |
+
# Trained from PaliGemma, using FSQ tokenizer.
|
| 78 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
|
| 79 |
+
|
| 80 |
+
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
| 81 |
+
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
|
openpi/examples/droid/README_train.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training on DROID
|
| 2 |
+
|
| 3 |
+
Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
|
| 4 |
+
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
|
| 5 |
+
|
| 6 |
+
In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
|
| 7 |
+
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
|
| 8 |
+
|
| 9 |
+
## Install
|
| 10 |
+
|
| 11 |
+
We need a few additional dependencies for RLDS data loading. Run:
|
| 12 |
+
```bash
|
| 13 |
+
uv sync --group rlds
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Download DROID dataset
|
| 17 |
+
|
| 18 |
+
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
|
| 19 |
+
```
|
| 20 |
+
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
|
| 24 |
+
|
| 25 |
+
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
|
| 26 |
+
|
| 27 |
+
## Run
|
| 28 |
+
|
| 29 |
+
First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
|
| 30 |
+
|
| 31 |
+
Then, compute normalization statistics (this will take ~10 minutes):
|
| 32 |
+
```bash
|
| 33 |
+
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Run training:
|
| 37 |
+
```bash
|
| 38 |
+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
|
| 42 |
+
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
|
| 43 |
+
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
## Compute Requirements
|
| 47 |
+
|
| 48 |
+
Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
|
| 49 |
+
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
|
| 50 |
+
|
| 51 |
+
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
## Data Filtering
|
| 55 |
+
|
| 56 |
+
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
|
| 57 |
+
|
| 58 |
+
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
|
| 59 |
+
|
| 60 |
+
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
|
| 61 |
+
|
| 62 |
+
## RoboArena
|
| 63 |
+
|
| 64 |
+
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
|
| 65 |
+
|
| 66 |
+
If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Fine-Tuning on Custom DROID Datasets
|
| 70 |
+
|
| 71 |
+
Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
|
| 72 |
+
|
| 73 |
+
Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## Step 1: Converting your custom DROID dataset to LeRobot
|
| 77 |
+
|
| 78 |
+
We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
|
| 79 |
+
```
|
| 80 |
+
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
|
| 84 |
+
```
|
| 85 |
+
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
|
| 89 |
+
|
| 90 |
+
Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
|
| 91 |
+
```
|
| 92 |
+
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Step 2: Run fine-tuning with your custom dataset
|
| 96 |
+
|
| 97 |
+
Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
|
| 98 |
+
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
|
| 99 |
+
|
| 100 |
+
To launch training:
|
| 101 |
+
```
|
| 102 |
+
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
|
| 106 |
+
|
openpi/examples/droid/compute_droid_nonidle_ranges.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
|
| 3 |
+
that should be sampled during training (all others are filtered out).
|
| 4 |
+
|
| 5 |
+
Filtering logic:
|
| 6 |
+
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
|
| 7 |
+
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
|
| 8 |
+
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
|
| 9 |
+
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
|
| 10 |
+
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
|
| 11 |
+
|
| 12 |
+
This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
|
| 13 |
+
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import tensorflow as tf
|
| 22 |
+
import tensorflow_datasets as tfds
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
|
| 26 |
+
|
| 27 |
+
builder = tfds.builder_from_directory(
|
| 28 |
+
# path to the `droid` directory (not its parent)
|
| 29 |
+
builder_dir="<path_to_droid_dataset_tfds_files>",
|
| 30 |
+
)
|
| 31 |
+
ds = builder.as_dataset(split="train", shuffle_files=False)
|
| 32 |
+
tf.data.experimental.ignore_errors(ds)
|
| 33 |
+
|
| 34 |
+
keep_ranges_path = "<path_to_where_to_save_the_json>"
|
| 35 |
+
|
| 36 |
+
min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
|
| 37 |
+
min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
|
| 38 |
+
filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
|
| 39 |
+
|
| 40 |
+
keep_ranges_map = {}
|
| 41 |
+
if Path(keep_ranges_path).exists():
|
| 42 |
+
with Path(keep_ranges_path).open("r") as f:
|
| 43 |
+
keep_ranges_map = json.load(f)
|
| 44 |
+
print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
|
| 45 |
+
|
| 46 |
+
for ep_idx, ep in enumerate(tqdm(ds)):
|
| 47 |
+
recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
|
| 48 |
+
file_path = ep["episode_metadata"]["file_path"].numpy().decode()
|
| 49 |
+
|
| 50 |
+
key = f"{recording_folderpath}--{file_path}"
|
| 51 |
+
if key in keep_ranges_map:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
|
| 55 |
+
joint_velocities = np.array(joint_velocities)
|
| 56 |
+
|
| 57 |
+
is_idle_array = np.hstack(
|
| 58 |
+
[np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Find what steps go from idle to non-idle and vice-versa
|
| 62 |
+
is_idle_padded = np.concatenate(
|
| 63 |
+
[[False], is_idle_array, [False]]
|
| 64 |
+
) # Start and end with False, so idle at first step is a start of motion
|
| 65 |
+
|
| 66 |
+
is_idle_diff = np.diff(is_idle_padded.astype(int))
|
| 67 |
+
is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
|
| 68 |
+
is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
|
| 69 |
+
|
| 70 |
+
# Find which steps correspond to idle segments of length at least min_idle_len
|
| 71 |
+
true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
|
| 72 |
+
is_idle_true_starts = is_idle_true_starts[true_segment_masks]
|
| 73 |
+
is_idle_true_ends = is_idle_true_ends[true_segment_masks]
|
| 74 |
+
|
| 75 |
+
keep_mask = np.ones(len(joint_velocities), dtype=bool)
|
| 76 |
+
for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
|
| 77 |
+
keep_mask[start:end] = False
|
| 78 |
+
|
| 79 |
+
# Get all non-idle ranges of at least 16
|
| 80 |
+
# Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
|
| 81 |
+
keep_padded = np.concatenate([[False], keep_mask, [False]])
|
| 82 |
+
|
| 83 |
+
keep_diff = np.diff(keep_padded.astype(int))
|
| 84 |
+
keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
|
| 85 |
+
keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
|
| 86 |
+
|
| 87 |
+
# Find which steps correspond to non-idle segments of length at least min_non_idle_len
|
| 88 |
+
true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
|
| 89 |
+
keep_true_starts = keep_true_starts[true_segment_masks]
|
| 90 |
+
keep_true_ends = keep_true_ends[true_segment_masks]
|
| 91 |
+
|
| 92 |
+
# Add mapping from episode unique ID key to list of non-idle ranges to keep
|
| 93 |
+
keep_ranges_map[key] = []
|
| 94 |
+
for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
|
| 95 |
+
keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
|
| 96 |
+
|
| 97 |
+
if ep_idx % 1000 == 0:
|
| 98 |
+
with Path(keep_ranges_path).open("w") as f:
|
| 99 |
+
json.dump(keep_ranges_map, f)
|
| 100 |
+
|
| 101 |
+
print("Done!")
|
| 102 |
+
with Path(keep_ranges_path).open("w") as f:
|
| 103 |
+
json.dump(keep_ranges_map, f)
|
openpi/examples/droid/convert_droid_data_to_lerobot.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
|
| 6 |
+
|
| 7 |
+
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
| 8 |
+
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
| 9 |
+
|
| 10 |
+
The resulting dataset will get saved to the $LEROBOT_HOME directory.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import copy
|
| 15 |
+
import glob
|
| 16 |
+
import json
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import shutil
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import h5py
|
| 22 |
+
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
| 23 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 24 |
+
import numpy as np
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
import tyro
|
| 28 |
+
|
| 29 |
+
REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def resize_image(image, size):
|
| 33 |
+
image = Image.fromarray(image)
|
| 34 |
+
return np.array(image.resize(size, resample=Image.BICUBIC))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(data_dir: str, *, push_to_hub: bool = False):
|
| 38 |
+
# Clean up any existing dataset in the output directory
|
| 39 |
+
output_path = HF_LEROBOT_HOME / REPO_NAME
|
| 40 |
+
if output_path.exists():
|
| 41 |
+
shutil.rmtree(output_path)
|
| 42 |
+
data_dir = Path(data_dir)
|
| 43 |
+
|
| 44 |
+
# Create LeRobot dataset, define features to store
|
| 45 |
+
# We will follow the DROID data naming conventions here.
|
| 46 |
+
# LeRobot assumes that dtype of image data is `image`
|
| 47 |
+
dataset = LeRobotDataset.create(
|
| 48 |
+
repo_id=REPO_NAME,
|
| 49 |
+
robot_type="panda",
|
| 50 |
+
fps=15, # DROID data is typically recorded at 15fps
|
| 51 |
+
features={
|
| 52 |
+
# We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
|
| 53 |
+
"exterior_image_1_left": {
|
| 54 |
+
"dtype": "image",
|
| 55 |
+
"shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
|
| 56 |
+
"names": ["height", "width", "channel"],
|
| 57 |
+
},
|
| 58 |
+
"exterior_image_2_left": {
|
| 59 |
+
"dtype": "image",
|
| 60 |
+
"shape": (180, 320, 3),
|
| 61 |
+
"names": ["height", "width", "channel"],
|
| 62 |
+
},
|
| 63 |
+
"wrist_image_left": {
|
| 64 |
+
"dtype": "image",
|
| 65 |
+
"shape": (180, 320, 3),
|
| 66 |
+
"names": ["height", "width", "channel"],
|
| 67 |
+
},
|
| 68 |
+
"joint_position": {
|
| 69 |
+
"dtype": "float32",
|
| 70 |
+
"shape": (7,),
|
| 71 |
+
"names": ["joint_position"],
|
| 72 |
+
},
|
| 73 |
+
"gripper_position": {
|
| 74 |
+
"dtype": "float32",
|
| 75 |
+
"shape": (1,),
|
| 76 |
+
"names": ["gripper_position"],
|
| 77 |
+
},
|
| 78 |
+
"actions": {
|
| 79 |
+
"dtype": "float32",
|
| 80 |
+
"shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
|
| 81 |
+
"names": ["actions"],
|
| 82 |
+
},
|
| 83 |
+
},
|
| 84 |
+
image_writer_threads=10,
|
| 85 |
+
image_writer_processes=5,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Load language annotations
|
| 89 |
+
# Note: we load the DROID language annotations for this example, but you can manually define them for your own data
|
| 90 |
+
with (data_dir / "aggregated-annotations-030724.json").open() as f:
|
| 91 |
+
language_annotations = json.load(f)
|
| 92 |
+
|
| 93 |
+
# Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
|
| 94 |
+
# We assume the following directory structure:
|
| 95 |
+
# RAW_DROID_PATH/
|
| 96 |
+
# - <...>/
|
| 97 |
+
# - recordings/
|
| 98 |
+
# - MP4/
|
| 99 |
+
# - <camera_id>.mp4 # single-view video of left stereo pair camera
|
| 100 |
+
# - trajectory.hdf5
|
| 101 |
+
# - <...>/
|
| 102 |
+
episode_paths = list(data_dir.glob("**/trajectory.h5"))
|
| 103 |
+
print(f"Found {len(episode_paths)} episodes for conversion")
|
| 104 |
+
|
| 105 |
+
# We will loop over each dataset_name and write episodes to the LeRobot dataset
|
| 106 |
+
for episode_path in tqdm(episode_paths, desc="Converting episodes"):
|
| 107 |
+
# Load raw data
|
| 108 |
+
recording_folderpath = episode_path.parent / "recordings" / "MP4"
|
| 109 |
+
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
|
| 110 |
+
|
| 111 |
+
# To load the language instruction, we need to parse out the episode_id from the metadata file
|
| 112 |
+
# Again, you can modify this step for your own data, to load your own language instructions
|
| 113 |
+
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
|
| 114 |
+
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
|
| 115 |
+
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
|
| 116 |
+
"language_instruction1"
|
| 117 |
+
]
|
| 118 |
+
print(f"Converting episode with language instruction: {language_instruction}")
|
| 119 |
+
|
| 120 |
+
# Write to LeRobot dataset
|
| 121 |
+
for step in trajectory:
|
| 122 |
+
camera_type_dict = step["observation"]["camera_type"]
|
| 123 |
+
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
|
| 124 |
+
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
|
| 125 |
+
dataset.add_frame(
|
| 126 |
+
{
|
| 127 |
+
# Note: need to flip BGR --> RGB for loaded images
|
| 128 |
+
"exterior_image_1_left": resize_image(
|
| 129 |
+
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
|
| 130 |
+
),
|
| 131 |
+
"exterior_image_2_left": resize_image(
|
| 132 |
+
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
|
| 133 |
+
),
|
| 134 |
+
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
|
| 135 |
+
"joint_position": np.asarray(
|
| 136 |
+
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
|
| 137 |
+
),
|
| 138 |
+
"gripper_position": np.asarray(
|
| 139 |
+
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
|
| 140 |
+
),
|
| 141 |
+
# Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
|
| 142 |
+
"actions": np.concatenate(
|
| 143 |
+
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
|
| 144 |
+
),
|
| 145 |
+
"task": language_instruction,
|
| 146 |
+
}
|
| 147 |
+
)
|
| 148 |
+
dataset.save_episode()
|
| 149 |
+
|
| 150 |
+
# Optionally push to the Hugging Face Hub
|
| 151 |
+
if push_to_hub:
|
| 152 |
+
dataset.push_to_hub(
|
| 153 |
+
tags=["libero", "panda", "rlds"],
|
| 154 |
+
private=False,
|
| 155 |
+
push_videos=True,
|
| 156 |
+
license="apache-2.0",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
##########################################################################################################
|
| 161 |
+
################ The rest of this file are functions to parse the raw DROID data #########################
|
| 162 |
+
################ You don't need to worry about understanding this part #########################
|
| 163 |
+
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
|
| 164 |
+
##########################################################################################################
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
camera_type_dict = {
|
| 168 |
+
"hand_camera_id": 0,
|
| 169 |
+
"varied_camera_1_id": 1,
|
| 170 |
+
"varied_camera_2_id": 1,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
camera_type_to_string_dict = {
|
| 174 |
+
0: "hand_camera",
|
| 175 |
+
1: "varied_camera",
|
| 176 |
+
2: "fixed_camera",
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_camera_type(cam_id):
|
| 181 |
+
if cam_id not in camera_type_dict:
|
| 182 |
+
return None
|
| 183 |
+
type_int = camera_type_dict[cam_id]
|
| 184 |
+
return camera_type_to_string_dict[type_int]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class MP4Reader:
|
| 188 |
+
def __init__(self, filepath, serial_number):
|
| 189 |
+
# Save Parameters #
|
| 190 |
+
self.serial_number = serial_number
|
| 191 |
+
self._index = 0
|
| 192 |
+
|
| 193 |
+
# Open Video Reader #
|
| 194 |
+
self._mp4_reader = cv2.VideoCapture(filepath)
|
| 195 |
+
if not self._mp4_reader.isOpened():
|
| 196 |
+
raise RuntimeError("Corrupted MP4 File")
|
| 197 |
+
|
| 198 |
+
def set_reading_parameters(
|
| 199 |
+
self,
|
| 200 |
+
image=True, # noqa: FBT002
|
| 201 |
+
concatenate_images=False, # noqa: FBT002
|
| 202 |
+
resolution=(0, 0),
|
| 203 |
+
resize_func=None,
|
| 204 |
+
):
|
| 205 |
+
# Save Parameters #
|
| 206 |
+
self.image = image
|
| 207 |
+
self.concatenate_images = concatenate_images
|
| 208 |
+
self.resolution = resolution
|
| 209 |
+
self.resize_func = cv2.resize
|
| 210 |
+
self.skip_reading = not image
|
| 211 |
+
if self.skip_reading:
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
def get_frame_resolution(self):
|
| 215 |
+
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
|
| 216 |
+
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
|
| 217 |
+
return (width, height)
|
| 218 |
+
|
| 219 |
+
def get_frame_count(self):
|
| 220 |
+
if self.skip_reading:
|
| 221 |
+
return 0
|
| 222 |
+
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
|
| 223 |
+
|
| 224 |
+
def set_frame_index(self, index):
|
| 225 |
+
if self.skip_reading:
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
if index < self._index:
|
| 229 |
+
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
|
| 230 |
+
self._index = index
|
| 231 |
+
|
| 232 |
+
while self._index < index:
|
| 233 |
+
self.read_camera(ignore_data=True)
|
| 234 |
+
|
| 235 |
+
def _process_frame(self, frame):
|
| 236 |
+
frame = copy.deepcopy(frame)
|
| 237 |
+
if self.resolution == (0, 0):
|
| 238 |
+
return frame
|
| 239 |
+
return self.resize_func(frame, self.resolution)
|
| 240 |
+
|
| 241 |
+
def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
|
| 242 |
+
# Skip if Read Unnecesary #
|
| 243 |
+
if self.skip_reading:
|
| 244 |
+
return {}
|
| 245 |
+
|
| 246 |
+
# Read Camera #
|
| 247 |
+
success, frame = self._mp4_reader.read()
|
| 248 |
+
|
| 249 |
+
self._index += 1
|
| 250 |
+
if not success:
|
| 251 |
+
return None
|
| 252 |
+
if ignore_data:
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
# Return Data #
|
| 256 |
+
data_dict = {}
|
| 257 |
+
|
| 258 |
+
if self.concatenate_images or "stereo" not in self.serial_number:
|
| 259 |
+
data_dict["image"] = {self.serial_number: self._process_frame(frame)}
|
| 260 |
+
else:
|
| 261 |
+
single_width = frame.shape[1] // 2
|
| 262 |
+
data_dict["image"] = {
|
| 263 |
+
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
|
| 264 |
+
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
return data_dict
|
| 268 |
+
|
| 269 |
+
def disable_camera(self):
|
| 270 |
+
if hasattr(self, "_mp4_reader"):
|
| 271 |
+
self._mp4_reader.release()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class RecordedMultiCameraWrapper:
|
| 275 |
+
def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
|
| 276 |
+
# Save Camera Info #
|
| 277 |
+
self.camera_kwargs = camera_kwargs
|
| 278 |
+
|
| 279 |
+
# Open Camera Readers #
|
| 280 |
+
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
|
| 281 |
+
all_filepaths = mp4_filepaths
|
| 282 |
+
|
| 283 |
+
self.camera_dict = {}
|
| 284 |
+
for f in all_filepaths:
|
| 285 |
+
serial_number = f.split("/")[-1][:-4]
|
| 286 |
+
cam_type = get_camera_type(serial_number)
|
| 287 |
+
camera_kwargs.get(cam_type, {})
|
| 288 |
+
|
| 289 |
+
if f.endswith(".mp4"):
|
| 290 |
+
Reader = MP4Reader # noqa: N806
|
| 291 |
+
else:
|
| 292 |
+
raise ValueError
|
| 293 |
+
|
| 294 |
+
self.camera_dict[serial_number] = Reader(f, serial_number)
|
| 295 |
+
|
| 296 |
+
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
|
| 297 |
+
full_obs_dict = defaultdict(dict)
|
| 298 |
+
|
| 299 |
+
# Read Cameras In Randomized Order #
|
| 300 |
+
all_cam_ids = list(self.camera_dict.keys())
|
| 301 |
+
# random.shuffle(all_cam_ids)
|
| 302 |
+
|
| 303 |
+
for cam_id in all_cam_ids:
|
| 304 |
+
if "stereo" in cam_id:
|
| 305 |
+
continue
|
| 306 |
+
try:
|
| 307 |
+
cam_type = camera_type_dict[cam_id]
|
| 308 |
+
except KeyError:
|
| 309 |
+
print(f"{self.camera_dict} -- {camera_type_dict}")
|
| 310 |
+
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
|
| 311 |
+
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
|
| 312 |
+
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
|
| 313 |
+
|
| 314 |
+
timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
|
| 315 |
+
if index is not None:
|
| 316 |
+
self.camera_dict[cam_id].set_frame_index(index)
|
| 317 |
+
|
| 318 |
+
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
|
| 319 |
+
|
| 320 |
+
# Process Returned Data #
|
| 321 |
+
if data_dict is None:
|
| 322 |
+
return None
|
| 323 |
+
for key in data_dict:
|
| 324 |
+
full_obs_dict[key].update(data_dict[key])
|
| 325 |
+
|
| 326 |
+
return full_obs_dict
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
|
| 330 |
+
length = None
|
| 331 |
+
|
| 332 |
+
for key in hdf5_file:
|
| 333 |
+
if key in keys_to_ignore:
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
curr_data = hdf5_file[key]
|
| 337 |
+
if isinstance(curr_data, h5py.Group):
|
| 338 |
+
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
|
| 339 |
+
elif isinstance(curr_data, h5py.Dataset):
|
| 340 |
+
curr_length = len(curr_data)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError
|
| 343 |
+
|
| 344 |
+
if length is None:
|
| 345 |
+
length = curr_length
|
| 346 |
+
assert curr_length == length
|
| 347 |
+
|
| 348 |
+
return length
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
|
| 352 |
+
data_dict = {}
|
| 353 |
+
|
| 354 |
+
for key in hdf5_file:
|
| 355 |
+
if key in keys_to_ignore:
|
| 356 |
+
continue
|
| 357 |
+
|
| 358 |
+
curr_data = hdf5_file[key]
|
| 359 |
+
if isinstance(curr_data, h5py.Group):
|
| 360 |
+
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
|
| 361 |
+
elif isinstance(curr_data, h5py.Dataset):
|
| 362 |
+
data_dict[key] = curr_data[index]
|
| 363 |
+
else:
|
| 364 |
+
raise ValueError
|
| 365 |
+
|
| 366 |
+
return data_dict
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class TrajectoryReader:
|
| 370 |
+
def __init__(self, filepath, read_images=True): # noqa: FBT002
|
| 371 |
+
self._hdf5_file = h5py.File(filepath, "r")
|
| 372 |
+
is_video_folder = "observations/videos" in self._hdf5_file
|
| 373 |
+
self._read_images = read_images and is_video_folder
|
| 374 |
+
self._length = get_hdf5_length(self._hdf5_file)
|
| 375 |
+
self._video_readers = {}
|
| 376 |
+
self._index = 0
|
| 377 |
+
|
| 378 |
+
def length(self):
|
| 379 |
+
return self._length
|
| 380 |
+
|
| 381 |
+
def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
|
| 382 |
+
# Make Sure We Read Within Range #
|
| 383 |
+
if index is None:
|
| 384 |
+
index = self._index
|
| 385 |
+
else:
|
| 386 |
+
assert not self._read_images
|
| 387 |
+
self._index = index
|
| 388 |
+
assert index < self._length
|
| 389 |
+
|
| 390 |
+
# Load Low Dimensional Data #
|
| 391 |
+
keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
|
| 392 |
+
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
|
| 393 |
+
|
| 394 |
+
# Increment Read Index #
|
| 395 |
+
self._index += 1
|
| 396 |
+
|
| 397 |
+
# Return Timestep #
|
| 398 |
+
return timestep
|
| 399 |
+
|
| 400 |
+
def close(self):
|
| 401 |
+
self._hdf5_file.close()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def load_trajectory(
|
| 405 |
+
filepath=None,
|
| 406 |
+
read_cameras=True, # noqa: FBT002
|
| 407 |
+
recording_folderpath=None,
|
| 408 |
+
camera_kwargs={}, # noqa: B006
|
| 409 |
+
remove_skipped_steps=False, # noqa: FBT002
|
| 410 |
+
num_samples_per_traj=None,
|
| 411 |
+
num_samples_per_traj_coeff=1.5,
|
| 412 |
+
):
|
| 413 |
+
read_recording_folderpath = read_cameras and (recording_folderpath is not None)
|
| 414 |
+
|
| 415 |
+
traj_reader = TrajectoryReader(filepath)
|
| 416 |
+
if read_recording_folderpath:
|
| 417 |
+
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
|
| 418 |
+
|
| 419 |
+
horizon = traj_reader.length()
|
| 420 |
+
timestep_list = []
|
| 421 |
+
|
| 422 |
+
# Choose Timesteps To Save #
|
| 423 |
+
if num_samples_per_traj:
|
| 424 |
+
num_to_save = num_samples_per_traj
|
| 425 |
+
if remove_skipped_steps:
|
| 426 |
+
num_to_save = int(num_to_save * num_samples_per_traj_coeff)
|
| 427 |
+
max_size = min(num_to_save, horizon)
|
| 428 |
+
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
|
| 429 |
+
else:
|
| 430 |
+
indices_to_save = np.arange(horizon)
|
| 431 |
+
|
| 432 |
+
# Iterate Over Trajectory #
|
| 433 |
+
for i in indices_to_save:
|
| 434 |
+
# Get HDF5 Data #
|
| 435 |
+
timestep = traj_reader.read_timestep(index=i)
|
| 436 |
+
|
| 437 |
+
# If Applicable, Get Recorded Data #
|
| 438 |
+
if read_recording_folderpath:
|
| 439 |
+
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
|
| 440 |
+
camera_type_dict = {
|
| 441 |
+
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
|
| 442 |
+
}
|
| 443 |
+
camera_obs = camera_reader.read_cameras(
|
| 444 |
+
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
|
| 445 |
+
)
|
| 446 |
+
camera_failed = camera_obs is None
|
| 447 |
+
|
| 448 |
+
# Add Data To Timestep If Successful #
|
| 449 |
+
if camera_failed:
|
| 450 |
+
break
|
| 451 |
+
timestep["observation"].update(camera_obs)
|
| 452 |
+
|
| 453 |
+
# Filter Steps #
|
| 454 |
+
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
|
| 455 |
+
delete_skipped_step = step_skipped and remove_skipped_steps
|
| 456 |
+
|
| 457 |
+
# Save Filtered Timesteps #
|
| 458 |
+
if delete_skipped_step:
|
| 459 |
+
del timestep
|
| 460 |
+
else:
|
| 461 |
+
timestep_list.append(timestep)
|
| 462 |
+
|
| 463 |
+
# Remove Extra Transitions #
|
| 464 |
+
timestep_list = np.array(timestep_list)
|
| 465 |
+
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
|
| 466 |
+
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
|
| 467 |
+
timestep_list = timestep_list[ind_to_keep]
|
| 468 |
+
|
| 469 |
+
# Close Readers #
|
| 470 |
+
traj_reader.close()
|
| 471 |
+
|
| 472 |
+
# Return Data #
|
| 473 |
+
return timestep_list
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
if __name__ == "__main__":
|
| 477 |
+
tyro.cli(main)
|
openpi/examples/droid/main.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import dataclasses
|
| 5 |
+
import datetime
|
| 6 |
+
import faulthandler
|
| 7 |
+
import os
|
| 8 |
+
import signal
|
| 9 |
+
import time
|
| 10 |
+
from moviepy.editor import ImageSequenceClip
|
| 11 |
+
import numpy as np
|
| 12 |
+
from openpi_client import image_tools
|
| 13 |
+
from openpi_client import websocket_client_policy
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from droid.robot_env import RobotEnv
|
| 17 |
+
import tqdm
|
| 18 |
+
import tyro
|
| 19 |
+
|
| 20 |
+
faulthandler.enable()
|
| 21 |
+
|
| 22 |
+
# DROID data collection frequency -- we slow down execution to match this frequency
|
| 23 |
+
DROID_CONTROL_FREQUENCY = 15
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass
|
| 27 |
+
class Args:
|
| 28 |
+
# Hardware parameters
|
| 29 |
+
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
| 30 |
+
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
| 31 |
+
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
| 32 |
+
|
| 33 |
+
# Policy parameters
|
| 34 |
+
external_camera: str | None = (
|
| 35 |
+
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Rollout parameters
|
| 39 |
+
max_timesteps: int = 600
|
| 40 |
+
# How many actions to execute from a predicted action chunk before querying policy server again
|
| 41 |
+
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
| 42 |
+
open_loop_horizon: int = 8
|
| 43 |
+
|
| 44 |
+
# Remote server parameters
|
| 45 |
+
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
| 46 |
+
remote_port: int = (
|
| 47 |
+
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
| 52 |
+
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
| 53 |
+
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
| 54 |
+
@contextlib.contextmanager
|
| 55 |
+
def prevent_keyboard_interrupt():
|
| 56 |
+
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
| 57 |
+
interrupted = False
|
| 58 |
+
original_handler = signal.getsignal(signal.SIGINT)
|
| 59 |
+
|
| 60 |
+
def handler(signum, frame):
|
| 61 |
+
nonlocal interrupted
|
| 62 |
+
interrupted = True
|
| 63 |
+
|
| 64 |
+
signal.signal(signal.SIGINT, handler)
|
| 65 |
+
try:
|
| 66 |
+
yield
|
| 67 |
+
finally:
|
| 68 |
+
signal.signal(signal.SIGINT, original_handler)
|
| 69 |
+
if interrupted:
|
| 70 |
+
raise KeyboardInterrupt
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main(args: Args):
|
| 74 |
+
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
| 75 |
+
assert (
|
| 76 |
+
args.external_camera is not None and args.external_camera in ["left", "right"]
|
| 77 |
+
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
| 78 |
+
|
| 79 |
+
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
| 80 |
+
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
| 81 |
+
print("Created the droid env!")
|
| 82 |
+
|
| 83 |
+
# Connect to the policy server
|
| 84 |
+
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
| 85 |
+
|
| 86 |
+
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
| 87 |
+
|
| 88 |
+
while True:
|
| 89 |
+
instruction = input("Enter instruction: ")
|
| 90 |
+
|
| 91 |
+
# Rollout parameters
|
| 92 |
+
actions_from_chunk_completed = 0
|
| 93 |
+
pred_action_chunk = None
|
| 94 |
+
|
| 95 |
+
# Prepare to save video of rollout
|
| 96 |
+
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
| 97 |
+
video = []
|
| 98 |
+
bar = tqdm.tqdm(range(args.max_timesteps))
|
| 99 |
+
print("Running rollout... press Ctrl+C to stop early.")
|
| 100 |
+
for t_step in bar:
|
| 101 |
+
start_time = time.time()
|
| 102 |
+
try:
|
| 103 |
+
# Get the current observation
|
| 104 |
+
curr_obs = _extract_observation(
|
| 105 |
+
args,
|
| 106 |
+
env.get_observation(),
|
| 107 |
+
# Save the first observation to disk
|
| 108 |
+
save_to_disk=t_step == 0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
video.append(curr_obs[f"{args.external_camera}_image"])
|
| 112 |
+
|
| 113 |
+
# Send websocket request to policy server if it's time to predict a new chunk
|
| 114 |
+
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
| 115 |
+
actions_from_chunk_completed = 0
|
| 116 |
+
|
| 117 |
+
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
| 118 |
+
# and improve latency.
|
| 119 |
+
request_data = {
|
| 120 |
+
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
| 121 |
+
curr_obs[f"{args.external_camera}_image"], 224, 224
|
| 122 |
+
),
|
| 123 |
+
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
| 124 |
+
"observation/joint_position": curr_obs["joint_position"],
|
| 125 |
+
"observation/gripper_position": curr_obs["gripper_position"],
|
| 126 |
+
"prompt": instruction,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
| 130 |
+
# Ctrl+C will be handled after the server call is complete
|
| 131 |
+
with prevent_keyboard_interrupt():
|
| 132 |
+
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
| 133 |
+
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
| 134 |
+
assert pred_action_chunk.shape == (10, 8)
|
| 135 |
+
|
| 136 |
+
# Select current action to execute from chunk
|
| 137 |
+
action = pred_action_chunk[actions_from_chunk_completed]
|
| 138 |
+
actions_from_chunk_completed += 1
|
| 139 |
+
|
| 140 |
+
# Binarize gripper action
|
| 141 |
+
if action[-1].item() > 0.5:
|
| 142 |
+
# action[-1] = 1.0
|
| 143 |
+
action = np.concatenate([action[:-1], np.ones((1,))])
|
| 144 |
+
else:
|
| 145 |
+
# action[-1] = 0.0
|
| 146 |
+
action = np.concatenate([action[:-1], np.zeros((1,))])
|
| 147 |
+
|
| 148 |
+
# clip all dimensions of action to [-1, 1]
|
| 149 |
+
action = np.clip(action, -1, 1)
|
| 150 |
+
|
| 151 |
+
env.step(action)
|
| 152 |
+
|
| 153 |
+
# Sleep to match DROID data collection frequency
|
| 154 |
+
elapsed_time = time.time() - start_time
|
| 155 |
+
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
| 156 |
+
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
| 157 |
+
except KeyboardInterrupt:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
video = np.stack(video)
|
| 161 |
+
save_filename = "video_" + timestamp
|
| 162 |
+
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
| 163 |
+
|
| 164 |
+
success: str | float | None = None
|
| 165 |
+
while not isinstance(success, float):
|
| 166 |
+
success = input(
|
| 167 |
+
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
| 168 |
+
)
|
| 169 |
+
if success == "y":
|
| 170 |
+
success = 1.0
|
| 171 |
+
elif success == "n":
|
| 172 |
+
success = 0.0
|
| 173 |
+
|
| 174 |
+
success = float(success) / 100
|
| 175 |
+
if not (0 <= success <= 1):
|
| 176 |
+
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
| 177 |
+
|
| 178 |
+
df = df.append(
|
| 179 |
+
{
|
| 180 |
+
"success": success,
|
| 181 |
+
"duration": t_step,
|
| 182 |
+
"video_filename": save_filename,
|
| 183 |
+
},
|
| 184 |
+
ignore_index=True,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
| 188 |
+
break
|
| 189 |
+
env.reset()
|
| 190 |
+
|
| 191 |
+
os.makedirs("results", exist_ok=True)
|
| 192 |
+
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
| 193 |
+
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
| 194 |
+
df.to_csv(csv_filename)
|
| 195 |
+
print(f"Results saved to {csv_filename}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
| 199 |
+
image_observations = obs_dict["image"]
|
| 200 |
+
left_image, right_image, wrist_image = None, None, None
|
| 201 |
+
for key in image_observations:
|
| 202 |
+
# Note the "left" below refers to the left camera in the stereo pair.
|
| 203 |
+
# The model is only trained on left stereo cams, so we only feed those.
|
| 204 |
+
if args.left_camera_id in key and "left" in key:
|
| 205 |
+
left_image = image_observations[key]
|
| 206 |
+
elif args.right_camera_id in key and "left" in key:
|
| 207 |
+
right_image = image_observations[key]
|
| 208 |
+
elif args.wrist_camera_id in key and "left" in key:
|
| 209 |
+
wrist_image = image_observations[key]
|
| 210 |
+
|
| 211 |
+
# Drop the alpha dimension
|
| 212 |
+
left_image = left_image[..., :3]
|
| 213 |
+
right_image = right_image[..., :3]
|
| 214 |
+
wrist_image = wrist_image[..., :3]
|
| 215 |
+
|
| 216 |
+
# Convert to RGB
|
| 217 |
+
left_image = left_image[..., ::-1]
|
| 218 |
+
right_image = right_image[..., ::-1]
|
| 219 |
+
wrist_image = wrist_image[..., ::-1]
|
| 220 |
+
|
| 221 |
+
# In addition to image observations, also capture the proprioceptive state
|
| 222 |
+
robot_state = obs_dict["robot_state"]
|
| 223 |
+
cartesian_position = np.array(robot_state["cartesian_position"])
|
| 224 |
+
joint_position = np.array(robot_state["joint_positions"])
|
| 225 |
+
gripper_position = np.array([robot_state["gripper_position"]])
|
| 226 |
+
|
| 227 |
+
# Save the images to disk so that they can be viewed live while the robot is running
|
| 228 |
+
# Create one combined image to make live viewing easy
|
| 229 |
+
if save_to_disk:
|
| 230 |
+
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
| 231 |
+
combined_image = Image.fromarray(combined_image)
|
| 232 |
+
combined_image.save("robot_camera_views.png")
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
"left_image": left_image,
|
| 236 |
+
"right_image": right_image,
|
| 237 |
+
"wrist_image": wrist_image,
|
| 238 |
+
"cartesian_position": cartesian_position,
|
| 239 |
+
"joint_position": joint_position,
|
| 240 |
+
"gripper_position": gripper_position,
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
args: Args = tyro.cli(Args)
|
| 246 |
+
main(args)
|