| from einops import rearrange | |
| import torch | |
| from torch import nn | |
| # Standalone: only Wan VAE (used by infworld_config.yaml) | |
| from .vae import WanVAE | |
| class WanVAEModelWrapper(nn.Module): | |
| def __init__(self, vae_pth, dtype=torch.float, device="cuda", patch_size=(4, 8, 8)): | |
| super(WanVAEModelWrapper, self).__init__() | |
| self.module = WanVAE( | |
| vae_pth=vae_pth, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| self.dtype = dtype | |
| self.device = device | |
| self.out_channels = 16 | |
| self.patch_size = patch_size | |
| def encode(self, x): | |
| # input: x: B, C, T, H, W or B, C, H, W | |
| # return: x: B, C, T, H, W | |
| if len(x.shape) == 4: | |
| x = rearrange(x, "B C H W -> B C 1 H W") | |
| x = self.module.encode_batch(x) | |
| return x | |
| def decode(self, x): | |
| # input: x: B, C, T, H, W or B, C, H, W | |
| # return: x: B, C, T, H, W | |
| if len(x.shape) == 4: | |
| x = rearrange(x, "T C H W -> 1 C T H W") | |
| x = self.module.decode_batch(x) | |
| return x | |
| def get_latent_size(self, input_size): | |
| latent_size = [] | |
| for i in range(3): | |
| if i == 0: | |
| target_size = 1 + (input_size[i] - 1) // self.patch_size[i] | |
| latent_size.append(target_size) | |
| else: | |
| assert input_size[i] % self.patch_size[i] == 0, "Input spatial size must be divisible by patch size" | |
| target_size = input_size[i] // self.patch_size[i] | |
| latent_size.append(target_size) | |
| return latent_size | |