File size: 3,625 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import cv2
import torch
import yaml
import os
import sys

# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

from wm.dataset.dataset import RoboticsDatasetWrapper

def create_overlay_video(config_path, sample_idx=0, output_dir="results/visualizations"):
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. Load Config
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # 2. Setup Dataset
    print(f"Loading dataset: {config['dataset']['name']}")
    dataset = RoboticsDatasetWrapper.get_dataset(config['dataset']['name'])
    
    # 3. Get a sample
    sample = dataset[sample_idx]
    obs = sample['obs'] # [T, 3, H, W] in [0, 1]
    actions = sample['action'] # [T, 2] (dx, dy)
    
    if isinstance(obs, torch.Tensor):
        obs = obs.numpy()
    if isinstance(actions, torch.Tensor):
        actions = actions.numpy()
    
    # Transformations requested by user:
    # 1. Switch dx and dy
    actions = actions[:, [1, 0]]
    # 2. Negate dx and dy (dx -> -dx, dy -> -dy)
    actions = -actions
    
    T, C, H, W = obs.shape
    
    # 4. Calculate Accumulated Path (Relative to center)
    path = np.cumsum(actions, axis=0)
    
    # 5. Scaling to Pixels
    # Center the path and scale to middle 60% of image height/width
    max_disp = np.abs(path).max()
    scale = (min(H, W) * 0.3) / max_disp if max_disp > 0 else 1.0
    
    # Map to pixel coordinates (Relative to center)
    pixel_path = path * scale + np.array([W // 2, H // 2])
    
    # 6. Video Setup
    video_path = os.path.join(output_dir, f"overlay_video_sample_{sample_idx}.mp4")
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(video_path, fourcc, 10.0, (W, H))
    
    print(f"Generating overlay video: {W}x{H}, {T} frames")
    
    for t in range(T):
        # Prepare frame: [C, H, W] -> [H, W, C]
        frame = (np.transpose(obs[t], (1, 2, 0)) * 255).astype(np.uint8)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        
        # Draw the full path history up to time t
        if t > 0:
            for i in range(1, t + 1):
                pt1 = (int(pixel_path[i-1, 0]), int(pixel_path[i-1, 1]))
                pt2 = (int(pixel_path[i, 0]), int(pixel_path[i, 1]))
                # Color fades from Blue (start) to Red (end)
                color = (int(255 * (1 - i/T)), 0, int(255 * (i/T)))
                cv2.line(frame, pt1, pt2, color, 2, cv2.LINE_AA)
        
        # Draw current position (Green dot)
        curr_pos = (int(pixel_path[t, 0]), int(pixel_path[t, 1]))
        cv2.circle(frame, curr_pos, 5, (0, 255, 0), -1, cv2.LINE_AA)
        
        # Draw current action vector (White arrow)
        adx, ady = actions[t, 0] * scale, actions[t, 1] * scale
        arrow_end = (int(pixel_path[t, 0] + adx), int(pixel_path[t, 1] + ady))
        cv2.arrowedLine(frame, curr_pos, arrow_end, (255, 255, 255), 2, tipLength=0.3)
        
        # Add Step Label
        cv2.putText(frame, f"Step: {t}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        out.write(frame)
        
    out.release()
    print(f"Overlay video saved to: {video_path}")

if __name__ == "__main__":
    # Example usage: python wm/eval/visualize_overlay.py --sample 0
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="wm/config/fulltraj_dit/lang_table.yaml")
    parser.add_argument("--sample", type=int, default=0)
    args = parser.parse_args()
    
    create_overlay_video(args.config, args.sample)