File size: 21,766 Bytes
79a1f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d625c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
import os
import shutil
import urllib.request
from pathlib import Path
from typing import Dict, Tuple, Any, Optional, List

import numpy as np
import torch
from PIL import Image

import comfy.model_management as model_management

# transformers is required for depth-estimation pipeline
try:
    from transformers import pipeline
except Exception as e:
    pipeline = None
    _TRANSFORMERS_IMPORT_ERROR = e


# --------------------------------------------------------------------------------------
# Paths / sources
# --------------------------------------------------------------------------------------

# This file: comfyui-salia_online/nodes/Salia_Depth.py
# Plugin root: comfyui-salia_online/
PLUGIN_ROOT = Path(__file__).resolve().parent.parent

# Requested local path: assets/depth
MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

REQUIRED_FILES = {
    "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
    "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
    "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
}

# "zoe-path" fallback
ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"


# --------------------------------------------------------------------------------------
# Logging helpers
# --------------------------------------------------------------------------------------

def _make_logger() -> Tuple[List[str], Any]:
    lines: List[str] = []

    def log(msg: str):
        # console
        try:
            print(msg)
        except Exception:
            pass
        # UI string
        lines.append(str(msg))

    return lines, log


def _fmt_bytes(n: Optional[int]) -> str:
    if n is None:
        return "?"
    # simple readable
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if n < 1024:
            return f"{n:.0f}{unit}"
        n /= 1024.0
    return f"{n:.1f}PB"


def _file_size(path: Path) -> Optional[int]:
    try:
        return path.stat().st_size
    except Exception:
        return None


def _hf_cache_info() -> Dict[str, str]:
    info: Dict[str, str] = {}
    info["env.HF_HOME"] = os.environ.get("HF_HOME", "")
    info["env.HF_HUB_CACHE"] = os.environ.get("HF_HUB_CACHE", "")
    info["env.TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "")
    info["env.HUGGINGFACE_HUB_CACHE"] = os.environ.get("HUGGINGFACE_HUB_CACHE", "")

    try:
        from huggingface_hub import constants as hf_constants
        # These exist in most hub versions:
        info["huggingface_hub.constants.HF_HOME"] = str(getattr(hf_constants, "HF_HOME", ""))
        info["huggingface_hub.constants.HF_HUB_CACHE"] = str(getattr(hf_constants, "HF_HUB_CACHE", ""))
    except Exception:
        pass

    return info


# --------------------------------------------------------------------------------------
# Download helpers
# --------------------------------------------------------------------------------------

def _have_required_files() -> bool:
    return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())


def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
    """
    Download with atomic temp rename.
    """
    dst.parent.mkdir(parents=True, exist_ok=True)
    tmp = dst.with_suffix(dst.suffix + ".tmp")

    if tmp.exists():
        try:
            tmp.unlink()
        except Exception:
            pass

    req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
    with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
        shutil.copyfileobj(r, f)

    tmp.replace(dst)


def ensure_local_model_files(log) -> bool:
    """
    Ensure assets/depth contains the 3 files.
    Returns True if present or downloaded successfully, else False.
    """
    # Always log expected locations + URLs, even if we don't download.
    log("[SaliaDepth] ===== Local model file check =====")
    log(f"[SaliaDepth] Plugin root: {PLUGIN_ROOT}")
    log(f"[SaliaDepth] Local model dir (on drive): {MODEL_DIR}")

    for fname, url in REQUIRED_FILES.items():
        fpath = MODEL_DIR / fname
        exists = fpath.exists()
        size = _file_size(fpath) if exists else None
        log(f"[SaliaDepth]   - {fname}")
        log(f"[SaliaDepth]       local path: {fpath}  exists={exists}  size={_fmt_bytes(size)}")
        log(f"[SaliaDepth]       remote url : {url}")

    if _have_required_files():
        log("[SaliaDepth] All required local files already exist. No download needed.")
        return True

    log("[SaliaDepth] One or more local files missing. Attempting download...")

    try:
        for fname, url in REQUIRED_FILES.items():
            fpath = MODEL_DIR / fname
            if fpath.exists():
                continue
            log(f"[SaliaDepth] Downloading '{fname}' -> '{fpath}'")
            _download_url_to_file(url, fpath)
            log(f"[SaliaDepth] Downloaded '{fname}' size={_fmt_bytes(_file_size(fpath))}")

        ok = _have_required_files()
        log(f"[SaliaDepth] Download finished. ok={ok}")
        return ok
    except Exception as e:
        log(f"[SaliaDepth] Download failed with error: {repr(e)}")
        return False


# --------------------------------------------------------------------------------------
# Exact Zoe-style preprocessing helpers (copied/adapted from your snippet)
# --------------------------------------------------------------------------------------

def HWC3(x: np.ndarray) -> np.ndarray:
    assert x.dtype == np.uint8
    if x.ndim == 2:
        x = x[:, :, None]
    assert x.ndim == 3
    H, W, C = x.shape
    assert C == 1 or C == 3 or C == 4
    if C == 3:
        return x
    if C == 1:
        return np.concatenate([x, x, x], axis=2)
    # C == 4
    color = x[:, :, 0:3].astype(np.float32)
    alpha = x[:, :, 3:4].astype(np.float32) / 255.0
    y = color * alpha + 255.0 * (1.0 - alpha)  # white background
    y = y.clip(0, 255).astype(np.uint8)
    return y


def pad64(x: int) -> int:
    return int(np.ceil(float(x) / 64.0) * 64 - x)


def safer_memory(x: np.ndarray) -> np.ndarray:
    return np.ascontiguousarray(x.copy()).copy()


def resize_image_with_pad_min_side(
    input_image: np.ndarray,
    resolution: int,
    upscale_method: str = "INTER_CUBIC",
    skip_hwc3: bool = False,
    mode: str = "edge",
    log=None
) -> Tuple[np.ndarray, Any]:
    """
    EXACT behavior like your zoe.transformers.py:
      k = resolution / min(H,W)
      resize to (W_target, H_target)
      pad to multiple of 64
      return padded image and remove_pad() closure
    """
    # prefer cv2 like original for matching results
    cv2 = None
    try:
        import cv2 as _cv2
        cv2 = _cv2
    except Exception:
        cv2 = None
        if log:
            log("[SaliaDepth] WARN: cv2 not available; resizing will use PIL fallback (may change results).")

    if skip_hwc3:
        img = input_image
    else:
        img = HWC3(input_image)

    H_raw, W_raw, _ = img.shape
    if resolution <= 0:
        # keep original, but still pad to 64 (we will handle padding separately for -1 path)
        return img, (lambda x: x)

    k = float(resolution) / float(min(H_raw, W_raw))
    H_target = int(np.round(float(H_raw) * k))
    W_target = int(np.round(float(W_raw) * k))

    if cv2 is not None:
        upscale_methods = {
            "INTER_NEAREST": cv2.INTER_NEAREST,
            "INTER_LINEAR": cv2.INTER_LINEAR,
            "INTER_AREA": cv2.INTER_AREA,
            "INTER_CUBIC": cv2.INTER_CUBIC,
            "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
        }
        method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
        img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
    else:
        # PIL fallback
        pil = Image.fromarray(img)
        resample = Image.BICUBIC if k > 1 else Image.LANCZOS
        pil = pil.resize((W_target, H_target), resample=resample)
        img = np.array(pil, dtype=np.uint8)

    H_pad, W_pad = pad64(H_target), pad64(W_target)
    img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)

    def remove_pad(x: np.ndarray) -> np.ndarray:
        return safer_memory(x[:H_target, :W_target, ...])

    return safer_memory(img_padded), remove_pad


def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
    """
    For resolution == -1: keep original resolution but still pad to multiples of 64,
    then provide remove_pad that returns original size.
    """
    img = HWC3(img_u8)
    H_raw, W_raw, _ = img.shape
    H_pad, W_pad = pad64(H_raw), pad64(W_raw)
    img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)

    def remove_pad(x: np.ndarray) -> np.ndarray:
        return safer_memory(x[:H_raw, :W_raw, ...])

    return safer_memory(img_padded), remove_pad


# --------------------------------------------------------------------------------------
# RGBA rules (as you requested)
# --------------------------------------------------------------------------------------

def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    If RGBA: return RGB composited over WHITE + alpha_u8 kept separately.
    If RGB: return input RGB + None alpha.
    """
    if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
        rgba = inp_u8.astype(np.uint8)
        rgb = rgba[:, :, 0:3].astype(np.float32)
        a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
        rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
        alpha_u8 = rgba[:, :, 3].copy()
        return rgb_white, alpha_u8
    # force to RGB
    return HWC3(inp_u8), None


def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
    """
    Requested output rule:
      - attach alpha to depth (conceptually RGBA)
      - composite over BLACK
      - output RGB
    That is equivalent to depth_rgb * alpha.
    """
    depth_rgb_u8 = HWC3(depth_rgb_u8)
    a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
    out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
    return out


# --------------------------------------------------------------------------------------
# ComfyUI conversion helpers
# --------------------------------------------------------------------------------------

def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
    """
    Comfy IMAGE: float [0..1], shape [H,W,C] or [B,H,W,C]
    Convert to uint8 HWC.
    """
    if img.ndim == 4:
        img = img[0]
    arr = img.detach().cpu().float().clamp(0, 1).numpy()
    u8 = (arr * 255.0).round().astype(np.uint8)
    return u8


def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
    img_u8 = HWC3(img_u8)
    t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
    return t.unsqueeze(0)  # [1,H,W,C]


# --------------------------------------------------------------------------------------
# Pipeline loading (local-first, then zoe fallback)
# --------------------------------------------------------------------------------------

_PIPE_CACHE: Dict[Tuple[str, str], Any] = {}  # (model_source, device_str) -> pipeline


def _try_load_pipeline(model_source: str, device: torch.device, log):
    """
    Use transformers.pipeline like Zoe code does.
    We intentionally do NOT pass device=... here, and instead move model like Zoe node.
    """
    if pipeline is None:
        raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")

    key = (model_source, str(device))
    if key in _PIPE_CACHE:
        log(f"[SaliaDepth] Using cached pipeline for source='{model_source}' device='{device}'")
        return _PIPE_CACHE[key]

    log(f"[SaliaDepth] Creating pipeline(task='depth-estimation', model='{model_source}')")
    p = pipeline(task="depth-estimation", model=model_source)

    # Try to move model to torch device, like ZoeDetector.to()
    try:
        p.model = p.model.to(device)
        p.device = device  # Zoe code sets this; newer transformers uses torch.device internally
        log(f"[SaliaDepth] Moved pipeline model to device: {device}")
    except Exception as e:
        log(f"[SaliaDepth] WARN: Could not move pipeline model to device {device}: {repr(e)}")

    # Log config info for debugging
    try:
        cfg = p.model.config
        log(f"[SaliaDepth] Model class: {p.model.__class__.__name__}")
        log(f"[SaliaDepth] Config class: {cfg.__class__.__name__}")
        log(f"[SaliaDepth] Config model_type: {getattr(cfg, 'model_type', '')}")
        log(f"[SaliaDepth] Config _name_or_path: {getattr(cfg, '_name_or_path', '')}")
    except Exception as e:
        log(f"[SaliaDepth] WARN: Could not log model config: {repr(e)}")

    _PIPE_CACHE[key] = p
    return p


def get_depth_pipeline(device: torch.device, log):
    """
    1) Ensure assets/depth files exist (download if missing)
    2) Try load local dir
    3) Fallback to Intel/zoedepth-nyu-kitti
    4) If both fail -> None
    """
    # Always log HF cache info (helps locate where fallback downloads go)
    log("[SaliaDepth] ===== Hugging Face cache info (fallback path) =====")
    for k, v in _hf_cache_info().items():
        if v:
            log(f"[SaliaDepth] {k} = {v}")
    log(f"[SaliaDepth] Zoe fallback repo id: {ZOE_FALLBACK_REPO_ID}")

    # Local-first
    local_ok = ensure_local_model_files(log)
    if local_ok:
        try:
            log(f"[SaliaDepth] Trying LOCAL model from directory: {MODEL_DIR}")
            return _try_load_pipeline(str(MODEL_DIR), device, log)
        except Exception as e:
            log(f"[SaliaDepth] Local model load FAILED: {repr(e)}")

    # Fallback
    try:
        log(f"[SaliaDepth] Trying ZOE fallback model: {ZOE_FALLBACK_REPO_ID}")
        return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device, log)
    except Exception as e:
        log(f"[SaliaDepth] Zoe fallback load FAILED: {repr(e)}")

    return None


# --------------------------------------------------------------------------------------
# Depth inference (Zoe-style)
# --------------------------------------------------------------------------------------

def depth_estimate_zoe_style(
    pipe,
    input_rgb_u8: np.ndarray,
    detect_resolution: int,
    log,
    upscale_method: str = "INTER_CUBIC"
) -> np.ndarray:
    """
    Matches your ZoeDetector.__call__ logic very closely.
    Returns uint8 RGB depth map.
    """
    # detect_resolution:
    #   - if -1: keep original but pad-to-64
    #   - else: min-side resize to detect_resolution, then pad-to-64
    if detect_resolution == -1:
        work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
        log(f"[SaliaDepth] Preprocess: resolution=-1 (no resize), padded to 64. work={work_img.shape}")
    else:
        work_img, remove_pad = resize_image_with_pad_min_side(
            input_rgb_u8,
            int(detect_resolution),
            upscale_method=upscale_method,
            skip_hwc3=False,
            mode="edge",
            log=log
        )
        log(f"[SaliaDepth] Preprocess: min-side resized to {detect_resolution}, padded to 64. work={work_img.shape}")

    pil_image = Image.fromarray(work_img)

    with torch.no_grad():
        result = pipe(pil_image)
        depth = result["depth"]

        if isinstance(depth, Image.Image):
            depth_array = np.array(depth, dtype=np.float32)
        else:
            depth_array = np.array(depth, dtype=np.float32)

        # EXACT normalization like your Zoe code
        vmin = float(np.percentile(depth_array, 2))
        vmax = float(np.percentile(depth_array, 85))

        log(f"[SaliaDepth] Depth raw stats: shape={depth_array.shape} vmin(p2)={vmin:.6f} vmax(p85)={vmax:.6f} mean={float(depth_array.mean()):.6f}")

        depth_array = depth_array - vmin
        denom = (vmax - vmin)
        if abs(denom) < 1e-12:
            # avoid division by zero; log it
            log("[SaliaDepth] WARN: vmax==vmin; forcing denom epsilon to avoid NaNs.")
            denom = 1e-6
        depth_array = depth_array / denom

        # EXACT invert like your Zoe code
        depth_array = 1.0 - depth_array

        depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)

    detected_map = remove_pad(HWC3(depth_image))
    log(f"[SaliaDepth] Output (post-remove_pad): {detected_map.shape} dtype={detected_map.dtype}")
    return detected_map


def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int, log) -> np.ndarray:
    """
    Resize depth output back to original input size.
    Use cv2 if available, else PIL.
    """
    try:
        import cv2
        out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
        return out.astype(np.uint8)
    except Exception as e:
        log(f"[SaliaDepth] WARN: cv2 resize failed ({repr(e)}); using PIL.")
        pil = Image.fromarray(depth_rgb_u8)
        pil = pil.resize((w0, h0), resample=Image.BILINEAR)
        return np.array(pil, dtype=np.uint8)


# --------------------------------------------------------------------------------------
# ComfyUI Node
# --------------------------------------------------------------------------------------

class Salia_Depth_Preprocessor:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                # note: default -1, min -1
                "resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
            }
        }

    # 2 outputs: image + log string
    RETURN_TYPES = ("IMAGE", "STRING")
    FUNCTION = "execute"
    CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"

    def execute(self, image, resolution=-1):
        lines, log = _make_logger()
        log("[SaliaDepth] ==================================================")
        log("[SaliaDepth] SaliaDepthPreprocessor starting")
        log(f"[SaliaDepth] resolution input = {resolution}")

        # Get torch device
        try:
            device = model_management.get_torch_device()
        except Exception as e:
            device = torch.device("cpu")
            log(f"[SaliaDepth] WARN: model_management.get_torch_device failed: {repr(e)} -> using CPU")

        log(f"[SaliaDepth] torch device = {device}")

        # Load pipeline
        pipe = None
        try:
            pipe = get_depth_pipeline(device, log)
        except Exception as e:
            log(f"[SaliaDepth] ERROR: get_depth_pipeline crashed: {repr(e)}")
            pipe = None

        if pipe is None:
            log("[SaliaDepth] FATAL: No pipeline available. Returning input image unchanged.")
            return (image, "\n".join(lines))

        # Batch support
        if image.ndim == 3:
            image = image.unsqueeze(0)

        outs = []
        for i in range(image.shape[0]):
            try:
                # Original dimensions
                h0 = int(image[i].shape[0])
                w0 = int(image[i].shape[1])
                c0 = int(image[i].shape[2])
                log(f"[SaliaDepth] ---- Batch index {i} input shape = ({h0},{w0},{c0}) ----")

                inp_u8 = comfy_tensor_to_u8(image[i])

                # RGBA rule (pre)
                rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
                had_rgba = alpha_u8 is not None
                log(f"[SaliaDepth] had_rgba={had_rgba}")

                # Run depth (Zoe-style)
                depth_rgb = depth_estimate_zoe_style(
                    pipe=pipe,
                    input_rgb_u8=rgb_for_depth,
                    detect_resolution=int(resolution),
                    log=log,
                    upscale_method="INTER_CUBIC"
                )

                # Resize back to original input size
                depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0, log=log)

                # RGBA rule (post)
                if had_rgba:
                    # Use original alpha at original size.
                    # If alpha size differs, resize alpha to match.
                    if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
                        log("[SaliaDepth] Alpha size mismatch; resizing alpha to original size.")
                        try:
                            import cv2
                            alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
                        except Exception:
                            pil_a = Image.fromarray(alpha_u8)
                            pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
                            alpha_u8 = np.array(pil_a, dtype=np.uint8)

                    # "Put alpha on RGB turning it into RGBA, then put BLACK background behind it, then back to RGB"
                    depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
                    log("[SaliaDepth] Applied RGBA post-step (alpha + black background).")

                outs.append(u8_to_comfy_tensor(depth_rgb))

            except Exception as e:
                log(f"[SaliaDepth] ERROR: Inference failed at batch index {i}: {repr(e)}")
                log("[SaliaDepth] Passing through original input image for this batch item.")
                outs.append(image[i].unsqueeze(0))

        out = torch.cat(outs, dim=0)
        log("[SaliaDepth] Done.")
        return (out, "\n".join(lines))


NODE_CLASS_MAPPINGS = {
    "SaliaDepthPreprocessor": Salia_Depth_Preprocessor
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "SaliaDepthPreprocessor": "Salia Depth (local assets/depth + logs)"
}