| | import os |
| | import sys |
| | import torch |
| |
|
| | |
| | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))) |
| |
|
| | from wm.model.wan_base.modules.vae import _video_vae |
| |
|
| |
|
| | class WanVAEWrapper(torch.nn.Module): |
| | def __init__(self, pretrained_path=None): |
| | super().__init__() |
| | |
| | |
| | mean = [ |
| | -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, |
| | 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 |
| | ] |
| | std = [ |
| | 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, |
| | 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 |
| | ] |
| | self.register_buffer("mean", torch.tensor(mean, dtype=torch.float32)) |
| | self.register_buffer("std", torch.tensor(std, dtype=torch.float32)) |
| |
|
| | |
| | if pretrained_path is None: |
| | pretrained_path = "wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" |
| |
|
| | |
| | self.model = _video_vae( |
| | pretrained_path=pretrained_path, |
| | z_dim=16, |
| | ).eval().requires_grad_(False) |
| |
|
| | def encode(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | x = x.permute(0, 2, 1, 3, 4) |
| | |
| | device, dtype = x.device, x.dtype |
| | scale = [self.mean.to(device=device, dtype=dtype), |
| | 1.0 / self.std.to(device=device, dtype=dtype)] |
| | |
| | latents = [ |
| | self.model.encode(u.unsqueeze(0), scale).squeeze(0) |
| | for u in x |
| | ] |
| | latents = torch.stack(latents, dim=0) |
| | |
| | |
| | |
| | latents = latents.permute(0, 2, 1, 3, 4) |
| | return latents |
| |
|
| | def decode_to_pixel(self, latent: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | zs = latent.permute(0, 2, 1, 3, 4) |
| |
|
| | device, dtype = latent.device, latent.dtype |
| | scale = [self.mean.to(device=device, dtype=dtype), |
| | 1.0 / self.std.to(device=device, dtype=dtype)] |
| |
|
| | output = [ |
| | self.model.decode(u.unsqueeze(0), |
| | scale).float().clamp_(-1, 1).squeeze(0) |
| | for u in zs |
| | ] |
| | output = torch.stack(output, dim=0) |
| | |
| | |
| | output = output.permute(0, 2, 1, 3, 4) |
| | return output |
| |
|
| | if __name__ == "__main__": |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Testing WanVAEWrapper on {device}...") |
| | |
| | ckpt_path = "/storage/ice-shared/ae8803che/hxue/data/checkpoint/wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" |
| | if not os.path.exists(ckpt_path): |
| | print(f"Warning: Checkpoint not found at {ckpt_path}") |
| | ckpt_path = "wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" |
| |
|
| | try: |
| | vae = WanVAEWrapper(pretrained_path=ckpt_path).to(device) |
| | print("Model loaded successfully.") |
| | |
| | |
| | |
| | B, T, C, H, W = 1, 5, 3, 128, 128 |
| | dummy_video = torch.randn(B, T, C, H, W).to(device).clamp(-1, 1) |
| | |
| | print(f"Input video shape: {dummy_video.shape}") |
| | |
| | with torch.no_grad(): |
| | |
| | latent = vae.encode(dummy_video) |
| | print(f"Latent shape: {latent.shape}") |
| | |
| | |
| | recon = vae.decode_to_pixel(latent) |
| | print(f"Reconstructed video shape: {recon.shape}") |
| | |
| | mse = torch.nn.functional.mse_loss(dummy_video, recon) |
| | print(f"Reconstruction MSE: {mse.item():.6f}") |
| | |
| | except Exception as e: |
| | print(f"Test failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|