| | import numpy as np |
| | import cv2 |
| | import torch |
| | import yaml |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| |
|
| | from wm.dataset.dataset import RoboticsDatasetWrapper |
| |
|
| | def create_overlay_video(config_path, sample_idx=0, output_dir="results/visualizations"): |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | |
| | |
| | print(f"Loading dataset: {config['dataset']['name']}") |
| | dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name']) |
| | |
| | |
| | sample = dataset[sample_idx] |
| | obs = sample['obs'] |
| | actions = sample['action'] |
| | |
| | if isinstance(obs, torch.Tensor): |
| | obs = obs.numpy() |
| | if isinstance(actions, torch.Tensor): |
| | actions = actions.numpy() |
| | |
| | |
| | |
| | actions = actions[:, [1, 0]] |
| | |
| | actions = -actions |
| | |
| | T, C, H, W = obs.shape |
| | |
| | |
| | path = np.cumsum(actions, axis=0) |
| | |
| | |
| | |
| | max_disp = np.abs(path).max() |
| | scale = (min(H, W) * 0.3) / max_disp if max_disp > 0 else 1.0 |
| | |
| | |
| | pixel_path = path * scale + np.array([W // 2, H // 2]) |
| | |
| | |
| | video_path = os.path.join(output_dir, f"overlay_video_sample_{sample_idx}.mp4") |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(video_path, fourcc, 10.0, (W, H)) |
| | |
| | print(f"Generating overlay video: {W}x{H}, {T} frames") |
| | |
| | for t in range(T): |
| | |
| | frame = (np.transpose(obs[t], (1, 2, 0)) * 255).astype(np.uint8) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
| | |
| | |
| | if t > 0: |
| | for i in range(1, t + 1): |
| | pt1 = (int(pixel_path[i-1, 0]), int(pixel_path[i-1, 1])) |
| | pt2 = (int(pixel_path[i, 0]), int(pixel_path[i, 1])) |
| | |
| | color = (int(255 * (1 - i/T)), 0, int(255 * (i/T))) |
| | cv2.line(frame, pt1, pt2, color, 2, cv2.LINE_AA) |
| | |
| | |
| | curr_pos = (int(pixel_path[t, 0]), int(pixel_path[t, 1])) |
| | cv2.circle(frame, curr_pos, 5, (0, 255, 0), -1, cv2.LINE_AA) |
| | |
| | |
| | adx, ady = actions[t, 0] * scale, actions[t, 1] * scale |
| | arrow_end = (int(pixel_path[t, 0] + adx), int(pixel_path[t, 1] + ady)) |
| | cv2.arrowedLine(frame, curr_pos, arrow_end, (255, 255, 255), 2, tipLength=0.3) |
| | |
| | |
| | cv2.putText(frame, f"Step: {t}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
| | |
| | out.write(frame) |
| | |
| | out.release() |
| | print(f"Overlay video saved to: {video_path}") |
| |
|
| | if __name__ == "__main__": |
| | |
| | import argparse |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str, default="wm/config/fulltraj_dit/lang_table.yaml") |
| | parser.add_argument("--sample", type=int, default=0) |
| | args = parser.parse_args() |
| | |
| | create_overlay_video(args.config, args.sample) |
| |
|