File size: 3,924 Bytes
c336648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from diffusers import AutoencoderKL

DTYPE = torch.float16
DEVICE = "cuda:0"

class SDv1_VAE:
	scale = 1/8
	channels = 4
	def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
		self.device = device
		self.dtype = dtype
		self.model = AutoencoderKL.from_pretrained(
			"stabilityai/sd-vae-ft-mse"
		)
		self.model.eval().to(self.dtype).to(self.device)
		if dec_only:
			del self.model.encoder

	def encode(self, image):
		image = image.to(self.dtype).to(self.device)
		image = (image * 2.0) - 1.0 # assuming input is [0;1]
		with torch.no_grad():
			latent = self.model.encode(image).latent_dist.sample()
		return latent.to(image.dtype).to(image.device)

	def decode(self, latent, grad=False):
		latent = latent.to(self.dtype).to(self.device)
		if grad:
			out = self.model.decode(latent)[0]
		else:
			with torch.no_grad():
				out = self.model.decode(latent).sample
			out = torch.clamp(out, min=-1.0, max=1.0)
			out = (out + 1.0) / 2.0
		return out.to(latent.dtype).to(latent.device)

class SDXL_VAE(SDv1_VAE):
	scale = 1/8
	channels = 4
	def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
		self.device = device
		self.dtype = dtype
		self.model = AutoencoderKL.from_pretrained(
			"madebyollin/sdxl-vae-fp16-fix"
		)
		self.model.eval().to(self.dtype).to(self.device)
		if dec_only:
			del self.model.encoder

class SDv3_VAE(SDv1_VAE):
	scale = 1/8
	channels = 16
	def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
		self.device = device
		self.dtype = dtype
		self.model = AutoencoderKL.from_pretrained(
			"stabilityai/stable-diffusion-3-medium-diffusers",
			subfolder="vae"
		)
		self.model.eval().to(self.dtype).to(self.device)
		if dec_only:
			del self.model.encoder

class CascadeC_VAE(SDv1_VAE):
	scale = 1/32
	channels = 16
	def __init__(self, device=DEVICE, dtype=DTYPE, **kwargs):
		self.device = device
		self.dtype = dtype

		#For now this is just piggybacking off of koyha-ss/sd-scripts
		from library import stable_cascade as sc
		from safetensors.torch import load_file
		from huggingface_hub import hf_hub_download

		self.model = sc.EfficientNetEncoder()
		self.model.load_state_dict(load_file(
			str(hf_hub_download(
			repo_id   = "stabilityai/stable-cascade",
			filename  = "effnet_encoder.safetensors",
			))
		))
		self.model.eval().to(self.dtype).to(self.device)

class CascadeA_VAE():
	scale = 1/4
	channels = 4
	def __init__(self, device=DEVICE, dtype=DTYPE, dec_only=False):
		self.device = device
		self.dtype = dtype

		# not sure if this will change in the future?
		from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
		self.model = PaellaVQModel.from_pretrained(
			"stabilityai/stable-cascade",
			subfolder="vqgan"
		)
		self.model.eval().to(self.dtype).to(self.device)
		if dec_only:
			del self.model.encoder

	def encode(self, image):
		image = image.to(self.dtype).to(self.device)
		with torch.no_grad():
			latent = self.model.encode(image).latents
		return latent.to(image.dtype).to(image.device)

	def decode(self, latent, grad=False):
		latent = latent.to(self.dtype).to(self.device)
		if grad:
			out = self.model.decode(latent)[0]
		else:
			with torch.no_grad():
				out = self.model.decode(latent).sample
			out = torch.clamp(out, min=0.0, max=1.0)
		return out.to(latent.dtype).to(latent.device)

class No_VAE():
	scale = 1
	channels = 3
	def __init__(self, *args, **kwargs):
		pass

	def encode(self, image):
		return image

	def decode(self, image):
		return image

vae_vers = {
	"no": No_VAE,
	"v1": SDv1_VAE,
	"xl": SDXL_VAE,
	"v3": SDv3_VAE,
	"cc": CascadeC_VAE,
	"ca": CascadeA_VAE,
}

def load_vae(ver, *args, **kwargs):
	assert ver in vae_vers.keys(), f"Unknown VAE '{ver}'"
	vae_class = vae_vers[ver]
	return vae_class(*args, **kwargs)