nnh-pbbb commited on
Commit
cd793b5
·
verified ·
1 Parent(s): 5dac561

Add files using upload-large-folder tool

Browse files
robot/cam.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import json_numpy
3
+ import numpy as np
4
+ import requests
5
+ import queue
6
+ import cv2
7
+ import pickle, os
8
+ from matplotlib.pyplot import step
9
+ from typing import Any, cast
10
+
11
+ json_numpy.patch()
12
+
13
+ from starforce.experiment.data_config import DATA_CONFIG_MAP
14
+ from starforce.model.policy import Gr00tPolicy
15
+
16
+ # Import RTCController from separate module
17
+ from loguru import logger
18
+
19
+ try:
20
+ import pyrealsense2 as rs
21
+ except ImportError:
22
+ print("Warning: pyrealsense2 not available. Camera functionality will be limited.")
23
+ rs = None
24
+
25
+ # Help static analyzers: treat rs as dynamic Any when available
26
+ if rs is not None:
27
+ rs = cast(Any, rs)
28
+
29
+
30
+ class CameraWrapper:
31
+ def __init__(
32
+ self, devices=None, width=640, height=480, fps=30, num_realsense=0, cv_format="MJPEG"
33
+ ):
34
+ self.width = width
35
+ self.height = height
36
+ self.fps = fps
37
+ self.num_realsense = max(0, int(num_realsense))
38
+ self.cv_format = cv_format
39
+ self.cameras = [] # list of dicts: {type: 'rs'|'cv', handle: pipeline|cap}
40
+ self.device_ids = devices if devices is not None else []
41
+ self._open_cameras()
42
+ print(f"successfully opened {len(self.cameras)} cameras!")
43
+
44
+ def _open_cameras(self):
45
+ if not self.device_ids:
46
+ print("No devices provided for CameraWrapper")
47
+ return
48
+
49
+ for idx, dev in enumerate(self.device_ids):
50
+ # Decide camera type
51
+ use_realsense = idx < self.num_realsense
52
+
53
+ if use_realsense:
54
+ if rs is None:
55
+ print(
56
+ f"pyrealsense2 not available, skipping RealSense device at index {idx} (id: {dev})"
57
+ )
58
+ continue
59
+ try:
60
+ serial = str(dev)
61
+ pipeline = rs.pipeline() # type: ignore[attr-defined]
62
+ config = rs.config() # type: ignore[attr-defined]
63
+ config.enable_device(serial)
64
+ config.enable_stream(rs.stream.color, self.width, self.height, rs.format.bgr8, self.fps) # type: ignore[attr-defined]
65
+ pipeline.start(config)
66
+ self.cameras.append({"type": "rs", "handle": pipeline})
67
+ print(f"RealSense camera {serial} opened successfully")
68
+ except Exception as e:
69
+ print(f"Failed to open RealSense camera {dev}: {e}")
70
+ else:
71
+ try:
72
+ device_index = int(dev)
73
+ print(f"Ready to read deive: {device_index}")
74
+ cap = cv2.VideoCapture(device_index)
75
+
76
+ if self.cv_format == "MJPEG":
77
+ cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) # type: ignore[attr-defined]
78
+ elif self.cv_format == "YUYV":
79
+ cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV")) # type: ignore[attr-defined]
80
+
81
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
82
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
83
+ cap.set(cv2.CAP_PROP_FPS, self.fps)
84
+
85
+ if not cap.isOpened():
86
+ raise ValueError(f"Cannot open OpenCV camera {device_index}")
87
+
88
+ self.cameras.append({"type": "cv", "handle": cap})
89
+ print(f"OpenCV camera {device_index} opened successfully")
90
+ except Exception as e:
91
+ print(f"Failed to open OpenCV camera {dev}: {e}")
92
+
93
+ def get_images(self):
94
+ images = []
95
+ if len(self.cameras) == 0:
96
+ # Return dummy images if no cameras available - use 640x480 which is expected by the model
97
+ for _ in range(max(1, len(self.device_ids))):
98
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
99
+ dummy_img[:, :, :] = 128 # Gray color instead of black
100
+ images.append(dummy_img)
101
+ return images
102
+
103
+ for cam in self.cameras:
104
+ if cam["type"] == "rs":
105
+ try:
106
+ pipeline = cam["handle"]
107
+ frames = pipeline.wait_for_frames()
108
+ color_frame = frames.get_color_frame()
109
+ if not color_frame:
110
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
111
+ dummy_img[:, :, :] = 128
112
+ images.append(dummy_img)
113
+ else:
114
+ img = np.asanyarray(color_frame.get_data())
115
+ images.append(img)
116
+ except Exception as e:
117
+ print(f"Error reading from RealSense: {e}")
118
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
119
+ dummy_img[:, :, :] = 128
120
+ images.append(dummy_img)
121
+ elif cam["type"] == "cv":
122
+ cap = cam["handle"]
123
+ ret, frame = cap.read()
124
+ if not ret or frame is None:
125
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
126
+ dummy_img[:, :, :] = 128
127
+ images.append(dummy_img)
128
+ else:
129
+ images.append(frame)
130
+ return images
131
+
132
+ def release(self):
133
+ for cam in self.cameras:
134
+ if cam["type"] == "rs":
135
+ try:
136
+ cam["handle"].stop()
137
+ except Exception:
138
+ pass
139
+ elif cam["type"] == "cv":
140
+ try:
141
+ cam["handle"].release()
142
+ except Exception:
143
+ pass
144
+ self.cameras = []
starforce.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
starhelm/starhelm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .websocket_client_policy import WebsocketClientPolicy
2
+
3
+ __version__ = "0.1.0"
starhelm/starhelm/image_tools_test.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import openpi_client.image_tools as image_tools
4
+
5
+
6
+ def test_resize_with_pad_shapes():
7
+ # Test case 1: Resize image with larger dimensions
8
+ images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
9
+ height = 20
10
+ width = 20
11
+ resized_images = image_tools.resize_with_pad(images, height, width)
12
+ assert resized_images.shape == (2, height, width, 3)
13
+ assert np.all(resized_images == 0)
14
+
15
+ # Test case 2: Resize image with smaller dimensions
16
+ images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
17
+ height = 15
18
+ width = 15
19
+ resized_images = image_tools.resize_with_pad(images, height, width)
20
+ assert resized_images.shape == (3, height, width, 3)
21
+ assert np.all(resized_images == 0)
22
+
23
+ # Test case 3: Resize image with the same dimensions
24
+ images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
25
+ height = 50
26
+ width = 50
27
+ resized_images = image_tools.resize_with_pad(images, height, width)
28
+ assert resized_images.shape == (1, height, width, 3)
29
+ assert np.all(resized_images == 0)
30
+
31
+ # Test case 3: Resize image with odd-numbered padding
32
+ images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
33
+ height = 60
34
+ width = 80
35
+ resized_images = image_tools.resize_with_pad(images, height, width)
36
+ assert resized_images.shape == (1, height, width, 3)
37
+ assert np.all(resized_images == 0)
test_starhelm.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ read data from lerobot
3
+
4
+ requesting the server
5
+ """
6
+
7
+ from starhelm.websocket_client_policy import WebsocketClientPolicy
8
+ import time
9
+ import torch
10
+ from vlaholo.datasets.lerobot_dataset import LeRobotDataset
11
+
12
+ # from lerobot.datasets.lerobot_dataset import LeRobotDataset
13
+ # from starforce.data.dataset import LeRobotSingleDataset as LeRobotDataset
14
+ import sys
15
+
16
+
17
+ import os
18
+
19
+ os.environ.pop("http_proxy", None)
20
+ os.environ.pop("https_proxy", None)
21
+ os.environ.pop("all_proxy", None)
22
+
23
+
24
+ def get_dummy_data():
25
+ """
26
+ Same data loader from vlaholo
27
+
28
+ """
29
+ dataset_repo_id = "data/qz_zz/0801_task9/pick"
30
+ dataset = LeRobotDataset(dataset_repo_id, episodes=[10], video_backend="pyav")
31
+
32
+ dataloader = torch.utils.data.DataLoader(
33
+ dataset,
34
+ num_workers=0,
35
+ batch_size=1,
36
+ )
37
+
38
+ batch = next(batch for i, batch in enumerate(dataloader) if i == 29)
39
+ return dataset, batch
40
+
41
+
42
+ def tensor_as_np_image(t):
43
+ return t[0].permute(1, 2, 0).cpu().unsqueeze(0).numpy()
44
+
45
+
46
+ if __name__ == "__main__":
47
+
48
+ # vla_model = WebsocketClientPolicy(host="172.16.0.171", port=9001)
49
+ vla_model = WebsocketClientPolicy(host="172.16.0.111", port=9001)
50
+
51
+ ds, batch_data = get_dummy_data()
52
+ print(batch_data.keys())
53
+ t0 = time.time()
54
+ benchmark_iters = 30
55
+ for _ in range(benchmark_iters):
56
+ # print(batch)
57
+ t00 = time.time()
58
+
59
+ # hwc 0-1 numpy array
60
+ image_cam_high = tensor_as_np_image(batch_data["observation.images.cam_high"])
61
+ image_cam_left = tensor_as_np_image(batch_data["observation.images.cam_left_wrist"])
62
+ image_cam_right = tensor_as_np_image(batch_data["observation.images.cam_right_wrist"])
63
+ # [1, H, W, 3] 0-1 pixelvalue
64
+ print(f"image_cam_high: {image_cam_high.shape}")
65
+ # obs format
66
+ obs = {
67
+ "images": {
68
+ "cam_high": image_cam_high,
69
+ "cam_left_wrist": image_cam_left,
70
+ "cam_right_wrist": image_cam_right,
71
+ },
72
+ # state: [1, 14]
73
+ "state": batch_data["observation.state"].cpu().numpy(),
74
+ # str language
75
+ "prompt": batch_data["task"],
76
+ # for verbose
77
+ "debug": True,
78
+ }
79
+ action = vla_model.infer(obs=obs)
80
+
81
+ print(
82
+ "##info, action:",
83
+ action,
84
+ time.time() - t00,
85
+ )
86
+ break
87
+ t1 = time.time()
88
+ print(f"cost: {t1-t0:.3f}, avg: {(t1-t0)/benchmark_iters}")
tests/alter_lerobot_key.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ change a key of an existed lerobot dataset
4
+ """
5
+
6
+ import os
7
+ import shutil
8
+ from pathlib import Path
9
+ from datasets import load_dataset
10
+ from vlaholo.datasets.lerobot_dataset import LeRobotDataset
11
+
12
+ # Paths
13
+ # old_root = Path("data/qz/lerobot_data/airbot_datasets/airbot_data_0724/pick")
14
+ old_root = Path("data/sl/0721pre_data_v3")
15
+ old_root = Path("/pfs/data/yangcheng/.cache/huggingface/lerobot/pick")
16
+ new_root = old_root / 'new'
17
+
18
+ # 1. Load the existing dataset using LeRobotDataset
19
+ old_ds = LeRobotDataset(repo_id=old_root)
20
+ print(f'done')
21
+
22
+ # 2. Create a fresh LeRobotDataset at the new location with the same metadata
23
+ new_ds = LeRobotDataset.create(
24
+ repo_id="converted_dataset",
25
+ fps=old_ds.meta.info["fps"],
26
+ root=new_root,
27
+ features=old_ds.meta.info["features"],
28
+ use_videos=old_ds.meta.info.get("video", True),
29
+ )
30
+ # 3. Copy auxiliary folders (videos, tasks, stats, info)
31
+ for folder in ["videos", "meta"]:
32
+ src = old_root / folder
33
+ dst = new_root / folder
34
+ if src.exists():
35
+ shutil.copytree(src, dst, dirs_exist_ok=True)
36
+
37
+ # 4. Iterate through episodes and rewrite parquet with renamed key
38
+ for ep_idx in range(old_ds.meta.total_episodes):
39
+ # Load episode data via Hugging Face datasets
40
+ data_path = old_root / old_ds.meta.get_data_file_path(ep_idx)
41
+ ds = load_dataset(
42
+ "parquet",
43
+ data_files=[str(data_path)],
44
+ split="train",
45
+ )
46
+ # Rename column
47
+ ds = ds.rename_column("state", "observation.state")
48
+
49
+ # Write back via LeRobotDataset's internal method
50
+ out_path = new_root / new_ds.meta.get_data_file_path(ep_idx)
51
+ ds.to_parquet(str(out_path))
52
+
53
+ print(f"✅ New dataset with renamed key saved to {new_root}")
tests/async_client.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import threading
3
+ import json_numpy
4
+ import numpy as np
5
+ import requests
6
+ import queue
7
+ import cv2
8
+ import pickle, os
9
+ from matplotlib.pyplot import step
10
+ from typing import Any, cast
11
+
12
+ json_numpy.patch()
13
+
14
+ from airbot_py.arm import AIRBOTPlay, RobotMode, SpeedProfile
15
+
16
+ # Import RTCController from separate module
17
+ from rtc_controller import RTCController
18
+ from ruckig_planner import RuckigPlanner
19
+
20
+
21
+ try:
22
+ import pyrealsense2 as rs
23
+ except ImportError:
24
+ print("Warning: pyrealsense2 not available. Camera functionality will be limited.")
25
+ rs = None
26
+
27
+ # Help static analyzers: treat rs as dynamic Any when available
28
+ if rs is not None:
29
+ rs = cast(Any, rs)
30
+
31
+ class CameraWrapper:
32
+ def __init__(self, devices=None, width=640, height=480, fps=30, num_realsense=0, cv_format="MJPEG"):
33
+ self.width = width
34
+ self.height = height
35
+ self.fps = fps
36
+ self.num_realsense = max(0, int(num_realsense))
37
+ self.cv_format = cv_format
38
+ self.cameras = [] # list of dicts: {type: 'rs'|'cv', handle: pipeline|cap}
39
+ self.device_ids = devices if devices is not None else []
40
+ self._open_cameras()
41
+ print(f'successfully opened {len(self.cameras)} cameras!')
42
+
43
+ def _open_cameras(self):
44
+ if not self.device_ids:
45
+ print("No devices provided for CameraWrapper")
46
+ return
47
+
48
+ for idx, dev in enumerate(self.device_ids):
49
+ # Decide camera type
50
+ use_realsense = (idx < self.num_realsense)
51
+
52
+ if use_realsense:
53
+ if rs is None:
54
+ print(f"pyrealsense2 not available, skipping RealSense device at index {idx} (id: {dev})")
55
+ continue
56
+ try:
57
+ serial = str(dev)
58
+ pipeline = rs.pipeline() # type: ignore[attr-defined]
59
+ config = rs.config() # type: ignore[attr-defined]
60
+ config.enable_device(serial)
61
+ config.enable_stream(rs.stream.color, self.width, self.height, rs.format.bgr8, self.fps) # type: ignore[attr-defined]
62
+ pipeline.start(config)
63
+ self.cameras.append({"type": "rs", "handle": pipeline})
64
+ print(f"RealSense camera {serial} opened successfully")
65
+ except Exception as e:
66
+ print(f"Failed to open RealSense camera {dev}: {e}")
67
+ else:
68
+ try:
69
+ device_index = int(dev)
70
+ cap = cv2.VideoCapture(device_index)
71
+
72
+ if self.cv_format == "MJPEG":
73
+ cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"MJPG")) # type: ignore[attr-defined]
74
+ elif self.cv_format == "YUYV":
75
+ cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV")) # type: ignore[attr-defined]
76
+
77
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
78
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
79
+ cap.set(cv2.CAP_PROP_FPS, self.fps)
80
+
81
+ if not cap.isOpened():
82
+ print(f"Error: Cannot open OpenCV camera {device_index}")
83
+ continue
84
+
85
+ self.cameras.append({"type": "cv", "handle": cap})
86
+ print(f"OpenCV camera {device_index} opened successfully")
87
+ except Exception as e:
88
+ print(f"Failed to open OpenCV camera {dev}: {e}")
89
+
90
+ def get_images(self):
91
+ images = []
92
+ if len(self.cameras) == 0:
93
+ # Return dummy images if no cameras available - use 640x480 which is expected by the model
94
+ for _ in range(max(1, len(self.device_ids))):
95
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
96
+ dummy_img[:, :, :] = 128 # Gray color instead of black
97
+ images.append(dummy_img)
98
+ return images
99
+
100
+ for cam in self.cameras:
101
+ if cam["type"] == "rs":
102
+ try:
103
+ pipeline = cam["handle"]
104
+ frames = pipeline.wait_for_frames()
105
+ color_frame = frames.get_color_frame()
106
+ if not color_frame:
107
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
108
+ dummy_img[:, :, :] = 128
109
+ images.append(dummy_img)
110
+ else:
111
+ img = np.asanyarray(color_frame.get_data())
112
+ images.append(img)
113
+ except Exception as e:
114
+ print(f"Error reading from RealSense: {e}")
115
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
116
+ dummy_img[:, :, :] = 128
117
+ images.append(dummy_img)
118
+ elif cam["type"] == "cv":
119
+ cap = cam["handle"]
120
+ ret, frame = cap.read()
121
+ if not ret or frame is None:
122
+ dummy_img = np.zeros((self.height, self.width, 3), dtype=np.uint8)
123
+ dummy_img[:, :, :] = 128
124
+ images.append(dummy_img)
125
+ else:
126
+ images.append(frame)
127
+ return images
128
+
129
+ def release(self):
130
+ for cam in self.cameras:
131
+ if cam["type"] == "rs":
132
+ try:
133
+ cam["handle"].stop()
134
+ except Exception:
135
+ pass
136
+ elif cam["type"] == "cv":
137
+ try:
138
+ cam["handle"].release()
139
+ except Exception:
140
+ pass
141
+ self.cameras = []
142
+
143
+
144
+ def normalization(state):
145
+ """Normalize robot state for model input.
146
+
147
+ - Threshold gripper: > 0.04 -> 1.0 else -1.0
148
+ - Keep 6 joint values unchanged
149
+ Returns np.ndarray of shape (7,)
150
+ """
151
+ arr = np.array(state, dtype=np.float32).copy()
152
+ arr[6] = 1.0 if arr[6] > 0.04 else -1.0
153
+ return arr
154
+
155
+
156
+ def unnormalization(action):
157
+ """Unnormalize model action to robot command space.
158
+
159
+ - Map gripper: > 0.5 -> 0.06850814 (open), else -> 0.025 (close)
160
+ - Keep 6 joint values unchanged
161
+ Returns np.ndarray of shape (7,)
162
+ """
163
+ arr = np.array(action, dtype=np.float32).copy()
164
+ arr[6] = 0.06850814 if arr[6] > 0 else 0.025
165
+ return arr
166
+
167
+
168
+ class RobotWrapper:
169
+ # Ports dict example: {'left_arm': 50051, 'right_arm': None}
170
+ # Only arms with a non-None port will be initialized
171
+ def __init__(self, url='localhost', ports=None, arm_speed='slow', type='move'):
172
+ if ports is None:
173
+ ports = {'left_arm': None, 'right_arm': None}
174
+ assert any(p is not None for p in ports.values()), "at least one arm port is required"
175
+ assert arm_speed in ['slow', 'default', 'fast'], "arm_speed must be in ['slow', default', 'fast']"
176
+ assert type in ['move', 'servo'], "type must be in ['move', 'servo']"
177
+ self.type = type
178
+
179
+ self.robots = {}
180
+
181
+ for arm_name in ['left_arm', 'right_arm']:
182
+ port = ports.get(arm_name)
183
+ if port is None:
184
+ continue
185
+
186
+ robot = AIRBOTPlay(url=url, port=port)
187
+ robot.connect()
188
+ robot.set_speed_profile(SpeedProfile.SLOW if arm_speed == 'slow' else SpeedProfile.FAST)
189
+
190
+ product_info = robot.get_product_info()
191
+ print("---------------------------------------------------------")
192
+ print(f"Arm: {arm_name}")
193
+ print(f"Product name: {product_info['product_type']}")
194
+ print(f"Serial number: {product_info['sn']}")
195
+ print(f"Simulation mode: {product_info['is_sim']}")
196
+ print(f"Using interfaces: {product_info['interfaces']}")
197
+ print(f"Installed end effectors: {product_info['eef_types']}")
198
+ print(f"Firmware versions: {product_info['fw_versions']}")
199
+ print("---------------------------------------------------------")
200
+
201
+ # Default initial joints per arm (optional; can be customized)
202
+ if arm_name == 'left_arm':
203
+ joints = [0.0, 0.0, 0.15, -1.7, 0.1, 1.7]
204
+ else:
205
+ joints = [0.0, 0.0, 0.15, 1.7, -0.1, -1.7]
206
+
207
+ # using move mode to move to initial pose
208
+ robot.switch_mode(RobotMode.PLANNING_POS)
209
+ # Overwrite with a safe neutral pose
210
+ joints = [0.00019073777366429567, 0.17948424816131592, 0.027656977996230125, 1.4654383659362793, -0.3435187339782715, -1.4288166761398315]
211
+ # joints = [0.03604944050312042, 0.17948424816131592, 0.029564354568719864, 1.6039139032363892, -0.3419928252696991, -1.5939955711364746]
212
+ robot.move_to_joint_pos(joints)
213
+ robot.move_eef_pos([0.06850814])
214
+ print(f'arm: {arm_name}, joints: {joints}')
215
+
216
+ if self.type == 'move':
217
+ robot.switch_mode(RobotMode.PLANNING_POS)
218
+ elif self.type == 'servo':
219
+ robot.switch_mode(RobotMode.SERVO_JOINT_POS)
220
+
221
+ init_joint_pos = robot.get_joint_pos()
222
+ init_eef_pos = robot.get_eef_pos()
223
+ print(f'[{arm_name}] init_joint_pos: {init_joint_pos}, init_eef_pos: {init_eef_pos}')
224
+ self.robots[arm_name] = robot
225
+ print(f"robot arm {arm_name} (port: {port}) init success!")
226
+ time.sleep(2)
227
+
228
+ def move_to_pos(self, pos, arm='right_arm'):
229
+ assert arm in self.robots, f"arm '{arm}' not initialized"
230
+ if self.type == 'move':
231
+ self.robots[arm].move_to_joint_pos(pos[:6], blocking=True)
232
+ self.robots[arm].move_eef_pos([pos[6]], blocking=True)
233
+ elif self.type == 'servo':
234
+ self.robots[arm].servo_joint_pos(pos[:6])
235
+ self.robots[arm].servo_eef_pos([pos[6]])
236
+
237
+ def get_joint_pos(self, arm='right_arm'):
238
+ assert arm in self.robots, f"arm '{arm}' not initialized"
239
+ return self.robots[arm].get_joint_pos()
240
+
241
+ def get_eef_pos(self, arm='right_arm'):
242
+ assert arm in self.robots, f"arm '{arm}' not initialized"
243
+ return self.robots[arm].get_eef_pos()
244
+
245
+ def get_state_pos(self, arm='right_arm'):
246
+ assert arm in self.robots, f"arm '{arm}' not initialized"
247
+ pos = self.robots[arm].get_joint_pos()
248
+ eef_pos = self.robots[arm].get_eef_pos()
249
+ result = pos + eef_pos
250
+ return result
251
+
252
+ class ActionSmoother:
253
+ def __init__(self, method='exponential', alpha=0.3, window_size=5, smooth_dims=None):
254
+ self.method = method
255
+ self.window_size = window_size
256
+ self.smooth_dims = smooth_dims
257
+
258
+ if isinstance(alpha, (list, tuple, np.ndarray)):
259
+ self.alpha = np.array(alpha, dtype=np.float32)
260
+ else:
261
+ self.alpha = alpha
262
+
263
+ self.history = []
264
+ self.smoothed_action = None
265
+
266
+ def smooth_action(self, raw_action):
267
+ raw_action = np.array(raw_action, dtype=np.float32)
268
+
269
+ self.history.append(raw_action.copy())
270
+ if len(self.history) > self.window_size:
271
+ self.history = self.history[-self.window_size:]
272
+
273
+ if self.smoothed_action is None:
274
+ self.smoothed_action = raw_action.copy()
275
+ return self.smoothed_action
276
+
277
+ result_action = raw_action.copy()
278
+
279
+ if self.smooth_dims is None:
280
+ dims_to_smooth = list(range(len(raw_action)))
281
+ else:
282
+ dims_to_smooth = [d for d in self.smooth_dims if d < len(raw_action)]
283
+
284
+ if self.method == 'exponential':
285
+ if isinstance(self.alpha, np.ndarray):
286
+ for dim in dims_to_smooth:
287
+ alpha = self.alpha[dim] if dim < len(self.alpha) else self.alpha[-1]
288
+ result_action[dim] = alpha * raw_action[dim] + (1 - alpha) * self.smoothed_action[dim]
289
+ else:
290
+ for dim in dims_to_smooth:
291
+ result_action[dim] = self.alpha * raw_action[dim] + (1 - self.alpha) * self.smoothed_action[dim]
292
+
293
+ elif self.method == 'moving_average':
294
+ history_array = np.array(self.history)
295
+ for dim in dims_to_smooth:
296
+ result_action[dim] = np.mean(history_array[:, dim])
297
+
298
+ elif self.method == 'linear_interpolation':
299
+ for dim in dims_to_smooth:
300
+ result_action[dim] = 0.7 * raw_action[dim] + 0.3 * self.smoothed_action[dim]
301
+
302
+ elif self.method == 'identity':
303
+ return result_action
304
+
305
+ elif self.method == 'average':
306
+ history_array = np.array(self.history)
307
+ for dim in dims_to_smooth:
308
+ result_action[dim] = np.mean(history_array[:, dim])
309
+
310
+ else:
311
+ raise ValueError(f"Unknown smoothing method: {self.method}")
312
+
313
+ self.smoothed_action = result_action.copy()
314
+
315
+ return result_action
316
+
317
+ def reset(self):
318
+ self.history = []
319
+ self.smoothed_action = None
320
+
321
+
322
+ class VLAClient:
323
+ def __init__(self, server_url):
324
+ self.server_url = server_url
325
+
326
+ def predict(self, obs):
327
+ try:
328
+ response = requests.post(
329
+ self.server_url,
330
+ json={"observation": obs},
331
+ )
332
+ action_chunk = response.json()
333
+
334
+ actions = []
335
+ for arm, gripper in zip(action_chunk['action.right_arm'], action_chunk['action.right_gripper']):
336
+ action = np.asarray(list(arm) + [float(gripper)], dtype=np.float32)
337
+ actions.append(action)
338
+
339
+ return np.array(actions)
340
+ except Exception as e:
341
+ print(f"VLA prediction error: {e}")
342
+ return None
343
+
344
+
345
+ def predict_actions(server_url, obs):
346
+ response = requests.post(
347
+ server_url,
348
+ json={"observation": obs},
349
+ )
350
+ return response.json()
351
+
352
+
353
+ if __name__ == "__main__":
354
+
355
+ import argparse
356
+ parser = argparse.ArgumentParser()
357
+ parser.add_argument("--use_rtc", action='store_true', default=False)
358
+ parser.add_argument("--fast", action='store_true', default=False)
359
+ args = parser.parse_args()
360
+
361
+ USE_RTC_CONTROLLER = args.use_rtc
362
+ print(f"Use RTC Controller: {USE_RTC_CONTROLLER}")
363
+
364
+ task_description= "Pick up each block one by one and place them all into the bowl."
365
+ task_description = 'Pick up each block one by one and place them all into the blue bowl.'
366
+ task_description = 'Pick up each block one by one and place them all into the right bowl.'
367
+
368
+
369
+ print("Robot System Init...")
370
+ robots = RobotWrapper(
371
+ url='localhost',
372
+ ports={'left_arm': None, 'right_arm': 50053},
373
+ arm_speed='default' if args.fast else 'slow',
374
+ type='servo'
375
+ )
376
+
377
+ if args.fast:
378
+ target_queue = queue.Queue(maxsize=30)
379
+ servo_queue = queue.Queue(maxsize=30)
380
+ rp = RuckigPlanner(robots.robots['right_arm'], robots.get_state_pos('right_arm'), DoFs=7, dt=0.02)
381
+ rp.start(servo_queue, target_queue)
382
+
383
+ # server_url = "http://106.13.248.32:10090/act"
384
+ server_url = "http://127.0.0.1:10090/act"
385
+ # server_url = "http://114.111.24.161:10090/act"
386
+ print(f"VLA Server URL: {server_url}")
387
+
388
+ # Camera System Init
389
+ print("Camera System Init...")
390
+ caps = CameraWrapper(
391
+ devices=["215322074711", "242622070332", 0],
392
+ num_realsense=2,
393
+ width=640,
394
+ height=480,
395
+ fps=30,
396
+ cv_format="MJPEG"
397
+ )
398
+
399
+ time.sleep(2)
400
+
401
+ print(f'Smoother System Init for Grippers...')
402
+ action_smoother = ActionSmoother(
403
+ method='average',
404
+ alpha=0.1,
405
+ window_size=10,
406
+ smooth_dims=[6]
407
+ )
408
+ # action_smoother = None
409
+
410
+ print("Loop Start!")
411
+ step_count = 0
412
+
413
+ if USE_RTC_CONTROLLER:
414
+ print("Using ASync (RTCController) Mode...")
415
+
416
+ vla_client = VLAClient(server_url)
417
+
418
+ def get_observation():
419
+ images = caps.get_images()
420
+ images = caps.get_images()
421
+ images = caps.get_images()
422
+
423
+ if len(images) >= 3:
424
+ img_right, img_front, img_env = images[:3]
425
+ else:
426
+ filler = np.zeros((480, 640, 3), dtype=np.uint8)
427
+ filler[:, :, :] = 128
428
+ imgs = images + [filler] * (3 - len(images))
429
+ img_right, img_front, img_env = imgs[:3]
430
+
431
+ arm_state = normalization(robots.get_state_pos(arm='right_arm'))
432
+
433
+ obs = {
434
+ "video.cam_head": img_front[np.newaxis, ::],
435
+ "video.cam_env": img_env[np.newaxis, ::],
436
+ "video.cam_right_wrist": img_right[np.newaxis, ::],
437
+ "state.right_arm": np.expand_dims(np.array(arm_state[:6], dtype=np.float32), axis=0),
438
+ "state.right_gripper": np.expand_dims(np.array([arm_state[6]], dtype=np.float32), axis=0),
439
+ "annotation.human.task_description": [task_description]
440
+ }
441
+
442
+ return obs
443
+
444
+ # RTCController Settings
445
+ H = 16
446
+ d = 8
447
+ s = 5
448
+
449
+ print("RTCController System Init...")
450
+ rtc_controller = RTCController(
451
+ vla_client=vla_client,
452
+ observation_fn=get_observation,
453
+ H=H,
454
+ d=d,
455
+ s=s,
456
+ )
457
+
458
+ print("predict action chunk...")
459
+ rtc_controller.reset()
460
+
461
+ if not rtc_controller.is_ready():
462
+ print(f'RTCController Configuration Failed!')
463
+ exit(1)
464
+
465
+ print(f"RTCController Configuration Completed")
466
+ print(f"RTC Parames: H={H}, d={d}, s={s}")
467
+
468
+ while True:
469
+ try:
470
+ loop_start_time = time.time()
471
+ step_count += 1
472
+
473
+ action = rtc_controller.step()
474
+
475
+ if action is None:
476
+ print("RTCController is not ready...")
477
+ time.sleep(0.05)
478
+ continue
479
+
480
+ denormalized_action = unnormalization(action)
481
+ denormalized_action = action_smoother.smooth_action(denormalized_action) if action_smoother else denormalized_action
482
+
483
+ if args.fast:
484
+ target_queue.put(denormalized_action.tolist())
485
+ else:
486
+ robots.move_to_pos(denormalized_action.tolist(), arm='right_arm')
487
+
488
+ # if step_count % 10 == 0:
489
+ # print(f"[RTC] Step {step_count}: executor_index={rtc_controller.executor_index}")
490
+ # print(f"Action: {action}")
491
+ # current_state = robots.get_state_pos(arm='right_arm')
492
+ # print(f"Robot state: {current_state}")
493
+
494
+ # 20Hz
495
+ if args.fast:
496
+ target_period = 0.1
497
+ else:
498
+ target_period = 0.05
499
+ elapsed = time.time() - loop_start_time
500
+ sleep_time = max(0.0, target_period - elapsed)
501
+
502
+ if sleep_time > 0:
503
+ time.sleep(sleep_time)
504
+ else:
505
+ print(f"Loop control time cost: {elapsed:.3f}s > {target_period:.3f}s")
506
+
507
+ except KeyboardInterrupt:
508
+ print("Received interrupt signal, exiting safely...")
509
+
510
+
511
+ break
512
+ except Exception as e:
513
+ print(f"ASYNC control mode error: {e}")
514
+ time.sleep(0.05)
515
+
516
+
517
+ else:
518
+ print("Using Sync Mode...")
519
+
520
+ current_actions = []
521
+ current_action_idx = 0
522
+ chunk_size = 16
523
+
524
+ while True:
525
+ try:
526
+ loop_start_time = time.time()
527
+ step_count += 1
528
+
529
+ # 如果所有actions都执行完了,获取新的chunk
530
+ if current_action_idx >= len(current_actions):
531
+ print(f"[Sync Mode] Obtain new action chunk...")
532
+
533
+ # image
534
+ images = caps.get_images()
535
+ images = caps.get_images()
536
+ images = caps.get_images()
537
+
538
+ if len(images) >= 3:
539
+ img_right, img_front, img_env = images[:3]
540
+ else:
541
+ filler = np.zeros((480, 640, 3), dtype=np.uint8)
542
+ filler[:, :, :] = 128
543
+ imgs = images + [filler] * (3 - len(images))
544
+ img_right, img_front, img_env = imgs[:3]
545
+
546
+ # state
547
+ arm_state = normalization(robots.get_state_pos(arm='right_arm'))
548
+
549
+ # obs
550
+ obs = {
551
+ "video.cam_head": img_front[np.newaxis, ::],
552
+ "video.cam_env": img_env[np.newaxis, ::],
553
+ "video.cam_right_wrist": img_right[np.newaxis, ::],
554
+ "state.right_arm": np.expand_dims(np.array(arm_state[:6], dtype=np.float32), axis=0),
555
+ "state.right_gripper": np.expand_dims(np.array([arm_state[6]], dtype=np.float32), axis=0),
556
+ "annotation.human.task_description": [task_description]
557
+ }
558
+
559
+ start_time = time.time()
560
+ action_chunk = predict_actions(server_url, obs)
561
+ print(f"inference time cost: {time.time() - start_time:.3f}s")
562
+
563
+ current_actions = []
564
+ current_action_idx = 0
565
+ for arm, gripper in zip(action_chunk['action.right_arm'], action_chunk['action.right_gripper']):
566
+ action = np.asarray(list(arm) + [float(gripper)], dtype=np.float32)
567
+ current_actions.append(action)
568
+
569
+ print(f"Obtainbed {len(current_actions)} action steps...")
570
+
571
+ if current_action_idx < len(current_actions):
572
+ raw_action = current_actions[current_action_idx]
573
+ current_action_idx += 1
574
+
575
+ smoothed_action = action_smoother.smooth_action(raw_action)
576
+
577
+ denormalized_action = unnormalization(smoothed_action)
578
+ robots.move_to_pos(denormalized_action.tolist(), arm='right_arm')
579
+
580
+ if step_count % 10 == 0:
581
+ print(f"[Sync Mode] Step {step_count}: action_idx={current_action_idx}/{len(current_actions)}")
582
+ print(f" Raw: {raw_action}")
583
+ print(f" Smoothed: {smoothed_action}")
584
+
585
+ # 20Hz control
586
+ time.sleep(0.05)
587
+
588
+ except KeyboardInterrupt:
589
+ print("Received interrupt signal, exiting safely...")
590
+ break
591
+ except Exception as e:
592
+ print(f"SYNC control mode error: {e}")
593
+ time.sleep(0.05)
594
+
595
+ print("Clearing resources...")
596
+ try:
597
+ caps.release()
598
+ print("Camera resources released")
599
+ except Exception as e:
600
+ print(f"Camera clearing error: {e}")
601
+
602
+ print("Program ended")
tests/install_av_opencv.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ # make sure ffmpeg have av codec
3
+ sudo apt install libavcodec-dev libavformat-dev libavutil-dev libswscale-dev ffmpeg
4
+ export CMAKE_ARGS="-D WITH_FFMPEG=ON"
5
+ pip install --no-binary opencv-python --no-deps opencv-python -v
tests/modality.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "state": {
3
+ "left_arm": {
4
+ "start": 0,
5
+ "end": 6
6
+ },
7
+ "left_gripper": {
8
+ "start": 6,
9
+ "end": 7
10
+ },
11
+ "right_arm": {
12
+ "start": 7,
13
+ "end": 13
14
+ },
15
+ "right_gripper": {
16
+ "start": 13,
17
+ "end": 14
18
+ }
19
+ },
20
+ "action": {
21
+ "left_arm": {
22
+ "start": 0,
23
+ "end": 6
24
+ },
25
+ "left_gripper": {
26
+ "start": 6,
27
+ "end": 7
28
+ },
29
+ "right_arm": {
30
+ "start": 7,
31
+ "end": 13
32
+ },
33
+ "right_gripper": {
34
+ "start": 13,
35
+ "end": 14
36
+ }
37
+ },
38
+ "video": {
39
+ "cam_high": {
40
+ "original_key": "observation.images.cam_high"
41
+ },
42
+ "cam_left_wrist": {
43
+ "original_key": "observation.images.cam_left_wrist"
44
+ },
45
+ "cam_right_wrist": {
46
+ "original_key": "observation.images.cam_right_wrist"
47
+ }
48
+ },
49
+ "annotation": {
50
+ "human.action.task_description": {
51
+ "original_key": "task_index"
52
+ }
53
+ }
54
+ }
tests/replay_sl.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ load lerobot dataset and replay it on REAL ARM
3
+ '''
4
+ #!/usr/bin/env python3
5
+ # -*- coding: utf-8 -*-
6
+
7
+ '''
8
+ test for control Agilex arm.
9
+ '''
10
+
11
+ import os
12
+ import sys
13
+ import time
14
+ import argparse
15
+ import threading
16
+ import json
17
+ import pandas as pd
18
+ import numpy as np
19
+ import json
20
+ import time
21
+ import logging
22
+ import argparse
23
+ import threading
24
+ from pathlib import Path
25
+ from typing import Dict, Any, List, Optional
26
+
27
+ from piper_sdk import C_PiperInterface_V2
28
+
29
+ # 配置日志
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(levelname)s - %(message)s',
33
+ handlers=[
34
+ logging.FileHandler('arm_replay.log'),
35
+ logging.StreamHandler()
36
+ ]
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+ class ArmDataReplayer:
41
+ """机械臂数据重播器"""
42
+
43
+ def __init__(self, arms_config: Dict[str, str]):
44
+ """
45
+ 初始化重播器
46
+
47
+ Args:
48
+ arms_config: 机械臂配置 {'left': 'can_left', 'right': 'can_right'}
49
+ """
50
+ self.arms_config = arms_config
51
+
52
+ # 机械臂接口
53
+ self.arms = {}
54
+ self.is_connected = False
55
+
56
+ # 重播数据
57
+ self.replay_data = None
58
+ self.replay_metadata = None
59
+
60
+ # 重播控制
61
+ self.is_replaying = False
62
+ self.replay_thread = None
63
+ self.replay_speed = 1.0
64
+ self.current_frame = 0
65
+ self.total_frames = 0
66
+
67
+ # 重播统计
68
+ self.frames_replayed = 0
69
+ self.start_time = None
70
+
71
+ logger.info("数据重播器初始化完成")
72
+
73
+ def connect_arms(self) -> bool:
74
+ """连接机械臂"""
75
+ try:
76
+ for arm_name, can_name in self.arms_config.items():
77
+ logger.info(f"正在连接{arm_name}臂 ({can_name})...")
78
+
79
+ arm = C_PiperInterface_V2(
80
+ can_name=can_name,
81
+ judge_flag=False,
82
+ can_auto_init=True,
83
+ dh_is_offset=1,
84
+ start_sdk_joint_limit=False,
85
+ start_sdk_gripper_limit=False
86
+ )
87
+
88
+ arm.ConnectPort()
89
+ self.arms[arm_name] = arm
90
+
91
+ # 等待连接稳定
92
+ time.sleep(1.0)
93
+
94
+ arm.EnableArm()
95
+ time.sleep(0.5)
96
+
97
+ # 验证连接
98
+ test_msg = arm.GetArmJointMsgs()
99
+ if test_msg and test_msg.Hz > 0:
100
+ logger.info(f"✅ {arm_name}臂连接成功")
101
+ else:
102
+ logger.warning(f"⚠️ {arm_name}臂连接可能不稳定")
103
+
104
+ self.is_connected = True
105
+ logger.info("所有机械臂连接完成")
106
+ return True
107
+
108
+ except Exception as e:
109
+ logger.error(f"连接机械臂失败: {e}")
110
+ self.disconnect_arms()
111
+ return False
112
+
113
+ def disconnect_arms(self):
114
+ """断开机械臂连接"""
115
+ try:
116
+ for arm_name, arm in self.arms.items():
117
+ if arm:
118
+ arm.DisconnectPort()
119
+ logger.info(f"{arm_name}臂已断开连接")
120
+
121
+ self.arms.clear()
122
+ self.is_connected = False
123
+ logger.info("所有机械臂已断开连接")
124
+
125
+ except Exception as e:
126
+ logger.error(f"断开连接失败: {e}")
127
+
128
+ def load_hdf5_data(self, filename: str) -> bool:
129
+ """加载HDF5格式数据"""
130
+ try:
131
+ with h5py.File(filename, 'r') as f:
132
+ # 读取元数据
133
+ self.replay_metadata = {
134
+ 'collection_frequency': f['metadata'].attrs.get('collection_frequency', 30),
135
+ 'total_samples': f['metadata'].attrs.get('total_samples', 0),
136
+ 'arms': f['metadata'].attrs.get('arms', []),
137
+ 'start_time': f['metadata'].attrs.get('start_time', 0),
138
+ 'creation_time': f['metadata'].attrs.get('creation_time', '')
139
+ }
140
+
141
+ # 读取时间数据
142
+ timestamps = f['timestamps'][:]
143
+ relative_times = f['relative_times'][:]
144
+
145
+ # 读取机械臂数据
146
+ arms_data = {}
147
+ for arm_name in self.replay_metadata['arms']:
148
+ if f'arms/{arm_name}' in f:
149
+ arm_group = f[f'arms/{arm_name}']
150
+ arms_data[arm_name] = {
151
+ 'joint_positions': arm_group['joint_positions'][:],
152
+ 'gripper_angles': arm_group['gripper_angles'][:],
153
+ 'gripper_efforts': arm_group['gripper_efforts'][:]
154
+ }
155
+
156
+ # 兼容性处理:检查是否有gripper_positions字段
157
+ if 'gripper_positions' in arm_group:
158
+ arms_data[arm_name]['gripper_positions'] = arm_group['gripper_positions'][:]
159
+ else:
160
+ # 如果没有位置数据,使用零填充
161
+ arms_data[arm_name]['gripper_positions'] = np.zeros_like(arms_data[arm_name]['gripper_angles'])
162
+
163
+ # 重构数据格式
164
+ self.replay_data = []
165
+ for i in range(len(timestamps)):
166
+ frame = {
167
+ 'timestamp': timestamps[i],
168
+ 'relative_time': relative_times[i],
169
+ 'frame_index': i,
170
+ 'arms': {}
171
+ }
172
+
173
+ for arm_name, data in arms_data.items():
174
+ if i < len(data['joint_positions']):
175
+ frame['arms'][arm_name] = {
176
+ 'joint_positions': data['joint_positions'][i],
177
+ 'gripper_angle': data['gripper_angles'][i],
178
+ 'gripper_position': data['gripper_positions'][i],
179
+ 'gripper_effort': data['gripper_efforts'][i]
180
+ }
181
+
182
+ self.replay_data.append(frame)
183
+
184
+ self.total_frames = len(self.replay_data)
185
+ logger.info(f"HDF5数据加载完成: {self.total_frames}帧")
186
+ logger.info(f"包含机械臂: {self.replay_metadata['arms']}")
187
+ logger.info(f"原始采集频率: {self.replay_metadata['collection_frequency']}Hz")
188
+
189
+ return True
190
+
191
+ except Exception as e:
192
+ logger.error(f"加载HDF5数据失败: {e}")
193
+ return False
194
+
195
+ def load_lerobot_data(self, data_dir: str, episode_idx: int = 0) -> bool:
196
+ """
197
+ 加载LeRobot格式的数据集目录。
198
+
199
+ Args:
200
+ data_dir: LeRobot数据集的根目录路径。
201
+ episode_idx: 要加载的episode索引。
202
+ """
203
+ try:
204
+ logger.info(f"开始加载LeRobot数据集: {data_dir}, Episode: {episode_idx}")
205
+ root_path = Path(data_dir)
206
+
207
+ info_path = root_path / 'meta' / 'info.json'
208
+ data_path = root_path / 'data'
209
+
210
+ if not (root_path.is_dir() and info_path.exists() and data_path.exists()):
211
+ logger.error(f"无效的LeRobot数据集目录结构。缺少 meta/info.json 或 data 目录。")
212
+ return False
213
+
214
+ with open(info_path, 'r', encoding='utf-8') as f:
215
+ info = json.load(f)
216
+
217
+ self.replay_metadata = {
218
+ 'collection_frequency': info.get('fps', 30),
219
+ 'arms': ['left', 'right'],
220
+ 'lerobot_info': info
221
+ }
222
+ logger.info(f"数据集元信息加载完成。采集频率: {self.replay_metadata['collection_frequency']}Hz")
223
+
224
+ # 找到对应的parquet文件
225
+ # 假设 chunk 总是 000
226
+ parquet_file = data_path / 'chunk-000' / f'episode_{episode_idx:06d}.parquet'
227
+ if not parquet_file.exists():
228
+ logger.error(f"未找到Episode {episode_idx} 的数据文件: {parquet_file}")
229
+ return False
230
+
231
+ logger.info(f"正在读取数据文件: {parquet_file}")
232
+ df = pd.read_parquet(parquet_file)
233
+
234
+ self.replay_data = []
235
+
236
+ # LeRobot的action是puppet(从)臂的下一个state,所以我们用action作为目标
237
+ # names是一个嵌套列表
238
+ action_names = info['features']['action']['names'][0]
239
+
240
+ # 创建映射
241
+ joint_indices = {'left': [None]*6, 'right': [None]*6}
242
+ gripper_indices = {'left': -1, 'right': -1}
243
+
244
+ for i, name in enumerate(action_names):
245
+ if 'masterLeft' in name:
246
+ if 'joint6' in name:
247
+ gripper_indices['left'] = i
248
+ else:
249
+ # 从 '...joint0' 中提取数字 0
250
+ joint_num = int(name.split('joint')[-1])
251
+ if 0 <= joint_num < 6:
252
+ joint_indices['left'][joint_num] = i
253
+ elif 'masterRight' in name:
254
+ if 'joint6' in name:
255
+ gripper_indices['right'] = i
256
+ else:
257
+ joint_num = int(name.split('joint')[-1])
258
+ if 0 <= joint_num < 6:
259
+ joint_indices['right'][joint_num] = i
260
+
261
+ logger.info(f"解析出的左臂关节索引: {joint_indices['left']}")
262
+ logger.info(f"解析出的左臂夹爪索引: {gripper_indices['left']}")
263
+ logger.info(f"解析出的右臂关节索引: {joint_indices['right']}")
264
+ logger.info(f"解析出的右臂夹爪索引: {gripper_indices['right']}")
265
+
266
+
267
+ for i, row in df.iterrows():
268
+ frame = {
269
+ 'timestamp': row.get('timestamp', time.time()),
270
+ 'relative_time': row.get('timestamp', 0) - df['timestamp'].iloc[0] if 'timestamp' in df else i / self.replay_metadata['collection_frequency'],
271
+ 'frame_index': i,
272
+ 'arms': {}
273
+ }
274
+
275
+ action_values = row['action']
276
+
277
+ for arm_name in ['left', 'right']:
278
+ # LeRobot的action通常是下一个state,所以直接用作目标
279
+ joint_positions = np.array([action_values[j] for j in joint_indices[arm_name]])
280
+
281
+ # 夹爪数据,LeRobot通常是-1到1,需要映射到0-60度
282
+ # 这里的 aloha_arm 数据集,夹爪值在-1到1之间,-1为闭合,1为张开
283
+ gripper_action = action_values[gripper_indices[arm_name]]
284
+ # 映射: 1 -> 0度 (张开), -1 -> 60度 (闭合)
285
+ gripper_angle = (1 - gripper_action) / 2 * 60
286
+
287
+ frame['arms'][arm_name] = {
288
+ 'joint_positions': joint_positions,
289
+ 'gripper_angle': gripper_angle,
290
+ }
291
+
292
+ self.replay_data.append(frame)
293
+
294
+ self.total_frames = len(self.replay_data)
295
+ logger.info(f"LeRobot数据加载完成: {self.total_frames}帧")
296
+ logger.info(f"包含机械臂: {list(self.replay_data[0]['arms'].keys())}")
297
+
298
+ return True
299
+
300
+ except Exception as e:
301
+ logger.error(f"加载LeRobot数据失败: {e}", exc_info=True)
302
+ return False
303
+
304
+ def load_data(self, path: str, episode_idx: int = 0) -> bool:
305
+ """自动检测并加载数据"""
306
+ filepath = Path(path)
307
+ if not filepath.exists():
308
+ logger.error(f"数据文件或目录不存在: {path}")
309
+ return False
310
+
311
+ if filepath.is_dir():
312
+ # 认为是LeRobot数据集目录
313
+ return self.load_lerobot_data(path, episode_idx=episode_idx)
314
+ elif filepath.suffix.lower() == '.h5' or filepath.suffix.lower() == '.hdf5':
315
+ return self.load_hdf5_data(path)
316
+ # elif filepath.suffix.lower() == '.json': # 旧的逻辑,暂时禁用
317
+ else:
318
+ logger.error(f"不支持的文件格式或路径类型: {path}")
319
+ return False
320
+
321
+ def back_to_zero_position(self) -> bool:
322
+
323
+ for k, piper in self.arms:
324
+ logger.info(f'==> processing {k}')
325
+ # piper = self.arms.get('right', None)
326
+ # piper.JointConfig(joint_num=7, set_zero=0xAE)
327
+ # piper.GripperCtrl(set_zero=0xAE)
328
+
329
+ piper.JointCtrl(
330
+ joint_1=0, # 0度
331
+ joint_2=0, # 0度
332
+ joint_3=0, # 0度
333
+ joint_4=0, # 0度
334
+ joint_5=0, # 0度
335
+ joint_6=0 # 0度
336
+ )
337
+
338
+ joint_msgs = piper.GetArmJointMsgs()
339
+ print(f"关节状态: {joint_msgs}")
340
+ joint1_angle = joint_msgs.joint_state.joint_1
341
+ print(f"关节1角度: {joint1_angle/1000.0} 度")
342
+ joint2_angle = joint_msgs.joint_state.joint_2
343
+ print(f"关节2角度: {joint2_angle/1000.0} 度")
344
+ joint3_angle = joint_msgs.joint_state.joint_3
345
+ print(f"关节3角度: {joint3_angle/1000.0} 度")
346
+ joint4_angle = joint_msgs.joint_state.joint_4
347
+ print(f"关节4角度: {joint4_angle/1000.0} 度")
348
+ joint5_angle = joint_msgs.joint_state.joint_5
349
+ print(f"关节5角度: {joint5_angle/1000.0} 度")
350
+ joint6_angle = joint_msgs.joint_state.joint_6
351
+ print(f"关节6角度: {joint6_angle/1000.0} 度")
352
+ joint7_angle = joint_msgs.joint_state.joint_7
353
+ print(f"关节7角度: {joint7_angle/1000.0} 度")
354
+
355
+ # 获取夹爪状态
356
+ gripper_msgs = piper.GetArmGripperMsgs()
357
+ print(f"夹爪状态: {gripper_msgs}")
358
+ return True
359
+
360
+ def move_to_start_position(self) -> bool:
361
+ """移动到起始位置"""
362
+ if not self.is_connected or not self.replay_data:
363
+ logger.error("机械臂未连接或数据未加载")
364
+ return False
365
+
366
+ try:
367
+ start_frame = self.replay_data[0]
368
+ logger.info("正在移动到起始位置...")
369
+ print(start_frame)
370
+ self.back_to_zero_position()
371
+
372
+ for arm_name, arm in self.arms.items():
373
+ if arm_name in start_frame['arms']:
374
+ target_joint_positions = start_frame['arms'][arm_name]['joint_positions']
375
+ target_gripper_angle = start_frame['arms'][arm_name]['gripper_angle']
376
+
377
+ # 转换弧度到0.001度
378
+ target_joint_positions_deg = (target_joint_positions * 180.0 * 1000.0 / np.pi).astype(int)
379
+
380
+ # 设置控制模式和速度
381
+ arm.MotionCtrl_2(ctrl_mode=0x01, move_mode=0x01, move_spd_rate_ctrl=20)
382
+
383
+ # 发送关节控制指令
384
+ arm.JointCtrl(
385
+ int(target_joint_positions_deg[0]), int(target_joint_positions_deg[1]),
386
+ int(target_joint_positions_deg[2]), int(target_joint_positions_deg[3]),
387
+ int(target_joint_positions_deg[4]), int(target_joint_positions_deg[5])
388
+ )
389
+ # 夹爪控制
390
+ arm.GripperCtrl(int(target_gripper_angle * 1000.0), 0, 0x01, 0)
391
+
392
+ logger.info(f"{arm_name}臂目标关节位置 (0.001度): {target_joint_positions_deg.tolist()}")
393
+ logger.info(f"{arm_name}臂目标夹爪位置 (度): {target_gripper_angle:.2f}")
394
+
395
+ # 等待移动完成
396
+ logger.info("等待机械臂移动到起始位置...")
397
+ time.sleep(3.0)
398
+ logger.info("==> 移动到起始位置完成")
399
+ return True
400
+
401
+ except Exception as e:
402
+ logger.error(f"移动到起始位置失败: {e}")
403
+ return False
404
+
405
+ def reset_to_zero(self):
406
+ """重置到零位置"""
407
+ for arm_name, arm in self.arms.items():
408
+ arm.GripperCtrl(0, 0, 0x01, 0)
409
+ arm.JointCtrl(0, 0, 0, 0, 0, 0)
410
+ time.sleep(0.5)
411
+ logger.info(f"{arm_name}臂重置到零位置完成")
412
+ return True
413
+
414
+ def start_replay(self, speed: float = 1.0, start_frame: int = 0, end_frame: Optional[int] = None):
415
+ """开始重播"""
416
+ if not self.is_connected or not self.replay_data:
417
+ logger.error("机械臂未连接或数据未加载")
418
+ return False
419
+
420
+ if self.is_replaying:
421
+ logger.warning("重播已在进行中")
422
+ return False
423
+
424
+ self.replay_speed = speed
425
+ self.current_frame = start_frame
426
+ self.frames_replayed = 0
427
+ self.start_time = time.time()
428
+
429
+ if end_frame is None:
430
+ end_frame = self.total_frames
431
+
432
+ # 启动重播线程
433
+ self.is_replaying = True
434
+ self.replay_thread = threading.Thread(
435
+ target=self._replay_loop,
436
+ args=(start_frame, end_frame),
437
+ daemon=True
438
+ )
439
+ self.replay_thread.start()
440
+
441
+ logger.info(f"开始重播数据")
442
+ logger.info(f"重播速度: {speed}x")
443
+ logger.info(f"帧范围: {start_frame} - {end_frame}")
444
+
445
+ return True
446
+
447
+ def stop_replay(self):
448
+ """停止重播"""
449
+ if not self.is_replaying:
450
+ return
451
+
452
+ self.is_replaying = False
453
+
454
+ if self.replay_thread:
455
+ self.replay_thread.join(timeout=2.0)
456
+
457
+ logger.info("重播已停止")
458
+ logger.info(f"总重播帧数: {self.frames_replayed}")
459
+
460
+ def _replay_loop(self, start_frame: int, end_frame: int):
461
+ """重播循环"""
462
+ original_frequency = self.replay_metadata.get('collection_frequency', 30)
463
+ base_dt = 1.0 / original_frequency
464
+ adjusted_dt = base_dt / self.replay_speed
465
+
466
+ logger.info(f"原始频率: {original_frequency}Hz, 调整后间隔: {adjusted_dt:.4f}s")
467
+
468
+ for frame_idx in range(start_frame, min(end_frame, self.total_frames)):
469
+ if not self.is_replaying:
470
+ break
471
+
472
+ frame_start_time = time.time()
473
+
474
+ try:
475
+ frame = self.replay_data[frame_idx]
476
+ self.current_frame = frame_idx
477
+
478
+ # 执行关节控制
479
+ for arm_name, arm in self.arms.items():
480
+ if arm_name in frame['arms']:
481
+ arm_data = frame['arms'][arm_name]
482
+ target_joint_positions = arm_data['joint_positions']
483
+ target_gripper_angle = arm_data['gripper_angle']
484
+
485
+ # 转换弧度到0.001度
486
+ target_joint_positions_deg = (target_joint_positions * 180.0 * 1000.0 / np.pi).astype(int)
487
+
488
+ # 发送关节控制指令
489
+ arm.JointCtrl(
490
+ int(target_joint_positions_deg[0]), int(target_joint_positions_deg[1]),
491
+ int(target_joint_positions_deg[2]), int(target_joint_positions_deg[3]),
492
+ int(target_joint_positions_deg[4]), int(target_joint_positions_deg[5])
493
+ )
494
+ # 夹爪控制
495
+ arm.GripperCtrl(int(target_gripper_angle * 1000.0), 0, 0x01, 0)
496
+
497
+ self.frames_replayed += 1
498
+
499
+ # 等待下一个周期
500
+ elapsed_time = time.time() - frame_start_time
501
+ sleep_time = adjusted_dt - elapsed_time
502
+ if sleep_time > 0:
503
+ time.sleep(sleep_time)
504
+
505
+ except Exception as e:
506
+ logger.error(f"重播帧 {frame_idx} 时发生错误: {e}")
507
+ self.is_replaying = False
508
+ break
509
+
510
+ self.is_replaying = False
511
+ logger.info("重播循环结束")
512
+ self.reset_to_zero()
513
+
514
+
515
+ def get_replay_info(self) -> Dict[str, Any]:
516
+ """获取重播信息"""
517
+ return {
518
+ 'total_frames': self.total_frames,
519
+ 'current_frame': self.current_frame,
520
+ 'frames_replayed': self.frames_replayed,
521
+ 'is_replaying': self.is_replaying,
522
+ 'replay_speed': self.replay_speed,
523
+ 'metadata': self.replay_metadata
524
+ }
525
+
526
+ def get_frame_data(self, frame_index: int) -> Optional[Dict[str, Any]]:
527
+ """获取指定帧的数据"""
528
+ if not self.replay_data or frame_index >= len(self.replay_data):
529
+ return None
530
+ return self.replay_data[frame_index]
531
+
532
+ def seek_to_frame(self, frame_index: int):
533
+ """跳转到指定帧"""
534
+ if not self.replay_data:
535
+ logger.error("没有数据可跳转")
536
+ return False
537
+
538
+ if frame_index < 0 or frame_index >= self.total_frames:
539
+ logger.error(f"帧索引超出范围: {frame_index}")
540
+ return False
541
+
542
+ self.current_frame = frame_index
543
+ logger.info(f"跳转到帧: {frame_index}")
544
+ return True
545
+
546
+
547
+ def main():
548
+ """主函数"""
549
+ parser = argparse.ArgumentParser(description="机械臂数据重播工具")
550
+ parser.add_argument(
551
+ '--path',
552
+ type=str,
553
+ required=True,
554
+ help="要重播的数据文件路径(.h5)或LeRobot数据集目录"
555
+ )
556
+ parser.add_argument(
557
+ '--episode',
558
+ type=int,
559
+ default=0,
560
+ help="当路径为LeRobot目录时,指定要重播的episode索引"
561
+ )
562
+ parser.add_argument(
563
+ '--speed',
564
+ type=float,
565
+ default=1.0,
566
+ help="重播速度倍率"
567
+ )
568
+ parser.add_argument(
569
+ '--left-can',
570
+ type=str,
571
+ default='can_left',
572
+ help="左臂CAN接口名称"
573
+ )
574
+ parser.add_argument(
575
+ '--right-can',
576
+ type=str,
577
+ default='can_right',
578
+ help="右臂CAN接口名称"
579
+ )
580
+
581
+ args = parser.parse_args()
582
+
583
+ arms_config = {
584
+ 'left': args.left_can,
585
+ 'right': args.right_can
586
+ }
587
+
588
+ replayer = ArmDataReplayer(arms_config)
589
+
590
+ # 连接机械臂
591
+ if not replayer.connect_arms():
592
+ logger.error("无法连接到机械臂,程序退出。")
593
+ return
594
+
595
+ # 加载数据
596
+ if not replayer.load_data(args.path, episode_idx=args.episode):
597
+ logger.error("数据加载失败,程序退出。")
598
+ replayer.disconnect_arms()
599
+ return
600
+
601
+ try:
602
+ # 移动到起始位置
603
+ if not replayer.move_to_start_position():
604
+ logger.error("移动到起始位置失败,程序退出。")
605
+ return
606
+
607
+ logger.info("准备开始重播...")
608
+ # # 开始重播
609
+ # replayer.start_replay(speed=args.speed)
610
+
611
+ # # 保持主线程运行,直到重播完成或用户中断
612
+ # while replayer.is_replaying:
613
+ # info = replayer.get_replay_info()
614
+ # print(f"\r重播中... 帧: {info['current_frame']}/{info['total_frames']} ({(info['current_frame']+1)*100/info['total_frames']:.1f}%)", end="")
615
+ # time.sleep(0.5)
616
+
617
+ # print("\n重播完成。")
618
+
619
+ except KeyboardInterrupt:
620
+ logger.info("接收到中断信号,正在停止...")
621
+
622
+ finally:
623
+ replayer.stop_replay()
624
+ replayer.disconnect_arms()
625
+ logger.info("程序已清理并退出。")
626
+
627
+ if __name__ == "__main__":
628
+ main()
tests/save_s1.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Saving a s1 pretrained model for training
3
+
4
+ """
5
+
6
+ import torch
7
+ from starforce.model.starforce_s1 import Starforce_S1, Starforce_S1_Config
8
+ from starforce.model.action_head.flow_matching_action_head import FlowmatchingActionHeadConfig
9
+
10
+ config = Starforce_S1_Config()
11
+ config.backbone_cfg = {
12
+ "tune_llm": False,
13
+ # "vllm_base_model_path": "Qwen/Qwen2.5-VL-7B-Instruct",
14
+ "vllm_base_model_path": "/pfs/pfs-ahGxdf/data/wujingyi/huggingface/Qwen2.5-VL-3B-Instruct",
15
+ "select_layer": 12,
16
+ "feature_dim": 2048,
17
+ "project_to_dim": 2048,
18
+ }
19
+ config.action_horizon = 16
20
+ config.action_dim = 32
21
+ config.action_head_cfg = {
22
+ "action_dim": 32,
23
+ "action_horizon": 16,
24
+ "add_pos_embed": True,
25
+ "backbone_embedding_dim": 2048,
26
+ "diffusion_model_cfg": {
27
+ "attention_head_dim": 48,
28
+ "cross_attention_dim": 2048,
29
+ "dropout": 0.2,
30
+ "final_dropout": True,
31
+ "interleave_self_attention": True,
32
+ "norm_type": "ada_norm",
33
+ "num_attention_heads": 32,
34
+ "num_layers": 16,
35
+ "output_dim": 1024,
36
+ "positional_embeddings": None,
37
+ },
38
+ "hidden_size": 1024,
39
+ "input_embedding_dim": 1536,
40
+ "max_action_dim": 32,
41
+ "max_state_dim": 64,
42
+ "model_dtype": "float32",
43
+ "noise_beta_alpha": 1.5,
44
+ "noise_beta_beta": 1.0,
45
+ "noise_s": 0.999,
46
+ "num_inference_timesteps": 4,
47
+ "num_target_vision_tokens": 32,
48
+ "num_timestep_buckets": 1000,
49
+ "tune_diffusion_model": True,
50
+ "tune_projector": True,
51
+ "use_vlln": True,
52
+ "vl_self_attention_cfg": {
53
+ "attention_head_dim": 64,
54
+ "dropout": 0.2,
55
+ "final_dropout": True,
56
+ "num_attention_heads": 32,
57
+ "num_layers": 4,
58
+ "positional_embeddings": None,
59
+ },
60
+ }
61
+ model = Starforce_S1(config=config, local_model_path=None)
62
+
63
+ # action_head_state_dict = torch.load("checkpoints/GR00T-N1.5-3B-action-expert.pth")
64
+ action_head_state_dict = torch.load("checkpoints/qz-action-expert.pth")
65
+ model.action_head.load_state_dict(action_head_state_dict)
66
+
67
+ model.save_pretrained("checkpoints/Starforce-S1-3B")
68
+
69
+ print(model.)
70
+ print("done!")
tests/save_s1_7B.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Saving a s1 pretrained model for training
3
+
4
+ """
5
+
6
+ import torch
7
+ from starforce.model.starforce_s1 import Starforce_S1, Starforce_S1_Config
8
+ from starforce.model.action_head.flow_matching_action_head import FlowmatchingActionHeadConfig
9
+
10
+ config = Starforce_S1_Config()
11
+ config.backbone_cfg = {
12
+ "tune_llm": False,
13
+ # "vllm_base_model_path": "Qwen/Qwen2.5-VL-7B-Instruct",
14
+ "vllm_base_model_path": "/pfs/pfs-ahGxdf/data/wujingyi/huggingface/Qwen2.5-VL-7B-Instruct",
15
+ "select_layer": 12,
16
+ "feature_dim": 3584,
17
+ "project_to_dim": 2048,
18
+ }
19
+ config.action_horizon = 16
20
+ config.action_dim = 32
21
+ config.action_head_cfg = {
22
+ "action_dim": 32,
23
+ "action_horizon": 16,
24
+ "add_pos_embed": True,
25
+ "backbone_embedding_dim": 2048,
26
+ "diffusion_model_cfg": {
27
+ "attention_head_dim": 48,
28
+ "cross_attention_dim": 2048,
29
+ "dropout": 0.2,
30
+ "final_dropout": True,
31
+ "interleave_self_attention": True,
32
+ "norm_type": "ada_norm",
33
+ "num_attention_heads": 32,
34
+ "num_layers": 16,
35
+ "output_dim": 1024,
36
+ "positional_embeddings": None,
37
+ },
38
+ "hidden_size": 1024,
39
+ "input_embedding_dim": 1536,
40
+ "max_action_dim": 32,
41
+ "max_state_dim": 64,
42
+ "model_dtype": "float32",
43
+ "noise_beta_alpha": 1.5,
44
+ "noise_beta_beta": 1.0,
45
+ "noise_s": 0.999,
46
+ "num_inference_timesteps": 4,
47
+ "num_target_vision_tokens": 32,
48
+ "num_timestep_buckets": 1000,
49
+ "tune_diffusion_model": True,
50
+ "tune_projector": True,
51
+ "use_vlln": True,
52
+ "vl_self_attention_cfg": {
53
+ "attention_head_dim": 64,
54
+ "dropout": 0.2,
55
+ "final_dropout": True,
56
+ "num_attention_heads": 32,
57
+ "num_layers": 4,
58
+ "positional_embeddings": None,
59
+ },
60
+ }
61
+ model = Starforce_S1(config=config, local_model_path=None)
62
+
63
+ # action_head_state_dict = torch.load("checkpoints/GR00T-N1.5-3B-action-expert.pth")
64
+ action_head_state_dict = torch.load("checkpoints/qz-action-expert.pth")
65
+ model.action_head.load_state_dict(action_head_state_dict)
66
+
67
+ model.save_pretrained("checkpoints/Starforce-S1-7B")
68
+
69
+ print("done!")
tests/test_cv.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+
4
+ video_path = "data/sl/0723pre_data_v1/videos/chunk-000/observation.images.cam_high/episode_000045.mp4" # Replace with your actual video path
5
+ video_path = 'data/test_aloha_singlearm/videos/chunk-000/observation.images.cam_high/episode_000000.mp4'
6
+ print("OpenCV build info:")
7
+ print(cv2.getBuildInformation())
8
+
9
+ cap = cv2.VideoCapture(video_path)
10
+
11
+ if not cap.isOpened():
12
+ print(f"Error: Failed to open video file: {video_path}")
13
+ sys.exit(1)
14
+
15
+ ret, frame = cap.read()
16
+ if not ret:
17
+ print("Error: Unable to read the first frame of the video.")
18
+ cap.release()
19
+ sys.exit(1)
20
+
21
+ print("Success: Video file opened and first frame read successfully.")
22
+ cap.release()
tests/test_hf.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
+
5
+ from transformers import AutoConfig
6
+ from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
7
+ from transformers import AutoModelForCausalLM, AutoModel
8
+ from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
9
+ import torch
10
+
11
+
12
+ vllm_base_model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
13
+
14
+ config = Qwen2_5_VLConfig.from_pretrained(vllm_base_model_path, trust_remote_code=True)
15
+ vllm_model = Qwen2_5_VLForConditionalGeneration(
16
+ config=config,
17
+ )
18
+
19
+ print(vllm_model)
tests/test_pi0.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from starforce.datasets.lerobot_dataset import LeRobotDataset
2
+ import torch
3
+ from starforce.models.pretrained import PreTrainedConfig
4
+ from starforce.models.build_model import make_policy
5
+ from loguru import logger
6
+ import time
7
+
8
+
9
+ def main():
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
12
+ # paligemma doesn't support bf16?
13
+ # dtype = torch.float32
14
+ logger.info(f"##info, device: {device}, dtype: {dtype}")
15
+
16
+ # dataset_repo_id = "danaaubakirova/koch_test"
17
+ # dataset_repo_id = "data/robotwin2lerobot/block_hammer_beat"
18
+ dataset_repo_id = "/pfs/data/xiongxiao/lerobot_fps30/open_laptop"
19
+ # ckpt_torch_dir = "/pfs/data/fgang/vla_holo/checkpoints/pi0"
20
+ # ckpt_torch_dir = "/pfs/data/fgang/outputs_models/pi0-1-20000/pretrained_model"
21
+ # ckpt_torch_dir = "/pfs/data/fgang/outputs_models/pi0-robotwin-30fps-tasks3"
22
+ # ckpt_torch_dir = "/pfs/data/fgang/vla_holo/outputs/pi0-fixed-20ksteps"
23
+ ckpt_torch_dir = "/pfs/data/fgang/outputs_models/pi0-robotwin-30fps-tasks5"
24
+
25
+ dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
26
+ dataloader = torch.utils.data.DataLoader(
27
+ dataset,
28
+ num_workers=0,
29
+ batch_size=1,
30
+ )
31
+ batch = next(iter(dataloader))
32
+ # To device
33
+ for k in batch:
34
+ if isinstance(batch[k], torch.Tensor):
35
+ batch[k] = batch[k].to(device=device, dtype=dtype)
36
+ print(f'dataset.meta: {dataset.meta}')
37
+
38
+ cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir, device=device)
39
+ cfg.pretrained_path = ckpt_torch_dir
40
+ policy = make_policy(cfg, ds_meta=dataset.meta)
41
+ # policy.to(dtype)
42
+ # print(policy)
43
+
44
+ t0 = time.time()
45
+ with torch.amp.autocast(device_type=device):
46
+ benchmark_iters = 30
47
+ for _ in range(benchmark_iters):
48
+ # print(batch)
49
+ t00 = time.time()
50
+ action = policy.select_action(batch, n_steps_out=50)
51
+ torch.cuda.synchronize()
52
+ # print("##info, action:", action.shape, action.dtype, action.device, action, time.time() - t00)
53
+ t1 = time.time()
54
+ print(f'cost: {t1-t0:.3f}, avg: {(t1-t0)/benchmark_iters}')
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
tests/test_starhelm.py ADDED
File without changes
tests/test_tensor.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ write a code read model weights from
3
+ checkpoints/models/pi0-sl-b1/model.safetensors
4
+ print all key and values
5
+ '''
6
+
7
+ '''
8
+ write a code read model weights from
9
+ checkpoints/models/pi0-sl-b1/model.safetensors
10
+ print all key and values
11
+ '''
12
+
13
+ from safetensors.torch import load_file
14
+
15
+ # Load the weights from the safetensors file
16
+ weights = load_file("checkpoints/models/pi0-sl-b1/model.safetensors")
17
+
18
+ # Print all keys and their corresponding values' shapes
19
+ for key, value in weights.items():
20
+ if 'normalize' in key:
21
+ print(f"Key: {key}, Shape: {value.shape}, Type: {value.dtype} {value}")
tests/vis_lerobot_data.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from vlaholo.datasets.lerobot_dataset import LeRobotDataset
4
+ import os
5
+ import cv2
6
+ from matplotlib.animation import FuncAnimation
7
+
8
+ """
9
+ TODO:
10
+
11
+ support datasets == 4.0
12
+ """
13
+
14
+
15
+ def plot_episode_joint_states(dataset_path: str, episode_index: int):
16
+ dataset = LeRobotDataset(dataset_path)
17
+
18
+ if episode_index >= dataset.num_episodes:
19
+ print(
20
+ f"episode index {episode_index} is out of range, total episodes: {dataset.num_episodes}"
21
+ )
22
+ episode_index = dataset.num_episodes - 1
23
+ print(f"force set to max episode index: {episode_index}")
24
+
25
+ hf_dataset = dataset.hf_dataset
26
+ episode_ds = hf_dataset.filter(lambda x: x["episode_index"] == episode_index)
27
+ video_paths = dataset.encode_episode_videos(episode_index=episode_index)
28
+
29
+ caps = {}
30
+ for key, path in video_paths.items():
31
+ cap = cv2.VideoCapture(path)
32
+ if not cap.isOpened():
33
+ raise ValueError(f"Could not open video: {path}")
34
+ caps[key] = cap
35
+
36
+ fps = caps[next(iter(caps))].get(cv2.CAP_PROP_FPS)
37
+ total_frames = int(caps[next(iter(caps))].get(cv2.CAP_PROP_FRAME_COUNT))
38
+
39
+ df = episode_ds.to_pandas()
40
+ joint_states = np.vstack(df["observation.state"].values)
41
+ timestamps = df["timestamp"].values
42
+ duration_sec = timestamps[-1] - timestamps[0]
43
+
44
+ joint_names = dataset.features["observation.state"]["names"]
45
+ if isinstance(joint_names, list) and len(joint_names) == 1 and isinstance(joint_names[0], list):
46
+ joint_names = joint_names[0]
47
+ if len(joint_names) <= 1:
48
+ joint_names = [f"Joint {i}" for i in range(joint_states.shape[1])]
49
+
50
+ n_joints = joint_states.shape[1]
51
+ n_joint_rows = (n_joints + 2) // 3
52
+
53
+ # 创建 Figure 并设置对称左右边距及顶部空间
54
+ fig = plt.figure(figsize=(18, 4 + 4 * n_joint_rows))
55
+ fig.subplots_adjust(top=0.92, bottom=0.05, left=0.05, right=0.95, hspace=0.4, wspace=0.3)
56
+
57
+ # 一级标题:字体更大(48号)
58
+ fig.suptitle(
59
+ "Starforce Data Inspect System", x=0.5, y=0.98, fontsize=35, fontweight="bold", ha="center"
60
+ )
61
+ # 二级统计信息:fps、总帧数、episode index、轨迹时长,16号字体
62
+ stats_text = (
63
+ f"FPS: {fps:.2f} Total frames: {total_frames} "
64
+ f"Episode: {episode_index} Duration: {duration_sec:.2f}s"
65
+ )
66
+ fig.text(0.5, 0.92, stats_text, ha="center", fontsize=21, fontweight="bold")
67
+
68
+ plt.rcParams.update(
69
+ {
70
+ "font.family": "sans-serif",
71
+ "font.sans-serif": ["Arial", "DejaVu Sans"],
72
+ "font.size": 12,
73
+ "axes.titlesize": 14,
74
+ "axes.labelsize": 13,
75
+ "axes.spines.top": False,
76
+ "axes.spines.right": False,
77
+ }
78
+ )
79
+
80
+ # 使用均等宽度的 GridSpec
81
+ gs = fig.add_gridspec(
82
+ n_joint_rows + 1, 3, width_ratios=[1, 1, 1], height_ratios=[2] + [1] * n_joint_rows
83
+ )
84
+
85
+ # 渲染视频区域
86
+ video_axes, video_imgs = {}, {}
87
+ for idx, key in enumerate(video_paths.keys()):
88
+ ax = fig.add_subplot(gs[0, idx])
89
+ ax.set_xticks([])
90
+ ax.set_yticks([])
91
+ ax.set_title(key)
92
+ img = ax.imshow(np.zeros((480, 640, 3)), aspect="auto")
93
+ ax.set_box_aspect(480 / 640)
94
+ video_axes[key] = ax
95
+ video_imgs[key] = img
96
+
97
+ # 绘制轨迹
98
+ joint_axes, lines, time_lines = [], [], []
99
+ base_colors = [
100
+ "#1f77b4",
101
+ "#ff7f0e",
102
+ "#2ca02c",
103
+ "#d62728",
104
+ "#9467bd",
105
+ "#8c564b",
106
+ "#e377c2",
107
+ "#7f7f7f",
108
+ "#bcbd22",
109
+ "#17becf",
110
+ ]
111
+ colors = (base_colors * ((n_joints // len(base_colors)) + 1))[:n_joints]
112
+
113
+ for i in range(n_joints):
114
+ row, col = 1 + i // 3, i % 3
115
+ ax = fig.add_subplot(gs[row, col])
116
+
117
+ # 渐变背景
118
+ gradient = np.linspace(0, 1, 256).reshape(256, 1)
119
+ extent = [timestamps[0], timestamps[-1], joint_states[:, i].min(), joint_states[:, i].max()]
120
+ ax.imshow(
121
+ np.repeat(gradient, 256, axis=1),
122
+ aspect="auto",
123
+ cmap="Blues",
124
+ alpha=0.1,
125
+ extent=extent,
126
+ origin="lower",
127
+ zorder=0,
128
+ )
129
+
130
+ # 轨迹线
131
+ (line,) = ax.plot([], [], label=joint_names[i], color=colors[i], linewidth=2.5, zorder=1)
132
+ lines.append(line)
133
+
134
+ # 轴设置
135
+ ax.set_xlabel("Time (s)")
136
+ ax.set_ylabel("pos")
137
+ ax.spines["left"].set_visible(False)
138
+
139
+ ax.set_title(joint_names[i], fontweight="bold")
140
+ ax.set_xlim(timestamps[0], timestamps[-1])
141
+ y0, y1 = joint_states[:, i].min(), joint_states[:, i].max()
142
+ m = (y1 - y0) * 0.1
143
+ ax.set_ylim(y0 - m, y1 + m)
144
+
145
+ tl = ax.axvline(x=timestamps[0], color="crimson", alpha=0.7, linewidth=1.2, zorder=2)
146
+ time_lines.append(tl)
147
+ joint_axes.append(ax)
148
+
149
+ # 初始化与动画函数
150
+ def init():
151
+ for ln in lines:
152
+ ln.set_data([], [])
153
+ return lines + time_lines + list(video_imgs.values())
154
+
155
+ def animate(frame_idx):
156
+ idx = min(frame_idx, len(timestamps) - 1)
157
+ t = timestamps[idx]
158
+ print(
159
+ f"\rProcessing frames: {frame_idx + 1}/{total_frames} ({(frame_idx+1)/total_frames*100:.1f}%)",
160
+ end="",
161
+ flush=True,
162
+ )
163
+
164
+ for key, cap in caps.items():
165
+ ret, frame = cap.read()
166
+ if ret:
167
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
168
+ video_imgs[key].set_array(frame)
169
+ for j, ln in enumerate(lines):
170
+ ln.set_data(timestamps[: idx + 1], joint_states[: idx + 1, j])
171
+ for tl in time_lines:
172
+ tl.set_xdata([t, t])
173
+ return lines + time_lines + list(video_imgs.values())
174
+
175
+ anim = FuncAnimation(
176
+ fig, animate, init_func=init, frames=total_frames, interval=1000 / fps, blit=True
177
+ )
178
+ print()
179
+
180
+ save_dir = "outputs/"
181
+ os.makedirs(save_dir, exist_ok=True)
182
+ out_path = os.path.join(save_dir, f"episode_{episode_index}_animation.mp4")
183
+ anim.save(out_path, writer="ffmpeg", fps=fps)
184
+
185
+ plt.close()
186
+ for cap in caps.values():
187
+ cap.release()
188
+ print(f"Animation saved to: {out_path}")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ import argparse
193
+
194
+ parser = argparse.ArgumentParser(
195
+ description="Visualize joint states of a LeRobot dataset episode"
196
+ )
197
+ parser.add_argument("dataset_path", type=str, help="Path or HF repo ID of the LeRobot dataset")
198
+ parser.add_argument("-i", type=int, default=89, help="Episode index to visualize")
199
+ args = parser.parse_args()
200
+ plot_episode_joint_states(args.dataset_path, args.i)
tests/vis_lerobot_data_v1.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from vlaholo.datasets.lerobot_dataset import LeRobotDataset
4
+ import os
5
+ import cv2
6
+ from matplotlib.animation import FuncAnimation
7
+ from vlaholo.utils.dataset_utils import DEFAULT_VIDEO_PATH
8
+
9
+
10
+ def plot_episode_joint_states(dataset_path: str, episode_index: int):
11
+
12
+ dataset = LeRobotDataset(dataset_path)
13
+
14
+ if episode_index > dataset.num_episodes:
15
+ print(
16
+ f"episode index {episode_index} is out of range, total episodes: {dataset.num_episodes}"
17
+ )
18
+ episode_index = dataset.num_episodes - 1
19
+ print(f"force set to max episode index: {episode_index}")
20
+
21
+ hf_dataset = dataset.hf_dataset
22
+ episode_ds = hf_dataset.filter(lambda x: x["episode_index"] == episode_index)
23
+ video_paths = dataset.encode_episode_videos(episode_index=episode_index)
24
+
25
+ caps = {}
26
+ for key, path in video_paths.items():
27
+ cap = cv2.VideoCapture(path)
28
+ if not cap.isOpened():
29
+ raise ValueError(f"Could not open video: {path}")
30
+ caps[key] = cap
31
+
32
+ fps = caps[list(caps.keys())[0]].get(cv2.CAP_PROP_FPS)
33
+ total_frames = int(caps[list(caps.keys())[0]].get(cv2.CAP_PROP_FRAME_COUNT))
34
+
35
+ df = episode_ds.to_pandas()
36
+ joint_states = np.vstack(df["observation.state"].values)
37
+ timestamps = df["timestamp"].values
38
+
39
+ fig = plt.figure(figsize=(15, 10))
40
+ gs = fig.add_gridspec(3, 2)
41
+
42
+ video_names_map = {
43
+ "observation.images.cam_high": "High Camera",
44
+ "observation.images.cam_left_wrist": "Left Wrist Camera",
45
+ "observation.images.cam_right_wrist": "Right Wrist Camera",
46
+ }
47
+
48
+ video_axes = {
49
+ "observation.images.cam_high": fig.add_subplot(gs[0, 0]),
50
+ "observation.images.cam_left_wrist": fig.add_subplot(gs[0, 1]),
51
+ "observation.images.cam_right_wrist": fig.add_subplot(gs[1, :]),
52
+ }
53
+
54
+ joint_ax = fig.add_subplot(gs[2, :])
55
+
56
+ video_imgs = {}
57
+ for key, ax in video_axes.items():
58
+ ax.set_xticks([])
59
+ ax.set_yticks([])
60
+ ax.set_title(video_names_map[key])
61
+ img = ax.imshow(np.zeros((480, 640, 3)))
62
+ video_imgs[key] = img
63
+
64
+ n_joints = joint_states.shape[1]
65
+ lines = []
66
+ for i in range(n_joints):
67
+ (line,) = joint_ax.plot([], [], label=f"Joint {i}")
68
+ lines.append(line)
69
+
70
+ joint_ax.set_xlim(timestamps[0], timestamps[-1])
71
+ joint_ax.set_ylim(joint_states.min(), joint_states.max())
72
+ joint_ax.grid(True)
73
+ joint_ax.legend(loc="upper right")
74
+ joint_ax.set_xlabel("Time (s)")
75
+
76
+ time_line = joint_ax.axvline(x=timestamps[0], color="r")
77
+
78
+ def init():
79
+ for line in lines:
80
+ line.set_data([], [])
81
+ return lines + [time_line] + list(video_imgs.values())
82
+
83
+ def animate(frame_idx):
84
+ data_idx = min(frame_idx, len(timestamps) - 1)
85
+ current_time = timestamps[data_idx]
86
+
87
+ for key, cap in caps.items():
88
+ ret, frame = cap.read()
89
+ if ret:
90
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
91
+ video_imgs[key].set_array(frame)
92
+
93
+ for i, line in enumerate(lines):
94
+ line.set_data(timestamps[: data_idx + 1], joint_states[: data_idx + 1, i])
95
+
96
+ time_line.set_xdata([current_time, current_time])
97
+ return lines + [time_line] + list(video_imgs.values())
98
+
99
+ anim = FuncAnimation(
100
+ fig, animate, init_func=init, frames=total_frames, interval=1000 / fps, blit=True
101
+ )
102
+
103
+ save_dir = "outputs/"
104
+ os.makedirs(save_dir, exist_ok=True)
105
+ output_path = os.path.join(save_dir, f"episode_{episode_index}_animation.mp4")
106
+ anim.save(output_path, writer="ffmpeg", fps=fps)
107
+
108
+ plt.close()
109
+ for cap in caps.values():
110
+ cap.release()
111
+
112
+ print(f"Animation saved to: {output_path}")
113
+
114
+
115
+ if __name__ == "__main__":
116
+ import argparse
117
+
118
+ parser = argparse.ArgumentParser(
119
+ description="Visualize joint states of a LeRobot dataset episode"
120
+ )
121
+ parser.add_argument("dataset_path", type=str, help="Path or HF repo ID of the LeRobot dataset")
122
+ parser.add_argument("-i", type=int, default=89, help="Episode index to visualize")
123
+ args = parser.parse_args()
124
+
125
+ plot_episode_joint_states(args.dataset_path, args.i)
wandb/debug.log ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2025-12-20 21:53:42,814 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Current SDK version is 0.18.0
2
+ 2025-12-20 21:53:42,814 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Configure stats pid to 3365891
3
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Loading settings from /home/lumos6/.config/wandb/settings
4
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Loading settings from /home/lumos6/work/starforce2/wandb/settings
5
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Loading settings from environment variables: {}
6
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Applying setup settings: {'_disable_service': False}
7
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Inferring run settings from compute environment: {'program_relpath': 'finetune.py', 'program_abspath': '/home/lumos6/work/starforce2/finetune.py', 'program': '/home/lumos6/work/starforce2/finetune.py'}
8
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Applying login settings: {}
9
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_setup.py:_flush():77] Applying login settings: {'mode': 'offline'}
10
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:_log_setup():525] Logging user logs to /home/lumos6/work/starforce2/wandb/offline-run-20251220_215342-tsf926l3/logs/debug.log
11
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:_log_setup():526] Logging internal logs to /home/lumos6/work/starforce2/wandb/offline-run-20251220_215342-tsf926l3/logs/debug-internal.log
12
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():609] calling init triggers
13
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():616] wandb.init called with sweep_config: {}
14
+ config: {}
15
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():659] starting backend
16
+ 2025-12-20 21:53:42,815 INFO MainThread:3365891 [wandb_init.py:init():663] setting up manager
17
+ 2025-12-20 21:53:42,817 INFO MainThread:3365891 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
18
+ 2025-12-20 21:53:42,817 INFO MainThread:3365891 [wandb_init.py:init():671] backend started and connected
19
+ 2025-12-20 21:53:42,818 INFO MainThread:3365891 [wandb_init.py:init():766] updated telemetry
20
+ 2025-12-20 21:53:42,823 INFO MainThread:3365891 [wandb_init.py:init():799] communicating run to backend with 90.0 second timeout
21
+ 2025-12-20 21:53:42,835 INFO MainThread:3365891 [wandb_init.py:init():850] starting run threads in backend
22
+ 2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_console_start():2466] atexit reg
23
+ 2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_redirect():2312] redirect: wrap_raw
24
+ 2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_redirect():2377] Wrapping output streams.
25
+ 2025-12-20 21:53:42,973 INFO MainThread:3365891 [wandb_run.py:_redirect():2402] Redirects installed.
26
+ 2025-12-20 21:53:42,974 INFO MainThread:3365891 [wandb_init.py:init():893] run started, returning control to user process
27
+ 2025-12-20 21:53:42,974 INFO MainThread:3365891 [wandb_run.py:_config_callback():1393] config_cb None None {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['GR00T_N1_5'], 'finetuning_task': None, 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'}, 'label2id': {'LABEL_0': 0, 'LABEL_1': 1}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': 'outputs/gr00t-3b-piper-task-pickup-bs8-1gpu-step60k/final_model', 'transformers_version': '4.52.2', 'action_dim': 32, 'action_head_cfg': {'action_dim': 32, 'action_horizon': 16, 'add_pos_embed': True, 'backbone_embedding_dim': 2048, 'diffusion_model_cfg': {'attention_head_dim': 48, 'cross_attention_dim': 2048, 'dropout': 0.2, 'final_dropout': True, 'interleave_self_attention': True, 'norm_type': 'ada_norm', 'num_attention_heads': 32, 'num_layers': 16, 'output_dim': 1024, 'positional_embeddings': None}, 'hidden_size': 1024, 'input_embedding_dim': 1536, 'max_action_dim': 32, 'max_state_dim': 64, 'model_dtype': 'float32', 'noise_beta_alpha': 1.5, 'noise_beta_beta': 1.0, 'noise_s': 0.999, 'num_inference_timesteps': 4, 'num_target_vision_tokens': 32, 'num_timestep_buckets': 1000, 'tune_diffusion_model': True, 'tune_projector': True, 'use_vlln': True, 'vl_self_attention_cfg': {'attention_head_dim': 64, 'dropout': 0.2, 'final_dropout': True, 'num_attention_heads': 32, 'num_layers': 4, 'positional_embeddings': None}}, 'action_horizon': 16, 'backbone_cfg': {'eagle_path': 'NVEagle/eagle_er-qwen3_1_7B-Siglip2_400M_stage1_5_128gpu_er_v7_1mlp_nops', 'load_bf16': False, 'project_to_dim': None, 'reproject_vision': False, 'select_layer': 12, 'tune_llm': False, 'tune_visual': True, 'use_flash_attention': True}, 'compute_dtype': 'bfloat16', 'hidden_size': 2048, 'model_dtype': 'float32', 'model_type': 'gr00t_n1_5', 'attn_implementation': None, 'output_dir': 'outputs/gr00t-3b-piper-task-pickup02-bs8-1gpu-step60k', 'overwrite_output_dir': False, 'do_train': False, 'do_eval': False, 'do_predict': False, 'eval_strategy': 'no', 'prediction_loss_only': False, 'per_device_train_batch_size': 16, 'per_device_eval_batch_size': 8, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 2.5e-05, 'weight_decay': 1e-05, 'adam_beta1': 0.95, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 300, 'max_steps': 60000, 'lr_scheduler_type': 'cosine', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.05, 'warmup_steps': 0, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': 'outputs/gr00t-3b-piper-task-pickup02-bs8-1gpu-step60k/runs/Dec20_21-53-38_lumos6', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 10, 'logging_nan_inf_filter': True, 'save_strategy': 'steps', 'save_steps': 5000, 'save_total_limit': 3, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': True, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': True, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': None, 'dataloader_num_workers': 8, 'dataloader_prefetch_factor': None, 'past_index': -1, 'run_name': 'outputs/gr00t-3b-piper-task-pickup02-bs8-1gpu-step60k', 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': '', 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard'], 'ddp_find_unused_parameters': False, 'ddp_bucket_cap_mb': 100, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': False, 'dataloader_persistent_workers': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': None, 'hub_always_push': False, 'gradient_checkpointing': False, 'gradient_checkpointing_kwargs': None, 'include_inputs_for_metrics': False, 'include_for_metrics': [], 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'eval_use_gather_object': False, 'average_tokens_across_devices': False}
28
+ 2025-12-20 21:53:42,976 INFO MainThread:3365891 [wandb_config.py:__setitem__():154] config set model/num_parameters = 2724163520 - <bound method Run._config_callback of <wandb.sdk.wandb_run.Run object at 0x75c8c8147460>>
29
+ 2025-12-20 21:53:42,976 INFO MainThread:3365891 [wandb_run.py:_config_callback():1393] config_cb model/num_parameters 2724163520 None
30
+ 2025-12-21 08:08:04,099 WARNING MsgRouterThr:3365891 [router.py:message_loop():77] message_loop has been closed