File size: 4,312 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import os
import sys
import torch
# Add project root to sys.path to allow absolute imports
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")))
from wm.model.wan_base.modules.vae import _video_vae
class WanVAEWrapper(torch.nn.Module):
def __init__(self, pretrained_path=None):
super().__init__()
# Mean and std for scaling latents
mean = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
std = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32))
self.register_buffer("std", torch.tensor(std, dtype=torch.float32))
# Default path if none provided
if pretrained_path is None:
pretrained_path = "wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
# init model
self.model = _video_vae(
pretrained_path=pretrained_path,
z_dim=16,
).eval().requires_grad_(False)
def encode(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, num_frames, num_channels, height, width]
# Convert to [batch_size, num_channels, num_frames, height, width]
x = x.permute(0, 2, 1, 3, 4)
device, dtype = x.device, x.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
latents = [
self.model.encode(u.unsqueeze(0), scale).squeeze(0)
for u in x
]
latents = torch.stack(latents, dim=0)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
latents = latents.permute(0, 2, 1, 3, 4)
return latents
def decode_to_pixel(self, latent: torch.Tensor) -> torch.Tensor:
# latent: [batch_size, num_frames, num_channels, height, width]
# to [batch_size, num_channels, num_frames, height, width]
zs = latent.permute(0, 2, 1, 3, 4)
device, dtype = latent.device, latent.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
output = [
self.model.decode(u.unsqueeze(0),
scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
output = torch.stack(output, dim=0)
# from [batch_size, num_channels, num_frames, height, width]
# to [batch_size, num_frames, num_channels, height, width]
output = output.permute(0, 2, 1, 3, 4)
return output
if __name__ == "__main__":
# Test code
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Testing WanVAEWrapper on {device}...")
ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
if not os.path.exists(ckpt_path):
print(f"Warning: Checkpoint not found at {ckpt_path}")
ckpt_path = "wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
try:
vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device)
print("Model loaded successfully.")
# Create dummy video (B, T, C, H, W)
# T should be 1 + 4*k for Wan VAE (e.g., 5, 9, 13...)
B, T, C, H, W = 1, 5, 3, 128, 128
dummy_video = torch.randn(B, T, C, H, W).to(device).clamp(-1, 1)
print(f"Input video shape: {dummy_video.shape}")
with torch.no_grad():
# Test encode
latent = vae.encode(dummy_video)
print(f"Latent shape: {latent.shape}")
# Test decode
recon = vae.decode_to_pixel(latent)
print(f"Reconstructed video shape: {recon.shape}")
mse = torch.nn.functional.mse_loss(dummy_video, recon)
print(f"Reconstruction MSE: {mse.item():.6f}")
except Exception as e:
print(f"Test failed: {e}")
import traceback
traceback.print_exc()
|