| import torch |
|
|
| from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig |
| from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata |
| from lerobot.policies.act.modeling_act import ACTPolicy |
| from lerobot.policies.factory import make_pre_post_processors |
| from lerobot.policies.utils import build_inference_frame, make_robot_action |
| from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig |
|
|
| MAX_EPISODES = 5 |
| MAX_STEPS_PER_EPISODE = 20 |
|
|
|
|
| def main(): |
| device = torch.device("mps") |
| model_id = "<user>/robot_learning_tutorial_act" |
| model = ACTPolicy.from_pretrained(model_id) |
|
|
| dataset_id = "lerobot/svla_so101_pickplace" |
| |
| dataset_metadata = LeRobotDatasetMetadata(dataset_id) |
| preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats) |
|
|
| |
| follower_port = ... |
|
|
| |
| follower_id = ... |
|
|
| |
| |
| |
| camera_config = { |
| "side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), |
| "up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30), |
| } |
|
|
| robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config) |
| robot = SO100Follower(robot_cfg) |
| robot.connect() |
|
|
| for _ in range(MAX_EPISODES): |
| for _ in range(MAX_STEPS_PER_EPISODE): |
| obs = robot.get_observation() |
| obs_frame = build_inference_frame( |
| observation=obs, ds_features=dataset_metadata.features, device=device |
| ) |
|
|
| obs = preprocess(obs_frame) |
|
|
| action = model.select_action(obs) |
| action = postprocess(action) |
|
|
| action = make_robot_action(action, dataset_metadata.features) |
|
|
| robot.send_action(action) |
|
|
| print("Episode finished! Starting new episode...") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|