File size: 678 Bytes
98eeefd
 
f5651ba
 
98eeefd
 
 
f5651ba
 
 
98eeefd
 
 
 
 
 
 
 
 
 
 
f5651ba
98eeefd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""VAE loader (placeholder - actual loading handled by ModelManager)"""

import os
import sys
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
from config import MODELS_DIR


def load_vae(device: str = "cuda"):
    """Load VAE from HuggingFace

    Args:
        device: Device to load VAE on

    Returns:
        VAE model
    """
    vae = AutoencoderKL.from_pretrained(
        "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16, cache_dir=MODELS_DIR
    )
    vae.config.scaling_factor = 0.18215
    vae.config.shift_factor = 0
    return vae.to(device)