lsnu commited on
Commit
a431815
·
verified ·
1 Parent(s): 724cd54

Add custom TWIN conversion + configs; document experimental status

Browse files
openpi/README.md CHANGED
@@ -321,3 +321,15 @@ We will collect common issues and their solutions here. If you encounter an issu
321
  | Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
322
  | Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
323
  | Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  | Import errors when running examples | Make sure you've installed all dependencies with `uv sync`. Some examples may have additional requirements listed in their READMEs. |
322
  | Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
323
  | Diverging training loss | Check the `q01`, `q99`, and `std` values in `norm_stats.json` for your dataset. Certain dimensions that are rarely used can end up with very small `q01`, `q99`, or `std` values, leading to huge states and actions after normalization. You can manually adjust the norm stats as a workaround. |
324
+
325
+ ## Multiarm/TWIN Custom Additions (Experimental)
326
+
327
+ These multiarm + TWIN additions were custom-integrated for research experiments and are **not** an official upstream openpi release.
328
+
329
+ - Includes a custom TWIN->LeRobot conversion script at `openpi/scripts/convert_twin_squashfs_to_lerobot.py`.
330
+ - Includes custom training config entries in `openpi/src/openpi/training/config.py` (e.g. `pi05_twin_bimanual_parallel_finetune`).
331
+ - Intended for rapid experimentation; edge cases may still exist.
332
+ - Behavior is not guaranteed to be flawless across all datasets/environments without further validation.
333
+
334
+ Updated: 2026-03-05
335
+
openpi/scripts/convert_twin_squashfs_to_lerobot.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert TWIN bimanual RLBench squashfs files to a LeRobot dataset.
4
+
5
+ This script intentionally supports bounded conversion via --max-episodes/--max-frames
6
+ for local stress testing before running full conversion on larger machines.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import dataclasses
13
+ import logging
14
+ import os
15
+ import pickle
16
+ from pathlib import Path
17
+ import re
18
+ import shutil
19
+ import subprocess
20
+ import tempfile
21
+ from typing import Any
22
+
23
+ from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
24
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
25
+ import numpy as np
26
+ from PIL import Image
27
+
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class ConverterStats:
31
+ episodes_seen: int = 0
32
+ episodes_written: int = 0
33
+ frames_written: int = 0
34
+ frames_skipped: int = 0
35
+
36
+
37
+ class _FallbackUnpickler(pickle.Unpickler):
38
+ """Loads RLBench pickles even when rlbench is not installed."""
39
+
40
+ _cache: dict[tuple[str, str], type] = {}
41
+
42
+ def find_class(self, module: str, name: str) -> Any:
43
+ try:
44
+ return super().find_class(module, name)
45
+ except Exception:
46
+ key = (module, name)
47
+ if key not in self._cache:
48
+ cls = type(name, (), {})
49
+ cls.__module__ = module
50
+ self._cache[key] = cls
51
+ return self._cache[key]
52
+
53
+
54
+ def _load_pickle(path: Path) -> Any:
55
+ with path.open("rb") as f:
56
+ return _FallbackUnpickler(f).load()
57
+
58
+
59
+ def _run_unsquashfs(squashfs_path: Path, dest_dir: Path, patterns: list[str]) -> None:
60
+ cmd = [
61
+ "unsquashfs",
62
+ "-f",
63
+ "-d",
64
+ str(dest_dir),
65
+ str(squashfs_path),
66
+ *patterns,
67
+ ]
68
+ logging.info("Running: %s", " ".join(cmd))
69
+ subprocess.run(cmd, check=True)
70
+
71
+
72
+ def _episode_sort_key(ep_name: str) -> int:
73
+ m = re.match(r"episode(\d+)$", ep_name)
74
+ if not m:
75
+ return 10**9
76
+ return int(m.group(1))
77
+
78
+
79
+ def _collect_episode_dirs(extract_root: Path) -> list[Path]:
80
+ base = extract_root / "all_variations" / "episodes"
81
+ if not base.exists():
82
+ return []
83
+ episodes = [p for p in base.iterdir() if p.is_dir() and p.name.startswith("episode")]
84
+ return sorted(episodes, key=lambda p: _episode_sort_key(p.name))
85
+
86
+
87
+ def _image_array(path: Path) -> np.ndarray:
88
+ return np.asarray(Image.open(path).convert("RGB"))
89
+
90
+
91
+ def _to_state(obs: Any) -> np.ndarray:
92
+ left = obs.left
93
+ right = obs.right
94
+ return np.concatenate(
95
+ [
96
+ np.asarray(left.joint_positions, dtype=np.float32),
97
+ np.asarray([left.gripper_open], dtype=np.float32),
98
+ np.asarray(right.joint_positions, dtype=np.float32),
99
+ np.asarray([right.gripper_open], dtype=np.float32),
100
+ ],
101
+ dtype=np.float32,
102
+ )
103
+
104
+
105
+ def _to_action(obs: Any) -> np.ndarray:
106
+ left = obs.left
107
+ right = obs.right
108
+ return np.concatenate(
109
+ [
110
+ np.asarray(left.joint_velocities, dtype=np.float32),
111
+ np.asarray([left.gripper_open], dtype=np.float32),
112
+ np.asarray(right.joint_velocities, dtype=np.float32),
113
+ np.asarray([right.gripper_open], dtype=np.float32),
114
+ ],
115
+ dtype=np.float32,
116
+ )
117
+
118
+
119
+ def _select_prompt(variation_descriptions: Any) -> str:
120
+ if isinstance(variation_descriptions, (list, tuple)) and variation_descriptions:
121
+ for item in variation_descriptions:
122
+ if isinstance(item, str) and item.strip():
123
+ return item.strip()
124
+ return "perform the task"
125
+
126
+
127
+ def convert(
128
+ squashfs_path: Path,
129
+ repo_id: str,
130
+ *,
131
+ cameras: list[str],
132
+ max_episodes: int | None,
133
+ max_frames: int | None,
134
+ fps: int,
135
+ push_to_hub: bool,
136
+ private: bool,
137
+ cleanup_output: bool,
138
+ ) -> ConverterStats:
139
+ output_path = HF_LEROBOT_HOME / repo_id
140
+ if cleanup_output and output_path.exists():
141
+ shutil.rmtree(output_path)
142
+
143
+ with tempfile.TemporaryDirectory(prefix="twin_unsquash_") as tmp:
144
+ extract_root = Path(tmp)
145
+
146
+ # 1) Extract episode metadata only.
147
+ _run_unsquashfs(
148
+ squashfs_path,
149
+ extract_root,
150
+ [
151
+ "all_variations/episodes/episode*/low_dim_obs.pkl",
152
+ "all_variations/episodes/episode*/variation_descriptions.pkl",
153
+ "all_variations/episodes/episode*/variation_number.pkl",
154
+ ],
155
+ )
156
+
157
+ episode_dirs = _collect_episode_dirs(extract_root)
158
+ if max_episodes is not None:
159
+ episode_dirs = episode_dirs[:max_episodes]
160
+ if not episode_dirs:
161
+ raise RuntimeError("No episodes found after metadata extraction.")
162
+
163
+ # 2) Extract RGB images only for selected episodes/cameras.
164
+ image_patterns: list[str] = []
165
+ for ep in episode_dirs:
166
+ for camera in cameras:
167
+ image_patterns.append(f"all_variations/episodes/{ep.name}/{camera}_rgb/*.png")
168
+ _run_unsquashfs(squashfs_path, extract_root, image_patterns)
169
+
170
+ # Determine image shapes from first valid frame.
171
+ first_frame = None
172
+ for ep in episode_dirs:
173
+ for camera in cameras:
174
+ candidate = ep / f"{camera}_rgb" / "rgb_0000.png"
175
+ if candidate.exists():
176
+ first_frame = candidate
177
+ break
178
+ if first_frame is not None:
179
+ break
180
+ if first_frame is None:
181
+ raise RuntimeError("No RGB frames found in extracted episodes.")
182
+ h, w, c = _image_array(first_frame).shape
183
+ if c != 3:
184
+ raise RuntimeError(f"Expected RGB images with 3 channels, got shape {(h, w, c)}")
185
+
186
+ features: dict[str, dict[str, Any]] = {
187
+ f"{camera}_image": {
188
+ "dtype": "image",
189
+ "shape": (h, w, 3),
190
+ "names": ["height", "width", "channel"],
191
+ }
192
+ for camera in cameras
193
+ }
194
+ features["state"] = {
195
+ "dtype": "float32",
196
+ "shape": (16,),
197
+ "names": [f"state_{i}" for i in range(16)],
198
+ }
199
+ features["action"] = {
200
+ "dtype": "float32",
201
+ "shape": (16,),
202
+ "names": [f"action_{i}" for i in range(16)],
203
+ }
204
+
205
+ dataset = LeRobotDataset.create(
206
+ repo_id=repo_id,
207
+ robot_type="rlbench_bimanual",
208
+ fps=fps,
209
+ features=features,
210
+ image_writer_threads=min(8, os.cpu_count() or 4),
211
+ image_writer_processes=1,
212
+ )
213
+
214
+ stats = ConverterStats()
215
+ for ep_dir in episode_dirs:
216
+ stats = dataclasses.replace(stats, episodes_seen=stats.episodes_seen + 1)
217
+ low_dim_obs = _load_pickle(ep_dir / "low_dim_obs.pkl")
218
+ variation_descriptions = _load_pickle(ep_dir / "variation_descriptions.pkl")
219
+ prompt = _select_prompt(variation_descriptions)
220
+
221
+ observations = getattr(low_dim_obs, "_observations", None)
222
+ if observations is None:
223
+ logging.warning("Skipping %s: missing _observations.", ep_dir.name)
224
+ continue
225
+
226
+ frame_limit = len(observations)
227
+ if max_frames is not None:
228
+ frame_limit = min(frame_limit, max_frames)
229
+
230
+ written_in_episode = 0
231
+ for i in range(frame_limit):
232
+ frame_data: dict[str, Any] = {}
233
+ missing = False
234
+ for camera in cameras:
235
+ frame_path = ep_dir / f"{camera}_rgb" / f"rgb_{i:04d}.png"
236
+ if not frame_path.exists():
237
+ missing = True
238
+ break
239
+ frame_data[f"{camera}_image"] = _image_array(frame_path)
240
+ if missing:
241
+ stats = dataclasses.replace(stats, frames_skipped=stats.frames_skipped + 1)
242
+ continue
243
+
244
+ obs = observations[i]
245
+ frame_data["state"] = _to_state(obs)
246
+ frame_data["action"] = _to_action(obs)
247
+ frame_data["task"] = prompt
248
+
249
+ dataset.add_frame(frame_data)
250
+ stats = dataclasses.replace(stats, frames_written=stats.frames_written + 1)
251
+ written_in_episode += 1
252
+
253
+ if written_in_episode > 0:
254
+ dataset.save_episode()
255
+ stats = dataclasses.replace(stats, episodes_written=stats.episodes_written + 1)
256
+
257
+ if push_to_hub:
258
+ dataset.push_to_hub(
259
+ private=private,
260
+ push_videos=True,
261
+ tags=["twin", "bimanual", "rlbench", "lerobot", "openpi"],
262
+ license="apache-2.0",
263
+ )
264
+
265
+ return stats
266
+
267
+
268
+ def _parse_args() -> argparse.Namespace:
269
+ parser = argparse.ArgumentParser()
270
+ parser.add_argument("--squashfs-path", type=Path, required=True, help="Path to one TWIN squashfs file.")
271
+ parser.add_argument(
272
+ "--repo-id",
273
+ required=True,
274
+ help="LeRobot repo id (e.g. your_hf_username/twin_bimanual_dual_push_train).",
275
+ )
276
+ parser.add_argument(
277
+ "--cameras",
278
+ default="front,wrist_left,wrist_right",
279
+ help="Comma-separated camera names, without '_rgb' suffix.",
280
+ )
281
+ parser.add_argument("--max-episodes", type=int, default=None, help="Limit number of episodes to convert.")
282
+ parser.add_argument("--max-frames", type=int, default=None, help="Limit number of frames per episode.")
283
+ parser.add_argument("--fps", type=int, default=10)
284
+ parser.add_argument("--push-to-hub", action="store_true")
285
+ parser.add_argument("--private", action="store_true")
286
+ parser.add_argument("--no-cleanup-output", action="store_true")
287
+ parser.add_argument("--verbose", action="store_true")
288
+ return parser.parse_args()
289
+
290
+
291
+ def main() -> int:
292
+ args = _parse_args()
293
+ logging.basicConfig(level=logging.INFO if args.verbose else logging.WARNING, format="%(levelname)s: %(message)s")
294
+ cameras = [c.strip() for c in args.cameras.split(",") if c.strip()]
295
+ if not cameras:
296
+ raise ValueError("At least one camera must be specified in --cameras.")
297
+
298
+ stats = convert(
299
+ squashfs_path=args.squashfs_path,
300
+ repo_id=args.repo_id,
301
+ cameras=cameras,
302
+ max_episodes=args.max_episodes,
303
+ max_frames=args.max_frames,
304
+ fps=args.fps,
305
+ push_to_hub=args.push_to_hub,
306
+ private=args.private,
307
+ cleanup_output=not args.no_cleanup_output,
308
+ )
309
+ print(
310
+ "Conversion complete:",
311
+ f"episodes_seen={stats.episodes_seen}",
312
+ f"episodes_written={stats.episodes_written}",
313
+ f"frames_written={stats.frames_written}",
314
+ f"frames_skipped={stats.frames_skipped}",
315
+ )
316
+ return 0
317
+
318
+
319
+ if __name__ == "__main__":
320
+ raise SystemExit(main())
openpi/src/openpi/training/config.py CHANGED
@@ -462,6 +462,47 @@ class LeRobotDROIDDataConfig(DataConfigFactory):
462
  )
463
 
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  @dataclasses.dataclass(frozen=True)
466
  class TrainConfig:
467
  # Name of the config. Must be unique. Will be used to reference this config.
@@ -938,6 +979,39 @@ _CONFIGS = [
938
  num_train_steps=20_000,
939
  batch_size=16,
940
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
  #
942
  # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
943
  #
@@ -1010,6 +1084,33 @@ _CONFIGS = [
1010
  wandb_enabled=False,
1011
  pytorch_training_precision="float32",
1012
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1013
  # RoboArena & PolaRiS configs.
1014
  *roboarena_config.get_roboarena_configs(),
1015
  *polaris_config.get_polaris_configs(),
 
462
  )
463
 
464
 
465
+ @dataclasses.dataclass(frozen=True)
466
+ class LeRobotTWINBimanualDataConfig(DataConfigFactory):
467
+ """
468
+ Data config for TWIN bimanual datasets converted to LeRobot format via
469
+ scripts/convert_twin_squashfs_to_lerobot.py.
470
+ """
471
+
472
+ default_prompt: str | None = None
473
+
474
+ @override
475
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
476
+ repack_transform = _transforms.Group(
477
+ inputs=[
478
+ _transforms.RepackTransform(
479
+ {
480
+ "images": {
481
+ "cam_high": "front_image",
482
+ "cam_left_wrist": "wrist_left_image",
483
+ "cam_right_wrist": "wrist_right_image",
484
+ },
485
+ "state": "state",
486
+ "actions": "action",
487
+ "prompt": "task",
488
+ }
489
+ )
490
+ ]
491
+ )
492
+ data_transforms = _transforms.Group(
493
+ inputs=[aloha_policy.AlohaInputs(adapt_to_pi=False)],
494
+ outputs=[],
495
+ )
496
+ model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config)
497
+ return dataclasses.replace(
498
+ self.create_base_config(assets_dirs, model_config),
499
+ repack_transforms=repack_transform,
500
+ data_transforms=data_transforms,
501
+ model_transforms=model_transforms,
502
+ action_sequence_keys=("action",),
503
+ )
504
+
505
+
506
  @dataclasses.dataclass(frozen=True)
507
  class TrainConfig:
508
  # Name of the config. Must be unique. Will be used to reference this config.
 
979
  num_train_steps=20_000,
980
  batch_size=16,
981
  ),
982
+ TrainConfig(
983
+ # Baseline pi05 fine-tuning on TWIN bimanual LeRobot data (single action head).
984
+ name="pi05_twin_bimanual_finetune",
985
+ model=pi0_config.Pi0Config(
986
+ pi05=True,
987
+ action_dim=32, # Keep pi05 pretraining action dimensionality.
988
+ action_horizon=16,
989
+ ),
990
+ data=LeRobotTWINBimanualDataConfig(
991
+ repo_id="your_hf_username/twin_bimanual_lerobot_train",
992
+ base_config=DataConfig(prompt_from_task=False),
993
+ ),
994
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
995
+ num_train_steps=20_000,
996
+ batch_size=16,
997
+ ),
998
+ TrainConfig(
999
+ # Parallel per-arm action-head pi05 fine-tuning on TWIN bimanual LeRobot data.
1000
+ name="pi05_twin_bimanual_parallel_finetune",
1001
+ model=pi0_config.Pi0Config(
1002
+ pi05=True,
1003
+ action_dim=32,
1004
+ action_horizon=16,
1005
+ arm_action_dims=(16, 16),
1006
+ ),
1007
+ data=LeRobotTWINBimanualDataConfig(
1008
+ repo_id="your_hf_username/twin_bimanual_lerobot_train",
1009
+ base_config=DataConfig(prompt_from_task=False),
1010
+ ),
1011
+ weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"),
1012
+ num_train_steps=20_000,
1013
+ batch_size=16,
1014
+ ),
1015
  #
1016
  # ALOHA Sim configs. This config is used to demonstrate how to train on a simple simulated environment.
1017
  #
 
1084
  wandb_enabled=False,
1085
  pytorch_training_precision="float32",
1086
  ),
1087
+ TrainConfig(
1088
+ # Local smoke-test for converted TWIN LeRobot data.
1089
+ name="debug_pi05_twin_bimanual_parallel_local_smoke",
1090
+ model=pi0_config.Pi0Config(
1091
+ pi05=True,
1092
+ paligemma_variant="dummy",
1093
+ action_expert_variant="dummy",
1094
+ action_dim=32,
1095
+ action_horizon=8,
1096
+ max_token_len=64,
1097
+ arm_action_dims=(16, 16),
1098
+ ),
1099
+ data=LeRobotTWINBimanualDataConfig(
1100
+ # This repo id is produced by scripts/convert_twin_squashfs_to_lerobot.py in local smoke mode.
1101
+ repo_id="local/twin_bimanual_dual_push_smoke",
1102
+ base_config=DataConfig(prompt_from_task=False),
1103
+ ),
1104
+ batch_size=1,
1105
+ num_workers=0,
1106
+ num_train_steps=2,
1107
+ log_interval=1,
1108
+ save_interval=1,
1109
+ overwrite=True,
1110
+ exp_name="debug_pi05_twin_bimanual_parallel_local_smoke",
1111
+ wandb_enabled=False,
1112
+ pytorch_training_precision="float32",
1113
+ ),
1114
  # RoboArena & PolaRiS configs.
1115
  *roboarena_config.get_roboarena_configs(),
1116
  *polaris_config.get_polaris_configs(),