File size: 1,587 Bytes
01c7703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from einops import rearrange
import torch
from torch import nn

# Standalone: only Wan VAE (used by infworld_config.yaml)
from .vae import WanVAE


class WanVAEModelWrapper(nn.Module):
    def __init__(self, vae_pth, dtype=torch.float, device="cuda", patch_size=(4, 8, 8)):
        super(WanVAEModelWrapper, self).__init__()
        self.module = WanVAE(
            vae_pth=vae_pth,
            dtype=dtype,
            device=device,
        )
        self.dtype = dtype
        self.device = device
        self.out_channels = 16
        self.patch_size = patch_size

    def encode(self, x):
        # input: x: B, C, T, H, W or B, C, H, W
        # return: x: B, C, T, H, W
        if len(x.shape) == 4:
            x = rearrange(x, "B C H W -> B C 1 H W")
        x = self.module.encode_batch(x)
        return x

    def decode(self, x):
        # input: x: B, C, T, H, W or B, C, H, W
        # return: x: B, C, T, H, W
        if len(x.shape) == 4:
            x = rearrange(x, "T C H W -> 1 C T H W")
        x = self.module.decode_batch(x)
        return x

    def get_latent_size(self, input_size):
        latent_size = []
        for i in range(3):
            if i == 0:
                target_size = 1 + (input_size[i] - 1) // self.patch_size[i]
                latent_size.append(target_size)
            else:
                assert input_size[i] % self.patch_size[i] == 0, "Input spatial size must be divisible by patch size"
                target_size = input_size[i] // self.patch_size[i]
                latent_size.append(target_size)
        return latent_size