saliacoel commited on
Commit
d18b98d
·
verified ·
1 Parent(s): 3b94cb9

Upload salia_detailer_ezpz.py

Browse files
Files changed (1) hide show
  1. salia_detailer_ezpz.py +531 -0
salia_detailer_ezpz.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import threading
3
+ from typing import Any, Dict, Tuple, Optional
4
+
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import folder_paths
10
+
11
+
12
+ # -------------------------------------------------------------------------------------
13
+ # Global caches (lazy-load + don't load duplicates across multiple node instances)
14
+ # -------------------------------------------------------------------------------------
15
+
16
+ _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
17
+ _CN_CACHE: Dict[str, Any] = {}
18
+ _CKPT_LOCK = threading.Lock()
19
+ _CN_LOCK = threading.Lock()
20
+
21
+
22
+ # -------------------------------------------------------------------------------------
23
+ # PIL helpers (Lanczos resize for IMAGE and MASK)
24
+ # -------------------------------------------------------------------------------------
25
+
26
+ def _pil_lanczos():
27
+ # Pillow compatibility
28
+ if hasattr(Image, "Resampling"):
29
+ return Image.Resampling.LANCZOS
30
+ return Image.LANCZOS
31
+
32
+
33
+ def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
34
+ """
35
+ Comfy IMAGE: [B,H,W,C] or [H,W,C], float [0..1]
36
+ -> PIL RGB/RGBA
37
+ """
38
+ if img.ndim == 4:
39
+ img = img[0]
40
+ img = img.detach().cpu().float().clamp(0, 1)
41
+ arr = (img.numpy() * 255.0).round().astype(np.uint8)
42
+
43
+ if arr.shape[-1] == 4:
44
+ return Image.fromarray(arr, mode="RGBA")
45
+ return Image.fromarray(arr, mode="RGB")
46
+
47
+
48
+ def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor:
49
+ """
50
+ PIL RGB/RGBA -> Comfy IMAGE [1,H,W,C], float [0..1]
51
+ """
52
+ if pil.mode not in ("RGB", "RGBA"):
53
+ pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB")
54
+ arr = np.array(pil).astype(np.float32) / 255.0
55
+ t = torch.from_numpy(arr) # [H,W,C]
56
+ return t.unsqueeze(0)
57
+
58
+
59
+ def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image:
60
+ """
61
+ Comfy MASK: [B,H,W] or [H,W], float [0..1] -> PIL L
62
+ """
63
+ if mask.ndim == 3:
64
+ mask = mask[0]
65
+ mask = mask.detach().cpu().float().clamp(0, 1)
66
+ arr = (mask.numpy() * 255.0).round().astype(np.uint8)
67
+ return Image.fromarray(arr, mode="L")
68
+
69
+
70
+ def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
71
+ """
72
+ PIL L -> Comfy MASK [1,H,W], float [0..1]
73
+ """
74
+ if pil_l.mode != "L":
75
+ pil_l = pil_l.convert("L")
76
+ arr = np.array(pil_l).astype(np.float32) / 255.0
77
+ t = torch.from_numpy(arr) # [H,W]
78
+ return t.unsqueeze(0)
79
+
80
+
81
+ def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
82
+ """
83
+ Resize Comfy IMAGE [B,H,W,C] with Lanczos via PIL, preserving channels.
84
+ """
85
+ if img.ndim != 4:
86
+ raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
87
+
88
+ outs = []
89
+ for i in range(img.shape[0]):
90
+ pil = _image_tensor_to_pil(img[i].unsqueeze(0))
91
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
92
+ outs.append(_pil_to_image_tensor(pil))
93
+ return torch.cat(outs, dim=0)
94
+
95
+
96
+ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
97
+ """
98
+ Resize Comfy MASK [B,H,W] with Lanczos via PIL.
99
+ """
100
+ if mask.ndim != 3:
101
+ raise ValueError("Expected MASK tensor with shape [B,H,W].")
102
+
103
+ outs = []
104
+ for i in range(mask.shape[0]):
105
+ pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
106
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
107
+ outs.append(_pil_to_mask_tensor(pil))
108
+ return torch.cat(outs, dim=0)
109
+
110
+
111
+ # -------------------------------------------------------------------------------------
112
+ # Core lazy loaders (checkpoint + controlnet), cached globally
113
+ # -------------------------------------------------------------------------------------
114
+
115
+ def _load_checkpoint_cached(ckpt_name: str):
116
+ """
117
+ Mirrors comfy-core CheckpointLoaderSimple, but cached to avoid double-loads.
118
+ Returns: (model, clip, vae)
119
+ """
120
+ with _CKPT_LOCK:
121
+ if ckpt_name in _CKPT_CACHE:
122
+ return _CKPT_CACHE[ckpt_name]
123
+
124
+ import nodes # lazy
125
+ loader = nodes.CheckpointLoaderSimple()
126
+ fn = getattr(loader, loader.FUNCTION)
127
+ model, clip, vae = fn(ckpt_name=ckpt_name)
128
+
129
+ _CKPT_CACHE[ckpt_name] = (model, clip, vae)
130
+ return model, clip, vae
131
+
132
+
133
+ def _load_controlnet_cached(control_net_name: str):
134
+ """
135
+ Mirrors comfy-core ControlNetLoader, but cached to avoid double-loads.
136
+ Returns: controlnet
137
+ """
138
+ with _CN_LOCK:
139
+ if control_net_name in _CN_CACHE:
140
+ return _CN_CACHE[control_net_name]
141
+
142
+ import nodes # lazy
143
+ loader = nodes.ControlNetLoader()
144
+ fn = getattr(loader, loader.FUNCTION)
145
+ (cn,) = fn(control_net_name=control_net_name)
146
+
147
+ _CN_CACHE[control_net_name] = cn
148
+ return cn
149
+
150
+
151
+ # -------------------------------------------------------------------------------------
152
+ # Asset dropdown support (from comfyui-salia_online assets/images)
153
+ # (We still lazy-call the user's LoadImage_SaliaOnline_Assets for consistent mask behavior.)
154
+ # -------------------------------------------------------------------------------------
155
+
156
+ def _list_asset_pngs_fallback():
157
+ # Fallback scanner (if utils import fails)
158
+ try:
159
+ from pathlib import Path
160
+ plugin_root = Path(__file__).resolve().parent.parent
161
+ img_dir = plugin_root / "assets" / "images"
162
+ if not img_dir.exists():
163
+ return []
164
+ files = sorted([p.name for p in img_dir.glob("*.png")])
165
+ return files
166
+ except Exception:
167
+ return []
168
+
169
+
170
+ def _list_asset_pngs():
171
+ try:
172
+ # Prefer your plugin's own list function (same dropdown as your node)
173
+ from ..utils.io import list_pngs # type: ignore
174
+ return list_pngs() or []
175
+ except Exception:
176
+ return _list_asset_pngs_fallback()
177
+
178
+
179
+ def _load_asset_mask(asset_name: str):
180
+ """
181
+ Lazy-import and run your LoadImage_SaliaOnline_Assets node.
182
+ Returns: MASK
183
+ """
184
+ # NOTE: Keep this lazy so importing the plugin doesn't force-load anything.
185
+ from .salia_loadimage_assets import LoadImage_SaliaOnline_Assets # lazy-ish (light)
186
+
187
+ loader = LoadImage_SaliaOnline_Assets()
188
+ img, mask = loader.run(asset_name)
189
+ return mask
190
+
191
+
192
+ def _run_salia_depth(image: torch.Tensor, resolution: int) -> torch.Tensor:
193
+ """
194
+ Lazy-import and run your Salia_Depth node.
195
+ Returns IMAGE (depth)
196
+ """
197
+ from .salia_depth import Salia_Depth # heavy -> lazy import here
198
+
199
+ node = Salia_Depth()
200
+ fn = getattr(node, node.FUNCTION)
201
+ (depth_img,) = fn(image=image, resolution=int(resolution))
202
+ return depth_img
203
+
204
+
205
+ # -------------------------------------------------------------------------------------
206
+ # Alpha-over paste (RGBA square onto base at X,Y)
207
+ # -------------------------------------------------------------------------------------
208
+
209
+ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor:
210
+ """
211
+ base: [B,H,W,C] where C is 3 or 4, float [0..1]
212
+ overlay_rgba: [B,s,s,4] float [0..1]
213
+ """
214
+ if base.ndim != 4 or overlay_rgba.ndim != 4:
215
+ raise ValueError("base and overlay must be [B,H,W,C].")
216
+
217
+ B, H, W, C = base.shape
218
+ b2, sH, sW, c2 = overlay_rgba.shape
219
+ if c2 != 4:
220
+ raise ValueError("overlay_rgba must have 4 channels (RGBA).")
221
+ if sH != sW:
222
+ raise ValueError("overlay must be square.")
223
+ s = sH
224
+
225
+ if x < 0 or y < 0 or x + s > W or y + s > H:
226
+ raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
227
+
228
+ # Match batch
229
+ if b2 != B:
230
+ if b2 == 1 and B > 1:
231
+ overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
232
+ else:
233
+ raise ValueError("Batch mismatch between base and overlay.")
234
+
235
+ out = base.clone()
236
+
237
+ overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1)
238
+ overlay_a = overlay_rgba[..., 3:4].clamp(0, 1)
239
+
240
+ base_rgb = out[:, y:y + s, x:x + s, 0:3]
241
+ comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
242
+ out[:, y:y + s, x:x + s, 0:3] = comp_rgb
243
+
244
+ # If base has alpha, composite alpha too (optional)
245
+ if C == 4:
246
+ base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
247
+ comp_a = overlay_a + base_a * (1.0 - overlay_a)
248
+ out[:, y:y + s, x:x + s, 3:4] = comp_a
249
+
250
+ return out.clamp(0, 1)
251
+
252
+
253
+ # -------------------------------------------------------------------------------------
254
+ # The One-Node Workflow
255
+ # -------------------------------------------------------------------------------------
256
+
257
+ class Salia_Detailer_EZPZ:
258
+ """
259
+ One node that replicates the workflow you described.
260
+ """
261
+
262
+ CATEGORY = "image/salia"
263
+ RETURN_TYPES = ("IMAGE",)
264
+ RETURN_NAMES = ("image",)
265
+ FUNCTION = "run"
266
+
267
+ @classmethod
268
+ def INPUT_TYPES(cls):
269
+ # Dropdowns
270
+ ckpts = folder_paths.get_filename_list("checkpoints") or ["<no checkpoints found>"]
271
+ cns = folder_paths.get_filename_list("controlnet") or ["<no controlnets found>"]
272
+ assets = _list_asset_pngs() or ["<no pngs found>"]
273
+
274
+ # KSampler dropdowns (match comfy-core)
275
+ try:
276
+ import comfy.samplers
277
+ sampler_names = comfy.samplers.KSampler.SAMPLERS
278
+ scheduler_names = comfy.samplers.KSampler.SCHEDULERS
279
+ except Exception:
280
+ sampler_names = ["euler"]
281
+ scheduler_names = ["karras"]
282
+
283
+ # Upscale dropdown as requested
284
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
285
+
286
+ return {
287
+ "required": {
288
+ "image": ("IMAGE",),
289
+
290
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
291
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
292
+ "square_size": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
293
+
294
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
295
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
296
+
297
+ "upscale_factor": (upscale_choices, {"default": "4"}),
298
+
299
+ # 3 dropdown menus you requested
300
+ "ckpt_name": (ckpts, {}),
301
+ "control_net_name": (cns, {}),
302
+ "asset_image": (assets, {}),
303
+
304
+ # ControlNet params
305
+ "controlnet_strength": ("FLOAT", {"default": 0.33, "min": 0.00, "max": 10.00, "step": 0.01}),
306
+ "controlnet_start_percent": ("FLOAT", {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.01}),
307
+ "controlnet_end_percent": ("FLOAT", {"default": 1.00, "min": 0.00, "max": 1.00, "step": 0.01}),
308
+
309
+ # KSampler params
310
+ "steps": ("INT", {"default": 30, "min": 1, "max": 200, "step": 1}),
311
+ "cfg": ("FLOAT", {"default": 2.6, "min": 0.00, "max": 10.00, "step": 0.05}),
312
+ "sampler_name": (sampler_names, {"default": "euler"} if "euler" in sampler_names else {}),
313
+ "scheduler": (scheduler_names, {"default": "karras"} if "karras" in scheduler_names else {}),
314
+ "denoise": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
315
+ }
316
+ }
317
+
318
+ def run(
319
+ self,
320
+ image: torch.Tensor,
321
+ X_coord: int,
322
+ Y_coord: int,
323
+ square_size: int,
324
+ positive_prompt: str,
325
+ negative_prompt: str,
326
+ upscale_factor: str, # dropdown returns str
327
+ ckpt_name: str,
328
+ control_net_name: str,
329
+ asset_image: str,
330
+ controlnet_strength: float,
331
+ controlnet_start_percent: float,
332
+ controlnet_end_percent: float,
333
+ steps: int,
334
+ cfg: float,
335
+ sampler_name: str,
336
+ scheduler: str,
337
+ denoise: float,
338
+ ):
339
+ # -------------------------
340
+ # Validate / normalize
341
+ # -------------------------
342
+ if image.ndim == 3:
343
+ image = image.unsqueeze(0)
344
+
345
+ if image.ndim != 4:
346
+ raise ValueError("Input image must be [B,H,W,C].")
347
+
348
+ B, H, W, C = image.shape
349
+ if C not in (3, 4):
350
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
351
+
352
+ x = int(X_coord)
353
+ y = int(Y_coord)
354
+ s = int(square_size)
355
+
356
+ up = int(upscale_factor)
357
+ if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
358
+ raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
359
+
360
+ if s <= 0:
361
+ raise ValueError("square_size must be > 0")
362
+
363
+ if x < 0 or y < 0 or x + s > W or y + s > H:
364
+ raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
365
+
366
+ up_w = s * up
367
+ up_h = s * up
368
+
369
+ # VAE/UNet path is happiest with multiples of 8
370
+ if (up_w % 8) != 0 or (up_h % 8) != 0:
371
+ raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
372
+
373
+ # Clamp controlnet percent range
374
+ start_p = float(max(0.0, min(1.0, controlnet_start_percent)))
375
+ end_p = float(max(0.0, min(1.0, controlnet_end_percent)))
376
+ if end_p < start_p:
377
+ start_p, end_p = end_p, start_p
378
+
379
+ # -------------------------
380
+ # 1) Crop square (we use it twice internally)
381
+ # -------------------------
382
+ crop = image[:, y:y + s, x:x + s, :]
383
+ crop_rgb = crop[:, :, :, 0:3].contiguous() # force RGB for model/depth
384
+
385
+ # -------------------------
386
+ # 2) Depth path: Salia_Depth(crop) then upscale depth with Lanczos
387
+ # -------------------------
388
+ depth_small = _run_salia_depth(crop_rgb, resolution=s)
389
+ depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
390
+
391
+ # -------------------------
392
+ # 3) Generation path: upscale crop with Lanczos then VAE Encode
393
+ # -------------------------
394
+ crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
395
+
396
+ # -------------------------
397
+ # 4) Load asset mask (dropdown) and resize it to match upscaled resolution
398
+ # -------------------------
399
+ if asset_image == "<no pngs found>":
400
+ raise FileNotFoundError("No PNGs found in comfyui-salia_online/assets/images")
401
+
402
+ asset_mask = _load_asset_mask(asset_image) # MASK
403
+ if asset_mask.ndim == 2:
404
+ asset_mask = asset_mask.unsqueeze(0)
405
+ if asset_mask.ndim != 3:
406
+ raise ValueError("Asset mask must be [B,H,W].")
407
+
408
+ # Match batch
409
+ if asset_mask.shape[0] != B:
410
+ if asset_mask.shape[0] == 1 and B > 1:
411
+ asset_mask = asset_mask.expand(B, -1, -1)
412
+ else:
413
+ raise ValueError("Batch mismatch for asset mask.")
414
+
415
+ asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
416
+
417
+ # -------------------------
418
+ # 5) Load checkpoint + controlnet (lazy + cached)
419
+ # -------------------------
420
+ if ckpt_name == "<no checkpoints found>":
421
+ raise FileNotFoundError("No checkpoints found in your ComfyUI models/checkpoints folder.")
422
+
423
+ if control_net_name == "<no controlnets found>":
424
+ raise FileNotFoundError("No controlnets found in your ComfyUI models/controlnet folder.")
425
+
426
+ model, clip, vae = _load_checkpoint_cached(ckpt_name)
427
+ controlnet = _load_controlnet_cached(control_net_name)
428
+
429
+ # -------------------------
430
+ # 6) Encode prompts (CLIPTextEncode)
431
+ # -------------------------
432
+ import nodes # lazy
433
+
434
+ pos_enc = nodes.CLIPTextEncode()
435
+ neg_enc = nodes.CLIPTextEncode()
436
+ pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
437
+ neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
438
+
439
+ (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
440
+ (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
441
+
442
+ # -------------------------
443
+ # 7) Apply ControlNet (ControlNetApplyAdvanced)
444
+ # -------------------------
445
+ cn_apply = nodes.ControlNetApplyAdvanced()
446
+ cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
447
+
448
+ pos_cn, neg_cn = cn_fn(
449
+ strength=float(controlnet_strength),
450
+ start_percent=float(start_p),
451
+ end_percent=float(end_p),
452
+ positive=pos_cond,
453
+ negative=neg_cond,
454
+ control_net=controlnet,
455
+ image=depth_up,
456
+ vae=vae,
457
+ )
458
+
459
+ # -------------------------
460
+ # 8) VAE Encode (crop_up) -> latent
461
+ # -------------------------
462
+ vae_enc = nodes.VAEEncode()
463
+ vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
464
+ (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
465
+
466
+ # -------------------------
467
+ # 9) KSampler
468
+ # -------------------------
469
+ # No seed input requested: derive a stable seed from inputs so changing anything changes seed.
470
+ seed_material = (
471
+ f"{ckpt_name}|{control_net_name}|{asset_image}|{x}|{y}|{s}|{up}|"
472
+ f"{steps}|{cfg}|{sampler_name}|{scheduler}|{denoise}|"
473
+ f"{controlnet_strength}|{start_p}|{end_p}|"
474
+ f"{positive_prompt}|{negative_prompt}"
475
+ ).encode("utf-8", errors="ignore")
476
+ seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
477
+
478
+ ksampler = nodes.KSampler()
479
+ k_fn = getattr(ksampler, ksampler.FUNCTION)
480
+ (sampled_latent,) = k_fn(
481
+ seed=seed64,
482
+ steps=int(steps),
483
+ cfg=float(cfg),
484
+ sampler_name=str(sampler_name),
485
+ scheduler=str(scheduler),
486
+ denoise=float(denoise),
487
+ model=model,
488
+ positive=pos_cn,
489
+ negative=neg_cn,
490
+ latent_image=latent,
491
+ )
492
+
493
+ # -------------------------
494
+ # 10) VAE Decode -> RGB image
495
+ # -------------------------
496
+ vae_dec = nodes.VAEDecode()
497
+ vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
498
+ (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
499
+
500
+ # -------------------------
501
+ # 11) JoinImageWithAlpha (decoded_rgb + asset_mask_up) -> RGBA
502
+ # -------------------------
503
+ join = nodes.JoinImageWithAlpha()
504
+ join_fn = getattr(join, join.FUNCTION)
505
+
506
+ # Some Comfy versions name the mask input "alpha", others "mask".
507
+ try:
508
+ (rgba_up,) = join_fn(image=decoded_rgb, alpha=asset_mask_up)
509
+ except TypeError:
510
+ (rgba_up,) = join_fn(image=decoded_rgb, mask=asset_mask_up)
511
+
512
+ # -------------------------
513
+ # 12) Downscale RGBA back to original crop resolution (square_size) with Lanczos
514
+ # -------------------------
515
+ rgba_square = _resize_image_lanczos(rgba_up, s, s)
516
+
517
+ # -------------------------
518
+ # 13) Paste RGBA square onto original input image at X,Y using alpha-over
519
+ # -------------------------
520
+ out = _alpha_over_region(image, rgba_square, x=x, y=y)
521
+
522
+ return (out,)
523
+
524
+
525
+ NODE_CLASS_MAPPINGS = {
526
+ "Salia_Detailer_EZPZ": Salia_Detailer_EZPZ,
527
+ }
528
+
529
+ NODE_DISPLAY_NAME_MAPPINGS = {
530
+ "Salia_Detailer_EZPZ": "Salia_Detailer_EZPZ",
531
+ }