abdul004 commited on
Commit
724ba0f
·
verified ·
1 Parent(s): edf9d93

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +181 -0
  2. evaluate_pi05.py +122 -0
  3. so101_config.py +117 -0
  4. so101_policy.py +109 -0
  5. test_config_local.py +275 -0
README.md ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pi0.5 Fine-tuning for SO-101
2
+
3
+ Fine-tune Physical Intelligence's Pi0.5 on the SO-101 ball-in-cup task.
4
+
5
+ ## Overview
6
+
7
+ | Item | Value |
8
+ |------|-------|
9
+ | **Base Model** | Pi0.5 (`gs://openpi-assets/checkpoints/pi05_base`) |
10
+ | **Dataset** | `abdul004/so101_ball_in_cup_v5` (72 episodes) |
11
+ | **GPU Required** | A100 80GB (~$1.50/hr on Vast.ai) |
12
+ | **Training Time** | ~2-3 hours for 5K steps |
13
+
14
+ ## Files in This Directory
15
+
16
+ ```
17
+ pi0_so101/
18
+ ├── README.md # This file
19
+ ├── so101_policy.py # Input/output transforms (copy to openpi/src/openpi/policies/)
20
+ └── so101_config.py # Config template (add to openpi/src/openpi/training/config.py)
21
+ ```
22
+
23
+ ## Step-by-Step Setup on Vast.ai
24
+
25
+ ### 1. Rent GPU Instance
26
+
27
+ On [Vast.ai](https://vast.ai), search for:
28
+ - **GPU:** A100 80GB or H100
29
+ - **Disk:** 100GB+
30
+ - **Image:** Any with CUDA (PyTorch image works)
31
+
32
+ ### 2. SSH and Clone OpenPi
33
+
34
+ ```bash
35
+ # Clone with submodules
36
+ git clone --recurse-submodules https://github.com/Physical-Intelligence/openpi.git
37
+ cd openpi
38
+
39
+ # Install uv package manager
40
+ curl -LsSf https://astral.sh/uv/install.sh | sh
41
+ source $HOME/.local/bin/env
42
+
43
+ # Install dependencies
44
+ GIT_LFS_SKIP_SMUDGE=1 uv sync
45
+
46
+ # Login to HuggingFace (for dataset access)
47
+ huggingface-cli login
48
+ ```
49
+
50
+ ### 3. Add SO-101 Config
51
+
52
+ ```bash
53
+ # Copy policy file
54
+ # (upload so101_policy.py from your local machine, or create it)
55
+ cp /path/to/so101_policy.py src/openpi/policies/so101_policy.py
56
+ ```
57
+
58
+ Then edit `src/openpi/training/config.py`:
59
+
60
+ **Add import at top:**
61
+ ```python
62
+ import openpi.policies.so101_policy as so101_policy
63
+ ```
64
+
65
+ **Add DataConfig class** (after `LeRobotLiberoDataConfig`):
66
+ ```python
67
+ @dataclasses.dataclass(frozen=True)
68
+ class LeRobotSO101DataConfig(DataConfigFactory):
69
+ @override
70
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
71
+ repack_transform = _transforms.Group(
72
+ inputs=[
73
+ _transforms.RepackTransform({
74
+ "observation/images/overhead": "observation.images.overhead",
75
+ "observation/images/wrist": "observation.images.wrist",
76
+ "observation/state": "observation.state",
77
+ "action": "action",
78
+ "prompt": "prompt",
79
+ })
80
+ ]
81
+ )
82
+
83
+ data_transforms = _transforms.Group(
84
+ inputs=[so101_policy.SO101Inputs(
85
+ action_dim=model_config.action_dim,
86
+ model_type=model_config.model_type
87
+ )],
88
+ outputs=[so101_policy.SO101Outputs()],
89
+ )
90
+
91
+ # Delta mask: 5 joints = delta, gripper = absolute
92
+ delta_action_mask = _transforms.make_bool_mask(5, -1)
93
+ data_transforms = data_transforms.push(
94
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
95
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
96
+ )
97
+
98
+ model_transforms = ModelTransformFactory()(model_config)
99
+
100
+ return dataclasses.replace(
101
+ self.create_base_config(assets_dirs, model_config),
102
+ repack_transforms=repack_transform,
103
+ data_transforms=data_transforms,
104
+ model_transforms=model_transforms,
105
+ action_sequence_keys=("action",),
106
+ )
107
+ ```
108
+
109
+ **Add TrainConfig** to `_CONFIGS` list:
110
+ ```python
111
+ TrainConfig(
112
+ name="pi05_so101",
113
+ model=pi0_config.Pi0Config(pi05=True, action_horizon=15),
114
+ data=LeRobotSO101DataConfig(
115
+ repo_id="abdul004/so101_ball_in_cup_v5",
116
+ base_config=DataConfig(prompt_from_task=True),
117
+ ),
118
+ weight_loader=weight_loaders.CheckpointWeightLoader(
119
+ "gs://openpi-assets/checkpoints/pi05_base/params"
120
+ ),
121
+ num_train_steps=5_000,
122
+ batch_size=32,
123
+ ),
124
+ ```
125
+
126
+ ### 4. Compute Normalization Stats
127
+
128
+ ```bash
129
+ uv run scripts/compute_norm_stats.py --config-name pi05_so101
130
+ ```
131
+
132
+ ### 5. Train
133
+
134
+ ```bash
135
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup
136
+ ```
137
+
138
+ Training progress will be logged to console and Weights & Biases.
139
+
140
+ ### 6. Download Checkpoint
141
+
142
+ After training, checkpoints are saved to `checkpoints/pi05_so101/ball_in_cup/`.
143
+
144
+ Download to your local machine:
145
+ ```bash
146
+ # On your local machine
147
+ scp -r vast_instance:openpi/checkpoints/pi05_so101/ball_in_cup/5000 ./pi05_so101_checkpoint
148
+ ```
149
+
150
+ ## Inference on Robot
151
+
152
+ (Coming soon - need to adapt LeRobot inference script)
153
+
154
+ ## Key Adaptations from LeKiwi
155
+
156
+ | Aspect | LeKiwi | SO-101 |
157
+ |--------|--------|--------|
158
+ | Action dim | 9 | 6 |
159
+ | Cameras | 3 (top, wrist, front) | 2 (overhead, wrist) |
160
+ | Camera keys | `observation.images.top` | `observation.images.overhead` |
161
+ | Delta mask | `make_bool_mask(5, -4)` | `make_bool_mask(5, -1)` |
162
+
163
+ ## Troubleshooting
164
+
165
+ ### Out of Memory
166
+ Set memory fraction higher:
167
+ ```bash
168
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 uv run scripts/train.py ...
169
+ ```
170
+
171
+ ### Dataset Not Found
172
+ Make sure you're logged into HuggingFace:
173
+ ```bash
174
+ huggingface-cli login
175
+ ```
176
+
177
+ ### Missing Norm Stats
178
+ Run compute_norm_stats.py before training:
179
+ ```bash
180
+ uv run scripts/compute_norm_stats.py --config-name pi05_so101
181
+ ```
evaluate_pi05.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pi0.5 Inference for SO-101 Robot
4
+ Adapted from Ilia Larchenko's LeKiwi evaluation script
5
+
6
+ Usage:
7
+ python evaluate_pi05.py --checkpoint checkpoints/pi05_so101/params
8
+ """
9
+
10
+ import argparse
11
+ import time
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+
16
+
17
+ def run_inference(checkpoint_path: str, robot_type: str = "so101"):
18
+ """Run Pi0 inference on SO-101 robot."""
19
+
20
+ # Import OpenPi (only when running inference)
21
+ from openpi.models import model as _model
22
+ from openpi.policies import policy_config
23
+
24
+ # Import LeRobot for robot control
25
+ from lerobot.common.robot_devices.robots.so101 import SO101Robot
26
+
27
+ print(f"Loading checkpoint from: {checkpoint_path}")
28
+
29
+ # Load the fine-tuned Pi0/Pi0.5 model
30
+ # This will auto-detect if it's Pi0 or Pi0.5 based on checkpoint
31
+ policy = policy_config.create_trained_policy(checkpoint_path)
32
+
33
+ # Connect to robot
34
+ print("Connecting to SO-101 robot...")
35
+ robot = SO101Robot()
36
+ robot.connect()
37
+
38
+ # Inference parameters
39
+ FPS = 30 # Match training FPS
40
+ ACTIONS_TO_EXECUTE = 15 # Execute fewer than predicted for better precision
41
+ TASK_PROMPT = "pick up the orange ball and put it in the pink cup"
42
+
43
+ print(f"Task: {TASK_PROMPT}")
44
+ print(f"FPS: {FPS}, Actions per chunk: {ACTIONS_TO_EXECUTE}")
45
+ print("Starting inference loop... Press Ctrl+C to stop")
46
+
47
+ try:
48
+ action_queue = []
49
+ step = 0
50
+
51
+ while True:
52
+ loop_start = time.perf_counter()
53
+
54
+ # Get current observation from robot
55
+ observation = robot.get_observation()
56
+
57
+ # If action queue is empty, get new predictions
58
+ if len(action_queue) == 0:
59
+ # Prepare observation for Pi0
60
+ obs_dict = {
61
+ "observation/state": observation["state"],
62
+ "observation/images/overhead": observation["images"]["overhead"],
63
+ "observation/images/wrist": observation["images"]["wrist"],
64
+ "prompt": TASK_PROMPT,
65
+ }
66
+
67
+ # Run inference
68
+ inference_start = time.perf_counter()
69
+ predicted_actions = policy.infer(obs_dict)["actions"]
70
+ inference_time = time.perf_counter() - inference_start
71
+
72
+ # Only use first N actions for better precision
73
+ action_queue = list(predicted_actions[:ACTIONS_TO_EXECUTE])
74
+
75
+ print(f"Step {step}: Inference took {inference_time*1000:.0f}ms, queued {len(action_queue)} actions")
76
+
77
+ # Execute next action
78
+ action = action_queue.pop(0)
79
+ robot.send_action(action)
80
+
81
+ step += 1
82
+
83
+ # Maintain FPS
84
+ elapsed = time.perf_counter() - loop_start
85
+ sleep_time = max(0, (1.0 / FPS) - elapsed)
86
+ time.sleep(sleep_time)
87
+
88
+ except KeyboardInterrupt:
89
+ print("\nStopping...")
90
+ finally:
91
+ robot.disconnect()
92
+ print("Robot disconnected")
93
+
94
+
95
+ def main():
96
+ parser = argparse.ArgumentParser(description="Run Pi0/Pi0.5 on SO-101 robot")
97
+ parser.add_argument(
98
+ "--checkpoint",
99
+ type=str,
100
+ required=True,
101
+ help="Path to fine-tuned checkpoint (e.g., checkpoints/pi05_so101/params)"
102
+ )
103
+ parser.add_argument(
104
+ "--robot",
105
+ type=str,
106
+ default="so101",
107
+ choices=["so101"],
108
+ help="Robot type"
109
+ )
110
+ parser.add_argument(
111
+ "--prompt",
112
+ type=str,
113
+ default="pick up the orange ball and put it in the pink cup",
114
+ help="Task prompt for the model"
115
+ )
116
+
117
+ args = parser.parse_args()
118
+ run_inference(args.checkpoint, args.robot)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()
so101_config.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SO-101 Training Config for OpenPi Pi0.5
2
+ # Adapted from Ilia Larchenko's LeKiwi config
3
+ #
4
+ # HOW TO USE:
5
+ # 1. Copy so101_policy.py to openpi/src/openpi/policies/
6
+ # 2. Add the imports and config class below to openpi/src/openpi/training/config.py
7
+ # 3. Add the TrainConfig to the _CONFIGS list in config.py
8
+ # 4. Run: uv run scripts/compute_norm_stats.py --config-name pi05_so101
9
+ # 5. Run: XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=my_experiment
10
+
11
+ # =============================================================================
12
+ # ADD THESE IMPORTS to the top of config.py:
13
+ # =============================================================================
14
+ # import openpi.policies.so101_policy as so101_policy
15
+
16
+ # =============================================================================
17
+ # ADD THIS CLASS to config.py (after the other DataConfig classes):
18
+ # =============================================================================
19
+
20
+ """
21
+ @dataclasses.dataclass(frozen=True)
22
+ class LeRobotSO101DataConfig(DataConfigFactory):
23
+ '''
24
+ Data config for SO-101 ball-in-cup task.
25
+
26
+ Dataset: abdul004/so101_ball_in_cup_v5
27
+ - 72 episodes of teleoperated demonstrations
28
+ - 6 DOF actions (5 arm joints + 1 gripper)
29
+ - 2 cameras (overhead + wrist)
30
+ '''
31
+
32
+ @override
33
+ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
34
+ # Remap LeRobot dataset keys to OpenPi format
35
+ # Left side = OpenPi expected keys, Right side = LeRobot dataset keys
36
+ repack_transform = _transforms.Group(
37
+ inputs=[
38
+ _transforms.RepackTransform(
39
+ {
40
+ "observation/images/overhead": "observation.images.overhead",
41
+ "observation/images/wrist": "observation.images.wrist",
42
+ "observation/state": "observation.state",
43
+ "action": "action",
44
+ "prompt": "prompt",
45
+ }
46
+ )
47
+ ]
48
+ )
49
+
50
+ # Data transforms using SO-101 policy
51
+ data_transforms = _transforms.Group(
52
+ inputs=[so101_policy.SO101Inputs(
53
+ action_dim=model_config.action_dim,
54
+ model_type=model_config.model_type
55
+ )],
56
+ outputs=[so101_policy.SO101Outputs()],
57
+ )
58
+
59
+ # Delta action mask:
60
+ # - First 5 dimensions (arm joints): convert to delta actions
61
+ # - Last 1 dimension (gripper): keep absolute
62
+ # make_bool_mask(5, -1) = [True, True, True, True, True, False]
63
+ delta_action_mask = _transforms.make_bool_mask(5, -1)
64
+ data_transforms = data_transforms.push(
65
+ inputs=[_transforms.DeltaActions(delta_action_mask)],
66
+ outputs=[_transforms.AbsoluteActions(delta_action_mask)],
67
+ )
68
+
69
+ # Model transforms (tokenization, etc.) - standard, no changes needed
70
+ model_transforms = ModelTransformFactory()(model_config)
71
+
72
+ return dataclasses.replace(
73
+ self.create_base_config(assets_dirs, model_config),
74
+ repack_transforms=repack_transform,
75
+ data_transforms=data_transforms,
76
+ model_transforms=model_transforms,
77
+ action_sequence_keys=("action",), # LeRobot uses "action" not "actions"
78
+ )
79
+ """
80
+
81
+ # =============================================================================
82
+ # ADD THIS TrainConfig to the _CONFIGS list in config.py:
83
+ # =============================================================================
84
+
85
+ """
86
+ TrainConfig(
87
+ name="pi05_so101",
88
+ model=pi0_config.Pi0Config(
89
+ pi05=True,
90
+ action_horizon=15, # Shorter horizon for Pi0.5
91
+ ),
92
+ data=LeRobotSO101DataConfig(
93
+ repo_id="abdul004/so101_ball_in_cup_v5",
94
+ base_config=DataConfig(prompt_from_task=True),
95
+ ),
96
+ weight_loader=weight_loaders.CheckpointWeightLoader(
97
+ "gs://openpi-assets/checkpoints/pi05_base/params"
98
+ ),
99
+ num_train_steps=5_000, # Ilia found 5K sufficient for simple tasks
100
+ batch_size=32,
101
+ ),
102
+ """
103
+
104
+ # =============================================================================
105
+ # FULL EXAMPLE: What config.py changes look like
106
+ # =============================================================================
107
+
108
+ # Near the top of config.py, add:
109
+ # import openpi.policies.so101_policy as so101_policy
110
+
111
+ # After LeRobotLiberoDataConfig class, add the LeRobotSO101DataConfig class above
112
+
113
+ # In the _CONFIGS list, add the TrainConfig above
114
+
115
+ # Then run:
116
+ # uv run scripts/compute_norm_stats.py --config-name pi05_so101
117
+ # XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi05_so101 --exp-name=ball_in_cup
so101_policy.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SO-101 Policy transforms for OpenPi Pi0.5
2
+ # Adapted from Ilia Larchenko's LeKiwi implementation
3
+ # https://github.com/IliaLarchenko/lerobot_random/blob/main/vla/pi/lekiwi_policy.py
4
+ #
5
+ # Copy this file to: openpi/src/openpi/policies/so101_policy.py
6
+
7
+ import dataclasses
8
+
9
+ import einops
10
+ import numpy as np
11
+
12
+ from openpi import transforms
13
+ from openpi.models import model as _model
14
+
15
+
16
+ # SO-101 has 6 DOF: 5 arm joints + 1 gripper
17
+ SO101_ACTION_DIM = 6
18
+
19
+
20
+ def make_so101_example() -> dict:
21
+ """Creates a random input example for testing SO-101 policy."""
22
+ return {
23
+ "observation/state": np.random.rand(SO101_ACTION_DIM).astype(np.float32),
24
+ "observation/images/overhead": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
25
+ "observation/images/wrist": np.random.randint(256, size=(480, 640, 3), dtype=np.uint8),
26
+ "prompt": "pick up the orange ball and put it in the pink cup",
27
+ }
28
+
29
+
30
+ def _parse_image(image) -> np.ndarray:
31
+ """Convert image to HWC uint8 format expected by Pi0."""
32
+ image = np.asarray(image)
33
+ # LeRobot stores as float32 CHW, convert to uint8 HWC
34
+ if np.issubdtype(image.dtype, np.floating):
35
+ image = (255 * image).astype(np.uint8)
36
+ if image.shape[0] == 3:
37
+ image = einops.rearrange(image, "c h w -> h w c")
38
+ return image
39
+
40
+
41
+ @dataclasses.dataclass(frozen=True)
42
+ class SO101Inputs(transforms.DataTransformFn):
43
+ """
44
+ Convert SO-101 observations to Pi0 model input format.
45
+
46
+ SO-101 has:
47
+ - 6 DOF state (5 arm joints + 1 gripper)
48
+ - 2 cameras (overhead + wrist)
49
+
50
+ Pi0 expects 3 camera slots, so we duplicate overhead for the third slot.
51
+ """
52
+
53
+ # Model's action dimension (SO-101 actions will be padded to this)
54
+ action_dim: int
55
+
56
+ # Model type (PI0, PI05, PI0_FAST)
57
+ model_type: _model.ModelType = _model.ModelType.PI0
58
+
59
+ def __call__(self, data: dict) -> dict:
60
+ # Pad state from 6 DOF to model's action_dim
61
+ state = transforms.pad_to_dim(data["observation/state"], self.action_dim)
62
+
63
+ # Parse images from SO-101's camera keys
64
+ overhead_image = _parse_image(data["observation/images/overhead"])
65
+ wrist_image = _parse_image(data["observation/images/wrist"])
66
+
67
+ # Map to Pi0's expected camera slots:
68
+ # - base_0_rgb: overhead camera (top-down view)
69
+ # - left_wrist_0_rgb: wrist camera
70
+ # - right_wrist_0_rgb: duplicate overhead (we only have 2 cameras)
71
+ inputs = {
72
+ "state": state,
73
+ "image": {
74
+ "base_0_rgb": overhead_image,
75
+ "left_wrist_0_rgb": wrist_image,
76
+ "right_wrist_0_rgb": overhead_image, # Duplicate overhead
77
+ },
78
+ "image_mask": {
79
+ "base_0_rgb": np.True_,
80
+ "left_wrist_0_rgb": np.True_,
81
+ # For Pi0 (not FAST), mask the duplicated camera
82
+ "right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
83
+ },
84
+ }
85
+
86
+ # Pad actions during training
87
+ if "action" in data:
88
+ actions = transforms.pad_to_dim(data["action"], self.action_dim)
89
+ inputs["actions"] = actions
90
+
91
+ # Pass language prompt to model
92
+ if "prompt" in data:
93
+ inputs["prompt"] = data["prompt"]
94
+
95
+ return inputs
96
+
97
+
98
+ @dataclasses.dataclass(frozen=True)
99
+ class SO101Outputs(transforms.DataTransformFn):
100
+ """
101
+ Convert Pi0 model outputs back to SO-101 action format.
102
+
103
+ Only return the first 6 actions (5 arm joints + 1 gripper),
104
+ discarding any padding.
105
+ """
106
+
107
+ def __call__(self, data: dict) -> dict:
108
+ # Return only first 6 actions for SO-101
109
+ return {"actions": np.asarray(data["actions"][:, :SO101_ACTION_DIM])}
test_config_local.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test SO-101 Pi0.5 config locally without GPU.
4
+
5
+ This verifies:
6
+ 1. Dataset loads correctly
7
+ 2. Keys match expected format
8
+ 3. Transforms work (simulated)
9
+ 4. Shapes are correct for Pi0.5
10
+
11
+ Run: python test_config_local.py
12
+ """
13
+
14
+ import numpy as np
15
+ from pathlib import Path
16
+
17
+
18
+ def test_dataset_structure():
19
+ """Test that dataset has expected structure."""
20
+ print("=" * 60)
21
+ print("1. Testing Dataset Structure")
22
+ print("=" * 60)
23
+
24
+ # Use LeRobot's dataset loader which handles videos properly
25
+ import sys
26
+ sys.path.insert(0, "/Users/abdul/repo/lerobot")
27
+ from lerobot.datasets.lerobot_dataset import LeRobotDataset
28
+
29
+ # Load dataset (uses local cache)
30
+ ds = LeRobotDataset("abdul004/so101_ball_in_cup_v5")
31
+ sample = ds[0] # Get first sample
32
+
33
+ print(f"\nDataset keys: {list(sample.keys())}")
34
+ print(f"Total samples: {len(ds)}")
35
+
36
+ # Check expected keys
37
+ expected_keys = [
38
+ "action",
39
+ "observation.state",
40
+ "observation.images.overhead",
41
+ "observation.images.wrist",
42
+ "timestamp",
43
+ "frame_index",
44
+ "episode_index",
45
+ ]
46
+
47
+ for key in expected_keys:
48
+ if key in sample:
49
+ val = sample[key]
50
+ if hasattr(val, 'shape'):
51
+ print(f" ✅ {key}: shape={val.shape}, dtype={val.dtype}")
52
+ elif hasattr(val, '__len__') and not isinstance(val, (str, dict)):
53
+ print(f" ✅ {key}: len={len(val)}")
54
+ else:
55
+ print(f" ✅ {key}: {type(val).__name__}")
56
+ else:
57
+ print(f" ❌ {key}: MISSING!")
58
+
59
+ return sample
60
+
61
+
62
+ def test_image_parsing(sample):
63
+ """Test image format conversion."""
64
+ print("\n" + "=" * 60)
65
+ print("2. Testing Image Parsing")
66
+ print("=" * 60)
67
+
68
+ import einops
69
+
70
+ def _parse_image(image) -> np.ndarray:
71
+ """Convert image to HWC uint8 format expected by Pi0."""
72
+ image = np.asarray(image)
73
+ original_shape = image.shape
74
+ original_dtype = image.dtype
75
+
76
+ if np.issubdtype(image.dtype, np.floating):
77
+ image = (255 * image).astype(np.uint8)
78
+ if image.shape[0] == 3:
79
+ image = einops.rearrange(image, "c h w -> h w c")
80
+
81
+ print(f" Input: shape={original_shape}, dtype={original_dtype}")
82
+ print(f" Output: shape={image.shape}, dtype={image.dtype}")
83
+ return image
84
+
85
+ print("\nOverhead camera:")
86
+ overhead = _parse_image(sample["observation.images.overhead"])
87
+
88
+ print("\nWrist camera:")
89
+ wrist = _parse_image(sample["observation.images.wrist"])
90
+
91
+ # Verify final shapes
92
+ assert overhead.shape[2] == 3, f"Overhead should be HWC, got {overhead.shape}"
93
+ assert wrist.shape[2] == 3, f"Wrist should be HWC, got {wrist.shape}"
94
+ assert overhead.dtype == np.uint8, f"Should be uint8, got {overhead.dtype}"
95
+
96
+ print("\n ✅ Images correctly converted to HWC uint8 format")
97
+
98
+ return overhead, wrist
99
+
100
+
101
+ def test_state_and_action(sample):
102
+ """Test state and action dimensions."""
103
+ print("\n" + "=" * 60)
104
+ print("3. Testing State and Action Dimensions")
105
+ print("=" * 60)
106
+
107
+ state = np.asarray(sample["observation.state"])
108
+ action = np.asarray(sample["action"])
109
+
110
+ print(f"\n State: shape={state.shape}, values={state}")
111
+ print(f" Action: shape={action.shape}, values={action}")
112
+
113
+ # SO-101 should have 6 DOF
114
+ assert len(state) == 6, f"State should be 6 DOF, got {len(state)}"
115
+ assert len(action) == 6, f"Action should be 6 DOF, got {len(action)}"
116
+
117
+ print("\n ✅ State and Action are 6 DOF as expected")
118
+
119
+ # Test padding to model action_dim (Pi0.5 uses 32 by default, but we can use smaller)
120
+ def pad_to_dim(arr, target_dim):
121
+ """Pad array to target dimension."""
122
+ arr = np.asarray(arr)
123
+ if len(arr) >= target_dim:
124
+ return arr[:target_dim]
125
+ return np.pad(arr, (0, target_dim - len(arr)), mode='constant')
126
+
127
+ model_action_dim = 32 # Pi0.5 default
128
+ padded_state = pad_to_dim(state, model_action_dim)
129
+ padded_action = pad_to_dim(action, model_action_dim)
130
+
131
+ print(f"\n Padded state: shape={padded_state.shape}")
132
+ print(f" Padded action: shape={padded_action.shape}")
133
+ print(f" ✅ Padding to model action_dim={model_action_dim} works")
134
+
135
+ return state, action
136
+
137
+
138
+ def test_delta_transform(state, action):
139
+ """Test delta action transformation."""
140
+ print("\n" + "=" * 60)
141
+ print("4. Testing Delta Action Transform")
142
+ print("=" * 60)
143
+
144
+ # Delta mask: first 5 joints = delta, gripper = absolute
145
+ # make_bool_mask(5, -1) = [True, True, True, True, True, False]
146
+ delta_mask = [True, True, True, True, True, False]
147
+
148
+ print(f"\n Delta mask: {delta_mask}")
149
+ print(f" (5 joints use delta, gripper stays absolute)")
150
+
151
+ # Simulate delta transform
152
+ delta_action = np.zeros_like(action)
153
+ for i, use_delta in enumerate(delta_mask):
154
+ if use_delta:
155
+ delta_action[i] = action[i] - state[i] # Convert to delta
156
+ else:
157
+ delta_action[i] = action[i] # Keep absolute (gripper)
158
+
159
+ print(f"\n Original action: {action}")
160
+ print(f" Current state: {state}")
161
+ print(f" Delta action: {delta_action}")
162
+
163
+ # Verify we can convert back
164
+ recovered_action = np.zeros_like(delta_action)
165
+ for i, use_delta in enumerate(delta_mask):
166
+ if use_delta:
167
+ recovered_action[i] = state[i] + delta_action[i] # Delta to absolute
168
+ else:
169
+ recovered_action[i] = delta_action[i] # Already absolute
170
+
171
+ np.testing.assert_array_almost_equal(action, recovered_action)
172
+ print(f" Recovered: {recovered_action}")
173
+ print("\n ✅ Delta transform is reversible")
174
+
175
+
176
+ def test_repack_transform():
177
+ """Test the repack transform key mapping."""
178
+ print("\n" + "=" * 60)
179
+ print("5. Testing Repack Transform (Key Mapping)")
180
+ print("=" * 60)
181
+
182
+ # This is what OpenPi's RepackTransform does
183
+ repack_map = {
184
+ "observation/images/overhead": "observation.images.overhead",
185
+ "observation/images/wrist": "observation.images.wrist",
186
+ "observation/state": "observation.state",
187
+ "action": "action",
188
+ "prompt": "prompt",
189
+ }
190
+
191
+ print("\n LeRobot key → OpenPi key:")
192
+ for openpi_key, lerobot_key in repack_map.items():
193
+ print(f" {lerobot_key} → {openpi_key}")
194
+
195
+ print("\n ✅ Key mapping defined correctly")
196
+
197
+
198
+ def test_pi0_input_format(overhead, wrist, state, action):
199
+ """Test the final Pi0 input format."""
200
+ print("\n" + "=" * 60)
201
+ print("6. Testing Pi0.5 Input Format")
202
+ print("=" * 60)
203
+
204
+ # Simulate what SO101Inputs produces
205
+ model_action_dim = 32
206
+
207
+ def pad_to_dim(arr, target_dim):
208
+ arr = np.asarray(arr)
209
+ if len(arr) >= target_dim:
210
+ return arr[:target_dim]
211
+ return np.pad(arr, (0, target_dim - len(arr)), mode='constant')
212
+
213
+ inputs = {
214
+ "state": pad_to_dim(state, model_action_dim),
215
+ "image": {
216
+ "base_0_rgb": overhead, # Overhead → base
217
+ "left_wrist_0_rgb": wrist, # Wrist → left_wrist
218
+ "right_wrist_0_rgb": overhead, # Duplicate overhead
219
+ },
220
+ "image_mask": {
221
+ "base_0_rgb": True,
222
+ "left_wrist_0_rgb": True,
223
+ "right_wrist_0_rgb": False, # Masked for Pi0 (not FAST)
224
+ },
225
+ "actions": pad_to_dim(action, model_action_dim),
226
+ "prompt": "pick up the orange ball and put it in the pink cup",
227
+ }
228
+
229
+ print("\n Pi0.5 input structure:")
230
+ print(f" state: shape={inputs['state'].shape}")
231
+ print(f" image.base_0_rgb: shape={inputs['image']['base_0_rgb'].shape}")
232
+ print(f" image.left_wrist_0_rgb: shape={inputs['image']['left_wrist_0_rgb'].shape}")
233
+ print(f" image.right_wrist_0_rgb: shape={inputs['image']['right_wrist_0_rgb'].shape}")
234
+ print(f" image_mask: {inputs['image_mask']}")
235
+ print(f" actions: shape={inputs['actions'].shape}")
236
+ print(f" prompt: '{inputs['prompt']}'")
237
+
238
+ print("\n ✅ Pi0.5 input format is correct!")
239
+
240
+
241
+ def main():
242
+ print("\n🧪 Testing SO-101 Pi0.5 Config Locally\n")
243
+
244
+ try:
245
+ # Test 1: Dataset structure
246
+ sample = test_dataset_structure()
247
+
248
+ # Test 2: Image parsing
249
+ overhead, wrist = test_image_parsing(sample)
250
+
251
+ # Test 3: State and action
252
+ state, action = test_state_and_action(sample)
253
+
254
+ # Test 4: Delta transform
255
+ test_delta_transform(state, action)
256
+
257
+ # Test 5: Repack transform
258
+ test_repack_transform()
259
+
260
+ # Test 6: Final Pi0 format
261
+ test_pi0_input_format(overhead, wrist, state, action)
262
+
263
+ print("\n" + "=" * 60)
264
+ print("✅ ALL TESTS PASSED!")
265
+ print("=" * 60)
266
+ print("\nConfig should work on Vast.ai. Ready to train!")
267
+
268
+ except Exception as e:
269
+ print(f"\n❌ TEST FAILED: {e}")
270
+ import traceback
271
+ traceback.print_exc()
272
+
273
+
274
+ if __name__ == "__main__":
275
+ main()