File size: 2,668 Bytes
f4a41d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | from collections import defaultdict
from typing import Dict, List
import torch
from tqdm import trange
from .model.iter import try_get_iter
class VAEDecodeBatched:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"samples": ("LATENT", ),
"vae": ("VAE", ),
"batch_size": ("INT", {
"default": 1,
"min": 1,
"max": 32,
"step": 1
}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "latent"
def decode(self, vae, samples, batch_size: int):
s = samples['samples']
n = s.shape[0]
iters = try_get_iter(vae)
if iters is None:
vae_num = 1
else:
vae_num = len(iters)
vae_results: Dict[int,List[torch.Tensor]] = defaultdict(lambda: [])
for i in trange(0, n, batch_size):
e = min([i+batch_size, n])
t = s[i:e, ...]
v = vae.decode(t)
vaes = torch.chunk(v, vae_num)
for vn, vv in enumerate(vaes):
vae_results[vn].append(vv)
results = []
for k in sorted(vae_results.keys()):
v = vae_results[k]
results.extend(v)
vs = torch.cat(results).contiguous()
return (vs,)
class VAEEncodeBatched:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"pixels": ("IMAGE", ),
"vae": ("VAE", ),
"batch_size": ("INT", {
"default": 1,
"min": 1,
"max": 32,
"step": 1
}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "latent"
def encode(self, vae, pixels, batch_size: int):
n = pixels.shape[0]
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
pixels = pixels[:,:,:,:3]
results = []
for i in trange(0, n, batch_size):
e = max([i+batch_size, n])
t = pixels[i:e, ...]
v = vae.encode(t)
results.append(v)
vs = torch.cat(results)
return ({"samples":vs}, )
|