File size: 2,582 Bytes
0a7036f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
import torch
from diffusers import AutoencoderKLWan
from transformers import (
T5TokenizerFast,
UMT5EncoderModel,
)
from .model import WanTransformer3DModel
def load_vae(
vae_path,
torch_dtype,
torch_device,
):
vae = AutoencoderKLWan.from_pretrained(
vae_path,
torch_dtype=torch_dtype,
)
return vae.to(torch_device)
def load_text_encoder(
text_encoder_path,
torch_dtype,
torch_device,
):
text_encoder = UMT5EncoderModel.from_pretrained(
text_encoder_path,
torch_dtype=torch_dtype,
)
return text_encoder.to(torch_device)
def load_tokenizer(tokenizer_path, ):
tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path, )
return tokenizer
def load_transformer(
transformer_path,
torch_dtype,
torch_device,
):
model = WanTransformer3DModel.from_pretrained(
transformer_path,
torch_dtype=torch_dtype,
)
return model.to(torch_device)
def patchify(x, patch_size):
if patch_size is None or patch_size == 1:
return x
batch_size, channels, frames, height, width = x.shape
x = x.view(batch_size, channels, frames, height // patch_size, patch_size,
width // patch_size, patch_size)
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames,
height // patch_size, width // patch_size)
return x
class WanVAEStreamingWrapper:
def __init__(self, vae_model):
self.vae = vae_model
self.encoder = vae_model.encoder
self.quant_conv = vae_model.quant_conv
if hasattr(self.vae, "_cached_conv_counts"):
self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
else:
count = 0
for m in self.encoder.modules():
if m.__class__.__name__ == "WanCausalConv3d":
count += 1
self.enc_conv_num = count
self.clear_cache()
def clear_cache(self):
self.feat_cache = [None] * self.enc_conv_num
def encode_chunk(self, x_chunk):
if hasattr(self.vae.config,
"patch_size") and self.vae.config.patch_size is not None:
x_chunk = patchify(x_chunk, self.vae.config.patch_size)
feat_idx = [0]
out = self.encoder(x_chunk,
feat_cache=self.feat_cache,
feat_idx=feat_idx)
enc = self.quant_conv(out)
return enc
|