File size: 374 Bytes
873b6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import gc

import torch

from .vae2_2_module import Wan2_2_VAE


def get_vae2_2(model_path, device="cuda", weight_dtype=torch.float32) -> Wan2_2_VAE:
    vae = Wan2_2_VAE(vae_pth=model_path).to(device).to(weight_dtype)
    vae.vae.requires_grad_(False)
    vae.vae.eval()
    gc.collect()
    torch.cuda.empty_cache()
    return vae


__all__ = ["Wan2_2_VAE", "get_vae2_2"]