| import numpy as np |
| import matplotlib.pyplot as plt |
| from vlaholo.datasets.lerobot_dataset import LeRobotDataset |
| import os |
| import cv2 |
| from matplotlib.animation import FuncAnimation |
|
|
| """ |
| TODO: |
| |
| support datasets == 4.0 |
| """ |
|
|
|
|
| def plot_episode_joint_states(dataset_path: str, episode_index: int): |
| dataset = LeRobotDataset(dataset_path) |
|
|
| if episode_index >= dataset.num_episodes: |
| print( |
| f"episode index {episode_index} is out of range, total episodes: {dataset.num_episodes}" |
| ) |
| episode_index = dataset.num_episodes - 1 |
| print(f"force set to max episode index: {episode_index}") |
|
|
| hf_dataset = dataset.hf_dataset |
| episode_ds = hf_dataset.filter(lambda x: x["episode_index"] == episode_index) |
| video_paths = dataset.encode_episode_videos(episode_index=episode_index) |
|
|
| caps = {} |
| for key, path in video_paths.items(): |
| cap = cv2.VideoCapture(path) |
| if not cap.isOpened(): |
| raise ValueError(f"Could not open video: {path}") |
| caps[key] = cap |
|
|
| fps = caps[next(iter(caps))].get(cv2.CAP_PROP_FPS) |
| total_frames = int(caps[next(iter(caps))].get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
| df = episode_ds.to_pandas() |
| joint_states = np.vstack(df["observation.state"].values) |
| timestamps = df["timestamp"].values |
| duration_sec = timestamps[-1] - timestamps[0] |
|
|
| joint_names = dataset.features["observation.state"]["names"] |
| if isinstance(joint_names, list) and len(joint_names) == 1 and isinstance(joint_names[0], list): |
| joint_names = joint_names[0] |
| if len(joint_names) <= 1: |
| joint_names = [f"Joint {i}" for i in range(joint_states.shape[1])] |
|
|
| n_joints = joint_states.shape[1] |
| n_joint_rows = (n_joints + 2) // 3 |
|
|
| |
| fig = plt.figure(figsize=(18, 4 + 4 * n_joint_rows)) |
| fig.subplots_adjust(top=0.92, bottom=0.05, left=0.05, right=0.95, hspace=0.4, wspace=0.3) |
|
|
| |
| fig.suptitle( |
| "Starforce Data Inspect System", x=0.5, y=0.98, fontsize=35, fontweight="bold", ha="center" |
| ) |
| |
| stats_text = ( |
| f"FPS: {fps:.2f} Total frames: {total_frames} " |
| f"Episode: {episode_index} Duration: {duration_sec:.2f}s" |
| ) |
| fig.text(0.5, 0.92, stats_text, ha="center", fontsize=21, fontweight="bold") |
|
|
| plt.rcParams.update( |
| { |
| "font.family": "sans-serif", |
| "font.sans-serif": ["Arial", "DejaVu Sans"], |
| "font.size": 12, |
| "axes.titlesize": 14, |
| "axes.labelsize": 13, |
| "axes.spines.top": False, |
| "axes.spines.right": False, |
| } |
| ) |
|
|
| |
| gs = fig.add_gridspec( |
| n_joint_rows + 1, 3, width_ratios=[1, 1, 1], height_ratios=[2] + [1] * n_joint_rows |
| ) |
|
|
| |
| video_axes, video_imgs = {}, {} |
| for idx, key in enumerate(video_paths.keys()): |
| ax = fig.add_subplot(gs[0, idx]) |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| ax.set_title(key) |
| img = ax.imshow(np.zeros((480, 640, 3)), aspect="auto") |
| ax.set_box_aspect(480 / 640) |
| video_axes[key] = ax |
| video_imgs[key] = img |
|
|
| |
| joint_axes, lines, time_lines = [], [], [] |
| base_colors = [ |
| "#1f77b4", |
| "#ff7f0e", |
| "#2ca02c", |
| "#d62728", |
| "#9467bd", |
| "#8c564b", |
| "#e377c2", |
| "#7f7f7f", |
| "#bcbd22", |
| "#17becf", |
| ] |
| colors = (base_colors * ((n_joints // len(base_colors)) + 1))[:n_joints] |
|
|
| for i in range(n_joints): |
| row, col = 1 + i // 3, i % 3 |
| ax = fig.add_subplot(gs[row, col]) |
|
|
| |
| gradient = np.linspace(0, 1, 256).reshape(256, 1) |
| extent = [timestamps[0], timestamps[-1], joint_states[:, i].min(), joint_states[:, i].max()] |
| ax.imshow( |
| np.repeat(gradient, 256, axis=1), |
| aspect="auto", |
| cmap="Blues", |
| alpha=0.1, |
| extent=extent, |
| origin="lower", |
| zorder=0, |
| ) |
|
|
| |
| (line,) = ax.plot([], [], label=joint_names[i], color=colors[i], linewidth=2.5, zorder=1) |
| lines.append(line) |
|
|
| |
| ax.set_xlabel("Time (s)") |
| ax.set_ylabel("pos") |
| ax.spines["left"].set_visible(False) |
|
|
| ax.set_title(joint_names[i], fontweight="bold") |
| ax.set_xlim(timestamps[0], timestamps[-1]) |
| y0, y1 = joint_states[:, i].min(), joint_states[:, i].max() |
| m = (y1 - y0) * 0.1 |
| ax.set_ylim(y0 - m, y1 + m) |
|
|
| tl = ax.axvline(x=timestamps[0], color="crimson", alpha=0.7, linewidth=1.2, zorder=2) |
| time_lines.append(tl) |
| joint_axes.append(ax) |
|
|
| |
| def init(): |
| for ln in lines: |
| ln.set_data([], []) |
| return lines + time_lines + list(video_imgs.values()) |
|
|
| def animate(frame_idx): |
| idx = min(frame_idx, len(timestamps) - 1) |
| t = timestamps[idx] |
| print( |
| f"\rProcessing frames: {frame_idx + 1}/{total_frames} ({(frame_idx+1)/total_frames*100:.1f}%)", |
| end="", |
| flush=True, |
| ) |
|
|
| for key, cap in caps.items(): |
| ret, frame = cap.read() |
| if ret: |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| video_imgs[key].set_array(frame) |
| for j, ln in enumerate(lines): |
| ln.set_data(timestamps[: idx + 1], joint_states[: idx + 1, j]) |
| for tl in time_lines: |
| tl.set_xdata([t, t]) |
| return lines + time_lines + list(video_imgs.values()) |
|
|
| anim = FuncAnimation( |
| fig, animate, init_func=init, frames=total_frames, interval=1000 / fps, blit=True |
| ) |
| print() |
|
|
| save_dir = "outputs/" |
| os.makedirs(save_dir, exist_ok=True) |
| out_path = os.path.join(save_dir, f"episode_{episode_index}_animation.mp4") |
| anim.save(out_path, writer="ffmpeg", fps=fps) |
|
|
| plt.close() |
| for cap in caps.values(): |
| cap.release() |
| print(f"Animation saved to: {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser( |
| description="Visualize joint states of a LeRobot dataset episode" |
| ) |
| parser.add_argument("dataset_path", type=str, help="Path or HF repo ID of the LeRobot dataset") |
| parser.add_argument("-i", type=int, default=89, help="Episode index to visualize") |
| args = parser.parse_args() |
| plot_episode_joint_states(args.dataset_path, args.i) |
|
|