saliacoel commited on
Commit
79a1f7d
·
verified ·
1 Parent(s): 9f8df29

Update salia_depth.py

Browse files
Files changed (1) hide show
  1. salia_depth.py +610 -610
salia_depth.py CHANGED
@@ -1,611 +1,611 @@
1
- import os
2
- import shutil
3
- import urllib.request
4
- from pathlib import Path
5
- from typing import Dict, Tuple, Any, Optional, List
6
-
7
- import numpy as np
8
- import torch
9
- from PIL import Image
10
-
11
- import comfy.model_management as model_management
12
-
13
- # transformers is required for depth-estimation pipeline
14
- try:
15
- from transformers import pipeline
16
- except Exception as e:
17
- pipeline = None
18
- _TRANSFORMERS_IMPORT_ERROR = e
19
-
20
-
21
- # --------------------------------------------------------------------------------------
22
- # Paths / sources
23
- # --------------------------------------------------------------------------------------
24
-
25
- # This file: comfyui-salia_online/nodes/Salia_Depth.py
26
- # Plugin root: comfyui-salia_online/
27
- PLUGIN_ROOT = Path(__file__).resolve().parent.parent
28
-
29
- # Requested local path: assets/depth
30
- MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
31
- MODEL_DIR.mkdir(parents=True, exist_ok=True)
32
-
33
- REQUIRED_FILES = {
34
- "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
35
- "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
36
- "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
37
- }
38
-
39
- # "zoe-path" fallback
40
- ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
41
-
42
-
43
- # --------------------------------------------------------------------------------------
44
- # Logging helpers
45
- # --------------------------------------------------------------------------------------
46
-
47
- def _make_logger() -> Tuple[List[str], Any]:
48
- lines: List[str] = []
49
-
50
- def log(msg: str):
51
- # console
52
- try:
53
- print(msg)
54
- except Exception:
55
- pass
56
- # UI string
57
- lines.append(str(msg))
58
-
59
- return lines, log
60
-
61
-
62
- def _fmt_bytes(n: Optional[int]) -> str:
63
- if n is None:
64
- return "?"
65
- # simple readable
66
- for unit in ["B", "KB", "MB", "GB", "TB"]:
67
- if n < 1024:
68
- return f"{n:.0f}{unit}"
69
- n /= 1024.0
70
- return f"{n:.1f}PB"
71
-
72
-
73
- def _file_size(path: Path) -> Optional[int]:
74
- try:
75
- return path.stat().st_size
76
- except Exception:
77
- return None
78
-
79
-
80
- def _hf_cache_info() -> Dict[str, str]:
81
- info: Dict[str, str] = {}
82
- info["env.HF_HOME"] = os.environ.get("HF_HOME", "")
83
- info["env.HF_HUB_CACHE"] = os.environ.get("HF_HUB_CACHE", "")
84
- info["env.TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "")
85
- info["env.HUGGINGFACE_HUB_CACHE"] = os.environ.get("HUGGINGFACE_HUB_CACHE", "")
86
-
87
- try:
88
- from huggingface_hub import constants as hf_constants
89
- # These exist in most hub versions:
90
- info["huggingface_hub.constants.HF_HOME"] = str(getattr(hf_constants, "HF_HOME", ""))
91
- info["huggingface_hub.constants.HF_HUB_CACHE"] = str(getattr(hf_constants, "HF_HUB_CACHE", ""))
92
- except Exception:
93
- pass
94
-
95
- return info
96
-
97
-
98
- # --------------------------------------------------------------------------------------
99
- # Download helpers
100
- # --------------------------------------------------------------------------------------
101
-
102
- def _have_required_files() -> bool:
103
- return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
104
-
105
-
106
- def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
107
- """
108
- Download with atomic temp rename.
109
- """
110
- dst.parent.mkdir(parents=True, exist_ok=True)
111
- tmp = dst.with_suffix(dst.suffix + ".tmp")
112
-
113
- if tmp.exists():
114
- try:
115
- tmp.unlink()
116
- except Exception:
117
- pass
118
-
119
- req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
120
- with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
121
- shutil.copyfileobj(r, f)
122
-
123
- tmp.replace(dst)
124
-
125
-
126
- def ensure_local_model_files(log) -> bool:
127
- """
128
- Ensure assets/depth contains the 3 files.
129
- Returns True if present or downloaded successfully, else False.
130
- """
131
- # Always log expected locations + URLs, even if we don't download.
132
- log("[SaliaDepth] ===== Local model file check =====")
133
- log(f"[SaliaDepth] Plugin root: {PLUGIN_ROOT}")
134
- log(f"[SaliaDepth] Local model dir (on drive): {MODEL_DIR}")
135
-
136
- for fname, url in REQUIRED_FILES.items():
137
- fpath = MODEL_DIR / fname
138
- exists = fpath.exists()
139
- size = _file_size(fpath) if exists else None
140
- log(f"[SaliaDepth] - {fname}")
141
- log(f"[SaliaDepth] local path: {fpath} exists={exists} size={_fmt_bytes(size)}")
142
- log(f"[SaliaDepth] remote url : {url}")
143
-
144
- if _have_required_files():
145
- log("[SaliaDepth] All required local files already exist. No download needed.")
146
- return True
147
-
148
- log("[SaliaDepth] One or more local files missing. Attempting download...")
149
-
150
- try:
151
- for fname, url in REQUIRED_FILES.items():
152
- fpath = MODEL_DIR / fname
153
- if fpath.exists():
154
- continue
155
- log(f"[SaliaDepth] Downloading '{fname}' -> '{fpath}'")
156
- _download_url_to_file(url, fpath)
157
- log(f"[SaliaDepth] Downloaded '{fname}' size={_fmt_bytes(_file_size(fpath))}")
158
-
159
- ok = _have_required_files()
160
- log(f"[SaliaDepth] Download finished. ok={ok}")
161
- return ok
162
- except Exception as e:
163
- log(f"[SaliaDepth] Download failed with error: {repr(e)}")
164
- return False
165
-
166
-
167
- # --------------------------------------------------------------------------------------
168
- # Exact Zoe-style preprocessing helpers (copied/adapted from your snippet)
169
- # --------------------------------------------------------------------------------------
170
-
171
- def HWC3(x: np.ndarray) -> np.ndarray:
172
- assert x.dtype == np.uint8
173
- if x.ndim == 2:
174
- x = x[:, :, None]
175
- assert x.ndim == 3
176
- H, W, C = x.shape
177
- assert C == 1 or C == 3 or C == 4
178
- if C == 3:
179
- return x
180
- if C == 1:
181
- return np.concatenate([x, x, x], axis=2)
182
- # C == 4
183
- color = x[:, :, 0:3].astype(np.float32)
184
- alpha = x[:, :, 3:4].astype(np.float32) / 255.0
185
- y = color * alpha + 255.0 * (1.0 - alpha) # white background
186
- y = y.clip(0, 255).astype(np.uint8)
187
- return y
188
-
189
-
190
- def pad64(x: int) -> int:
191
- return int(np.ceil(float(x) / 64.0) * 64 - x)
192
-
193
-
194
- def safer_memory(x: np.ndarray) -> np.ndarray:
195
- return np.ascontiguousarray(x.copy()).copy()
196
-
197
-
198
- def resize_image_with_pad_min_side(
199
- input_image: np.ndarray,
200
- resolution: int,
201
- upscale_method: str = "INTER_CUBIC",
202
- skip_hwc3: bool = False,
203
- mode: str = "edge",
204
- log=None
205
- ) -> Tuple[np.ndarray, Any]:
206
- """
207
- EXACT behavior like your zoe.transformers.py:
208
- k = resolution / min(H,W)
209
- resize to (W_target, H_target)
210
- pad to multiple of 64
211
- return padded image and remove_pad() closure
212
- """
213
- # prefer cv2 like original for matching results
214
- cv2 = None
215
- try:
216
- import cv2 as _cv2
217
- cv2 = _cv2
218
- except Exception:
219
- cv2 = None
220
- if log:
221
- log("[SaliaDepth] WARN: cv2 not available; resizing will use PIL fallback (may change results).")
222
-
223
- if skip_hwc3:
224
- img = input_image
225
- else:
226
- img = HWC3(input_image)
227
-
228
- H_raw, W_raw, _ = img.shape
229
- if resolution <= 0:
230
- # keep original, but still pad to 64 (we will handle padding separately for -1 path)
231
- return img, (lambda x: x)
232
-
233
- k = float(resolution) / float(min(H_raw, W_raw))
234
- H_target = int(np.round(float(H_raw) * k))
235
- W_target = int(np.round(float(W_raw) * k))
236
-
237
- if cv2 is not None:
238
- upscale_methods = {
239
- "INTER_NEAREST": cv2.INTER_NEAREST,
240
- "INTER_LINEAR": cv2.INTER_LINEAR,
241
- "INTER_AREA": cv2.INTER_AREA,
242
- "INTER_CUBIC": cv2.INTER_CUBIC,
243
- "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
244
- }
245
- method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
246
- img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
247
- else:
248
- # PIL fallback
249
- pil = Image.fromarray(img)
250
- resample = Image.BICUBIC if k > 1 else Image.LANCZOS
251
- pil = pil.resize((W_target, H_target), resample=resample)
252
- img = np.array(pil, dtype=np.uint8)
253
-
254
- H_pad, W_pad = pad64(H_target), pad64(W_target)
255
- img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
256
-
257
- def remove_pad(x: np.ndarray) -> np.ndarray:
258
- return safer_memory(x[:H_target, :W_target, ...])
259
-
260
- return safer_memory(img_padded), remove_pad
261
-
262
-
263
- def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
264
- """
265
- For resolution == -1: keep original resolution but still pad to multiples of 64,
266
- then provide remove_pad that returns original size.
267
- """
268
- img = HWC3(img_u8)
269
- H_raw, W_raw, _ = img.shape
270
- H_pad, W_pad = pad64(H_raw), pad64(W_raw)
271
- img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
272
-
273
- def remove_pad(x: np.ndarray) -> np.ndarray:
274
- return safer_memory(x[:H_raw, :W_raw, ...])
275
-
276
- return safer_memory(img_padded), remove_pad
277
-
278
-
279
- # --------------------------------------------------------------------------------------
280
- # RGBA rules (as you requested)
281
- # --------------------------------------------------------------------------------------
282
-
283
- def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
284
- """
285
- If RGBA: return RGB composited over WHITE + alpha_u8 kept separately.
286
- If RGB: return input RGB + None alpha.
287
- """
288
- if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
289
- rgba = inp_u8.astype(np.uint8)
290
- rgb = rgba[:, :, 0:3].astype(np.float32)
291
- a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
292
- rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
293
- alpha_u8 = rgba[:, :, 3].copy()
294
- return rgb_white, alpha_u8
295
- # force to RGB
296
- return HWC3(inp_u8), None
297
-
298
-
299
- def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
300
- """
301
- Requested output rule:
302
- - attach alpha to depth (conceptually RGBA)
303
- - composite over BLACK
304
- - output RGB
305
- That is equivalent to depth_rgb * alpha.
306
- """
307
- depth_rgb_u8 = HWC3(depth_rgb_u8)
308
- a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
309
- out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
310
- return out
311
-
312
-
313
- # --------------------------------------------------------------------------------------
314
- # ComfyUI conversion helpers
315
- # --------------------------------------------------------------------------------------
316
-
317
- def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
318
- """
319
- Comfy IMAGE: float [0..1], shape [H,W,C] or [B,H,W,C]
320
- Convert to uint8 HWC.
321
- """
322
- if img.ndim == 4:
323
- img = img[0]
324
- arr = img.detach().cpu().float().clamp(0, 1).numpy()
325
- u8 = (arr * 255.0).round().astype(np.uint8)
326
- return u8
327
-
328
-
329
- def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
330
- img_u8 = HWC3(img_u8)
331
- t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
332
- return t.unsqueeze(0) # [1,H,W,C]
333
-
334
-
335
- # --------------------------------------------------------------------------------------
336
- # Pipeline loading (local-first, then zoe fallback)
337
- # --------------------------------------------------------------------------------------
338
-
339
- _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
340
-
341
-
342
- def _try_load_pipeline(model_source: str, device: torch.device, log):
343
- """
344
- Use transformers.pipeline like Zoe code does.
345
- We intentionally do NOT pass device=... here, and instead move model like Zoe node.
346
- """
347
- if pipeline is None:
348
- raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
349
-
350
- key = (model_source, str(device))
351
- if key in _PIPE_CACHE:
352
- log(f"[SaliaDepth] Using cached pipeline for source='{model_source}' device='{device}'")
353
- return _PIPE_CACHE[key]
354
-
355
- log(f"[SaliaDepth] Creating pipeline(task='depth-estimation', model='{model_source}')")
356
- p = pipeline(task="depth-estimation", model=model_source)
357
-
358
- # Try to move model to torch device, like ZoeDetector.to()
359
- try:
360
- p.model = p.model.to(device)
361
- p.device = device # Zoe code sets this; newer transformers uses torch.device internally
362
- log(f"[SaliaDepth] Moved pipeline model to device: {device}")
363
- except Exception as e:
364
- log(f"[SaliaDepth] WARN: Could not move pipeline model to device {device}: {repr(e)}")
365
-
366
- # Log config info for debugging
367
- try:
368
- cfg = p.model.config
369
- log(f"[SaliaDepth] Model class: {p.model.__class__.__name__}")
370
- log(f"[SaliaDepth] Config class: {cfg.__class__.__name__}")
371
- log(f"[SaliaDepth] Config model_type: {getattr(cfg, 'model_type', '')}")
372
- log(f"[SaliaDepth] Config _name_or_path: {getattr(cfg, '_name_or_path', '')}")
373
- except Exception as e:
374
- log(f"[SaliaDepth] WARN: Could not log model config: {repr(e)}")
375
-
376
- _PIPE_CACHE[key] = p
377
- return p
378
-
379
-
380
- def get_depth_pipeline(device: torch.device, log):
381
- """
382
- 1) Ensure assets/depth files exist (download if missing)
383
- 2) Try load local dir
384
- 3) Fallback to Intel/zoedepth-nyu-kitti
385
- 4) If both fail -> None
386
- """
387
- # Always log HF cache info (helps locate where fallback downloads go)
388
- log("[SaliaDepth] ===== Hugging Face cache info (fallback path) =====")
389
- for k, v in _hf_cache_info().items():
390
- if v:
391
- log(f"[SaliaDepth] {k} = {v}")
392
- log(f"[SaliaDepth] Zoe fallback repo id: {ZOE_FALLBACK_REPO_ID}")
393
-
394
- # Local-first
395
- local_ok = ensure_local_model_files(log)
396
- if local_ok:
397
- try:
398
- log(f"[SaliaDepth] Trying LOCAL model from directory: {MODEL_DIR}")
399
- return _try_load_pipeline(str(MODEL_DIR), device, log)
400
- except Exception as e:
401
- log(f"[SaliaDepth] Local model load FAILED: {repr(e)}")
402
-
403
- # Fallback
404
- try:
405
- log(f"[SaliaDepth] Trying ZOE fallback model: {ZOE_FALLBACK_REPO_ID}")
406
- return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device, log)
407
- except Exception as e:
408
- log(f"[SaliaDepth] Zoe fallback load FAILED: {repr(e)}")
409
-
410
- return None
411
-
412
-
413
- # --------------------------------------------------------------------------------------
414
- # Depth inference (Zoe-style)
415
- # --------------------------------------------------------------------------------------
416
-
417
- def depth_estimate_zoe_style(
418
- pipe,
419
- input_rgb_u8: np.ndarray,
420
- detect_resolution: int,
421
- log,
422
- upscale_method: str = "INTER_CUBIC"
423
- ) -> np.ndarray:
424
- """
425
- Matches your ZoeDetector.__call__ logic very closely.
426
- Returns uint8 RGB depth map.
427
- """
428
- # detect_resolution:
429
- # - if -1: keep original but pad-to-64
430
- # - else: min-side resize to detect_resolution, then pad-to-64
431
- if detect_resolution == -1:
432
- work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
433
- log(f"[SaliaDepth] Preprocess: resolution=-1 (no resize), padded to 64. work={work_img.shape}")
434
- else:
435
- work_img, remove_pad = resize_image_with_pad_min_side(
436
- input_rgb_u8,
437
- int(detect_resolution),
438
- upscale_method=upscale_method,
439
- skip_hwc3=False,
440
- mode="edge",
441
- log=log
442
- )
443
- log(f"[SaliaDepth] Preprocess: min-side resized to {detect_resolution}, padded to 64. work={work_img.shape}")
444
-
445
- pil_image = Image.fromarray(work_img)
446
-
447
- with torch.no_grad():
448
- result = pipe(pil_image)
449
- depth = result["depth"]
450
-
451
- if isinstance(depth, Image.Image):
452
- depth_array = np.array(depth, dtype=np.float32)
453
- else:
454
- depth_array = np.array(depth, dtype=np.float32)
455
-
456
- # EXACT normalization like your Zoe code
457
- vmin = float(np.percentile(depth_array, 2))
458
- vmax = float(np.percentile(depth_array, 85))
459
-
460
- log(f"[SaliaDepth] Depth raw stats: shape={depth_array.shape} vmin(p2)={vmin:.6f} vmax(p85)={vmax:.6f} mean={float(depth_array.mean()):.6f}")
461
-
462
- depth_array = depth_array - vmin
463
- denom = (vmax - vmin)
464
- if abs(denom) < 1e-12:
465
- # avoid division by zero; log it
466
- log("[SaliaDepth] WARN: vmax==vmin; forcing denom epsilon to avoid NaNs.")
467
- denom = 1e-6
468
- depth_array = depth_array / denom
469
-
470
- # EXACT invert like your Zoe code
471
- depth_array = 1.0 - depth_array
472
-
473
- depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
474
-
475
- detected_map = remove_pad(HWC3(depth_image))
476
- log(f"[SaliaDepth] Output (post-remove_pad): {detected_map.shape} dtype={detected_map.dtype}")
477
- return detected_map
478
-
479
-
480
- def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int, log) -> np.ndarray:
481
- """
482
- Resize depth output back to original input size.
483
- Use cv2 if available, else PIL.
484
- """
485
- try:
486
- import cv2
487
- out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
488
- return out.astype(np.uint8)
489
- except Exception as e:
490
- log(f"[SaliaDepth] WARN: cv2 resize failed ({repr(e)}); using PIL.")
491
- pil = Image.fromarray(depth_rgb_u8)
492
- pil = pil.resize((w0, h0), resample=Image.BILINEAR)
493
- return np.array(pil, dtype=np.uint8)
494
-
495
-
496
- # --------------------------------------------------------------------------------------
497
- # ComfyUI Node
498
- # --------------------------------------------------------------------------------------
499
-
500
- class Salia_Depth_Preprocessor:
501
- @classmethod
502
- def INPUT_TYPES(cls):
503
- return {
504
- "required": {
505
- "image": ("IMAGE",),
506
- # note: default -1, min -1
507
- "resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
508
- }
509
- }
510
-
511
- # 2 outputs: image + log string
512
- RETURN_TYPES = ("IMAGE", "STRING")
513
- FUNCTION = "execute"
514
- CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
515
-
516
- def execute(self, image, resolution=-1):
517
- lines, log = _make_logger()
518
- log("[SaliaDepth] ==================================================")
519
- log("[SaliaDepth] SaliaDepthPreprocessor starting")
520
- log(f"[SaliaDepth] resolution input = {resolution}")
521
-
522
- # Get torch device
523
- try:
524
- device = model_management.get_torch_device()
525
- except Exception as e:
526
- device = torch.device("cpu")
527
- log(f"[SaliaDepth] WARN: model_management.get_torch_device failed: {repr(e)} -> using CPU")
528
-
529
- log(f"[SaliaDepth] torch device = {device}")
530
-
531
- # Load pipeline
532
- pipe = None
533
- try:
534
- pipe = get_depth_pipeline(device, log)
535
- except Exception as e:
536
- log(f"[SaliaDepth] ERROR: get_depth_pipeline crashed: {repr(e)}")
537
- pipe = None
538
-
539
- if pipe is None:
540
- log("[SaliaDepth] FATAL: No pipeline available. Returning input image unchanged.")
541
- return (image, "\n".join(lines))
542
-
543
- # Batch support
544
- if image.ndim == 3:
545
- image = image.unsqueeze(0)
546
-
547
- outs = []
548
- for i in range(image.shape[0]):
549
- try:
550
- # Original dimensions
551
- h0 = int(image[i].shape[0])
552
- w0 = int(image[i].shape[1])
553
- c0 = int(image[i].shape[2])
554
- log(f"[SaliaDepth] ---- Batch index {i} input shape = ({h0},{w0},{c0}) ----")
555
-
556
- inp_u8 = comfy_tensor_to_u8(image[i])
557
-
558
- # RGBA rule (pre)
559
- rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
560
- had_rgba = alpha_u8 is not None
561
- log(f"[SaliaDepth] had_rgba={had_rgba}")
562
-
563
- # Run depth (Zoe-style)
564
- depth_rgb = depth_estimate_zoe_style(
565
- pipe=pipe,
566
- input_rgb_u8=rgb_for_depth,
567
- detect_resolution=int(resolution),
568
- log=log,
569
- upscale_method="INTER_CUBIC"
570
- )
571
-
572
- # Resize back to original input size
573
- depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0, log=log)
574
-
575
- # RGBA rule (post)
576
- if had_rgba:
577
- # Use original alpha at original size.
578
- # If alpha size differs, resize alpha to match.
579
- if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
580
- log("[SaliaDepth] Alpha size mismatch; resizing alpha to original size.")
581
- try:
582
- import cv2
583
- alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
584
- except Exception:
585
- pil_a = Image.fromarray(alpha_u8)
586
- pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
587
- alpha_u8 = np.array(pil_a, dtype=np.uint8)
588
-
589
- # "Put alpha on RGB turning it into RGBA, then put BLACK background behind it, then back to RGB"
590
- depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
591
- log("[SaliaDepth] Applied RGBA post-step (alpha + black background).")
592
-
593
- outs.append(u8_to_comfy_tensor(depth_rgb))
594
-
595
- except Exception as e:
596
- log(f"[SaliaDepth] ERROR: Inference failed at batch index {i}: {repr(e)}")
597
- log("[SaliaDepth] Passing through original input image for this batch item.")
598
- outs.append(image[i].unsqueeze(0))
599
-
600
- out = torch.cat(outs, dim=0)
601
- log("[SaliaDepth] Done.")
602
- return (out, "\n".join(lines))
603
-
604
-
605
- NODE_CLASS_MAPPINGS = {
606
- "SaliaDepthPreprocessor": Salia_Depth_Preprocessor
607
- }
608
-
609
- NODE_DISPLAY_NAME_MAPPINGS = {
610
- "SaliaDepthPreprocessor": "Salia Depth (local assets/depth + logs)"
611
  }
 
1
+ import os
2
+ import shutil
3
+ import urllib.request
4
+ from pathlib import Path
5
+ from typing import Dict, Tuple, Any, Optional, List
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ import comfy.model_management as model_management
12
+
13
+ # transformers is required for depth-estimation pipeline
14
+ try:
15
+ from transformers import pipeline
16
+ except Exception as e:
17
+ pipeline = None
18
+ _TRANSFORMERS_IMPORT_ERROR = e
19
+
20
+
21
+ # --------------------------------------------------------------------------------------
22
+ # Paths / sources
23
+ # --------------------------------------------------------------------------------------
24
+
25
+ # This file: comfyui-salia_online/nodes/Salia_Depth.py
26
+ # Plugin root: comfyui-salia_online/
27
+ PLUGIN_ROOT = Path(__file__).resolve().parent.parent
28
+
29
+ # Requested local path: assets/depth
30
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
31
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ REQUIRED_FILES = {
34
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
35
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
36
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
37
+ }
38
+
39
+ # "zoe-path" fallback
40
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
41
+
42
+
43
+ # --------------------------------------------------------------------------------------
44
+ # Logging helpers
45
+ # --------------------------------------------------------------------------------------
46
+
47
+ def _make_logger() -> Tuple[List[str], Any]:
48
+ lines: List[str] = []
49
+
50
+ def log(msg: str):
51
+ # console
52
+ try:
53
+ print(msg)
54
+ except Exception:
55
+ pass
56
+ # UI string
57
+ lines.append(str(msg))
58
+
59
+ return lines, log
60
+
61
+
62
+ def _fmt_bytes(n: Optional[int]) -> str:
63
+ if n is None:
64
+ return "?"
65
+ # simple readable
66
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
67
+ if n < 1024:
68
+ return f"{n:.0f}{unit}"
69
+ n /= 1024.0
70
+ return f"{n:.1f}PB"
71
+
72
+
73
+ def _file_size(path: Path) -> Optional[int]:
74
+ try:
75
+ return path.stat().st_size
76
+ except Exception:
77
+ return None
78
+
79
+
80
+ def _hf_cache_info() -> Dict[str, str]:
81
+ info: Dict[str, str] = {}
82
+ info["env.HF_HOME"] = os.environ.get("HF_HOME", "")
83
+ info["env.HF_HUB_CACHE"] = os.environ.get("HF_HUB_CACHE", "")
84
+ info["env.TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "")
85
+ info["env.HUGGINGFACE_HUB_CACHE"] = os.environ.get("HUGGINGFACE_HUB_CACHE", "")
86
+
87
+ try:
88
+ from huggingface_hub import constants as hf_constants
89
+ # These exist in most hub versions:
90
+ info["huggingface_hub.constants.HF_HOME"] = str(getattr(hf_constants, "HF_HOME", ""))
91
+ info["huggingface_hub.constants.HF_HUB_CACHE"] = str(getattr(hf_constants, "HF_HUB_CACHE", ""))
92
+ except Exception:
93
+ pass
94
+
95
+ return info
96
+
97
+
98
+ # --------------------------------------------------------------------------------------
99
+ # Download helpers
100
+ # --------------------------------------------------------------------------------------
101
+
102
+ def _have_required_files() -> bool:
103
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
104
+
105
+
106
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
107
+ """
108
+ Download with atomic temp rename.
109
+ """
110
+ dst.parent.mkdir(parents=True, exist_ok=True)
111
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
112
+
113
+ if tmp.exists():
114
+ try:
115
+ tmp.unlink()
116
+ except Exception:
117
+ pass
118
+
119
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
120
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
121
+ shutil.copyfileobj(r, f)
122
+
123
+ tmp.replace(dst)
124
+
125
+
126
+ def ensure_local_model_files(log) -> bool:
127
+ """
128
+ Ensure assets/depth contains the 3 files.
129
+ Returns True if present or downloaded successfully, else False.
130
+ """
131
+ # Always log expected locations + URLs, even if we don't download.
132
+ log("[SaliaDepth] ===== Local model file check =====")
133
+ log(f"[SaliaDepth] Plugin root: {PLUGIN_ROOT}")
134
+ log(f"[SaliaDepth] Local model dir (on drive): {MODEL_DIR}")
135
+
136
+ for fname, url in REQUIRED_FILES.items():
137
+ fpath = MODEL_DIR / fname
138
+ exists = fpath.exists()
139
+ size = _file_size(fpath) if exists else None
140
+ log(f"[SaliaDepth] - {fname}")
141
+ log(f"[SaliaDepth] local path: {fpath} exists={exists} size={_fmt_bytes(size)}")
142
+ log(f"[SaliaDepth] remote url : {url}")
143
+
144
+ if _have_required_files():
145
+ log("[SaliaDepth] All required local files already exist. No download needed.")
146
+ return True
147
+
148
+ log("[SaliaDepth] One or more local files missing. Attempting download...")
149
+
150
+ try:
151
+ for fname, url in REQUIRED_FILES.items():
152
+ fpath = MODEL_DIR / fname
153
+ if fpath.exists():
154
+ continue
155
+ log(f"[SaliaDepth] Downloading '{fname}' -> '{fpath}'")
156
+ _download_url_to_file(url, fpath)
157
+ log(f"[SaliaDepth] Downloaded '{fname}' size={_fmt_bytes(_file_size(fpath))}")
158
+
159
+ ok = _have_required_files()
160
+ log(f"[SaliaDepth] Download finished. ok={ok}")
161
+ return ok
162
+ except Exception as e:
163
+ log(f"[SaliaDepth] Download failed with error: {repr(e)}")
164
+ return False
165
+
166
+
167
+ # --------------------------------------------------------------------------------------
168
+ # Exact Zoe-style preprocessing helpers (copied/adapted from your snippet)
169
+ # --------------------------------------------------------------------------------------
170
+
171
+ def HWC3(x: np.ndarray) -> np.ndarray:
172
+ assert x.dtype == np.uint8
173
+ if x.ndim == 2:
174
+ x = x[:, :, None]
175
+ assert x.ndim == 3
176
+ H, W, C = x.shape
177
+ assert C == 1 or C == 3 or C == 4
178
+ if C == 3:
179
+ return x
180
+ if C == 1:
181
+ return np.concatenate([x, x, x], axis=2)
182
+ # C == 4
183
+ color = x[:, :, 0:3].astype(np.float32)
184
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
185
+ y = color * alpha + 255.0 * (1.0 - alpha) # white background
186
+ y = y.clip(0, 255).astype(np.uint8)
187
+ return y
188
+
189
+
190
+ def pad64(x: int) -> int:
191
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
192
+
193
+
194
+ def safer_memory(x: np.ndarray) -> np.ndarray:
195
+ return np.ascontiguousarray(x.copy()).copy()
196
+
197
+
198
+ def resize_image_with_pad_min_side(
199
+ input_image: np.ndarray,
200
+ resolution: int,
201
+ upscale_method: str = "INTER_CUBIC",
202
+ skip_hwc3: bool = False,
203
+ mode: str = "edge",
204
+ log=None
205
+ ) -> Tuple[np.ndarray, Any]:
206
+ """
207
+ EXACT behavior like your zoe.transformers.py:
208
+ k = resolution / min(H,W)
209
+ resize to (W_target, H_target)
210
+ pad to multiple of 64
211
+ return padded image and remove_pad() closure
212
+ """
213
+ # prefer cv2 like original for matching results
214
+ cv2 = None
215
+ try:
216
+ import cv2 as _cv2
217
+ cv2 = _cv2
218
+ except Exception:
219
+ cv2 = None
220
+ if log:
221
+ log("[SaliaDepth] WARN: cv2 not available; resizing will use PIL fallback (may change results).")
222
+
223
+ if skip_hwc3:
224
+ img = input_image
225
+ else:
226
+ img = HWC3(input_image)
227
+
228
+ H_raw, W_raw, _ = img.shape
229
+ if resolution <= 0:
230
+ # keep original, but still pad to 64 (we will handle padding separately for -1 path)
231
+ return img, (lambda x: x)
232
+
233
+ k = float(resolution) / float(min(H_raw, W_raw))
234
+ H_target = int(np.round(float(H_raw) * k))
235
+ W_target = int(np.round(float(W_raw) * k))
236
+
237
+ if cv2 is not None:
238
+ upscale_methods = {
239
+ "INTER_NEAREST": cv2.INTER_NEAREST,
240
+ "INTER_LINEAR": cv2.INTER_LINEAR,
241
+ "INTER_AREA": cv2.INTER_AREA,
242
+ "INTER_CUBIC": cv2.INTER_CUBIC,
243
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
244
+ }
245
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
246
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
247
+ else:
248
+ # PIL fallback
249
+ pil = Image.fromarray(img)
250
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
251
+ pil = pil.resize((W_target, H_target), resample=resample)
252
+ img = np.array(pil, dtype=np.uint8)
253
+
254
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
255
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
256
+
257
+ def remove_pad(x: np.ndarray) -> np.ndarray:
258
+ return safer_memory(x[:H_target, :W_target, ...])
259
+
260
+ return safer_memory(img_padded), remove_pad
261
+
262
+
263
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
264
+ """
265
+ For resolution == -1: keep original resolution but still pad to multiples of 64,
266
+ then provide remove_pad that returns original size.
267
+ """
268
+ img = HWC3(img_u8)
269
+ H_raw, W_raw, _ = img.shape
270
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
271
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
272
+
273
+ def remove_pad(x: np.ndarray) -> np.ndarray:
274
+ return safer_memory(x[:H_raw, :W_raw, ...])
275
+
276
+ return safer_memory(img_padded), remove_pad
277
+
278
+
279
+ # --------------------------------------------------------------------------------------
280
+ # RGBA rules (as you requested)
281
+ # --------------------------------------------------------------------------------------
282
+
283
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
284
+ """
285
+ If RGBA: return RGB composited over WHITE + alpha_u8 kept separately.
286
+ If RGB: return input RGB + None alpha.
287
+ """
288
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
289
+ rgba = inp_u8.astype(np.uint8)
290
+ rgb = rgba[:, :, 0:3].astype(np.float32)
291
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
292
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
293
+ alpha_u8 = rgba[:, :, 3].copy()
294
+ return rgb_white, alpha_u8
295
+ # force to RGB
296
+ return HWC3(inp_u8), None
297
+
298
+
299
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
300
+ """
301
+ Requested output rule:
302
+ - attach alpha to depth (conceptually RGBA)
303
+ - composite over BLACK
304
+ - output RGB
305
+ That is equivalent to depth_rgb * alpha.
306
+ """
307
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
308
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
309
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
310
+ return out
311
+
312
+
313
+ # --------------------------------------------------------------------------------------
314
+ # ComfyUI conversion helpers
315
+ # --------------------------------------------------------------------------------------
316
+
317
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
318
+ """
319
+ Comfy IMAGE: float [0..1], shape [H,W,C] or [B,H,W,C]
320
+ Convert to uint8 HWC.
321
+ """
322
+ if img.ndim == 4:
323
+ img = img[0]
324
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
325
+ u8 = (arr * 255.0).round().astype(np.uint8)
326
+ return u8
327
+
328
+
329
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
330
+ img_u8 = HWC3(img_u8)
331
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
332
+ return t.unsqueeze(0) # [1,H,W,C]
333
+
334
+
335
+ # --------------------------------------------------------------------------------------
336
+ # Pipeline loading (local-first, then zoe fallback)
337
+ # --------------------------------------------------------------------------------------
338
+
339
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
340
+
341
+
342
+ def _try_load_pipeline(model_source: str, device: torch.device, log):
343
+ """
344
+ Use transformers.pipeline like Zoe code does.
345
+ We intentionally do NOT pass device=... here, and instead move model like Zoe node.
346
+ """
347
+ if pipeline is None:
348
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
349
+
350
+ key = (model_source, str(device))
351
+ if key in _PIPE_CACHE:
352
+ log(f"[SaliaDepth] Using cached pipeline for source='{model_source}' device='{device}'")
353
+ return _PIPE_CACHE[key]
354
+
355
+ log(f"[SaliaDepth] Creating pipeline(task='depth-estimation', model='{model_source}')")
356
+ p = pipeline(task="depth-estimation", model=model_source)
357
+
358
+ # Try to move model to torch device, like ZoeDetector.to()
359
+ try:
360
+ p.model = p.model.to(device)
361
+ p.device = device # Zoe code sets this; newer transformers uses torch.device internally
362
+ log(f"[SaliaDepth] Moved pipeline model to device: {device}")
363
+ except Exception as e:
364
+ log(f"[SaliaDepth] WARN: Could not move pipeline model to device {device}: {repr(e)}")
365
+
366
+ # Log config info for debugging
367
+ try:
368
+ cfg = p.model.config
369
+ log(f"[SaliaDepth] Model class: {p.model.__class__.__name__}")
370
+ log(f"[SaliaDepth] Config class: {cfg.__class__.__name__}")
371
+ log(f"[SaliaDepth] Config model_type: {getattr(cfg, 'model_type', '')}")
372
+ log(f"[SaliaDepth] Config _name_or_path: {getattr(cfg, '_name_or_path', '')}")
373
+ except Exception as e:
374
+ log(f"[SaliaDepth] WARN: Could not log model config: {repr(e)}")
375
+
376
+ _PIPE_CACHE[key] = p
377
+ return p
378
+
379
+
380
+ def get_depth_pipeline(device: torch.device, log):
381
+ """
382
+ 1) Ensure assets/depth files exist (download if missing)
383
+ 2) Try load local dir
384
+ 3) Fallback to Intel/zoedepth-nyu-kitti
385
+ 4) If both fail -> None
386
+ """
387
+ # Always log HF cache info (helps locate where fallback downloads go)
388
+ log("[SaliaDepth] ===== Hugging Face cache info (fallback path) =====")
389
+ for k, v in _hf_cache_info().items():
390
+ if v:
391
+ log(f"[SaliaDepth] {k} = {v}")
392
+ log(f"[SaliaDepth] Zoe fallback repo id: {ZOE_FALLBACK_REPO_ID}")
393
+
394
+ # Local-first
395
+ local_ok = ensure_local_model_files(log)
396
+ if local_ok:
397
+ try:
398
+ log(f"[SaliaDepth] Trying LOCAL model from directory: {MODEL_DIR}")
399
+ return _try_load_pipeline(str(MODEL_DIR), device, log)
400
+ except Exception as e:
401
+ log(f"[SaliaDepth] Local model load FAILED: {repr(e)}")
402
+
403
+ # Fallback
404
+ try:
405
+ log(f"[SaliaDepth] Trying ZOE fallback model: {ZOE_FALLBACK_REPO_ID}")
406
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device, log)
407
+ except Exception as e:
408
+ log(f"[SaliaDepth] Zoe fallback load FAILED: {repr(e)}")
409
+
410
+ return None
411
+
412
+
413
+ # --------------------------------------------------------------------------------------
414
+ # Depth inference (Zoe-style)
415
+ # --------------------------------------------------------------------------------------
416
+
417
+ def depth_estimate_zoe_style(
418
+ pipe,
419
+ input_rgb_u8: np.ndarray,
420
+ detect_resolution: int,
421
+ log,
422
+ upscale_method: str = "INTER_CUBIC"
423
+ ) -> np.ndarray:
424
+ """
425
+ Matches your ZoeDetector.__call__ logic very closely.
426
+ Returns uint8 RGB depth map.
427
+ """
428
+ # detect_resolution:
429
+ # - if -1: keep original but pad-to-64
430
+ # - else: min-side resize to detect_resolution, then pad-to-64
431
+ if detect_resolution == -1:
432
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
433
+ log(f"[SaliaDepth] Preprocess: resolution=-1 (no resize), padded to 64. work={work_img.shape}")
434
+ else:
435
+ work_img, remove_pad = resize_image_with_pad_min_side(
436
+ input_rgb_u8,
437
+ int(detect_resolution),
438
+ upscale_method=upscale_method,
439
+ skip_hwc3=False,
440
+ mode="edge",
441
+ log=log
442
+ )
443
+ log(f"[SaliaDepth] Preprocess: min-side resized to {detect_resolution}, padded to 64. work={work_img.shape}")
444
+
445
+ pil_image = Image.fromarray(work_img)
446
+
447
+ with torch.no_grad():
448
+ result = pipe(pil_image)
449
+ depth = result["depth"]
450
+
451
+ if isinstance(depth, Image.Image):
452
+ depth_array = np.array(depth, dtype=np.float32)
453
+ else:
454
+ depth_array = np.array(depth, dtype=np.float32)
455
+
456
+ # EXACT normalization like your Zoe code
457
+ vmin = float(np.percentile(depth_array, 2))
458
+ vmax = float(np.percentile(depth_array, 85))
459
+
460
+ log(f"[SaliaDepth] Depth raw stats: shape={depth_array.shape} vmin(p2)={vmin:.6f} vmax(p85)={vmax:.6f} mean={float(depth_array.mean()):.6f}")
461
+
462
+ depth_array = depth_array - vmin
463
+ denom = (vmax - vmin)
464
+ if abs(denom) < 1e-12:
465
+ # avoid division by zero; log it
466
+ log("[SaliaDepth] WARN: vmax==vmin; forcing denom epsilon to avoid NaNs.")
467
+ denom = 1e-6
468
+ depth_array = depth_array / denom
469
+
470
+ # EXACT invert like your Zoe code
471
+ depth_array = 1.0 - depth_array
472
+
473
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
474
+
475
+ detected_map = remove_pad(HWC3(depth_image))
476
+ log(f"[SaliaDepth] Output (post-remove_pad): {detected_map.shape} dtype={detected_map.dtype}")
477
+ return detected_map
478
+
479
+
480
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int, log) -> np.ndarray:
481
+ """
482
+ Resize depth output back to original input size.
483
+ Use cv2 if available, else PIL.
484
+ """
485
+ try:
486
+ import cv2
487
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
488
+ return out.astype(np.uint8)
489
+ except Exception as e:
490
+ log(f"[SaliaDepth] WARN: cv2 resize failed ({repr(e)}); using PIL.")
491
+ pil = Image.fromarray(depth_rgb_u8)
492
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
493
+ return np.array(pil, dtype=np.uint8)
494
+
495
+
496
+ # --------------------------------------------------------------------------------------
497
+ # ComfyUI Node
498
+ # --------------------------------------------------------------------------------------
499
+
500
+ class Salia_Depth_Preprocessor:
501
+ @classmethod
502
+ def INPUT_TYPES(cls):
503
+ return {
504
+ "required": {
505
+ "image": ("IMAGE",),
506
+ # note: default -1, min -1
507
+ "resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
508
+ }
509
+ }
510
+
511
+ # 2 outputs: image + log string
512
+ RETURN_TYPES = ("IMAGE", "STRING")
513
+ FUNCTION = "execute"
514
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
515
+
516
+ def execute(self, image, resolution=-1):
517
+ lines, log = _make_logger()
518
+ log("[SaliaDepth] ==================================================")
519
+ log("[SaliaDepth] SaliaDepthPreprocessor starting")
520
+ log(f"[SaliaDepth] resolution input = {resolution}")
521
+
522
+ # Get torch device
523
+ try:
524
+ device = model_management.get_torch_device()
525
+ except Exception as e:
526
+ device = torch.device("cpu")
527
+ log(f"[SaliaDepth] WARN: model_management.get_torch_device failed: {repr(e)} -> using CPU")
528
+
529
+ log(f"[SaliaDepth] torch device = {device}")
530
+
531
+ # Load pipeline
532
+ pipe = None
533
+ try:
534
+ pipe = get_depth_pipeline(device, log)
535
+ except Exception as e:
536
+ log(f"[SaliaDepth] ERROR: get_depth_pipeline crashed: {repr(e)}")
537
+ pipe = None
538
+
539
+ if pipe is None:
540
+ log("[SaliaDepth] FATAL: No pipeline available. Returning input image unchanged.")
541
+ return (image, "\n".join(lines))
542
+
543
+ # Batch support
544
+ if image.ndim == 3:
545
+ image = image.unsqueeze(0)
546
+
547
+ outs = []
548
+ for i in range(image.shape[0]):
549
+ try:
550
+ # Original dimensions
551
+ h0 = int(image[i].shape[0])
552
+ w0 = int(image[i].shape[1])
553
+ c0 = int(image[i].shape[2])
554
+ log(f"[SaliaDepth] ---- Batch index {i} input shape = ({h0},{w0},{c0}) ----")
555
+
556
+ inp_u8 = comfy_tensor_to_u8(image[i])
557
+
558
+ # RGBA rule (pre)
559
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
560
+ had_rgba = alpha_u8 is not None
561
+ log(f"[SaliaDepth] had_rgba={had_rgba}")
562
+
563
+ # Run depth (Zoe-style)
564
+ depth_rgb = depth_estimate_zoe_style(
565
+ pipe=pipe,
566
+ input_rgb_u8=rgb_for_depth,
567
+ detect_resolution=int(resolution),
568
+ log=log,
569
+ upscale_method="INTER_CUBIC"
570
+ )
571
+
572
+ # Resize back to original input size
573
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0, log=log)
574
+
575
+ # RGBA rule (post)
576
+ if had_rgba:
577
+ # Use original alpha at original size.
578
+ # If alpha size differs, resize alpha to match.
579
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
580
+ log("[SaliaDepth] Alpha size mismatch; resizing alpha to original size.")
581
+ try:
582
+ import cv2
583
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
584
+ except Exception:
585
+ pil_a = Image.fromarray(alpha_u8)
586
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
587
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
588
+
589
+ # "Put alpha on RGB turning it into RGBA, then put BLACK background behind it, then back to RGB"
590
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
591
+ log("[SaliaDepth] Applied RGBA post-step (alpha + black background).")
592
+
593
+ outs.append(u8_to_comfy_tensor(depth_rgb))
594
+
595
+ except Exception as e:
596
+ log(f"[SaliaDepth] ERROR: Inference failed at batch index {i}: {repr(e)}")
597
+ log("[SaliaDepth] Passing through original input image for this batch item.")
598
+ outs.append(image[i].unsqueeze(0))
599
+
600
+ out = torch.cat(outs, dim=0)
601
+ log("[SaliaDepth] Done.")
602
+ return (out, "\n".join(lines))
603
+
604
+
605
+ NODE_CLASS_MAPPINGS = {
606
+ "SaliaDepthPreprocessor": Salia_Depth_Preprocessor
607
+ }
608
+
609
+ NODE_DISPLAY_NAME_MAPPINGS = {
610
+ "SaliaDepthPreprocessor": "Salia Depth (local assets/depth + logs)"
611
  }