saliacoel commited on
Commit
0911006
·
verified ·
1 Parent(s): 41e34bb

Upload salia_detailer_ezpz.py

Browse files
Files changed (1) hide show
  1. salia_detailer_ezpz.py +403 -162
salia_detailer_ezpz.py CHANGED
@@ -1,12 +1,24 @@
1
  import hashlib
 
2
  import threading
 
 
3
  from typing import Any, Dict, Tuple, Optional
4
 
5
- import torch
6
  import numpy as np
 
7
  from PIL import Image, ImageOps
8
 
9
  import folder_paths
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  # -------------------------------------------------------------------------------------
@@ -19,12 +31,31 @@ _CKPT_LOCK = threading.Lock()
19
  _CN_LOCK = threading.Lock()
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # -------------------------------------------------------------------------------------
23
  # PIL helpers (Lanczos resize for IMAGE and MASK)
24
  # -------------------------------------------------------------------------------------
25
 
26
  def _pil_lanczos():
27
- # Pillow compatibility
28
  if hasattr(Image, "Resampling"):
29
  return Image.Resampling.LANCZOS
30
  return Image.LANCZOS
@@ -32,14 +63,12 @@ def _pil_lanczos():
32
 
33
  def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
34
  """
35
- Comfy IMAGE: [B,H,W,C] or [H,W,C], float [0..1]
36
- -> PIL RGB/RGBA
37
  """
38
  if img.ndim == 4:
39
  img = img[0]
40
  img = img.detach().cpu().float().clamp(0, 1)
41
  arr = (img.numpy() * 255.0).round().astype(np.uint8)
42
-
43
  if arr.shape[-1] == 4:
44
  return Image.fromarray(arr, mode="RGBA")
45
  return Image.fromarray(arr, mode="RGB")
@@ -80,11 +109,10 @@ def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
80
 
81
  def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
82
  """
83
- Resize Comfy IMAGE [B,H,W,C] with Lanczos via PIL, preserving channels.
84
  """
85
  if img.ndim != 4:
86
  raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
87
-
88
  outs = []
89
  for i in range(img.shape[0]):
90
  pil = _image_tensor_to_pil(img[i].unsqueeze(0))
@@ -99,7 +127,6 @@ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
99
  """
100
  if mask.ndim != 3:
101
  raise ValueError("Expected MASK tensor with shape [B,H,W].")
102
-
103
  outs = []
104
  for i in range(mask.shape[0]):
105
  pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
@@ -121,7 +148,7 @@ def _load_checkpoint_cached(ckpt_name: str):
121
  if ckpt_name in _CKPT_CACHE:
122
  return _CKPT_CACHE[ckpt_name]
123
 
124
- import nodes # lazy
125
  loader = nodes.CheckpointLoaderSimple()
126
  fn = getattr(loader, loader.FUNCTION)
127
  model, clip, vae = fn(ckpt_name=ckpt_name)
@@ -139,7 +166,7 @@ def _load_controlnet_cached(control_net_name: str):
139
  if control_net_name in _CN_CACHE:
140
  return _CN_CACHE[control_net_name]
141
 
142
- import nodes # lazy
143
  loader = nodes.ControlNetLoader()
144
  fn = getattr(loader, loader.FUNCTION)
145
  (cn,) = fn(control_net_name=control_net_name)
@@ -149,110 +176,60 @@ def _load_controlnet_cached(control_net_name: str):
149
 
150
 
151
  # -------------------------------------------------------------------------------------
152
- # Assets/images dropdown + loader (INLINED, no LoadImage_SaliaOnline_Assets dependency)
153
  # -------------------------------------------------------------------------------------
154
 
155
- _ASSETS_DIR_CACHE: Optional["object"] = None
156
- _ASSETS_DIR_LOCK = threading.Lock()
157
 
158
 
159
- def _find_assets_images_dir():
160
- """
161
- Find the plugin's assets/images folder by walking upward from this file.
162
- This is robust even if Comfy imports modules in weird ways.
163
- """
164
- from pathlib import Path
165
-
166
- here = Path(__file__).resolve()
167
- # check a few levels up; plugin root should be near
168
- for parent in [here.parent] + list(here.parents)[:8]:
169
- candidate = parent / "assets" / "images"
170
- if candidate.is_dir():
171
- return candidate
172
- return None
173
-
174
-
175
- def _assets_images_dir():
176
- global _ASSETS_DIR_CACHE
177
- with _ASSETS_DIR_LOCK:
178
- if _ASSETS_DIR_CACHE is not None:
179
- # If it was found once, reuse.
180
- try:
181
- if _ASSETS_DIR_CACHE.is_dir():
182
- return _ASSETS_DIR_CACHE
183
- except Exception:
184
- pass
185
-
186
- found = _find_assets_images_dir()
187
- _ASSETS_DIR_CACHE = found
188
- return found
189
-
190
-
191
- def _list_asset_pngs():
192
- """
193
- List PNGs inside assets/images (recursive), returning paths relative to assets/images.
194
- """
195
  img_dir = _assets_images_dir()
196
- if img_dir is None:
197
  return []
198
-
199
  files = []
200
- try:
201
- for p in img_dir.rglob("*"):
202
- if p.is_file() and p.suffix.lower() == ".png":
203
- rel = p.relative_to(img_dir).as_posix()
204
- files.append(rel)
205
- files.sort()
206
- return files
207
- except Exception:
208
- return []
209
 
210
 
211
- def _safe_asset_path(asset_rel_path: str):
212
- """
213
- Resolve a selected dropdown entry to an actual file path inside assets/images.
214
- Prevents path traversal.
215
- """
216
- from pathlib import Path
217
-
218
  img_dir = _assets_images_dir()
219
- if img_dir is None:
220
- raise FileNotFoundError("assets/images folder not found (could not locate plugin assets).")
221
 
222
  base = img_dir.resolve()
223
  rel = Path(asset_rel_path)
224
-
225
  if rel.is_absolute():
226
  raise ValueError("Absolute paths are not allowed for asset_image.")
227
 
228
- # Resolve and verify containment
229
  full = (base / rel).resolve()
230
  if base != full and base not in full.parents:
231
  raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
232
 
233
  if not full.is_file():
234
  raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
235
-
236
  if full.suffix.lower() != ".png":
237
  raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
238
 
239
  return full
240
 
241
 
242
- def _load_asset_image_and_mask(asset_rel_path: str):
243
  """
244
- Load PNG from assets/images and return (IMAGE, MASK) in ComfyUI formats.
245
 
246
- IMPORTANT: Mask semantics match ComfyUI core LoadImage:
247
  - If PNG has alpha: mask = 1 - alpha
248
- - If no alpha: mask = 0 (opaque)
249
  """
250
  p = _safe_asset_path(asset_rel_path)
251
 
252
  im = Image.open(p)
253
  im = ImageOps.exif_transpose(im)
254
 
255
- # Ensure we can extract alpha if present
256
  had_alpha = ("A" in im.getbands())
257
  rgba = im.convert("RGBA")
258
  rgb = rgba.convert("RGB")
@@ -261,8 +238,8 @@ def _load_asset_image_and_mask(asset_rel_path: str):
261
  img_t = torch.from_numpy(rgb_arr)[None, ...]
262
 
263
  if had_alpha:
264
- alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 # [H,W], 1=opaque
265
- mask = 1.0 - alpha # Comfy MASK convention
266
  else:
267
  h, w = rgb.size[1], rgb.size[0]
268
  mask = np.zeros((h, w), dtype=np.float32)
@@ -272,20 +249,334 @@ def _load_asset_image_and_mask(asset_rel_path: str):
272
 
273
 
274
  # -------------------------------------------------------------------------------------
275
- # Salia_Depth (still lazy import, unchanged)
276
  # -------------------------------------------------------------------------------------
277
 
278
- def _run_salia_depth(image: torch.Tensor, resolution: int) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  """
280
- Lazy-import and run your Salia_Depth node.
281
- Returns IMAGE (depth)
 
282
  """
283
- from .salia_depth import Salia_Depth # heavy -> lazy import here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
- node = Salia_Depth()
286
- fn = getattr(node, node.FUNCTION)
287
- (depth_img,) = fn(image=image, resolution=int(resolution))
288
- return depth_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
 
291
  # -------------------------------------------------------------------------------------
@@ -327,7 +618,7 @@ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y
327
  comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
328
  out[:, y:y + s, x:x + s, 0:3] = comp_rgb
329
 
330
- # If base has alpha, composite alpha too (optional)
331
  if C == 4:
332
  base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
333
  comp_a = overlay_a + base_a * (1.0 - overlay_a)
@@ -340,11 +631,7 @@ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y
340
  # The One-Node Workflow
341
  # -------------------------------------------------------------------------------------
342
 
343
- class Salia_Detailer_EZPZ:
344
- """
345
- One node that replicates the workflow you described.
346
- """
347
-
348
  CATEGORY = "image/salia"
349
  RETURN_TYPES = ("IMAGE",)
350
  RETURN_NAMES = ("image",)
@@ -352,12 +639,10 @@ class Salia_Detailer_EZPZ:
352
 
353
  @classmethod
354
  def INPUT_TYPES(cls):
355
- # Dropdowns
356
  ckpts = folder_paths.get_filename_list("checkpoints") or ["<no checkpoints found>"]
357
  cns = folder_paths.get_filename_list("controlnet") or ["<no controlnets found>"]
358
  assets = _list_asset_pngs() or ["<no pngs found>"]
359
 
360
- # KSampler dropdowns (match comfy-core)
361
  try:
362
  import comfy.samplers
363
  sampler_names = comfy.samplers.KSampler.SAMPLERS
@@ -366,7 +651,6 @@ class Salia_Detailer_EZPZ:
366
  sampler_names = ["euler"]
367
  scheduler_names = ["karras"]
368
 
369
- # Upscale dropdown as requested
370
  upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
371
 
372
  return {
@@ -382,17 +666,14 @@ class Salia_Detailer_EZPZ:
382
 
383
  "upscale_factor": (upscale_choices, {"default": "4"}),
384
 
385
- # 3 dropdown menus you requested
386
  "ckpt_name": (ckpts, {}),
387
  "control_net_name": (cns, {}),
388
  "asset_image": (assets, {}),
389
 
390
- # ControlNet params
391
  "controlnet_strength": ("FLOAT", {"default": 0.33, "min": 0.00, "max": 10.00, "step": 0.01}),
392
  "controlnet_start_percent": ("FLOAT", {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.01}),
393
  "controlnet_end_percent": ("FLOAT", {"default": 1.00, "min": 0.00, "max": 1.00, "step": 0.01}),
394
 
395
- # KSampler params
396
  "steps": ("INT", {"default": 30, "min": 1, "max": 200, "step": 1}),
397
  "cfg": ("FLOAT", {"default": 2.6, "min": 0.00, "max": 10.00, "step": 0.05}),
398
  "sampler_name": (sampler_names, {"default": "euler"} if "euler" in sampler_names else {}),
@@ -409,7 +690,7 @@ class Salia_Detailer_EZPZ:
409
  square_size: int,
410
  positive_prompt: str,
411
  negative_prompt: str,
412
- upscale_factor: str, # dropdown returns str
413
  ckpt_name: str,
414
  control_net_name: str,
415
  asset_image: str,
@@ -422,12 +703,8 @@ class Salia_Detailer_EZPZ:
422
  scheduler: str,
423
  denoise: float,
424
  ):
425
- # -------------------------
426
- # Validate / normalize
427
- # -------------------------
428
  if image.ndim == 3:
429
  image = image.unsqueeze(0)
430
-
431
  if image.ndim != 4:
432
  raise ValueError("Input image must be [B,H,W,C].")
433
 
@@ -442,56 +719,43 @@ class Salia_Detailer_EZPZ:
442
  up = int(upscale_factor)
443
  if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
444
  raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
445
-
446
  if s <= 0:
447
  raise ValueError("square_size must be > 0")
448
-
449
  if x < 0 or y < 0 or x + s > W or y + s > H:
450
  raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
451
 
452
  up_w = s * up
453
  up_h = s * up
454
 
455
- # VAE/UNet path is happiest with multiples of 8
456
  if (up_w % 8) != 0 or (up_h % 8) != 0:
457
  raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
458
 
459
- # Clamp controlnet percent range
460
  start_p = float(max(0.0, min(1.0, controlnet_start_percent)))
461
  end_p = float(max(0.0, min(1.0, controlnet_end_percent)))
462
  if end_p < start_p:
463
  start_p, end_p = end_p, start_p
464
 
465
- # -------------------------
466
- # 1) Crop square (we use it twice internally)
467
- # -------------------------
468
  crop = image[:, y:y + s, x:x + s, :]
469
- crop_rgb = crop[:, :, :, 0:3].contiguous() # force RGB for model/depth
470
 
471
- # -------------------------
472
- # 2) Depth path: Salia_Depth(crop) then upscale depth with Lanczos
473
- # -------------------------
474
- depth_small = _run_salia_depth(crop_rgb, resolution=s)
475
  depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
476
 
477
- # -------------------------
478
- # 3) Generation path: upscale crop with Lanczos then VAE Encode
479
- # -------------------------
480
  crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
481
 
482
- # -------------------------
483
- # 4) Load asset mask (INLINE assets loader) and resize to match upscaled resolution
484
- # -------------------------
485
  if asset_image == "<no pngs found>":
486
  raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
 
487
 
488
- _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image) # MASK is what we need
489
  if asset_mask.ndim == 2:
490
  asset_mask = asset_mask.unsqueeze(0)
491
  if asset_mask.ndim != 3:
492
  raise ValueError("Asset mask must be [B,H,W].")
493
 
494
- # Match batch
495
  if asset_mask.shape[0] != B:
496
  if asset_mask.shape[0] == 1 and B > 1:
497
  asset_mask = asset_mask.expand(B, -1, -1)
@@ -500,37 +764,28 @@ class Salia_Detailer_EZPZ:
500
 
501
  asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
502
 
503
- # -------------------------
504
- # 5) Load checkpoint + controlnet (lazy + cached)
505
- # -------------------------
506
  if ckpt_name == "<no checkpoints found>":
507
- raise FileNotFoundError("No checkpoints found in your ComfyUI models/checkpoints folder.")
508
-
509
  if control_net_name == "<no controlnets found>":
510
- raise FileNotFoundError("No controlnets found in your ComfyUI models/controlnet folder.")
511
 
512
  model, clip, vae = _load_checkpoint_cached(ckpt_name)
513
  controlnet = _load_controlnet_cached(control_net_name)
514
 
515
- # -------------------------
516
- # 6) Encode prompts (CLIPTextEncode)
517
- # -------------------------
518
- import nodes # lazy
519
 
 
520
  pos_enc = nodes.CLIPTextEncode()
521
  neg_enc = nodes.CLIPTextEncode()
522
  pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
523
  neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
524
-
525
  (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
526
  (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
527
 
528
- # -------------------------
529
- # 7) Apply ControlNet (ControlNetApplyAdvanced)
530
- # -------------------------
531
  cn_apply = nodes.ControlNetApplyAdvanced()
532
  cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
533
-
534
  pos_cn, neg_cn = cn_fn(
535
  strength=float(controlnet_strength),
536
  start_percent=float(start_p),
@@ -542,16 +797,12 @@ class Salia_Detailer_EZPZ:
542
  vae=vae,
543
  )
544
 
545
- # -------------------------
546
- # 8) VAE Encode (crop_up) -> latent
547
- # -------------------------
548
  vae_enc = nodes.VAEEncode()
549
  vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
550
  (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
551
 
552
- # -------------------------
553
- # 9) KSampler
554
- # -------------------------
555
  seed_material = (
556
  f"{ckpt_name}|{control_net_name}|{asset_image}|{x}|{y}|{s}|{up}|"
557
  f"{steps}|{cfg}|{sampler_name}|{scheduler}|{denoise}|"
@@ -575,41 +826,31 @@ class Salia_Detailer_EZPZ:
575
  latent_image=latent,
576
  )
577
 
578
- # -------------------------
579
- # 10) VAE Decode -> RGB image
580
- # -------------------------
581
  vae_dec = nodes.VAEDecode()
582
  vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
583
  (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
584
 
585
- # -------------------------
586
- # 11) JoinImageWithAlpha (decoded_rgb + asset_mask_up) -> RGBA
587
- # -------------------------
588
  join = nodes.JoinImageWithAlpha()
589
  join_fn = getattr(join, join.FUNCTION)
590
-
591
  try:
592
  (rgba_up,) = join_fn(image=decoded_rgb, alpha=asset_mask_up)
593
  except TypeError:
594
  (rgba_up,) = join_fn(image=decoded_rgb, mask=asset_mask_up)
595
 
596
- # -------------------------
597
- # 12) Downscale RGBA back to original crop resolution (square_size) with Lanczos
598
- # -------------------------
599
  rgba_square = _resize_image_lanczos(rgba_up, s, s)
600
 
601
- # -------------------------
602
- # 13) Paste RGBA square onto original input image at X,Y using alpha-over
603
- # -------------------------
604
  out = _alpha_over_region(image, rgba_square, x=x, y=y)
605
-
606
  return (out,)
607
 
608
 
609
  NODE_CLASS_MAPPINGS = {
610
- "Salia_Detailer_EZPZ": Salia_Detailer_EZPZ,
611
  }
612
 
613
  NODE_DISPLAY_NAME_MAPPINGS = {
614
- "Salia_Detailer_EZPZ": "Salia_Detailer_EZPZ",
615
  }
 
1
  import hashlib
2
+ import shutil
3
  import threading
4
+ import urllib.request
5
+ from pathlib import Path
6
  from typing import Any, Dict, Tuple, Optional
7
 
 
8
  import numpy as np
9
+ import torch
10
  from PIL import Image, ImageOps
11
 
12
  import folder_paths
13
+ import comfy.model_management as model_management
14
+
15
+
16
+ # transformers is required for depth-estimation pipeline
17
+ try:
18
+ from transformers import pipeline
19
+ except Exception as e:
20
+ pipeline = None
21
+ _TRANSFORMERS_IMPORT_ERROR = e
22
 
23
 
24
  # -------------------------------------------------------------------------------------
 
31
  _CN_LOCK = threading.Lock()
32
 
33
 
34
+ # -------------------------------------------------------------------------------------
35
+ # Plugin root detection (robust against hyphen/underscore module naming)
36
+ # -------------------------------------------------------------------------------------
37
+
38
+ def _find_plugin_root() -> Path:
39
+ """
40
+ Walk upwards from this file until we find an 'assets' folder.
41
+ This works regardless of how Comfy names the python module.
42
+ """
43
+ here = Path(__file__).resolve()
44
+ for parent in [here.parent] + list(here.parents)[:10]:
45
+ if (parent / "assets").is_dir():
46
+ return parent
47
+ # fallback: typical layout nodes/<thisfile>.py -> plugin root is parent.parent
48
+ return here.parent.parent
49
+
50
+
51
+ PLUGIN_ROOT = _find_plugin_root()
52
+
53
+
54
  # -------------------------------------------------------------------------------------
55
  # PIL helpers (Lanczos resize for IMAGE and MASK)
56
  # -------------------------------------------------------------------------------------
57
 
58
  def _pil_lanczos():
 
59
  if hasattr(Image, "Resampling"):
60
  return Image.Resampling.LANCZOS
61
  return Image.LANCZOS
 
63
 
64
  def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
65
  """
66
+ Comfy IMAGE: [B,H,W,C] or [H,W,C], float [0..1] -> PIL RGB/RGBA
 
67
  """
68
  if img.ndim == 4:
69
  img = img[0]
70
  img = img.detach().cpu().float().clamp(0, 1)
71
  arr = (img.numpy() * 255.0).round().astype(np.uint8)
 
72
  if arr.shape[-1] == 4:
73
  return Image.fromarray(arr, mode="RGBA")
74
  return Image.fromarray(arr, mode="RGB")
 
109
 
110
  def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
111
  """
112
+ Resize Comfy IMAGE [B,H,W,C] with Lanczos via PIL.
113
  """
114
  if img.ndim != 4:
115
  raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
 
116
  outs = []
117
  for i in range(img.shape[0]):
118
  pil = _image_tensor_to_pil(img[i].unsqueeze(0))
 
127
  """
128
  if mask.ndim != 3:
129
  raise ValueError("Expected MASK tensor with shape [B,H,W].")
 
130
  outs = []
131
  for i in range(mask.shape[0]):
132
  pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
 
148
  if ckpt_name in _CKPT_CACHE:
149
  return _CKPT_CACHE[ckpt_name]
150
 
151
+ import nodes
152
  loader = nodes.CheckpointLoaderSimple()
153
  fn = getattr(loader, loader.FUNCTION)
154
  model, clip, vae = fn(ckpt_name=ckpt_name)
 
166
  if control_net_name in _CN_CACHE:
167
  return _CN_CACHE[control_net_name]
168
 
169
+ import nodes
170
  loader = nodes.ControlNetLoader()
171
  fn = getattr(loader, loader.FUNCTION)
172
  (cn,) = fn(control_net_name=control_net_name)
 
176
 
177
 
178
  # -------------------------------------------------------------------------------------
179
+ # Assets/images dropdown + loader (inlined; no LoadImage_SaliaOnline_Assets dependency)
180
  # -------------------------------------------------------------------------------------
181
 
182
+ def _assets_images_dir() -> Path:
183
+ return PLUGIN_ROOT / "assets" / "images"
184
 
185
 
186
+ def _list_asset_pngs() -> list:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  img_dir = _assets_images_dir()
188
+ if not img_dir.is_dir():
189
  return []
 
190
  files = []
191
+ for p in img_dir.rglob("*"):
192
+ if p.is_file() and p.suffix.lower() == ".png":
193
+ files.append(p.relative_to(img_dir).as_posix())
194
+ files.sort()
195
+ return files
 
 
 
 
196
 
197
 
198
+ def _safe_asset_path(asset_rel_path: str) -> Path:
 
 
 
 
 
 
199
  img_dir = _assets_images_dir()
200
+ if not img_dir.is_dir():
201
+ raise FileNotFoundError(f"assets/images folder not found: {img_dir}")
202
 
203
  base = img_dir.resolve()
204
  rel = Path(asset_rel_path)
 
205
  if rel.is_absolute():
206
  raise ValueError("Absolute paths are not allowed for asset_image.")
207
 
 
208
  full = (base / rel).resolve()
209
  if base != full and base not in full.parents:
210
  raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
211
 
212
  if not full.is_file():
213
  raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
 
214
  if full.suffix.lower() != ".png":
215
  raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
216
 
217
  return full
218
 
219
 
220
+ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
221
  """
222
+ Returns (IMAGE, MASK) in ComfyUI formats.
223
 
224
+ Mask semantics: match ComfyUI core LoadImage:
225
  - If PNG has alpha: mask = 1 - alpha
226
+ - Else: mask = 0
227
  """
228
  p = _safe_asset_path(asset_rel_path)
229
 
230
  im = Image.open(p)
231
  im = ImageOps.exif_transpose(im)
232
 
 
233
  had_alpha = ("A" in im.getbands())
234
  rgba = im.convert("RGBA")
235
  rgb = rgba.convert("RGB")
 
238
  img_t = torch.from_numpy(rgb_arr)[None, ...]
239
 
240
  if had_alpha:
241
+ alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0
242
+ mask = 1.0 - alpha
243
  else:
244
  h, w = rgb.size[1], rgb.size[0]
245
  mask = np.zeros((h, w), dtype=np.float32)
 
249
 
250
 
251
  # -------------------------------------------------------------------------------------
252
+ # Salia_Depth (INLINED: exact logic, no imports from other files)
253
  # -------------------------------------------------------------------------------------
254
 
255
+ # Local model path: assets/depth
256
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
257
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
258
+
259
+ REQUIRED_FILES = {
260
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
261
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
262
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
263
+ }
264
+
265
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
266
+
267
+
268
+ def _have_required_files() -> bool:
269
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
270
+
271
+
272
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
273
+ dst.parent.mkdir(parents=True, exist_ok=True)
274
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
275
+
276
+ if tmp.exists():
277
+ try:
278
+ tmp.unlink()
279
+ except Exception:
280
+ pass
281
+
282
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
283
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
284
+ shutil.copyfileobj(r, f)
285
+
286
+ tmp.replace(dst)
287
+
288
+
289
+ def ensure_local_model_files() -> bool:
290
+ if _have_required_files():
291
+ return True
292
+ try:
293
+ for fname, url in REQUIRED_FILES.items():
294
+ fpath = MODEL_DIR / fname
295
+ if fpath.exists():
296
+ continue
297
+ _download_url_to_file(url, fpath)
298
+ return _have_required_files()
299
+ except Exception:
300
+ return False
301
+
302
+
303
+ def HWC3(x: np.ndarray) -> np.ndarray:
304
+ assert x.dtype == np.uint8
305
+ if x.ndim == 2:
306
+ x = x[:, :, None]
307
+ assert x.ndim == 3
308
+ H, W, C = x.shape
309
+ assert C == 1 or C == 3 or C == 4
310
+ if C == 3:
311
+ return x
312
+ if C == 1:
313
+ return np.concatenate([x, x, x], axis=2)
314
+ # C == 4
315
+ color = x[:, :, 0:3].astype(np.float32)
316
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
317
+ y = color * alpha + 255.0 * (1.0 - alpha) # white background
318
+ y = y.clip(0, 255).astype(np.uint8)
319
+ return y
320
+
321
+
322
+ def pad64(x: int) -> int:
323
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
324
+
325
+
326
+ def safer_memory(x: np.ndarray) -> np.ndarray:
327
+ return np.ascontiguousarray(x.copy()).copy()
328
+
329
+
330
+ def resize_image_with_pad_min_side(
331
+ input_image: np.ndarray,
332
+ resolution: int,
333
+ upscale_method: str = "INTER_CUBIC",
334
+ skip_hwc3: bool = False,
335
+ mode: str = "edge",
336
+ ) -> Tuple[np.ndarray, Any]:
337
+ cv2 = None
338
+ try:
339
+ import cv2 as _cv2
340
+ cv2 = _cv2
341
+ except Exception:
342
+ cv2 = None
343
+
344
+ img = input_image if skip_hwc3 else HWC3(input_image)
345
+
346
+ H_raw, W_raw, _ = img.shape
347
+ if resolution <= 0:
348
+ return img, (lambda x: x)
349
+
350
+ k = float(resolution) / float(min(H_raw, W_raw))
351
+ H_target = int(np.round(float(H_raw) * k))
352
+ W_target = int(np.round(float(W_raw) * k))
353
+
354
+ if cv2 is not None:
355
+ upscale_methods = {
356
+ "INTER_NEAREST": cv2.INTER_NEAREST,
357
+ "INTER_LINEAR": cv2.INTER_LINEAR,
358
+ "INTER_AREA": cv2.INTER_AREA,
359
+ "INTER_CUBIC": cv2.INTER_CUBIC,
360
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
361
+ }
362
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
363
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
364
+ else:
365
+ pil = Image.fromarray(img)
366
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
367
+ pil = pil.resize((W_target, H_target), resample=resample)
368
+ img = np.array(pil, dtype=np.uint8)
369
+
370
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
371
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
372
+
373
+ def remove_pad(x: np.ndarray) -> np.ndarray:
374
+ return safer_memory(x[:H_target, :W_target, ...])
375
+
376
+ return safer_memory(img_padded), remove_pad
377
+
378
+
379
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
380
+ img = HWC3(img_u8)
381
+ H_raw, W_raw, _ = img.shape
382
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
383
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
384
+
385
+ def remove_pad(x: np.ndarray) -> np.ndarray:
386
+ return safer_memory(x[:H_raw, :W_raw, ...])
387
+
388
+ return safer_memory(img_padded), remove_pad
389
+
390
+
391
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
392
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
393
+ rgba = inp_u8.astype(np.uint8)
394
+ rgb = rgba[:, :, 0:3].astype(np.float32)
395
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
396
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
397
+ alpha_u8 = rgba[:, :, 3].copy()
398
+ return rgb_white, alpha_u8
399
+ return HWC3(inp_u8), None
400
+
401
+
402
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
403
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
404
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
405
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
406
+ return out
407
+
408
+
409
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
410
+ if img.ndim == 4:
411
+ img = img[0]
412
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
413
+ u8 = (arr * 255.0).round().astype(np.uint8)
414
+ return u8
415
+
416
+
417
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
418
+ img_u8 = HWC3(img_u8)
419
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
420
+ return t.unsqueeze(0) # [1,H,W,C]
421
+
422
+
423
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
424
+
425
+
426
+ def _try_load_pipeline(model_source: str, device: torch.device):
427
+ if pipeline is None:
428
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
429
+
430
+ key = (model_source, str(device))
431
+ if key in _PIPE_CACHE:
432
+ return _PIPE_CACHE[key]
433
+
434
+ p = pipeline(task="depth-estimation", model=model_source)
435
+ try:
436
+ p.model = p.model.to(device)
437
+ p.device = device
438
+ except Exception:
439
+ pass
440
+
441
+ _PIPE_CACHE[key] = p
442
+ return p
443
+
444
+
445
+ def get_depth_pipeline(device: torch.device):
446
+ if ensure_local_model_files():
447
+ try:
448
+ return _try_load_pipeline(str(MODEL_DIR), device)
449
+ except Exception:
450
+ pass
451
+ try:
452
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
453
+ except Exception:
454
+ return None
455
+
456
+
457
+ def depth_estimate_zoe_style(
458
+ pipe,
459
+ input_rgb_u8: np.ndarray,
460
+ detect_resolution: int,
461
+ upscale_method: str = "INTER_CUBIC",
462
+ ) -> np.ndarray:
463
+ if detect_resolution == -1:
464
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
465
+ else:
466
+ work_img, remove_pad = resize_image_with_pad_min_side(
467
+ input_rgb_u8,
468
+ int(detect_resolution),
469
+ upscale_method=upscale_method,
470
+ skip_hwc3=False,
471
+ mode="edge",
472
+ )
473
+
474
+ pil_image = Image.fromarray(work_img)
475
+
476
+ with torch.no_grad():
477
+ result = pipe(pil_image)
478
+ depth = result["depth"]
479
+
480
+ if isinstance(depth, Image.Image):
481
+ depth_array = np.array(depth, dtype=np.float32)
482
+ else:
483
+ depth_array = np.array(depth, dtype=np.float32)
484
+
485
+ vmin = float(np.percentile(depth_array, 2))
486
+ vmax = float(np.percentile(depth_array, 85))
487
+
488
+ depth_array = depth_array - vmin
489
+ denom = (vmax - vmin)
490
+ if abs(denom) < 1e-12:
491
+ denom = 1e-6
492
+ depth_array = depth_array / denom
493
+
494
+ depth_array = 1.0 - depth_array
495
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
496
+
497
+ detected_map = remove_pad(HWC3(depth_image))
498
+ return detected_map
499
+
500
+
501
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray:
502
+ try:
503
+ import cv2
504
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
505
+ return out.astype(np.uint8)
506
+ except Exception:
507
+ pil = Image.fromarray(depth_rgb_u8)
508
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
509
+ return np.array(pil, dtype=np.uint8)
510
+
511
+
512
+ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
513
  """
514
+ Internal callable version of your Salia_Depth node:
515
+ input: IMAGE [B,H,W,3 or 4]
516
+ output: IMAGE [B,H,W,3]
517
  """
518
+ # Get torch device
519
+ try:
520
+ device = model_management.get_torch_device()
521
+ except Exception:
522
+ device = torch.device("cpu")
523
+
524
+ # Load pipeline
525
+ pipe = None
526
+ try:
527
+ pipe = get_depth_pipeline(device)
528
+ except Exception:
529
+ pipe = None
530
+
531
+ # If everything fails, pass-through
532
+ if pipe is None:
533
+ return image
534
+
535
+ # Batch support
536
+ if image.ndim == 3:
537
+ image = image.unsqueeze(0)
538
 
539
+ outs = []
540
+ for i in range(image.shape[0]):
541
+ try:
542
+ h0 = int(image[i].shape[0])
543
+ w0 = int(image[i].shape[1])
544
+
545
+ inp_u8 = comfy_tensor_to_u8(image[i])
546
+
547
+ # RGBA rule (pre)
548
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
549
+ had_rgba = alpha_u8 is not None
550
+
551
+ # Depth
552
+ depth_rgb = depth_estimate_zoe_style(
553
+ pipe=pipe,
554
+ input_rgb_u8=rgb_for_depth,
555
+ detect_resolution=int(resolution),
556
+ upscale_method="INTER_CUBIC",
557
+ )
558
+
559
+ # Resize back to original size
560
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
561
+
562
+ # RGBA rule (post)
563
+ if had_rgba:
564
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
565
+ try:
566
+ import cv2
567
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
568
+ except Exception:
569
+ pil_a = Image.fromarray(alpha_u8)
570
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
571
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
572
+
573
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
574
+
575
+ outs.append(u8_to_comfy_tensor(depth_rgb))
576
+ except Exception:
577
+ outs.append(image[i].unsqueeze(0))
578
+
579
+ return torch.cat(outs, dim=0)
580
 
581
 
582
  # -------------------------------------------------------------------------------------
 
618
  comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
619
  out[:, y:y + s, x:x + s, 0:3] = comp_rgb
620
 
621
+ # If base has alpha, composite alpha too
622
  if C == 4:
623
  base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
624
  comp_a = overlay_a + base_a * (1.0 - overlay_a)
 
631
  # The One-Node Workflow
632
  # -------------------------------------------------------------------------------------
633
 
634
+ class Salia_OneNode_WorkflowSquare:
 
 
 
 
635
  CATEGORY = "image/salia"
636
  RETURN_TYPES = ("IMAGE",)
637
  RETURN_NAMES = ("image",)
 
639
 
640
  @classmethod
641
  def INPUT_TYPES(cls):
 
642
  ckpts = folder_paths.get_filename_list("checkpoints") or ["<no checkpoints found>"]
643
  cns = folder_paths.get_filename_list("controlnet") or ["<no controlnets found>"]
644
  assets = _list_asset_pngs() or ["<no pngs found>"]
645
 
 
646
  try:
647
  import comfy.samplers
648
  sampler_names = comfy.samplers.KSampler.SAMPLERS
 
651
  sampler_names = ["euler"]
652
  scheduler_names = ["karras"]
653
 
 
654
  upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
655
 
656
  return {
 
666
 
667
  "upscale_factor": (upscale_choices, {"default": "4"}),
668
 
 
669
  "ckpt_name": (ckpts, {}),
670
  "control_net_name": (cns, {}),
671
  "asset_image": (assets, {}),
672
 
 
673
  "controlnet_strength": ("FLOAT", {"default": 0.33, "min": 0.00, "max": 10.00, "step": 0.01}),
674
  "controlnet_start_percent": ("FLOAT", {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.01}),
675
  "controlnet_end_percent": ("FLOAT", {"default": 1.00, "min": 0.00, "max": 1.00, "step": 0.01}),
676
 
 
677
  "steps": ("INT", {"default": 30, "min": 1, "max": 200, "step": 1}),
678
  "cfg": ("FLOAT", {"default": 2.6, "min": 0.00, "max": 10.00, "step": 0.05}),
679
  "sampler_name": (sampler_names, {"default": "euler"} if "euler" in sampler_names else {}),
 
690
  square_size: int,
691
  positive_prompt: str,
692
  negative_prompt: str,
693
+ upscale_factor: str,
694
  ckpt_name: str,
695
  control_net_name: str,
696
  asset_image: str,
 
703
  scheduler: str,
704
  denoise: float,
705
  ):
 
 
 
706
  if image.ndim == 3:
707
  image = image.unsqueeze(0)
 
708
  if image.ndim != 4:
709
  raise ValueError("Input image must be [B,H,W,C].")
710
 
 
719
  up = int(upscale_factor)
720
  if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
721
  raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
 
722
  if s <= 0:
723
  raise ValueError("square_size must be > 0")
 
724
  if x < 0 or y < 0 or x + s > W or y + s > H:
725
  raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
726
 
727
  up_w = s * up
728
  up_h = s * up
729
 
 
730
  if (up_w % 8) != 0 or (up_h % 8) != 0:
731
  raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
732
 
 
733
  start_p = float(max(0.0, min(1.0, controlnet_start_percent)))
734
  end_p = float(max(0.0, min(1.0, controlnet_end_percent)))
735
  if end_p < start_p:
736
  start_p, end_p = end_p, start_p
737
 
738
+ # 1) Crop
 
 
739
  crop = image[:, y:y + s, x:x + s, :]
740
+ crop_rgb = crop[:, :, :, 0:3].contiguous()
741
 
742
+ # 2) Depth (inline Salia_Depth) then Lanczos upscale
743
+ depth_small = _salia_depth_execute(crop_rgb, resolution=s)
 
 
744
  depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
745
 
746
+ # 3) Upscale crop for VAE Encode
 
 
747
  crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
748
 
749
+ # 4) Load asset mask (inline) and resize
 
 
750
  if asset_image == "<no pngs found>":
751
  raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
752
+ _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
753
 
 
754
  if asset_mask.ndim == 2:
755
  asset_mask = asset_mask.unsqueeze(0)
756
  if asset_mask.ndim != 3:
757
  raise ValueError("Asset mask must be [B,H,W].")
758
 
 
759
  if asset_mask.shape[0] != B:
760
  if asset_mask.shape[0] == 1 and B > 1:
761
  asset_mask = asset_mask.expand(B, -1, -1)
 
764
 
765
  asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
766
 
767
+ # 5) Load checkpoint + controlnet (cached)
 
 
768
  if ckpt_name == "<no checkpoints found>":
769
+ raise FileNotFoundError("No checkpoints found in models/checkpoints.")
 
770
  if control_net_name == "<no controlnets found>":
771
+ raise FileNotFoundError("No controlnets found in models/controlnet.")
772
 
773
  model, clip, vae = _load_checkpoint_cached(ckpt_name)
774
  controlnet = _load_controlnet_cached(control_net_name)
775
 
776
+ import nodes
 
 
 
777
 
778
+ # 6) CLIP encodes
779
  pos_enc = nodes.CLIPTextEncode()
780
  neg_enc = nodes.CLIPTextEncode()
781
  pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
782
  neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
 
783
  (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
784
  (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
785
 
786
+ # 7) Apply ControlNet
 
 
787
  cn_apply = nodes.ControlNetApplyAdvanced()
788
  cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
 
789
  pos_cn, neg_cn = cn_fn(
790
  strength=float(controlnet_strength),
791
  start_percent=float(start_p),
 
797
  vae=vae,
798
  )
799
 
800
+ # 8) VAE Encode
 
 
801
  vae_enc = nodes.VAEEncode()
802
  vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
803
  (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
804
 
805
+ # 9) KSampler (deterministic seed derived from inputs)
 
 
806
  seed_material = (
807
  f"{ckpt_name}|{control_net_name}|{asset_image}|{x}|{y}|{s}|{up}|"
808
  f"{steps}|{cfg}|{sampler_name}|{scheduler}|{denoise}|"
 
826
  latent_image=latent,
827
  )
828
 
829
+ # 10) VAE Decode
 
 
830
  vae_dec = nodes.VAEDecode()
831
  vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
832
  (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
833
 
834
+ # 11) JoinImageWithAlpha
 
 
835
  join = nodes.JoinImageWithAlpha()
836
  join_fn = getattr(join, join.FUNCTION)
 
837
  try:
838
  (rgba_up,) = join_fn(image=decoded_rgb, alpha=asset_mask_up)
839
  except TypeError:
840
  (rgba_up,) = join_fn(image=decoded_rgb, mask=asset_mask_up)
841
 
842
+ # 12) Downscale RGBA back to crop size
 
 
843
  rgba_square = _resize_image_lanczos(rgba_up, s, s)
844
 
845
+ # 13) Paste back onto original at X,Y (alpha-over)
 
 
846
  out = _alpha_over_region(image, rgba_square, x=x, y=y)
 
847
  return (out,)
848
 
849
 
850
  NODE_CLASS_MAPPINGS = {
851
+ "Salia_OneNode_WorkflowSquare": Salia_OneNode_WorkflowSquare,
852
  }
853
 
854
  NODE_DISPLAY_NAME_MAPPINGS = {
855
+ "Salia_OneNode_WorkflowSquare": "Salia One-Node Workflow (Crop+Depth+CN+Sample+Paste)",
856
  }