| | import matplotlib.pyplot as plt |
| | from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas |
| | import numpy as np |
| | import yaml |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from wm.dataset.dataset import RoboticsDatasetWrapper |
| | import os |
| | import cv2 |
| |
|
| | def create_side_by_side_video(config_path, sample_idx=0, output_dir="results/visualizations"): |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | |
| | print(f"Loading dataset for video: {config['dataset']['name']}") |
| | dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) |
| | |
| | |
| | sample = dataset[sample_idx] |
| | obs = sample['obs'] |
| | actions = sample['action'] |
| | |
| | if isinstance(obs, torch.Tensor): |
| | obs = obs.numpy() |
| | if isinstance(actions, torch.Tensor): |
| | actions = actions.numpy() |
| | |
| | |
| | T, C, H, W = obs.shape |
| | print(f"Video length: {T}, Frame size: {H}x{W}") |
| | |
| | |
| | max_val = np.abs(actions).max() |
| | lim = max_val * 1.2 |
| | x_lims = [-lim, lim] |
| | y_lims = [-lim, lim] |
| |
|
| | |
| | plot_size = H |
| | video_path = os.path.join(output_dir, f"action_video_sample_{sample_idx}.mp4") |
| | |
| | |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(video_path, fourcc, 5.0, (W + plot_size, H)) |
| | |
| | fig, ax = plt.subplots(figsize=(5, 5), dpi=100) |
| | canvas = FigureCanvas(fig) |
| | |
| | for t in range(T): |
| | |
| | |
| | frame = (np.transpose(obs[t], (1, 2, 0)) * 255).astype(np.uint8) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
| | |
| | |
| | ax.clear() |
| | |
| | dx, dy = actions[t, 0], actions[t, 1] |
| | |
| | |
| | ax.quiver(0, 0, dx, dy, angles='xy', scale_units='xy', scale=1, color='red', width=0.015) |
| | |
| | |
| | ax.scatter(dx, dy, color='red', s=50, edgecolors='k', zorder=5) |
| | |
| | |
| | ax.axhline(0, color='black', linewidth=0.8, alpha=0.3) |
| | ax.axvline(0, color='black', linewidth=0.8, alpha=0.3) |
| | |
| | ax.set_xlim(x_lims) |
| | ax.set_ylim(y_lims) |
| | ax.set_aspect('equal') |
| | ax.set_title(f"Step {t} | Action: ({dx:.3f}, {dy:.3f})", fontsize=10) |
| | ax.set_xlabel("dx") |
| | ax.set_ylabel("dy") |
| | ax.grid(True, linestyle=':', alpha=0.5) |
| | |
| | canvas.draw() |
| | plot_img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') |
| | plot_img = plot_img.reshape(canvas.get_width_height()[::-1] + (3,)) |
| | plot_img = cv2.cvtColor(plot_img, cv2.COLOR_RGB2BGR) |
| | plot_img = cv2.resize(plot_img, (plot_size, H)) |
| | |
| | |
| | combined = np.hstack((frame, plot_img)) |
| | out.write(combined) |
| | |
| | out.release() |
| | plt.close(fig) |
| | print(f"Side-by-side video saved to: {video_path}") |
| |
|
| | def visualize_action_traj(config_path, sample_idx=0, output_dir="results/visualizations"): |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | |
| | print(f"Loading dataset: {config['dataset']['name']}") |
| | dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) |
| | |
| | |
| | sample = dataset[sample_idx] |
| | actions = sample['action'] |
| | |
| | if isinstance(actions, torch.Tensor): |
| | actions = actions.numpy() |
| | |
| | |
| | T = actions.shape[0] |
| | print(f"Trajectory length: {T}") |
| | |
| | |
| | plt.figure(figsize=(8, 8)) |
| | |
| | |
| | max_val = np.abs(actions).max() |
| | lim = max_val * 1.2 |
| | |
| | |
| | plt.axhline(0, color='black', linewidth=1, alpha=0.5) |
| | plt.axvline(0, color='black', linewidth=1, alpha=0.5) |
| | |
| | |
| | colors = plt.cm.viridis(np.linspace(0, 1, T)) |
| | for i in range(T): |
| | dx, dy = actions[i, 0], actions[i, 1] |
| | plt.quiver(0, 0, dx, dy, angles='xy', scale_units='xy', scale=1, |
| | color=colors[i], alpha=0.8, width=0.005) |
| | plt.scatter(dx, dy, color=colors[i], s=30, edgecolors='k', alpha=0.8) |
| | plt.annotate(f"{i}", (dx, dy), textcoords="offset points", xytext=(0,5), ha='center', fontsize=8) |
| |
|
| | plt.title(f"2D Action Vectors (Deltas) - Sample {sample_idx}", fontsize=14) |
| | plt.xlabel("dx", fontsize=12) |
| | plt.ylabel("dy", fontsize=12) |
| | plt.grid(True, linestyle='--', alpha=0.7) |
| | plt.xlim([-lim, lim]) |
| | plt.ylim([-lim, lim]) |
| | plt.gca().set_aspect('equal') |
| | |
| | output_path = os.path.join(output_dir, f"action_traj_sample_{sample_idx}.png") |
| | plt.savefig(output_path, dpi=300, bbox_inches='tight') |
| | print(f"Action trajectory plot saved to: {output_path}") |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, default="/storage/ice-shared/ae8803che/hxue/data/world_model/wm/config/fulltraj_dit/lang_table.yaml") |
| | parser.add_argument("--samples", type=int, nargs='+', default=[0, 10, 50, 100]) |
| | args = parser.parse_args() |
| | |
| | for idx in args.samples: |
| | try: |
| | visualize_action_traj(args.config, sample_idx=idx) |
| | create_side_by_side_video(args.config, sample_idx=idx) |
| | except Exception as e: |
| | print(f"Error processing sample {idx}: {e}") |
| |
|