saliacoel commited on
Commit
f260c3a
·
verified ·
1 Parent(s): 7ce12b6

Upload salia_detailer_ezpz_gated.py

Browse files
Files changed (1) hide show
  1. salia_detailer_ezpz_gated.py +890 -0
salia_detailer_ezpz_gated.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import shutil
3
+ import threading
4
+ import urllib.request
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Tuple, Optional
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageOps
11
+
12
+ import folder_paths
13
+ import comfy.model_management as model_management
14
+
15
+
16
+ # transformers is required for depth-estimation pipeline
17
+ try:
18
+ from transformers import pipeline
19
+ except Exception as e:
20
+ pipeline = None
21
+ _TRANSFORMERS_IMPORT_ERROR = e
22
+
23
+
24
+ # -------------------------------------------------------------------------------------
25
+ # Global caches (checkpoint + controlnet) so using the node multiple times won't reload
26
+ # -------------------------------------------------------------------------------------
27
+
28
+ _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
29
+ _CN_CACHE: Dict[str, Any] = {}
30
+ _CKPT_LOCK = threading.Lock()
31
+ _CN_LOCK = threading.Lock()
32
+
33
+
34
+ # -------------------------------------------------------------------------------------
35
+ # Plugin root detection (works whether file is in plugin root or nodes/)
36
+ # -------------------------------------------------------------------------------------
37
+
38
+ def _find_plugin_root() -> Path:
39
+ """
40
+ Walk upwards from this file until we find an 'assets' folder.
41
+ Robust against hyphen/underscore package naming and different file placement.
42
+ """
43
+ here = Path(__file__).resolve()
44
+ for parent in [here.parent] + list(here.parents)[:12]:
45
+ if (parent / "assets").is_dir():
46
+ return parent
47
+ # fallback: typical nodes/<file>.py
48
+ return here.parent.parent
49
+
50
+
51
+ PLUGIN_ROOT = _find_plugin_root()
52
+
53
+
54
+ # -------------------------------------------------------------------------------------
55
+ # PIL helpers (Lanczos resize for IMAGE and MASK)
56
+ # -------------------------------------------------------------------------------------
57
+
58
+ def _pil_lanczos():
59
+ if hasattr(Image, "Resampling"):
60
+ return Image.Resampling.LANCZOS
61
+ return Image.LANCZOS
62
+
63
+
64
+ def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
65
+ """
66
+ Comfy IMAGE: [B,H,W,C] or [H,W,C], float [0..1] -> PIL RGB/RGBA
67
+ """
68
+ if img.ndim == 4:
69
+ img = img[0]
70
+ img = img.detach().cpu().float().clamp(0, 1)
71
+ arr = (img.numpy() * 255.0).round().astype(np.uint8)
72
+ if arr.shape[-1] == 4:
73
+ return Image.fromarray(arr, mode="RGBA")
74
+ return Image.fromarray(arr, mode="RGB")
75
+
76
+
77
+ def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor:
78
+ """
79
+ PIL RGB/RGBA -> Comfy IMAGE [1,H,W,C], float [0..1]
80
+ """
81
+ if pil.mode not in ("RGB", "RGBA"):
82
+ pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB")
83
+ arr = np.array(pil).astype(np.float32) / 255.0
84
+ t = torch.from_numpy(arr) # [H,W,C]
85
+ return t.unsqueeze(0)
86
+
87
+
88
+ def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image:
89
+ """
90
+ Comfy MASK: [B,H,W] or [H,W], float [0..1] -> PIL L
91
+ """
92
+ if mask.ndim == 3:
93
+ mask = mask[0]
94
+ mask = mask.detach().cpu().float().clamp(0, 1)
95
+ arr = (mask.numpy() * 255.0).round().astype(np.uint8)
96
+ return Image.fromarray(arr, mode="L")
97
+
98
+
99
+ def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
100
+ """
101
+ PIL L -> Comfy MASK [1,H,W], float [0..1]
102
+ """
103
+ if pil_l.mode != "L":
104
+ pil_l = pil_l.convert("L")
105
+ arr = np.array(pil_l).astype(np.float32) / 255.0
106
+ t = torch.from_numpy(arr) # [H,W]
107
+ return t.unsqueeze(0)
108
+
109
+
110
+ def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
111
+ """
112
+ Resize Comfy IMAGE [B,H,W,C] with Lanczos via PIL.
113
+ """
114
+ if img.ndim != 4:
115
+ raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
116
+ outs = []
117
+ for i in range(img.shape[0]):
118
+ pil = _image_tensor_to_pil(img[i].unsqueeze(0))
119
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
120
+ outs.append(_pil_to_image_tensor(pil))
121
+ return torch.cat(outs, dim=0)
122
+
123
+
124
+ def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
125
+ """
126
+ Resize Comfy MASK [B,H,W] with Lanczos via PIL.
127
+ """
128
+ if mask.ndim != 3:
129
+ raise ValueError("Expected MASK tensor with shape [B,H,W].")
130
+ outs = []
131
+ for i in range(mask.shape[0]):
132
+ pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
133
+ pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
134
+ outs.append(_pil_to_mask_tensor(pil))
135
+ return torch.cat(outs, dim=0)
136
+
137
+
138
+ # -------------------------------------------------------------------------------------
139
+ # ✅ ComfyUI 0.5.1 FIX: Manual JoinImageWithAlpha equivalent
140
+ # -------------------------------------------------------------------------------------
141
+
142
+ def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
143
+ """
144
+ Make RGBA from:
145
+ rgb: IMAGE [B,H,W,3] float [0..1]
146
+ mask: MASK [B,H,W] float [0..1] (Comfy convention: 1=masked/transparent)
147
+ Output:
148
+ rgba: IMAGE [B,H,W,4] where alpha = 1 - mask (1=opaque, 0=transparent)
149
+ """
150
+ if rgb.ndim == 3:
151
+ rgb = rgb.unsqueeze(0)
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(0)
154
+
155
+ if rgb.ndim != 4 or rgb.shape[-1] != 3:
156
+ raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
157
+ if mask.ndim != 3:
158
+ raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
159
+
160
+ # Batch match
161
+ if mask.shape[0] != rgb.shape[0]:
162
+ if mask.shape[0] == 1 and rgb.shape[0] > 1:
163
+ mask = mask.expand(rgb.shape[0], -1, -1)
164
+ else:
165
+ raise ValueError("Batch mismatch between rgb and mask.")
166
+
167
+ # Size match
168
+ if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
169
+ raise ValueError(
170
+ f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
171
+ )
172
+
173
+ mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
174
+ alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) # [B,H,W,1]
175
+
176
+ rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) # [B,H,W,4]
177
+ return rgba
178
+
179
+
180
+ # -------------------------------------------------------------------------------------
181
+ # Core lazy loaders (checkpoint + controlnet), cached globally
182
+ # -------------------------------------------------------------------------------------
183
+
184
+ def _load_checkpoint_cached(ckpt_name: str):
185
+ """
186
+ Mirrors comfy-core CheckpointLoaderSimple, but cached to avoid double-loads.
187
+ Returns: (model, clip, vae)
188
+ """
189
+ with _CKPT_LOCK:
190
+ if ckpt_name in _CKPT_CACHE:
191
+ return _CKPT_CACHE[ckpt_name]
192
+
193
+ import nodes
194
+ loader = nodes.CheckpointLoaderSimple()
195
+ fn = getattr(loader, loader.FUNCTION)
196
+ model, clip, vae = fn(ckpt_name=ckpt_name)
197
+
198
+ _CKPT_CACHE[ckpt_name] = (model, clip, vae)
199
+ return model, clip, vae
200
+
201
+
202
+ def _load_controlnet_cached(control_net_name: str):
203
+ """
204
+ Mirrors comfy-core ControlNetLoader, but cached to avoid double-loads.
205
+ Returns: controlnet
206
+ """
207
+ with _CN_LOCK:
208
+ if control_net_name in _CN_CACHE:
209
+ return _CN_CACHE[control_net_name]
210
+
211
+ import nodes
212
+ loader = nodes.ControlNetLoader()
213
+ fn = getattr(loader, loader.FUNCTION)
214
+ (cn,) = fn(control_net_name=control_net_name)
215
+
216
+ _CN_CACHE[control_net_name] = cn
217
+ return cn
218
+
219
+
220
+ # -------------------------------------------------------------------------------------
221
+ # Assets/images dropdown + loader (inlined; no LoadImage_SaliaOnline_Assets dependency)
222
+ # -------------------------------------------------------------------------------------
223
+
224
+ def _assets_images_dir() -> Path:
225
+ return PLUGIN_ROOT / "assets" / "images"
226
+
227
+
228
+ def _list_asset_pngs() -> list:
229
+ img_dir = _assets_images_dir()
230
+ if not img_dir.is_dir():
231
+ return []
232
+ files = []
233
+ for p in img_dir.rglob("*"):
234
+ if p.is_file() and p.suffix.lower() == ".png":
235
+ files.append(p.relative_to(img_dir).as_posix())
236
+ files.sort()
237
+ return files
238
+
239
+
240
+ def _safe_asset_path(asset_rel_path: str) -> Path:
241
+ img_dir = _assets_images_dir()
242
+ if not img_dir.is_dir():
243
+ raise FileNotFoundError(f"assets/images folder not found: {img_dir}")
244
+
245
+ base = img_dir.resolve()
246
+ rel = Path(asset_rel_path)
247
+
248
+ if rel.is_absolute():
249
+ raise ValueError("Absolute paths are not allowed for asset_image.")
250
+
251
+ full = (base / rel).resolve()
252
+
253
+ # path traversal protection
254
+ if base != full and base not in full.parents:
255
+ raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
256
+
257
+ if not full.is_file():
258
+ raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
259
+ if full.suffix.lower() != ".png":
260
+ raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
261
+
262
+ return full
263
+
264
+
265
+ def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
266
+ """
267
+ Returns (IMAGE, MASK) in ComfyUI formats.
268
+
269
+ Mask semantics: match ComfyUI core LoadImage:
270
+ - alpha is RGBA alpha channel normalized to [0..1]
271
+ - mask = 1 - alpha
272
+ """
273
+ p = _safe_asset_path(asset_rel_path)
274
+
275
+ im = Image.open(p)
276
+ im = ImageOps.exif_transpose(im)
277
+
278
+ rgba = im.convert("RGBA")
279
+ rgb = rgba.convert("RGB")
280
+
281
+ rgb_arr = np.array(rgb).astype(np.float32) / 255.0 # [H,W,3]
282
+ img_t = torch.from_numpy(rgb_arr)[None, ...]
283
+
284
+ alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 # [H,W]
285
+ mask = 1.0 - alpha # Comfy MASK convention
286
+
287
+ mask_t = torch.from_numpy(mask)[None, ...]
288
+ return img_t, mask_t
289
+
290
+
291
+ # -------------------------------------------------------------------------------------
292
+ # Salia_Depth (INLINED, no imports from other files)
293
+ # -------------------------------------------------------------------------------------
294
+
295
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
296
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
297
+
298
+ REQUIRED_FILES = {
299
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
300
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
301
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
302
+ }
303
+
304
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
305
+
306
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
307
+ _PIPE_LOCK = threading.Lock()
308
+
309
+
310
+ def _have_required_files() -> bool:
311
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
312
+
313
+
314
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
315
+ dst.parent.mkdir(parents=True, exist_ok=True)
316
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
317
+
318
+ if tmp.exists():
319
+ try:
320
+ tmp.unlink()
321
+ except Exception:
322
+ pass
323
+
324
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
325
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
326
+ shutil.copyfileobj(r, f)
327
+
328
+ tmp.replace(dst)
329
+
330
+
331
+ def ensure_local_model_files() -> bool:
332
+ if _have_required_files():
333
+ return True
334
+ try:
335
+ for fname, url in REQUIRED_FILES.items():
336
+ fpath = MODEL_DIR / fname
337
+ if fpath.exists():
338
+ continue
339
+ _download_url_to_file(url, fpath)
340
+ return _have_required_files()
341
+ except Exception:
342
+ return False
343
+
344
+
345
+ def HWC3(x: np.ndarray) -> np.ndarray:
346
+ assert x.dtype == np.uint8
347
+ if x.ndim == 2:
348
+ x = x[:, :, None]
349
+ assert x.ndim == 3
350
+ H, W, C = x.shape
351
+ assert C == 1 or C == 3 or C == 4
352
+ if C == 3:
353
+ return x
354
+ if C == 1:
355
+ return np.concatenate([x, x, x], axis=2)
356
+ # C == 4
357
+ color = x[:, :, 0:3].astype(np.float32)
358
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
359
+ y = color * alpha + 255.0 * (1.0 - alpha) # white background
360
+ y = y.clip(0, 255).astype(np.uint8)
361
+ return y
362
+
363
+
364
+ def pad64(x: int) -> int:
365
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
366
+
367
+
368
+ def safer_memory(x: np.ndarray) -> np.ndarray:
369
+ return np.ascontiguousarray(x.copy()).copy()
370
+
371
+
372
+ def resize_image_with_pad_min_side(
373
+ input_image: np.ndarray,
374
+ resolution: int,
375
+ upscale_method: str = "INTER_CUBIC",
376
+ skip_hwc3: bool = False,
377
+ mode: str = "edge",
378
+ ) -> Tuple[np.ndarray, Any]:
379
+ cv2 = None
380
+ try:
381
+ import cv2 as _cv2
382
+ cv2 = _cv2
383
+ except Exception:
384
+ cv2 = None
385
+
386
+ img = input_image if skip_hwc3 else HWC3(input_image)
387
+
388
+ H_raw, W_raw, _ = img.shape
389
+ if resolution <= 0:
390
+ return img, (lambda x: x)
391
+
392
+ k = float(resolution) / float(min(H_raw, W_raw))
393
+ H_target = int(np.round(float(H_raw) * k))
394
+ W_target = int(np.round(float(W_raw) * k))
395
+
396
+ if cv2 is not None:
397
+ upscale_methods = {
398
+ "INTER_NEAREST": cv2.INTER_NEAREST,
399
+ "INTER_LINEAR": cv2.INTER_LINEAR,
400
+ "INTER_AREA": cv2.INTER_AREA,
401
+ "INTER_CUBIC": cv2.INTER_CUBIC,
402
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
403
+ }
404
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
405
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
406
+ else:
407
+ pil = Image.fromarray(img)
408
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
409
+ pil = pil.resize((W_target, H_target), resample=resample)
410
+ img = np.array(pil, dtype=np.uint8)
411
+
412
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
413
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
414
+
415
+ def remove_pad(x: np.ndarray) -> np.ndarray:
416
+ return safer_memory(x[:H_target, :W_target, ...])
417
+
418
+ return safer_memory(img_padded), remove_pad
419
+
420
+
421
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
422
+ img = HWC3(img_u8)
423
+ H_raw, W_raw, _ = img.shape
424
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
425
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
426
+
427
+ def remove_pad(x: np.ndarray) -> np.ndarray:
428
+ return safer_memory(x[:H_raw, :W_raw, ...])
429
+
430
+ return safer_memory(img_padded), remove_pad
431
+
432
+
433
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
434
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
435
+ rgba = inp_u8.astype(np.uint8)
436
+ rgb = rgba[:, :, 0:3].astype(np.float32)
437
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
438
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
439
+ alpha_u8 = rgba[:, :, 3].copy()
440
+ return rgb_white, alpha_u8
441
+ return HWC3(inp_u8), None
442
+
443
+
444
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
445
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
446
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
447
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
448
+ return out
449
+
450
+
451
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
452
+ if img.ndim == 4:
453
+ img = img[0]
454
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
455
+ u8 = (arr * 255.0).round().astype(np.uint8)
456
+ return u8
457
+
458
+
459
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
460
+ img_u8 = HWC3(img_u8)
461
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
462
+ return t.unsqueeze(0) # [1,H,W,C]
463
+
464
+
465
+ def _try_load_pipeline(model_source: str, device: torch.device):
466
+ if pipeline is None:
467
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
468
+
469
+ key = (model_source, str(device))
470
+ with _PIPE_LOCK:
471
+ if key in _PIPE_CACHE:
472
+ return _PIPE_CACHE[key]
473
+
474
+ p = pipeline(task="depth-estimation", model=model_source)
475
+ try:
476
+ p.model = p.model.to(device)
477
+ p.device = device
478
+ except Exception:
479
+ pass
480
+
481
+ _PIPE_CACHE[key] = p
482
+ return p
483
+
484
+
485
+ def get_depth_pipeline(device: torch.device):
486
+ if ensure_local_model_files():
487
+ try:
488
+ return _try_load_pipeline(str(MODEL_DIR), device)
489
+ except Exception:
490
+ pass
491
+ try:
492
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
493
+ except Exception:
494
+ return None
495
+
496
+
497
+ def depth_estimate_zoe_style(
498
+ pipe,
499
+ input_rgb_u8: np.ndarray,
500
+ detect_resolution: int,
501
+ upscale_method: str = "INTER_CUBIC",
502
+ ) -> np.ndarray:
503
+ if detect_resolution == -1:
504
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
505
+ else:
506
+ work_img, remove_pad = resize_image_with_pad_min_side(
507
+ input_rgb_u8,
508
+ int(detect_resolution),
509
+ upscale_method=upscale_method,
510
+ skip_hwc3=False,
511
+ mode="edge",
512
+ )
513
+
514
+ pil_image = Image.fromarray(work_img)
515
+
516
+ with torch.no_grad():
517
+ result = pipe(pil_image)
518
+ depth = result["depth"]
519
+
520
+ if isinstance(depth, Image.Image):
521
+ depth_array = np.array(depth, dtype=np.float32)
522
+ else:
523
+ depth_array = np.array(depth, dtype=np.float32)
524
+
525
+ vmin = float(np.percentile(depth_array, 2))
526
+ vmax = float(np.percentile(depth_array, 85))
527
+
528
+ depth_array = depth_array - vmin
529
+ denom = (vmax - vmin)
530
+ if abs(denom) < 1e-12:
531
+ denom = 1e-6
532
+ depth_array = depth_array / denom
533
+
534
+ depth_array = 1.0 - depth_array
535
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
536
+
537
+ detected_map = remove_pad(HWC3(depth_image))
538
+ return detected_map
539
+
540
+
541
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray:
542
+ try:
543
+ import cv2
544
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
545
+ return out.astype(np.uint8)
546
+ except Exception:
547
+ pil = Image.fromarray(depth_rgb_u8)
548
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
549
+ return np.array(pil, dtype=np.uint8)
550
+
551
+
552
+ def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
553
+ """
554
+ Internal callable version of Salia_Depth:
555
+ input: IMAGE [B,H,W,3 or 4]
556
+ output: IMAGE [B,H,W,3]
557
+ """
558
+ try:
559
+ device = model_management.get_torch_device()
560
+ except Exception:
561
+ device = torch.device("cpu")
562
+
563
+ pipe_obj = None
564
+ try:
565
+ pipe_obj = get_depth_pipeline(device)
566
+ except Exception:
567
+ pipe_obj = None
568
+
569
+ if pipe_obj is None:
570
+ return image
571
+
572
+ if image.ndim == 3:
573
+ image = image.unsqueeze(0)
574
+
575
+ outs = []
576
+ for i in range(image.shape[0]):
577
+ try:
578
+ h0 = int(image[i].shape[0])
579
+ w0 = int(image[i].shape[1])
580
+
581
+ inp_u8 = comfy_tensor_to_u8(image[i])
582
+
583
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
584
+ had_rgba = alpha_u8 is not None
585
+
586
+ depth_rgb = depth_estimate_zoe_style(
587
+ pipe=pipe_obj,
588
+ input_rgb_u8=rgb_for_depth,
589
+ detect_resolution=int(resolution),
590
+ upscale_method="INTER_CUBIC",
591
+ )
592
+
593
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
594
+
595
+ if had_rgba:
596
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
597
+ try:
598
+ import cv2
599
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
600
+ except Exception:
601
+ pil_a = Image.fromarray(alpha_u8)
602
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
603
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
604
+
605
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
606
+
607
+ outs.append(u8_to_comfy_tensor(depth_rgb))
608
+ except Exception:
609
+ outs.append(image[i].unsqueeze(0))
610
+
611
+ return torch.cat(outs, dim=0)
612
+
613
+
614
+ # -------------------------------------------------------------------------------------
615
+ # Alpha-over paste (RGBA square onto base at X,Y)
616
+ # -------------------------------------------------------------------------------------
617
+
618
+ def _alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor:
619
+ """
620
+ base: [B,H,W,C] where C is 3 or 4, float [0..1]
621
+ overlay_rgba: [B,s,s,4] float [0..1]
622
+ """
623
+ if base.ndim != 4 or overlay_rgba.ndim != 4:
624
+ raise ValueError("base and overlay must be [B,H,W,C].")
625
+
626
+ B, H, W, C = base.shape
627
+ b2, sH, sW, c2 = overlay_rgba.shape
628
+ if c2 != 4:
629
+ raise ValueError("overlay_rgba must have 4 channels (RGBA).")
630
+ if sH != sW:
631
+ raise ValueError("overlay must be square.")
632
+ s = sH
633
+
634
+ if x < 0 or y < 0 or x + s > W or y + s > H:
635
+ raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
636
+
637
+ if b2 != B:
638
+ if b2 == 1 and B > 1:
639
+ overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
640
+ else:
641
+ raise ValueError("Batch mismatch between base and overlay.")
642
+
643
+ out = base.clone()
644
+
645
+ overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1)
646
+ overlay_a = overlay_rgba[..., 3:4].clamp(0, 1)
647
+
648
+ base_rgb = out[:, y:y + s, x:x + s, 0:3]
649
+ comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
650
+ out[:, y:y + s, x:x + s, 0:3] = comp_rgb
651
+
652
+ if C == 4:
653
+ base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
654
+ comp_a = overlay_a + base_a * (1.0 - overlay_a)
655
+ out[:, y:y + s, x:x + s, 3:4] = comp_a
656
+
657
+ return out.clamp(0, 1)
658
+
659
+
660
+ # -------------------------------------------------------------------------------------
661
+ # The One-Node Workflow
662
+ # -------------------------------------------------------------------------------------
663
+
664
+ class Salia_ezpz_gated:
665
+ CATEGORY = "image/salia"
666
+ RETURN_TYPES = ("IMAGE",)
667
+ RETURN_NAMES = ("image",)
668
+ FUNCTION = "run"
669
+
670
+ @classmethod
671
+ def INPUT_TYPES(cls):
672
+ ckpts = folder_paths.get_filename_list("checkpoints") or ["<no checkpoints found>"]
673
+ cns = folder_paths.get_filename_list("controlnet") or ["<no controlnets found>"]
674
+ assets = _list_asset_pngs() or ["<no pngs found>"]
675
+
676
+ try:
677
+ import comfy.samplers
678
+ sampler_names = comfy.samplers.KSampler.SAMPLERS
679
+ scheduler_names = comfy.samplers.KSampler.SCHEDULERS
680
+ except Exception:
681
+ sampler_names = ["euler"]
682
+ scheduler_names = ["karras"]
683
+
684
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
685
+
686
+ return {
687
+ "required": {
688
+ "image": ("IMAGE",),
689
+ "trigger_string": ("STRING", {"default": ""}),
690
+
691
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
692
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
693
+ "square_size": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
694
+
695
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
696
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
697
+
698
+ "upscale_factor": (upscale_choices, {"default": "4"}),
699
+
700
+ "ckpt_name": (ckpts, {}),
701
+ "control_net_name": (cns, {}),
702
+ "asset_image": (assets, {}),
703
+
704
+ "controlnet_strength": ("FLOAT", {"default": 0.33, "min": 0.00, "max": 10.00, "step": 0.01}),
705
+ "controlnet_start_percent": ("FLOAT", {"default": 0.00, "min": 0.00, "max": 1.00, "step": 0.01}),
706
+ "controlnet_end_percent": ("FLOAT", {"default": 1.00, "min": 0.00, "max": 1.00, "step": 0.01}),
707
+
708
+ "steps": ("INT", {"default": 30, "min": 1, "max": 200, "step": 1}),
709
+ "cfg": ("FLOAT", {"default": 2.6, "min": 0.00, "max": 10.00, "step": 0.05}),
710
+ "sampler_name": (sampler_names, {"default": "euler"} if "euler" in sampler_names else {}),
711
+ "scheduler": (scheduler_names, {"default": "karras"} if "karras" in scheduler_names else {}),
712
+ "denoise": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
713
+ }
714
+ }
715
+
716
+ def run(
717
+ self,
718
+ image: torch.Tensor,
719
+ trigger_string: str = "",
720
+ X_coord: int = 0,
721
+ Y_coord: int = 0,
722
+ square_size: int = 384,
723
+ positive_prompt: str = "",
724
+ negative_prompt: str = "",
725
+ upscale_factor: str = "4",
726
+ ckpt_name: str = "",
727
+ control_net_name: str = "",
728
+ asset_image: str = "",
729
+ controlnet_strength: float = 0.33,
730
+ controlnet_start_percent: float = 0.0,
731
+ controlnet_end_percent: float = 1.0,
732
+ steps: int = 30,
733
+ cfg: float = 2.6,
734
+ sampler_name: str = "euler",
735
+ scheduler: str = "karras",
736
+ denoise: float = 0.35,
737
+ ):
738
+ # If trigger_string is exactly empty, bypass everything and return input unchanged.
739
+ if trigger_string == "":
740
+ return (image,)
741
+
742
+ # Normalize input to [B,H,W,C]
743
+ if image.ndim == 3:
744
+ image = image.unsqueeze(0)
745
+ if image.ndim != 4:
746
+ raise ValueError("Input image must be [B,H,W,C].")
747
+
748
+ B, H, W, C = image.shape
749
+ if C not in (3, 4):
750
+ raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
751
+
752
+ x = int(X_coord)
753
+ y = int(Y_coord)
754
+ s = int(square_size)
755
+
756
+ up = int(upscale_factor)
757
+ if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
758
+ raise ValueError("upscale_factor must be one of: 1,2,4,6,8,10,12,14,16")
759
+
760
+ if s <= 0:
761
+ raise ValueError("square_size must be > 0")
762
+ if x < 0 or y < 0 or x + s > W or y + s > H:
763
+ raise ValueError(f"Crop out of bounds. image={W}x{H}, crop at ({x},{y}) size={s}")
764
+
765
+ up_w = s * up
766
+ up_h = s * up
767
+
768
+ # VAE/UNet path likes multiples of 8
769
+ if (up_w % 8) != 0 or (up_h % 8) != 0:
770
+ raise ValueError("square_size * upscale_factor must be divisible by 8 (required by VAE pipeline).")
771
+
772
+ start_p = float(max(0.0, min(1.0, controlnet_start_percent)))
773
+ end_p = float(max(0.0, min(1.0, controlnet_end_percent)))
774
+ if end_p < start_p:
775
+ start_p, end_p = end_p, start_p
776
+
777
+ # 1) Crop square
778
+ crop = image[:, y:y + s, x:x + s, :]
779
+ crop_rgb = crop[:, :, :, 0:3].contiguous()
780
+
781
+ # 2) Depth (inline Salia_Depth) then upscale with Lanczos
782
+ depth_small = _salia_depth_execute(crop_rgb, resolution=s)
783
+ depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
784
+
785
+ # 3) Upscale crop for VAE encode
786
+ crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
787
+
788
+ # 4) Load asset mask and resize
789
+ if asset_image == "<no pngs found>":
790
+ raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
791
+ _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
792
+
793
+ if asset_mask.ndim == 2:
794
+ asset_mask = asset_mask.unsqueeze(0)
795
+ if asset_mask.ndim != 3:
796
+ raise ValueError("Asset mask must be [B,H,W].")
797
+
798
+ if asset_mask.shape[0] != B:
799
+ if asset_mask.shape[0] == 1 and B > 1:
800
+ asset_mask = asset_mask.expand(B, -1, -1)
801
+ else:
802
+ raise ValueError("Batch mismatch for asset mask.")
803
+
804
+ asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
805
+
806
+ # 5) Load checkpoint + controlnet (cached)
807
+ if ckpt_name == "<no checkpoints found>":
808
+ raise FileNotFoundError("No checkpoints found in models/checkpoints.")
809
+ if control_net_name == "<no controlnets found>":
810
+ raise FileNotFoundError("No controlnets found in models/controlnet.")
811
+
812
+ model, clip, vae = _load_checkpoint_cached(ckpt_name)
813
+ controlnet = _load_controlnet_cached(control_net_name)
814
+
815
+ import nodes
816
+
817
+ # 6) CLIP encodes
818
+ pos_enc = nodes.CLIPTextEncode()
819
+ neg_enc = nodes.CLIPTextEncode()
820
+ pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
821
+ neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
822
+ (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
823
+ (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
824
+
825
+ # 7) Apply ControlNet
826
+ cn_apply = nodes.ControlNetApplyAdvanced()
827
+ cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
828
+ pos_cn, neg_cn = cn_fn(
829
+ strength=float(controlnet_strength),
830
+ start_percent=float(start_p),
831
+ end_percent=float(end_p),
832
+ positive=pos_cond,
833
+ negative=neg_cond,
834
+ control_net=controlnet,
835
+ image=depth_up,
836
+ vae=vae,
837
+ )
838
+
839
+ # 8) VAE Encode
840
+ vae_enc = nodes.VAEEncode()
841
+ vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
842
+ (latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
843
+
844
+ # 9) KSampler (deterministic seed derived from inputs)
845
+ seed_material = (
846
+ f"{ckpt_name}|{control_net_name}|{asset_image}|{x}|{y}|{s}|{up}|"
847
+ f"{steps}|{cfg}|{sampler_name}|{scheduler}|{denoise}|"
848
+ f"{controlnet_strength}|{start_p}|{end_p}|"
849
+ f"{positive_prompt}|{negative_prompt}"
850
+ ).encode("utf-8", errors="ignore")
851
+ seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
852
+
853
+ ksampler = nodes.KSampler()
854
+ k_fn = getattr(ksampler, ksampler.FUNCTION)
855
+ (sampled_latent,) = k_fn(
856
+ seed=seed64,
857
+ steps=int(steps),
858
+ cfg=float(cfg),
859
+ sampler_name=str(sampler_name),
860
+ scheduler=str(scheduler),
861
+ denoise=float(denoise),
862
+ model=model,
863
+ positive=pos_cn,
864
+ negative=neg_cn,
865
+ latent_image=latent,
866
+ )
867
+
868
+ # 10) VAE Decode -> RGB
869
+ vae_dec = nodes.VAEDecode()
870
+ vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
871
+ (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
872
+
873
+ # 11) ✅ Manual "JoinImageWithAlpha"
874
+ rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
875
+
876
+ # 12) Downscale RGBA back to crop size
877
+ rgba_square = _resize_image_lanczos(rgba_up, s, s)
878
+
879
+ # 13) Paste back onto original at X,Y (alpha-over)
880
+ out = _alpha_over_region(image, rgba_square, x=x, y=y)
881
+ return (out,)
882
+
883
+
884
+ NODE_CLASS_MAPPINGS = {
885
+ "Salia_ezpz_gated": Salia_ezpz_gated,
886
+ }
887
+
888
+ NODE_DISPLAY_NAME_MAPPINGS = {
889
+ "Salia_ezpz_gated": "Salia EZPZ Gated",
890
+ }