gr00t1.5_starforce / test_starhelm.py
nnh-pbbb's picture
Add files using upload-large-folder tool
cd793b5 verified
"""
read data from lerobot
requesting the server
"""
from starhelm.websocket_client_policy import WebsocketClientPolicy
import time
import torch
from vlaholo.datasets.lerobot_dataset import LeRobotDataset
# from lerobot.datasets.lerobot_dataset import LeRobotDataset
# from starforce.data.dataset import LeRobotSingleDataset as LeRobotDataset
import sys
import os
os.environ.pop("http_proxy", None)
os.environ.pop("https_proxy", None)
os.environ.pop("all_proxy", None)
def get_dummy_data():
"""
Same data loader from vlaholo
"""
dataset_repo_id = "data/qz_zz/0801_task9/pick"
dataset = LeRobotDataset(dataset_repo_id, episodes=[10], video_backend="pyav")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
)
batch = next(batch for i, batch in enumerate(dataloader) if i == 29)
return dataset, batch
def tensor_as_np_image(t):
return t[0].permute(1, 2, 0).cpu().unsqueeze(0).numpy()
if __name__ == "__main__":
# vla_model = WebsocketClientPolicy(host="172.16.0.171", port=9001)
vla_model = WebsocketClientPolicy(host="172.16.0.111", port=9001)
ds, batch_data = get_dummy_data()
print(batch_data.keys())
t0 = time.time()
benchmark_iters = 30
for _ in range(benchmark_iters):
# print(batch)
t00 = time.time()
# hwc 0-1 numpy array
image_cam_high = tensor_as_np_image(batch_data["observation.images.cam_high"])
image_cam_left = tensor_as_np_image(batch_data["observation.images.cam_left_wrist"])
image_cam_right = tensor_as_np_image(batch_data["observation.images.cam_right_wrist"])
# [1, H, W, 3] 0-1 pixelvalue
print(f"image_cam_high: {image_cam_high.shape}")
# obs format
obs = {
"images": {
"cam_high": image_cam_high,
"cam_left_wrist": image_cam_left,
"cam_right_wrist": image_cam_right,
},
# state: [1, 14]
"state": batch_data["observation.state"].cpu().numpy(),
# str language
"prompt": batch_data["task"],
# for verbose
"debug": True,
}
action = vla_model.infer(obs=obs)
print(
"##info, action:",
action,
time.time() - t00,
)
break
t1 = time.time()
print(f"cost: {t1-t0:.3f}, avg: {(t1-t0)/benchmark_iters}")