saliacoel commited on
Commit
f1929d6
·
verified ·
1 Parent(s): 1cc0ed9

Upload salia_depth.py

Browse files
Files changed (1) hide show
  1. salia_depth.py +483 -221
salia_depth.py CHANGED
@@ -1,10 +1,8 @@
1
- from __future__ import annotations
2
-
3
  import os
4
  import shutil
5
  import urllib.request
6
  from pathlib import Path
7
- from typing import Dict, Tuple, Optional, List
8
 
9
  import numpy as np
10
  import torch
@@ -12,248 +10,492 @@ from PIL import Image
12
 
13
  import comfy.model_management as model_management
14
 
 
15
  try:
16
- import cv2
17
- except Exception:
18
- cv2 = None
 
19
 
20
 
21
- # -----------------------------
22
- # Paths / URLs (per your spec)
23
- # -----------------------------
24
 
25
- # nodes/Salia_Depth.py -> comfyui-salia_online/
26
- PLUGIN_ROOT = Path(__file__).resolve().parents[1]
 
27
 
28
- # MUST be assets/depth (not assets/assets, not assets/)
29
- ASSETS_DEPTH_DIR = PLUGIN_ROOT / "assets" / "depth"
 
30
 
31
- REQUIRED_FILES = ["config.json", "preprocessor_config.json", "model.safetensors"]
32
-
33
- HF_BASE = "https://huggingface.co/saliacoel/depth/resolve/main"
34
- FILE_URLS = {
35
- "config.json": f"{HF_BASE}/config.json",
36
- "preprocessor_config.json": f"{HF_BASE}/preprocessor_config.json",
37
- "model.safetensors": f"{HF_BASE}/model.safetensors",
38
  }
39
 
40
- # Fallback “zoe-path
41
- FALLBACK_ZOE_REPO = "Intel/zoedepth-nyu-kitti"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
- # -----------------------------
45
- # Global model cache
46
- # -----------------------------
47
- # key: (device_str, source_id) -> (processor, model)
48
- _MODEL_CACHE: Dict[Tuple[str, str], Tuple[object, object]] = {}
49
 
50
 
51
- # -----------------------------
52
- # Utility
53
- # -----------------------------
 
 
 
54
 
55
- def _ensure_dir(p: Path) -> None:
56
- p.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
57
 
 
58
 
59
- def _file_ok(p: Path) -> bool:
60
- # existence + non-empty is a good baseline against partial downloads
61
- return p.exists() and p.is_file() and p.stat().st_size > 0
62
 
 
 
 
63
 
64
- def _have_local_files() -> bool:
65
- return all(_file_ok(ASSETS_DEPTH_DIR / f) for f in REQUIRED_FILES)
66
 
67
 
68
- def _download_file(url: str, dst: Path, timeout: int = 60) -> None:
69
  """
70
- Download url -> dst atomically (tmp + replace).
71
- Raises on failure.
72
  """
73
- _ensure_dir(dst.parent)
74
  tmp = dst.with_suffix(dst.suffix + ".tmp")
75
 
76
- req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-Salia-Depth/1.0"})
 
 
 
 
 
 
77
  with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
78
  shutil.copyfileobj(r, f)
79
 
80
- if not _file_ok(tmp):
81
- raise RuntimeError(f"Downloaded file is empty/corrupt: {tmp}")
82
 
83
- os.replace(tmp, dst)
84
 
85
-
86
- def _ensure_local_model_files() -> bool:
87
  """
88
- Ensure the 3 required files exist in assets/depth.
89
- Returns True if available afterwards, False if download failed.
90
  """
91
- _ensure_dir(ASSETS_DEPTH_DIR)
92
-
93
- # already present
94
- if _have_local_files():
 
 
 
 
 
 
 
 
 
 
 
95
  return True
96
 
97
- # try download missing ones
 
98
  try:
99
- for fname in REQUIRED_FILES:
100
- dst = ASSETS_DEPTH_DIR / fname
101
- if not _file_ok(dst):
102
- _download_file(FILE_URLS[fname], dst)
103
- return _have_local_files()
 
 
 
 
 
 
104
  except Exception as e:
105
- print(f"[SaliaDepth] Download from saliacoel/depth failed: {e}")
106
  return False
107
 
108
 
109
- def _resize_max_side_uint8(img_u8: np.ndarray, max_side: int) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  """
111
- Resize uint8 HWC so that max(H,W) == max_side, keep aspect ratio.
112
- If max_side <= 0 or already matches, returns original.
 
 
 
113
  """
114
- if max_side <= 0:
115
- return img_u8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- h, w = img_u8.shape[:2]
118
- cur_max = max(h, w)
119
- if cur_max == 0 or cur_max == max_side:
120
- return img_u8
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- scale = float(max_side) / float(cur_max)
123
- new_w = max(1, int(round(w * scale)))
124
- new_h = max(1, int(round(h * scale)))
125
 
126
- if cv2 is not None:
127
- interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC
128
- return cv2.resize(img_u8, (new_w, new_h), interpolation=interp)
129
 
130
- # PIL fallback
131
- pil = Image.fromarray(img_u8)
132
- resample = Image.Resampling.LANCZOS if scale < 1 else Image.Resampling.BICUBIC
133
- pil = pil.resize((new_w, new_h), resample=resample)
134
- return np.array(pil, dtype=np.uint8)
135
 
136
 
137
- def _depth_to_hint_rgb(depth_2d: np.ndarray) -> np.ndarray:
138
  """
139
- Normalize depth to a ControlNet-style grayscale RGB hint.
140
- Uses percentile normalization (2..85) and inverts.
141
  """
142
- d = depth_2d.astype(np.float32)
143
- if not np.isfinite(d).all():
144
- d = np.nan_to_num(d, nan=0.0, posinf=0.0, neginf=0.0)
 
145
 
146
- vmin = np.percentile(d, 2)
147
- vmax = np.percentile(d, 85)
148
- denom = max(vmax - vmin, 1e-6)
149
 
150
- dn = (d - vmin) / denom
151
- dn = np.clip(dn, 0.0, 1.0)
152
- dn = 1.0 - dn
153
 
154
- u8 = (dn * 255.0).round().clip(0, 255).astype(np.uint8)
155
- return np.stack([u8, u8, u8], axis=-1)
156
 
 
 
 
157
 
158
- def _comfy_tensor_to_uint8_hwc(img: torch.Tensor) -> np.ndarray:
159
  """
160
- ComfyUI IMAGE: float [0..1], shape [H,W,3]
161
- -> uint8 HWC
162
  """
163
- x = img.detach()
164
- if x.is_cuda:
165
- x = x.cpu()
166
- x = x.float().clamp(0, 1).numpy()
167
- return (x * 255.0).round().clip(0, 255).astype(np.uint8)
168
-
169
-
170
- def _uint8_hwc_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
 
 
 
 
171
  """
172
- uint8 HWC -> float32 tensor HWC [0..1]
 
 
 
 
173
  """
174
- return torch.from_numpy(img_u8.astype(np.float32) / 255.0)
 
 
 
175
 
176
 
177
- def _post_process_depth(processor, outputs, target_h: int, target_w: int) -> np.ndarray:
 
 
 
 
178
  """
179
- Transformers API compatibility shim.
180
- Some versions use target_sizes, some source_sizes.
181
- Returns depth as float32 HxW.
182
  """
183
- # Try the most common signature first
184
- try:
185
- post = processor.post_process_depth_estimation(outputs, target_sizes=[(target_h, target_w)])
186
- except TypeError:
187
- post = processor.post_process_depth_estimation(outputs, source_sizes=[(target_h, target_w)])
188
 
189
- # expected: list[{"predicted_depth": tensor[H,W]}]
190
- depth_t = post[0]["predicted_depth"]
191
- return depth_t.detach().float().cpu().numpy()
192
 
 
 
 
 
193
 
194
- def _load_zoedepth_from_local(device: torch.device):
195
- """
196
- Load ZoeDepth from ASSETS_DEPTH_DIR (offline).
197
- """
198
- from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
199
 
200
- key = (str(device), f"local::{ASSETS_DEPTH_DIR}")
201
- if key in _MODEL_CACHE:
202
- return _MODEL_CACHE[key]
203
 
204
- processor = AutoImageProcessor.from_pretrained(str(ASSETS_DEPTH_DIR), local_files_only=True)
205
- model = ZoeDepthForDepthEstimation.from_pretrained(str(ASSETS_DEPTH_DIR), local_files_only=True)
206
- model.eval().to(device)
207
 
208
- _MODEL_CACHE[key] = (processor, model)
209
- return processor, model
210
 
211
-
212
- def _load_zoedepth_fallback(device: torch.device):
213
  """
214
- Load ZoeDepth from HF (zoe-path fallback).
 
215
  """
216
- from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
 
 
 
 
 
 
217
 
218
- key = (str(device), f"hf::{FALLBACK_ZOE_REPO}")
219
- if key in _MODEL_CACHE:
220
- return _MODEL_CACHE[key]
 
 
 
 
 
 
 
221
 
222
- processor = AutoImageProcessor.from_pretrained(FALLBACK_ZOE_REPO)
223
- model = ZoeDepthForDepthEstimation.from_pretrained(FALLBACK_ZOE_REPO)
224
- model.eval().to(device)
 
 
 
 
 
 
225
 
226
- _MODEL_CACHE[key] = (processor, model)
227
- return processor, model
228
 
229
 
230
- def _get_model(device: torch.device):
231
  """
232
- 1) Try local assets/depth (download if missing)
233
- 2) If that fails -> zoe-path fallback
234
- 3) If that fails -> return None
 
235
  """
 
 
 
 
 
 
 
236
  # Local-first
 
 
 
 
 
 
 
 
 
237
  try:
238
- if _ensure_local_model_files():
239
- try:
240
- return _load_zoedepth_from_local(device)
241
- except Exception as e:
242
- print(f"[SaliaDepth] Local load failed (assets/depth). Will fallback to zoe-path. Error: {e}")
243
  except Exception as e:
244
- print(f"[SaliaDepth] Local ensure/load unexpected error. Fallback to zoe-path. Error: {e}")
245
 
246
- # Fallback: zoe-path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  try:
248
- return _load_zoedepth_fallback(device)
 
 
249
  except Exception as e:
250
- print(f"[SaliaDepth] Zoe fallback load failed. Will passthrough image. Error: {e}")
251
- return None
 
 
252
 
253
 
254
- # -----------------------------
255
  # ComfyUI Node
256
- # -----------------------------
257
 
258
  class Salia_Depth_Preprocessor:
259
  @classmethod
@@ -261,83 +503,103 @@ class Salia_Depth_Preprocessor:
261
  return {
262
  "required": {
263
  "image": ("IMAGE",),
264
- # note 5: default -1, min -1
265
  "resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
266
  }
267
  }
268
 
269
- RETURN_TYPES = ("IMAGE",)
 
270
  FUNCTION = "execute"
271
  CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
272
 
273
- def execute(self, image: torch.Tensor, resolution: int = -1):
274
- """
275
- If anything fails:
276
- - return (image,) passthrough
277
- """
278
- # Basic shape validation; if weird, passthrough
 
279
  try:
280
- if image.dim() != 4 or image.shape[-1] != 3:
281
- print(f"[SaliaDepth] Unexpected input IMAGE shape {tuple(image.shape)}; passthrough.")
282
- return (image,)
283
- except Exception:
284
- return (image,)
285
 
286
- device = model_management.get_torch_device()
287
 
288
- model_pack = _get_model(device)
289
- if model_pack is None:
290
- return (image,)
 
 
 
 
291
 
292
- processor, model = model_pack
 
 
293
 
294
- outs: List[torch.Tensor] = []
 
 
295
 
296
- for b in range(image.shape[0]):
 
297
  try:
298
- # input in original size
299
- img_u8 = _comfy_tensor_to_uint8_hwc(image[b])
300
- h0, w0 = img_u8.shape[0], img_u8.shape[1]
301
-
302
- # note 5: if -1, use bigger side (max(w,h))
303
- max_side = max(w0, h0) if resolution == -1 else int(resolution)
304
-
305
- # resize for inference (max side rule)
306
- img_inf = _resize_max_side_uint8(img_u8, max_side=max_side)
307
- pil = Image.fromarray(img_inf)
308
-
309
- # preprocess
310
- inputs = processor(images=pil, return_tensors="pt")
311
- pixel_values = inputs["pixel_values"].to(device)
312
-
313
- with torch.inference_mode():
314
- outputs = model(pixel_values=pixel_values)
315
-
316
- # postprocess back to inference image size
317
- depth_np = _post_process_depth(processor, outputs, pil.height, pil.width)
318
-
319
- # depth -> grayscale RGB hint
320
- hint_rgb = _depth_to_hint_rgb(depth_np)
321
-
322
- # resize hint back to original size
323
- if hint_rgb.shape[0] != h0 or hint_rgb.shape[1] != w0:
324
- if cv2 is not None:
325
- hint_rgb = cv2.resize(hint_rgb, (w0, h0), interpolation=cv2.INTER_CUBIC)
326
- else:
327
- hint_rgb = np.array(
328
- Image.fromarray(hint_rgb).resize((w0, h0), resample=Image.Resampling.BICUBIC),
329
- dtype=np.uint8
330
- )
331
-
332
- outs.append(_uint8_hwc_to_comfy_tensor(hint_rgb))
 
 
 
 
 
 
 
 
 
333
 
334
  except Exception as e:
335
- # Per-image failure -> passthrough that image (keeps batch size consistent)
336
- print(f"[SaliaDepth] Inference failed on batch index {b}; passthrough that frame. Error: {e}")
337
- outs.append(image[b].detach().cpu() if image[b].is_cuda else image[b])
338
 
339
- out_batch = torch.stack(outs, dim=0)
340
- return (out_batch,)
 
341
 
342
 
343
  NODE_CLASS_MAPPINGS = {
@@ -345,5 +607,5 @@ NODE_CLASS_MAPPINGS = {
345
  }
346
 
347
  NODE_DISPLAY_NAME_MAPPINGS = {
348
- "SaliaDepthPreprocessor": "Salia Depth"
349
  }
 
 
 
1
  import os
2
  import shutil
3
  import urllib.request
4
  from pathlib import Path
5
+ from typing import Dict, Tuple, Any, Optional, List
6
 
7
  import numpy as np
8
  import torch
 
10
 
11
  import comfy.model_management as model_management
12
 
13
+ # transformers is required for depth-estimation pipeline
14
  try:
15
+ from transformers import pipeline
16
+ except Exception as e:
17
+ pipeline = None
18
+ _TRANSFORMERS_IMPORT_ERROR = e
19
 
20
 
21
+ # --------------------------------------------------------------------------------------
22
+ # Paths / sources
23
+ # --------------------------------------------------------------------------------------
24
 
25
+ # This file: comfyui-salia_online/nodes/Salia_Depth.py
26
+ # Plugin root: comfyui-salia_online/
27
+ PLUGIN_ROOT = Path(__file__).resolve().parent.parent
28
 
29
+ # Requested local path: assets/depth
30
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
31
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
32
 
33
+ REQUIRED_FILES = {
34
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
35
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
36
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
 
 
 
37
  }
38
 
39
+ # "zoe-path" fallback
40
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
41
+
42
+
43
+ # --------------------------------------------------------------------------------------
44
+ # Logging helpers
45
+ # --------------------------------------------------------------------------------------
46
+
47
+ def _make_logger() -> Tuple[List[str], Any]:
48
+ lines: List[str] = []
49
+
50
+ def log(msg: str):
51
+ # console
52
+ try:
53
+ print(msg)
54
+ except Exception:
55
+ pass
56
+ # UI string
57
+ lines.append(str(msg))
58
+
59
+ return lines, log
60
+
61
+
62
+ def _fmt_bytes(n: Optional[int]) -> str:
63
+ if n is None:
64
+ return "?"
65
+ # simple readable
66
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
67
+ if n < 1024:
68
+ return f"{n:.0f}{unit}"
69
+ n /= 1024.0
70
+ return f"{n:.1f}PB"
71
 
72
 
73
+ def _file_size(path: Path) -> Optional[int]:
74
+ try:
75
+ return path.stat().st_size
76
+ except Exception:
77
+ return None
78
 
79
 
80
+ def _hf_cache_info() -> Dict[str, str]:
81
+ info: Dict[str, str] = {}
82
+ info["env.HF_HOME"] = os.environ.get("HF_HOME", "")
83
+ info["env.HF_HUB_CACHE"] = os.environ.get("HF_HUB_CACHE", "")
84
+ info["env.TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "")
85
+ info["env.HUGGINGFACE_HUB_CACHE"] = os.environ.get("HUGGINGFACE_HUB_CACHE", "")
86
 
87
+ try:
88
+ from huggingface_hub import constants as hf_constants
89
+ # These exist in most hub versions:
90
+ info["huggingface_hub.constants.HF_HOME"] = str(getattr(hf_constants, "HF_HOME", ""))
91
+ info["huggingface_hub.constants.HF_HUB_CACHE"] = str(getattr(hf_constants, "HF_HUB_CACHE", ""))
92
+ except Exception:
93
+ pass
94
 
95
+ return info
96
 
 
 
 
97
 
98
+ # --------------------------------------------------------------------------------------
99
+ # Download helpers
100
+ # --------------------------------------------------------------------------------------
101
 
102
+ def _have_required_files() -> bool:
103
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
104
 
105
 
106
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
107
  """
108
+ Download with atomic temp rename.
 
109
  """
110
+ dst.parent.mkdir(parents=True, exist_ok=True)
111
  tmp = dst.with_suffix(dst.suffix + ".tmp")
112
 
113
+ if tmp.exists():
114
+ try:
115
+ tmp.unlink()
116
+ except Exception:
117
+ pass
118
+
119
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
120
  with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
121
  shutil.copyfileobj(r, f)
122
 
123
+ tmp.replace(dst)
 
124
 
 
125
 
126
+ def ensure_local_model_files(log) -> bool:
 
127
  """
128
+ Ensure assets/depth contains the 3 files.
129
+ Returns True if present or downloaded successfully, else False.
130
  """
131
+ # Always log expected locations + URLs, even if we don't download.
132
+ log("[SaliaDepth] ===== Local model file check =====")
133
+ log(f"[SaliaDepth] Plugin root: {PLUGIN_ROOT}")
134
+ log(f"[SaliaDepth] Local model dir (on drive): {MODEL_DIR}")
135
+
136
+ for fname, url in REQUIRED_FILES.items():
137
+ fpath = MODEL_DIR / fname
138
+ exists = fpath.exists()
139
+ size = _file_size(fpath) if exists else None
140
+ log(f"[SaliaDepth] - {fname}")
141
+ log(f"[SaliaDepth] local path: {fpath} exists={exists} size={_fmt_bytes(size)}")
142
+ log(f"[SaliaDepth] remote url : {url}")
143
+
144
+ if _have_required_files():
145
+ log("[SaliaDepth] All required local files already exist. No download needed.")
146
  return True
147
 
148
+ log("[SaliaDepth] One or more local files missing. Attempting download...")
149
+
150
  try:
151
+ for fname, url in REQUIRED_FILES.items():
152
+ fpath = MODEL_DIR / fname
153
+ if fpath.exists():
154
+ continue
155
+ log(f"[SaliaDepth] Downloading '{fname}' -> '{fpath}'")
156
+ _download_url_to_file(url, fpath)
157
+ log(f"[SaliaDepth] Downloaded '{fname}' size={_fmt_bytes(_file_size(fpath))}")
158
+
159
+ ok = _have_required_files()
160
+ log(f"[SaliaDepth] Download finished. ok={ok}")
161
+ return ok
162
  except Exception as e:
163
+ log(f"[SaliaDepth] Download failed with error: {repr(e)}")
164
  return False
165
 
166
 
167
+ # --------------------------------------------------------------------------------------
168
+ # Exact Zoe-style preprocessing helpers (copied/adapted from your snippet)
169
+ # --------------------------------------------------------------------------------------
170
+
171
+ def HWC3(x: np.ndarray) -> np.ndarray:
172
+ assert x.dtype == np.uint8
173
+ if x.ndim == 2:
174
+ x = x[:, :, None]
175
+ assert x.ndim == 3
176
+ H, W, C = x.shape
177
+ assert C == 1 or C == 3 or C == 4
178
+ if C == 3:
179
+ return x
180
+ if C == 1:
181
+ return np.concatenate([x, x, x], axis=2)
182
+ # C == 4
183
+ color = x[:, :, 0:3].astype(np.float32)
184
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
185
+ y = color * alpha + 255.0 * (1.0 - alpha) # white background
186
+ y = y.clip(0, 255).astype(np.uint8)
187
+ return y
188
+
189
+
190
+ def pad64(x: int) -> int:
191
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
192
+
193
+
194
+ def safer_memory(x: np.ndarray) -> np.ndarray:
195
+ return np.ascontiguousarray(x.copy()).copy()
196
+
197
+
198
+ def resize_image_with_pad_min_side(
199
+ input_image: np.ndarray,
200
+ resolution: int,
201
+ upscale_method: str = "INTER_CUBIC",
202
+ skip_hwc3: bool = False,
203
+ mode: str = "edge",
204
+ log=None
205
+ ) -> Tuple[np.ndarray, Any]:
206
  """
207
+ EXACT behavior like your zoe.transformers.py:
208
+ k = resolution / min(H,W)
209
+ resize to (W_target, H_target)
210
+ pad to multiple of 64
211
+ return padded image and remove_pad() closure
212
  """
213
+ # prefer cv2 like original for matching results
214
+ cv2 = None
215
+ try:
216
+ import cv2 as _cv2
217
+ cv2 = _cv2
218
+ except Exception:
219
+ cv2 = None
220
+ if log:
221
+ log("[SaliaDepth] WARN: cv2 not available; resizing will use PIL fallback (may change results).")
222
+
223
+ if skip_hwc3:
224
+ img = input_image
225
+ else:
226
+ img = HWC3(input_image)
227
+
228
+ H_raw, W_raw, _ = img.shape
229
+ if resolution <= 0:
230
+ # keep original, but still pad to 64 (we will handle padding separately for -1 path)
231
+ return img, (lambda x: x)
232
+
233
+ k = float(resolution) / float(min(H_raw, W_raw))
234
+ H_target = int(np.round(float(H_raw) * k))
235
+ W_target = int(np.round(float(W_raw) * k))
236
 
237
+ if cv2 is not None:
238
+ upscale_methods = {
239
+ "INTER_NEAREST": cv2.INTER_NEAREST,
240
+ "INTER_LINEAR": cv2.INTER_LINEAR,
241
+ "INTER_AREA": cv2.INTER_AREA,
242
+ "INTER_CUBIC": cv2.INTER_CUBIC,
243
+ "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
244
+ }
245
+ method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
246
+ img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
247
+ else:
248
+ # PIL fallback
249
+ pil = Image.fromarray(img)
250
+ resample = Image.BICUBIC if k > 1 else Image.LANCZOS
251
+ pil = pil.resize((W_target, H_target), resample=resample)
252
+ img = np.array(pil, dtype=np.uint8)
253
 
254
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
255
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
 
256
 
257
+ def remove_pad(x: np.ndarray) -> np.ndarray:
258
+ return safer_memory(x[:H_target, :W_target, ...])
 
259
 
260
+ return safer_memory(img_padded), remove_pad
 
 
 
 
261
 
262
 
263
+ def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
264
  """
265
+ For resolution == -1: keep original resolution but still pad to multiples of 64,
266
+ then provide remove_pad that returns original size.
267
  """
268
+ img = HWC3(img_u8)
269
+ H_raw, W_raw, _ = img.shape
270
+ H_pad, W_pad = pad64(H_raw), pad64(W_raw)
271
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
272
 
273
+ def remove_pad(x: np.ndarray) -> np.ndarray:
274
+ return safer_memory(x[:H_raw, :W_raw, ...])
 
275
 
276
+ return safer_memory(img_padded), remove_pad
 
 
277
 
 
 
278
 
279
+ # --------------------------------------------------------------------------------------
280
+ # RGBA rules (as you requested)
281
+ # --------------------------------------------------------------------------------------
282
 
283
+ def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
284
  """
285
+ If RGBA: return RGB composited over WHITE + alpha_u8 kept separately.
286
+ If RGB: return input RGB + None alpha.
287
  """
288
+ if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
289
+ rgba = inp_u8.astype(np.uint8)
290
+ rgb = rgba[:, :, 0:3].astype(np.float32)
291
+ a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
292
+ rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
293
+ alpha_u8 = rgba[:, :, 3].copy()
294
+ return rgb_white, alpha_u8
295
+ # force to RGB
296
+ return HWC3(inp_u8), None
297
+
298
+
299
+ def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
300
  """
301
+ Requested output rule:
302
+ - attach alpha to depth (conceptually RGBA)
303
+ - composite over BLACK
304
+ - output RGB
305
+ That is equivalent to depth_rgb * alpha.
306
  """
307
+ depth_rgb_u8 = HWC3(depth_rgb_u8)
308
+ a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
309
+ out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
310
+ return out
311
 
312
 
313
+ # --------------------------------------------------------------------------------------
314
+ # ComfyUI conversion helpers
315
+ # --------------------------------------------------------------------------------------
316
+
317
+ def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
318
  """
319
+ Comfy IMAGE: float [0..1], shape [H,W,C] or [B,H,W,C]
320
+ Convert to uint8 HWC.
 
321
  """
322
+ if img.ndim == 4:
323
+ img = img[0]
324
+ arr = img.detach().cpu().float().clamp(0, 1).numpy()
325
+ u8 = (arr * 255.0).round().astype(np.uint8)
326
+ return u8
327
 
 
 
 
328
 
329
+ def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
330
+ img_u8 = HWC3(img_u8)
331
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
332
+ return t.unsqueeze(0) # [1,H,W,C]
333
 
 
 
 
 
 
334
 
335
+ # --------------------------------------------------------------------------------------
336
+ # Pipeline loading (local-first, then zoe fallback)
337
+ # --------------------------------------------------------------------------------------
338
 
339
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
 
 
340
 
 
 
341
 
342
+ def _try_load_pipeline(model_source: str, device: torch.device, log):
 
343
  """
344
+ Use transformers.pipeline like Zoe code does.
345
+ We intentionally do NOT pass device=... here, and instead move model like Zoe node.
346
  """
347
+ if pipeline is None:
348
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
349
+
350
+ key = (model_source, str(device))
351
+ if key in _PIPE_CACHE:
352
+ log(f"[SaliaDepth] Using cached pipeline for source='{model_source}' device='{device}'")
353
+ return _PIPE_CACHE[key]
354
 
355
+ log(f"[SaliaDepth] Creating pipeline(task='depth-estimation', model='{model_source}')")
356
+ p = pipeline(task="depth-estimation", model=model_source)
357
+
358
+ # Try to move model to torch device, like ZoeDetector.to()
359
+ try:
360
+ p.model = p.model.to(device)
361
+ p.device = device # Zoe code sets this; newer transformers uses torch.device internally
362
+ log(f"[SaliaDepth] Moved pipeline model to device: {device}")
363
+ except Exception as e:
364
+ log(f"[SaliaDepth] WARN: Could not move pipeline model to device {device}: {repr(e)}")
365
 
366
+ # Log config info for debugging
367
+ try:
368
+ cfg = p.model.config
369
+ log(f"[SaliaDepth] Model class: {p.model.__class__.__name__}")
370
+ log(f"[SaliaDepth] Config class: {cfg.__class__.__name__}")
371
+ log(f"[SaliaDepth] Config model_type: {getattr(cfg, 'model_type', '')}")
372
+ log(f"[SaliaDepth] Config _name_or_path: {getattr(cfg, '_name_or_path', '')}")
373
+ except Exception as e:
374
+ log(f"[SaliaDepth] WARN: Could not log model config: {repr(e)}")
375
 
376
+ _PIPE_CACHE[key] = p
377
+ return p
378
 
379
 
380
+ def get_depth_pipeline(device: torch.device, log):
381
  """
382
+ 1) Ensure assets/depth files exist (download if missing)
383
+ 2) Try load local dir
384
+ 3) Fallback to Intel/zoedepth-nyu-kitti
385
+ 4) If both fail -> None
386
  """
387
+ # Always log HF cache info (helps locate where fallback downloads go)
388
+ log("[SaliaDepth] ===== Hugging Face cache info (fallback path) =====")
389
+ for k, v in _hf_cache_info().items():
390
+ if v:
391
+ log(f"[SaliaDepth] {k} = {v}")
392
+ log(f"[SaliaDepth] Zoe fallback repo id: {ZOE_FALLBACK_REPO_ID}")
393
+
394
  # Local-first
395
+ local_ok = ensure_local_model_files(log)
396
+ if local_ok:
397
+ try:
398
+ log(f"[SaliaDepth] Trying LOCAL model from directory: {MODEL_DIR}")
399
+ return _try_load_pipeline(str(MODEL_DIR), device, log)
400
+ except Exception as e:
401
+ log(f"[SaliaDepth] Local model load FAILED: {repr(e)}")
402
+
403
+ # Fallback
404
  try:
405
+ log(f"[SaliaDepth] Trying ZOE fallback model: {ZOE_FALLBACK_REPO_ID}")
406
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device, log)
 
 
 
407
  except Exception as e:
408
+ log(f"[SaliaDepth] Zoe fallback load FAILED: {repr(e)}")
409
 
410
+ return None
411
+
412
+
413
+ # --------------------------------------------------------------------------------------
414
+ # Depth inference (Zoe-style)
415
+ # --------------------------------------------------------------------------------------
416
+
417
+ def depth_estimate_zoe_style(
418
+ pipe,
419
+ input_rgb_u8: np.ndarray,
420
+ detect_resolution: int,
421
+ log,
422
+ upscale_method: str = "INTER_CUBIC"
423
+ ) -> np.ndarray:
424
+ """
425
+ Matches your ZoeDetector.__call__ logic very closely.
426
+ Returns uint8 RGB depth map.
427
+ """
428
+ # detect_resolution:
429
+ # - if -1: keep original but pad-to-64
430
+ # - else: min-side resize to detect_resolution, then pad-to-64
431
+ if detect_resolution == -1:
432
+ work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
433
+ log(f"[SaliaDepth] Preprocess: resolution=-1 (no resize), padded to 64. work={work_img.shape}")
434
+ else:
435
+ work_img, remove_pad = resize_image_with_pad_min_side(
436
+ input_rgb_u8,
437
+ int(detect_resolution),
438
+ upscale_method=upscale_method,
439
+ skip_hwc3=False,
440
+ mode="edge",
441
+ log=log
442
+ )
443
+ log(f"[SaliaDepth] Preprocess: min-side resized to {detect_resolution}, padded to 64. work={work_img.shape}")
444
+
445
+ pil_image = Image.fromarray(work_img)
446
+
447
+ with torch.no_grad():
448
+ result = pipe(pil_image)
449
+ depth = result["depth"]
450
+
451
+ if isinstance(depth, Image.Image):
452
+ depth_array = np.array(depth, dtype=np.float32)
453
+ else:
454
+ depth_array = np.array(depth, dtype=np.float32)
455
+
456
+ # EXACT normalization like your Zoe code
457
+ vmin = float(np.percentile(depth_array, 2))
458
+ vmax = float(np.percentile(depth_array, 85))
459
+
460
+ log(f"[SaliaDepth] Depth raw stats: shape={depth_array.shape} vmin(p2)={vmin:.6f} vmax(p85)={vmax:.6f} mean={float(depth_array.mean()):.6f}")
461
+
462
+ depth_array = depth_array - vmin
463
+ denom = (vmax - vmin)
464
+ if abs(denom) < 1e-12:
465
+ # avoid division by zero; log it
466
+ log("[SaliaDepth] WARN: vmax==vmin; forcing denom epsilon to avoid NaNs.")
467
+ denom = 1e-6
468
+ depth_array = depth_array / denom
469
+
470
+ # EXACT invert like your Zoe code
471
+ depth_array = 1.0 - depth_array
472
+
473
+ depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
474
+
475
+ detected_map = remove_pad(HWC3(depth_image))
476
+ log(f"[SaliaDepth] Output (post-remove_pad): {detected_map.shape} dtype={detected_map.dtype}")
477
+ return detected_map
478
+
479
+
480
+ def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int, log) -> np.ndarray:
481
+ """
482
+ Resize depth output back to original input size.
483
+ Use cv2 if available, else PIL.
484
+ """
485
  try:
486
+ import cv2
487
+ out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
488
+ return out.astype(np.uint8)
489
  except Exception as e:
490
+ log(f"[SaliaDepth] WARN: cv2 resize failed ({repr(e)}); using PIL.")
491
+ pil = Image.fromarray(depth_rgb_u8)
492
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
493
+ return np.array(pil, dtype=np.uint8)
494
 
495
 
496
+ # --------------------------------------------------------------------------------------
497
  # ComfyUI Node
498
+ # --------------------------------------------------------------------------------------
499
 
500
  class Salia_Depth_Preprocessor:
501
  @classmethod
 
503
  return {
504
  "required": {
505
  "image": ("IMAGE",),
506
+ # note: default -1, min -1
507
  "resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
508
  }
509
  }
510
 
511
+ # 2 outputs: image + log string
512
+ RETURN_TYPES = ("IMAGE", "STRING")
513
  FUNCTION = "execute"
514
  CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
515
 
516
+ def execute(self, image, resolution=-1):
517
+ lines, log = _make_logger()
518
+ log("[SaliaDepth] ==================================================")
519
+ log("[SaliaDepth] SaliaDepthPreprocessor starting")
520
+ log(f"[SaliaDepth] resolution input = {resolution}")
521
+
522
+ # Get torch device
523
  try:
524
+ device = model_management.get_torch_device()
525
+ except Exception as e:
526
+ device = torch.device("cpu")
527
+ log(f"[SaliaDepth] WARN: model_management.get_torch_device failed: {repr(e)} -> using CPU")
 
528
 
529
+ log(f"[SaliaDepth] torch device = {device}")
530
 
531
+ # Load pipeline
532
+ pipe = None
533
+ try:
534
+ pipe = get_depth_pipeline(device, log)
535
+ except Exception as e:
536
+ log(f"[SaliaDepth] ERROR: get_depth_pipeline crashed: {repr(e)}")
537
+ pipe = None
538
 
539
+ if pipe is None:
540
+ log("[SaliaDepth] FATAL: No pipeline available. Returning input image unchanged.")
541
+ return (image, "\n".join(lines))
542
 
543
+ # Batch support
544
+ if image.ndim == 3:
545
+ image = image.unsqueeze(0)
546
 
547
+ outs = []
548
+ for i in range(image.shape[0]):
549
  try:
550
+ # Original dimensions
551
+ h0 = int(image[i].shape[0])
552
+ w0 = int(image[i].shape[1])
553
+ c0 = int(image[i].shape[2])
554
+ log(f"[SaliaDepth] ---- Batch index {i} input shape = ({h0},{w0},{c0}) ----")
555
+
556
+ inp_u8 = comfy_tensor_to_u8(image[i])
557
+
558
+ # RGBA rule (pre)
559
+ rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
560
+ had_rgba = alpha_u8 is not None
561
+ log(f"[SaliaDepth] had_rgba={had_rgba}")
562
+
563
+ # Run depth (Zoe-style)
564
+ depth_rgb = depth_estimate_zoe_style(
565
+ pipe=pipe,
566
+ input_rgb_u8=rgb_for_depth,
567
+ detect_resolution=int(resolution),
568
+ log=log,
569
+ upscale_method="INTER_CUBIC"
570
+ )
571
+
572
+ # Resize back to original input size
573
+ depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0, log=log)
574
+
575
+ # RGBA rule (post)
576
+ if had_rgba:
577
+ # Use original alpha at original size.
578
+ # If alpha size differs, resize alpha to match.
579
+ if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
580
+ log("[SaliaDepth] Alpha size mismatch; resizing alpha to original size.")
581
+ try:
582
+ import cv2
583
+ alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
584
+ except Exception:
585
+ pil_a = Image.fromarray(alpha_u8)
586
+ pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
587
+ alpha_u8 = np.array(pil_a, dtype=np.uint8)
588
+
589
+ # "Put alpha on RGB turning it into RGBA, then put BLACK background behind it, then back to RGB"
590
+ depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
591
+ log("[SaliaDepth] Applied RGBA post-step (alpha + black background).")
592
+
593
+ outs.append(u8_to_comfy_tensor(depth_rgb))
594
 
595
  except Exception as e:
596
+ log(f"[SaliaDepth] ERROR: Inference failed at batch index {i}: {repr(e)}")
597
+ log("[SaliaDepth] Passing through original input image for this batch item.")
598
+ outs.append(image[i].unsqueeze(0))
599
 
600
+ out = torch.cat(outs, dim=0)
601
+ log("[SaliaDepth] Done.")
602
+ return (out, "\n".join(lines))
603
 
604
 
605
  NODE_CLASS_MAPPINGS = {
 
607
  }
608
 
609
  NODE_DISPLAY_NAME_MAPPINGS = {
610
+ "SaliaDepthPreprocessor": "Salia Depth (local assets/depth + logs)"
611
  }