File size: 8,577 Bytes
9ad6280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#!/usr/bin/env python3
"""
Run Pi0.5 inference on SO-101.

Uses LeRobot's FeetechMotorsBus with calibration for correct normalization,
but bypasses lerobot_record's problematic control loop.

Usage:
  python infer_so101.py --task "pick up the blue football"
"""
import argparse
import json
import logging
import sys
import time
from pathlib import Path

import cv2
import numpy as np
import scservo_sdk as scs
import torch

sys.path.insert(0, str(Path(__file__).parent))
sys.path.insert(0, str(Path.home() / "lerobot" / "src"))

logging.basicConfig(level=logging.WARNING, format='%(asctime)s %(message)s', datefmt='%H:%M:%S')
log = logging.getLogger()

MOTOR_NAMES = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
MOTOR_IDS = [1, 2, 3, 4, 5, 6]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, required=True)
    parser.add_argument("--checkpoint", type=str,
                        default="/mnt/hdd/pi05-training/full_run/checkpoints/015000/pretrained_model")
    parser.add_argument("--port", type=str, default="/dev/ttyACM0")
    parser.add_argument("--cam-front", type=int, default=2)
    parser.add_argument("--cam-wrist", type=int, default=0)
    parser.add_argument("--max-steps", type=int, default=0, help="0 = run until Ctrl+C")
    args = parser.parse_args()

    # --- Connect motors using LeRobot's bus (for calibration/normalization) ---
    from lerobot.motors.feetech.feetech import FeetechMotorsBus
    from lerobot.motors import Motor, MotorNormMode, MotorCalibration

    bus = FeetechMotorsBus(
        port=args.port,
        motors={
            'shoulder_pan': Motor(1, 'sts3215', MotorNormMode.RANGE_M100_100),
            'shoulder_lift': Motor(2, 'sts3215', MotorNormMode.RANGE_M100_100),
            'elbow_flex': Motor(3, 'sts3215', MotorNormMode.RANGE_M100_100),
            'wrist_flex': Motor(4, 'sts3215', MotorNormMode.RANGE_M100_100),
            'wrist_roll': Motor(5, 'sts3215', MotorNormMode.RANGE_M100_100),
            'gripper': Motor(6, 'sts3215', MotorNormMode.RANGE_0_100),
        },
    )
    bus.connect()

    # Load calibration
    cal_path = Path.home() / ".cache/huggingface/lerobot/calibration/robots/so_follower/my_so101.json"
    cal = json.load(open(cal_path))
    cal_dict = {name: MotorCalibration(**vals) for name, vals in cal.items()}
    bus.write_calibration(cal_dict)
    log.warning("Bus connected with calibration")

    # Configure motors the same way LeRobot does in so_follower.configure()
    # This uses torque_disabled() context which disables torque, configures, re-enables
    with bus.torque_disabled():
        bus.configure_motors()
        for motor in bus.motors:
            bus.write("Operating_Mode", motor, 0)  # Position mode
            bus.write("P_Coefficient", motor, 16)
            bus.write("I_Coefficient", motor, 0)
            bus.write("D_Coefficient", motor, 32)
            bus.write("Goal_Velocity", motor, 600)  # Slow velocity limit
            bus.write("Acceleration", motor, 50)     # Gentle acceleration
            if motor == "gripper":
                bus.write("Max_Torque_Limit", motor, 500)
                bus.write("Protection_Current", motor, 250)
                bus.write("Overload_Torque", motor, 25)
    # torque_disabled() re-enables torque on exit
    # Velocity and acceleration limits prevent snapping
    log.warning("Motors configured and torque enabled (velocity/accel limited)")

    # --- Open cameras ---
    cap_front = cv2.VideoCapture(args.cam_front)
    cap_wrist = cv2.VideoCapture(args.cam_wrist)
    for cap in [cap_front, cap_wrist]:
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
    log.warning("Cameras open")

    # --- Load policy + preprocessor + postprocessor ---
    from lerobot.policies.factory import make_pre_post_processors
    from lerobot.policies.utils import prepare_observation_for_inference, make_robot_action
    from lerobot.configs.policies import PreTrainedConfig
    from lerobot.processor.rename_processor import rename_stats
    from lerobot.policies.pi05.modeling_pi05 import PI05Policy

    log.warning("Loading Pi0.5...")
    policy_cfg = PreTrainedConfig.from_pretrained(args.checkpoint)
    policy_cfg.pretrained_path = Path(args.checkpoint)

    policy = PI05Policy.from_pretrained(args.checkpoint)
    policy = policy.to("cuda")
    policy.eval()
    policy.reset()

    # Build stats from checkpoint's saved preprocessor
    rename_map = {
        "observation.images.front": "observation.images.base_0_rgb",
        "observation.images.wrist": "observation.images.left_wrist_0_rgb",
    }

    preprocessor, postprocessor = make_pre_post_processors(
        policy_cfg=policy_cfg,
        pretrained_path=policy_cfg.pretrained_path,
        preprocessor_overrides={
            "device_processor": {"device": "cuda"},
            "rename_observations_processor": {"rename_map": rename_map},
        },
    )

    action_names = [f"{name}.pos" for name in MOTOR_NAMES]
    ds_features = {"action": {"names": action_names}}

    # --- Set up live camera display ---
    try:
        import rerun as rr
        rr.init("so101_inference", spawn=True)
        use_rerun = True
        log.warning("Rerun viewer launched — live camera feed")
    except ImportError:
        use_rerun = False
        log.warning("Rerun not available, no live view")

    log.warning(f"Running: '{args.task}' — Ctrl+C to stop")

    step = 0
    try:
        while args.max_steps == 0 or step < args.max_steps:
            t0 = time.perf_counter()

            # 1. Read motor positions (calibrated/normalized by bus)
            try:
                pos_dict = bus.sync_read("Present_Position", num_retry=5)
            except ConnectionError:
                bus.port_handler.is_using = False
                bus.port_handler.ser.reset_input_buffer()
                continue

            # Build observation dict
            state_array = np.array([pos_dict[name] for name in MOTOR_NAMES], dtype=np.float32)

            # 2. Capture camera images
            ret_f, frame_front = cap_front.read()
            ret_w, frame_wrist = cap_wrist.read()
            if not ret_f or not ret_w:
                continue

            # Live display
            if use_rerun:
                rr.set_time_sequence("step", step)
                rr.log("camera/front", rr.Image(frame_front))
                rr.log("camera/wrist", rr.Image(frame_wrist))
                rr.log("state", rr.BarChart([pos_dict[n] for n in MOTOR_NAMES]))

            observation = {
                "observation.images.front": frame_front,
                "observation.images.wrist": frame_wrist,
                "observation.state": state_array,
            }

            # 3. Inference
            with torch.inference_mode():
                obs = prepare_observation_for_inference(
                    observation, torch.device("cuda"), args.task, "so101_follower"
                )
                obs = preprocessor(obs)
                action = policy.select_action(obs)
                action = postprocessor(action)

            # 4. Convert to motor commands
            robot_action = make_robot_action(action, ds_features)

            # 5. Send to motors (calibrated/normalized by bus)
            goal_pos = {name: robot_action[f"{name}.pos"] for name in MOTOR_NAMES}
            try:
                bus.sync_write("Goal_Position", goal_pos)
            except ConnectionError:
                bus.port_handler.is_using = False
                bus.port_handler.ser.reset_input_buffer()

            dt = time.perf_counter() - t0
            step += 1

            if step % 10 == 0:
                pos_str = " ".join(f"{pos_dict[n]:>7.1f}" for n in MOTOR_NAMES)
                act_str = " ".join(f"{robot_action[f'{n}.pos']:>7.1f}" for n in MOTOR_NAMES)
                log.warning(f"step {step:>4} | state=[{pos_str}] | action=[{act_str}] | {dt*1000:.0f}ms")

    except KeyboardInterrupt:
        log.warning("Stopped by user")
    finally:
        log.warning("Disabling torque...")
        try:
            bus.disable_torque()
        except Exception:
            for mid in MOTOR_IDS:
                try:
                    bus.packet_handler.write1ByteTxRx(bus.port_handler, mid, 40, 0)
                except Exception:
                    pass
        bus.disconnect()
        cap_front.release()
        cap_wrist.release()
        log.warning("Done")


if __name__ == "__main__":
    main()