Instructions to use StrongRoboticsLab/pi05-so100-diverse with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use StrongRoboticsLab/pi05-so100-diverse with LeRobot:
- Notebooks
- Google Colab
- Kaggle
File size: 8,577 Bytes
9ad6280 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | #!/usr/bin/env python3
"""
Run Pi0.5 inference on SO-101.
Uses LeRobot's FeetechMotorsBus with calibration for correct normalization,
but bypasses lerobot_record's problematic control loop.
Usage:
python infer_so101.py --task "pick up the blue football"
"""
import argparse
import json
import logging
import sys
import time
from pathlib import Path
import cv2
import numpy as np
import scservo_sdk as scs
import torch
sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path.home() / "lerobot" / "src"))
logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S')
log = logging.getLogger()
MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
MOTOR_IDS = [1, 2, 3, 4, 5, 6]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, required=True)
parser.add_argument("--checkpoint", type=str,
default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model")
parser.add_argument("--port", type=str, default="/dev/ttyACM0")
parser.add_argument("--cam-front", type=int, default=2)
parser.add_argument("--cam-wrist", type=int, default=0)
parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C")
args = parser.parse_args()
# --- Connect motors using LeRobot's bus (for calibration/normalization) ---
from lerobot.motors.feetech.feetech import FeetechMotorsBus
from lerobot.motors import Motor, MotorNormMode, MotorCalibration
bus = FeetechMotorsBus(
port=args.port,
motors={
'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100),
'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100),
'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100),
'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100),
'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100),
'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100),
},
)
bus.connect()
# Load calibration
cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json"
cal = json.load(open(cal_path))
cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()}
bus.write_calibration(cal_dict)
log.warning("Bus connected with calibration")
# Configure motors the same way LeRobot does in so_follower.configure()
# This uses torque_disabled() context which disables torque, configures, re-enables
with bus.torque_disabled():
bus.configure_motors()
for motor in bus.motors:
bus.write("Operating_Mode", motor, 0) # Position mode
bus.write("P_Coefficient", motor, 16)
bus.write("I_Coefficient", motor, 0)
bus.write("D_Coefficient", motor, 32)
bus.write("Goal_Velocity", motor, 600) # Slow velocity limit
bus.write("Acceleration", motor, 50) # Gentle acceleration
if motor == "gripper":
bus.write("Max_Torque_Limit", motor, 500)
bus.write("Protection_Current", motor, 250)
bus.write("Overload_Torque", motor, 25)
# torque_disabled() re-enables torque on exit
# Velocity and acceleration limits prevent snapping
log.warning("Motors configured and torque enabled (velocity/accel limited)")
# --- Open cameras ---
cap_front = cv2.VideoCapture(args.cam_front)
cap_wrist = cv2.VideoCapture(args.cam_wrist)
for cap in [cap_front, cap_wrist]:
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
log.warning("Cameras open")
# --- Load policy + preprocessor + postprocessor ---
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action
from lerobot.configs.policies import PreTrainedConfig
from lerobot.processor.rename_processor import rename_stats
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
log.warning("Loading Pi0.5...")
policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint)
policy_cfg.pretrained_path = Path(args.checkpoint)
policy = PI05Policy.from_pretrained(args.checkpoint)
policy = policy.to("cuda")
policy.eval()
policy.reset()
# Build stats from checkpoint's saved preprocessor
rename_map = {
"observation.images.front": "observation.images.base_0_rgb",
"observation.images.wrist": "observation.images.left_wrist_0_rgb",
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy_cfg,
pretrained_path=policy_cfg.pretrained_path,
preprocessor_overrides={
"device_processor": {"device": "cuda"},
"rename_observations_processor": {"rename_map": rename_map},
},
)
action_names = [f"{name}.pos" for name in MOTOR_NAMES]
ds_features = {"action": {"names": action_names}}
# --- Set up live camera display ---
try:
import rerun as rr
rr.init("so101_inference", spawn=True)
use_rerun = True
log.warning("Rerun viewer launched — live camera feed")
except ImportError:
use_rerun = False
log.warning("Rerun not available, no live view")
log.warning(f"Running: '{args.task}' — Ctrl+C to stop")
step = 0
try:
while args.max_steps == 0 or step < args.max_steps:
t0 = time.perf_counter()
# 1. Read motor positions (calibrated/normalized by bus)
try:
pos_dict = bus.sync_read("Present_Position", num_retry=5)
except ConnectionError:
bus.port_handler.is_using = False
bus.port_handler.ser.reset_input_buffer()
continue
# Build observation dict
state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32)
# 2. Capture camera images
ret_f, frame_front = cap_front.read()
ret_w, frame_wrist = cap_wrist.read()
if not ret_f or not ret_w:
continue
# Live display
if use_rerun:
rr.set_time_sequence("step", step)
rr.log("camera/front", rr.Image(frame_front))
rr.log("camera/wrist", rr.Image(frame_wrist))
rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES]))
observation = {
"observation.images.front": frame_front,
"observation.images.wrist": frame_wrist,
"observation.state": state_array,
}
# 3. Inference
with torch.inference_mode():
obs = prepare_observation_for_inference(
observation, torch.device("cuda"), args.task, "so101_follower"
)
obs = preprocessor(obs)
action = policy.select_action(obs)
action = postprocessor(action)
# 4. Convert to motor commands
robot_action = make_robot_action(action, ds_features)
# 5. Send to motors (calibrated/normalized by bus)
goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES}
try:
bus.sync_write("Goal_Position", goal_pos)
except ConnectionError:
bus.port_handler.is_using = False
bus.port_handler.ser.reset_input_buffer()
dt = time.perf_counter() - t0
step += 1
if step % 10 == 0:
pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES)
act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES)
log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms")
except KeyboardInterrupt:
log.warning("Stopped by user")
finally:
log.warning("Disabling torque...")
try:
bus.disable_torque()
except Exception:
for mid in MOTOR_IDS:
try:
bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0)
except Exception:
pass
bus.disconnect()
cap_front.release()
cap_wrist.release()
log.warning("Done")
if __name__ == "__main__":
main()
|