import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas import numpy as np import yaml import torch from torch.utils.data import DataLoader from wm.dataset.dataset import RoboticsDatasetWrapper import os import cv2 def create_side_by_side_video(config_path, sample_idx=0, output_dir="results/visualizations"): os.makedirs(output_dir, exist_ok=True) # 1. Load Config with open(config_path, 'r') as f: config = yaml.safe_load(f) # 2. Setup Dataset print(f"Loading dataset for video: {config['dataset']['name']}") dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) # 3. Get a sample sample = dataset[sample_idx] obs = sample['obs'] # [T, 3, H, W] in [0, 1] actions = sample['action'] # [T, 2] if isinstance(obs, torch.Tensor): obs = obs.numpy() if isinstance(actions, torch.Tensor): actions = actions.numpy() # Use raw actions (deltas) T, C, H, W = obs.shape print(f"Video length: {T}, Frame size: {H}x{W}") # Calculate fixed plot limits symmetric around (0,0) max_val = np.abs(actions).max() lim = max_val * 1.2 # Add 20% padding x_lims = [-lim, lim] y_lims = [-lim, lim] # Video Setup plot_size = H video_path = os.path.join(output_dir, f"action_video_sample_{sample_idx}.mp4") # Use avc1 for better compatibility with browser/VSCode players if possible, # but mp4v is safer for basic OpenCV. Let's use mp4v. fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(video_path, fourcc, 5.0, (W + plot_size, H)) fig, ax = plt.subplots(figsize=(5, 5), dpi=100) canvas = FigureCanvas(fig) for t in range(T): # 1. Prepare Video Frame # obs is [C, H, W] -> [H, W, C] frame = (np.transpose(obs[t], (1, 2, 0)) * 255).astype(np.uint8) frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # 2. Prepare Plot Frame ax.clear() dx, dy = actions[t, 0], actions[t, 1] # Draw arrow from (0,0) to (dx, dy) ax.quiver(0, 0, dx, dy, angles='xy', scale_units='xy', scale=1, color='red', width=0.015) # Add a dot at the end for visibility ax.scatter(dx, dy, color='red', s=50, edgecolors='k', zorder=5) # Draw origin axes ax.axhline(0, color='black', linewidth=0.8, alpha=0.3) ax.axvline(0, color='black', linewidth=0.8, alpha=0.3) ax.set_xlim(x_lims) ax.set_ylim(y_lims) ax.set_aspect('equal') ax.set_title(f"Step {t} | Action: ({dx:.3f}, {dy:.3f})", fontsize=10) ax.set_xlabel("dx") ax.set_ylabel("dy") ax.grid(True, linestyle=':', alpha=0.5) canvas.draw() plot_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') plot_img = plot_img.reshape(canvas.get_width_height()[::-1] + (3,)) plot_img = cv2.cvtColor(plot_img, cv2.COLOR_RGB2BGR) plot_img = cv2.resize(plot_img, (plot_size, H)) # 3. Combine Side-by-Side combined = np.hstack((frame, plot_img)) out.write(combined) out.release() plt.close(fig) print(f"Side-by-side video saved to: {video_path}") def visualize_action_traj(config_path, sample_idx=0, output_dir="results/visualizations"): os.makedirs(output_dir, exist_ok=True) # 1. Load Config with open(config_path, 'r') as f: config = yaml.safe_load(f) # 2. Setup Dataset print(f"Loading dataset: {config['dataset']['name']}") dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) # 3. Get a sample sample = dataset[sample_idx] actions = sample['action'] # [T, 2] if isinstance(actions, torch.Tensor): actions = actions.numpy() # Use raw actions (deltas) T = actions.shape[0] print(f"Trajectory length: {T}") # 4. Plot 2D Action Vectors plt.figure(figsize=(8, 8)) # Calculate limits symmetric around (0,0) max_val = np.abs(actions).max() lim = max_val * 1.2 # Plot origin axes plt.axhline(0, color='black', linewidth=1, alpha=0.5) plt.axvline(0, color='black', linewidth=1, alpha=0.5) # Plot all arrows with color gradient for time colors = plt.cm.viridis(np.linspace(0, 1, T)) for i in range(T): dx, dy = actions[i, 0], actions[i, 1] plt.quiver(0, 0, dx, dy, angles='xy', scale_units='xy', scale=1, color=colors[i], alpha=0.8, width=0.005) plt.scatter(dx, dy, color=colors[i], s=30, edgecolors='k', alpha=0.8) plt.annotate(f"{i}", (dx, dy), textcoords="offset points", xytext=(0,5), ha='center', fontsize=8) plt.title(f"2D Action Vectors (Deltas) - Sample {sample_idx}", fontsize=14) plt.xlabel("dx", fontsize=12) plt.ylabel("dy", fontsize=12) plt.grid(True, linestyle='--', alpha=0.7) plt.xlim([-lim, lim]) plt.ylim([-lim, lim]) plt.gca().set_aspect('equal') output_path = os.path.join(output_dir, f"action_traj_sample_{sample_idx}.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') print(f"Action trajectory plot saved to: {output_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="/storage/ice-shared/ae8803che/hxue/data/world_model/wm/config/fulltraj_dit/lang_table.yaml") parser.add_argument("--samples", type=int, nargs='+', default=[0, 10, 50, 100]) args = parser.parse_args() for idx in args.samples: try: visualize_action_traj(args.config, sample_idx=idx) create_side_by_side_video(args.config, sample_idx=idx) except Exception as e: print(f"Error processing sample {idx}: {e}")