| | |
| | import os |
| | import time |
| | import inspect |
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| |
|
| | from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy |
| | from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig |
| | from lerobot.policies.utils import build_inference_frame, make_robot_action |
| | from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower |
| | from lerobot.processor import PolicyProcessorPipeline |
| | from lerobot.datasets.utils import hw_to_dataset_features |
| |
|
| | |
| | |
| | |
| | MODEL_ID = "lerobot/smolvla_base" |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | FOLLOWER_PORT = "/dev/ttyACM3" |
| |
|
| | TOP_CAM_INDEX = 4 |
| | WRIST_CAM_INDEX = 9 |
| |
|
| | TASK = "Pick up the red block." |
| | ROBOT_TYPE = "so101_follower" |
| |
|
| | FPS = 10 |
| | EPISODE_SECONDS = 5.0 |
| |
|
| | BUFFER = "so100" |
| | os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| |
|
| |
|
| | def hw_to_dataset_features_compat(hw_feats, prefix: str, use_videos: bool = True): |
| | sig = inspect.signature(hw_to_dataset_features) |
| | params = sig.parameters |
| |
|
| | if "use_videos" in params: |
| | return hw_to_dataset_features(hw_feats, prefix, use_videos=use_videos) |
| | if "use_images" in params: |
| | return hw_to_dataset_features(hw_feats, prefix, use_images=use_videos) |
| | if len(params) >= 3: |
| | return hw_to_dataset_features(hw_feats, prefix, use_videos) |
| | return hw_to_dataset_features(hw_feats, prefix) |
| |
|
| |
|
| | |
| | |
| | |
| | print("[INFO] Loading SmolVLA...") |
| | policy = SmolVLAPolicy.from_pretrained(MODEL_ID).to(DEVICE) |
| | policy.eval() |
| |
|
| | |
| | |
| | |
| | print("[INFO] Loading pretrained preprocessor...") |
| | preprocess = PolicyProcessorPipeline.from_pretrained( |
| | MODEL_ID, |
| | config_filename="policy_preprocessor.json", |
| | overrides={"device_processor": {"device": DEVICE}}, |
| | ) |
| |
|
| | |
| | |
| | |
| | print("[INFO] Loading pretrained action stats...") |
| | state_path = hf_hub_download( |
| | repo_id=MODEL_ID, |
| | filename="policy_postprocessor_step_0_unnormalizer_processor.safetensors", |
| | ) |
| | state = load_file(state_path) |
| |
|
| | mean = state[f"{BUFFER}.buffer.action.mean"].to(DEVICE) |
| | std = state[f"{BUFFER}.buffer.action.std"].to(DEVICE) |
| |
|
| | print(f"[INFO] Action dim = {mean.numel()}") |
| |
|
| |
|
| | def decode_action(action_norm: torch.Tensor) -> torch.Tensor: |
| | return action_norm * std + mean |
| |
|
| |
|
| | |
| | |
| | |
| | camera_cfg = { |
| | "camera1": OpenCVCameraConfig(index_or_path=TOP_CAM_INDEX, width=640, height=480, fps=30), |
| | "camera2": OpenCVCameraConfig(index_or_path=WRIST_CAM_INDEX, width=640, height=480, fps=30), |
| | } |
| |
|
| | |
| | |
| | |
| | print("[INFO] Connecting SO101 follower...") |
| | robot_cfg = SO101FollowerConfig( |
| | port=FOLLOWER_PORT, |
| | id="so101_follower_arm", |
| | cameras=camera_cfg, |
| | ) |
| | robot = SO101Follower(robot_cfg) |
| | robot.connect() |
| |
|
| | |
| | |
| | |
| | USE_VIDEOS = True |
| | action_features = hw_to_dataset_features_compat(robot.action_features, "action", use_videos=USE_VIDEOS) |
| | obs_features = hw_to_dataset_features_compat(robot.observation_features, "observation", use_videos=USE_VIDEOS) |
| | ds_features = {**obs_features, **action_features} |
| |
|
| | |
| | assert "action" in ds_features and "names" in ds_features["action"], f"ds_features['action'] missing names: {ds_features.get('action')}" |
| |
|
| | |
| | |
| | |
| | dt = 1.0 / FPS |
| | t_end = time.time() + EPISODE_SECONDS |
| |
|
| | print("[INFO] Starting evaluation...") |
| | policy.reset() |
| |
|
| | try: |
| | while time.time() < t_end: |
| | t0 = time.time() |
| |
|
| | obs = robot.get_observation() |
| |
|
| | obs_frame = build_inference_frame( |
| | observation=obs, |
| | ds_features=ds_features, |
| | device=DEVICE, |
| | task=TASK, |
| | robot_type=ROBOT_TYPE, |
| | ) |
| |
|
| | batch = preprocess(obs_frame) |
| |
|
| | with torch.no_grad(): |
| | action_norm = policy.select_action(batch) |
| |
|
| | action_real = decode_action(action_norm).squeeze(0) |
| |
|
| | |
| | robot_action = make_robot_action(action_real, ds_features) |
| | robot.send_action(robot_action) |
| |
|
| | time.sleep(max(0.0, dt - (time.time() - t0))) |
| |
|
| | except KeyboardInterrupt: |
| | print("\n[INFO] Ctrl+C received.") |
| |
|
| | finally: |
| | try: |
| | robot.disconnect() |
| | except Exception: |
| | pass |
| | print("[INFO] Done.") |
| |
|