| """ |
| read data from lerobot |
| |
| requesting the server |
| """ |
|
|
| from starhelm.websocket_client_policy import WebsocketClientPolicy |
| import time |
| import torch |
| from vlaholo.datasets.lerobot_dataset import LeRobotDataset |
|
|
| |
| |
| import sys |
|
|
|
|
| import os |
|
|
| os.environ.pop("http_proxy", None) |
| os.environ.pop("https_proxy", None) |
| os.environ.pop("all_proxy", None) |
|
|
|
|
| def get_dummy_data(): |
| """ |
| Same data loader from vlaholo |
| |
| """ |
| dataset_repo_id = "data/qz_zz/0801_task9/pick" |
| dataset = LeRobotDataset(dataset_repo_id, episodes=[10], video_backend="pyav") |
|
|
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| num_workers=0, |
| batch_size=1, |
| ) |
|
|
| batch = next(batch for i, batch in enumerate(dataloader) if i == 29) |
| return dataset, batch |
|
|
|
|
| def tensor_as_np_image(t): |
| return t[0].permute(1, 2, 0).cpu().unsqueeze(0).numpy() |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| vla_model = WebsocketClientPolicy(host="172.16.0.111", port=9001) |
|
|
| ds, batch_data = get_dummy_data() |
| print(batch_data.keys()) |
| t0 = time.time() |
| benchmark_iters = 30 |
| for _ in range(benchmark_iters): |
| |
| t00 = time.time() |
|
|
| |
| image_cam_high = tensor_as_np_image(batch_data["observation.images.cam_high"]) |
| image_cam_left = tensor_as_np_image(batch_data["observation.images.cam_left_wrist"]) |
| image_cam_right = tensor_as_np_image(batch_data["observation.images.cam_right_wrist"]) |
| |
| print(f"image_cam_high: {image_cam_high.shape}") |
| |
| obs = { |
| "images": { |
| "cam_high": image_cam_high, |
| "cam_left_wrist": image_cam_left, |
| "cam_right_wrist": image_cam_right, |
| }, |
| |
| "state": batch_data["observation.state"].cpu().numpy(), |
| |
| "prompt": batch_data["task"], |
| |
| "debug": True, |
| } |
| action = vla_model.infer(obs=obs) |
|
|
| print( |
| "##info, action:", |
| action, |
| time.time() - t00, |
| ) |
| break |
| t1 = time.time() |
| print(f"cost: {t1-t0:.3f}, avg: {(t1-t0)/benchmark_iters}") |
|
|