Add files using upload-large-folder tool
Browse files- robot/cam.py +144 -0
- starforce.egg-info/dependency_links.txt +1 -0
- starhelm/starhelm/__init__.py +3 -0
- starhelm/starhelm/image_tools_test.py +37 -0
- test_starhelm.py +88 -0
- tests/alter_lerobot_key.py +53 -0
- tests/async_client.py +602 -0
- tests/install_av_opencv.sh +5 -0
- tests/modality.json +54 -0
- tests/replay_sl.py +628 -0
- tests/save_s1.py +70 -0
- tests/save_s1_7B.py +69 -0
- tests/test_cv.py +22 -0
- tests/test_hf.py +19 -0
- tests/test_pi0.py +58 -0
- tests/test_starhelm.py +0 -0
- tests/test_tensor.py +21 -0
- tests/vis_lerobot_data.py +200 -0
- tests/vis_lerobot_data_v1.py +125 -0
- wandb/debug.log +30 -0
robot/cam.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
import json_numpy
|
| 3 |
+
import numpy as np
|
| 4 |
+
import requests
|
| 5 |
+
import queue
|
| 6 |
+
import cv2
|
| 7 |
+
import pickle, os
|
| 8 |
+
from matplotlib.pyplot import step
|
| 9 |
+
from typing import Any, cast
|
| 10 |
+
|
| 11 |
+
json_numpy.patch()
|
| 12 |
+
|
| 13 |
+
from starforce.experiment.data_config import DATA_CONFIG_MAP
|
| 14 |
+
from starforce.model.policy import Gr00tPolicy
|
| 15 |
+
|
| 16 |
+
# Import RTCController from separate module
|
| 17 |
+
from loguru import logger
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import pyrealsense2 as rs
|
| 21 |
+
except ImportError:
|
| 22 |
+
print("Warning: pyrealsense2 not available. Camera functionality will be limited.")
|
| 23 |
+
rs = None
|
| 24 |
+
|
| 25 |
+
# Help static analyzers: treat rs as dynamic Any when available
|
| 26 |
+
if rs is not None:
|
| 27 |
+
rs = cast(Any, rs)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CameraWrapper:
|
| 31 |
+
def __init__(
|
| 32 |
+
self, devices=None, width=640, height=480, fps=30, num_realsense=0, cv_format="MJPEG"
|
| 33 |
+
):
|
| 34 |
+
self.width = width
|
| 35 |
+
self.height = height
|
| 36 |
+
self.fps = fps
|
| 37 |
+
self.num_realsense = max(0, int(num_realsense))
|
| 38 |
+
self.cv_format = cv_format
|
| 39 |
+
self.cameras = [] # list of dicts: {type: 'rs'|'cv', handle: pipeline|cap}
|
| 40 |
+
self.device_ids = devices if devices is not None else []
|
| 41 |
+
self._open_cameras()
|
| 42 |
+
print(f"successfully opened {len(self.cameras)} cameras!")
|
| 43 |
+
|
| 44 |
+
def _open_cameras(self):
|
| 45 |
+
if not self.device_ids:
|
| 46 |
+
print("No devices provided for CameraWrapper")
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
for idx, dev in enumerate(self.device_ids):
|
| 50 |
+
# Decide camera type
|
| 51 |
+
use_realsense = idx < self.num_realsense
|
| 52 |
+
|
| 53 |
+
if use_realsense:
|
| 54 |
+
if rs is None:
|
| 55 |
+
print(
|
| 56 |
+
f"pyrealsense2 not available, skipping RealSense device at index {idx} (id: {dev})"
|
| 57 |
+
)
|
| 58 |
+
continue
|
| 59 |
+
try:
|
| 60 |
+
serial = str(dev)
|
| 61 |
+
pipeline = rs.pipeline() # type: ignore[attr-defined]
|
| 62 |
+
config = rs.config() # type: ignore[attr-defined]
|
| 63 |
+
config.enable_device(serial)
|
| 64 |
+
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.bgr8, self.fps) # type: ignore[attr-defined]
|
| 65 |
+
pipeline.start(config)
|
| 66 |
+
self.cameras.append({"type": "rs", "handle": pipeline})
|
| 67 |
+
print(f"RealSense camera {serial} opened successfully")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Failed to open RealSense camera {dev}: {e}")
|
| 70 |
+
else:
|
| 71 |
+
try:
|
| 72 |
+
device_index = int(dev)
|
| 73 |
+
print(f"Ready to read deive: {device_index}")
|
| 74 |
+
cap = cv2.VideoCapture(device_index)
|
| 75 |
+
|
| 76 |
+
if self.cv_format == "MJPEG":
|
| 77 |
+
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) # type: ignore[attr-defined]
|
| 78 |
+
elif self.cv_format == "YUYV":
|
| 79 |
+
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV")) # type: ignore[attr-defined]
|
| 80 |
+
|
| 81 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
| 82 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
| 83 |
+
cap.set(cv2.CAP_PROP_FPS, self.fps)
|
| 84 |
+
|
| 85 |
+
if not cap.isOpened():
|
| 86 |
+
raise ValueError(f"Cannot open OpenCV camera {device_index}")
|
| 87 |
+
|
| 88 |
+
self.cameras.append({"type": "cv", "handle": cap})
|
| 89 |
+
print(f"OpenCV camera {device_index} opened successfully")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Failed to open OpenCV camera {dev}: {e}")
|
| 92 |
+
|
| 93 |
+
def get_images(self):
|
| 94 |
+
images = []
|
| 95 |
+
if len(self.cameras) == 0:
|
| 96 |
+
# Return dummy images if no cameras available - use 640x480 which is expected by the model
|
| 97 |
+
for _ in range(max(1, len(self.device_ids))):
|
| 98 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 99 |
+
dummy_img[:, :, :] = 128 # Gray color instead of black
|
| 100 |
+
images.append(dummy_img)
|
| 101 |
+
return images
|
| 102 |
+
|
| 103 |
+
for cam in self.cameras:
|
| 104 |
+
if cam["type"] == "rs":
|
| 105 |
+
try:
|
| 106 |
+
pipeline = cam["handle"]
|
| 107 |
+
frames = pipeline.wait_for_frames()
|
| 108 |
+
color_frame = frames.get_color_frame()
|
| 109 |
+
if not color_frame:
|
| 110 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 111 |
+
dummy_img[:, :, :] = 128
|
| 112 |
+
images.append(dummy_img)
|
| 113 |
+
else:
|
| 114 |
+
img = np.asanyarray(color_frame.get_data())
|
| 115 |
+
images.append(img)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"Error reading from RealSense: {e}")
|
| 118 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 119 |
+
dummy_img[:, :, :] = 128
|
| 120 |
+
images.append(dummy_img)
|
| 121 |
+
elif cam["type"] == "cv":
|
| 122 |
+
cap = cam["handle"]
|
| 123 |
+
ret, frame = cap.read()
|
| 124 |
+
if not ret or frame is None:
|
| 125 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 126 |
+
dummy_img[:, :, :] = 128
|
| 127 |
+
images.append(dummy_img)
|
| 128 |
+
else:
|
| 129 |
+
images.append(frame)
|
| 130 |
+
return images
|
| 131 |
+
|
| 132 |
+
def release(self):
|
| 133 |
+
for cam in self.cameras:
|
| 134 |
+
if cam["type"] == "rs":
|
| 135 |
+
try:
|
| 136 |
+
cam["handle"].stop()
|
| 137 |
+
except Exception:
|
| 138 |
+
pass
|
| 139 |
+
elif cam["type"] == "cv":
|
| 140 |
+
try:
|
| 141 |
+
cam["handle"].release()
|
| 142 |
+
except Exception:
|
| 143 |
+
pass
|
| 144 |
+
self.cameras = []
|
starforce.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
starhelm/starhelm/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .websocket_client_policy import WebsocketClientPolicy
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
starhelm/starhelm/image_tools_test.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import openpi_client.image_tools as image_tools
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_resize_with_pad_shapes():
|
| 7 |
+
# Test case 1: Resize image with larger dimensions
|
| 8 |
+
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
| 9 |
+
height = 20
|
| 10 |
+
width = 20
|
| 11 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 12 |
+
assert resized_images.shape == (2, height, width, 3)
|
| 13 |
+
assert np.all(resized_images == 0)
|
| 14 |
+
|
| 15 |
+
# Test case 2: Resize image with smaller dimensions
|
| 16 |
+
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
| 17 |
+
height = 15
|
| 18 |
+
width = 15
|
| 19 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 20 |
+
assert resized_images.shape == (3, height, width, 3)
|
| 21 |
+
assert np.all(resized_images == 0)
|
| 22 |
+
|
| 23 |
+
# Test case 3: Resize image with the same dimensions
|
| 24 |
+
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
| 25 |
+
height = 50
|
| 26 |
+
width = 50
|
| 27 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 28 |
+
assert resized_images.shape == (1, height, width, 3)
|
| 29 |
+
assert np.all(resized_images == 0)
|
| 30 |
+
|
| 31 |
+
# Test case 3: Resize image with odd-numbered padding
|
| 32 |
+
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
| 33 |
+
height = 60
|
| 34 |
+
width = 80
|
| 35 |
+
resized_images = image_tools.resize_with_pad(images, height, width)
|
| 36 |
+
assert resized_images.shape == (1, height, width, 3)
|
| 37 |
+
assert np.all(resized_images == 0)
|
test_starhelm.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
read data from lerobot
|
| 3 |
+
|
| 4 |
+
requesting the server
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from starhelm.websocket_client_policy import WebsocketClientPolicy
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
from vlaholo.datasets.lerobot_dataset import LeRobotDataset
|
| 11 |
+
|
| 12 |
+
# from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
| 13 |
+
# from starforce.data.dataset import LeRobotSingleDataset as LeRobotDataset
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
os.environ.pop("http_proxy", None)
|
| 20 |
+
os.environ.pop("https_proxy", None)
|
| 21 |
+
os.environ.pop("all_proxy", None)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_dummy_data():
|
| 25 |
+
"""
|
| 26 |
+
Same data loader from vlaholo
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
dataset_repo_id = "data/qz_zz/0801_task9/pick"
|
| 30 |
+
dataset = LeRobotDataset(dataset_repo_id, episodes=[10], video_backend="pyav")
|
| 31 |
+
|
| 32 |
+
dataloader = torch.utils.data.DataLoader(
|
| 33 |
+
dataset,
|
| 34 |
+
num_workers=0,
|
| 35 |
+
batch_size=1,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
batch = next(batch for i, batch in enumerate(dataloader) if i == 29)
|
| 39 |
+
return dataset, batch
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def tensor_as_np_image(t):
|
| 43 |
+
return t[0].permute(1, 2, 0).cpu().unsqueeze(0).numpy()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
|
| 48 |
+
# vla_model = WebsocketClientPolicy(host="172.16.0.171", port=9001)
|
| 49 |
+
vla_model = WebsocketClientPolicy(host="172.16.0.111", port=9001)
|
| 50 |
+
|
| 51 |
+
ds, batch_data = get_dummy_data()
|
| 52 |
+
print(batch_data.keys())
|
| 53 |
+
t0 = time.time()
|
| 54 |
+
benchmark_iters = 30
|
| 55 |
+
for _ in range(benchmark_iters):
|
| 56 |
+
# print(batch)
|
| 57 |
+
t00 = time.time()
|
| 58 |
+
|
| 59 |
+
# hwc 0-1 numpy array
|
| 60 |
+
image_cam_high = tensor_as_np_image(batch_data["observation.images.cam_high"])
|
| 61 |
+
image_cam_left = tensor_as_np_image(batch_data["observation.images.cam_left_wrist"])
|
| 62 |
+
image_cam_right = tensor_as_np_image(batch_data["observation.images.cam_right_wrist"])
|
| 63 |
+
# [1, H, W, 3] 0-1 pixelvalue
|
| 64 |
+
print(f"image_cam_high: {image_cam_high.shape}")
|
| 65 |
+
# obs format
|
| 66 |
+
obs = {
|
| 67 |
+
"images": {
|
| 68 |
+
"cam_high": image_cam_high,
|
| 69 |
+
"cam_left_wrist": image_cam_left,
|
| 70 |
+
"cam_right_wrist": image_cam_right,
|
| 71 |
+
},
|
| 72 |
+
# state: [1, 14]
|
| 73 |
+
"state": batch_data["observation.state"].cpu().numpy(),
|
| 74 |
+
# str language
|
| 75 |
+
"prompt": batch_data["task"],
|
| 76 |
+
# for verbose
|
| 77 |
+
"debug": True,
|
| 78 |
+
}
|
| 79 |
+
action = vla_model.infer(obs=obs)
|
| 80 |
+
|
| 81 |
+
print(
|
| 82 |
+
"##info, action:",
|
| 83 |
+
action,
|
| 84 |
+
time.time() - t00,
|
| 85 |
+
)
|
| 86 |
+
break
|
| 87 |
+
t1 = time.time()
|
| 88 |
+
print(f"cost: {t1-t0:.3f}, avg: {(t1-t0)/benchmark_iters}")
|
tests/alter_lerobot_key.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
|
| 3 |
+
change a key of an existed lerobot dataset
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from vlaholo.datasets.lerobot_dataset import LeRobotDataset
|
| 11 |
+
|
| 12 |
+
# Paths
|
| 13 |
+
# old_root = Path("data/qz/lerobot_data/airbot_datasets/airbot_data_0724/pick")
|
| 14 |
+
old_root = Path("data/sl/0721pre_data_v3")
|
| 15 |
+
old_root = Path("/pfs/data/yangcheng/.cache/huggingface/lerobot/pick")
|
| 16 |
+
new_root = old_root / 'new'
|
| 17 |
+
|
| 18 |
+
# 1. Load the existing dataset using LeRobotDataset
|
| 19 |
+
old_ds = LeRobotDataset(repo_id=old_root)
|
| 20 |
+
print(f'done')
|
| 21 |
+
|
| 22 |
+
# 2. Create a fresh LeRobotDataset at the new location with the same metadata
|
| 23 |
+
new_ds = LeRobotDataset.create(
|
| 24 |
+
repo_id="converted_dataset",
|
| 25 |
+
fps=old_ds.meta.info["fps"],
|
| 26 |
+
root=new_root,
|
| 27 |
+
features=old_ds.meta.info["features"],
|
| 28 |
+
use_videos=old_ds.meta.info.get("video", True),
|
| 29 |
+
)
|
| 30 |
+
# 3. Copy auxiliary folders (videos, tasks, stats, info)
|
| 31 |
+
for folder in ["videos", "meta"]:
|
| 32 |
+
src = old_root / folder
|
| 33 |
+
dst = new_root / folder
|
| 34 |
+
if src.exists():
|
| 35 |
+
shutil.copytree(src, dst, dirs_exist_ok=True)
|
| 36 |
+
|
| 37 |
+
# 4. Iterate through episodes and rewrite parquet with renamed key
|
| 38 |
+
for ep_idx in range(old_ds.meta.total_episodes):
|
| 39 |
+
# Load episode data via Hugging Face datasets
|
| 40 |
+
data_path = old_root / old_ds.meta.get_data_file_path(ep_idx)
|
| 41 |
+
ds = load_dataset(
|
| 42 |
+
"parquet",
|
| 43 |
+
data_files=[str(data_path)],
|
| 44 |
+
split="train",
|
| 45 |
+
)
|
| 46 |
+
# Rename column
|
| 47 |
+
ds = ds.rename_column("state", "observation.state")
|
| 48 |
+
|
| 49 |
+
# Write back via LeRobotDataset's internal method
|
| 50 |
+
out_path = new_root / new_ds.meta.get_data_file_path(ep_idx)
|
| 51 |
+
ds.to_parquet(str(out_path))
|
| 52 |
+
|
| 53 |
+
print(f"✅ New dataset with renamed key saved to {new_root}")
|
tests/async_client.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import threading
|
| 3 |
+
import json_numpy
|
| 4 |
+
import numpy as np
|
| 5 |
+
import requests
|
| 6 |
+
import queue
|
| 7 |
+
import cv2
|
| 8 |
+
import pickle, os
|
| 9 |
+
from matplotlib.pyplot import step
|
| 10 |
+
from typing import Any, cast
|
| 11 |
+
|
| 12 |
+
json_numpy.patch()
|
| 13 |
+
|
| 14 |
+
from airbot_py.arm import AIRBOTPlay, RobotMode, SpeedProfile
|
| 15 |
+
|
| 16 |
+
# Import RTCController from separate module
|
| 17 |
+
from rtc_controller import RTCController
|
| 18 |
+
from ruckig_planner import RuckigPlanner
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import pyrealsense2 as rs
|
| 23 |
+
except ImportError:
|
| 24 |
+
print("Warning: pyrealsense2 not available. Camera functionality will be limited.")
|
| 25 |
+
rs = None
|
| 26 |
+
|
| 27 |
+
# Help static analyzers: treat rs as dynamic Any when available
|
| 28 |
+
if rs is not None:
|
| 29 |
+
rs = cast(Any, rs)
|
| 30 |
+
|
| 31 |
+
class CameraWrapper:
|
| 32 |
+
def __init__(self, devices=None, width=640, height=480, fps=30, num_realsense=0, cv_format="MJPEG"):
|
| 33 |
+
self.width = width
|
| 34 |
+
self.height = height
|
| 35 |
+
self.fps = fps
|
| 36 |
+
self.num_realsense = max(0, int(num_realsense))
|
| 37 |
+
self.cv_format = cv_format
|
| 38 |
+
self.cameras = [] # list of dicts: {type: 'rs'|'cv', handle: pipeline|cap}
|
| 39 |
+
self.device_ids = devices if devices is not None else []
|
| 40 |
+
self._open_cameras()
|
| 41 |
+
print(f'successfully opened {len(self.cameras)} cameras!')
|
| 42 |
+
|
| 43 |
+
def _open_cameras(self):
|
| 44 |
+
if not self.device_ids:
|
| 45 |
+
print("No devices provided for CameraWrapper")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
for idx, dev in enumerate(self.device_ids):
|
| 49 |
+
# Decide camera type
|
| 50 |
+
use_realsense = (idx < self.num_realsense)
|
| 51 |
+
|
| 52 |
+
if use_realsense:
|
| 53 |
+
if rs is None:
|
| 54 |
+
print(f"pyrealsense2 not available, skipping RealSense device at index {idx} (id: {dev})")
|
| 55 |
+
continue
|
| 56 |
+
try:
|
| 57 |
+
serial = str(dev)
|
| 58 |
+
pipeline = rs.pipeline() # type: ignore[attr-defined]
|
| 59 |
+
config = rs.config() # type: ignore[attr-defined]
|
| 60 |
+
config.enable_device(serial)
|
| 61 |
+
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.bgr8, self.fps) # type: ignore[attr-defined]
|
| 62 |
+
pipeline.start(config)
|
| 63 |
+
self.cameras.append({"type": "rs", "handle": pipeline})
|
| 64 |
+
print(f"RealSense camera {serial} opened successfully")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Failed to open RealSense camera {dev}: {e}")
|
| 67 |
+
else:
|
| 68 |
+
try:
|
| 69 |
+
device_index = int(dev)
|
| 70 |
+
cap = cv2.VideoCapture(device_index)
|
| 71 |
+
|
| 72 |
+
if self.cv_format == "MJPEG":
|
| 73 |
+
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) # type: ignore[attr-defined]
|
| 74 |
+
elif self.cv_format == "YUYV":
|
| 75 |
+
cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV")) # type: ignore[attr-defined]
|
| 76 |
+
|
| 77 |
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
| 78 |
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
| 79 |
+
cap.set(cv2.CAP_PROP_FPS, self.fps)
|
| 80 |
+
|
| 81 |
+
if not cap.isOpened():
|
| 82 |
+
print(f"Error: Cannot open OpenCV camera {device_index}")
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
self.cameras.append({"type": "cv", "handle": cap})
|
| 86 |
+
print(f"OpenCV camera {device_index} opened successfully")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Failed to open OpenCV camera {dev}: {e}")
|
| 89 |
+
|
| 90 |
+
def get_images(self):
|
| 91 |
+
images = []
|
| 92 |
+
if len(self.cameras) == 0:
|
| 93 |
+
# Return dummy images if no cameras available - use 640x480 which is expected by the model
|
| 94 |
+
for _ in range(max(1, len(self.device_ids))):
|
| 95 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 96 |
+
dummy_img[:, :, :] = 128 # Gray color instead of black
|
| 97 |
+
images.append(dummy_img)
|
| 98 |
+
return images
|
| 99 |
+
|
| 100 |
+
for cam in self.cameras:
|
| 101 |
+
if cam["type"] == "rs":
|
| 102 |
+
try:
|
| 103 |
+
pipeline = cam["handle"]
|
| 104 |
+
frames = pipeline.wait_for_frames()
|
| 105 |
+
color_frame = frames.get_color_frame()
|
| 106 |
+
if not color_frame:
|
| 107 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 108 |
+
dummy_img[:, :, :] = 128
|
| 109 |
+
images.append(dummy_img)
|
| 110 |
+
else:
|
| 111 |
+
img = np.asanyarray(color_frame.get_data())
|
| 112 |
+
images.append(img)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error reading from RealSense: {e}")
|
| 115 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 116 |
+
dummy_img[:, :, :] = 128
|
| 117 |
+
images.append(dummy_img)
|
| 118 |
+
elif cam["type"] == "cv":
|
| 119 |
+
cap = cam["handle"]
|
| 120 |
+
ret, frame = cap.read()
|
| 121 |
+
if not ret or frame is None:
|
| 122 |
+
dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
| 123 |
+
dummy_img[:, :, :] = 128
|
| 124 |
+
images.append(dummy_img)
|
| 125 |
+
else:
|
| 126 |
+
images.append(frame)
|
| 127 |
+
return images
|
| 128 |
+
|
| 129 |
+
def release(self):
|
| 130 |
+
for cam in self.cameras:
|
| 131 |
+
if cam["type"] == "rs":
|
| 132 |
+
try:
|
| 133 |
+
cam["handle"].stop()
|
| 134 |
+
except Exception:
|
| 135 |
+
pass
|
| 136 |
+
elif cam["type"] == "cv":
|
| 137 |
+
try:
|
| 138 |
+
cam["handle"].release()
|
| 139 |
+
except Exception:
|
| 140 |
+
pass
|
| 141 |
+
self.cameras = []
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def normalization(state):
|
| 145 |
+
"""Normalize robot state for model input.
|
| 146 |
+
|
| 147 |
+
- Threshold gripper: > 0.04 -> 1.0 else -1.0
|
| 148 |
+
- Keep 6 joint values unchanged
|
| 149 |
+
Returns np.ndarray of shape (7,)
|
| 150 |
+
"""
|
| 151 |
+
arr = np.array(state, dtype=np.float32).copy()
|
| 152 |
+
arr[6] = 1.0 if arr[6] > 0.04 else -1.0
|
| 153 |
+
return arr
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def unnormalization(action):
|
| 157 |
+
"""Unnormalize model action to robot command space.
|
| 158 |
+
|
| 159 |
+
- Map gripper: > 0.5 -> 0.06850814 (open), else -> 0.025 (close)
|
| 160 |
+
- Keep 6 joint values unchanged
|
| 161 |
+
Returns np.ndarray of shape (7,)
|
| 162 |
+
"""
|
| 163 |
+
arr = np.array(action, dtype=np.float32).copy()
|
| 164 |
+
arr[6] = 0.06850814 if arr[6] > 0 else 0.025
|
| 165 |
+
return arr
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class RobotWrapper:
|
| 169 |
+
# Ports dict example: {'left_arm': 50051, 'right_arm': None}
|
| 170 |
+
# Only arms with a non-None port will be initialized
|
| 171 |
+
def __init__(self, url='localhost', ports=None, arm_speed='slow', type='move'):
|
| 172 |
+
if ports is None:
|
| 173 |
+
ports = {'left_arm': None, 'right_arm': None}
|
| 174 |
+
assert any(p is not None for p in ports.values()), "at least one arm port is required"
|
| 175 |
+
assert arm_speed in ['slow', 'default', 'fast'], "arm_speed must be in ['slow', default', 'fast']"
|
| 176 |
+
assert type in ['move', 'servo'], "type must be in ['move', 'servo']"
|
| 177 |
+
self.type = type
|
| 178 |
+
|
| 179 |
+
self.robots = {}
|
| 180 |
+
|
| 181 |
+
for arm_name in ['left_arm', 'right_arm']:
|
| 182 |
+
port = ports.get(arm_name)
|
| 183 |
+
if port is None:
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
robot = AIRBOTPlay(url=url, port=port)
|
| 187 |
+
robot.connect()
|
| 188 |
+
robot.set_speed_profile(SpeedProfile.SLOW if arm_speed == 'slow' else SpeedProfile.FAST)
|
| 189 |
+
|
| 190 |
+
product_info = robot.get_product_info()
|
| 191 |
+
print("---------------------------------------------------------")
|
| 192 |
+
print(f"Arm: {arm_name}")
|
| 193 |
+
print(f"Product name: {product_info['product_type']}")
|
| 194 |
+
print(f"Serial number: {product_info['sn']}")
|
| 195 |
+
print(f"Simulation mode: {product_info['is_sim']}")
|
| 196 |
+
print(f"Using interfaces: {product_info['interfaces']}")
|
| 197 |
+
print(f"Installed end effectors: {product_info['eef_types']}")
|
| 198 |
+
print(f"Firmware versions: {product_info['fw_versions']}")
|
| 199 |
+
print("---------------------------------------------------------")
|
| 200 |
+
|
| 201 |
+
# Default initial joints per arm (optional; can be customized)
|
| 202 |
+
if arm_name == 'left_arm':
|
| 203 |
+
joints = [0.0, 0.0, 0.15, -1.7, 0.1, 1.7]
|
| 204 |
+
else:
|
| 205 |
+
joints = [0.0, 0.0, 0.15, 1.7, -0.1, -1.7]
|
| 206 |
+
|
| 207 |
+
# using move mode to move to initial pose
|
| 208 |
+
robot.switch_mode(RobotMode.PLANNING_POS)
|
| 209 |
+
# Overwrite with a safe neutral pose
|
| 210 |
+
joints = [0.00019073777366429567, 0.17948424816131592, 0.027656977996230125, 1.4654383659362793, -0.3435187339782715, -1.4288166761398315]
|
| 211 |
+
# joints = [0.03604944050312042, 0.17948424816131592, 0.029564354568719864, 1.6039139032363892, -0.3419928252696991, -1.5939955711364746]
|
| 212 |
+
robot.move_to_joint_pos(joints)
|
| 213 |
+
robot.move_eef_pos([0.06850814])
|
| 214 |
+
print(f'arm: {arm_name}, joints: {joints}')
|
| 215 |
+
|
| 216 |
+
if self.type == 'move':
|
| 217 |
+
robot.switch_mode(RobotMode.PLANNING_POS)
|
| 218 |
+
elif self.type == 'servo':
|
| 219 |
+
robot.switch_mode(RobotMode.SERVO_JOINT_POS)
|
| 220 |
+
|
| 221 |
+
init_joint_pos = robot.get_joint_pos()
|
| 222 |
+
init_eef_pos = robot.get_eef_pos()
|
| 223 |
+
print(f'[{arm_name}] init_joint_pos: {init_joint_pos}, init_eef_pos: {init_eef_pos}')
|
| 224 |
+
self.robots[arm_name] = robot
|
| 225 |
+
print(f"robot arm {arm_name} (port: {port}) init success!")
|
| 226 |
+
time.sleep(2)
|
| 227 |
+
|
| 228 |
+
def move_to_pos(self, pos, arm='right_arm'):
|
| 229 |
+
assert arm in self.robots, f"arm '{arm}' not initialized"
|
| 230 |
+
if self.type == 'move':
|
| 231 |
+
self.robots[arm].move_to_joint_pos(pos[:6], blocking=True)
|
| 232 |
+
self.robots[arm].move_eef_pos([pos[6]], blocking=True)
|
| 233 |
+
elif self.type == 'servo':
|
| 234 |
+
self.robots[arm].servo_joint_pos(pos[:6])
|
| 235 |
+
self.robots[arm].servo_eef_pos([pos[6]])
|
| 236 |
+
|
| 237 |
+
def get_joint_pos(self, arm='right_arm'):
|
| 238 |
+
assert arm in self.robots, f"arm '{arm}' not initialized"
|
| 239 |
+
return self.robots[arm].get_joint_pos()
|
| 240 |
+
|
| 241 |
+
def get_eef_pos(self, arm='right_arm'):
|
| 242 |
+
assert arm in self.robots, f"arm '{arm}' not initialized"
|
| 243 |
+
return self.robots[arm].get_eef_pos()
|
| 244 |
+
|
| 245 |
+
def get_state_pos(self, arm='right_arm'):
|
| 246 |
+
assert arm in self.robots, f"arm '{arm}' not initialized"
|
| 247 |
+
pos = self.robots[arm].get_joint_pos()
|
| 248 |
+
eef_pos = self.robots[arm].get_eef_pos()
|
| 249 |
+
result = pos + eef_pos
|
| 250 |
+
return result
|
| 251 |
+
|
| 252 |
+
class ActionSmoother:
|
| 253 |
+
def __init__(self, method='exponential', alpha=0.3, window_size=5, smooth_dims=None):
|
| 254 |
+
self.method = method
|
| 255 |
+
self.window_size = window_size
|
| 256 |
+
self.smooth_dims = smooth_dims
|
| 257 |
+
|
| 258 |
+
if isinstance(alpha, (list, tuple, np.ndarray)):
|
| 259 |
+
self.alpha = np.array(alpha, dtype=np.float32)
|
| 260 |
+
else:
|
| 261 |
+
self.alpha = alpha
|
| 262 |
+
|
| 263 |
+
self.history = []
|
| 264 |
+
self.smoothed_action = None
|
| 265 |
+
|
| 266 |
+
def smooth_action(self, raw_action):
|
| 267 |
+
raw_action = np.array(raw_action, dtype=np.float32)
|
| 268 |
+
|
| 269 |
+
self.history.append(raw_action.copy())
|
| 270 |
+
if len(self.history) > self.window_size:
|
| 271 |
+
self.history = self.history[-self.window_size:]
|
| 272 |
+
|
| 273 |
+
if self.smoothed_action is None:
|
| 274 |
+
self.smoothed_action = raw_action.copy()
|
| 275 |
+
return self.smoothed_action
|
| 276 |
+
|
| 277 |
+
result_action = raw_action.copy()
|
| 278 |
+
|
| 279 |
+
if self.smooth_dims is None:
|
| 280 |
+
dims_to_smooth = list(range(len(raw_action)))
|
| 281 |
+
else:
|
| 282 |
+
dims_to_smooth = [d for d in self.smooth_dims if d < len(raw_action)]
|
| 283 |
+
|
| 284 |
+
if self.method == 'exponential':
|
| 285 |
+
if isinstance(self.alpha, np.ndarray):
|
| 286 |
+
for dim in dims_to_smooth:
|
| 287 |
+
alpha = self.alpha[dim] if dim < len(self.alpha) else self.alpha[-1]
|
| 288 |
+
result_action[dim] = alpha * raw_action[dim] + (1 - alpha) * self.smoothed_action[dim]
|
| 289 |
+
else:
|
| 290 |
+
for dim in dims_to_smooth:
|
| 291 |
+
result_action[dim] = self.alpha * raw_action[dim] + (1 - self.alpha) * self.smoothed_action[dim]
|
| 292 |
+
|
| 293 |
+
elif self.method == 'moving_average':
|
| 294 |
+
history_array = np.array(self.history)
|
| 295 |
+
for dim in dims_to_smooth:
|
| 296 |
+
result_action[dim] = np.mean(history_array[:, dim])
|
| 297 |
+
|
| 298 |
+
elif self.method == 'linear_interpolation':
|
| 299 |
+
for dim in dims_to_smooth:
|
| 300 |
+
result_action[dim] = 0.7 * raw_action[dim] + 0.3 * self.smoothed_action[dim]
|
| 301 |
+
|
| 302 |
+
elif self.method == 'identity':
|
| 303 |
+
return result_action
|
| 304 |
+
|
| 305 |
+
elif self.method == 'average':
|
| 306 |
+
history_array = np.array(self.history)
|
| 307 |
+
for dim in dims_to_smooth:
|
| 308 |
+
result_action[dim] = np.mean(history_array[:, dim])
|
| 309 |
+
|
| 310 |
+
else:
|
| 311 |
+
raise ValueError(f"Unknown smoothing method: {self.method}")
|
| 312 |
+
|
| 313 |
+
self.smoothed_action = result_action.copy()
|
| 314 |
+
|
| 315 |
+
return result_action
|
| 316 |
+
|
| 317 |
+
def reset(self):
|
| 318 |
+
self.history = []
|
| 319 |
+
self.smoothed_action = None
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class VLAClient:
|
| 323 |
+
def __init__(self, server_url):
|
| 324 |
+
self.server_url = server_url
|
| 325 |
+
|
| 326 |
+
def predict(self, obs):
|
| 327 |
+
try:
|
| 328 |
+
response = requests.post(
|
| 329 |
+
self.server_url,
|
| 330 |
+
json={"observation": obs},
|
| 331 |
+
)
|
| 332 |
+
action_chunk = response.json()
|
| 333 |
+
|
| 334 |
+
actions = []
|
| 335 |
+
for arm, gripper in zip(action_chunk['action.right_arm'], action_chunk['action.right_gripper']):
|
| 336 |
+
action = np.asarray(list(arm) + [float(gripper)], dtype=np.float32)
|
| 337 |
+
actions.append(action)
|
| 338 |
+
|
| 339 |
+
return np.array(actions)
|
| 340 |
+
except Exception as e:
|
| 341 |
+
print(f"VLA prediction error: {e}")
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def predict_actions(server_url, obs):
|
| 346 |
+
response = requests.post(
|
| 347 |
+
server_url,
|
| 348 |
+
json={"observation": obs},
|
| 349 |
+
)
|
| 350 |
+
return response.json()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
|
| 355 |
+
import argparse
|
| 356 |
+
parser = argparse.ArgumentParser()
|
| 357 |
+
parser.add_argument("--use_rtc", action='store_true', default=False)
|
| 358 |
+
parser.add_argument("--fast", action='store_true', default=False)
|
| 359 |
+
args = parser.parse_args()
|
| 360 |
+
|
| 361 |
+
USE_RTC_CONTROLLER = args.use_rtc
|
| 362 |
+
print(f"Use RTC Controller: {USE_RTC_CONTROLLER}")
|
| 363 |
+
|
| 364 |
+
task_description= "Pick up each block one by one and place them all into the bowl."
|
| 365 |
+
task_description = 'Pick up each block one by one and place them all into the blue bowl.'
|
| 366 |
+
task_description = 'Pick up each block one by one and place them all into the right bowl.'
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
print("Robot System Init...")
|
| 370 |
+
robots = RobotWrapper(
|
| 371 |
+
url='localhost',
|
| 372 |
+
ports={'left_arm': None, 'right_arm': 50053},
|
| 373 |
+
arm_speed='default' if args.fast else 'slow',
|
| 374 |
+
type='servo'
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if args.fast:
|
| 378 |
+
target_queue = queue.Queue(maxsize=30)
|
| 379 |
+
servo_queue = queue.Queue(maxsize=30)
|
| 380 |
+
rp = RuckigPlanner(robots.robots['right_arm'], robots.get_state_pos('right_arm'), DoFs=7, dt=0.02)
|
| 381 |
+
rp.start(servo_queue, target_queue)
|
| 382 |
+
|
| 383 |
+
# server_url = "http://106.13.248.32:10090/act"
|
| 384 |
+
server_url = "http://127.0.0.1:10090/act"
|
| 385 |
+
# server_url = "http://114.111.24.161:10090/act"
|
| 386 |
+
print(f"VLA Server URL: {server_url}")
|
| 387 |
+
|
| 388 |
+
# Camera System Init
|
| 389 |
+
print("Camera System Init...")
|
| 390 |
+
caps = CameraWrapper(
|
| 391 |
+
devices=["215322074711", "242622070332", 0],
|
| 392 |
+
num_realsense=2,
|
| 393 |
+
width=640,
|
| 394 |
+
height=480,
|
| 395 |
+
fps=30,
|
| 396 |
+
cv_format="MJPEG"
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
time.sleep(2)
|
| 400 |
+
|
| 401 |
+
print(f'Smoother System Init for Grippers...')
|
| 402 |
+
action_smoother = ActionSmoother(
|
| 403 |
+
method='average',
|
| 404 |
+
alpha=0.1,
|
| 405 |
+
window_size=10,
|
| 406 |
+
smooth_dims=[6]
|
| 407 |
+
)
|
| 408 |
+
# action_smoother = None
|
| 409 |
+
|
| 410 |
+
print("Loop Start!")
|
| 411 |
+
step_count = 0
|
| 412 |
+
|
| 413 |
+
if USE_RTC_CONTROLLER:
|
| 414 |
+
print("Using ASync (RTCController) Mode...")
|
| 415 |
+
|
| 416 |
+
vla_client = VLAClient(server_url)
|
| 417 |
+
|
| 418 |
+
def get_observation():
|
| 419 |
+
images = caps.get_images()
|
| 420 |
+
images = caps.get_images()
|
| 421 |
+
images = caps.get_images()
|
| 422 |
+
|
| 423 |
+
if len(images) >= 3:
|
| 424 |
+
img_right, img_front, img_env = images[:3]
|
| 425 |
+
else:
|
| 426 |
+
filler = np.zeros((480, 640, 3), dtype=np.uint8)
|
| 427 |
+
filler[:, :, :] = 128
|
| 428 |
+
imgs = images + [filler] * (3 - len(images))
|
| 429 |
+
img_right, img_front, img_env = imgs[:3]
|
| 430 |
+
|
| 431 |
+
arm_state = normalization(robots.get_state_pos(arm='right_arm'))
|
| 432 |
+
|
| 433 |
+
obs = {
|
| 434 |
+
"video.cam_head": img_front[np.newaxis, ::],
|
| 435 |
+
"video.cam_env": img_env[np.newaxis, ::],
|
| 436 |
+
"video.cam_right_wrist": img_right[np.newaxis, ::],
|
| 437 |
+
"state.right_arm": np.expand_dims(np.array(arm_state[:6], dtype=np.float32), axis=0),
|
| 438 |
+
"state.right_gripper": np.expand_dims(np.array([arm_state[6]], dtype=np.float32), axis=0),
|
| 439 |
+
"annotation.human.task_description": [task_description]
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
return obs
|
| 443 |
+
|
| 444 |
+
# RTCController Settings
|
| 445 |
+
H = 16
|
| 446 |
+
d = 8
|
| 447 |
+
s = 5
|
| 448 |
+
|
| 449 |
+
print("RTCController System Init...")
|
| 450 |
+
rtc_controller = RTCController(
|
| 451 |
+
vla_client=vla_client,
|
| 452 |
+
observation_fn=get_observation,
|
| 453 |
+
H=H,
|
| 454 |
+
d=d,
|
| 455 |
+
s=s,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
print("predict action chunk...")
|
| 459 |
+
rtc_controller.reset()
|
| 460 |
+
|
| 461 |
+
if not rtc_controller.is_ready():
|
| 462 |
+
print(f'RTCController Configuration Failed!')
|
| 463 |
+
exit(1)
|
| 464 |
+
|
| 465 |
+
print(f"RTCController Configuration Completed")
|
| 466 |
+
print(f"RTC Parames: H={H}, d={d}, s={s}")
|
| 467 |
+
|
| 468 |
+
while True:
|
| 469 |
+
try:
|
| 470 |
+
loop_start_time = time.time()
|
| 471 |
+
step_count += 1
|
| 472 |
+
|
| 473 |
+
action = rtc_controller.step()
|
| 474 |
+
|
| 475 |
+
if action is None:
|
| 476 |
+
print("RTCController is not ready...")
|
| 477 |
+
time.sleep(0.05)
|
| 478 |
+
continue
|
| 479 |
+
|
| 480 |
+
denormalized_action = unnormalization(action)
|
| 481 |
+
denormalized_action = action_smoother.smooth_action(denormalized_action) if action_smoother else denormalized_action
|
| 482 |
+
|
| 483 |
+
if args.fast:
|
| 484 |
+
target_queue.put(denormalized_action.tolist())
|
| 485 |
+
else:
|
| 486 |
+
robots.move_to_pos(denormalized_action.tolist(), arm='right_arm')
|
| 487 |
+
|
| 488 |
+
# if step_count % 10 == 0:
|
| 489 |
+
# print(f"[RTC] Step {step_count}: executor_index={rtc_controller.executor_index}")
|
| 490 |
+
# print(f"Action: {action}")
|
| 491 |
+
# current_state = robots.get_state_pos(arm='right_arm')
|
| 492 |
+
# print(f"Robot state: {current_state}")
|
| 493 |
+
|
| 494 |
+
# 20Hz
|
| 495 |
+
if args.fast:
|
| 496 |
+
target_period = 0.1
|
| 497 |
+
else:
|
| 498 |
+
target_period = 0.05
|
| 499 |
+
elapsed = time.time() - loop_start_time
|
| 500 |
+
sleep_time = max(0.0, target_period - elapsed)
|
| 501 |
+
|
| 502 |
+
if sleep_time > 0:
|
| 503 |
+
time.sleep(sleep_time)
|
| 504 |
+
else:
|
| 505 |
+
print(f"Loop control time cost: {elapsed:.3f}s > {target_period:.3f}s")
|
| 506 |
+
|
| 507 |
+
except KeyboardInterrupt:
|
| 508 |
+
print("Received interrupt signal, exiting safely...")
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
break
|
| 512 |
+
except Exception as e:
|
| 513 |
+
print(f"ASYNC control mode error: {e}")
|
| 514 |
+
time.sleep(0.05)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
else:
|
| 518 |
+
print("Using Sync Mode...")
|
| 519 |
+
|
| 520 |
+
current_actions = []
|
| 521 |
+
current_action_idx = 0
|
| 522 |
+
chunk_size = 16
|
| 523 |
+
|
| 524 |
+
while True:
|
| 525 |
+
try:
|
| 526 |
+
loop_start_time = time.time()
|
| 527 |
+
step_count += 1
|
| 528 |
+
|
| 529 |
+
# 如果所有actions都执行完了,获取新的chunk
|
| 530 |
+
if current_action_idx >= len(current_actions):
|
| 531 |
+
print(f"[Sync Mode] Obtain new action chunk...")
|
| 532 |
+
|
| 533 |
+
# image
|
| 534 |
+
images = caps.get_images()
|
| 535 |
+
images = caps.get_images()
|
| 536 |
+
images = caps.get_images()
|
| 537 |
+
|
| 538 |
+
if len(images) >= 3:
|
| 539 |
+
img_right, img_front, img_env = images[:3]
|
| 540 |
+
else:
|
| 541 |
+
filler = np.zeros((480, 640, 3), dtype=np.uint8)
|
| 542 |
+
filler[:, :, :] = 128
|
| 543 |
+
imgs = images + [filler] * (3 - len(images))
|
| 544 |
+
img_right, img_front, img_env = imgs[:3]
|
| 545 |
+
|
| 546 |
+
# state
|
| 547 |
+
arm_state = normalization(robots.get_state_pos(arm='right_arm'))
|
| 548 |
+
|
| 549 |
+
# obs
|
| 550 |
+
obs = {
|
| 551 |
+
"video.cam_head": img_front[np.newaxis, ::],
|
| 552 |
+
"video.cam_env": img_env[np.newaxis, ::],
|
| 553 |
+
"video.cam_right_wrist": img_right[np.newaxis, ::],
|
| 554 |
+
"state.right_arm": np.expand_dims(np.array(arm_state[:6], dtype=np.float32), axis=0),
|
| 555 |
+
"state.right_gripper": np.expand_dims(np.array([arm_state[6]], dtype=np.float32), axis=0),
|
| 556 |
+
"annotation.human.task_description": [task_description]
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
start_time = time.time()
|
| 560 |
+
action_chunk = predict_actions(server_url, obs)
|
| 561 |
+
print(f"inference time cost: {time.time() - start_time:.3f}s")
|
| 562 |
+
|
| 563 |
+
current_actions = []
|
| 564 |
+
current_action_idx = 0
|
| 565 |
+
for arm, gripper in zip(action_chunk['action.right_arm'], action_chunk['action.right_gripper']):
|
| 566 |
+
action = np.asarray(list(arm) + [float(gripper)], dtype=np.float32)
|
| 567 |
+
current_actions.append(action)
|
| 568 |
+
|
| 569 |
+
print(f"Obtainbed {len(current_actions)} action steps...")
|
| 570 |
+
|
| 571 |
+
if current_action_idx < len(current_actions):
|
| 572 |
+
raw_action = current_actions[current_action_idx]
|
| 573 |
+
current_action_idx += 1
|
| 574 |
+
|
| 575 |
+
smoothed_action = action_smoother.smooth_action(raw_action)
|
| 576 |
+
|
| 577 |
+
denormalized_action = unnormalization(smoothed_action)
|
| 578 |
+
robots.move_to_pos(denormalized_action.tolist(), arm='right_arm')
|
| 579 |
+
|
| 580 |
+
if step_count % 10 == 0:
|
| 581 |
+
print(f"[Sync Mode] Step {step_count}: action_idx={current_action_idx}/{len(current_actions)}")
|
| 582 |
+
print(f" Raw: {raw_action}")
|
| 583 |
+
print(f" Smoothed: {smoothed_action}")
|
| 584 |
+
|
| 585 |
+
# 20Hz control
|
| 586 |
+
time.sleep(0.05)
|
| 587 |
+
|
| 588 |
+
except KeyboardInterrupt:
|
| 589 |
+
print("Received interrupt signal, exiting safely...")
|
| 590 |
+
break
|
| 591 |
+
except Exception as e:
|
| 592 |
+
print(f"SYNC control mode error: {e}")
|
| 593 |
+
time.sleep(0.05)
|
| 594 |
+
|
| 595 |
+
print("Clearing resources...")
|
| 596 |
+
try:
|
| 597 |
+
caps.release()
|
| 598 |
+
print("Camera resources released")
|
| 599 |
+
except Exception as e:
|
| 600 |
+
print(f"Camera clearing error: {e}")
|
| 601 |
+
|
| 602 |
+
print("Program ended")
|
tests/install_av_opencv.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# make sure ffmpeg have av codec
|
| 3 |
+
sudo apt install libavcodec-dev libavformat-dev libavutil-dev libswscale-dev ffmpeg
|
| 4 |
+
export CMAKE_ARGS="-D WITH_FFMPEG=ON"
|
| 5 |
+
pip install --no-binary opencv-python --no-deps opencv-python -v
|
tests/modality.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"state": {
|
| 3 |
+
"left_arm": {
|
| 4 |
+
"start": 0,
|
| 5 |
+
"end": 6
|
| 6 |
+
},
|
| 7 |
+
"left_gripper": {
|
| 8 |
+
"start": 6,
|
| 9 |
+
"end": 7
|
| 10 |
+
},
|
| 11 |
+
"right_arm": {
|
| 12 |
+
"start": 7,
|
| 13 |
+
"end": 13
|
| 14 |
+
},
|
| 15 |
+
"right_gripper": {
|
| 16 |
+
"start": 13,
|
| 17 |
+
"end": 14
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"action": {
|
| 21 |
+
"left_arm": {
|
| 22 |
+
"start": 0,
|
| 23 |
+
"end": 6
|
| 24 |
+
},
|
| 25 |
+
"left_gripper": {
|
| 26 |
+
"start": 6,
|
| 27 |
+
"end": 7
|
| 28 |
+
},
|
| 29 |
+
"right_arm": {
|
| 30 |
+
"start": 7,
|
| 31 |
+
"end": 13
|
| 32 |
+
},
|
| 33 |
+
"right_gripper": {
|
| 34 |
+
"start": 13,
|
| 35 |
+
"end": 14
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"video": {
|
| 39 |
+
"cam_high": {
|
| 40 |
+
"original_key": "observation.images.cam_high"
|
| 41 |
+
},
|
| 42 |
+
"cam_left_wrist": {
|
| 43 |
+
"original_key": "observation.images.cam_left_wrist"
|
| 44 |
+
},
|
| 45 |
+
"cam_right_wrist": {
|
| 46 |
+
"original_key": "observation.images.cam_right_wrist"
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"annotation": {
|
| 50 |
+
"human.action.task_description": {
|
| 51 |
+
"original_key": "task_index"
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
tests/replay_sl.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
load lerobot dataset and replay it on REAL ARM
|
| 3 |
+
'''
|
| 4 |
+
#!/usr/bin/env python3
|
| 5 |
+
# -*- coding: utf-8 -*-
|
| 6 |
+
|
| 7 |
+
'''
|
| 8 |
+
test for control Agilex arm.
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import argparse
|
| 15 |
+
import threading
|
| 16 |
+
import json
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import numpy as np
|
| 19 |
+
import json
|
| 20 |
+
import time
|
| 21 |
+
import logging
|
| 22 |
+
import argparse
|
| 23 |
+
import threading
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Dict, Any, List, Optional
|
| 26 |
+
|
| 27 |
+
from piper_sdk import C_PiperInterface_V2
|
| 28 |
+
|
| 29 |
+
# 配置日志
|
| 30 |
+
logging.basicConfig(
|
| 31 |
+
level=logging.INFO,
|
| 32 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 33 |
+
handlers=[
|
| 34 |
+
logging.FileHandler('arm_replay.log'),
|
| 35 |
+
logging.StreamHandler()
|
| 36 |
+
]
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
class ArmDataReplayer:
|
| 41 |
+
"""机械臂数据重播器"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, arms_config: Dict[str, str]):
|
| 44 |
+
"""
|
| 45 |
+
初始化重播器
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
arms_config: 机械臂配置 {'left': 'can_left', 'right': 'can_right'}
|
| 49 |
+
"""
|
| 50 |
+
self.arms_config = arms_config
|
| 51 |
+
|
| 52 |
+
# 机械臂接口
|
| 53 |
+
self.arms = {}
|
| 54 |
+
self.is_connected = False
|
| 55 |
+
|
| 56 |
+
# 重播数据
|
| 57 |
+
self.replay_data = None
|
| 58 |
+
self.replay_metadata = None
|
| 59 |
+
|
| 60 |
+
# 重播控制
|
| 61 |
+
self.is_replaying = False
|
| 62 |
+
self.replay_thread = None
|
| 63 |
+
self.replay_speed = 1.0
|
| 64 |
+
self.current_frame = 0
|
| 65 |
+
self.total_frames = 0
|
| 66 |
+
|
| 67 |
+
# 重播统计
|
| 68 |
+
self.frames_replayed = 0
|
| 69 |
+
self.start_time = None
|
| 70 |
+
|
| 71 |
+
logger.info("数据重播器初始化完成")
|
| 72 |
+
|
| 73 |
+
def connect_arms(self) -> bool:
|
| 74 |
+
"""连接机械臂"""
|
| 75 |
+
try:
|
| 76 |
+
for arm_name, can_name in self.arms_config.items():
|
| 77 |
+
logger.info(f"正在连接{arm_name}臂 ({can_name})...")
|
| 78 |
+
|
| 79 |
+
arm = C_PiperInterface_V2(
|
| 80 |
+
can_name=can_name,
|
| 81 |
+
judge_flag=False,
|
| 82 |
+
can_auto_init=True,
|
| 83 |
+
dh_is_offset=1,
|
| 84 |
+
start_sdk_joint_limit=False,
|
| 85 |
+
start_sdk_gripper_limit=False
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
arm.ConnectPort()
|
| 89 |
+
self.arms[arm_name] = arm
|
| 90 |
+
|
| 91 |
+
# 等待连接稳定
|
| 92 |
+
time.sleep(1.0)
|
| 93 |
+
|
| 94 |
+
arm.EnableArm()
|
| 95 |
+
time.sleep(0.5)
|
| 96 |
+
|
| 97 |
+
# 验证连接
|
| 98 |
+
test_msg = arm.GetArmJointMsgs()
|
| 99 |
+
if test_msg and test_msg.Hz > 0:
|
| 100 |
+
logger.info(f"✅ {arm_name}臂连接成功")
|
| 101 |
+
else:
|
| 102 |
+
logger.warning(f"⚠️ {arm_name}臂连接可能不稳定")
|
| 103 |
+
|
| 104 |
+
self.is_connected = True
|
| 105 |
+
logger.info("所有机械臂连接完成")
|
| 106 |
+
return True
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"连接机械臂失败: {e}")
|
| 110 |
+
self.disconnect_arms()
|
| 111 |
+
return False
|
| 112 |
+
|
| 113 |
+
def disconnect_arms(self):
|
| 114 |
+
"""断开机械臂连接"""
|
| 115 |
+
try:
|
| 116 |
+
for arm_name, arm in self.arms.items():
|
| 117 |
+
if arm:
|
| 118 |
+
arm.DisconnectPort()
|
| 119 |
+
logger.info(f"{arm_name}臂已断开连接")
|
| 120 |
+
|
| 121 |
+
self.arms.clear()
|
| 122 |
+
self.is_connected = False
|
| 123 |
+
logger.info("所有机械臂已断开连接")
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"断开连接失败: {e}")
|
| 127 |
+
|
| 128 |
+
def load_hdf5_data(self, filename: str) -> bool:
|
| 129 |
+
"""加载HDF5格式数据"""
|
| 130 |
+
try:
|
| 131 |
+
with h5py.File(filename, 'r') as f:
|
| 132 |
+
# 读取元数据
|
| 133 |
+
self.replay_metadata = {
|
| 134 |
+
'collection_frequency': f['metadata'].attrs.get('collection_frequency', 30),
|
| 135 |
+
'total_samples': f['metadata'].attrs.get('total_samples', 0),
|
| 136 |
+
'arms': f['metadata'].attrs.get('arms', []),
|
| 137 |
+
'start_time': f['metadata'].attrs.get('start_time', 0),
|
| 138 |
+
'creation_time': f['metadata'].attrs.get('creation_time', '')
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
# 读取时间数据
|
| 142 |
+
timestamps = f['timestamps'][:]
|
| 143 |
+
relative_times = f['relative_times'][:]
|
| 144 |
+
|
| 145 |
+
# 读取机械臂数据
|
| 146 |
+
arms_data = {}
|
| 147 |
+
for arm_name in self.replay_metadata['arms']:
|
| 148 |
+
if f'arms/{arm_name}' in f:
|
| 149 |
+
arm_group = f[f'arms/{arm_name}']
|
| 150 |
+
arms_data[arm_name] = {
|
| 151 |
+
'joint_positions': arm_group['joint_positions'][:],
|
| 152 |
+
'gripper_angles': arm_group['gripper_angles'][:],
|
| 153 |
+
'gripper_efforts': arm_group['gripper_efforts'][:]
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
# 兼容性处理:检查是否有gripper_positions字段
|
| 157 |
+
if 'gripper_positions' in arm_group:
|
| 158 |
+
arms_data[arm_name]['gripper_positions'] = arm_group['gripper_positions'][:]
|
| 159 |
+
else:
|
| 160 |
+
# 如果没有位置数据,使用零填充
|
| 161 |
+
arms_data[arm_name]['gripper_positions'] = np.zeros_like(arms_data[arm_name]['gripper_angles'])
|
| 162 |
+
|
| 163 |
+
# 重构数据格式
|
| 164 |
+
self.replay_data = []
|
| 165 |
+
for i in range(len(timestamps)):
|
| 166 |
+
frame = {
|
| 167 |
+
'timestamp': timestamps[i],
|
| 168 |
+
'relative_time': relative_times[i],
|
| 169 |
+
'frame_index': i,
|
| 170 |
+
'arms': {}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
for arm_name, data in arms_data.items():
|
| 174 |
+
if i < len(data['joint_positions']):
|
| 175 |
+
frame['arms'][arm_name] = {
|
| 176 |
+
'joint_positions': data['joint_positions'][i],
|
| 177 |
+
'gripper_angle': data['gripper_angles'][i],
|
| 178 |
+
'gripper_position': data['gripper_positions'][i],
|
| 179 |
+
'gripper_effort': data['gripper_efforts'][i]
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
self.replay_data.append(frame)
|
| 183 |
+
|
| 184 |
+
self.total_frames = len(self.replay_data)
|
| 185 |
+
logger.info(f"HDF5数据加载完成: {self.total_frames}帧")
|
| 186 |
+
logger.info(f"包含机械臂: {self.replay_metadata['arms']}")
|
| 187 |
+
logger.info(f"原始采集频率: {self.replay_metadata['collection_frequency']}Hz")
|
| 188 |
+
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"加载HDF5数据失败: {e}")
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
def load_lerobot_data(self, data_dir: str, episode_idx: int = 0) -> bool:
|
| 196 |
+
"""
|
| 197 |
+
加载LeRobot格式的数据集目录。
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
data_dir: LeRobot数据集的根目录路径。
|
| 201 |
+
episode_idx: 要加载的episode索引。
|
| 202 |
+
"""
|
| 203 |
+
try:
|
| 204 |
+
logger.info(f"开始加载LeRobot数据集: {data_dir}, Episode: {episode_idx}")
|
| 205 |
+
root_path = Path(data_dir)
|
| 206 |
+
|
| 207 |
+
info_path = root_path / 'meta' / 'info.json'
|
| 208 |
+
data_path = root_path / 'data'
|
| 209 |
+
|
| 210 |
+
if not (root_path.is_dir() and info_path.exists() and data_path.exists()):
|
| 211 |
+
logger.error(f"无效的LeRobot数据集目录结构。缺少 meta/info.json 或 data 目录。")
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
with open(info_path, 'r', encoding='utf-8') as f:
|
| 215 |
+
info = json.load(f)
|
| 216 |
+
|
| 217 |
+
self.replay_metadata = {
|
| 218 |
+
'collection_frequency': info.get('fps', 30),
|
| 219 |
+
'arms': ['left', 'right'],
|
| 220 |
+
'lerobot_info': info
|
| 221 |
+
}
|
| 222 |
+
logger.info(f"数据集元信息加载完成。采集频率: {self.replay_metadata['collection_frequency']}Hz")
|
| 223 |
+
|
| 224 |
+
# 找到对应的parquet文件
|
| 225 |
+
# 假设 chunk 总是 000
|
| 226 |
+
parquet_file = data_path / 'chunk-000' / f'episode_{episode_idx:06d}.parquet'
|
| 227 |
+
if not parquet_file.exists():
|
| 228 |
+
logger.error(f"未找到Episode {episode_idx} 的数据文件: {parquet_file}")
|
| 229 |
+
return False
|
| 230 |
+
|
| 231 |
+
logger.info(f"正在读取数据文件: {parquet_file}")
|
| 232 |
+
df = pd.read_parquet(parquet_file)
|
| 233 |
+
|
| 234 |
+
self.replay_data = []
|
| 235 |
+
|
| 236 |
+
# LeRobot的action是puppet(从)臂的下一个state,所以我们用action作为目标
|
| 237 |
+
# names是一个嵌套列表
|
| 238 |
+
action_names = info['features']['action']['names'][0]
|
| 239 |
+
|
| 240 |
+
# 创建映射
|
| 241 |
+
joint_indices = {'left': [None]*6, 'right': [None]*6}
|
| 242 |
+
gripper_indices = {'left': -1, 'right': -1}
|
| 243 |
+
|
| 244 |
+
for i, name in enumerate(action_names):
|
| 245 |
+
if 'masterLeft' in name:
|
| 246 |
+
if 'joint6' in name:
|
| 247 |
+
gripper_indices['left'] = i
|
| 248 |
+
else:
|
| 249 |
+
# 从 '...joint0' 中提取数字 0
|
| 250 |
+
joint_num = int(name.split('joint')[-1])
|
| 251 |
+
if 0 <= joint_num < 6:
|
| 252 |
+
joint_indices['left'][joint_num] = i
|
| 253 |
+
elif 'masterRight' in name:
|
| 254 |
+
if 'joint6' in name:
|
| 255 |
+
gripper_indices['right'] = i
|
| 256 |
+
else:
|
| 257 |
+
joint_num = int(name.split('joint')[-1])
|
| 258 |
+
if 0 <= joint_num < 6:
|
| 259 |
+
joint_indices['right'][joint_num] = i
|
| 260 |
+
|
| 261 |
+
logger.info(f"解析出的左臂关节索引: {joint_indices['left']}")
|
| 262 |
+
logger.info(f"解析出的左臂夹爪索引: {gripper_indices['left']}")
|
| 263 |
+
logger.info(f"解析出的右臂关节索引: {joint_indices['right']}")
|
| 264 |
+
logger.info(f"解析出的右臂夹爪索引: {gripper_indices['right']}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
for i, row in df.iterrows():
|
| 268 |
+
frame = {
|
| 269 |
+
'timestamp': row.get('timestamp', time.time()),
|
| 270 |
+
'relative_time': row.get('timestamp', 0) - df['timestamp'].iloc[0] if 'timestamp' in df else i / self.replay_metadata['collection_frequency'],
|
| 271 |
+
'frame_index': i,
|
| 272 |
+
'arms': {}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
action_values = row['action']
|
| 276 |
+
|
| 277 |
+
for arm_name in ['left', 'right']:
|
| 278 |
+
# LeRobot的action通常是下一个state,所以直接用作目标
|
| 279 |
+
joint_positions = np.array([action_values[j] for j in joint_indices[arm_name]])
|
| 280 |
+
|
| 281 |
+
# 夹爪数据,LeRobot通常是-1到1,需要映射到0-60度
|
| 282 |
+
# 这里的 aloha_arm 数据集,夹爪值在-1到1之间,-1为闭合,1为张开
|
| 283 |
+
gripper_action = action_values[gripper_indices[arm_name]]
|
| 284 |
+
# 映射: 1 -> 0度 (张开), -1 -> 60度 (闭合)
|
| 285 |
+
gripper_angle = (1 - gripper_action) / 2 * 60
|
| 286 |
+
|
| 287 |
+
frame['arms'][arm_name] = {
|
| 288 |
+
'joint_positions': joint_positions,
|
| 289 |
+
'gripper_angle': gripper_angle,
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
self.replay_data.append(frame)
|
| 293 |
+
|
| 294 |
+
self.total_frames = len(self.replay_data)
|
| 295 |
+
logger.info(f"LeRobot数据加载完成: {self.total_frames}帧")
|
| 296 |
+
logger.info(f"包含机械臂: {list(self.replay_data[0]['arms'].keys())}")
|
| 297 |
+
|
| 298 |
+
return True
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"加载LeRobot数据失败: {e}", exc_info=True)
|
| 302 |
+
return False
|
| 303 |
+
|
| 304 |
+
def load_data(self, path: str, episode_idx: int = 0) -> bool:
|
| 305 |
+
"""自动检测并加载数据"""
|
| 306 |
+
filepath = Path(path)
|
| 307 |
+
if not filepath.exists():
|
| 308 |
+
logger.error(f"数据文件或目录不存在: {path}")
|
| 309 |
+
return False
|
| 310 |
+
|
| 311 |
+
if filepath.is_dir():
|
| 312 |
+
# 认为是LeRobot数据集目录
|
| 313 |
+
return self.load_lerobot_data(path, episode_idx=episode_idx)
|
| 314 |
+
elif filepath.suffix.lower() == '.h5' or filepath.suffix.lower() == '.hdf5':
|
| 315 |
+
return self.load_hdf5_data(path)
|
| 316 |
+
# elif filepath.suffix.lower() == '.json': # 旧的逻辑,暂时禁用
|
| 317 |
+
else:
|
| 318 |
+
logger.error(f"不支持的文件格式或路径类型: {path}")
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
def back_to_zero_position(self) -> bool:
|
| 322 |
+
|
| 323 |
+
for k, piper in self.arms:
|
| 324 |
+
logger.info(f'==> processing {k}')
|
| 325 |
+
# piper = self.arms.get('right', None)
|
| 326 |
+
# piper.JointConfig(joint_num=7, set_zero=0xAE)
|
| 327 |
+
# piper.GripperCtrl(set_zero=0xAE)
|
| 328 |
+
|
| 329 |
+
piper.JointCtrl(
|
| 330 |
+
joint_1=0, # 0度
|
| 331 |
+
joint_2=0, # 0度
|
| 332 |
+
joint_3=0, # 0度
|
| 333 |
+
joint_4=0, # 0度
|
| 334 |
+
joint_5=0, # 0度
|
| 335 |
+
joint_6=0 # 0度
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
joint_msgs = piper.GetArmJointMsgs()
|
| 339 |
+
print(f"关节状态: {joint_msgs}")
|
| 340 |
+
joint1_angle = joint_msgs.joint_state.joint_1
|
| 341 |
+
print(f"关节1角度: {joint1_angle/1000.0} 度")
|
| 342 |
+
joint2_angle = joint_msgs.joint_state.joint_2
|
| 343 |
+
print(f"关节2角度: {joint2_angle/1000.0} 度")
|
| 344 |
+
joint3_angle = joint_msgs.joint_state.joint_3
|
| 345 |
+
print(f"关节3角度: {joint3_angle/1000.0} 度")
|
| 346 |
+
joint4_angle = joint_msgs.joint_state.joint_4
|
| 347 |
+
print(f"关节4角度: {joint4_angle/1000.0} 度")
|
| 348 |
+
joint5_angle = joint_msgs.joint_state.joint_5
|
| 349 |
+
print(f"关节5角度: {joint5_angle/1000.0} 度")
|
| 350 |
+
joint6_angle = joint_msgs.joint_state.joint_6
|
| 351 |
+
print(f"关节6角度: {joint6_angle/1000.0} 度")
|
| 352 |
+
joint7_angle = joint_msgs.joint_state.joint_7
|
| 353 |
+
print(f"关节7角度: {joint7_angle/1000.0} 度")
|
| 354 |
+
|
| 355 |
+
# 获取夹爪状态
|
| 356 |
+
gripper_msgs = piper.GetArmGripperMsgs()
|
| 357 |
+
print(f"夹爪状态: {gripper_msgs}")
|
| 358 |
+
return True
|
| 359 |
+
|
| 360 |
+
def move_to_start_position(self) -> bool:
|
| 361 |
+
"""移动到起始位置"""
|
| 362 |
+
if not self.is_connected or not self.replay_data:
|
| 363 |
+
logger.error("机械臂未连接或数据未加载")
|
| 364 |
+
return False
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
start_frame = self.replay_data[0]
|
| 368 |
+
logger.info("正在移动到起始位置...")
|
| 369 |
+
print(start_frame)
|
| 370 |
+
self.back_to_zero_position()
|
| 371 |
+
|
| 372 |
+
for arm_name, arm in self.arms.items():
|
| 373 |
+
if arm_name in start_frame['arms']:
|
| 374 |
+
target_joint_positions = start_frame['arms'][arm_name]['joint_positions']
|
| 375 |
+
target_gripper_angle = start_frame['arms'][arm_name]['gripper_angle']
|
| 376 |
+
|
| 377 |
+
# 转换弧度到0.001度
|
| 378 |
+
target_joint_positions_deg = (target_joint_positions * 180.0 * 1000.0 / np.pi).astype(int)
|
| 379 |
+
|
| 380 |
+
# 设置控制模式和速度
|
| 381 |
+
arm.MotionCtrl_2(ctrl_mode=0x01, move_mode=0x01, move_spd_rate_ctrl=20)
|
| 382 |
+
|
| 383 |
+
# 发送关节控制指令
|
| 384 |
+
arm.JointCtrl(
|
| 385 |
+
int(target_joint_positions_deg[0]), int(target_joint_positions_deg[1]),
|
| 386 |
+
int(target_joint_positions_deg[2]), int(target_joint_positions_deg[3]),
|
| 387 |
+
int(target_joint_positions_deg[4]), int(target_joint_positions_deg[5])
|
| 388 |
+
)
|
| 389 |
+
# 夹爪控制
|
| 390 |
+
arm.GripperCtrl(int(target_gripper_angle * 1000.0), 0, 0x01, 0)
|
| 391 |
+
|
| 392 |
+
logger.info(f"{arm_name}臂目标关节位置 (0.001度): {target_joint_positions_deg.tolist()}")
|
| 393 |
+
logger.info(f"{arm_name}臂目标夹爪位置 (度): {target_gripper_angle:.2f}")
|
| 394 |
+
|
| 395 |
+
# 等待移动完成
|
| 396 |
+
logger.info("等待机械臂移动到起始位置...")
|
| 397 |
+
time.sleep(3.0)
|
| 398 |
+
logger.info("==> 移动到起始位置完成")
|
| 399 |
+
return True
|
| 400 |
+
|
| 401 |
+
except Exception as e:
|
| 402 |
+
logger.error(f"移动到起始位置失败: {e}")
|
| 403 |
+
return False
|
| 404 |
+
|
| 405 |
+
def reset_to_zero(self):
|
| 406 |
+
"""重置到零位置"""
|
| 407 |
+
for arm_name, arm in self.arms.items():
|
| 408 |
+
arm.GripperCtrl(0, 0, 0x01, 0)
|
| 409 |
+
arm.JointCtrl(0, 0, 0, 0, 0, 0)
|
| 410 |
+
time.sleep(0.5)
|
| 411 |
+
logger.info(f"{arm_name}臂重置到零位置完成")
|
| 412 |
+
return True
|
| 413 |
+
|
| 414 |
+
def start_replay(self, speed: float = 1.0, start_frame: int = 0, end_frame: Optional[int] = None):
|
| 415 |
+
"""开始重播"""
|
| 416 |
+
if not self.is_connected or not self.replay_data:
|
| 417 |
+
logger.error("机械臂未连接或数据未加载")
|
| 418 |
+
return False
|
| 419 |
+
|
| 420 |
+
if self.is_replaying:
|
| 421 |
+
logger.warning("重播已在进行中")
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
self.replay_speed = speed
|
| 425 |
+
self.current_frame = start_frame
|
| 426 |
+
self.frames_replayed = 0
|
| 427 |
+
self.start_time = time.time()
|
| 428 |
+
|
| 429 |
+
if end_frame is None:
|
| 430 |
+
end_frame = self.total_frames
|
| 431 |
+
|
| 432 |
+
# 启动重播线程
|
| 433 |
+
self.is_replaying = True
|
| 434 |
+
self.replay_thread = threading.Thread(
|
| 435 |
+
target=self._replay_loop,
|
| 436 |
+
args=(start_frame, end_frame),
|
| 437 |
+
daemon=True
|
| 438 |
+
)
|
| 439 |
+
self.replay_thread.start()
|
| 440 |
+
|
| 441 |
+
logger.info(f"开始重播数据")
|
| 442 |
+
logger.info(f"重播速度: {speed}x")
|
| 443 |
+
logger.info(f"帧范围: {start_frame} - {end_frame}")
|
| 444 |
+
|
| 445 |
+
return True
|
| 446 |
+
|
| 447 |
+
def stop_replay(self):
|
| 448 |
+
"""停止重播"""
|
| 449 |
+
if not self.is_replaying:
|
| 450 |
+
return
|
| 451 |
+
|
| 452 |
+
self.is_replaying = False
|
| 453 |
+
|
| 454 |
+
if self.replay_thread:
|
| 455 |
+
self.replay_thread.join(timeout=2.0)
|
| 456 |
+
|
| 457 |
+
logger.info("重播已停止")
|
| 458 |
+
logger.info(f"总重播帧数: {self.frames_replayed}")
|
| 459 |
+
|
| 460 |
+
def _replay_loop(self, start_frame: int, end_frame: int):
|
| 461 |
+
"""重播循环"""
|
| 462 |
+
original_frequency = self.replay_metadata.get('collection_frequency', 30)
|
| 463 |
+
base_dt = 1.0 / original_frequency
|
| 464 |
+
adjusted_dt = base_dt / self.replay_speed
|
| 465 |
+
|
| 466 |
+
logger.info(f"原始频率: {original_frequency}Hz, 调整后间隔: {adjusted_dt:.4f}s")
|
| 467 |
+
|
| 468 |
+
for frame_idx in range(start_frame, min(end_frame, self.total_frames)):
|
| 469 |
+
if not self.is_replaying:
|
| 470 |
+
break
|
| 471 |
+
|
| 472 |
+
frame_start_time = time.time()
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
frame = self.replay_data[frame_idx]
|
| 476 |
+
self.current_frame = frame_idx
|
| 477 |
+
|
| 478 |
+
# 执行关节控制
|
| 479 |
+
for arm_name, arm in self.arms.items():
|
| 480 |
+
if arm_name in frame['arms']:
|
| 481 |
+
arm_data = frame['arms'][arm_name]
|
| 482 |
+
target_joint_positions = arm_data['joint_positions']
|
| 483 |
+
target_gripper_angle = arm_data['gripper_angle']
|
| 484 |
+
|
| 485 |
+
# 转换弧度到0.001度
|
| 486 |
+
target_joint_positions_deg = (target_joint_positions * 180.0 * 1000.0 / np.pi).astype(int)
|
| 487 |
+
|
| 488 |
+
# 发送关节控制指令
|
| 489 |
+
arm.JointCtrl(
|
| 490 |
+
int(target_joint_positions_deg[0]), int(target_joint_positions_deg[1]),
|
| 491 |
+
int(target_joint_positions_deg[2]), int(target_joint_positions_deg[3]),
|
| 492 |
+
int(target_joint_positions_deg[4]), int(target_joint_positions_deg[5])
|
| 493 |
+
)
|
| 494 |
+
# 夹爪控制
|
| 495 |
+
arm.GripperCtrl(int(target_gripper_angle * 1000.0), 0, 0x01, 0)
|
| 496 |
+
|
| 497 |
+
self.frames_replayed += 1
|
| 498 |
+
|
| 499 |
+
# 等待下一个周期
|
| 500 |
+
elapsed_time = time.time() - frame_start_time
|
| 501 |
+
sleep_time = adjusted_dt - elapsed_time
|
| 502 |
+
if sleep_time > 0:
|
| 503 |
+
time.sleep(sleep_time)
|
| 504 |
+
|
| 505 |
+
except Exception as e:
|
| 506 |
+
logger.error(f"重播帧 {frame_idx} 时发生错误: {e}")
|
| 507 |
+
self.is_replaying = False
|
| 508 |
+
break
|
| 509 |
+
|
| 510 |
+
self.is_replaying = False
|
| 511 |
+
logger.info("重播循环结束")
|
| 512 |
+
self.reset_to_zero()
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def get_replay_info(self) -> Dict[str, Any]:
|
| 516 |
+
"""获取重播信息"""
|
| 517 |
+
return {
|
| 518 |
+
'total_frames': self.total_frames,
|
| 519 |
+
'current_frame': self.current_frame,
|
| 520 |
+
'frames_replayed': self.frames_replayed,
|
| 521 |
+
'is_replaying': self.is_replaying,
|
| 522 |
+
'replay_speed': self.replay_speed,
|
| 523 |
+
'metadata': self.replay_metadata
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
def get_frame_data(self, frame_index: int) -> Optional[Dict[str, Any]]:
|
| 527 |
+
"""获取指定帧的数据"""
|
| 528 |
+
if not self.replay_data or frame_index >= len(self.replay_data):
|
| 529 |
+
return None
|
| 530 |
+
return self.replay_data[frame_index]
|
| 531 |
+
|
| 532 |
+
def seek_to_frame(self, frame_index: int):
|
| 533 |
+
"""跳转到指定帧"""
|
| 534 |
+
if not self.replay_data:
|
| 535 |
+
logger.error("没有数据可跳转")
|
| 536 |
+
return False
|
| 537 |
+
|
| 538 |
+
if frame_index < 0 or frame_index >= self.total_frames:
|
| 539 |
+
logger.error(f"帧索引超出范围: {frame_index}")
|
| 540 |
+
return False
|
| 541 |
+
|
| 542 |
+
self.current_frame = frame_index
|
| 543 |
+
logger.info(f"跳转到帧: {frame_index}")
|
| 544 |
+
return True
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def main():
|
| 548 |
+
"""主函数"""
|
| 549 |
+
parser = argparse.ArgumentParser(description="机械臂数据重播工具")
|
| 550 |
+
parser.add_argument(
|
| 551 |
+
'--path',
|
| 552 |
+
type=str,
|
| 553 |
+
required=True,
|
| 554 |
+
help="要重播的数据文件路径(.h5)或LeRobot数据集目录"
|
| 555 |
+
)
|
| 556 |
+
parser.add_argument(
|
| 557 |
+
'--episode',
|
| 558 |
+
type=int,
|
| 559 |
+
default=0,
|
| 560 |
+
help="当路径为LeRobot目录时,指定要重播的episode索引"
|
| 561 |
+
)
|
| 562 |
+
parser.add_argument(
|
| 563 |
+
'--speed',
|
| 564 |
+
type=float,
|
| 565 |
+
default=1.0,
|
| 566 |
+
help="重播速度倍率"
|
| 567 |
+
)
|
| 568 |
+
parser.add_argument(
|
| 569 |
+
'--left-can',
|
| 570 |
+
type=str,
|
| 571 |
+
default='can_left',
|
| 572 |
+
help="左臂CAN接口名称"
|
| 573 |
+
)
|
| 574 |
+
parser.add_argument(
|
| 575 |
+
'--right-can',
|
| 576 |
+
type=str,
|
| 577 |
+
default='can_right',
|
| 578 |
+
help="右臂CAN接口名称"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
args = parser.parse_args()
|
| 582 |
+
|
| 583 |
+
arms_config = {
|
| 584 |
+
'left': args.left_can,
|
| 585 |
+
'right': args.right_can
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
replayer = ArmDataReplayer(arms_config)
|
| 589 |
+
|
| 590 |
+
# 连接机械臂
|
| 591 |
+
if not replayer.connect_arms():
|
| 592 |
+
logger.error("无法连接到机械臂,程序退出。")
|
| 593 |
+
return
|
| 594 |
+
|
| 595 |
+
# 加载数据
|
| 596 |
+
if not replayer.load_data(args.path, episode_idx=args.episode):
|
| 597 |
+
logger.error("数据加载失败,程序退出。")
|
| 598 |
+
replayer.disconnect_arms()
|
| 599 |
+
return
|
| 600 |
+
|
| 601 |
+
try:
|
| 602 |
+
# 移动到起始位置
|
| 603 |
+
if not replayer.move_to_start_position():
|
| 604 |
+
logger.error("移动到起始位置失败,程序退出。")
|
| 605 |
+
return
|
| 606 |
+
|
| 607 |
+
logger.info("准备开始重播...")
|
| 608 |
+
# # 开始重播
|
| 609 |
+
# replayer.start_replay(speed=args.speed)
|
| 610 |
+
|
| 611 |
+
# # 保持主线程运行,直到重播完成或用户中断
|
| 612 |
+
# while replayer.is_replaying:
|
| 613 |
+
# info = replayer.get_replay_info()
|
| 614 |
+
# print(f"\r重播中... 帧: {info['current_frame']}/{info['total_frames']} ({(info['current_frame']+1)*100/info['total_frames']:.1f}%)", end="")
|
| 615 |
+
# time.sleep(0.5)
|
| 616 |
+
|
| 617 |
+
# print("\n重播完成。")
|
| 618 |
+
|
| 619 |
+
except KeyboardInterrupt:
|
| 620 |
+
logger.info("接收到中断信号,正在停止...")
|
| 621 |
+
|
| 622 |
+
finally:
|
| 623 |
+
replayer.stop_replay()
|
| 624 |
+
replayer.disconnect_arms()
|
| 625 |
+
logger.info("程序已清理并退出。")
|
| 626 |
+
|
| 627 |
+
if __name__ == "__main__":
|
| 628 |
+
main()
|
tests/save_s1.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Saving a s1 pretrained model for training
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from starforce.model.starforce_s1 import Starforce_S1, Starforce_S1_Config
|
| 8 |
+
from starforce.model.action_head.flow_matching_action_head import FlowmatchingActionHeadConfig
|
| 9 |
+
|
| 10 |
+
config = Starforce_S1_Config()
|
| 11 |
+
config.backbone_cfg = {
|
| 12 |
+
"tune_llm": False,
|
| 13 |
+
# "vllm_base_model_path": "Qwen/Qwen2.5-VL-7B-Instruct",
|
| 14 |
+
"vllm_base_model_path": "/pfs/pfs-ahGxdf/data/wujingyi/huggingface/Qwen2.5-VL-3B-Instruct",
|
| 15 |
+
"select_layer": 12,
|
| 16 |
+
"feature_dim": 2048,
|
| 17 |
+
"project_to_dim": 2048,
|
| 18 |
+
}
|
| 19 |
+
config.action_horizon = 16
|
| 20 |
+
config.action_dim = 32
|
| 21 |
+
config.action_head_cfg = {
|
| 22 |
+
"action_dim": 32,
|
| 23 |
+
"action_horizon": 16,
|
| 24 |
+
"add_pos_embed": True,
|
| 25 |
+
"backbone_embedding_dim": 2048,
|
| 26 |
+
"diffusion_model_cfg": {
|
| 27 |
+
"attention_head_dim": 48,
|
| 28 |
+
"cross_attention_dim": 2048,
|
| 29 |
+
"dropout": 0.2,
|
| 30 |
+
"final_dropout": True,
|
| 31 |
+
"interleave_self_attention": True,
|
| 32 |
+
"norm_type": "ada_norm",
|
| 33 |
+
"num_attention_heads": 32,
|
| 34 |
+
"num_layers": 16,
|
| 35 |
+
"output_dim": 1024,
|
| 36 |
+
"positional_embeddings": None,
|
| 37 |
+
},
|
| 38 |
+
"hidden_size": 1024,
|
| 39 |
+
"input_embedding_dim": 1536,
|
| 40 |
+
"max_action_dim": 32,
|
| 41 |
+
"max_state_dim": 64,
|
| 42 |
+
"model_dtype": "float32",
|
| 43 |
+
"noise_beta_alpha": 1.5,
|
| 44 |
+
"noise_beta_beta": 1.0,
|
| 45 |
+
"noise_s": 0.999,
|
| 46 |
+
"num_inference_timesteps": 4,
|
| 47 |
+
"num_target_vision_tokens": 32,
|
| 48 |
+
"num_timestep_buckets": 1000,
|
| 49 |
+
"tune_diffusion_model": True,
|
| 50 |
+
"tune_projector": True,
|
| 51 |
+
"use_vlln": True,
|
| 52 |
+
"vl_self_attention_cfg": {
|
| 53 |
+
"attention_head_dim": 64,
|
| 54 |
+
"dropout": 0.2,
|
| 55 |
+
"final_dropout": True,
|
| 56 |
+
"num_attention_heads": 32,
|
| 57 |
+
"num_layers": 4,
|
| 58 |
+
"positional_embeddings": None,
|
| 59 |
+
},
|
| 60 |
+
}
|
| 61 |
+
model = Starforce_S1(config=config, local_model_path=None)
|
| 62 |
+
|
| 63 |
+
# action_head_state_dict = torch.load("checkpoints/GR00T-N1.5-3B-action-expert.pth")
|
| 64 |
+
action_head_state_dict = torch.load("checkpoints/qz-action-expert.pth")
|
| 65 |
+
model.action_head.load_state_dict(action_head_state_dict)
|
| 66 |
+
|
| 67 |
+
model.save_pretrained("checkpoints/Starforce-S1-3B")
|
| 68 |
+
|
| 69 |
+
print(model.)
|
| 70 |
+
print("done!")
|
tests/save_s1_7B.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Saving a s1 pretrained model for training
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from starforce.model.starforce_s1 import Starforce_S1, Starforce_S1_Config
|
| 8 |
+
from starforce.model.action_head.flow_matching_action_head import FlowmatchingActionHeadConfig
|
| 9 |
+
|
| 10 |
+
config = Starforce_S1_Config()
|
| 11 |
+
config.backbone_cfg = {
|
| 12 |
+
"tune_llm": False,
|
| 13 |
+
# "vllm_base_model_path": "Qwen/Qwen2.5-VL-7B-Instruct",
|
| 14 |
+
"vllm_base_model_path": "/pfs/pfs-ahGxdf/data/wujingyi/huggingface/Qwen2.5-VL-7B-Instruct",
|
| 15 |
+
"select_layer": 12,
|
| 16 |
+
"feature_dim": 3584,
|
| 17 |
+
"project_to_dim": 2048,
|
| 18 |
+
}
|
| 19 |
+
config.action_horizon = 16
|
| 20 |
+
config.action_dim = 32
|
| 21 |
+
config.action_head_cfg = {
|
| 22 |
+
"action_dim": 32,
|
| 23 |
+
"action_horizon": 16,
|
| 24 |
+
"add_pos_embed": True,
|
| 25 |
+
"backbone_embedding_dim": 2048,
|
| 26 |
+
"diffusion_model_cfg": {
|
| 27 |
+
"attention_head_dim": 48,
|
| 28 |
+
"cross_attention_dim": 2048,
|
| 29 |
+
"dropout": 0.2,
|
| 30 |
+
"final_dropout": True,
|
| 31 |
+
"interleave_self_attention": True,
|
| 32 |
+
"norm_type": "ada_norm",
|
| 33 |
+
"num_attention_heads": 32,
|
| 34 |
+
"num_layers": 16,
|
| 35 |
+
"output_dim": 1024,
|
| 36 |
+
"positional_embeddings": None,
|
| 37 |
+
},
|
| 38 |
+
"hidden_size": 1024,
|
| 39 |
+
"input_embedding_dim": 1536,
|
| 40 |
+
"max_action_dim": 32,
|
| 41 |
+
"max_state_dim": 64,
|
| 42 |
+
"model_dtype": "float32",
|
| 43 |
+
"noise_beta_alpha": 1.5,
|
| 44 |
+
"noise_beta_beta": 1.0,
|
| 45 |
+
"noise_s": 0.999,
|
| 46 |
+
"num_inference_timesteps": 4,
|
| 47 |
+
"num_target_vision_tokens": 32,
|
| 48 |
+
"num_timestep_buckets": 1000,
|
| 49 |
+
"tune_diffusion_model": True,
|
| 50 |
+
"tune_projector": True,
|
| 51 |
+
"use_vlln": True,
|
| 52 |
+
"vl_self_attention_cfg": {
|
| 53 |
+
"attention_head_dim": 64,
|
| 54 |
+
"dropout": 0.2,
|
| 55 |
+
"final_dropout": True,
|
| 56 |
+
"num_attention_heads": 32,
|
| 57 |
+
"num_layers": 4,
|
| 58 |
+
"positional_embeddings": None,
|
| 59 |
+
},
|
| 60 |
+
}
|
| 61 |
+
model = Starforce_S1(config=config, local_model_path=None)
|
| 62 |
+
|
| 63 |
+
# action_head_state_dict = torch.load("checkpoints/GR00T-N1.5-3B-action-expert.pth")
|
| 64 |
+
action_head_state_dict = torch.load("checkpoints/qz-action-expert.pth")
|
| 65 |
+
model.action_head.load_state_dict(action_head_state_dict)
|
| 66 |
+
|
| 67 |
+
model.save_pretrained("checkpoints/Starforce-S1-7B")
|
| 68 |
+
|
| 69 |
+
print("done!")
|
tests/test_cv.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
video_path = "data/sl/0723pre_data_v1/videos/chunk-000/observation.images.cam_high/episode_000045.mp4" # Replace with your actual video path
|
| 5 |
+
video_path = 'data/test_aloha_singlearm/videos/chunk-000/observation.images.cam_high/episode_000000.mp4'
|
| 6 |
+
print("OpenCV build info:")
|
| 7 |
+
print(cv2.getBuildInformation())
|
| 8 |
+
|
| 9 |
+
cap = cv2.VideoCapture(video_path)
|
| 10 |
+
|
| 11 |
+
if not cap.isOpened():
|
| 12 |
+
print(f"Error: Failed to open video file: {video_path}")
|
| 13 |
+
sys.exit(1)
|
| 14 |
+
|
| 15 |
+
ret, frame = cap.read()
|
| 16 |
+
if not ret:
|
| 17 |
+
print("Error: Unable to read the first frame of the video.")
|
| 18 |
+
cap.release()
|
| 19 |
+
sys.exit(1)
|
| 20 |
+
|
| 21 |
+
print("Success: Video file opened and first frame read successfully.")
|
| 22 |
+
cap.release()
|
tests/test_hf.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 4 |
+
|
| 5 |
+
from transformers import AutoConfig
|
| 6 |
+
from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoModel
|
| 8 |
+
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
vllm_base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
| 13 |
+
|
| 14 |
+
config = Qwen2_5_VLConfig.from_pretrained(vllm_base_model_path, trust_remote_code=True)
|
| 15 |
+
vllm_model = Qwen2_5_VLForConditionalGeneration(
|
| 16 |
+
config=config,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
print(vllm_model)
|
tests/test_pi0.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from starforce.datasets.lerobot_dataset import LeRobotDataset
|
| 2 |
+
import torch
|
| 3 |
+
from starforce.models.pretrained import PreTrainedConfig
|
| 4 |
+
from starforce.models.build_model import make_policy
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
| 12 |
+
# paligemma doesn't support bf16?
|
| 13 |
+
# dtype = torch.float32
|
| 14 |
+
logger.info(f"##info, device: {device}, dtype: {dtype}")
|
| 15 |
+
|
| 16 |
+
# dataset_repo_id = "danaaubakirova/koch_test"
|
| 17 |
+
# dataset_repo_id = "data/robotwin2lerobot/block_hammer_beat"
|
| 18 |
+
dataset_repo_id = "/pfs/data/xiongxiao/lerobot_fps30/open_laptop"
|
| 19 |
+
# ckpt_torch_dir = "/pfs/data/fgang/vla_holo/checkpoints/pi0"
|
| 20 |
+
# ckpt_torch_dir = "/pfs/data/fgang/outputs_models/pi0-1-20000/pretrained_model"
|
| 21 |
+
# ckpt_torch_dir = "/pfs/data/fgang/outputs_models/pi0-robotwin-30fps-tasks3"
|
| 22 |
+
# ckpt_torch_dir = "/pfs/data/fgang/vla_holo/outputs/pi0-fixed-20ksteps"
|
| 23 |
+
ckpt_torch_dir = "/pfs/data/fgang/outputs_models/pi0-robotwin-30fps-tasks5"
|
| 24 |
+
|
| 25 |
+
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
| 26 |
+
dataloader = torch.utils.data.DataLoader(
|
| 27 |
+
dataset,
|
| 28 |
+
num_workers=0,
|
| 29 |
+
batch_size=1,
|
| 30 |
+
)
|
| 31 |
+
batch = next(iter(dataloader))
|
| 32 |
+
# To device
|
| 33 |
+
for k in batch:
|
| 34 |
+
if isinstance(batch[k], torch.Tensor):
|
| 35 |
+
batch[k] = batch[k].to(device=device, dtype=dtype)
|
| 36 |
+
print(f'dataset.meta: {dataset.meta}')
|
| 37 |
+
|
| 38 |
+
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir, device=device)
|
| 39 |
+
cfg.pretrained_path = ckpt_torch_dir
|
| 40 |
+
policy = make_policy(cfg, ds_meta=dataset.meta)
|
| 41 |
+
# policy.to(dtype)
|
| 42 |
+
# print(policy)
|
| 43 |
+
|
| 44 |
+
t0 = time.time()
|
| 45 |
+
with torch.amp.autocast(device_type=device):
|
| 46 |
+
benchmark_iters = 30
|
| 47 |
+
for _ in range(benchmark_iters):
|
| 48 |
+
# print(batch)
|
| 49 |
+
t00 = time.time()
|
| 50 |
+
action = policy.select_action(batch, n_steps_out=50)
|
| 51 |
+
torch.cuda.synchronize()
|
| 52 |
+
# print("##info, action:", action.shape, action.dtype, action.device, action, time.time() - t00)
|
| 53 |
+
t1 = time.time()
|
| 54 |
+
print(f'cost: {t1-t0:.3f}, avg: {(t1-t0)/benchmark_iters}')
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|
tests/test_starhelm.py
ADDED
|
File without changes
|
tests/test_tensor.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
write a code read model weights from
|
| 3 |
+
checkpoints/models/pi0-sl-b1/model.safetensors
|
| 4 |
+
print all key and values
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
'''
|
| 8 |
+
write a code read model weights from
|
| 9 |
+
checkpoints/models/pi0-sl-b1/model.safetensors
|
| 10 |
+
print all key and values
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
from safetensors.torch import load_file
|
| 14 |
+
|
| 15 |
+
# Load the weights from the safetensors file
|
| 16 |
+
weights = load_file("checkpoints/models/pi0-sl-b1/model.safetensors")
|
| 17 |
+
|
| 18 |
+
# Print all keys and their corresponding values' shapes
|
| 19 |
+
for key, value in weights.items():
|
| 20 |
+
if 'normalize' in key:
|
| 21 |
+
print(f"Key: {key}, Shape: {value.shape}, Type: {value.dtype} {value}")
|
tests/vis_lerobot_data.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from vlaholo.datasets.lerobot_dataset import LeRobotDataset
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
from matplotlib.animation import FuncAnimation
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
TODO:
|
| 10 |
+
|
| 11 |
+
support datasets == 4.0
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def plot_episode_joint_states(dataset_path: str, episode_index: int):
|
| 16 |
+
dataset = LeRobotDataset(dataset_path)
|
| 17 |
+
|
| 18 |
+
if episode_index >= dataset.num_episodes:
|
| 19 |
+
print(
|
| 20 |
+
f"episode index {episode_index} is out of range, total episodes: {dataset.num_episodes}"
|
| 21 |
+
)
|
| 22 |
+
episode_index = dataset.num_episodes - 1
|
| 23 |
+
print(f"force set to max episode index: {episode_index}")
|
| 24 |
+
|
| 25 |
+
hf_dataset = dataset.hf_dataset
|
| 26 |
+
episode_ds = hf_dataset.filter(lambda x: x["episode_index"] == episode_index)
|
| 27 |
+
video_paths = dataset.encode_episode_videos(episode_index=episode_index)
|
| 28 |
+
|
| 29 |
+
caps = {}
|
| 30 |
+
for key, path in video_paths.items():
|
| 31 |
+
cap = cv2.VideoCapture(path)
|
| 32 |
+
if not cap.isOpened():
|
| 33 |
+
raise ValueError(f"Could not open video: {path}")
|
| 34 |
+
caps[key] = cap
|
| 35 |
+
|
| 36 |
+
fps = caps[next(iter(caps))].get(cv2.CAP_PROP_FPS)
|
| 37 |
+
total_frames = int(caps[next(iter(caps))].get(cv2.CAP_PROP_FRAME_COUNT))
|
| 38 |
+
|
| 39 |
+
df = episode_ds.to_pandas()
|
| 40 |
+
joint_states = np.vstack(df["observation.state"].values)
|
| 41 |
+
timestamps = df["timestamp"].values
|
| 42 |
+
duration_sec = timestamps[-1] - timestamps[0]
|
| 43 |
+
|
| 44 |
+
joint_names = dataset.features["observation.state"]["names"]
|
| 45 |
+
if isinstance(joint_names, list) and len(joint_names) == 1 and isinstance(joint_names[0], list):
|
| 46 |
+
joint_names = joint_names[0]
|
| 47 |
+
if len(joint_names) <= 1:
|
| 48 |
+
joint_names = [f"Joint {i}" for i in range(joint_states.shape[1])]
|
| 49 |
+
|
| 50 |
+
n_joints = joint_states.shape[1]
|
| 51 |
+
n_joint_rows = (n_joints + 2) // 3
|
| 52 |
+
|
| 53 |
+
# 创建 Figure 并设置对称左右边距及顶部空间
|
| 54 |
+
fig = plt.figure(figsize=(18, 4 + 4 * n_joint_rows))
|
| 55 |
+
fig.subplots_adjust(top=0.92, bottom=0.05, left=0.05, right=0.95, hspace=0.4, wspace=0.3)
|
| 56 |
+
|
| 57 |
+
# 一级标题:字体更大(48号)
|
| 58 |
+
fig.suptitle(
|
| 59 |
+
"Starforce Data Inspect System", x=0.5, y=0.98, fontsize=35, fontweight="bold", ha="center"
|
| 60 |
+
)
|
| 61 |
+
# 二级统计信息:fps、总帧数、episode index、轨迹时长,16号字体
|
| 62 |
+
stats_text = (
|
| 63 |
+
f"FPS: {fps:.2f} Total frames: {total_frames} "
|
| 64 |
+
f"Episode: {episode_index} Duration: {duration_sec:.2f}s"
|
| 65 |
+
)
|
| 66 |
+
fig.text(0.5, 0.92, stats_text, ha="center", fontsize=21, fontweight="bold")
|
| 67 |
+
|
| 68 |
+
plt.rcParams.update(
|
| 69 |
+
{
|
| 70 |
+
"font.family": "sans-serif",
|
| 71 |
+
"font.sans-serif": ["Arial", "DejaVu Sans"],
|
| 72 |
+
"font.size": 12,
|
| 73 |
+
"axes.titlesize": 14,
|
| 74 |
+
"axes.labelsize": 13,
|
| 75 |
+
"axes.spines.top": False,
|
| 76 |
+
"axes.spines.right": False,
|
| 77 |
+
}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# 使用均等宽度的 GridSpec
|
| 81 |
+
gs = fig.add_gridspec(
|
| 82 |
+
n_joint_rows + 1, 3, width_ratios=[1, 1, 1], height_ratios=[2] + [1] * n_joint_rows
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# 渲染视频区域
|
| 86 |
+
video_axes, video_imgs = {}, {}
|
| 87 |
+
for idx, key in enumerate(video_paths.keys()):
|
| 88 |
+
ax = fig.add_subplot(gs[0, idx])
|
| 89 |
+
ax.set_xticks([])
|
| 90 |
+
ax.set_yticks([])
|
| 91 |
+
ax.set_title(key)
|
| 92 |
+
img = ax.imshow(np.zeros((480, 640, 3)), aspect="auto")
|
| 93 |
+
ax.set_box_aspect(480 / 640)
|
| 94 |
+
video_axes[key] = ax
|
| 95 |
+
video_imgs[key] = img
|
| 96 |
+
|
| 97 |
+
# 绘制轨迹
|
| 98 |
+
joint_axes, lines, time_lines = [], [], []
|
| 99 |
+
base_colors = [
|
| 100 |
+
"#1f77b4",
|
| 101 |
+
"#ff7f0e",
|
| 102 |
+
"#2ca02c",
|
| 103 |
+
"#d62728",
|
| 104 |
+
"#9467bd",
|
| 105 |
+
"#8c564b",
|
| 106 |
+
"#e377c2",
|
| 107 |
+
"#7f7f7f",
|
| 108 |
+
"#bcbd22",
|
| 109 |
+
"#17becf",
|
| 110 |
+
]
|
| 111 |
+
colors = (base_colors * ((n_joints // len(base_colors)) + 1))[:n_joints]
|
| 112 |
+
|
| 113 |
+
for i in range(n_joints):
|
| 114 |
+
row, col = 1 + i // 3, i % 3
|
| 115 |
+
ax = fig.add_subplot(gs[row, col])
|
| 116 |
+
|
| 117 |
+
# 渐变背景
|
| 118 |
+
gradient = np.linspace(0, 1, 256).reshape(256, 1)
|
| 119 |
+
extent = [timestamps[0], timestamps[-1], joint_states[:, i].min(), joint_states[:, i].max()]
|
| 120 |
+
ax.imshow(
|
| 121 |
+
np.repeat(gradient, 256, axis=1),
|
| 122 |
+
aspect="auto",
|
| 123 |
+
cmap="Blues",
|
| 124 |
+
alpha=0.1,
|
| 125 |
+
extent=extent,
|
| 126 |
+
origin="lower",
|
| 127 |
+
zorder=0,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 轨迹线
|
| 131 |
+
(line,) = ax.plot([], [], label=joint_names[i], color=colors[i], linewidth=2.5, zorder=1)
|
| 132 |
+
lines.append(line)
|
| 133 |
+
|
| 134 |
+
# 轴设置
|
| 135 |
+
ax.set_xlabel("Time (s)")
|
| 136 |
+
ax.set_ylabel("pos")
|
| 137 |
+
ax.spines["left"].set_visible(False)
|
| 138 |
+
|
| 139 |
+
ax.set_title(joint_names[i], fontweight="bold")
|
| 140 |
+
ax.set_xlim(timestamps[0], timestamps[-1])
|
| 141 |
+
y0, y1 = joint_states[:, i].min(), joint_states[:, i].max()
|
| 142 |
+
m = (y1 - y0) * 0.1
|
| 143 |
+
ax.set_ylim(y0 - m, y1 + m)
|
| 144 |
+
|
| 145 |
+
tl = ax.axvline(x=timestamps[0], color="crimson", alpha=0.7, linewidth=1.2, zorder=2)
|
| 146 |
+
time_lines.append(tl)
|
| 147 |
+
joint_axes.append(ax)
|
| 148 |
+
|
| 149 |
+
# 初始化与动画函数
|
| 150 |
+
def init():
|
| 151 |
+
for ln in lines:
|
| 152 |
+
ln.set_data([], [])
|
| 153 |
+
return lines + time_lines + list(video_imgs.values())
|
| 154 |
+
|
| 155 |
+
def animate(frame_idx):
|
| 156 |
+
idx = min(frame_idx, len(timestamps) - 1)
|
| 157 |
+
t = timestamps[idx]
|
| 158 |
+
print(
|
| 159 |
+
f"\rProcessing frames: {frame_idx + 1}/{total_frames} ({(frame_idx+1)/total_frames*100:.1f}%)",
|
| 160 |
+
end="",
|
| 161 |
+
flush=True,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
for key, cap in caps.items():
|
| 165 |
+
ret, frame = cap.read()
|
| 166 |
+
if ret:
|
| 167 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 168 |
+
video_imgs[key].set_array(frame)
|
| 169 |
+
for j, ln in enumerate(lines):
|
| 170 |
+
ln.set_data(timestamps[: idx + 1], joint_states[: idx + 1, j])
|
| 171 |
+
for tl in time_lines:
|
| 172 |
+
tl.set_xdata([t, t])
|
| 173 |
+
return lines + time_lines + list(video_imgs.values())
|
| 174 |
+
|
| 175 |
+
anim = FuncAnimation(
|
| 176 |
+
fig, animate, init_func=init, frames=total_frames, interval=1000 / fps, blit=True
|
| 177 |
+
)
|
| 178 |
+
print()
|
| 179 |
+
|
| 180 |
+
save_dir = "outputs/"
|
| 181 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 182 |
+
out_path = os.path.join(save_dir, f"episode_{episode_index}_animation.mp4")
|
| 183 |
+
anim.save(out_path, writer="ffmpeg", fps=fps)
|
| 184 |
+
|
| 185 |
+
plt.close()
|
| 186 |
+
for cap in caps.values():
|
| 187 |
+
cap.release()
|
| 188 |
+
print(f"Animation saved to: {out_path}")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
import argparse
|
| 193 |
+
|
| 194 |
+
parser = argparse.ArgumentParser(
|
| 195 |
+
description="Visualize joint states of a LeRobot dataset episode"
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument("dataset_path", type=str, help="Path or HF repo ID of the LeRobot dataset")
|
| 198 |
+
parser.add_argument("-i", type=int, default=89, help="Episode index to visualize")
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
plot_episode_joint_states(args.dataset_path, args.i)
|
tests/vis_lerobot_data_v1.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from vlaholo.datasets.lerobot_dataset import LeRobotDataset
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
from matplotlib.animation import FuncAnimation
|
| 7 |
+
from vlaholo.utils.dataset_utils import DEFAULT_VIDEO_PATH
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def plot_episode_joint_states(dataset_path: str, episode_index: int):
|
| 11 |
+
|
| 12 |
+
dataset = LeRobotDataset(dataset_path)
|
| 13 |
+
|
| 14 |
+
if episode_index > dataset.num_episodes:
|
| 15 |
+
print(
|
| 16 |
+
f"episode index {episode_index} is out of range, total episodes: {dataset.num_episodes}"
|
| 17 |
+
)
|
| 18 |
+
episode_index = dataset.num_episodes - 1
|
| 19 |
+
print(f"force set to max episode index: {episode_index}")
|
| 20 |
+
|
| 21 |
+
hf_dataset = dataset.hf_dataset
|
| 22 |
+
episode_ds = hf_dataset.filter(lambda x: x["episode_index"] == episode_index)
|
| 23 |
+
video_paths = dataset.encode_episode_videos(episode_index=episode_index)
|
| 24 |
+
|
| 25 |
+
caps = {}
|
| 26 |
+
for key, path in video_paths.items():
|
| 27 |
+
cap = cv2.VideoCapture(path)
|
| 28 |
+
if not cap.isOpened():
|
| 29 |
+
raise ValueError(f"Could not open video: {path}")
|
| 30 |
+
caps[key] = cap
|
| 31 |
+
|
| 32 |
+
fps = caps[list(caps.keys())[0]].get(cv2.CAP_PROP_FPS)
|
| 33 |
+
total_frames = int(caps[list(caps.keys())[0]].get(cv2.CAP_PROP_FRAME_COUNT))
|
| 34 |
+
|
| 35 |
+
df = episode_ds.to_pandas()
|
| 36 |
+
joint_states = np.vstack(df["observation.state"].values)
|
| 37 |
+
timestamps = df["timestamp"].values
|
| 38 |
+
|
| 39 |
+
fig = plt.figure(figsize=(15, 10))
|
| 40 |
+
gs = fig.add_gridspec(3, 2)
|
| 41 |
+
|
| 42 |
+
video_names_map = {
|
| 43 |
+
"observation.images.cam_high": "High Camera",
|
| 44 |
+
"observation.images.cam_left_wrist": "Left Wrist Camera",
|
| 45 |
+
"observation.images.cam_right_wrist": "Right Wrist Camera",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
video_axes = {
|
| 49 |
+
"observation.images.cam_high": fig.add_subplot(gs[0, 0]),
|
| 50 |
+
"observation.images.cam_left_wrist": fig.add_subplot(gs[0, 1]),
|
| 51 |
+
"observation.images.cam_right_wrist": fig.add_subplot(gs[1, :]),
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
joint_ax = fig.add_subplot(gs[2, :])
|
| 55 |
+
|
| 56 |
+
video_imgs = {}
|
| 57 |
+
for key, ax in video_axes.items():
|
| 58 |
+
ax.set_xticks([])
|
| 59 |
+
ax.set_yticks([])
|
| 60 |
+
ax.set_title(video_names_map[key])
|
| 61 |
+
img = ax.imshow(np.zeros((480, 640, 3)))
|
| 62 |
+
video_imgs[key] = img
|
| 63 |
+
|
| 64 |
+
n_joints = joint_states.shape[1]
|
| 65 |
+
lines = []
|
| 66 |
+
for i in range(n_joints):
|
| 67 |
+
(line,) = joint_ax.plot([], [], label=f"Joint {i}")
|
| 68 |
+
lines.append(line)
|
| 69 |
+
|
| 70 |
+
joint_ax.set_xlim(timestamps[0], timestamps[-1])
|
| 71 |
+
joint_ax.set_ylim(joint_states.min(), joint_states.max())
|
| 72 |
+
joint_ax.grid(True)
|
| 73 |
+
joint_ax.legend(loc="upper right")
|
| 74 |
+
joint_ax.set_xlabel("Time (s)")
|
| 75 |
+
|
| 76 |
+
time_line = joint_ax.axvline(x=timestamps[0], color="r")
|
| 77 |
+
|
| 78 |
+
def init():
|
| 79 |
+
for line in lines:
|
| 80 |
+
line.set_data([], [])
|
| 81 |
+
return lines + [time_line] + list(video_imgs.values())
|
| 82 |
+
|
| 83 |
+
def animate(frame_idx):
|
| 84 |
+
data_idx = min(frame_idx, len(timestamps) - 1)
|
| 85 |
+
current_time = timestamps[data_idx]
|
| 86 |
+
|
| 87 |
+
for key, cap in caps.items():
|
| 88 |
+
ret, frame = cap.read()
|
| 89 |
+
if ret:
|
| 90 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 91 |
+
video_imgs[key].set_array(frame)
|
| 92 |
+
|
| 93 |
+
for i, line in enumerate(lines):
|
| 94 |
+
line.set_data(timestamps[: data_idx + 1], joint_states[: data_idx + 1, i])
|
| 95 |
+
|
| 96 |
+
time_line.set_xdata([current_time, current_time])
|
| 97 |
+
return lines + [time_line] + list(video_imgs.values())
|
| 98 |
+
|
| 99 |
+
anim = FuncAnimation(
|
| 100 |
+
fig, animate, init_func=init, frames=total_frames, interval=1000 / fps, blit=True
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
save_dir = "outputs/"
|
| 104 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 105 |
+
output_path = os.path.join(save_dir, f"episode_{episode_index}_animation.mp4")
|
| 106 |
+
anim.save(output_path, writer="ffmpeg", fps=fps)
|
| 107 |
+
|
| 108 |
+
plt.close()
|
| 109 |
+
for cap in caps.values():
|
| 110 |
+
cap.release()
|
| 111 |
+
|
| 112 |
+
print(f"Animation saved to: {output_path}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
import argparse
|
| 117 |
+
|
| 118 |
+
parser = argparse.ArgumentParser(
|
| 119 |
+
description="Visualize joint states of a LeRobot dataset episode"
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument("dataset_path", type=str, help="Path or HF repo ID of the LeRobot dataset")
|
| 122 |
+
parser.add_argument("-i", type=int, default=89, help="Episode index to visualize")
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
|
| 125 |
+
plot_episode_joint_states(args.dataset_path, args.i)
|
wandb/debug.log
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2025-12-20 21:53:42,814 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Current SDK version is 0.18.0
|
| 2 |
+
2025-12-20 21:53:42,814 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Configure stats pid to 3365891
|
| 3 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Loading settings from /home/lumos6/.config/wandb/settings
|
| 4 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Loading settings from /home/lumos6/work/starforce2/wandb/settings
|
| 5 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
|
| 6 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Applying setup settings: {'_disable_service': False}
|
| 7 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'finetune.py', 'program_abspath': '/home/lumos6/work/starforce2/finetune.py', 'program': '/home/lumos6/work/starforce2/finetune.py'}
|
| 8 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Applying login settings: {}
|
| 9 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Applying login settings: {'mode': 'offline'}
|
| 10 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:_log_setup():525] Logging user logs to /home/lumos6/work/starforce2/wandb/offline-run-20251220_215342-tsf926l3/logs/debug.log
|
| 11 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:_log_setup():526] Logging internal logs to /home/lumos6/work/starforce2/wandb/offline-run-20251220_215342-tsf926l3/logs/debug-internal.log
|
| 12 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():609] calling init triggers
|
| 13 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():616] wandb.init called with sweep_config: {}
|
| 14 |
+
config: {}
|
| 15 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():659] starting backend
|
| 16 |
+
2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():663] setting up manager
|
| 17 |
+
2025-12-20 21:53:42,817 INFO MainThread:3365891 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
| 18 |
+
2025-12-20 21:53:42,817 INFO MainThread:3365891 [wandb_init.py:init():671] backend started and connected
|
| 19 |
+
2025-12-20 21:53:42,818 INFO MainThread:3365891 [wandb_init.py:init():766] updated telemetry
|
| 20 |
+
2025-12-20 21:53:42,823 INFO MainThread:3365891 [wandb_init.py:init():799] communicating run to backend with 90.0 second timeout
|
| 21 |
+
2025-12-20 21:53:42,835 INFO MainThread:3365891 [wandb_init.py:init():850] starting run threads in backend
|
| 22 |
+
2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_console_start():2466] atexit reg
|
| 23 |
+
2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_redirect():2312] redirect: wrap_raw
|
| 24 |
+
2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_redirect():2377] Wrapping output streams.
|
| 25 |
+
2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_redirect():2402] Redirects installed.
|
| 26 |
+
2025-12-20 21:53:42,974 INFO MainThread:3365891 [wandb_init.py:init():893] run started, returning control to user process
|
| 27 |
+
2025-12-20 21:53:42,974 INFO MainThread:3365891 [wandb_run.py:_config_callback():1393] config_cb None None {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['GR00T_N1_5'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': 'outputs/gr00t-3b-piper-task-pickup-bs8-1gpu-step60k/final_model', 'transformers_version': '4.52.2', 'action_dim': 32, 'action_head_cfg': {'action_dim': 32, 'action_horizon': 16, 'add_pos_embed': True, 'backbone_embedding_dim': 2048, 'diffusion_model_cfg': {'attention_head_dim': 48, 'cross_attention_dim': 2048, 'dropout': 0.2, 'final_dropout': True, 'interleave_self_attention': True, 'norm_type': 'ada_norm', 'num_attention_heads': 32, 'num_layers': 16, 'output_dim': 1024, 'positional_embeddings': None}, 'hidden_size': 1024, 'input_embedding_dim': 1536, 'max_action_dim': 32, 'max_state_dim': 64, 'model_dtype': 'float32', 'noise_beta_alpha': 1.5, 'noise_beta_beta': 1.0, 'noise_s': 0.999, 'num_inference_timesteps': 4, 'num_target_vision_tokens': 32, 'num_timestep_buckets': 1000, 'tune_diffusion_model': True, 'tune_projector': True, 'use_vlln': True, 'vl_self_attention_cfg': {'attention_head_dim': 64, 'dropout': 0.2, 'final_dropout': True, 'num_attention_heads': 32, 'num_layers': 4, 'positional_embeddings': None}}, 'action_horizon': 16, 'backbone_cfg': {'eagle_path': 'NVEagle/eagle_er-qwen3_1_7B-Siglip2_400M_stage1_5_128gpu_er_v7_1mlp_nops', 'load_bf16': False, 'project_to_dim': None, 'reproject_vision': False, 'select_layer': 12, 'tune_llm': False, 'tune_visual': True, 'use_flash_attention': True}, 'compute_dtype': 'bfloat16', 'hidden_size': 2048, 'model_dtype': 'float32', 'model_type': 'gr00t_n1_5', 'attn_implementation': None, 'output_dir': 'outputs/gr00t-3b-piper-task-pickup02-bs8-1gpu-step60k', 'overwrite_output_dir': False, 'do_train': False, 'do_eval': False, 'do_predict': False, 'eval_strategy': 'no', 'prediction_loss_only': False, 'per_device_train_batch_size': 16, 'per_device_eval_batch_size': 8, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 2.5e-05, 'weight_decay': 1e-05, 'adam_beta1': 0.95, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 300, 'max_steps': 60000, 'lr_scheduler_type': 'cosine', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.05, 'warmup_steps': 0, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': 'outputs/gr00t-3b-piper-task-pickup02-bs8-1gpu-step60k/runs/Dec20_21-53-38_lumos6', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 10, 'logging_nan_inf_filter': True, 'save_strategy': 'steps', 'save_steps': 5000, 'save_total_limit': 3, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': True, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': True, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': None, 'dataloader_num_workers': 8, 'dataloader_prefetch_factor': None, 'past_index': -1, 'run_name': 'outputs/gr00t-3b-piper-task-pickup02-bs8-1gpu-step60k', 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': '', 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard'], 'ddp_find_unused_parameters': False, 'ddp_bucket_cap_mb': 100, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': False, 'dataloader_persistent_workers': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': None, 'hub_always_push': False, 'gradient_checkpointing': False, 'gradient_checkpointing_kwargs': None, 'include_inputs_for_metrics': False, 'include_for_metrics': [], 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'eval_use_gather_object': False, 'average_tokens_across_devices': False}
|
| 28 |
+
2025-12-20 21:53:42,976 INFO MainThread:3365891 [wandb_config.py:__setitem__():154] config set model/num_parameters = 2724163520 - <bound method Run._config_callback of <wandb.sdk.wandb_run.Run object at 0x75c8c8147460>>
|
| 29 |
+
2025-12-20 21:53:42,976 INFO MainThread:3365891 [wandb_run.py:_config_callback():1393] config_cb model/num_parameters 2724163520 None
|
| 30 |
+
2025-12-21 08:08:04,099 WARNING MsgRouterThr:3365891 [router.py:message_loop():77] message_loop has been closed
|