justinstrong commited on
Commit
081355b
·
verified ·
1 Parent(s): 5d1e4a3

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. build_index.py +33 -4
  2. eval_kitchen.py +263 -0
  3. eval_sim.py +171 -0
  4. filtered_index.json +3 -3
  5. infer_so101.py +223 -0
  6. so100_dataset.py +13 -0
build_index.py CHANGED
@@ -20,9 +20,22 @@ import random
20
  from collections import defaultdict
21
  from pathlib import Path
22
 
 
23
  import pandas as pd
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_dataset_meta(dataset_root: Path) -> dict | None:
27
  """Load and validate a single dataset's metadata."""
28
  info_path = dataset_root / "meta" / "info.json"
@@ -122,6 +135,7 @@ def build_index(
122
  datasets_passed = 0
123
  datasets_rejected = 0
124
  skipped_missing = 0
 
125
 
126
  for contrib_dir in contributors:
127
  if not contrib_dir.is_dir():
@@ -152,9 +166,9 @@ def build_index(
152
  skipped_missing += 1
153
  continue
154
 
155
- # Read actual row count from parquet (fast — just reads footer)
156
- pf = pd.read_parquet(parquet_path, columns=["frame_index"])
157
- actual_length = len(pf)
158
 
159
  if actual_length < min_episode_frames or actual_length > max_episode_frames:
160
  continue
@@ -166,6 +180,21 @@ def build_index(
166
  skipped_missing += 1
167
  continue
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # Get task from episodes.jsonl if available, else default
170
  task_idx = 0
171
  if meta["episodes"]:
@@ -178,7 +207,7 @@ def build_index(
178
  all_episodes.append((contributor, dataset_name, ep_idx, task, actual_length))
179
 
180
  print(f"Datasets: {datasets_passed} passed, {datasets_rejected} rejected")
181
- print(f"Episodes verified: {len(all_episodes)}, skipped (missing files): {skipped_missing}")
182
  print(f"Episodes before caps: {len(all_episodes)}")
183
 
184
  # Phase 2: Apply per-task cap
 
20
  from collections import defaultdict
21
  from pathlib import Path
22
 
23
+ import av
24
  import pandas as pd
25
 
26
 
27
+ def get_video_duration(video_path: Path) -> float:
28
+ """Get video duration in seconds by reading container metadata (fast, no decoding)."""
29
+ try:
30
+ container = av.open(str(video_path))
31
+ stream = container.streams.video[0]
32
+ duration = float(stream.duration * stream.time_base)
33
+ container.close()
34
+ return duration
35
+ except Exception:
36
+ return 0.0
37
+
38
+
39
  def load_dataset_meta(dataset_root: Path) -> dict | None:
40
  """Load and validate a single dataset's metadata."""
41
  info_path = dataset_root / "meta" / "info.json"
 
135
  datasets_passed = 0
136
  datasets_rejected = 0
137
  skipped_missing = 0
138
+ skipped_video_mismatch = 0
139
 
140
  for contrib_dir in contributors:
141
  if not contrib_dir.is_dir():
 
166
  skipped_missing += 1
167
  continue
168
 
169
+ # Read actual row count and timestamps from parquet
170
+ pf_full = pd.read_parquet(parquet_path, columns=["frame_index", "timestamp"])
171
+ actual_length = len(pf_full)
172
 
173
  if actual_length < min_episode_frames or actual_length > max_episode_frames:
174
  continue
 
180
  skipped_missing += 1
181
  continue
182
 
183
+ # Verify video duration covers all parquet timestamps
184
+ # The last frame's timestamp must be within the video duration
185
+ last_timestamp = float(pf_full["timestamp"].iloc[-1])
186
+ vid1_duration = get_video_duration(vid1)
187
+ vid2_duration = get_video_duration(vid2)
188
+ min_vid_duration = min(vid1_duration, vid2_duration)
189
+ if min_vid_duration > 0 and last_timestamp > min_vid_duration:
190
+ # Video is shorter than parquet claims — truncate to what the video covers
191
+ # Find the last frame index where timestamp <= video duration
192
+ valid_mask = pf_full["timestamp"] <= min_vid_duration
193
+ actual_length = int(valid_mask.sum())
194
+ if actual_length < min_episode_frames:
195
+ skipped_video_mismatch += 1
196
+ continue
197
+
198
  # Get task from episodes.jsonl if available, else default
199
  task_idx = 0
200
  if meta["episodes"]:
 
207
  all_episodes.append((contributor, dataset_name, ep_idx, task, actual_length))
208
 
209
  print(f"Datasets: {datasets_passed} passed, {datasets_rejected} rejected")
210
+ print(f"Episodes verified: {len(all_episodes)}, skipped missing: {skipped_missing}, skipped video mismatch: {skipped_video_mismatch}")
211
  print(f"Episodes before caps: {len(all_episodes)}")
212
 
213
  # Phase 2: Apply per-task cap
eval_kitchen.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate Pi0.5 checkpoints in the RoboCasa kitchen sim.
4
+ Compares base model vs finetuned model side by side.
5
+
6
+ Runs on CPU only (GPU is used by training).
7
+
8
+ Usage:
9
+ python eval_kitchen.py --checkpoint /mnt/hdd/pi05-training/full_run/checkpoints/004000/pretrained_model
10
+ python eval_kitchen.py --checkpoint lerobot/pi05_base # base model comparison
11
+ python eval_kitchen.py --compare # run both and save side-by-side
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ # EGL rendering for headless MuJoCo
21
+ os.environ["MUJOCO_GL"] = "egl"
22
+
23
+ import imageio
24
+ import numpy as np
25
+ import torch
26
+
27
+ sys.path.insert(0, str(Path(__file__).parent))
28
+ sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
29
+ sys.path.insert(0, "/mnt/hdd/pi05-training/robocasa_test")
30
+
31
+ from so100_kitchen_env import SO100KitchenEnv
32
+
33
+
34
+ def load_policy(checkpoint_path, device="cuda"):
35
+ """Load Pi0.5 policy."""
36
+ from lerobot.policies.pi05.modeling_pi05 import PI05Policy
37
+ print(f"Loading policy from {checkpoint_path} ({device})...")
38
+ policy = PI05Policy.from_pretrained(str(checkpoint_path))
39
+ policy = policy.to(device)
40
+ policy.eval()
41
+ return policy
42
+
43
+
44
+ def build_batch(env_obs, camera_image, task, stats, device="cuda"):
45
+ """Convert kitchen env observation to Pi0.5 batch format."""
46
+ import torchvision.transforms.functional as TF
47
+
48
+ # Image: (H, W, 3) uint8 -> (1, 3, 224, 224) float32
49
+ image = torch.from_numpy(camera_image).permute(2, 0, 1).float() / 255.0
50
+ image = image.unsqueeze(0)
51
+ image_224 = TF.resize(image, [224, 224], antialias=True)
52
+
53
+ # ImageNet normalization
54
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
55
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
56
+ image_224 = (image_224 - mean) / std
57
+
58
+ # State: joint positions in radians -> degrees (LeRobot scale), then normalize
59
+ joint_pos = env_obs["joint_pos"]
60
+ state_degrees = np.degrees(joint_pos)
61
+ state = torch.tensor(state_degrees, dtype=torch.float32).unsqueeze(0)
62
+
63
+ state_mean = torch.tensor(stats["observation.state"]["mean"], dtype=torch.float32)
64
+ state_std = torch.tensor(stats["observation.state"]["std"], dtype=torch.float32)
65
+ state = (state - state_mean) / (state_std + 1e-8)
66
+
67
+ # Pad to 32 dims
68
+ state_padded = torch.zeros(1, 32)
69
+ state_padded[:, :6] = state
70
+
71
+ # Tokenize
72
+ from transformers import AutoTokenizer
73
+ tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
74
+
75
+ state_discrete = ((state[0].clamp(-1, 1) + 1) / 2 * 255).int()
76
+ state_str = " ".join(str(v.item()) for v in state_discrete)
77
+ prompt = f"Task: {task}, State: {state_str};\nAction: "
78
+
79
+ tokens = tokenizer(
80
+ prompt, padding="max_length", max_length=200,
81
+ truncation=True, return_tensors="pt",
82
+ )
83
+
84
+ return {
85
+ "observation.images.base_0_rgb": image_224.to(device),
86
+ "observation.images.left_wrist_0_rgb": image_224.to(device),
87
+ "observation.state": state_padded.to(device),
88
+ "observation.language.tokens": tokens["input_ids"].to(device),
89
+ "observation.language.attention_mask": tokens["attention_mask"].bool().to(device),
90
+ }
91
+
92
+
93
+ def decode_actions(raw_actions, stats):
94
+ """Convert model output to joint angle radians."""
95
+ actions = raw_actions[0, :, :6].cpu().numpy()
96
+ action_mean = np.array(stats["action"]["mean"])
97
+ action_std = np.array(stats["action"]["std"])
98
+ actions = actions * action_std + action_mean
99
+ return np.radians(actions)
100
+
101
+
102
+ def run_episode(policy, env, task, stats, num_steps=200, camera="robot_workspace", show_live=True):
103
+ """Run one episode, return frames and joint trajectories."""
104
+ obs = env.reset()
105
+ frames = []
106
+ joint_history = []
107
+ chunk_actions = None
108
+ chunk_idx = 0
109
+
110
+ for step in range(num_steps):
111
+ if chunk_actions is None or chunk_idx >= len(chunk_actions):
112
+ camera_image = env.render(camera)
113
+ with torch.no_grad():
114
+ batch = build_batch(obs, camera_image, task, stats, device=next(policy.parameters()).device)
115
+ action = policy.select_action(batch)
116
+ chunk_actions = decode_actions(action.unsqueeze(0), stats)
117
+ chunk_idx = 0
118
+
119
+ action = chunk_actions[chunk_idx]
120
+ chunk_idx += 1
121
+
122
+ obs, reward, done, info = env.step(action)
123
+ frame = env.render(camera)
124
+ frames.append(frame)
125
+ joint_history.append(obs["joint_pos"].copy())
126
+
127
+ # Live display via cv2 (static camera)
128
+ if show_live:
129
+ try:
130
+ import cv2
131
+ cv2.imshow("SO-100 Kitchen Sim", cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
132
+ if cv2.waitKey(1) & 0xFF == ord('q'):
133
+ print("Quit by user")
134
+ break
135
+ except Exception:
136
+ pass
137
+
138
+ if step % 25 == 0:
139
+ pos = obs["joint_pos"]
140
+ print(f" step {step:>3}: joints=[{pos[0]:.2f} {pos[1]:.2f} {pos[2]:.2f} {pos[3]:.2f} {pos[4]:.2f} {pos[5]:.3f}]")
141
+
142
+ return frames, np.array(joint_history)
143
+
144
+
145
+ def main():
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument("--checkpoint", type=str, default=None)
148
+ parser.add_argument("--task", type=str, default="pick up the mug and place it on the plate")
149
+ parser.add_argument("--steps", type=int, default=200)
150
+ parser.add_argument("--output-dir", type=str, default="/mnt/hdd/pi05-training/eval_kitchen")
151
+ parser.add_argument("--compare", action="store_true", help="Run base vs finetuned comparison")
152
+ parser.add_argument("--viewer", action="store_true", help="Use MuJoCo interactive viewer (mouse orbit/pan/zoom)")
153
+ parser.add_argument("--finetuned-checkpoint", type=str,
154
+ default="/mnt/hdd/pi05-training/full_run/checkpoints/004000/pretrained_model")
155
+ args = parser.parse_args()
156
+
157
+ os.makedirs(args.output_dir, exist_ok=True)
158
+
159
+ with open(Path(__file__).parent / "norm_stats.json") as f:
160
+ stats = json.load(f)
161
+
162
+ env = SO100KitchenEnv()
163
+
164
+ if args.viewer:
165
+ # Interactive MuJoCo viewer with mouse controls
166
+ import mujoco.viewer
167
+ import time as _time
168
+ policy = load_policy(args.checkpoint or "lerobot/pi05_base")
169
+ obs = env.reset()
170
+ chunk_actions = None
171
+ chunk_idx = 0
172
+ device = next(policy.parameters()).device
173
+
174
+ print(f"Launching interactive viewer. Task: '{args.task}'")
175
+ print("Mouse: Left=rotate, Right=pan, Scroll=zoom")
176
+ print("Close window to exit.")
177
+
178
+ viewer = mujoco.viewer.launch_passive(env.model, env.data)
179
+ step = 0
180
+ while viewer.is_running():
181
+ # Get action from policy
182
+ if chunk_actions is None or chunk_idx >= len(chunk_actions):
183
+ camera_image = env.render("overview")
184
+ with torch.no_grad():
185
+ batch = build_batch(obs, camera_image, args.task, stats, device=device)
186
+ action = policy.select_action(batch)
187
+ chunk_actions = decode_actions(action.unsqueeze(0), stats)
188
+ chunk_idx = 0
189
+
190
+ act = chunk_actions[chunk_idx]
191
+ chunk_idx += 1
192
+
193
+ # Apply action to actuators
194
+ from so100_kitchen_env import JOINT_NAMES
195
+ for i, name in enumerate(JOINT_NAMES):
196
+ aid = env.actuator_ids.get(name)
197
+ if aid is not None:
198
+ env.data.ctrl[aid] = act[i]
199
+
200
+ # Step physics
201
+ mujoco.mj_step(env.model, env.data)
202
+ viewer.sync()
203
+
204
+ # Update obs
205
+ joint_pos = np.array([env.data.qpos[env.model.jnt_qposadr[env.joint_ids[n]]] for n in JOINT_NAMES])
206
+ obs = {"joint_pos": joint_pos}
207
+
208
+ step += 1
209
+ if step % 50 == 0:
210
+ print(f" step {step}: joints=[{' '.join(f'{j:.2f}' for j in joint_pos)}]")
211
+
212
+ _time.sleep(0.02) # ~50Hz
213
+
214
+ viewer.close()
215
+
216
+ elif args.compare:
217
+ # Run both base and finetuned
218
+ print("=== BASE MODEL ===")
219
+ base_policy = load_policy("lerobot/pi05_base")
220
+ base_frames, base_joints = run_episode(base_policy, env, args.task, stats, args.steps)
221
+ del base_policy
222
+
223
+ print("\n=== FINETUNED MODEL ===")
224
+ ft_policy = load_policy(args.finetuned_checkpoint)
225
+ ft_frames, ft_joints = run_episode(ft_policy, env, args.task, stats, args.steps)
226
+ del ft_policy
227
+
228
+ # Save videos
229
+ imageio.mimsave(f"{args.output_dir}/base_model.mp4", base_frames, fps=25)
230
+ imageio.mimsave(f"{args.output_dir}/finetuned_model.mp4", ft_frames, fps=25)
231
+
232
+ # Save side-by-side frames at key timesteps
233
+ for t in [0, 50, 100, 150, 199]:
234
+ if t < len(base_frames) and t < len(ft_frames):
235
+ combined = np.concatenate([base_frames[t], ft_frames[t]], axis=1)
236
+ imageio.imwrite(f"{args.output_dir}/compare_step_{t:03d}.png", combined)
237
+
238
+ # Print joint trajectory summary
239
+ print("\n=== COMPARISON ===")
240
+ print(f"Base model - joint range: {base_joints.min(axis=0)} to {base_joints.max(axis=0)}")
241
+ print(f"Finetuned - joint range: {ft_joints.min(axis=0)} to {ft_joints.max(axis=0)}")
242
+ print(f"Base model - total motion: {np.abs(np.diff(base_joints, axis=0)).sum():.2f} rad")
243
+ print(f"Finetuned - total motion: {np.abs(np.diff(ft_joints, axis=0)).sum():.2f} rad")
244
+
245
+ print(f"\nSaved to {args.output_dir}/")
246
+
247
+ elif args.checkpoint:
248
+ policy = load_policy(args.checkpoint)
249
+ frames, joints = run_episode(policy, env, args.task, stats, args.steps)
250
+
251
+ name = Path(args.checkpoint).parent.name if "checkpoint" in args.checkpoint else "model"
252
+ imageio.mimsave(f"{args.output_dir}/{name}.mp4", frames, fps=25)
253
+
254
+ for t in [0, len(frames)//2, len(frames)-1]:
255
+ imageio.imwrite(f"{args.output_dir}/{name}_step_{t:03d}.png", frames[t])
256
+
257
+ print(f"Saved {len(frames)} frames to {args.output_dir}/")
258
+ else:
259
+ print("Specify --checkpoint or --compare")
260
+
261
+
262
+ if __name__ == "__main__":
263
+ main()
eval_sim.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Evaluate a Pi0.5 checkpoint in the SO-100 MuJoCo sim.
4
+ Renders a video of the model controlling the arm.
5
+
6
+ Usage:
7
+ python eval_sim.py --checkpoint outputs/scale_up_1k/checkpoints/000500/pretrained_model
8
+ python eval_sim.py --checkpoint lerobot/pi05_base # test base model
9
+ """
10
+
11
+ import argparse
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ import imageio
16
+ import numpy as np
17
+ import torch
18
+
19
+ sys.path.insert(0, str(Path(__file__).parent))
20
+ sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
21
+
22
+ from gym_so100.env import SO100Env
23
+ from gym_so100.constants import normalize_lerobot_to_gym_so100
24
+
25
+
26
+ def load_policy(checkpoint_path, device="cuda"):
27
+ """Load Pi0.5 policy from checkpoint."""
28
+ from lerobot.policies.pi05.modeling_pi05 import PI05Policy
29
+
30
+ print(f"Loading policy from {checkpoint_path}...")
31
+ policy = PI05Policy.from_pretrained(str(checkpoint_path))
32
+ policy = policy.to(device)
33
+ policy.eval()
34
+ return policy
35
+
36
+
37
+ def build_batch(obs, task, stats, device="cuda"):
38
+ """Convert sim observation to Pi0.5 batch format."""
39
+ # Image: sim gives (H, W, 3) uint8 -> (1, 3, H, W) float32 [0,1]
40
+ image = torch.from_numpy(obs["pixels"]).permute(2, 0, 1).float() / 255.0
41
+ image = image.unsqueeze(0) # add batch dim
42
+
43
+ # Resize to 224x224
44
+ import torchvision.transforms.functional as TF
45
+ image_224 = TF.resize(image, [224, 224], antialias=True)
46
+
47
+ # ImageNet normalization
48
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
49
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
50
+ image_224 = (image_224 - mean) / std
51
+
52
+ # State: sim gives radians, convert to degrees (LeRobot scale)
53
+ agent_pos = obs["agent_pos"].copy()
54
+ agent_pos_degrees = np.degrees(agent_pos)
55
+ state = torch.tensor(agent_pos_degrees, dtype=torch.float32).unsqueeze(0)
56
+
57
+ # Normalize state with our stats
58
+ state_mean = torch.tensor(stats["observation.state"]["mean"], dtype=torch.float32)
59
+ state_std = torch.tensor(stats["observation.state"]["std"], dtype=torch.float32)
60
+ state = (state - state_mean) / (state_std + 1e-8)
61
+
62
+ # Pad state to 32 dims
63
+ state_padded = torch.zeros(1, 32)
64
+ state_padded[:, :6] = state
65
+
66
+ # Tokenize task
67
+ from transformers import AutoTokenizer
68
+ tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
69
+
70
+ # Discretize state for prompt (Pi0.5 format)
71
+ state_discrete = ((state[0].clamp(-1, 1) + 1) / 2 * 255).int()
72
+ state_str = " ".join(str(v.item()) for v in state_discrete)
73
+ prompt = f"Task: {task}, State: {state_str};\nAction: "
74
+
75
+ tokens = tokenizer(
76
+ prompt,
77
+ padding="max_length",
78
+ max_length=200,
79
+ truncation=True,
80
+ return_tensors="pt",
81
+ )
82
+
83
+ batch = {
84
+ "observation.images.base_0_rgb": image_224.to(device),
85
+ "observation.images.left_wrist_0_rgb": image_224.to(device),
86
+ "observation.state": state_padded.to(device),
87
+ "observation.language.tokens": tokens["input_ids"].to(device),
88
+ "observation.language.attention_mask": tokens["attention_mask"].bool().to(device),
89
+ }
90
+ return batch
91
+
92
+
93
+ def decode_actions(raw_actions, stats):
94
+ """Convert model output actions back to LeRobot scale, then to sim radians."""
95
+ actions = raw_actions[0, :, :6].cpu().numpy() # (chunk_size, 6)
96
+
97
+ # Unnormalize from MEAN_STD
98
+ action_mean = np.array(stats["action"]["mean"])
99
+ action_std = np.array(stats["action"]["std"])
100
+ actions = actions * action_std + action_mean
101
+
102
+ # Now in LeRobot degree-scale. Convert to radians for sim.
103
+ actions_rad = np.radians(actions)
104
+ return actions_rad
105
+
106
+
107
+ def main():
108
+ parser = argparse.ArgumentParser()
109
+ parser.add_argument("--checkpoint", type=str, required=True)
110
+ parser.add_argument("--task", type=str, default="pick up the cube and place it in the bin")
111
+ parser.add_argument("--steps", type=int, default=200)
112
+ parser.add_argument("--output", type=str, default="sim_eval.mp4")
113
+ parser.add_argument("--device", type=str, default="cuda")
114
+ args = parser.parse_args()
115
+
116
+ import json
117
+ with open(Path(__file__).parent / "norm_stats.json") as f:
118
+ stats = json.load(f)
119
+
120
+ # Load policy
121
+ policy = load_policy(args.checkpoint, args.device)
122
+
123
+ # Create sim
124
+ env = SO100Env(task="so100_cube_to_bin", obs_type="so100_pixels_agent_pos")
125
+ obs, info = env.reset()
126
+
127
+ frames = []
128
+ print(f"Running {args.steps} sim steps with task: '{args.task}'")
129
+
130
+ chunk_actions = None
131
+ chunk_idx = 0
132
+
133
+ for step in range(args.steps):
134
+ # Get new action chunk from policy every N steps
135
+ if chunk_actions is None or chunk_idx >= len(chunk_actions):
136
+ with torch.no_grad():
137
+ batch = build_batch(obs, args.task, stats, args.device)
138
+ action = policy.select_action(batch)
139
+ chunk_actions = decode_actions(action.unsqueeze(0), stats)
140
+ chunk_idx = 0
141
+
142
+ # Apply one action from the chunk
143
+ action = chunk_actions[chunk_idx]
144
+ chunk_idx += 1
145
+
146
+ # Normalize radians to sim's [-1, 1] action space
147
+ joint_mins = np.array([-1.92, -3.32, -0.174, -1.66, -2.79, -0.174])
148
+ joint_maxs = np.array([1.92, 0.174, 3.14, 1.66, 2.79, 1.75])
149
+ sim_action = 2.0 * (action - joint_mins) / (joint_maxs - joint_mins) - 1.0
150
+ sim_action = np.clip(sim_action, -1.0, 1.0)
151
+
152
+ obs, reward, terminated, truncated, info = env.step(sim_action.astype(np.float32))
153
+
154
+ frame = env.render()
155
+ frames.append(frame)
156
+
157
+ if step % 20 == 0:
158
+ pos = obs["agent_pos"]
159
+ print(f" step {step:>3}: pos=[{pos[0]:.2f} {pos[1]:.2f} {pos[2]:.2f} {pos[3]:.2f} {pos[4]:.2f} {pos[5]:.3f}] reward={reward:.3f}")
160
+
161
+ if terminated or truncated:
162
+ print(f"Episode ended at step {step}")
163
+ break
164
+
165
+ # Save video
166
+ imageio.mimsave(args.output, frames, fps=25)
167
+ print(f"Saved {len(frames)} frames to {args.output}")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()
filtered_index.json CHANGED
@@ -11,8 +11,8 @@
11
  "datasets": 376,
12
  "episodes": 10155,
13
  "unique_tasks": 215,
14
- "total_frames": 5431807,
15
- "est_hours": 50.29450925925926
16
  },
17
  "tasks": [
18
  "Build a Hanoi Tower.",
@@ -1014,7 +1014,7 @@
1014
  "episode_index": 10,
1015
  "task": "Grasp a lego block and put it in the bin.",
1016
  "task_index": 55,
1017
- "num_frames": 558
1018
  },
1019
  {
1020
  "dataset": "1lyz123576/so101_test-1",
 
11
  "datasets": 376,
12
  "episodes": 10155,
13
  "unique_tasks": 215,
14
+ "total_frames": 5431590,
15
+ "est_hours": 50.2925
16
  },
17
  "tasks": [
18
  "Build a Hanoi Tower.",
 
1014
  "episode_index": 10,
1015
  "task": "Grasp a lego block and put it in the bin.",
1016
  "task_index": 55,
1017
+ "num_frames": 341
1018
  },
1019
  {
1020
  "dataset": "1lyz123576/so101_test-1",
infer_so101.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Run Pi0.5 inference on SO-101.
4
+
5
+ Uses LeRobot's FeetechMotorsBus with calibration for correct normalization,
6
+ but bypasses lerobot_record's problematic control loop.
7
+
8
+ Usage:
9
+ python infer_so101.py --task "pick up the blue football"
10
+ """
11
+ import argparse
12
+ import json
13
+ import logging
14
+ import sys
15
+ import time
16
+ from pathlib import Path
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import scservo_sdk as scs
21
+ import torch
22
+
23
+ sys.path.insert(0, str(Path(__file__).parent))
24
+ sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
25
+
26
+ logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S')
27
+ log = logging.getLogger()
28
+
29
+ MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
30
+ MOTOR_IDS = [1, 2, 3, 4, 5, 6]
31
+
32
+
33
+ def main():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--task", type=str, required=True)
36
+ parser.add_argument("--checkpoint", type=str,
37
+ default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model")
38
+ parser.add_argument("--port", type=str, default="/dev/ttyACM0")
39
+ parser.add_argument("--cam-front", type=int, default=2)
40
+ parser.add_argument("--cam-wrist", type=int, default=0)
41
+ parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C")
42
+ args = parser.parse_args()
43
+
44
+ # --- Connect motors using LeRobot's bus (for calibration/normalization) ---
45
+ from lerobot.motors.feetech.feetech import FeetechMotorsBus
46
+ from lerobot.motors import Motor, MotorNormMode, MotorCalibration
47
+
48
+ bus = FeetechMotorsBus(
49
+ port=args.port,
50
+ motors={
51
+ 'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100),
52
+ 'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100),
53
+ 'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100),
54
+ 'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100),
55
+ 'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100),
56
+ 'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100),
57
+ },
58
+ )
59
+ bus.connect()
60
+
61
+ # Load calibration
62
+ cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json"
63
+ cal = json.load(open(cal_path))
64
+ cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()}
65
+ bus.write_calibration(cal_dict)
66
+ log.warning("Bus connected with calibration")
67
+
68
+ # Configure motors the same way LeRobot does in so_follower.configure()
69
+ # This uses torque_disabled() context which disables torque, configures, re-enables
70
+ with bus.torque_disabled():
71
+ bus.configure_motors()
72
+ for motor in bus.motors:
73
+ bus.write("Operating_Mode", motor, 0) # Position mode
74
+ bus.write("P_Coefficient", motor, 16)
75
+ bus.write("I_Coefficient", motor, 0)
76
+ bus.write("D_Coefficient", motor, 32)
77
+ bus.write("Goal_Velocity", motor, 600) # Slow velocity limit
78
+ bus.write("Acceleration", motor, 50) # Gentle acceleration
79
+ if motor == "gripper":
80
+ bus.write("Max_Torque_Limit", motor, 500)
81
+ bus.write("Protection_Current", motor, 250)
82
+ bus.write("Overload_Torque", motor, 25)
83
+ # torque_disabled() re-enables torque on exit
84
+ # Velocity and acceleration limits prevent snapping
85
+ log.warning("Motors configured and torque enabled (velocity/accel limited)")
86
+
87
+ # --- Open cameras ---
88
+ cap_front = cv2.VideoCapture(args.cam_front)
89
+ cap_wrist = cv2.VideoCapture(args.cam_wrist)
90
+ for cap in [cap_front, cap_wrist]:
91
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
92
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
93
+ log.warning("Cameras open")
94
+
95
+ # --- Load policy + preprocessor + postprocessor ---
96
+ from lerobot.policies.factory import make_pre_post_processors
97
+ from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action
98
+ from lerobot.configs.policies import PreTrainedConfig
99
+ from lerobot.processor.rename_processor import rename_stats
100
+ from lerobot.policies.pi05.modeling_pi05 import PI05Policy
101
+
102
+ log.warning("Loading Pi0.5...")
103
+ policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint)
104
+ policy_cfg.pretrained_path = Path(args.checkpoint)
105
+
106
+ policy = PI05Policy.from_pretrained(args.checkpoint)
107
+ policy = policy.to("cuda")
108
+ policy.eval()
109
+ policy.reset()
110
+
111
+ # Build stats from checkpoint's saved preprocessor
112
+ rename_map = {
113
+ "observation.images.front": "observation.images.base_0_rgb",
114
+ "observation.images.wrist": "observation.images.left_wrist_0_rgb",
115
+ }
116
+
117
+ preprocessor, postprocessor = make_pre_post_processors(
118
+ policy_cfg=policy_cfg,
119
+ pretrained_path=policy_cfg.pretrained_path,
120
+ preprocessor_overrides={
121
+ "device_processor": {"device": "cuda"},
122
+ "rename_observations_processor": {"rename_map": rename_map},
123
+ },
124
+ )
125
+
126
+ action_names = [f"{name}.pos" for name in MOTOR_NAMES]
127
+ ds_features = {"action": {"names": action_names}}
128
+
129
+ # --- Set up live camera display ---
130
+ try:
131
+ import rerun as rr
132
+ rr.init("so101_inference", spawn=True)
133
+ use_rerun = True
134
+ log.warning("Rerun viewer launched — live camera feed")
135
+ except ImportError:
136
+ use_rerun = False
137
+ log.warning("Rerun not available, no live view")
138
+
139
+ log.warning(f"Running: '{args.task}' — Ctrl+C to stop")
140
+
141
+ step = 0
142
+ try:
143
+ while args.max_steps == 0 or step < args.max_steps:
144
+ t0 = time.perf_counter()
145
+
146
+ # 1. Read motor positions (calibrated/normalized by bus)
147
+ try:
148
+ pos_dict = bus.sync_read("Present_Position", num_retry=5)
149
+ except ConnectionError:
150
+ bus.port_handler.is_using = False
151
+ bus.port_handler.ser.reset_input_buffer()
152
+ continue
153
+
154
+ # Build observation dict
155
+ state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32)
156
+
157
+ # 2. Capture camera images
158
+ ret_f, frame_front = cap_front.read()
159
+ ret_w, frame_wrist = cap_wrist.read()
160
+ if not ret_f or not ret_w:
161
+ continue
162
+
163
+ # Live display
164
+ if use_rerun:
165
+ rr.set_time_sequence("step", step)
166
+ rr.log("camera/front", rr.Image(frame_front))
167
+ rr.log("camera/wrist", rr.Image(frame_wrist))
168
+ rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES]))
169
+
170
+ observation = {
171
+ "observation.images.front": frame_front,
172
+ "observation.images.wrist": frame_wrist,
173
+ "observation.state": state_array,
174
+ }
175
+
176
+ # 3. Inference
177
+ with torch.inference_mode():
178
+ obs = prepare_observation_for_inference(
179
+ observation, torch.device("cuda"), args.task, "so101_follower"
180
+ )
181
+ obs = preprocessor(obs)
182
+ action = policy.select_action(obs)
183
+ action = postprocessor(action)
184
+
185
+ # 4. Convert to motor commands
186
+ robot_action = make_robot_action(action, ds_features)
187
+
188
+ # 5. Send to motors (calibrated/normalized by bus)
189
+ goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES}
190
+ try:
191
+ bus.sync_write("Goal_Position", goal_pos)
192
+ except ConnectionError:
193
+ bus.port_handler.is_using = False
194
+ bus.port_handler.ser.reset_input_buffer()
195
+
196
+ dt = time.perf_counter() - t0
197
+ step += 1
198
+
199
+ if step % 10 == 0:
200
+ pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES)
201
+ act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES)
202
+ log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms")
203
+
204
+ except KeyboardInterrupt:
205
+ log.warning("Stopped by user")
206
+ finally:
207
+ log.warning("Disabling torque...")
208
+ try:
209
+ bus.disable_torque()
210
+ except Exception:
211
+ for mid in MOTOR_IDS:
212
+ try:
213
+ bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0)
214
+ except Exception:
215
+ pass
216
+ bus.disconnect()
217
+ cap_front.release()
218
+ cap_wrist.release()
219
+ log.warning("Done")
220
+
221
+
222
+ if __name__ == "__main__":
223
+ main()
so100_dataset.py CHANGED
@@ -251,6 +251,19 @@ class SO100Dataset(Dataset):
251
  raise RuntimeError(f"Could not decode frame at t={timestamp} from {video_path}")
252
 
253
  def __getitem__(self, idx: int) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  dataset_path, ep_idx, frame_idx, num_frames, task, task_idx = self._frame_index[idx]
255
 
256
  df = self._load_parquet(dataset_path, ep_idx)
 
251
  raise RuntimeError(f"Could not decode frame at t={timestamp} from {video_path}")
252
 
253
  def __getitem__(self, idx: int) -> dict:
254
+ # Retry with a different sample if this one has corrupt/mismatched video
255
+ for _attempt in range(5):
256
+ try:
257
+ return self._get_sample(idx)
258
+ except (IndexError, RuntimeError, OSError) as e:
259
+ # Video duration doesn't match parquet timestamps, or file is corrupt.
260
+ # Pick a random different index and try again.
261
+ import random
262
+ idx = random.randint(0, len(self._frame_index) - 1)
263
+ # If all retries fail, raise
264
+ return self._get_sample(idx)
265
+
266
+ def _get_sample(self, idx: int) -> dict:
267
  dataset_path, ep_idx, frame_idx, num_frames, task, task_idx = self._frame_index[idx]
268
 
269
  df = self._load_parquet(dataset_path, ep_idx)