File size: 13,707 Bytes
6728bc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
# reference: https://github.com/comfyanonymous/ComfyUI/blob/v0.3.64/comfy/sd.py#L273
import itertools
import math
import torch
from backend import memory_management
from backend.patcher.base import ModelPatcher
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None):
"""https://github.com/comfyanonymous/ComfyUI/blob/v0.3.64/comfy/utils.py#L901"""
dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale(dim, val):
up = upscale_amount[dim]
if callable(up):
return up(val)
else:
return val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
if callable(up):
return up(val)
else:
return val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a):
out = []
for i in range(len(a)):
out.append(round(get_scale(i, a[i])))
return out
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
for b in range(samples.shape[0]):
s = samples[b : b + 1]
if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
output[b : b + 1] = function(s).to(output_device)
continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d]) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o.add_(ps * mask)
o_d.add_(mask)
output[b : b + 1] = out / out_div
return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu"):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device)
class VAE:
def __init__(self, model=None, device=None, dtype=None, no_init=False, *, is_wan=False):
if no_init:
return
if not is_wan:
self.upscale_ratio = 8
self.upscale_index_formula = None
self.downscale_ratio = 8
self.downscale_index_formula = None
self.latent_dim = 2
self.latent_channels = int(model.config.latent_channels) # 4 | 16
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * memory_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * memory_management.dtype_size(dtype)
else:
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = int(model.config.z_dim) # 16
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2] <= 4 else 6000) * shape[3] * shape[4] * memory_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2] <= 4 else 7000) * shape[3] * shape[4] * (8 * 8) * memory_management.dtype_size(dtype)
self.output_channels = 3
self.first_stage_model = model.eval()
self.device = device or memory_management.vae_device()
offload_device = memory_management.vae_offload_device()
self.vae_dtype = dtype or memory_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype)
self.output_device = memory_management.intermediate_device()
self.patcher = ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
self.is_wan = is_wan
def clone(self):
n = VAE(no_init=True)
n.patcher = self.patcher.clone()
n.memory_used_encode = self.memory_used_encode
n.memory_used_decode = self.memory_used_decode
n.downscale_ratio = self.downscale_ratio
n.latent_channels = self.latent_channels
n.first_stage_model = self.first_stage_model
n.device = self.device
n.vae_dtype = self.vae_dtype
n.output_device = self.output_device
n.is_wan = self.is_wan
return n
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output((tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device) + tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device) + tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device)) / 3.0)
return output
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
samples += tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
samples += tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
samples /= 3.0
return samples
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in: torch.Tensor):
if memory_management.VAE_ALWAYS_TILED:
return self.decode_tiled(samples_in).to(self.output_device)
pixel_samples = None
_tile = False
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = memory_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x : x + batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x : x + batch_number] = out
except memory_management.OOM_EXCEPTION:
print("Warning: Encountered Out of Memory during VAE decoding; Retrying with Tiled VAE Decoding...")
_tile = True
if _tile:
memory_management.soft_empty_cache()
return self.decode_tiled(samples_in).to(self.output_device)
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
return pixel_samples
def decode_tiled(self, samples: torch.Tensor, tile_x: int = 64, tile_y: int = 64, overlap: int = 16):
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype)
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
args = {
"tile_x": tile_x,
"tile_y": tile_y,
"overlap": overlap,
}
if not self.is_wan:
output = self.decode_tiled_(samples, **args)
else:
args["overlap"] = (1, overlap, overlap)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples: torch.Tensor):
if memory_management.VAE_ALWAYS_TILED:
return self.encode_tiled(pixel_samples)
_samples = pixel_samples.movedim(-1, 1)
if self.is_wan and _samples.ndim < 5:
_samples = _samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(_samples.shape, self.vae_dtype)
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = memory_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, _samples.shape[0], batch_number):
pixels_in = self.process_input(_samples[x : x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x : x + batch_number] = out
_tile = False
except memory_management.OOM_EXCEPTION:
print("Warning: Encountered Out of Memory during VAE Encoding; Retrying with Tiled VAE Encoding...")
_tile = True
if _tile:
memory_management.soft_empty_cache()
return self.encode_tiled(pixel_samples)
return samples
def encode_tiled(self, pixel_samples: torch.Tensor, tile_x: int = 512, tile_y: int = 512, overlap: int = 64):
pixel_samples = pixel_samples.movedim(-1, 1)
if self.is_wan:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
args = {
"tile_x": tile_x,
"tile_y": tile_y,
"overlap": overlap,
}
if not self.is_wan:
return self.encode_tiled_(pixel_samples, **args)
args["tile_t"] = self.upscale_ratio[0](9999)
args["overlap"] = (1, overlap, overlap)
maximum = self.upscale_ratio[0](self.downscale_ratio[0](pixel_samples.shape[2]))
return self.encode_tiled_3d(pixel_samples[:, :, :maximum], **args)
def process_input(self, image):
return image * 2.0 - 1.0
def process_output(self, image):
return torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|