| | |
| | import time |
| | import sys |
| | import queue |
| | import inspect |
| | import torch |
| |
|
| | from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy |
| | from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig |
| | from lerobot.datasets.utils import hw_to_dataset_features |
| | from lerobot.policies.factory import make_pre_post_processors |
| | from lerobot.policies.utils import build_inference_frame, make_robot_action |
| |
|
| | from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower |
| | from lerobot.teleoperators.so101_leader import SO101LeaderConfig |
| | from lerobot.teleoperators import make_teleoperator_from_config |
| |
|
| | from lerobot.datasets.lerobot_dataset import LeRobotDataset |
| |
|
| |
|
| | |
| | |
| | |
| | FOLLOWER_PORT = "/dev/ttyACM1" |
| | LEADER_PORT = "/dev/ttyACM2" |
| |
|
| | TOP_CAM_INDEX = 4 |
| | WRIST_CAM_INDEX = 9 |
| |
|
| | MODEL_ID = "lerobot/smolvla_base" |
| | TASK = "Pick up the red block." |
| | ROBOT_TYPE = "so101_follower" |
| |
|
| | FPS = 20 |
| | POLICY_SCALE = 1 |
| | EPISODE_SECONDS = 10.0 |
| |
|
| | |
| | curr_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) |
| | DATASET_REPO_ID = f"HenryZhang/so101_smolvla_eval_{curr_time}" |
| |
|
| | DATASET_ROOT = None |
| | USE_VIDEOS = True |
| | PUSH_TO_HUB_ON_EXIT = True |
| | PRIVATE_ON_HUB = False |
| | DATASET_TAGS = ["LeRobot", "so101", "smolvla", "policy-eval"] |
| | |
| |
|
| |
|
| | def log(msg): |
| | print(msg, flush=True) |
| |
|
| |
|
| | def start_enter_listener(cmd_q: "queue.Queue[str]"): |
| | """Press Enter to start one episode.""" |
| | while True: |
| | try: |
| | sys.stdin.readline() |
| | cmd_q.put("start_episode") |
| | except Exception: |
| | break |
| |
|
| |
|
| | def send_leader_action(robot, leader_action): |
| | if not isinstance(leader_action, dict): |
| | return |
| | out = {k: float(leader_action[k]) for k in robot.action_features.keys() if k in leader_action} |
| | if out: |
| | robot.send_action(out) |
| |
|
| |
|
| | def _import_build_dataset_frame(): |
| | try: |
| | from lerobot.common.datasets.utils import build_dataset_frame |
| | return build_dataset_frame |
| | except Exception: |
| | from lerobot.datasets.utils import build_dataset_frame |
| | return build_dataset_frame |
| |
|
| |
|
| | def create_dataset(repo_id, fps, root, robot_type, features, use_videos, num_cameras): |
| | kwargs = dict( |
| | repo_id=repo_id, |
| | fps=fps, |
| | root=root, |
| | robot_type=robot_type, |
| | features=features, |
| | use_videos=use_videos, |
| | image_writer_processes=0, |
| | image_writer_threads=4 * max(num_cameras, 1), |
| | ) |
| |
|
| | try: |
| | if "single_task" in inspect.signature(LeRobotDataset.create).parameters: |
| | kwargs["single_task"] = TASK |
| | except Exception: |
| | pass |
| |
|
| | try: |
| | ds = LeRobotDataset.create(**kwargs, exist_ok=True) |
| | except TypeError: |
| | ds = LeRobotDataset.create(**kwargs) |
| |
|
| | if hasattr(ds, "start_image_writer") and num_cameras > 0: |
| | ds.start_image_writer(num_processes=0, num_threads=4 * num_cameras) |
| |
|
| | log(f"[INFO] Dataset ready: {repo_id}") |
| | return ds |
| |
|
| |
|
| | def dataset_add_frame_compat(dataset, frame, task): |
| | try: |
| | if "task" in inspect.signature(dataset.add_frame).parameters: |
| | dataset.add_frame(frame, task=task) |
| | return |
| | except Exception: |
| | pass |
| |
|
| | frame["task"] = task |
| | dataset.add_frame(frame) |
| |
|
| |
|
| | def dataset_push_compat(dataset, repo_id, tags, private): |
| | try: |
| | if len(inspect.signature(dataset.push_to_hub).parameters) >= 1: |
| | dataset.push_to_hub(repo_id, tags=tags, private=private) |
| | return |
| | except Exception: |
| | pass |
| |
|
| | dataset.push_to_hub(tags=tags, private=private) |
| |
|
| |
|
| | def main(): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | log(f"[INFO] Device: {device}") |
| |
|
| | build_dataset_frame = _import_build_dataset_frame() |
| |
|
| | |
| | log(f"[INFO] Loading SmolVLA: {MODEL_ID}") |
| | policy = SmolVLAPolicy.from_pretrained(MODEL_ID).to(device) |
| | policy.eval() |
| |
|
| | preprocess, postprocess = make_pre_post_processors( |
| | policy.config, |
| | MODEL_ID, |
| | preprocessor_overrides={"device_processor": {"device": str(device)}}, |
| | postprocessor_overrides={"device_processor": {"device": str(device)}}, |
| | ) |
| |
|
| | |
| | camera_cfg = { |
| | "camera1": OpenCVCameraConfig(index_or_path=TOP_CAM_INDEX, width=640, height=480, fps=30), |
| | "camera2": OpenCVCameraConfig(index_or_path=WRIST_CAM_INDEX, width=640, height=480, fps=30), |
| | } |
| |
|
| | |
| | robot_cfg = SO101FollowerConfig(port=FOLLOWER_PORT, id="so101_follower_arm", cameras=camera_cfg) |
| | leader_cfg = SO101LeaderConfig(port=LEADER_PORT, id="so101_leader_arm") |
| |
|
| | log("[INFO] Connecting follower...") |
| | robot = SO101Follower(robot_cfg) |
| | robot.connect() |
| |
|
| | log("[INFO] Connecting leader...") |
| | teleop = make_teleoperator_from_config(leader_cfg) |
| | teleop.connect() |
| |
|
| | |
| | action_features = hw_to_dataset_features(robot.action_features, "action", USE_VIDEOS) |
| | obs_features = hw_to_dataset_features(robot.observation_features, "observation", USE_VIDEOS) |
| | dataset_features = {**action_features, **obs_features} |
| |
|
| | dataset = create_dataset( |
| | DATASET_REPO_ID, |
| | FPS, |
| | DATASET_ROOT, |
| | robot.name, |
| | dataset_features, |
| | USE_VIDEOS, |
| | len(getattr(robot, "cameras", [])), |
| | ) |
| |
|
| | |
| | cmd_q = queue.Queue() |
| | import threading |
| | threading.Thread(target=start_enter_listener, args=(cmd_q,), daemon=True).start() |
| |
|
| | log("\n[INFO] Press Enter to run ONE episode. Ctrl+C to exit.\n") |
| |
|
| | dt = 1.0 / FPS |
| | mode = "RESET" |
| | episode_idx = 0 |
| | episode_end_time = None |
| |
|
| | policy.reset() |
| |
|
| | try: |
| | while True: |
| | t0 = time.time() |
| |
|
| | if mode == "RESET" and not cmd_q.empty(): |
| | cmd_q.get_nowait() |
| | episode_idx += 1 |
| | policy.reset() |
| | if hasattr(dataset, "clear_episode_buffer"): |
| | dataset.clear_episode_buffer() |
| | episode_end_time = time.time() + EPISODE_SECONDS |
| | mode = "POLICY" |
| | log(f"[INFO] Episode {episode_idx} START") |
| |
|
| | if mode == "RESET": |
| | send_leader_action(robot, teleop.get_action()) |
| |
|
| | else: |
| | if time.time() >= episode_end_time: |
| | log(f"[INFO] Episode {episode_idx} END — saving...") |
| | t_save = time.time() |
| | dataset.save_episode() |
| | log(f"[INFO] Saved in {time.time() - t_save:.1f}s") |
| | mode = "RESET" |
| | episode_end_time = None |
| | else: |
| | obs = robot.get_observation() |
| | |
| | obs_frame = build_inference_frame( |
| | observation=obs, |
| | ds_features=dataset_features, |
| | device=device, |
| | task=TASK, |
| | robot_type=ROBOT_TYPE, |
| | ) |
| |
|
| | with torch.no_grad(): |
| | batch = preprocess(obs_frame) |
| | action = policy.select_action(batch) |
| | action = postprocess(action) |
| |
|
| | if isinstance(action, torch.Tensor): |
| | action = action.squeeze(0) * POLICY_SCALE |
| |
|
| | robot_action = make_robot_action(action, dataset_features) |
| | sent_action = robot.send_action(robot_action) |
| | print("Predicted:", action, "robot:", robot_action, "sent:", sent_action) |
| | frame = { |
| | **build_dataset_frame(dataset.features, obs, "observation"), |
| | **build_dataset_frame(dataset.features, sent_action, "action"), |
| | } |
| |
|
| | dataset_add_frame_compat(dataset, frame, TASK) |
| |
|
| | time.sleep(max(0.0, dt - (time.time() - t0))) |
| |
|
| | except KeyboardInterrupt: |
| | log("\n[INFO] Ctrl+C received.") |
| |
|
| | finally: |
| | teleop.disconnect() |
| | robot.disconnect() |
| |
|
| | if PUSH_TO_HUB_ON_EXIT: |
| | log("[INFO] Pushing dataset to Hub...") |
| | dataset_push_compat(dataset, DATASET_REPO_ID, DATASET_TAGS, PRIVATE_ON_HUB) |
| |
|
| | log("[INFO] Done.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|