File size: 3,924 Bytes
c336648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
from diffusers import AutoencoderKL
DTYPE = torch.float16
DEVICE = "cuda:0"
class SDv1_VAE:
scale = 1/8
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
def encode(self, image):
image = image.to(self.dtype).to(self.device)
image = (image * 2.0) - 1.0 # assuming input is [0;1]
with torch.no_grad():
latent = self.model.encode(image).latent_dist.sample()
return latent.to(image.dtype).to(image.device)
def decode(self, latent, grad=False):
latent = latent.to(self.dtype).to(self.device)
if grad:
out = self.model.decode(latent)[0]
else:
with torch.no_grad():
out = self.model.decode(latent).sample
out = torch.clamp(out, min=-1.0, max=1.0)
out = (out + 1.0) / 2.0
return out.to(latent.dtype).to(latent.device)
class SDXL_VAE(SDv1_VAE):
scale = 1/8
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class SDv3_VAE(SDv1_VAE):
scale = 1/8
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
self.model = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="vae"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
class CascadeC_VAE(SDv1_VAE):
scale = 1/32
channels = 16
def __init__(self, device=DEVICE, dtype=DTYPE, **kwargs):
self.device = device
self.dtype = dtype
#For now this is just piggybacking off of koyha-ss/sd-scripts
from library import stable_cascade as sc
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
self.model = sc.EfficientNetEncoder()
self.model.load_state_dict(load_file(
str(hf_hub_download(
repo_id = "stabilityai/stable-cascade",
filename = "effnet_encoder.safetensors",
))
))
self.model.eval().to(self.dtype).to(self.device)
class CascadeA_VAE():
scale = 1/4
channels = 4
def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
self.device = device
self.dtype = dtype
# not sure if this will change in the future?
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
self.model = PaellaVQModel.from_pretrained(
"stabilityai/stable-cascade",
subfolder="vqgan"
)
self.model.eval().to(self.dtype).to(self.device)
if dec_only:
del self.model.encoder
def encode(self, image):
image = image.to(self.dtype).to(self.device)
with torch.no_grad():
latent = self.model.encode(image).latents
return latent.to(image.dtype).to(image.device)
def decode(self, latent, grad=False):
latent = latent.to(self.dtype).to(self.device)
if grad:
out = self.model.decode(latent)[0]
else:
with torch.no_grad():
out = self.model.decode(latent).sample
out = torch.clamp(out, min=0.0, max=1.0)
return out.to(latent.dtype).to(latent.device)
class No_VAE():
scale = 1
channels = 3
def __init__(self, *args, **kwargs):
pass
def encode(self, image):
return image
def decode(self, image):
return image
vae_vers = {
"no": No_VAE,
"v1": SDv1_VAE,
"xl": SDXL_VAE,
"v3": SDv3_VAE,
"cc": CascadeC_VAE,
"ca": CascadeA_VAE,
}
def load_vae(ver, *args, **kwargs):
assert ver in vae_vers.keys(), f"Unknown VAE '{ver}'"
vae_class = vae_vers[ver]
return vae_class(*args, **kwargs)
|