gr00t1.5_starforce / tests /vis_lerobot_data.py
nnh-pbbb's picture
Add files using upload-large-folder tool
cd793b5 verified
import numpy as np
import matplotlib.pyplot as plt
from vlaholo.datasets.lerobot_dataset import LeRobotDataset
import os
import cv2
from matplotlib.animation import FuncAnimation
"""
TODO:
support datasets == 4.0
"""
def plot_episode_joint_states(dataset_path: str, episode_index: int):
dataset = LeRobotDataset(dataset_path)
if episode_index >= dataset.num_episodes:
print(
f"episode index {episode_index} is out of range, total episodes: {dataset.num_episodes}"
)
episode_index = dataset.num_episodes - 1
print(f"force set to max episode index: {episode_index}")
hf_dataset = dataset.hf_dataset
episode_ds = hf_dataset.filter(lambda x: x["episode_index"] == episode_index)
video_paths = dataset.encode_episode_videos(episode_index=episode_index)
caps = {}
for key, path in video_paths.items():
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video: {path}")
caps[key] = cap
fps = caps[next(iter(caps))].get(cv2.CAP_PROP_FPS)
total_frames = int(caps[next(iter(caps))].get(cv2.CAP_PROP_FRAME_COUNT))
df = episode_ds.to_pandas()
joint_states = np.vstack(df["observation.state"].values)
timestamps = df["timestamp"].values
duration_sec = timestamps[-1] - timestamps[0]
joint_names = dataset.features["observation.state"]["names"]
if isinstance(joint_names, list) and len(joint_names) == 1 and isinstance(joint_names[0], list):
joint_names = joint_names[0]
if len(joint_names) <= 1:
joint_names = [f"Joint {i}" for i in range(joint_states.shape[1])]
n_joints = joint_states.shape[1]
n_joint_rows = (n_joints + 2) // 3
# 创建 Figure 并设置对称左右边距及顶部空间
fig = plt.figure(figsize=(18, 4 + 4 * n_joint_rows))
fig.subplots_adjust(top=0.92, bottom=0.05, left=0.05, right=0.95, hspace=0.4, wspace=0.3)
# 一级标题:字体更大(48号)
fig.suptitle(
"Starforce Data Inspect System", x=0.5, y=0.98, fontsize=35, fontweight="bold", ha="center"
)
# 二级统计信息:fps、总帧数、episode index、轨迹时长,16号字体
stats_text = (
f"FPS: {fps:.2f} Total frames: {total_frames} "
f"Episode: {episode_index} Duration: {duration_sec:.2f}s"
)
fig.text(0.5, 0.92, stats_text, ha="center", fontsize=21, fontweight="bold")
plt.rcParams.update(
{
"font.family": "sans-serif",
"font.sans-serif": ["Arial", "DejaVu Sans"],
"font.size": 12,
"axes.titlesize": 14,
"axes.labelsize": 13,
"axes.spines.top": False,
"axes.spines.right": False,
}
)
# 使用均等宽度的 GridSpec
gs = fig.add_gridspec(
n_joint_rows + 1, 3, width_ratios=[1, 1, 1], height_ratios=[2] + [1] * n_joint_rows
)
# 渲染视频区域
video_axes, video_imgs = {}, {}
for idx, key in enumerate(video_paths.keys()):
ax = fig.add_subplot(gs[0, idx])
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(key)
img = ax.imshow(np.zeros((480, 640, 3)), aspect="auto")
ax.set_box_aspect(480 / 640)
video_axes[key] = ax
video_imgs[key] = img
# 绘制轨迹
joint_axes, lines, time_lines = [], [], []
base_colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
colors = (base_colors * ((n_joints // len(base_colors)) + 1))[:n_joints]
for i in range(n_joints):
row, col = 1 + i // 3, i % 3
ax = fig.add_subplot(gs[row, col])
# 渐变背景
gradient = np.linspace(0, 1, 256).reshape(256, 1)
extent = [timestamps[0], timestamps[-1], joint_states[:, i].min(), joint_states[:, i].max()]
ax.imshow(
np.repeat(gradient, 256, axis=1),
aspect="auto",
cmap="Blues",
alpha=0.1,
extent=extent,
origin="lower",
zorder=0,
)
# 轨迹线
(line,) = ax.plot([], [], label=joint_names[i], color=colors[i], linewidth=2.5, zorder=1)
lines.append(line)
# 轴设置
ax.set_xlabel("Time (s)")
ax.set_ylabel("pos")
ax.spines["left"].set_visible(False)
ax.set_title(joint_names[i], fontweight="bold")
ax.set_xlim(timestamps[0], timestamps[-1])
y0, y1 = joint_states[:, i].min(), joint_states[:, i].max()
m = (y1 - y0) * 0.1
ax.set_ylim(y0 - m, y1 + m)
tl = ax.axvline(x=timestamps[0], color="crimson", alpha=0.7, linewidth=1.2, zorder=2)
time_lines.append(tl)
joint_axes.append(ax)
# 初始化与动画函数
def init():
for ln in lines:
ln.set_data([], [])
return lines + time_lines + list(video_imgs.values())
def animate(frame_idx):
idx = min(frame_idx, len(timestamps) - 1)
t = timestamps[idx]
print(
f"\rProcessing frames: {frame_idx + 1}/{total_frames} ({(frame_idx+1)/total_frames*100:.1f}%)",
end="",
flush=True,
)
for key, cap in caps.items():
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_imgs[key].set_array(frame)
for j, ln in enumerate(lines):
ln.set_data(timestamps[: idx + 1], joint_states[: idx + 1, j])
for tl in time_lines:
tl.set_xdata([t, t])
return lines + time_lines + list(video_imgs.values())
anim = FuncAnimation(
fig, animate, init_func=init, frames=total_frames, interval=1000 / fps, blit=True
)
print()
save_dir = "outputs/"
os.makedirs(save_dir, exist_ok=True)
out_path = os.path.join(save_dir, f"episode_{episode_index}_animation.mp4")
anim.save(out_path, writer="ffmpeg", fps=fps)
plt.close()
for cap in caps.values():
cap.release()
print(f"Animation saved to: {out_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Visualize joint states of a LeRobot dataset episode"
)
parser.add_argument("dataset_path", type=str, help="Path or HF repo ID of the LeRobot dataset")
parser.add_argument("-i", type=int, default=89, help="Episode index to visualize")
args = parser.parse_args()
plot_episode_joint_states(args.dataset_path, args.i)