Spaces:
Runtime error
Runtime error
File size: 374 Bytes
873b6ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import gc
import torch
from .vae2_2_module import Wan2_2_VAE
def get_vae2_2(model_path, device="cuda", weight_dtype=torch.float32) -> Wan2_2_VAE:
vae = Wan2_2_VAE(vae_pth=model_path).to(device).to(weight_dtype)
vae.vae.requires_grad_(False)
vae.vae.eval()
gc.collect()
torch.cuda.empty_cache()
return vae
__all__ = ["Wan2_2_VAE", "get_vae2_2"]
|