zk-Armor commited on
Commit
0d41590
·
verified ·
1 Parent(s): d97958a

Upload ComfyUI/latent_preview.py

Browse files
Files changed (1) hide show
  1. ComfyUI/latent_preview.py +137 -0
ComfyUI/latent_preview.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from comfy.cli_args import args, LatentPreviewMethod
4
+ from comfy.taesd.taesd import TAESD
5
+ from comfy.sd import VAE
6
+ import comfy.model_management
7
+ import folder_paths
8
+ import comfy.utils
9
+ import logging
10
+
11
+ default_preview_method = args.preview_method
12
+
13
+ MAX_PREVIEW_RESOLUTION = args.preview_size
14
+ VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
15
+
16
+ def preview_to_image(latent_image, do_scale=True):
17
+ if do_scale:
18
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
19
+ .mul(0xFF) # to 0..255
20
+ )
21
+ else:
22
+ latents_ubyte = (latent_image.clamp(0, 1)
23
+ .mul(0xFF) # to 0..255
24
+ )
25
+ if comfy.model_management.directml_enabled:
26
+ latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
27
+ latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
28
+
29
+ return Image.fromarray(latents_ubyte.numpy())
30
+
31
+ class LatentPreviewer:
32
+ def decode_latent_to_preview(self, x0):
33
+ pass
34
+
35
+ def decode_latent_to_preview_image(self, preview_format, x0):
36
+ preview_image = self.decode_latent_to_preview(x0)
37
+ return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
38
+
39
+ class TAESDPreviewerImpl(LatentPreviewer):
40
+ def __init__(self, taesd):
41
+ self.taesd = taesd
42
+
43
+ def decode_latent_to_preview(self, x0):
44
+ x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
45
+ return preview_to_image(x_sample)
46
+
47
+ class TAEHVPreviewerImpl(TAESDPreviewerImpl):
48
+ def decode_latent_to_preview(self, x0):
49
+ x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
50
+ return preview_to_image(x_sample, do_scale=False)
51
+
52
+ class Latent2RGBPreviewer(LatentPreviewer):
53
+ def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
54
+ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
55
+ self.latent_rgb_factors_bias = None
56
+ if latent_rgb_factors_bias is not None:
57
+ self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
58
+ self.latent_rgb_factors_reshape = latent_rgb_factors_reshape
59
+
60
+ def decode_latent_to_preview(self, x0):
61
+ if self.latent_rgb_factors_reshape is not None:
62
+ x0 = self.latent_rgb_factors_reshape(x0)
63
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
64
+ if self.latent_rgb_factors_bias is not None:
65
+ self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
66
+
67
+ if x0.ndim == 5:
68
+ x0 = x0[0, :, 0]
69
+ else:
70
+ x0 = x0[0]
71
+
72
+ latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
73
+ # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
74
+
75
+ return preview_to_image(latent_image)
76
+
77
+
78
+ def get_previewer(device, latent_format):
79
+ previewer = None
80
+ method = args.preview_method
81
+ if method != LatentPreviewMethod.NoPreviews:
82
+ # TODO previewer methods
83
+ taesd_decoder_path = None
84
+ if latent_format.taesd_decoder_name is not None:
85
+ taesd_decoder_path = next(
86
+ (fn for fn in folder_paths.get_filename_list("vae_approx")
87
+ if fn.startswith(latent_format.taesd_decoder_name)),
88
+ ""
89
+ )
90
+ taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
91
+
92
+ if method == LatentPreviewMethod.Auto:
93
+ method = LatentPreviewMethod.Latent2RGB
94
+
95
+ if method == LatentPreviewMethod.TAESD:
96
+ if taesd_decoder_path:
97
+ if latent_format.taesd_decoder_name in VIDEO_TAES:
98
+ taesd = VAE(comfy.utils.load_torch_file(taesd_decoder_path))
99
+ taesd.first_stage_model.show_progress_bar = False
100
+ previewer = TAEHVPreviewerImpl(taesd)
101
+ else:
102
+ taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
103
+ previewer = TAESDPreviewerImpl(taesd)
104
+ else:
105
+ logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
106
+
107
+ if previewer is None:
108
+ if latent_format.latent_rgb_factors is not None:
109
+ previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
110
+ return previewer
111
+
112
+ def prepare_callback(model, steps, x0_output_dict=None):
113
+ preview_format = "JPEG"
114
+ if preview_format not in ["JPEG", "PNG"]:
115
+ preview_format = "JPEG"
116
+
117
+ previewer = get_previewer(model.load_device, model.model.latent_format)
118
+
119
+ pbar = comfy.utils.ProgressBar(steps)
120
+ def callback(step, x0, x, total_steps):
121
+ if x0_output_dict is not None:
122
+ x0_output_dict["x0"] = x0
123
+
124
+ preview_bytes = None
125
+ if previewer:
126
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
127
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
128
+ return callback
129
+
130
+ def set_preview_method(override: str = None):
131
+ if override and override != "default":
132
+ method = LatentPreviewMethod.from_string(override)
133
+ if method is not None:
134
+ args.preview_method = method
135
+ return
136
+ args.preview_method = default_preview_method
137
+