Spaces:
Runtime error
Runtime error
| import gc | |
| import torch | |
| from .vae2_2_module import Wan2_2_VAE | |
| def get_vae2_2(model_path, device="cuda", weight_dtype=torch.float32) -> Wan2_2_VAE: | |
| vae = Wan2_2_VAE(vae_pth=model_path).to(device).to(weight_dtype) | |
| vae.vae.requires_grad_(False) | |
| vae.vae.eval() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return vae | |
| __all__ = ["Wan2_2_VAE", "get_vae2_2"] | |