File size: 5,518 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import sys
import torch
import cv2
import numpy as np
from PIL import Image
from pathlib import Path

# Add project root to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

from wm.model.tokenizer.wan_tokenizer import WanVAEWrapper

def load_video(video_path, num_frames=13):
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    if total_frames <= 0:
        return None
        
    # Sample num_frames with equal intervals
    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    frames = []
    
    curr_idx = 0
    while len(frames) < num_frames:
        ret, frame = cap.read()
        if not ret:
            break
        if curr_idx in indices:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        curr_idx += 1
    cap.release()
    
    if not frames:
        return None
        
    # Pad if we couldn't get enough frames
    while len(frames) < num_frames:
        frames.append(frames[-1])
        
    return np.stack(frames) # (T, H, W, C)

def save_comparison(gt_frames, recon_frames, save_path):
    # gt_frames, recon_frames: (T, H, W, C)
    T, H, W, C = gt_frames.shape
    
    # Calculate absolute difference and multiply by 5
    diff = np.abs(gt_frames.astype(np.float32) - recon_frames.astype(np.float32))
    diff_vis = (diff * 1.0).clip(0, 255).astype(np.uint8)
    
    # Create a grid: 3 rows (GT, Recon, Diff*5), T columns
    combined = np.zeros((H * 3, W * T, C), dtype=np.uint8)
    
    for t in range(T):
        combined[:H, t*W:(t+1)*W] = gt_frames[t]
        combined[H:H*2, t*W:(t+1)*W] = recon_frames[t]
        combined[H*2:, t*W:(t+1)*W] = diff_vis[t]
        
    Image.fromarray(combined).save(save_path)

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Initialize VAE
    ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
    if not os.path.exists(ckpt_path):
        print(f"Checkpoint not found at {ckpt_path}")
        return

    print("Loading VAE model...")
    vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device)
    
    datasets = {
        "language_table": "/storage/ice-shared/ae8803che/hxue/data/dataset/language_table/",
        "dreamer4": "/storage/ice-shared/ae8803che/hxue/data/dataset/dreamer4_processed/",
        "rt1": "/storage/ice-shared/ae8803che/hxue/data/dataset/rt1/",
        "recon": "/storage/ice-shared/ae8803che/hxue/data/dataset/recon_processed/"
    }
    
    output_dir = Path("/storage/ice-shared/ae8803che/hxue/data/world_model/results/test_wan_vae")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Wan VAE requires T = 1 + 4k frames. We'll use 13 frames for reconstruction.
    T_vae = 13
    # Final visualization will sample 8 frames as requested.
    T_vis = 8
    
    for ds_name, ds_root in datasets.items():
        print(f"\n--- Processing {ds_name} ---")
        meta_path = Path(ds_root) / "metadata_lite.pt"
        if not meta_path.exists():
            print(f"Metadata not found for {ds_name}")
            continue
            
        metadata = torch.load(meta_path, map_location='cpu', weights_only=False)
        
        # Take 4 examples
        count = 0
        for i in range(len(metadata)):
            if count >= 4: break
            
            video_rel_path = metadata[i]['video_path']
            # Handle both absolute and relative paths
            if video_rel_path.startswith('/'):
                video_path = video_rel_path
            else:
                video_path = str(Path(ds_root) / video_rel_path)
                
            if not os.path.exists(video_path):
                continue
                
            print(f"Example {count}: {video_path}")
            
            # Load frames
            raw_frames = load_video(video_path, num_frames=T_vae)
            if raw_frames is None: continue
            
            # Prepare tensor for VAE: (B, T, C, H, W) in [-1, 1]
            # Wan VAE expects frames to be normalized to [-1, 1]
            video_tensor = torch.from_numpy(raw_frames).permute(0, 3, 1, 2).float() / 255.0
            video_tensor = (video_tensor * 2.0 - 1.0).unsqueeze(0).to(device)
            
            with torch.no_grad():
                # Encode and Decode
                latent = vae.encode(video_tensor)
                print(f"  Input shape (B, T, C, H, W): {video_tensor.shape}")
                print(f"  Latent shape (B, T', C', H', W'): {latent.shape}")
                recon = vae.decode_to_pixel(latent) # [B, T, C, H, W] in [-1, 1]
                
            # Convert reconstructed back to uint8 RGB (T, H, W, C)
            recon_np = ((recon[0].cpu().permute(0, 2, 3, 1).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
            gt_np = raw_frames # (T, H, W, C) uint8 RGB
            
            # Sample 8 frames for visualization as requested
            vis_indices = np.linspace(0, T_vae - 1, T_vis, dtype=int)
            vis_gt = gt_np[vis_indices]
            vis_recon = recon_np[vis_indices]
            
            save_path = output_dir / f"{ds_name}_sample_{count}.png"
            save_comparison(vis_gt, vis_recon, save_path)
            print(f"  Result saved to {save_path}")
            count += 1

    print("\nAll tests completed.")

if __name__ == "__main__":
    main()