saliacoel commited on
Commit
eb0d4e7
·
verified ·
1 Parent(s): b0b1cf0

Upload 2 files

Browse files
salia_detailer_ezpz_gated_Doubletime.py ADDED
@@ -0,0 +1,1075 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import shutil
3
+ import threading
4
+ import urllib.request
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Tuple, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageOps
11
+
12
+ import folder_paths
13
+ import comfy.model_management as model_management
14
+
15
+
16
+ # transformers is required for depth-estimation pipeline
17
+ try:
18
+ from transformers import pipeline
19
+ except Exception as e:
20
+ pipeline = None
21
+ _TRANSFORMERS_IMPORT_ERROR = e
22
+
23
+
24
+ # -------------------------------------------------------------------------------------
25
+ # Global caches (checkpoint + controlnet) so using the node multiple times won't reload
26
+ # -------------------------------------------------------------------------------------
27
+
28
+ _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
29
+ _CN_CACHE: Dict[str, Any] = {}
30
+ _CKPT_LOCK = threading.Lock()
31
+ _CN_LOCK = threading.Lock()
32
+
33
+
34
+ # -------------------------------------------------------------------------------------
35
+ # Plugin root detection (works whether file is in plugin root or nodes/)
36
+ # -------------------------------------------------------------------------------------
37
+
38
+ def _find_plugin_root() -> Path:
39
+ """
40
+ Walk upwards from this file until we find an 'assets' folder.
41
+ Robust against hyphen/underscore package naming and different file placement.
42
+ """
43
+ here = Path(__file__).resolve()
44
+ for parent in [here.parent] + list(here.parents)[:12]:
45
+ if (parent / "assets").is_dir():
46
+ return parent
47
+ # fallback: typical nodes/<file>.py
48
+ return here.parent.parent
49
+
50
+
51
+ PLUGIN_ROOT = _find_plugin_root()
52
+
53
+
54
+ # -------------------------------------------------------------------------------------
55
+ # PIL helpers (Lanczos resize for IMAGE and MASK)
56
+ # -------------------------------------------------------------------------------------
57
+
58
+ def _pil_lanczos():
59
+ if hasattr(Image, "Resampling"):
60
+ return Image.Resampling.LANCZOS
61
+ return Image.LANCZOS
62
+
63
+
64
+ def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
65
+ """
66
+ Comfy IMAGE: [B,H,W,C] or [H,W,C], float [0..1] -> PIL RGB/RGBA
67
+ """
68
+ if img.ndim == 4:
69
+ img = img[0]
70
+ img = img.detach().cpu().float().clamp(0, 1)
71
+ arr = (img.numpy() * 255.0).round().astype(np.uint8)
72
+ if arr.shape[-1] == 4:
73
+ return Image.fromarray(arr, mode="RGBA")
74
+ return Image.fromarray(arr, mode="RGB")
75
+
76
+
77
+ def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor:
78
+ """
79
+ PIL RGB/RGBA -> Comfy IMAGE [1,H,W,C], float [0..1]
80
+ """
81
+ if pil.mode not in ("RGB", "RGBA"):
82
+ pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB")
83
+ arr = np.array(pil).astype(np.float32) / 255.0
84
+ t = torch.from_numpy(arr) # [H,W,C]
85
+ return t.unsqueeze(0)
86
+
87
+
88
+ def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image:
89
+ """
90
+ Comfy MASK: [B,H,W] or [H,W], float [0..1] -> PIL L
91
+ """
92
+ if mask.ndim == 3:
93
+ mask = mask[0]
94
+ mask = mask.detach().cpu().float().clamp(0, 1)
95
+ arr = (mask.numpy() * 255.0).round().astype(np.uint8)
96
+ return Image.fromarray(arr, mode="L")
97
+
98
+
99
+ def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
100
+ """
101
+ PIL L -> Comfy MASK [1,H,W], float [0..1]
102
+ """
103
+ if pil_l.mode != "L":
104
+ pil_l = pil_l.convert("L")
105
+ arr = np.array(pil_l).astype(np.float32) / 255.0
106
+ t = torch.from_numpy(arr) # [H,W]
107
+ return t.unsqueeze(0)
108
+
109
+
110
+ def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
111
+ """
112
+ Resize Comfy IMAGE [B,H,W,C] with Lanczos via PIL.
113
+ """
114
+ if img.ndim != 4:
115
+ raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
116
+ outs = []
117
+ for i in range(img.shape[0]):
118
+ pil = _image_tensor_to_pil(img[i].unsqueeze(0))
119
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
120
+ outs.append(_pil_to_image_tensor(pil))
121
+ return torch.cat(outs, dim=0)
122
+
123
+
124
+ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
125
+ """
126
+ Resize Comfy MASK [B,H,W] with Lanczos via PIL.
127
+ """
128
+ if mask.ndim != 3:
129
+ raise ValueError("Expected MASK tensor with shape [B,H,W].")
130
+ outs = []
131
+ for i in range(mask.shape[0]):
132
+ pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
133
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
134
+ outs.append(_pil_to_mask_tensor(pil))
135
+ return torch.cat(outs, dim=0)
136
+
137
+
138
+ # -------------------------------------------------------------------------------------
139
+ # ✅ ComfyUI 0.5.1 FIX: Manual JoinImageWithAlpha equivalent
140
+ # -------------------------------------------------------------------------------------
141
+
142
+ def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
143
+ """
144
+ Make RGBA from:
145
+ rgb: IMAGE [B,H,W,3] float [0..1]
146
+ mask: MASK [B,H,W] float [0..1] (Comfy convention: 1=masked/transparent)
147
+ Output:
148
+ rgba: IMAGE [B,H,W,4] where alpha = 1 - mask (1=opaque, 0=transparent)
149
+ """
150
+ if rgb.ndim == 3:
151
+ rgb = rgb.unsqueeze(0)
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(0)
154
+
155
+ if rgb.ndim != 4 or rgb.shape[-1] != 3:
156
+ raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
157
+ if mask.ndim != 3:
158
+ raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
159
+
160
+ # Batch match
161
+ if mask.shape[0] != rgb.shape[0]:
162
+ if mask.shape[0] == 1 and rgb.shape[0] > 1:
163
+ mask = mask.expand(rgb.shape[0], -1, -1)
164
+ else:
165
+ raise ValueError("Batch mismatch between rgb and mask.")
166
+
167
+ # Size match
168
+ if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
169
+ raise ValueError(
170
+ f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
171
+ )
172
+
173
+ mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
174
+ alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) # [B,H,W,1]
175
+
176
+ rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) # [B,H,W,4]
177
+ return rgba
178
+
179
+
180
+ # -------------------------------------------------------------------------------------
181
+ # Core lazy loaders (checkpoint + controlnet), cached globally
182
+ # -------------------------------------------------------------------------------------
183
+
184
+ def _load_checkpoint_cached(ckpt_name: str):
185
+ """
186
+ Mirrors comfy-core CheckpointLoaderSimple, but cached to avoid double-loads.
187
+ Returns: (model, clip, vae)
188
+ """
189
+ with _CKPT_LOCK:
190
+ if ckpt_name in _CKPT_CACHE:
191
+ return _CKPT_CACHE[ckpt_name]
192
+
193
+ import nodes
194
+ loader = nodes.CheckpointLoaderSimple()
195
+ fn = getattr(loader, loader.FUNCTION)
196
+ model, clip, vae = fn(ckpt_name=ckpt_name)
197
+
198
+ _CKPT_CACHE[ckpt_name] = (model, clip, vae)
199
+ return model, clip, vae
200
+
201
+
202
+ def _load_controlnet_cached(control_net_name: str):
203
+ """
204
+ Mirrors comfy-core ControlNetLoader, but cached to avoid double-loads.
205
+ Returns: controlnet
206
+ """
207
+ with _CN_LOCK:
208
+ if control_net_name in _CN_CACHE:
209
+ return _CN_CACHE[control_net_name]
210
+
211
+ import nodes
212
+ loader = nodes.ControlNetLoader()
213
+ fn = getattr(loader, loader.FUNCTION)
214
+ (cn,) = fn(control_net_name=control_net_name)
215
+
216
+ _CN_CACHE[control_net_name] = cn
217
+ return cn
218
+
219
+
220
+ # -------------------------------------------------------------------------------------
221
+ # Assets/images dropdown + loader (inlined; no LoadImage_SaliaOnline_Assets dependency)
222
+ # -------------------------------------------------------------------------------------
223
+
224
+ def _assets_images_dir() -> Path:
225
+ return PLUGIN_ROOT / "assets" / "images"
226
+
227
+
228
+ def _list_asset_pngs() -> list:
229
+ img_dir = _assets_images_dir()
230
+ if not img_dir.is_dir():
231
+ return []
232
+ files = []
233
+ for p in img_dir.rglob("*"):
234
+ if p.is_file() and p.suffix.lower() == ".png":
235
+ files.append(p.relative_to(img_dir).as_posix())
236
+ files.sort()
237
+ return files
238
+
239
+
240
+ def _safe_asset_path(asset_rel_path: str) -> Path:
241
+ img_dir = _assets_images_dir()
242
+ if not img_dir.is_dir():
243
+ raise FileNotFoundError(f"assets/images folder not found: {img_dir}")
244
+
245
+ base = img_dir.resolve()
246
+ rel = Path(asset_rel_path)
247
+
248
+ if rel.is_absolute():
249
+ raise ValueError("Absolute paths are not allowed for asset_image.")
250
+
251
+ full = (base / rel).resolve()
252
+
253
+ # path traversal protection
254
+ if base != full and base not in full.parents:
255
+ raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
256
+
257
+ if not full.is_file():
258
+ raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
259
+ if full.suffix.lower() != ".png":
260
+ raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
261
+
262
+ return full
263
+
264
+
265
+ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
266
+ """
267
+ Returns (IMAGE, MASK) in ComfyUI formats.
268
+
269
+ Mask semantics: match ComfyUI core LoadImage:
270
+ - alpha is RGBA alpha channel normalized to [0..1]
271
+ - mask = 1 - alpha
272
+ """
273
+ p = _safe_asset_path(asset_rel_path)
274
+
275
+ im = Image.open(p)
276
+ im = ImageOps.exif_transpose(im)
277
+
278
+ rgba = im.convert("RGBA")
279
+ rgb = rgba.convert("RGB")
280
+
281
+ rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
282
+ img_t = torch.from_numpy(rgb_arr)[None, ...]
283
+
284
+ alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 # [H,W]
285
+ mask = 1.0 - alpha # Comfy MASK convention
286
+
287
+ mask_t = torch.from_numpy(mask)[None, ...]
288
+ return img_t, mask_t
289
+
290
+
291
+ # -------------------------------------------------------------------------------------
292
+ # Salia_Depth (INLINED, no imports from other files)
293
+ # -------------------------------------------------------------------------------------
294
+
295
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
296
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
297
+
298
+ REQUIRED_FILES = {
299
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
300
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
301
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
302
+ }
303
+
304
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
305
+
306
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
307
+ _PIPE_LOCK = threading.Lock()
308
+
309
+
310
+ def _have_required_files() -> bool:
311
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
312
+
313
+
314
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
315
+ dst.parent.mkdir(parents=True, exist_ok=True)
316
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
317
+
318
+ if tmp.exists():
319
+ try:
320
+ tmp.unlink()
321
+ except Exception:
322
+ pass
323
+
324
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
325
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
326
+ shutil.copyfileobj(r, f)
327
+
328
+ tmp.replace(dst)
329
+
330
+
331
+ def ensure_local_model_files() -> bool:
332
+ if _have_required_files():
333
+ return True
334
+ try:
335
+ for fname, url in REQUIRED_FILES.items():
336
+ fpath = MODEL_DIR / fname
337
+ if fpath.exists():
338
+ continue
339
+ _download_url_to_file(url, fpath)
340
+ return _have_required_files()
341
+ except Exception:
342
+ return False
343
+
344
+
345
+ def HWC3(x: np.ndarray) -> np.ndarray:
346
+ assert x.dtype == np.uint8
347
+ if x.ndim == 2:
348
+ x = x[:, :, None]
349
+ assert x.ndim == 3
350
+ H, W, C = x.shape
351
+ assert C == 1 or C == 3 or C == 4
352
+ if C == 3:
353
+ return x
354
+ if C == 1:
355
+ return np.concatenate([x, x, x], axis=2)
356
+ # C == 4
357
+ color = x[:, :, 0:3].astype(np.float32)
358
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
359
+ y = color * alpha + 255.0 * (1.0 - alpha) # white background
360
+ y = y.clip(0, 255).astype(np.uint8)
361
+ return y
362
+
363
+
364
+ def pad64(x: int) -> int:
365
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
366
+
367
+
368
+ def safer_memory(x: np.ndarray) -> np.ndarray:
369
+ return np.ascontiguousarray(x.copy()).copy()
370
+
371
+
372
+ def resize_image_with_pad_min_side(
373
+ input_image: np.ndarray,
374
+ resolution: int,
375
+ upscale_method: str = "INTER_CUBIC",
376
+ skip_hwc3: bool = False,
377
+ mode: str = "edge",
378
+ ) -> Tuple[np.ndarray, Any]:
379
+ cv2 = None
380
+ try:
381
+ import cv2 as _cv2
382
+ cv2 = _cv2
383
+ except Exception:
384
+ cv2 = None
385
+
386
+ img = input_image if skip_hwc3 else HWC3(input_image)
387
+
388
+ H_raw, W_raw, _ = img.shape
389
+ if resolution <= 0:
390
+ return img, (lambda x: x)
391
+
392
+ k = float(resolution) / float(min(H_raw, W_raw))
393
+ H_target = int(np.round(float(H_raw) * k))
394
+ W_target = int(np.round(float(W_raw) * k))
395
+
396
+ if cv2 is not None:
397
+ upscale_methods = {
398
+ "INTER_NEAREST": cv2.INTER_NEAREST,
399
+ "INTER_LINEAR": cv2.INTER_LINEAR,
400
+ "INTER_AREA": cv2.INTER_AREA,
401
+ "INTER_CUBIC": cv2.INTER_CUBIC,
402
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
403
+ }
404
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
405
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
406
+ else:
407
+ pil = Image.fromarray(img)
408
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
409
+ pil = pil.resize((W_target, H_target), resample=resample)
410
+ img = np.array(pil, dtype=np.uint8)
411
+
412
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
413
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
414
+
415
+ def remove_pad(x: np.ndarray) -> np.ndarray:
416
+ return safer_memory(x[:H_target, :W_target, ...])
417
+
418
+ return safer_memory(img_padded), remove_pad
419
+
420
+
421
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
422
+ img = HWC3(img_u8)
423
+ H_raw, W_raw, _ = img.shape
424
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
425
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
426
+
427
+ def remove_pad(x: np.ndarray) -> np.ndarray:
428
+ return safer_memory(x[:H_raw, :W_raw, ...])
429
+
430
+ return safer_memory(img_padded), remove_pad
431
+
432
+
433
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
434
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
435
+ rgba = inp_u8.astype(np.uint8)
436
+ rgb = rgba[:, :, 0:3].astype(np.float32)
437
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
438
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
439
+ alpha_u8 = rgba[:, :, 3].copy()
440
+ return rgb_white, alpha_u8
441
+ return HWC3(inp_u8), None
442
+
443
+
444
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
445
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
446
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
447
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
448
+ return out
449
+
450
+
451
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
452
+ if img.ndim == 4:
453
+ img = img[0]
454
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
455
+ u8 = (arr * 255.0).round().astype(np.uint8)
456
+ return u8
457
+
458
+
459
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
460
+ img_u8 = HWC3(img_u8)
461
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
462
+ return t.unsqueeze(0) # [1,H,W,C]
463
+
464
+
465
+ def _try_load_pipeline(model_source: str, device: torch.device):
466
+ if pipeline is None:
467
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
468
+
469
+ key = (model_source, str(device))
470
+ with _PIPE_LOCK:
471
+ if key in _PIPE_CACHE:
472
+ return _PIPE_CACHE[key]
473
+
474
+ p = pipeline(task="depth-estimation", model=model_source)
475
+ try:
476
+ p.model = p.model.to(device)
477
+ p.device = device
478
+ except Exception:
479
+ pass
480
+
481
+ _PIPE_CACHE[key] = p
482
+ return p
483
+
484
+
485
+ def get_depth_pipeline(device: torch.device):
486
+ if ensure_local_model_files():
487
+ try:
488
+ return _try_load_pipeline(str(MODEL_DIR), device)
489
+ except Exception:
490
+ pass
491
+ try:
492
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
493
+ except Exception:
494
+ return None
495
+
496
+
497
+ def depth_estimate_zoe_style(
498
+ pipe,
499
+ input_rgb_u8: np.ndarray,
500
+ detect_resolution: int,
501
+ upscale_method: str = "INTER_CUBIC",
502
+ ) -> np.ndarray:
503
+ if detect_resolution == -1:
504
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
505
+ else:
506
+ work_img, remove_pad = resize_image_with_pad_min_side(
507
+ input_rgb_u8,
508
+ int(detect_resolution),
509
+ upscale_method=upscale_method,
510
+ skip_hwc3=False,
511
+ mode="edge",
512
+ )
513
+
514
+ pil_image = Image.fromarray(work_img)
515
+
516
+ with torch.no_grad():
517
+ result = pipe(pil_image)
518
+ depth = result["depth"]
519
+
520
+ if isinstance(depth, Image.Image):
521
+ depth_array = np.array(depth, dtype=np.float32)
522
+ else:
523
+ depth_array = np.array(depth, dtype=np.float32)
524
+
525
+ vmin = float(np.percentile(depth_array, 2))
526
+ vmax = float(np.percentile(depth_array, 85))
527
+
528
+ depth_array = depth_array - vmin
529
+ denom = (vmax - vmin)
530
+ if abs(denom) < 1e-12:
531
+ denom = 1e-6
532
+ depth_array = depth_array / denom
533
+
534
+ depth_array = 1.0 - depth_array
535
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
536
+
537
+ detected_map = remove_pad(HWC3(depth_image))
538
+ return detected_map
539
+
540
+
541
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray:
542
+ try:
543
+ import cv2
544
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
545
+ return out.astype(np.uint8)
546
+ except Exception:
547
+ pil = Image.fromarray(depth_rgb_u8)
548
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
549
+ return np.array(pil, dtype=np.uint8)
550
+
551
+
552
+ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
553
+ """
554
+ Internal callable version of Salia_Depth:
555
+ input: IMAGE [B,H,W,3 or 4]
556
+ output: IMAGE [B,H,W,3]
557
+ """
558
+ try:
559
+ device = model_management.get_torch_device()
560
+ except Exception:
561
+ device = torch.device("cpu")
562
+
563
+ pipe_obj = None
564
+ try:
565
+ pipe_obj = get_depth_pipeline(device)
566
+ except Exception:
567
+ pipe_obj = None
568
+
569
+ if pipe_obj is None:
570
+ return image
571
+
572
+ if image.ndim == 3:
573
+ image = image.unsqueeze(0)
574
+
575
+ outs = []
576
+ for i in range(image.shape[0]):
577
+ try:
578
+ h0 = int(image[i].shape[0])
579
+ w0 = int(image[i].shape[1])
580
+
581
+ inp_u8 = comfy_tensor_to_u8(image[i])
582
+
583
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
584
+ had_rgba = alpha_u8 is not None
585
+
586
+ depth_rgb = depth_estimate_zoe_style(
587
+ pipe=pipe_obj,
588
+ input_rgb_u8=rgb_for_depth,
589
+ detect_resolution=int(resolution),
590
+ upscale_method="INTER_CUBIC",
591
+ )
592
+
593
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
594
+
595
+ if had_rgba:
596
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
597
+ try:
598
+ import cv2
599
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
600
+ except Exception:
601
+ pil_a = Image.fromarray(alpha_u8)
602
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
603
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
604
+
605
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
606
+
607
+ outs.append(u8_to_comfy_tensor(depth_rgb))
608
+ except Exception:
609
+ outs.append(image[i].unsqueeze(0))
610
+
611
+ return torch.cat(outs, dim=0)
612
+
613
+
614
+ # -------------------------------------------------------------------------------------
615
+ # Alpha-over paste (RGBA square onto base at X,Y)
616
+ # -------------------------------------------------------------------------------------
617
+
618
+ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor:
619
+ """
620
+ base: [B,H,W,C] where C is 3 or 4, float [0..1]
621
+ overlay_rgba: [B,s,s,4] float [0..1]
622
+ """
623
+ if base.ndim != 4 or overlay_rgba.ndim != 4:
624
+ raise ValueError("base and overlay must be [B,H,W,C].")
625
+
626
+ B, H, W, C = base.shape
627
+ b2, sH, sW, c2 = overlay_rgba.shape
628
+ if c2 != 4:
629
+ raise ValueError("overlay_rgba must have 4 channels (RGBA).")
630
+ if sH != sW:
631
+ raise ValueError("overlay must be square.")
632
+ s = sH
633
+
634
+ if x < 0 or y < 0 or x + s > W or y + s > H:
635
+ raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
636
+
637
+ if b2 != B:
638
+ if b2 == 1 and B > 1:
639
+ overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
640
+ else:
641
+ raise ValueError("Batch mismatch between base and overlay.")
642
+
643
+ out = base.clone()
644
+
645
+ overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1)
646
+ overlay_a = overlay_rgba[..., 3:4].clamp(0, 1)
647
+
648
+ base_rgb = out[:, y:y + s, x:x + s, 0:3]
649
+ comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
650
+ out[:, y:y + s, x:x + s, 0:3] = comp_rgb
651
+
652
+ if C == 4:
653
+ base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
654
+ comp_a = overlay_a + base_a * (1.0 - overlay_a)
655
+ out[:, y:y + s, x:x + s, 3:4] = comp_a
656
+
657
+ return out.clamp(0, 1)
658
+
659
+
660
+
661
+
662
+
663
+
664
+ # -------------------------------------------------------------------------------------
665
+ # Two-pass EZPZ node (same pipeline twice sequentially)
666
+ # -------------------------------------------------------------------------------------
667
+
668
+ # Hardcoded constants requested
669
+ _TWOPASS_CKPT_NAME = "SaliaHighlady_Speedy.safetensors"
670
+ _TWOPASS_CONTROLNET_NAME = "diffusion_pytorch_model_promax.safetensors"
671
+ _TWOPASS_CN_START_PERCENT = 0.00
672
+ _TWOPASS_CN_END_PERCENT = 1.00
673
+
674
+ # Pass 1 hardcoded sampler/settings
675
+ _PASS1_SAMPLER = "dpmpp_2m_sde_heun_gpu"
676
+ _PASS1_SCHEDULER = "karras"
677
+ _PASS1_STEPS = 29
678
+ _PASS1_CFG = 2.6
679
+ _PASS1_CN_STRENGTH = 0.33
680
+
681
+ # Pass 2 hardcoded sampler/settings
682
+ _PASS2_SAMPLER = "res_multistep_ancestral_cfg_pp"
683
+ _PASS2_SCHEDULER = "karras"
684
+ _PASS2_STEPS = 30
685
+ _PASS2_CFG = 1.7
686
+ _PASS2_CN_STRENGTH = 0.5
687
+
688
+
689
+ def _ensure_model_assets_exist_or_throw():
690
+ # Checkpoint existence
691
+ ckpt_path = folder_paths.get_full_path("checkpoints", _TWOPASS_CKPT_NAME)
692
+ if ckpt_path is None:
693
+ available = folder_paths.get_filename_list("checkpoints") or []
694
+ raise FileNotFoundError(
695
+ f"Hardcoded ckpt_name not found: {_TWOPASS_CKPT_NAME}\n"
696
+ f"Available checkpoints: {available[:50]}{' ...' if len(available) > 50 else ''}"
697
+ )
698
+
699
+ # ControlNet existence
700
+ cn_path = folder_paths.get_full_path("controlnet", _TWOPASS_CONTROLNET_NAME)
701
+ if cn_path is None:
702
+ available = folder_paths.get_filename_list("controlnet") or []
703
+ raise FileNotFoundError(
704
+ f"Hardcoded control_net_name not found: {_TWOPASS_CONTROLNET_NAME}\n"
705
+ f"Available controlnets: {available[:50]}{' ...' if len(available) > 50 else ''}"
706
+ )
707
+
708
+
709
+ def _validate_sampler_scheduler_exist_or_throw(sampler_name: str, scheduler: str):
710
+ try:
711
+ import comfy.samplers
712
+ samplers = set(comfy.samplers.KSampler.SAMPLERS)
713
+ schedulers = set(comfy.samplers.KSampler.SCHEDULERS)
714
+ if sampler_name not in samplers:
715
+ raise ValueError(
716
+ f"Hardcoded sampler_name not available: {sampler_name}\n"
717
+ f"Available samplers: {sorted(list(samplers))}"
718
+ )
719
+ if scheduler not in schedulers:
720
+ raise ValueError(
721
+ f"Hardcoded scheduler not available: {scheduler}\n"
722
+ f"Available schedulers: {sorted(list(schedulers))}"
723
+ )
724
+ except ImportError:
725
+ # If comfy.samplers import fails, let KSampler raise later; nothing else sensible to do here.
726
+ pass
727
+
728
+
729
+ def _crop_square_or_throw(img: torch.Tensor, x: int, y: int, s: int) -> torch.Tensor:
730
+ if img.ndim == 3:
731
+ img = img.unsqueeze(0)
732
+ if img.ndim != 4:
733
+ raise ValueError("Input image must be [B,H,W,C].")
734
+
735
+ B, H, W, C = img.shape
736
+ if C not in (3, 4):
737
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
738
+
739
+ if s <= 0:
740
+ raise ValueError("square_size must be > 0")
741
+ if x < 0 or y < 0 or x + s > W or y + s > H:
742
+ raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
743
+
744
+ return img[:, y:y + s, x:x + s, :]
745
+
746
+
747
+ def _run_one_pass(
748
+ base_image: torch.Tensor,
749
+ *,
750
+ x: int,
751
+ y: int,
752
+ square_size: int,
753
+ upscale_factor: int,
754
+ denoise: float,
755
+ steps: int,
756
+ cfg: float,
757
+ sampler_name: str,
758
+ scheduler: str,
759
+ controlnet_strength: float,
760
+ # shared objects
761
+ model,
762
+ clip,
763
+ vae,
764
+ controlnet,
765
+ pos_cond,
766
+ neg_cond,
767
+ asset_mask_batched: torch.Tensor,
768
+ pass_tag: str,
769
+ positive_prompt: str,
770
+ negative_prompt: str,
771
+ asset_image: str,
772
+ ) -> torch.Tensor:
773
+ """
774
+ Executes the exact same pipeline as Salia_ezpz_gated, but parameterized
775
+ for one pass. Returns a full image with the processed square pasted back.
776
+ """
777
+
778
+ # Normalize base_image
779
+ if base_image.ndim == 3:
780
+ base_image = base_image.unsqueeze(0)
781
+ if base_image.ndim != 4:
782
+ raise ValueError("Input image must be [B,H,W,C].")
783
+
784
+ B, H, W, C = base_image.shape
785
+ if C not in (3, 4):
786
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
787
+
788
+ s = int(square_size)
789
+ up = int(upscale_factor)
790
+
791
+ if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
792
+ raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
793
+
794
+ # bounds check for this pass
795
+ if s <= 0:
796
+ raise ValueError("square_size must be > 0")
797
+ if x < 0 or y < 0 or x + s > W or y + s > H:
798
+ raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
799
+
800
+ up_w = s * up
801
+ up_h = s * up
802
+
803
+ # VAE/UNet likes multiples of 8
804
+ if (up_w % 8) != 0 or (up_h % 8) != 0:
805
+ raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
806
+
807
+ # hardcoded CN start/end
808
+ start_p = float(_TWOPASS_CN_START_PERCENT)
809
+ end_p = float(_TWOPASS_CN_END_PERCENT)
810
+
811
+ # 1) Crop square
812
+ crop = base_image[:, y:y + s, x:x + s, :]
813
+ crop_rgb = crop[:, :, :, 0:3].contiguous()
814
+
815
+ # 2) Depth then upscale with Lanczos
816
+ depth_small = _salia_depth_execute(crop_rgb, resolution=s)
817
+ depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
818
+
819
+ # 3) Upscale crop for VAE encode
820
+ crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
821
+
822
+ # 4) Resize asset mask for this pass
823
+ if asset_mask_batched.ndim != 3:
824
+ raise ValueError("asset_mask_batched must be [B,H,W].")
825
+ if asset_mask_batched.shape[0] != B:
826
+ raise ValueError("Batch mismatch for asset mask vs base image.")
827
+ asset_mask_up = _resize_mask_lanczos(asset_mask_batched, up_w, up_h)
828
+
829
+ import nodes
830
+
831
+ # 5) Apply ControlNet
832
+ cn_apply = nodes.ControlNetApplyAdvanced()
833
+ cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
834
+ pos_cn, neg_cn = cn_fn(
835
+ strength=float(controlnet_strength),
836
+ start_percent=float(start_p),
837
+ end_percent=float(end_p),
838
+ positive=pos_cond,
839
+ negative=neg_cond,
840
+ control_net=controlnet,
841
+ image=depth_up,
842
+ vae=vae,
843
+ )
844
+
845
+ # 6) VAE Encode
846
+ vae_enc = nodes.VAEEncode()
847
+ vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
848
+ (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
849
+
850
+ # 7) KSampler seed (deterministic, pass-specific)
851
+ seed_material = (
852
+ f"{pass_tag}|{_TWOPASS_CKPT_NAME}|{_TWOPASS_CONTROLNET_NAME}|{asset_image}|{x}|{y}|{s}|{up}|"
853
+ f"{steps}|{cfg}|{sampler_name}|{scheduler}|{denoise}|"
854
+ f"{controlnet_strength}|{start_p}|{end_p}|"
855
+ f"{positive_prompt}|{negative_prompt}"
856
+ ).encode("utf-8", errors="ignore")
857
+ seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
858
+
859
+ ksampler = nodes.KSampler()
860
+ k_fn = getattr(ksampler, ksampler.FUNCTION)
861
+ (sampled_latent,) = k_fn(
862
+ seed=seed64,
863
+ steps=int(steps),
864
+ cfg=float(cfg),
865
+ sampler_name=str(sampler_name),
866
+ scheduler=str(scheduler),
867
+ denoise=float(denoise),
868
+ model=model,
869
+ positive=pos_cn,
870
+ negative=neg_cn,
871
+ latent_image=latent,
872
+ )
873
+
874
+ # 8) VAE Decode -> RGB
875
+ vae_dec = nodes.VAEDecode()
876
+ vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
877
+ (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
878
+
879
+ # 9) Manual JoinImageWithAlpha: decoded_rgb + asset_mask_up -> RGBA
880
+ rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
881
+
882
+ # 10) Downscale RGBA back to crop size
883
+ rgba_square = _resize_image_lanczos(rgba_up, s, s)
884
+
885
+ # 11) Paste back onto base at X,Y (alpha-over)
886
+ out = _alpha_over_region(base_image, rgba_square, x=x, y=y)
887
+ return out
888
+
889
+
890
+ class Salia_ezpz_gated_twopass:
891
+ CATEGORY = "image/salia"
892
+ RETURN_TYPES = ("IMAGE", "IMAGE")
893
+ RETURN_NAMES = ("image", "image_cropped")
894
+ FUNCTION = "run"
895
+
896
+ @classmethod
897
+ def INPUT_TYPES(cls):
898
+ assets = _list_asset_pngs() or ["<no pngs found>"]
899
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
900
+
901
+ return {
902
+ "required": {
903
+ "image": ("IMAGE",),
904
+ "trigger_string": ("STRING", {"default": ""}),
905
+
906
+ # Shared coords for both passes + final crop
907
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
908
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
909
+
910
+ # Per-pass crop sizes
911
+ "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
912
+ "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
913
+
914
+ # Shared prompts
915
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
916
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
917
+
918
+ # Per-pass upscale/denoise
919
+ "upscale_factor_1": (upscale_choices, {"default": "4"}),
920
+ "upscale_factor_2": (upscale_choices, {"default": "4"}),
921
+
922
+ "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
923
+ "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
924
+
925
+ # Shared asset image for alpha mask
926
+ "asset_image": (assets, {}),
927
+ }
928
+ }
929
+
930
+ def run(
931
+ self,
932
+ image: torch.Tensor,
933
+ trigger_string: str = "",
934
+ X_coord: int = 0,
935
+ Y_coord: int = 0,
936
+ square_size_1: int = 384,
937
+ square_size_2: int = 384,
938
+ positive_prompt: str = "",
939
+ negative_prompt: str = "",
940
+ upscale_factor_1: str = "4",
941
+ upscale_factor_2: str = "4",
942
+ denoise_1: float = 0.35,
943
+ denoise_2: float = 0.35,
944
+ asset_image: str = "",
945
+ ):
946
+ # Normalize input to [B,H,W,C] early (cropping always happens)
947
+ if image.ndim == 3:
948
+ image = image.unsqueeze(0)
949
+ if image.ndim != 4:
950
+ raise ValueError("Input image must be [B,H,W,C].")
951
+ B, H, W, C = image.shape
952
+ if C not in (3, 4):
953
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
954
+
955
+ x = int(X_coord)
956
+ y = int(Y_coord)
957
+ s2 = int(square_size_2)
958
+
959
+ # Always validate final crop bounds (even in bypass mode)
960
+ if s2 <= 0:
961
+ raise ValueError("square_size_2 must be > 0")
962
+ if x < 0 or y < 0 or x + s2 > W or y + s2 > H:
963
+ raise ValueError(f"Final crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s2}")
964
+
965
+ # If trigger_string is exactly empty: bypass ALL processing, but still crop from "second output"
966
+ if trigger_string == "":
967
+ out2 = image
968
+ cropped = out2[:, y:y + s2, x:x + s2, :]
969
+ return (out2, cropped)
970
+
971
+ # Validate hardcoded samplers/schedulers early for clearer failures
972
+ _validate_sampler_scheduler_exist_or_throw(_PASS1_SAMPLER, _PASS1_SCHEDULER)
973
+ _validate_sampler_scheduler_exist_or_throw(_PASS2_SAMPLER, _PASS2_SCHEDULER)
974
+
975
+ # Validate hardcoded model asset names exist
976
+ _ensure_model_assets_exist_or_throw()
977
+
978
+ # Asset image (shared)
979
+ if asset_image == "<no pngs found>":
980
+ raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
981
+ _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
982
+
983
+ # Batch-match the mask once
984
+ if asset_mask.ndim == 2:
985
+ asset_mask = asset_mask.unsqueeze(0)
986
+ if asset_mask.ndim != 3:
987
+ raise ValueError("Asset mask must be [B,H,W].")
988
+ if asset_mask.shape[0] != B:
989
+ if asset_mask.shape[0] == 1 and B > 1:
990
+ asset_mask = asset_mask.expand(B, -1, -1)
991
+ else:
992
+ raise ValueError("Batch mismatch for asset mask.")
993
+
994
+ # Load checkpoint + controlnet (cached)
995
+ model, clip, vae = _load_checkpoint_cached(_TWOPASS_CKPT_NAME)
996
+ controlnet = _load_controlnet_cached(_TWOPASS_CONTROLNET_NAME)
997
+
998
+ import nodes
999
+
1000
+ # CLIP encodes ONCE (shared prompts)
1001
+ pos_enc = nodes.CLIPTextEncode()
1002
+ neg_enc = nodes.CLIPTextEncode()
1003
+ pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
1004
+ neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
1005
+ (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
1006
+ (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
1007
+
1008
+ # Pass 1
1009
+ out1 = _run_one_pass(
1010
+ image,
1011
+ x=x,
1012
+ y=y,
1013
+ square_size=int(square_size_1),
1014
+ upscale_factor=int(upscale_factor_1),
1015
+ denoise=float(denoise_1),
1016
+ steps=int(_PASS1_STEPS),
1017
+ cfg=float(_PASS1_CFG),
1018
+ sampler_name=str(_PASS1_SAMPLER),
1019
+ scheduler=str(_PASS1_SCHEDULER),
1020
+ controlnet_strength=float(_PASS1_CN_STRENGTH),
1021
+ model=model,
1022
+ clip=clip,
1023
+ vae=vae,
1024
+ controlnet=controlnet,
1025
+ pos_cond=pos_cond,
1026
+ neg_cond=neg_cond,
1027
+ asset_mask_batched=asset_mask,
1028
+ pass_tag="PASS1",
1029
+ positive_prompt=str(positive_prompt),
1030
+ negative_prompt=str(negative_prompt),
1031
+ asset_image=str(asset_image),
1032
+ )
1033
+
1034
+ # Pass 2 (uses output of pass 1)
1035
+ out2 = _run_one_pass(
1036
+ out1,
1037
+ x=x,
1038
+ y=y,
1039
+ square_size=int(square_size_2),
1040
+ upscale_factor=int(upscale_factor_2),
1041
+ denoise=float(denoise_2),
1042
+ steps=int(_PASS2_STEPS),
1043
+ cfg=float(_PASS2_CFG),
1044
+ sampler_name=str(_PASS2_SAMPLER),
1045
+ scheduler=str(_PASS2_SCHEDULER),
1046
+ controlnet_strength=float(_PASS2_CN_STRENGTH),
1047
+ model=model,
1048
+ clip=clip,
1049
+ vae=vae,
1050
+ controlnet=controlnet,
1051
+ pos_cond=pos_cond,
1052
+ neg_cond=neg_cond,
1053
+ asset_mask_batched=asset_mask,
1054
+ pass_tag="PASS2",
1055
+ positive_prompt=str(positive_prompt),
1056
+ negative_prompt=str(negative_prompt),
1057
+ asset_image=str(asset_image),
1058
+ )
1059
+
1060
+ # Final crop output from second-pass image (always uses square_size_2)
1061
+ cropped = out2[:, y:y + s2, x:x + s2, :]
1062
+ return (out2, cropped)
1063
+
1064
+
1065
+ # -------------------------------------------------------------------------------------
1066
+ # Node mappings (include both nodes)
1067
+ # -------------------------------------------------------------------------------------
1068
+
1069
+ NODE_CLASS_MAPPINGS = {
1070
+ "Salia_ezpz_gated_DoubleTime": Salia_ezpz_gated_DoubleTime,
1071
+ }
1072
+
1073
+ NODE_DISPLAY_NAME_MAPPINGS = {
1074
+ "Salia_ezpz_gated_DoubleTime": "Salia_ezpz_gated_DoubleTime",
1075
+ }
salia_detailer_ezpz_gated_Duo2.py ADDED
@@ -0,0 +1,1252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import shutil
3
+ import threading
4
+ import urllib.request
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Tuple, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageOps
11
+
12
+ import folder_paths
13
+ import comfy.model_management as model_management
14
+
15
+
16
+ # transformers is required for depth-estimation pipeline
17
+ try:
18
+ from transformers import pipeline
19
+ except Exception as e:
20
+ pipeline = None
21
+ _TRANSFORMERS_IMPORT_ERROR = e
22
+
23
+
24
+ # -------------------------------------------------------------------------------------
25
+ # Global caches (checkpoint + controlnet) so using the node multiple times won't reload
26
+ # -------------------------------------------------------------------------------------
27
+
28
+ _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
29
+ _CN_CACHE: Dict[str, Any] = {}
30
+ _CKPT_LOCK = threading.Lock()
31
+ _CN_LOCK = threading.Lock()
32
+
33
+
34
+ # -------------------------------------------------------------------------------------
35
+ # Plugin root detection (works whether file is in plugin root or nodes/)
36
+ # -------------------------------------------------------------------------------------
37
+
38
+ def _find_plugin_root() -> Path:
39
+ """
40
+ Walk upwards from this file until we find an 'assets' folder.
41
+ Robust against hyphen/underscore package naming and different file placement.
42
+ """
43
+ here = Path(__file__).resolve()
44
+ for parent in [here.parent] + list(here.parents)[:12]:
45
+ if (parent / "assets").is_dir():
46
+ return parent
47
+ # fallback: typical nodes/<file>.py
48
+ return here.parent.parent
49
+
50
+
51
+ PLUGIN_ROOT = _find_plugin_root()
52
+
53
+
54
+ # -------------------------------------------------------------------------------------
55
+ # PIL helpers (Lanczos resize for IMAGE and MASK)
56
+ # -------------------------------------------------------------------------------------
57
+
58
+ def _pil_lanczos():
59
+ if hasattr(Image, "Resampling"):
60
+ return Image.Resampling.LANCZOS
61
+ return Image.LANCZOS
62
+
63
+
64
+ def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
65
+ """
66
+ Comfy IMAGE: [B,H,W,C] or [H,W,C], float [0..1] -> PIL RGB/RGBA
67
+ """
68
+ if img.ndim == 4:
69
+ img = img[0]
70
+ img = img.detach().cpu().float().clamp(0, 1)
71
+ arr = (img.numpy() * 255.0).round().astype(np.uint8)
72
+ if arr.shape[-1] == 4:
73
+ return Image.fromarray(arr, mode="RGBA")
74
+ return Image.fromarray(arr, mode="RGB")
75
+
76
+
77
+ def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor:
78
+ """
79
+ PIL RGB/RGBA -> Comfy IMAGE [1,H,W,C], float [0..1]
80
+ """
81
+ if pil.mode not in ("RGB", "RGBA"):
82
+ pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB")
83
+ arr = np.array(pil).astype(np.float32) / 255.0
84
+ t = torch.from_numpy(arr) # [H,W,C]
85
+ return t.unsqueeze(0)
86
+
87
+
88
+ def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image:
89
+ """
90
+ Comfy MASK: [B,H,W] or [H,W], float [0..1] -> PIL L
91
+ """
92
+ if mask.ndim == 3:
93
+ mask = mask[0]
94
+ mask = mask.detach().cpu().float().clamp(0, 1)
95
+ arr = (mask.numpy() * 255.0).round().astype(np.uint8)
96
+ return Image.fromarray(arr, mode="L")
97
+
98
+
99
+ def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
100
+ """
101
+ PIL L -> Comfy MASK [1,H,W], float [0..1]
102
+ """
103
+ if pil_l.mode != "L":
104
+ pil_l = pil_l.convert("L")
105
+ arr = np.array(pil_l).astype(np.float32) / 255.0
106
+ t = torch.from_numpy(arr) # [H,W]
107
+ return t.unsqueeze(0)
108
+
109
+
110
+ def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
111
+ """
112
+ Resize Comfy IMAGE [B,H,W,C] with Lanczos via PIL.
113
+ """
114
+ if img.ndim != 4:
115
+ raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
116
+ outs = []
117
+ for i in range(img.shape[0]):
118
+ pil = _image_tensor_to_pil(img[i].unsqueeze(0))
119
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
120
+ outs.append(_pil_to_image_tensor(pil))
121
+ return torch.cat(outs, dim=0)
122
+
123
+
124
+ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
125
+ """
126
+ Resize Comfy MASK [B,H,W] with Lanczos via PIL.
127
+ """
128
+ if mask.ndim != 3:
129
+ raise ValueError("Expected MASK tensor with shape [B,H,W].")
130
+ outs = []
131
+ for i in range(mask.shape[0]):
132
+ pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
133
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
134
+ outs.append(_pil_to_mask_tensor(pil))
135
+ return torch.cat(outs, dim=0)
136
+
137
+
138
+ # -------------------------------------------------------------------------------------
139
+ # ✅ ComfyUI 0.5.1 FIX: Manual JoinImageWithAlpha equivalent
140
+ # -------------------------------------------------------------------------------------
141
+
142
+ def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
143
+ """
144
+ Make RGBA from:
145
+ rgb: IMAGE [B,H,W,3] float [0..1]
146
+ mask: MASK [B,H,W] float [0..1] (Comfy convention: 1=masked/transparent)
147
+ Output:
148
+ rgba: IMAGE [B,H,W,4] where alpha = 1 - mask (1=opaque, 0=transparent)
149
+ """
150
+ if rgb.ndim == 3:
151
+ rgb = rgb.unsqueeze(0)
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(0)
154
+
155
+ if rgb.ndim != 4 or rgb.shape[-1] != 3:
156
+ raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
157
+ if mask.ndim != 3:
158
+ raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
159
+
160
+ # Batch match
161
+ if mask.shape[0] != rgb.shape[0]:
162
+ if mask.shape[0] == 1 and rgb.shape[0] > 1:
163
+ mask = mask.expand(rgb.shape[0], -1, -1)
164
+ else:
165
+ raise ValueError("Batch mismatch between rgb and mask.")
166
+
167
+ # Size match
168
+ if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
169
+ raise ValueError(
170
+ f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
171
+ )
172
+
173
+ mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
174
+ alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) # [B,H,W,1]
175
+
176
+ rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) # [B,H,W,4]
177
+ return rgba
178
+
179
+
180
+ # -------------------------------------------------------------------------------------
181
+ # Core lazy loaders (checkpoint + controlnet), cached globally
182
+ # -------------------------------------------------------------------------------------
183
+
184
+ def _load_checkpoint_cached(ckpt_name: str):
185
+ """
186
+ Mirrors comfy-core CheckpointLoaderSimple, but cached to avoid double-loads.
187
+ Returns: (model, clip, vae)
188
+ """
189
+ with _CKPT_LOCK:
190
+ if ckpt_name in _CKPT_CACHE:
191
+ return _CKPT_CACHE[ckpt_name]
192
+
193
+ import nodes
194
+ loader = nodes.CheckpointLoaderSimple()
195
+ fn = getattr(loader, loader.FUNCTION)
196
+ model, clip, vae = fn(ckpt_name=ckpt_name)
197
+
198
+ _CKPT_CACHE[ckpt_name] = (model, clip, vae)
199
+ return model, clip, vae
200
+
201
+
202
+ def _load_controlnet_cached(control_net_name: str):
203
+ """
204
+ Mirrors comfy-core ControlNetLoader, but cached to avoid double-loads.
205
+ Returns: controlnet
206
+ """
207
+ with _CN_LOCK:
208
+ if control_net_name in _CN_CACHE:
209
+ return _CN_CACHE[control_net_name]
210
+
211
+ import nodes
212
+ loader = nodes.ControlNetLoader()
213
+ fn = getattr(loader, loader.FUNCTION)
214
+ (cn,) = fn(control_net_name=control_net_name)
215
+
216
+ _CN_CACHE[control_net_name] = cn
217
+ return cn
218
+
219
+
220
+ # -------------------------------------------------------------------------------------
221
+ # Assets/images dropdown + loader (inlined; no LoadImage_SaliaOnline_Assets dependency)
222
+ # -------------------------------------------------------------------------------------
223
+
224
+ def _assets_images_dir() -> Path:
225
+ return PLUGIN_ROOT / "assets" / "images"
226
+
227
+
228
+ def _list_asset_pngs() -> list:
229
+ img_dir = _assets_images_dir()
230
+ if not img_dir.is_dir():
231
+ return []
232
+ files = []
233
+ for p in img_dir.rglob("*"):
234
+ if p.is_file() and p.suffix.lower() == ".png":
235
+ files.append(p.relative_to(img_dir).as_posix())
236
+ files.sort()
237
+ return files
238
+
239
+
240
+ def _safe_asset_path(asset_rel_path: str) -> Path:
241
+ img_dir = _assets_images_dir()
242
+ if not img_dir.is_dir():
243
+ raise FileNotFoundError(f"assets/images folder not found: {img_dir}")
244
+
245
+ base = img_dir.resolve()
246
+ rel = Path(asset_rel_path)
247
+
248
+ if rel.is_absolute():
249
+ raise ValueError("Absolute paths are not allowed for asset_image.")
250
+
251
+ full = (base / rel).resolve()
252
+
253
+ # path traversal protection
254
+ if base != full and base not in full.parents:
255
+ raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
256
+
257
+ if not full.is_file():
258
+ raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
259
+ if full.suffix.lower() != ".png":
260
+ raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
261
+
262
+ return full
263
+
264
+
265
+ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
266
+ """
267
+ Returns (IMAGE, MASK) in ComfyUI formats.
268
+
269
+ Mask semantics: match ComfyUI core LoadImage:
270
+ - alpha is RGBA alpha channel normalized to [0..1]
271
+ - mask = 1 - alpha
272
+ """
273
+ p = _safe_asset_path(asset_rel_path)
274
+
275
+ im = Image.open(p)
276
+ im = ImageOps.exif_transpose(im)
277
+
278
+ rgba = im.convert("RGBA")
279
+ rgb = rgba.convert("RGB")
280
+
281
+ rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
282
+ img_t = torch.from_numpy(rgb_arr)[None, ...]
283
+
284
+ alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 # [H,W]
285
+ mask = 1.0 - alpha # Comfy MASK convention
286
+
287
+ mask_t = torch.from_numpy(mask)[None, ...]
288
+ return img_t, mask_t
289
+
290
+
291
+ # -------------------------------------------------------------------------------------
292
+ # Salia_Depth (INLINED, no imports from other files)
293
+ # -------------------------------------------------------------------------------------
294
+
295
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
296
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
297
+
298
+ REQUIRED_FILES = {
299
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
300
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
301
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
302
+ }
303
+
304
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
305
+
306
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
307
+ _PIPE_LOCK = threading.Lock()
308
+
309
+
310
+ def _have_required_files() -> bool:
311
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
312
+
313
+
314
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
315
+ dst.parent.mkdir(parents=True, exist_ok=True)
316
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
317
+
318
+ if tmp.exists():
319
+ try:
320
+ tmp.unlink()
321
+ except Exception:
322
+ pass
323
+
324
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
325
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
326
+ shutil.copyfileobj(r, f)
327
+
328
+ tmp.replace(dst)
329
+
330
+
331
+ def ensure_local_model_files() -> bool:
332
+ if _have_required_files():
333
+ return True
334
+ try:
335
+ for fname, url in REQUIRED_FILES.items():
336
+ fpath = MODEL_DIR / fname
337
+ if fpath.exists():
338
+ continue
339
+ _download_url_to_file(url, fpath)
340
+ return _have_required_files()
341
+ except Exception:
342
+ return False
343
+
344
+
345
+ def HWC3(x: np.ndarray) -> np.ndarray:
346
+ assert x.dtype == np.uint8
347
+ if x.ndim == 2:
348
+ x = x[:, :, None]
349
+ assert x.ndim == 3
350
+ H, W, C = x.shape
351
+ assert C == 1 or C == 3 or C == 4
352
+ if C == 3:
353
+ return x
354
+ if C == 1:
355
+ return np.concatenate([x, x, x], axis=2)
356
+ # C == 4
357
+ color = x[:, :, 0:3].astype(np.float32)
358
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
359
+ y = color * alpha + 255.0 * (1.0 - alpha) # white background
360
+ y = y.clip(0, 255).astype(np.uint8)
361
+ return y
362
+
363
+
364
+ def pad64(x: int) -> int:
365
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
366
+
367
+
368
+ def safer_memory(x: np.ndarray) -> np.ndarray:
369
+ return np.ascontiguousarray(x.copy()).copy()
370
+
371
+
372
+ def resize_image_with_pad_min_side(
373
+ input_image: np.ndarray,
374
+ resolution: int,
375
+ upscale_method: str = "INTER_CUBIC",
376
+ skip_hwc3: bool = False,
377
+ mode: str = "edge",
378
+ ) -> Tuple[np.ndarray, Any]:
379
+ cv2 = None
380
+ try:
381
+ import cv2 as _cv2
382
+ cv2 = _cv2
383
+ except Exception:
384
+ cv2 = None
385
+
386
+ img = input_image if skip_hwc3 else HWC3(input_image)
387
+
388
+ H_raw, W_raw, _ = img.shape
389
+ if resolution <= 0:
390
+ return img, (lambda x: x)
391
+
392
+ k = float(resolution) / float(min(H_raw, W_raw))
393
+ H_target = int(np.round(float(H_raw) * k))
394
+ W_target = int(np.round(float(W_raw) * k))
395
+
396
+ if cv2 is not None:
397
+ upscale_methods = {
398
+ "INTER_NEAREST": cv2.INTER_NEAREST,
399
+ "INTER_LINEAR": cv2.INTER_LINEAR,
400
+ "INTER_AREA": cv2.INTER_AREA,
401
+ "INTER_CUBIC": cv2.INTER_CUBIC,
402
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
403
+ }
404
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
405
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
406
+ else:
407
+ pil = Image.fromarray(img)
408
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
409
+ pil = pil.resize((W_target, H_target), resample=resample)
410
+ img = np.array(pil, dtype=np.uint8)
411
+
412
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
413
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
414
+
415
+ def remove_pad(x: np.ndarray) -> np.ndarray:
416
+ return safer_memory(x[:H_target, :W_target, ...])
417
+
418
+ return safer_memory(img_padded), remove_pad
419
+
420
+
421
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
422
+ img = HWC3(img_u8)
423
+ H_raw, W_raw, _ = img.shape
424
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
425
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
426
+
427
+ def remove_pad(x: np.ndarray) -> np.ndarray:
428
+ return safer_memory(x[:H_raw, :W_raw, ...])
429
+
430
+ return safer_memory(img_padded), remove_pad
431
+
432
+
433
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
434
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
435
+ rgba = inp_u8.astype(np.uint8)
436
+ rgb = rgba[:, :, 0:3].astype(np.float32)
437
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
438
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
439
+ alpha_u8 = rgba[:, :, 3].copy()
440
+ return rgb_white, alpha_u8
441
+ return HWC3(inp_u8), None
442
+
443
+
444
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
445
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
446
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
447
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
448
+ return out
449
+
450
+
451
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
452
+ if img.ndim == 4:
453
+ img = img[0]
454
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
455
+ u8 = (arr * 255.0).round().astype(np.uint8)
456
+ return u8
457
+
458
+
459
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
460
+ img_u8 = HWC3(img_u8)
461
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
462
+ return t.unsqueeze(0) # [1,H,W,C]
463
+
464
+
465
+ def _try_load_pipeline(model_source: str, device: torch.device):
466
+ if pipeline is None:
467
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
468
+
469
+ key = (model_source, str(device))
470
+ with _PIPE_LOCK:
471
+ if key in _PIPE_CACHE:
472
+ return _PIPE_CACHE[key]
473
+
474
+ p = pipeline(task="depth-estimation", model=model_source)
475
+ try:
476
+ p.model = p.model.to(device)
477
+ p.device = device
478
+ except Exception:
479
+ pass
480
+
481
+ _PIPE_CACHE[key] = p
482
+ return p
483
+
484
+
485
+ def get_depth_pipeline(device: torch.device):
486
+ if ensure_local_model_files():
487
+ try:
488
+ return _try_load_pipeline(str(MODEL_DIR), device)
489
+ except Exception:
490
+ pass
491
+ try:
492
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
493
+ except Exception:
494
+ return None
495
+
496
+
497
+ def depth_estimate_zoe_style(
498
+ pipe,
499
+ input_rgb_u8: np.ndarray,
500
+ detect_resolution: int,
501
+ upscale_method: str = "INTER_CUBIC",
502
+ ) -> np.ndarray:
503
+ if detect_resolution == -1:
504
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
505
+ else:
506
+ work_img, remove_pad = resize_image_with_pad_min_side(
507
+ input_rgb_u8,
508
+ int(detect_resolution),
509
+ upscale_method=upscale_method,
510
+ skip_hwc3=False,
511
+ mode="edge",
512
+ )
513
+
514
+ pil_image = Image.fromarray(work_img)
515
+
516
+ with torch.no_grad():
517
+ result = pipe(pil_image)
518
+ depth = result["depth"]
519
+
520
+ if isinstance(depth, Image.Image):
521
+ depth_array = np.array(depth, dtype=np.float32)
522
+ else:
523
+ depth_array = np.array(depth, dtype=np.float32)
524
+
525
+ vmin = float(np.percentile(depth_array, 2))
526
+ vmax = float(np.percentile(depth_array, 85))
527
+
528
+ depth_array = depth_array - vmin
529
+ denom = (vmax - vmin)
530
+ if abs(denom) < 1e-12:
531
+ denom = 1e-6
532
+ depth_array = depth_array / denom
533
+
534
+ depth_array = 1.0 - depth_array
535
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
536
+
537
+ detected_map = remove_pad(HWC3(depth_image))
538
+ return detected_map
539
+
540
+
541
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray:
542
+ try:
543
+ import cv2
544
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
545
+ return out.astype(np.uint8)
546
+ except Exception:
547
+ pil = Image.fromarray(depth_rgb_u8)
548
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
549
+ return np.array(pil, dtype=np.uint8)
550
+
551
+
552
+ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
553
+ """
554
+ Internal callable version of Salia_Depth:
555
+ input: IMAGE [B,H,W,3 or 4]
556
+ output: IMAGE [B,H,W,3]
557
+ """
558
+ try:
559
+ device = model_management.get_torch_device()
560
+ except Exception:
561
+ device = torch.device("cpu")
562
+
563
+ pipe_obj = None
564
+ try:
565
+ pipe_obj = get_depth_pipeline(device)
566
+ except Exception:
567
+ pipe_obj = None
568
+
569
+ if pipe_obj is None:
570
+ return image
571
+
572
+ if image.ndim == 3:
573
+ image = image.unsqueeze(0)
574
+
575
+ outs = []
576
+ for i in range(image.shape[0]):
577
+ try:
578
+ h0 = int(image[i].shape[0])
579
+ w0 = int(image[i].shape[1])
580
+
581
+ inp_u8 = comfy_tensor_to_u8(image[i])
582
+
583
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
584
+ had_rgba = alpha_u8 is not None
585
+
586
+ depth_rgb = depth_estimate_zoe_style(
587
+ pipe=pipe_obj,
588
+ input_rgb_u8=rgb_for_depth,
589
+ detect_resolution=int(resolution),
590
+ upscale_method="INTER_CUBIC",
591
+ )
592
+
593
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
594
+
595
+ if had_rgba:
596
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
597
+ try:
598
+ import cv2
599
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
600
+ except Exception:
601
+ pil_a = Image.fromarray(alpha_u8)
602
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
603
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
604
+
605
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
606
+
607
+ outs.append(u8_to_comfy_tensor(depth_rgb))
608
+ except Exception:
609
+ outs.append(image[i].unsqueeze(0))
610
+
611
+ return torch.cat(outs, dim=0)
612
+
613
+
614
+ # -------------------------------------------------------------------------------------
615
+ # Alpha-over paste (RGBA square onto base at X,Y)
616
+ # -------------------------------------------------------------------------------------
617
+
618
+ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor:
619
+ """
620
+ base: [B,H,W,C] where C is 3 or 4, float [0..1]
621
+ overlay_rgba: [B,s,s,4] float [0..1]
622
+ """
623
+ if base.ndim != 4 or overlay_rgba.ndim != 4:
624
+ raise ValueError("base and overlay must be [B,H,W,C].")
625
+
626
+ B, H, W, C = base.shape
627
+ b2, sH, sW, c2 = overlay_rgba.shape
628
+ if c2 != 4:
629
+ raise ValueError("overlay_rgba must have 4 channels (RGBA).")
630
+ if sH != sW:
631
+ raise ValueError("overlay must be square.")
632
+ s = sH
633
+
634
+ if x < 0 or y < 0 or x + s > W or y + s > H:
635
+ raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
636
+
637
+ if b2 != B:
638
+ if b2 == 1 and B > 1:
639
+ overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
640
+ else:
641
+ raise ValueError("Batch mismatch between base and overlay.")
642
+
643
+ out = base.clone()
644
+
645
+ overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1)
646
+ overlay_a = overlay_rgba[..., 3:4].clamp(0, 1)
647
+
648
+ base_rgb = out[:, y:y + s, x:x + s, 0:3]
649
+ comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
650
+ out[:, y:y + s, x:x + s, 0:3] = comp_rgb
651
+
652
+ if C == 4:
653
+ base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
654
+ comp_a = overlay_a + base_a * (1.0 - overlay_a)
655
+ out[:, y:y + s, x:x + s, 3:4] = comp_a
656
+
657
+ return out.clamp(0, 1)
658
+
659
+
660
+ # -------------------------------------------------------------------------------------
661
+ # The One-Node Workflow
662
+ # -------------------------------------------------------------------------------------
663
+
664
+ class Salia_ezpz_gated:
665
+ CATEGORY = "image/salia"
666
+ RETURN_TYPES = ("IMAGE",)
667
+ RETURN_NAMES = ("image",)
668
+ FUNCTION = "run"
669
+
670
+ @classmethod
671
+ def INPUT_TYPES(cls):
672
+ ckpts = folder_paths.get_filename_list("checkpoints") or ["<no checkpoints found>"]
673
+ cns = folder_paths.get_filename_list("controlnet") or ["<no controlnets found>"]
674
+ assets = _list_asset_pngs() or ["<no pngs found>"]
675
+
676
+ try:
677
+ import comfy.samplers
678
+ sampler_names = comfy.samplers.KSampler.SAMPLERS
679
+ scheduler_names = comfy.samplers.KSampler.SCHEDULERS
680
+ except Exception:
681
+ sampler_names = ["euler"]
682
+ scheduler_names = ["karras"]
683
+
684
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
685
+
686
+ return {
687
+ "required": {
688
+ "image": ("IMAGE",),
689
+ "trigger_string": ("STRING", {"default": ""}),
690
+
691
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
692
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
693
+ "square_size": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
694
+
695
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
696
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
697
+
698
+ "upscale_factor": (upscale_choices, {"default": "4"}),
699
+
700
+ "ckpt_name": (ckpts, {}),
701
+ "control_net_name": (cns, {}),
702
+ "asset_image": (assets, {}),
703
+
704
+ "controlnet_strength": ("FLOAT", {"default": 0.33, "min": 0.00, "max": 10.00, "step": 0.01}),
705
+ "controlnet_start_percent": ("FLOAT", {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.01}),
706
+ "controlnet_end_percent": ("FLOAT", {"default": 1.00, "min": 0.00, "max": 1.00, "step": 0.01}),
707
+
708
+ "steps": ("INT", {"default": 30, "min": 1, "max": 200, "step": 1}),
709
+ "cfg": ("FLOAT", {"default": 2.6, "min": 0.00, "max": 10.00, "step": 0.05}),
710
+ "sampler_name": (sampler_names, {"default": "euler"} if "euler" in sampler_names else {}),
711
+ "scheduler": (scheduler_names, {"default": "karras"} if "karras" in scheduler_names else {}),
712
+ "denoise": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
713
+ }
714
+ }
715
+
716
+ def run(
717
+ self,
718
+ image: torch.Tensor,
719
+ trigger_string: str = "",
720
+ X_coord: int = 0,
721
+ Y_coord: int = 0,
722
+ square_size: int = 384,
723
+ positive_prompt: str = "",
724
+ negative_prompt: str = "",
725
+ upscale_factor: str = "4",
726
+ ckpt_name: str = "",
727
+ control_net_name: str = "",
728
+ asset_image: str = "",
729
+ controlnet_strength: float = 0.33,
730
+ controlnet_start_percent: float = 0.0,
731
+ controlnet_end_percent: float = 1.0,
732
+ steps: int = 30,
733
+ cfg: float = 2.6,
734
+ sampler_name: str = "euler",
735
+ scheduler: str = "karras",
736
+ denoise: float = 0.35,
737
+ ):
738
+ # If trigger_string is exactly empty, bypass everything and return input unchanged.
739
+ if trigger_string == "":
740
+ return (image,)
741
+
742
+ # Normalize input to [B,H,W,C]
743
+ if image.ndim == 3:
744
+ image = image.unsqueeze(0)
745
+ if image.ndim != 4:
746
+ raise ValueError("Input image must be [B,H,W,C].")
747
+
748
+ B, H, W, C = image.shape
749
+ if C not in (3, 4):
750
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
751
+
752
+ x = int(X_coord)
753
+ y = int(Y_coord)
754
+ s = int(square_size)
755
+
756
+ up = int(upscale_factor)
757
+ if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
758
+ raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
759
+
760
+ if s <= 0:
761
+ raise ValueError("square_size must be > 0")
762
+ if x < 0 or y < 0 or x + s > W or y + s > H:
763
+ raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
764
+
765
+ up_w = s * up
766
+ up_h = s * up
767
+
768
+ # VAE/UNet path likes multiples of 8
769
+ if (up_w % 8) != 0 or (up_h % 8) != 0:
770
+ raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
771
+
772
+ start_p = float(max(0.0, min(1.0, controlnet_start_percent)))
773
+ end_p = float(max(0.0, min(1.0, controlnet_end_percent)))
774
+ if end_p < start_p:
775
+ start_p, end_p = end_p, start_p
776
+
777
+ # 1) Crop square
778
+ crop = image[:, y:y + s, x:x + s, :]
779
+ crop_rgb = crop[:, :, :, 0:3].contiguous()
780
+
781
+ # 2) Depth (inline Salia_Depth) then upscale with Lanczos
782
+ depth_small = _salia_depth_execute(crop_rgb, resolution=s)
783
+ depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
784
+
785
+ # 3) Upscale crop for VAE encode
786
+ crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
787
+
788
+ # 4) Load asset mask and resize
789
+ if asset_image == "<no pngs found>":
790
+ raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
791
+ _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
792
+
793
+ if asset_mask.ndim == 2:
794
+ asset_mask = asset_mask.unsqueeze(0)
795
+ if asset_mask.ndim != 3:
796
+ raise ValueError("Asset mask must be [B,H,W].")
797
+
798
+ if asset_mask.shape[0] != B:
799
+ if asset_mask.shape[0] == 1 and B > 1:
800
+ asset_mask = asset_mask.expand(B, -1, -1)
801
+ else:
802
+ raise ValueError("Batch mismatch for asset mask.")
803
+
804
+ asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
805
+
806
+ # 5) Load checkpoint + controlnet (cached)
807
+ if ckpt_name == "<no checkpoints found>":
808
+ raise FileNotFoundError("No checkpoints found in models/checkpoints.")
809
+ if control_net_name == "<no controlnets found>":
810
+ raise FileNotFoundError("No controlnets found in models/controlnet.")
811
+
812
+ model, clip, vae = _load_checkpoint_cached(ckpt_name)
813
+ controlnet = _load_controlnet_cached(control_net_name)
814
+
815
+ import nodes
816
+
817
+ # 6) CLIP encodes
818
+ pos_enc = nodes.CLIPTextEncode()
819
+ neg_enc = nodes.CLIPTextEncode()
820
+ pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
821
+ neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
822
+ (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
823
+ (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
824
+
825
+ # 7) Apply ControlNet
826
+ cn_apply = nodes.ControlNetApplyAdvanced()
827
+ cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
828
+ pos_cn, neg_cn = cn_fn(
829
+ strength=float(controlnet_strength),
830
+ start_percent=float(start_p),
831
+ end_percent=float(end_p),
832
+ positive=pos_cond,
833
+ negative=neg_cond,
834
+ control_net=controlnet,
835
+ image=depth_up,
836
+ vae=vae,
837
+ )
838
+
839
+ # 8) VAE Encode
840
+ vae_enc = nodes.VAEEncode()
841
+ vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
842
+ (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
843
+
844
+ # 9) KSampler (deterministic seed derived from inputs)
845
+ seed_material = (
846
+ f"{ckpt_name}|{control_net_name}|{asset_image}|{x}|{y}|{s}|{up}|"
847
+ f"{steps}|{cfg}|{sampler_name}|{scheduler}|{denoise}|"
848
+ f"{controlnet_strength}|{start_p}|{end_p}|"
849
+ f"{positive_prompt}|{negative_prompt}"
850
+ ).encode("utf-8", errors="ignore")
851
+ seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
852
+
853
+ ksampler = nodes.KSampler()
854
+ k_fn = getattr(ksampler, ksampler.FUNCTION)
855
+ (sampled_latent,) = k_fn(
856
+ seed=seed64,
857
+ steps=int(steps),
858
+ cfg=float(cfg),
859
+ sampler_name=str(sampler_name),
860
+ scheduler=str(scheduler),
861
+ denoise=float(denoise),
862
+ model=model,
863
+ positive=pos_cn,
864
+ negative=neg_cn,
865
+ latent_image=latent,
866
+ )
867
+
868
+ # 10) VAE Decode -> RGB
869
+ vae_dec = nodes.VAEDecode()
870
+ vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
871
+ (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
872
+
873
+ # 11) ✅ Manual "JoinImageWithAlpha"
874
+ rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
875
+
876
+ # 12) Downscale RGBA back to crop size
877
+ rgba_square = _resize_image_lanczos(rgba_up, s, s)
878
+
879
+ # 13) Paste back onto original at X,Y (alpha-over)
880
+ out = _alpha_over_region(image, rgba_square, x=x, y=y)
881
+ return (out,)
882
+
883
+ # -------------------------------------------------------------------------------------
884
+ # Two-Pass EZPZ node (hardcoded ckpt/controlnet + per-pass hardcoded sampler/settings)
885
+ # -------------------------------------------------------------------------------------
886
+
887
+ _HARDCODED_CKPT_NAME = "SaliaHighlady_Speedy.safetensors"
888
+ _HARDCODED_CONTROLNET_NAME = "diffusion_pytorch_model_promax.safetensors"
889
+ _HARDCODED_CN_START = 0.00
890
+ _HARDCODED_CN_END = 1.00
891
+
892
+ # Pass 1 hardcoded settings
893
+ _PASS1_SAMPLER_NAME = "dpmpp_2m_sde_heun_gpu"
894
+ _PASS1_SCHEDULER = "karras"
895
+ _PASS1_STEPS = 29
896
+ _PASS1_CFG = 2.6
897
+ _PASS1_CONTROLNET_STRENGTH = 0.33
898
+
899
+ # Pass 2 hardcoded settings
900
+ _PASS2_SAMPLER_NAME = "res_multistep_ancestral_cfg_pp"
901
+ _PASS2_SCHEDULER = "karras"
902
+ _PASS2_STEPS = 30
903
+ _PASS2_CFG = 1.7
904
+ _PASS2_CONTROLNET_STRENGTH = 0.5
905
+
906
+
907
+ class Salia_ezpz_gated_Duo2:
908
+ """
909
+ Runs the same EZPZ pipeline twice, sequentially:
910
+ input -> pass1 -> pass2 -> output
911
+
912
+ Outputs:
913
+ (image, image_cropped)
914
+ image = pass2 final composite
915
+ image_cropped = crop from pass2 final composite at X/Y with square_size_2
916
+
917
+ Special:
918
+ If trigger_string == "":
919
+ - bypass both passes (no depth/cn/ksampler/etc)
920
+ - still crops from the (bypassed) "second output" which equals the input
921
+ """
922
+
923
+ CATEGORY = "image/salia"
924
+ RETURN_TYPES = ("IMAGE", "IMAGE")
925
+ RETURN_NAMES = ("image", "image_cropped")
926
+ FUNCTION = "run"
927
+
928
+ @classmethod
929
+ def INPUT_TYPES(cls):
930
+ assets = _list_asset_pngs() or ["<no pngs found>"]
931
+
932
+ # Keep the same upscale choices as your original node
933
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
934
+
935
+ return {
936
+ "required": {
937
+ "image": ("IMAGE",),
938
+ "trigger_string": ("STRING", {"default": ""}),
939
+
940
+ # shared coords (used in BOTH passes + final crop)
941
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
942
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
943
+
944
+ # shared prompts (used in BOTH passes)
945
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
946
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
947
+
948
+ # shared asset mask (used in BOTH passes)
949
+ "asset_image": (assets, {}),
950
+
951
+ # pass 1 variable inputs
952
+ "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
953
+ "upscale_factor_1": (upscale_choices, {"default": "4"}),
954
+ "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
955
+
956
+ # pass 2 variable inputs (+ used for final output crop)
957
+ "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
958
+ "upscale_factor_2": (upscale_choices, {"default": "4"}),
959
+ "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
960
+ }
961
+ }
962
+
963
+ def run(
964
+ self,
965
+ image: torch.Tensor,
966
+ trigger_string: str = "",
967
+ X_coord: int = 0,
968
+ Y_coord: int = 0,
969
+ positive_prompt: str = "",
970
+ negative_prompt: str = "",
971
+ asset_image: str = "",
972
+ square_size_1: int = 384,
973
+ upscale_factor_1: str = "4",
974
+ denoise_1: float = 0.35,
975
+ square_size_2: int = 384,
976
+ upscale_factor_2: str = "4",
977
+ denoise_2: float = 0.35,
978
+ ):
979
+ # -----------------------------
980
+ # Normalize input to [B,H,W,C]
981
+ # -----------------------------
982
+ if image.ndim == 3:
983
+ image = image.unsqueeze(0)
984
+ if image.ndim != 4:
985
+ raise ValueError("Input image must be [B,H,W,C].")
986
+
987
+ B, H, W, C = image.shape
988
+ if C not in (3, 4):
989
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
990
+
991
+ x = int(X_coord)
992
+ y = int(Y_coord)
993
+
994
+ s1 = int(square_size_1)
995
+ s2 = int(square_size_2)
996
+
997
+ # -----------------------------
998
+ # Small helpers (validation/crop)
999
+ # -----------------------------
1000
+ def _validate_square_bounds(s: int, label: str):
1001
+ if s <= 0:
1002
+ raise ValueError(f"{label}: square_size must be > 0")
1003
+ if x < 0 or y < 0 or x + s > W or y + s > H:
1004
+ raise ValueError(
1005
+ f"{label}: out of bounds. image={W}x{H}, rect at ({x},{y}) size={s}"
1006
+ )
1007
+
1008
+ def _validate_upscale(up: int, s: int, label: str):
1009
+ if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
1010
+ raise ValueError(f"{label}: upscale_factor must be one of 1,2,4,6,8,10,12,14,16")
1011
+ if ((s * up) % 8) != 0:
1012
+ raise ValueError(f"{label}: square_size * upscale_factor must be divisible by 8 (VAE requirement).")
1013
+
1014
+ def _crop_square(img: torch.Tensor, s: int) -> torch.Tensor:
1015
+ # img is [B,H,W,C]
1016
+ return img[:, y:y + s, x:x + s, :]
1017
+
1018
+ # Always validate the final crop (required even on bypass)
1019
+ _validate_square_bounds(s2, "final crop (square_size_2)")
1020
+
1021
+ # -----------------------------
1022
+ # Trigger bypass: skip both passes, still crop from "second output"
1023
+ # -----------------------------
1024
+ if trigger_string == "":
1025
+ out2 = image # passthrough
1026
+ cropped = _crop_square(out2, s2)
1027
+ return (out2, cropped)
1028
+
1029
+ # If we're not bypassing, validate pass-1 too
1030
+ _validate_square_bounds(s1, "pass1 (square_size_1)")
1031
+ _validate_square_bounds(s2, "pass2 (square_size_2)")
1032
+
1033
+ up1 = int(upscale_factor_1)
1034
+ up2 = int(upscale_factor_2)
1035
+ _validate_upscale(up1, s1, "pass1")
1036
+ _validate_upscale(up2, s2, "pass2")
1037
+
1038
+ # Clamp denoise defensively (UI already enforces range, but keep it safe)
1039
+ d1 = float(max(0.0, min(1.0, denoise_1)))
1040
+ d2 = float(max(0.0, min(1.0, denoise_2)))
1041
+
1042
+ # -----------------------------
1043
+ # Load asset mask ONCE (resized per pass)
1044
+ # -----------------------------
1045
+ if asset_image == "<no pngs found>":
1046
+ raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
1047
+
1048
+ _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
1049
+
1050
+ if asset_mask.ndim == 2:
1051
+ asset_mask = asset_mask.unsqueeze(0)
1052
+ if asset_mask.ndim != 3:
1053
+ raise ValueError("Asset mask must be [B,H,W].")
1054
+
1055
+ if asset_mask.shape[0] != B:
1056
+ if asset_mask.shape[0] == 1 and B > 1:
1057
+ asset_mask = asset_mask.expand(B, -1, -1)
1058
+ else:
1059
+ raise ValueError("Batch mismatch for asset mask vs input image batch.")
1060
+
1061
+ # -----------------------------
1062
+ # Load checkpoint + controlnet ONCE (cached globally)
1063
+ # -----------------------------
1064
+ import nodes
1065
+
1066
+ try:
1067
+ model, clip, vae = _load_checkpoint_cached(_HARDCODED_CKPT_NAME)
1068
+ except Exception as e:
1069
+ available = folder_paths.get_filename_list("checkpoints") or []
1070
+ raise FileNotFoundError(
1071
+ f"Hardcoded ckpt not found: '{_HARDCODED_CKPT_NAME}'. "
1072
+ f"Put it in models/checkpoints. Available (first 50): {available[:50]}"
1073
+ ) from e
1074
+
1075
+ try:
1076
+ controlnet = _load_controlnet_cached(_HARDCODED_CONTROLNET_NAME)
1077
+ except Exception as e:
1078
+ available = folder_paths.get_filename_list("controlnet") or []
1079
+ raise FileNotFoundError(
1080
+ f"Hardcoded controlnet not found: '{_HARDCODED_CONTROLNET_NAME}'. "
1081
+ f"Put it in models/controlnet. Available (first 50): {available[:50]}"
1082
+ ) from e
1083
+
1084
+ # Optional: nice early error if samplers/schedulers don't exist in this Comfy install
1085
+ try:
1086
+ import comfy.samplers
1087
+ avail_samplers = set(comfy.samplers.KSampler.SAMPLERS)
1088
+ avail_scheds = set(comfy.samplers.KSampler.SCHEDULERS)
1089
+ for sname in (_PASS1_SAMPLER_NAME, _PASS2_SAMPLER_NAME):
1090
+ if sname not in avail_samplers:
1091
+ raise ValueError(
1092
+ f"Sampler '{sname}' not available in this ComfyUI install. "
1093
+ f"Available (first 50): {list(avail_samplers)[:50]}"
1094
+ )
1095
+ for sch in (_PASS1_SCHEDULER, _PASS2_SCHEDULER):
1096
+ if sch not in avail_scheds:
1097
+ raise ValueError(
1098
+ f"Scheduler '{sch}' not available in this ComfyUI install. "
1099
+ f"Available: {list(avail_scheds)}"
1100
+ )
1101
+ except Exception:
1102
+ # If comfy.samplers can't be imported here for any reason, let KSampler handle it later.
1103
+ pass
1104
+
1105
+ # -----------------------------
1106
+ # Encode prompts ONCE (shared between both passes)
1107
+ # -----------------------------
1108
+ pos_enc = nodes.CLIPTextEncode()
1109
+ neg_enc = nodes.CLIPTextEncode()
1110
+ pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
1111
+ neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
1112
+ (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
1113
+ (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
1114
+
1115
+ # Instantiate node objects ONCE (tiny, but avoids duplication)
1116
+ cn_apply = nodes.ControlNetApplyAdvanced()
1117
+ cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
1118
+
1119
+ vae_enc = nodes.VAEEncode()
1120
+ vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
1121
+
1122
+ ksampler = nodes.KSampler()
1123
+ k_fn = getattr(ksampler, ksampler.FUNCTION)
1124
+
1125
+ vae_dec = nodes.VAEDecode()
1126
+ vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
1127
+
1128
+ # -----------------------------
1129
+ # One pass of the exact pipeline (parameterized)
1130
+ # -----------------------------
1131
+ def _run_pass(
1132
+ pass_index: int,
1133
+ in_image: torch.Tensor,
1134
+ s: int,
1135
+ up: int,
1136
+ denoise_v: float,
1137
+ steps_v: int,
1138
+ cfg_v: float,
1139
+ sampler_v: str,
1140
+ scheduler_v: str,
1141
+ controlnet_strength_v: float,
1142
+ ) -> torch.Tensor:
1143
+ up_w = s * up
1144
+ up_h = s * up
1145
+
1146
+ # 1) Crop square
1147
+ crop = in_image[:, y:y + s, x:x + s, :]
1148
+ crop_rgb = crop[:, :, :, 0:3].contiguous()
1149
+
1150
+ # 2) Depth (inline Salia_Depth) then upscale with Lanczos
1151
+ depth_small = _salia_depth_execute(crop_rgb, resolution=s)
1152
+ depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
1153
+
1154
+ # 3) Upscale crop for VAE encode
1155
+ crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
1156
+
1157
+ # 4) Resize asset mask to this pass's upscaled size
1158
+ asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
1159
+
1160
+ # 5) Apply ControlNet (hardcoded start/end)
1161
+ pos_cn, neg_cn = cn_fn(
1162
+ strength=float(controlnet_strength_v),
1163
+ start_percent=float(_HARDCODED_CN_START),
1164
+ end_percent=float(_HARDCODED_CN_END),
1165
+ positive=pos_cond,
1166
+ negative=neg_cond,
1167
+ control_net=controlnet,
1168
+ image=depth_up,
1169
+ vae=vae,
1170
+ )
1171
+
1172
+ # 6) VAE Encode
1173
+ (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
1174
+
1175
+ # 7) KSampler (deterministic seed derived from inputs + pass index)
1176
+ seed_material = (
1177
+ f"{_HARDCODED_CKPT_NAME}|{_HARDCODED_CONTROLNET_NAME}|{asset_image}|"
1178
+ f"pass={pass_index}|x={x}|y={y}|s={s}|up={up}|"
1179
+ f"steps={steps_v}|cfg={cfg_v}|sampler={sampler_v}|scheduler={scheduler_v}|denoise={denoise_v}|"
1180
+ f"cn_strength={controlnet_strength_v}|"
1181
+ f"{positive_prompt}|{negative_prompt}"
1182
+ ).encode("utf-8", errors="ignore")
1183
+ seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
1184
+
1185
+ (sampled_latent,) = k_fn(
1186
+ seed=seed64,
1187
+ steps=int(steps_v),
1188
+ cfg=float(cfg_v),
1189
+ sampler_name=str(sampler_v),
1190
+ scheduler=str(scheduler_v),
1191
+ denoise=float(denoise_v),
1192
+ model=model,
1193
+ positive=pos_cn,
1194
+ negative=neg_cn,
1195
+ latent_image=latent,
1196
+ )
1197
+
1198
+ # 8) VAE Decode -> RGB
1199
+ (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
1200
+
1201
+ # 9) Join alpha using the asset mask (same approach as your original node)
1202
+ rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
1203
+
1204
+ # 10) Downscale RGBA back to crop size, then alpha-over paste back
1205
+ rgba_square = _resize_image_lanczos(rgba_up, s, s)
1206
+ out = _alpha_over_region(in_image, rgba_square, x=x, y=y)
1207
+ return out
1208
+
1209
+ # -----------------------------
1210
+ # Run pass 1 then pass 2
1211
+ # -----------------------------
1212
+ out1 = _run_pass(
1213
+ pass_index=1,
1214
+ in_image=image,
1215
+ s=s1,
1216
+ up=up1,
1217
+ denoise_v=d1,
1218
+ steps_v=_PASS1_STEPS,
1219
+ cfg_v=_PASS1_CFG,
1220
+ sampler_v=_PASS1_SAMPLER_NAME,
1221
+ scheduler_v=_PASS1_SCHEDULER,
1222
+ controlnet_strength_v=_PASS1_CONTROLNET_STRENGTH,
1223
+ )
1224
+
1225
+ out2 = _run_pass(
1226
+ pass_index=2,
1227
+ in_image=out1,
1228
+ s=s2,
1229
+ up=up2,
1230
+ denoise_v=d2,
1231
+ steps_v=_PASS2_STEPS,
1232
+ cfg_v=_PASS2_CFG,
1233
+ sampler_v=_PASS2_SAMPLER_NAME,
1234
+ scheduler_v=_PASS2_SCHEDULER,
1235
+ controlnet_strength_v=_PASS2_CONTROLNET_STRENGTH,
1236
+ )
1237
+
1238
+ # Final crop from pass-2 output
1239
+ cropped = out2[:, y:y + s2, x:x + s2, :]
1240
+
1241
+ return (out2, cropped)
1242
+
1243
+
1244
+ NODE_CLASS_MAPPINGS = {
1245
+ "Salia_ezpz_gated": Salia_ezpz_gated,
1246
+ "Salia_ezpz_gated_Duo2": Salia_ezpz_gated_Duo2,
1247
+ }
1248
+
1249
+ NODE_DISPLAY_NAME_MAPPINGS = {
1250
+ "Salia_ezpz_gated": "Salia EZPZ Gated",
1251
+ "Salia_ezpz_gated_Duo2": "Salia_ezpz_gated_Duo2",
1252
+ }