File size: 2,668 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from collections import defaultdict
from typing import Dict, List
import torch
from tqdm import trange
from .model.iter import try_get_iter

class VAEDecodeBatched:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "samples": ("LATENT", ),
                "vae": ("VAE", ),
                "batch_size": ("INT", {
                    "default": 1,
                    "min": 1,
                    "max": 32,
                    "step": 1
                }),
            }
        }
    
    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "decode"

    CATEGORY = "latent"

    def decode(self, vae, samples, batch_size: int):
        s = samples['samples']
        n = s.shape[0]
        
        iters = try_get_iter(vae)
        if iters is None:
            vae_num = 1
        else:
            vae_num = len(iters)
        
        vae_results: Dict[int,List[torch.Tensor]] = defaultdict(lambda: [])
        
        for i in trange(0, n, batch_size):
            e = min([i+batch_size, n])
            t = s[i:e, ...]
            v = vae.decode(t)
            
            vaes = torch.chunk(v, vae_num)
            
            for vn, vv in enumerate(vaes):
                vae_results[vn].append(vv)
        
        results = []
        for k in sorted(vae_results.keys()):
            v = vae_results[k]
            results.extend(v)
            
        vs = torch.cat(results).contiguous()
        return (vs,)


class VAEEncodeBatched:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "pixels": ("IMAGE", ),
                "vae": ("VAE", ),
                "batch_size": ("INT", {
                    "default": 1,
                    "min": 1,
                    "max": 32,
                    "step": 1
                }),
            }
        }
    
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "encode"

    CATEGORY = "latent"

    def encode(self, vae, pixels, batch_size: int):
        n = pixels.shape[0]
        x = (pixels.shape[1] // 64) * 64
        y = (pixels.shape[2] // 64) * 64
        if pixels.shape[1] != x or pixels.shape[2] != y:
            pixels = pixels[:,:x,:y,:]
        
        pixels = pixels[:,:,:,:3]
        
        results = []
        for i in trange(0, n, batch_size):
            e = max([i+batch_size, n])
            t = pixels[i:e, ...]
            v = vae.encode(t)
            results.append(v)
        
        vs = torch.cat(results)
        return ({"samples":vs}, )