daVinci-MagiHuman / inference /model /vae2_2 /vae2_2_model.py
ethanchern's picture
init
873b6ec
raw
history blame contribute delete
374 Bytes
import gc
import torch
from .vae2_2_module import Wan2_2_VAE
def get_vae2_2(model_path, device="cuda", weight_dtype=torch.float32) -> Wan2_2_VAE:
vae = Wan2_2_VAE(vae_pth=model_path).to(device).to(weight_dtype)
vae.vae.requires_grad_(False)
vae.vae.eval()
gc.collect()
torch.cuda.empty_cache()
return vae
__all__ = ["Wan2_2_VAE", "get_vae2_2"]