Instructions to use StrongRoboticsLab/pi05-so100-diverse with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- LeRobot
How to use StrongRoboticsLab/pi05-so100-diverse with LeRobot:
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| Run Pi0.5 inference on SO-101. | |
| Uses LeRobot's FeetechMotorsBus with calibration for correct normalization, | |
| but bypasses lerobot_record's problematic control loop. | |
| Usage: | |
| python infer_so101.py --task "pick up the blue football" | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import scservo_sdk as scs | |
| import torch | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| sys.path.insert(0, str(Path.home() / "lerobot" / "src")) | |
| logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S') | |
| log = logging.getLogger() | |
| MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] | |
| MOTOR_IDS = [1, 2, 3, 4, 5, 6] | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--task", type=str, required=True) | |
| parser.add_argument("--checkpoint", type=str, | |
| default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model") | |
| parser.add_argument("--port", type=str, default="/dev/ttyACM0") | |
| parser.add_argument("--cam-front", type=int, default=2) | |
| parser.add_argument("--cam-wrist", type=int, default=0) | |
| parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C") | |
| args = parser.parse_args() | |
| # --- Connect motors using LeRobot's bus (for calibration/normalization) --- | |
| from lerobot.motors.feetech.feetech import FeetechMotorsBus | |
| from lerobot.motors import Motor, MotorNormMode, MotorCalibration | |
| bus = FeetechMotorsBus( | |
| port=args.port, | |
| motors={ | |
| 'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100), | |
| 'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100), | |
| 'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100), | |
| 'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100), | |
| 'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100), | |
| 'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100), | |
| }, | |
| ) | |
| bus.connect() | |
| # Load calibration | |
| cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json" | |
| cal = json.load(open(cal_path)) | |
| cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()} | |
| bus.write_calibration(cal_dict) | |
| log.warning("Bus connected with calibration") | |
| # Configure motors the same way LeRobot does in so_follower.configure() | |
| # This uses torque_disabled() context which disables torque, configures, re-enables | |
| with bus.torque_disabled(): | |
| bus.configure_motors() | |
| for motor in bus.motors: | |
| bus.write("Operating_Mode", motor, 0) # Position mode | |
| bus.write("P_Coefficient", motor, 16) | |
| bus.write("I_Coefficient", motor, 0) | |
| bus.write("D_Coefficient", motor, 32) | |
| bus.write("Goal_Velocity", motor, 600) # Slow velocity limit | |
| bus.write("Acceleration", motor, 50) # Gentle acceleration | |
| if motor == "gripper": | |
| bus.write("Max_Torque_Limit", motor, 500) | |
| bus.write("Protection_Current", motor, 250) | |
| bus.write("Overload_Torque", motor, 25) | |
| # torque_disabled() re-enables torque on exit | |
| # Velocity and acceleration limits prevent snapping | |
| log.warning("Motors configured and torque enabled (velocity/accel limited)") | |
| # --- Open cameras --- | |
| cap_front = cv2.VideoCapture(args.cam_front) | |
| cap_wrist = cv2.VideoCapture(args.cam_wrist) | |
| for cap in [cap_front, cap_wrist]: | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) | |
| log.warning("Cameras open") | |
| # --- Load policy + preprocessor + postprocessor --- | |
| from lerobot.policies.factory import make_pre_post_processors | |
| from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action | |
| from lerobot.configs.policies import PreTrainedConfig | |
| from lerobot.processor.rename_processor import rename_stats | |
| from lerobot.policies.pi05.modeling_pi05 import PI05Policy | |
| log.warning("Loading Pi0.5...") | |
| policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint) | |
| policy_cfg.pretrained_path = Path(args.checkpoint) | |
| policy = PI05Policy.from_pretrained(args.checkpoint) | |
| policy = policy.to("cuda") | |
| policy.eval() | |
| policy.reset() | |
| # Build stats from checkpoint's saved preprocessor | |
| rename_map = { | |
| "observation.images.front": "observation.images.base_0_rgb", | |
| "observation.images.wrist": "observation.images.left_wrist_0_rgb", | |
| } | |
| preprocessor, postprocessor = make_pre_post_processors( | |
| policy_cfg=policy_cfg, | |
| pretrained_path=policy_cfg.pretrained_path, | |
| preprocessor_overrides={ | |
| "device_processor": {"device": "cuda"}, | |
| "rename_observations_processor": {"rename_map": rename_map}, | |
| }, | |
| ) | |
| action_names = [f"{name}.pos" for name in MOTOR_NAMES] | |
| ds_features = {"action": {"names": action_names}} | |
| # --- Set up live camera display --- | |
| try: | |
| import rerun as rr | |
| rr.init("so101_inference", spawn=True) | |
| use_rerun = True | |
| log.warning("Rerun viewer launched — live camera feed") | |
| except ImportError: | |
| use_rerun = False | |
| log.warning("Rerun not available, no live view") | |
| log.warning(f"Running: '{args.task}' — Ctrl+C to stop") | |
| step = 0 | |
| try: | |
| while args.max_steps == 0 or step < args.max_steps: | |
| t0 = time.perf_counter() | |
| # 1. Read motor positions (calibrated/normalized by bus) | |
| try: | |
| pos_dict = bus.sync_read("Present_Position", num_retry=5) | |
| except ConnectionError: | |
| bus.port_handler.is_using = False | |
| bus.port_handler.ser.reset_input_buffer() | |
| continue | |
| # Build observation dict | |
| state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32) | |
| # 2. Capture camera images | |
| ret_f, frame_front = cap_front.read() | |
| ret_w, frame_wrist = cap_wrist.read() | |
| if not ret_f or not ret_w: | |
| continue | |
| # Live display | |
| if use_rerun: | |
| rr.set_time_sequence("step", step) | |
| rr.log("camera/front", rr.Image(frame_front)) | |
| rr.log("camera/wrist", rr.Image(frame_wrist)) | |
| rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES])) | |
| observation = { | |
| "observation.images.front": frame_front, | |
| "observation.images.wrist": frame_wrist, | |
| "observation.state": state_array, | |
| } | |
| # 3. Inference | |
| with torch.inference_mode(): | |
| obs = prepare_observation_for_inference( | |
| observation, torch.device("cuda"), args.task, "so101_follower" | |
| ) | |
| obs = preprocessor(obs) | |
| action = policy.select_action(obs) | |
| action = postprocessor(action) | |
| # 4. Convert to motor commands | |
| robot_action = make_robot_action(action, ds_features) | |
| # 5. Send to motors (calibrated/normalized by bus) | |
| goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES} | |
| try: | |
| bus.sync_write("Goal_Position", goal_pos) | |
| except ConnectionError: | |
| bus.port_handler.is_using = False | |
| bus.port_handler.ser.reset_input_buffer() | |
| dt = time.perf_counter() - t0 | |
| step += 1 | |
| if step % 10 == 0: | |
| pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES) | |
| act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES) | |
| log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms") | |
| except KeyboardInterrupt: | |
| log.warning("Stopped by user") | |
| finally: | |
| log.warning("Disabling torque...") | |
| try: | |
| bus.disable_torque() | |
| except Exception: | |
| for mid in MOTOR_IDS: | |
| try: | |
| bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0) | |
| except Exception: | |
| pass | |
| bus.disconnect() | |
| cap_front.release() | |
| cap_wrist.release() | |
| log.warning("Done") | |
| if __name__ == "__main__": | |
| main() | |