| import torch |
|
|
| from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig |
| from lerobot.datasets.feature_utils import hw_to_dataset_features |
| from lerobot.policies.factory import make_pre_post_processors |
| from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy |
| from lerobot.policies.utils import build_inference_frame, make_robot_action |
| from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig |
|
|
| MAX_EPISODES = 5 |
| MAX_STEPS_PER_EPISODE = 20 |
|
|
|
|
| def main(): |
| device = torch.device("mps") |
| model_id = "lerobot/smolvla_base" |
|
|
| model = SmolVLAPolicy.from_pretrained(model_id) |
|
|
| preprocess, postprocess = make_pre_post_processors( |
| model.config, |
| model_id, |
| |
| preprocessor_overrides={"device_processor": {"device": str(device)}}, |
| ) |
|
|
| |
| follower_port = ... |
|
|
| |
| follower_id = ... |
|
|
| |
| |
| |
| camera_config = { |
| "camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), |
| "camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30), |
| } |
|
|
| robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config) |
| robot = SO100Follower(robot_cfg) |
| robot.connect() |
|
|
| task = "" |
| robot_type = "" |
|
|
| |
| action_features = hw_to_dataset_features(robot.action_features, "action") |
| obs_features = hw_to_dataset_features(robot.observation_features, "observation") |
| dataset_features = {**action_features, **obs_features} |
|
|
| for _ in range(MAX_EPISODES): |
| for _ in range(MAX_STEPS_PER_EPISODE): |
| obs = robot.get_observation() |
| obs_frame = build_inference_frame( |
| observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type |
| ) |
|
|
| obs = preprocess(obs_frame) |
|
|
| action = model.select_action(obs) |
| action = postprocess(action) |
| action = make_robot_action(action, dataset_features) |
| robot.send_action(action) |
|
|
| print("Episode finished! Starting new episode...") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|