File size: 4,312 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
import torch

# Add project root to sys.path to allow absolute imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))

from wm.model.wan_base.modules.vae import _video_vae


class WanVAEWrapper(torch.nn.Module):
    def __init__(self, pretrained_path=None):
        super().__init__()
        
        # Mean and std for scaling latents
        mean = [
            -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
            0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
        ]
        std = [
            2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
            3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
        ]
        self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32))
        self.register_buffer("std", torch.tensor(std, dtype=torch.float32))

        # Default path if none provided
        if pretrained_path is None:
            pretrained_path = "wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"

        # init model
        self.model = _video_vae(
            pretrained_path=pretrained_path,
            z_dim=16,
        ).eval().requires_grad_(False)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch_size, num_frames, num_channels, height, width]
        # Convert to [batch_size, num_channels, num_frames, height, width]
        x = x.permute(0, 2, 1, 3, 4)
        
        device, dtype = x.device, x.dtype
        scale = [self.mean.to(device=device, dtype=dtype),
                 1.0 / self.std.to(device=device, dtype=dtype)]
        
        latents = [
            self.model.encode(u.unsqueeze(0), scale).squeeze(0)
            for u in x
        ]
        latents = torch.stack(latents, dim=0)
        
        # from [batch_size, num_channels, num_frames, height, width]
        # to [batch_size, num_frames, num_channels, height, width]
        latents = latents.permute(0, 2, 1, 3, 4)
        return latents

    def decode_to_pixel(self, latent: torch.Tensor) -> torch.Tensor:
        # latent: [batch_size, num_frames, num_channels, height, width]
        # to [batch_size, num_channels, num_frames, height, width]
        zs = latent.permute(0, 2, 1, 3, 4)

        device, dtype = latent.device, latent.dtype
        scale = [self.mean.to(device=device, dtype=dtype),
                 1.0 / self.std.to(device=device, dtype=dtype)]

        output = [
            self.model.decode(u.unsqueeze(0),
                              scale).float().clamp_(-1, 1).squeeze(0)
            for u in zs
        ]
        output = torch.stack(output, dim=0)
        # from [batch_size, num_channels, num_frames, height, width]
        # to [batch_size, num_frames, num_channels, height, width]
        output = output.permute(0, 2, 1, 3, 4)
        return output

if __name__ == "__main__":
    # Test code
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Testing WanVAEWrapper on {device}...")
    
    ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
    if not os.path.exists(ckpt_path):
        print(f"Warning: Checkpoint not found at {ckpt_path}")
        ckpt_path = "wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"

    try:
        vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device)
        print("Model loaded successfully.")
        
        # Create dummy video (B, T, C, H, W)
        # T should be 1 + 4*k for Wan VAE (e.g., 5, 9, 13...)
        B, T, C, H, W = 1, 5, 3, 128, 128
        dummy_video = torch.randn(B, T, C, H, W).to(device).clamp(-1, 1)
        
        print(f"Input video shape: {dummy_video.shape}")
        
        with torch.no_grad():
            # Test encode
            latent = vae.encode(dummy_video)
            print(f"Latent shape: {latent.shape}")
            
            # Test decode
            recon = vae.decode_to_pixel(latent)
            print(f"Reconstructed video shape: {recon.shape}")
            
            mse = torch.nn.functional.mse_loss(dummy_video, recon)
            print(f"Reconstruction MSE: {mse.item():.6f}")
            
    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()