world_model / wm /eval /visualize_action_trajectory.py
t1an's picture
Upload folder using huggingface_hub
f17ae24 verified
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy as np
import yaml
import torch
from torch.utils.data import DataLoader
from wm.dataset.dataset import RoboticsDatasetWrapper
import os
import cv2
def create_side_by_side_video(config_path, sample_idx=0, output_dir="results/visualizations"):
os.makedirs(output_dir, exist_ok=True)
# 1. Load Config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# 2. Setup Dataset
print(f"Loading dataset for video: {config['dataset']['name']}")
dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name'])
# 3. Get a sample
sample = dataset[sample_idx]
obs = sample['obs'] # [T, 3, H, W] in [0, 1]
actions = sample['action'] # [T, 2]
if isinstance(obs, torch.Tensor):
obs = obs.numpy()
if isinstance(actions, torch.Tensor):
actions = actions.numpy()
# Use raw actions (deltas)
T, C, H, W = obs.shape
print(f"Video length: {T}, Frame size: {H}x{W}")
# Calculate fixed plot limits symmetric around (0,0)
max_val = np.abs(actions).max()
lim = max_val * 1.2 # Add 20% padding
x_lims = [-lim, lim]
y_lims = [-lim, lim]
# Video Setup
plot_size = H
video_path = os.path.join(output_dir, f"action_video_sample_{sample_idx}.mp4")
# Use avc1 for better compatibility with browser/VSCode players if possible,
# but mp4v is safer for basic OpenCV. Let's use mp4v.
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(video_path, fourcc, 5.0, (W + plot_size, H))
fig, ax = plt.subplots(figsize=(5, 5), dpi=100)
canvas = FigureCanvas(fig)
for t in range(T):
# 1. Prepare Video Frame
# obs is [C, H, W] -> [H, W, C]
frame = (np.transpose(obs[t], (1, 2, 0)) * 255).astype(np.uint8)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# 2. Prepare Plot Frame
ax.clear()
dx, dy = actions[t, 0], actions[t, 1]
# Draw arrow from (0,0) to (dx, dy)
ax.quiver(0, 0, dx, dy, angles='xy', scale_units='xy', scale=1, color='red', width=0.015)
# Add a dot at the end for visibility
ax.scatter(dx, dy, color='red', s=50, edgecolors='k', zorder=5)
# Draw origin axes
ax.axhline(0, color='black', linewidth=0.8, alpha=0.3)
ax.axvline(0, color='black', linewidth=0.8, alpha=0.3)
ax.set_xlim(x_lims)
ax.set_ylim(y_lims)
ax.set_aspect('equal')
ax.set_title(f"Step {t} | Action: ({dx:.3f}, {dy:.3f})", fontsize=10)
ax.set_xlabel("dx")
ax.set_ylabel("dy")
ax.grid(True, linestyle=':', alpha=0.5)
canvas.draw()
plot_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
plot_img = plot_img.reshape(canvas.get_width_height()[::-1] + (3,))
plot_img = cv2.cvtColor(plot_img, cv2.COLOR_RGB2BGR)
plot_img = cv2.resize(plot_img, (plot_size, H))
# 3. Combine Side-by-Side
combined = np.hstack((frame, plot_img))
out.write(combined)
out.release()
plt.close(fig)
print(f"Side-by-side video saved to: {video_path}")
def visualize_action_traj(config_path, sample_idx=0, output_dir="results/visualizations"):
os.makedirs(output_dir, exist_ok=True)
# 1. Load Config
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# 2. Setup Dataset
print(f"Loading dataset: {config['dataset']['name']}")
dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name'])
# 3. Get a sample
sample = dataset[sample_idx]
actions = sample['action'] # [T, 2]
if isinstance(actions, torch.Tensor):
actions = actions.numpy()
# Use raw actions (deltas)
T = actions.shape[0]
print(f"Trajectory length: {T}")
# 4. Plot 2D Action Vectors
plt.figure(figsize=(8, 8))
# Calculate limits symmetric around (0,0)
max_val = np.abs(actions).max()
lim = max_val * 1.2
# Plot origin axes
plt.axhline(0, color='black', linewidth=1, alpha=0.5)
plt.axvline(0, color='black', linewidth=1, alpha=0.5)
# Plot all arrows with color gradient for time
colors = plt.cm.viridis(np.linspace(0, 1, T))
for i in range(T):
dx, dy = actions[i, 0], actions[i, 1]
plt.quiver(0, 0, dx, dy, angles='xy', scale_units='xy', scale=1,
color=colors[i], alpha=0.8, width=0.005)
plt.scatter(dx, dy, color=colors[i], s=30, edgecolors='k', alpha=0.8)
plt.annotate(f"{i}", (dx, dy), textcoords="offset points", xytext=(0,5), ha='center', fontsize=8)
plt.title(f"2D Action Vectors (Deltas) - Sample {sample_idx}", fontsize=14)
plt.xlabel("dx", fontsize=12)
plt.ylabel("dy", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.xlim([-lim, lim])
plt.ylim([-lim, lim])
plt.gca().set_aspect('equal')
output_path = os.path.join(output_dir, f"action_traj_sample_{sample_idx}.png")
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Action trajectory plot saved to: {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="/storage/ice-shared/ae8803che/hxue/data/world_model/wm/config/fulltraj_dit/lang_table.yaml")
parser.add_argument("--samples", type=int, nargs='+', default=[0, 10, 50, 100])
args = parser.parse_args()
for idx in args.samples:
try:
visualize_action_traj(args.config, sample_idx=idx)
create_side_by_side_video(args.config, sample_idx=idx)
except Exception as e:
print(f"Error processing sample {idx}: {e}")