Add custom TWIN conversion + configs; document experimental status
Browse files
openpi/README.md
CHANGED
|
@@ -321,3 +321,15 @@ We will collect common issues and their solutions here. If you encounter an issu
|
|
| 321 |
| Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
|
| 322 |
| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
|
| 323 |
| Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
| Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
|
| 322 |
| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
|
| 323 |
| Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
|
| 324 |
+
|
| 325 |
+
## Multiarm/TWIN Custom Additions (Experimental)
|
| 326 |
+
|
| 327 |
+
These multiarm + TWIN additions were custom-integrated for research experiments and are **not** an official upstream openpi release.
|
| 328 |
+
|
| 329 |
+
- Includes a custom TWIN->LeRobot conversion script at `openpi/scripts/convert_twin_squashfs_to_lerobot.py`.
|
| 330 |
+
- Includes custom training config entries in `openpi/src/openpi/training/config.py` (e.g. `pi05_twin_bimanual_parallel_finetune`).
|
| 331 |
+
- Intended for rapid experimentation; edge cases may still exist.
|
| 332 |
+
- Behavior is not guaranteed to be flawless across all datasets/environments without further validation.
|
| 333 |
+
|
| 334 |
+
Updated: 2026-03-05
|
| 335 |
+
|
openpi/scripts/convert_twin_squashfs_to_lerobot.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Convert TWIN bimanual RLBench squashfs files to a LeRobot dataset.
|
| 4 |
+
|
| 5 |
+
This script intentionally supports bounded conversion via --max-episodes/--max-frames
|
| 6 |
+
for local stress testing before running full conversion on larger machines.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import dataclasses
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import pickle
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import re
|
| 18 |
+
import shutil
|
| 19 |
+
import subprocess
|
| 20 |
+
import tempfile
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
| 24 |
+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
| 25 |
+
import numpy as np
|
| 26 |
+
from PIL import Image
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclasses.dataclass(frozen=True)
|
| 30 |
+
class ConverterStats:
|
| 31 |
+
episodes_seen: int = 0
|
| 32 |
+
episodes_written: int = 0
|
| 33 |
+
frames_written: int = 0
|
| 34 |
+
frames_skipped: int = 0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class _FallbackUnpickler(pickle.Unpickler):
|
| 38 |
+
"""Loads RLBench pickles even when rlbench is not installed."""
|
| 39 |
+
|
| 40 |
+
_cache: dict[tuple[str, str], type] = {}
|
| 41 |
+
|
| 42 |
+
def find_class(self, module: str, name: str) -> Any:
|
| 43 |
+
try:
|
| 44 |
+
return super().find_class(module, name)
|
| 45 |
+
except Exception:
|
| 46 |
+
key = (module, name)
|
| 47 |
+
if key not in self._cache:
|
| 48 |
+
cls = type(name, (), {})
|
| 49 |
+
cls.__module__ = module
|
| 50 |
+
self._cache[key] = cls
|
| 51 |
+
return self._cache[key]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _load_pickle(path: Path) -> Any:
|
| 55 |
+
with path.open("rb") as f:
|
| 56 |
+
return _FallbackUnpickler(f).load()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _run_unsquashfs(squashfs_path: Path, dest_dir: Path, patterns: list[str]) -> None:
|
| 60 |
+
cmd = [
|
| 61 |
+
"unsquashfs",
|
| 62 |
+
"-f",
|
| 63 |
+
"-d",
|
| 64 |
+
str(dest_dir),
|
| 65 |
+
str(squashfs_path),
|
| 66 |
+
*patterns,
|
| 67 |
+
]
|
| 68 |
+
logging.info("Running: %s", " ".join(cmd))
|
| 69 |
+
subprocess.run(cmd, check=True)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _episode_sort_key(ep_name: str) -> int:
|
| 73 |
+
m = re.match(r"episode(\d+)$", ep_name)
|
| 74 |
+
if not m:
|
| 75 |
+
return 10**9
|
| 76 |
+
return int(m.group(1))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _collect_episode_dirs(extract_root: Path) -> list[Path]:
|
| 80 |
+
base = extract_root / "all_variations" / "episodes"
|
| 81 |
+
if not base.exists():
|
| 82 |
+
return []
|
| 83 |
+
episodes = [p for p in base.iterdir() if p.is_dir() and p.name.startswith("episode")]
|
| 84 |
+
return sorted(episodes, key=lambda p: _episode_sort_key(p.name))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _image_array(path: Path) -> np.ndarray:
|
| 88 |
+
return np.asarray(Image.open(path).convert("RGB"))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _to_state(obs: Any) -> np.ndarray:
|
| 92 |
+
left = obs.left
|
| 93 |
+
right = obs.right
|
| 94 |
+
return np.concatenate(
|
| 95 |
+
[
|
| 96 |
+
np.asarray(left.joint_positions, dtype=np.float32),
|
| 97 |
+
np.asarray([left.gripper_open], dtype=np.float32),
|
| 98 |
+
np.asarray(right.joint_positions, dtype=np.float32),
|
| 99 |
+
np.asarray([right.gripper_open], dtype=np.float32),
|
| 100 |
+
],
|
| 101 |
+
dtype=np.float32,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _to_action(obs: Any) -> np.ndarray:
|
| 106 |
+
left = obs.left
|
| 107 |
+
right = obs.right
|
| 108 |
+
return np.concatenate(
|
| 109 |
+
[
|
| 110 |
+
np.asarray(left.joint_velocities, dtype=np.float32),
|
| 111 |
+
np.asarray([left.gripper_open], dtype=np.float32),
|
| 112 |
+
np.asarray(right.joint_velocities, dtype=np.float32),
|
| 113 |
+
np.asarray([right.gripper_open], dtype=np.float32),
|
| 114 |
+
],
|
| 115 |
+
dtype=np.float32,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _select_prompt(variation_descriptions: Any) -> str:
|
| 120 |
+
if isinstance(variation_descriptions, (list, tuple)) and variation_descriptions:
|
| 121 |
+
for item in variation_descriptions:
|
| 122 |
+
if isinstance(item, str) and item.strip():
|
| 123 |
+
return item.strip()
|
| 124 |
+
return "perform the task"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def convert(
|
| 128 |
+
squashfs_path: Path,
|
| 129 |
+
repo_id: str,
|
| 130 |
+
*,
|
| 131 |
+
cameras: list[str],
|
| 132 |
+
max_episodes: int | None,
|
| 133 |
+
max_frames: int | None,
|
| 134 |
+
fps: int,
|
| 135 |
+
push_to_hub: bool,
|
| 136 |
+
private: bool,
|
| 137 |
+
cleanup_output: bool,
|
| 138 |
+
) -> ConverterStats:
|
| 139 |
+
output_path = HF_LEROBOT_HOME / repo_id
|
| 140 |
+
if cleanup_output and output_path.exists():
|
| 141 |
+
shutil.rmtree(output_path)
|
| 142 |
+
|
| 143 |
+
with tempfile.TemporaryDirectory(prefix="twin_unsquash_") as tmp:
|
| 144 |
+
extract_root = Path(tmp)
|
| 145 |
+
|
| 146 |
+
# 1) Extract episode metadata only.
|
| 147 |
+
_run_unsquashfs(
|
| 148 |
+
squashfs_path,
|
| 149 |
+
extract_root,
|
| 150 |
+
[
|
| 151 |
+
"all_variations/episodes/episode*/low_dim_obs.pkl",
|
| 152 |
+
"all_variations/episodes/episode*/variation_descriptions.pkl",
|
| 153 |
+
"all_variations/episodes/episode*/variation_number.pkl",
|
| 154 |
+
],
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
episode_dirs = _collect_episode_dirs(extract_root)
|
| 158 |
+
if max_episodes is not None:
|
| 159 |
+
episode_dirs = episode_dirs[:max_episodes]
|
| 160 |
+
if not episode_dirs:
|
| 161 |
+
raise RuntimeError("No episodes found after metadata extraction.")
|
| 162 |
+
|
| 163 |
+
# 2) Extract RGB images only for selected episodes/cameras.
|
| 164 |
+
image_patterns: list[str] = []
|
| 165 |
+
for ep in episode_dirs:
|
| 166 |
+
for camera in cameras:
|
| 167 |
+
image_patterns.append(f"all_variations/episodes/{ep.name}/{camera}_rgb/*.png")
|
| 168 |
+
_run_unsquashfs(squashfs_path, extract_root, image_patterns)
|
| 169 |
+
|
| 170 |
+
# Determine image shapes from first valid frame.
|
| 171 |
+
first_frame = None
|
| 172 |
+
for ep in episode_dirs:
|
| 173 |
+
for camera in cameras:
|
| 174 |
+
candidate = ep / f"{camera}_rgb" / "rgb_0000.png"
|
| 175 |
+
if candidate.exists():
|
| 176 |
+
first_frame = candidate
|
| 177 |
+
break
|
| 178 |
+
if first_frame is not None:
|
| 179 |
+
break
|
| 180 |
+
if first_frame is None:
|
| 181 |
+
raise RuntimeError("No RGB frames found in extracted episodes.")
|
| 182 |
+
h, w, c = _image_array(first_frame).shape
|
| 183 |
+
if c != 3:
|
| 184 |
+
raise RuntimeError(f"Expected RGB images with 3 channels, got shape {(h, w, c)}")
|
| 185 |
+
|
| 186 |
+
features: dict[str, dict[str, Any]] = {
|
| 187 |
+
f"{camera}_image": {
|
| 188 |
+
"dtype": "image",
|
| 189 |
+
"shape": (h, w, 3),
|
| 190 |
+
"names": ["height", "width", "channel"],
|
| 191 |
+
}
|
| 192 |
+
for camera in cameras
|
| 193 |
+
}
|
| 194 |
+
features["state"] = {
|
| 195 |
+
"dtype": "float32",
|
| 196 |
+
"shape": (16,),
|
| 197 |
+
"names": [f"state_{i}" for i in range(16)],
|
| 198 |
+
}
|
| 199 |
+
features["action"] = {
|
| 200 |
+
"dtype": "float32",
|
| 201 |
+
"shape": (16,),
|
| 202 |
+
"names": [f"action_{i}" for i in range(16)],
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
dataset = LeRobotDataset.create(
|
| 206 |
+
repo_id=repo_id,
|
| 207 |
+
robot_type="rlbench_bimanual",
|
| 208 |
+
fps=fps,
|
| 209 |
+
features=features,
|
| 210 |
+
image_writer_threads=min(8, os.cpu_count() or 4),
|
| 211 |
+
image_writer_processes=1,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
stats = ConverterStats()
|
| 215 |
+
for ep_dir in episode_dirs:
|
| 216 |
+
stats = dataclasses.replace(stats, episodes_seen=stats.episodes_seen + 1)
|
| 217 |
+
low_dim_obs = _load_pickle(ep_dir / "low_dim_obs.pkl")
|
| 218 |
+
variation_descriptions = _load_pickle(ep_dir / "variation_descriptions.pkl")
|
| 219 |
+
prompt = _select_prompt(variation_descriptions)
|
| 220 |
+
|
| 221 |
+
observations = getattr(low_dim_obs, "_observations", None)
|
| 222 |
+
if observations is None:
|
| 223 |
+
logging.warning("Skipping %s: missing _observations.", ep_dir.name)
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
frame_limit = len(observations)
|
| 227 |
+
if max_frames is not None:
|
| 228 |
+
frame_limit = min(frame_limit, max_frames)
|
| 229 |
+
|
| 230 |
+
written_in_episode = 0
|
| 231 |
+
for i in range(frame_limit):
|
| 232 |
+
frame_data: dict[str, Any] = {}
|
| 233 |
+
missing = False
|
| 234 |
+
for camera in cameras:
|
| 235 |
+
frame_path = ep_dir / f"{camera}_rgb" / f"rgb_{i:04d}.png"
|
| 236 |
+
if not frame_path.exists():
|
| 237 |
+
missing = True
|
| 238 |
+
break
|
| 239 |
+
frame_data[f"{camera}_image"] = _image_array(frame_path)
|
| 240 |
+
if missing:
|
| 241 |
+
stats = dataclasses.replace(stats, frames_skipped=stats.frames_skipped + 1)
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
obs = observations[i]
|
| 245 |
+
frame_data["state"] = _to_state(obs)
|
| 246 |
+
frame_data["action"] = _to_action(obs)
|
| 247 |
+
frame_data["task"] = prompt
|
| 248 |
+
|
| 249 |
+
dataset.add_frame(frame_data)
|
| 250 |
+
stats = dataclasses.replace(stats, frames_written=stats.frames_written + 1)
|
| 251 |
+
written_in_episode += 1
|
| 252 |
+
|
| 253 |
+
if written_in_episode > 0:
|
| 254 |
+
dataset.save_episode()
|
| 255 |
+
stats = dataclasses.replace(stats, episodes_written=stats.episodes_written + 1)
|
| 256 |
+
|
| 257 |
+
if push_to_hub:
|
| 258 |
+
dataset.push_to_hub(
|
| 259 |
+
private=private,
|
| 260 |
+
push_videos=True,
|
| 261 |
+
tags=["twin", "bimanual", "rlbench", "lerobot", "openpi"],
|
| 262 |
+
license="apache-2.0",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return stats
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _parse_args() -> argparse.Namespace:
|
| 269 |
+
parser = argparse.ArgumentParser()
|
| 270 |
+
parser.add_argument("--squashfs-path", type=Path, required=True, help="Path to one TWIN squashfs file.")
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--repo-id",
|
| 273 |
+
required=True,
|
| 274 |
+
help="LeRobot repo id (e.g. your_hf_username/twin_bimanual_dual_push_train).",
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument(
|
| 277 |
+
"--cameras",
|
| 278 |
+
default="front,wrist_left,wrist_right",
|
| 279 |
+
help="Comma-separated camera names, without '_rgb' suffix.",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument("--max-episodes", type=int, default=None, help="Limit number of episodes to convert.")
|
| 282 |
+
parser.add_argument("--max-frames", type=int, default=None, help="Limit number of frames per episode.")
|
| 283 |
+
parser.add_argument("--fps", type=int, default=10)
|
| 284 |
+
parser.add_argument("--push-to-hub", action="store_true")
|
| 285 |
+
parser.add_argument("--private", action="store_true")
|
| 286 |
+
parser.add_argument("--no-cleanup-output", action="store_true")
|
| 287 |
+
parser.add_argument("--verbose", action="store_true")
|
| 288 |
+
return parser.parse_args()
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def main() -> int:
|
| 292 |
+
args = _parse_args()
|
| 293 |
+
logging.basicConfig(level=logging.INFO if args.verbose else logging.WARNING, format="%(levelname)s: %(message)s")
|
| 294 |
+
cameras = [c.strip() for c in args.cameras.split(",") if c.strip()]
|
| 295 |
+
if not cameras:
|
| 296 |
+
raise ValueError("At least one camera must be specified in --cameras.")
|
| 297 |
+
|
| 298 |
+
stats = convert(
|
| 299 |
+
squashfs_path=args.squashfs_path,
|
| 300 |
+
repo_id=args.repo_id,
|
| 301 |
+
cameras=cameras,
|
| 302 |
+
max_episodes=args.max_episodes,
|
| 303 |
+
max_frames=args.max_frames,
|
| 304 |
+
fps=args.fps,
|
| 305 |
+
push_to_hub=args.push_to_hub,
|
| 306 |
+
private=args.private,
|
| 307 |
+
cleanup_output=not args.no_cleanup_output,
|
| 308 |
+
)
|
| 309 |
+
print(
|
| 310 |
+
"Conversion complete:",
|
| 311 |
+
f"episodes_seen={stats.episodes_seen}",
|
| 312 |
+
f"episodes_written={stats.episodes_written}",
|
| 313 |
+
f"frames_written={stats.frames_written}",
|
| 314 |
+
f"frames_skipped={stats.frames_skipped}",
|
| 315 |
+
)
|
| 316 |
+
return 0
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
raise SystemExit(main())
|
openpi/src/openpi/training/config.py
CHANGED
|
@@ -462,6 +462,47 @@ class LeRobotDROIDDataConfig(DataConfigFactory):
|
|
| 462 |
)
|
| 463 |
|
| 464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
@dataclasses.dataclass(frozen=True)
|
| 466 |
class TrainConfig:
|
| 467 |
# Name of the config. Must be unique. Will be used to reference this config.
|
|
@@ -938,6 +979,39 @@ _CONFIGS = [
|
|
| 938 |
num_train_steps=20_000,
|
| 939 |
batch_size=16,
|
| 940 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
#
|
| 942 |
# ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
|
| 943 |
#
|
|
@@ -1010,6 +1084,33 @@ _CONFIGS = [
|
|
| 1010 |
wandb_enabled=False,
|
| 1011 |
pytorch_training_precision="float32",
|
| 1012 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1013 |
# RoboArena & PolaRiS configs.
|
| 1014 |
*roboarena_config.get_roboarena_configs(),
|
| 1015 |
*polaris_config.get_polaris_configs(),
|
|
|
|
| 462 |
)
|
| 463 |
|
| 464 |
|
| 465 |
+
@dataclasses.dataclass(frozen=True)
|
| 466 |
+
class LeRobotTWINBimanualDataConfig(DataConfigFactory):
|
| 467 |
+
"""
|
| 468 |
+
Data config for TWIN bimanual datasets converted to LeRobot format via
|
| 469 |
+
scripts/convert_twin_squashfs_to_lerobot.py.
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
default_prompt: str | None = None
|
| 473 |
+
|
| 474 |
+
@override
|
| 475 |
+
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
| 476 |
+
repack_transform = _transforms.Group(
|
| 477 |
+
inputs=[
|
| 478 |
+
_transforms.RepackTransform(
|
| 479 |
+
{
|
| 480 |
+
"images": {
|
| 481 |
+
"cam_high": "front_image",
|
| 482 |
+
"cam_left_wrist": "wrist_left_image",
|
| 483 |
+
"cam_right_wrist": "wrist_right_image",
|
| 484 |
+
},
|
| 485 |
+
"state": "state",
|
| 486 |
+
"actions": "action",
|
| 487 |
+
"prompt": "task",
|
| 488 |
+
}
|
| 489 |
+
)
|
| 490 |
+
]
|
| 491 |
+
)
|
| 492 |
+
data_transforms = _transforms.Group(
|
| 493 |
+
inputs=[aloha_policy.AlohaInputs(adapt_to_pi=False)],
|
| 494 |
+
outputs=[],
|
| 495 |
+
)
|
| 496 |
+
model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)
|
| 497 |
+
return dataclasses.replace(
|
| 498 |
+
self.create_base_config(assets_dirs, model_config),
|
| 499 |
+
repack_transforms=repack_transform,
|
| 500 |
+
data_transforms=data_transforms,
|
| 501 |
+
model_transforms=model_transforms,
|
| 502 |
+
action_sequence_keys=("action",),
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
@dataclasses.dataclass(frozen=True)
|
| 507 |
class TrainConfig:
|
| 508 |
# Name of the config. Must be unique. Will be used to reference this config.
|
|
|
|
| 979 |
num_train_steps=20_000,
|
| 980 |
batch_size=16,
|
| 981 |
),
|
| 982 |
+
TrainConfig(
|
| 983 |
+
# Baseline pi05 fine-tuning on TWIN bimanual LeRobot data (single action head).
|
| 984 |
+
name="pi05_twin_bimanual_finetune",
|
| 985 |
+
model=pi0_config.Pi0Config(
|
| 986 |
+
pi05=True,
|
| 987 |
+
action_dim=32, # Keep pi05 pretraining action dimensionality.
|
| 988 |
+
action_horizon=16,
|
| 989 |
+
),
|
| 990 |
+
data=LeRobotTWINBimanualDataConfig(
|
| 991 |
+
repo_id="your_hf_username/twin_bimanual_lerobot_train",
|
| 992 |
+
base_config=DataConfig(prompt_from_task=False),
|
| 993 |
+
),
|
| 994 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 995 |
+
num_train_steps=20_000,
|
| 996 |
+
batch_size=16,
|
| 997 |
+
),
|
| 998 |
+
TrainConfig(
|
| 999 |
+
# Parallel per-arm action-head pi05 fine-tuning on TWIN bimanual LeRobot data.
|
| 1000 |
+
name="pi05_twin_bimanual_parallel_finetune",
|
| 1001 |
+
model=pi0_config.Pi0Config(
|
| 1002 |
+
pi05=True,
|
| 1003 |
+
action_dim=32,
|
| 1004 |
+
action_horizon=16,
|
| 1005 |
+
arm_action_dims=(16, 16),
|
| 1006 |
+
),
|
| 1007 |
+
data=LeRobotTWINBimanualDataConfig(
|
| 1008 |
+
repo_id="your_hf_username/twin_bimanual_lerobot_train",
|
| 1009 |
+
base_config=DataConfig(prompt_from_task=False),
|
| 1010 |
+
),
|
| 1011 |
+
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
|
| 1012 |
+
num_train_steps=20_000,
|
| 1013 |
+
batch_size=16,
|
| 1014 |
+
),
|
| 1015 |
#
|
| 1016 |
# ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
|
| 1017 |
#
|
|
|
|
| 1084 |
wandb_enabled=False,
|
| 1085 |
pytorch_training_precision="float32",
|
| 1086 |
),
|
| 1087 |
+
TrainConfig(
|
| 1088 |
+
# Local smoke-test for converted TWIN LeRobot data.
|
| 1089 |
+
name="debug_pi05_twin_bimanual_parallel_local_smoke",
|
| 1090 |
+
model=pi0_config.Pi0Config(
|
| 1091 |
+
pi05=True,
|
| 1092 |
+
paligemma_variant="dummy",
|
| 1093 |
+
action_expert_variant="dummy",
|
| 1094 |
+
action_dim=32,
|
| 1095 |
+
action_horizon=8,
|
| 1096 |
+
max_token_len=64,
|
| 1097 |
+
arm_action_dims=(16, 16),
|
| 1098 |
+
),
|
| 1099 |
+
data=LeRobotTWINBimanualDataConfig(
|
| 1100 |
+
# This repo id is produced by scripts/convert_twin_squashfs_to_lerobot.py in local smoke mode.
|
| 1101 |
+
repo_id="local/twin_bimanual_dual_push_smoke",
|
| 1102 |
+
base_config=DataConfig(prompt_from_task=False),
|
| 1103 |
+
),
|
| 1104 |
+
batch_size=1,
|
| 1105 |
+
num_workers=0,
|
| 1106 |
+
num_train_steps=2,
|
| 1107 |
+
log_interval=1,
|
| 1108 |
+
save_interval=1,
|
| 1109 |
+
overwrite=True,
|
| 1110 |
+
exp_name="debug_pi05_twin_bimanual_parallel_local_smoke",
|
| 1111 |
+
wandb_enabled=False,
|
| 1112 |
+
pytorch_training_precision="float32",
|
| 1113 |
+
),
|
| 1114 |
# RoboArena & PolaRiS configs.
|
| 1115 |
*roboarena_config.get_roboarena_configs(),
|
| 1116 |
*polaris_config.get_polaris_configs(),
|