| |
| import torch |
| from diffusers import AutoencoderKLWan |
| from transformers import ( |
| T5TokenizerFast, |
| UMT5EncoderModel, |
| ) |
|
|
| from .model import WanTransformer3DModel |
|
|
|
|
| def load_vae( |
| vae_path, |
| torch_dtype, |
| torch_device, |
| ): |
| vae = AutoencoderKLWan.from_pretrained( |
| vae_path, |
| torch_dtype=torch_dtype, |
| ) |
| return vae.to(torch_device) |
|
|
|
|
| def load_text_encoder( |
| text_encoder_path, |
| torch_dtype, |
| torch_device, |
| ): |
| text_encoder = UMT5EncoderModel.from_pretrained( |
| text_encoder_path, |
| torch_dtype=torch_dtype, |
| ) |
| return text_encoder.to(torch_device) |
|
|
|
|
| def load_tokenizer(tokenizer_path, ): |
| tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path, ) |
| return tokenizer |
|
|
|
|
| def load_transformer( |
| transformer_path, |
| torch_dtype, |
| torch_device, |
| ): |
| model = WanTransformer3DModel.from_pretrained( |
| transformer_path, |
| torch_dtype=torch_dtype, |
| ) |
| return model.to(torch_device) |
|
|
|
|
| def patchify(x, patch_size): |
| if patch_size is None or patch_size == 1: |
| return x |
| batch_size, channels, frames, height, width = x.shape |
| x = x.view(batch_size, channels, frames, height // patch_size, patch_size, |
| width // patch_size, patch_size) |
| x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() |
| x = x.view(batch_size, channels * patch_size * patch_size, frames, |
| height // patch_size, width // patch_size) |
| return x |
|
|
|
|
| class WanVAEStreamingWrapper: |
|
|
| def __init__(self, vae_model): |
| self.vae = vae_model |
| self.encoder = vae_model.encoder |
| self.quant_conv = vae_model.quant_conv |
|
|
| if hasattr(self.vae, "_cached_conv_counts"): |
| self.enc_conv_num = self.vae._cached_conv_counts["encoder"] |
| else: |
| count = 0 |
| for m in self.encoder.modules(): |
| if m.__class__.__name__ == "WanCausalConv3d": |
| count += 1 |
| self.enc_conv_num = count |
|
|
| self.clear_cache() |
|
|
| def clear_cache(self): |
| self.feat_cache = [None] * self.enc_conv_num |
|
|
| def encode_chunk(self, x_chunk): |
| if hasattr(self.vae.config, |
| "patch_size") and self.vae.config.patch_size is not None: |
| x_chunk = patchify(x_chunk, self.vae.config.patch_size) |
| feat_idx = [0] |
| out = self.encoder(x_chunk, |
| feat_cache=self.feat_cache, |
| feat_idx=feat_idx) |
| enc = self.quant_conv(out) |
| return enc |
|
|