Spaces:
Runtime error
Runtime error
| import torch | |
| from pathlib import Path | |
| from hyimage.common.constants import PRECISION_TO_TYPE | |
| from .hunyuanimage_vae import HunyuanVAE2D | |
| def load_vae(device, vae_path: str = None, vae_precision: str = None): | |
| config = HunyuanVAE2D.load_config(vae_path) | |
| vae = HunyuanVAE2D.from_config(config) | |
| if Path(vae_path).exists(): | |
| ckpt = torch.load(Path(vae_path) / "pytorch_model.ckpt", map_location='cpu') | |
| if "state_dict" in ckpt: | |
| ckpt = ckpt["state_dict"] | |
| vae_ckpt = {} | |
| for k, v in ckpt.items(): | |
| if k.startswith("vae."): | |
| vae_ckpt[k.replace("vae.", "")] = v | |
| vae.load_state_dict(vae_ckpt) | |
| if vae_precision is not None: | |
| vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) | |
| vae.requires_grad_(False) | |
| if device is not None: | |
| vae = vae.to(device) | |
| vae.eval() | |
| return vae |