saliacoel commited on
Commit
1cc0ed9
·
verified ·
1 Parent(s): 3d625c7

Upload salia_depth.py

Browse files
Files changed (1) hide show
  1. salia_depth.py +229 -227
salia_depth.py CHANGED
@@ -1,7 +1,10 @@
 
 
 
1
  import shutil
2
  import urllib.request
3
  from pathlib import Path
4
- from typing import Dict, Tuple, Any, Optional
5
 
6
  import numpy as np
7
  import torch
@@ -9,272 +12,248 @@ from PIL import Image
9
 
10
  import comfy.model_management as model_management
11
 
12
- # transformers is required
13
  try:
14
- from transformers import pipeline
15
- except Exception as e:
16
- pipeline = None
17
- _TRANSFORMERS_IMPORT_ERROR = e
18
 
 
 
 
19
 
20
- # --------------------------------------------------------------------------------------
21
- # Paths / sources
22
- # --------------------------------------------------------------------------------------
23
 
24
- # This file: comfyui-salia_online/nodes/Salia_Depth.py
25
- # Plugin root: comfyui-salia_online/
26
- PLUGIN_ROOT = Path(__file__).resolve().parent.parent
27
 
28
- # Requested local path: assets/depth
29
- MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
30
- MODEL_DIR.mkdir(parents=True, exist_ok=True)
31
 
32
- REQUIRED_FILES = {
33
- "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
34
- "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
35
- "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
 
36
  }
37
 
38
- # "zoe-path" fallback (matches what your current ZoeDetector code pulls)
39
- ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
41
 
42
- # --------------------------------------------------------------------------------------
43
- # Download + validation helpers
44
- # --------------------------------------------------------------------------------------
45
 
46
- def _have_required_files() -> bool:
47
- return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
48
 
49
 
50
- def _download_url_to_file(url: str, dst: Path, timeout: int = 120) -> None:
51
  """
52
- Download with an atomic temp file -> rename.
 
53
  """
54
- dst.parent.mkdir(parents=True, exist_ok=True)
55
  tmp = dst.with_suffix(dst.suffix + ".tmp")
56
 
57
- if tmp.exists():
58
- try:
59
- tmp.unlink()
60
- except Exception:
61
- pass
62
-
63
- req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.0"})
64
  with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
65
  shutil.copyfileobj(r, f)
66
 
67
- tmp.replace(dst)
 
 
 
68
 
69
 
70
- def ensure_local_model_files() -> bool:
71
  """
72
- Ensure assets/depth contains config.json, model.safetensors, preprocessor_config.json.
73
- Returns True if files are present (either already or downloaded).
74
- Returns False if download failed.
75
  """
76
- if _have_required_files():
77
- return True
78
 
79
- print("[SaliaDepth] Local model files missing in:", str(MODEL_DIR))
80
- print("[SaliaDepth] Attempting to download required files from saliacoel/depth ...")
 
81
 
 
82
  try:
83
- for fname, url in REQUIRED_FILES.items():
84
- fpath = MODEL_DIR / fname
85
- if fpath.exists():
86
- continue
87
- print(f"[SaliaDepth] Downloading {fname} ...")
88
- _download_url_to_file(url, fpath)
89
- ok = _have_required_files()
90
- print(f"[SaliaDepth] Download complete. ok={ok}")
91
- return ok
92
  except Exception as e:
93
- print("[SaliaDepth] Download failed:", repr(e))
94
  return False
95
 
96
 
97
- # --------------------------------------------------------------------------------------
98
- # Pipeline cache / load
99
- # --------------------------------------------------------------------------------------
100
-
101
- _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
102
-
103
-
104
- def _pipeline_device_arg(device: torch.device) -> int:
105
- # transformers.pipeline: device=-1 for CPU, 0..N for CUDA index
106
- if device.type == "cuda":
107
- return int(device.index) if device.index is not None else 0
108
- return -1
109
-
110
-
111
- def _try_load_pipeline(model_source: str, device: torch.device):
112
  """
113
- model_source can be:
114
- - local directory path (string)
115
- - HF repo id
116
  """
117
- if pipeline is None:
118
- raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
119
 
120
- key = (model_source, str(device))
121
- if key in _PIPE_CACHE:
122
- return _PIPE_CACHE[key]
 
123
 
124
- dev_arg = _pipeline_device_arg(device)
125
- print(f"[SaliaDepth] Loading depth-estimation pipeline from '{model_source}' (device={dev_arg})")
 
126
 
127
- p = pipeline(task="depth-estimation", model=model_source, device=dev_arg)
 
 
128
 
129
- # If Comfy gives MPS (mac), pipeline device arg is -1; try moving model anyway.
130
- try:
131
- p.model = p.model.to(device)
132
- except Exception:
133
- pass
134
-
135
- _PIPE_CACHE[key] = p
136
- return p
137
 
138
 
139
- def get_depth_pipeline(device: torch.device):
140
  """
141
- 1) Try local assets/depth (download if missing)
142
- 2) Fallback to zoe-path Intel/zoedepth-nyu-kitti
143
- 3) If both fail -> return None
144
  """
145
- # 1) local-first
146
- if ensure_local_model_files():
147
- try:
148
- return _try_load_pipeline(str(MODEL_DIR), device)
149
- except Exception as e:
150
- print("[SaliaDepth] Local model load failed:", repr(e))
151
-
152
- # 2) zoe fallback
153
- try:
154
- print("[SaliaDepth] Falling back to Zoe path:", ZOE_FALLBACK_REPO_ID)
155
- return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
156
- except Exception as e:
157
- print("[SaliaDepth] Zoe fallback load failed:", repr(e))
158
-
159
- # 3) total failure
160
- return None
161
 
 
 
 
162
 
163
- # --------------------------------------------------------------------------------------
164
- # Image utilities
165
- # --------------------------------------------------------------------------------------
166
 
167
- def _hwc3(x: np.ndarray) -> np.ndarray:
168
- assert x.dtype == np.uint8
169
- if x.ndim == 2:
170
- x = x[:, :, None]
171
- if x.shape[2] == 1:
172
- return np.concatenate([x, x, x], axis=2)
173
- if x.shape[2] == 3:
174
- return x
175
- if x.shape[2] == 4:
176
- color = x[:, :, 0:3].astype(np.float32)
177
- alpha = x[:, :, 3:4].astype(np.float32) / 255.0
178
- y = color * alpha + 255.0 * (1.0 - alpha)
179
- return y.clip(0, 255).astype(np.uint8)
180
- raise ValueError("Unexpected channel count")
181
 
182
 
183
- def _pad64(n: int) -> int:
184
- return int(np.ceil(float(n) / 64.0) * 64 - n)
 
 
 
 
 
 
 
 
185
 
186
 
187
- def _resize_long_side(image_u8: np.ndarray, long_side: int) -> np.ndarray:
188
  """
189
- Resize so that max(H,W) == long_side. If long_side equals current long side -> no change.
190
  """
191
- h, w = image_u8.shape[:2]
192
- cur_long = max(h, w)
193
- if long_side <= 0 or long_side == cur_long:
194
- return image_u8
195
-
196
- scale = float(long_side) / float(cur_long)
197
- new_w = int(round(w * scale))
198
- new_h = int(round(h * scale))
199
-
200
- pil = Image.fromarray(image_u8)
201
- # Downscale with LANCZOS, upscale with BICUBIC
202
- resample = Image.BICUBIC if scale > 1.0 else Image.LANCZOS
203
- pil = pil.resize((new_w, new_h), resample=resample)
204
- return np.array(pil, dtype=np.uint8)
205
 
206
 
207
- def _pad_to_64(image_u8: np.ndarray, mode: str = "edge"):
208
- h, w = image_u8.shape[:2]
209
- hp = _pad64(h)
210
- wp = _pad64(w)
211
- padded = np.pad(image_u8, ((0, hp), (0, wp), (0, 0)), mode=mode)
212
-
213
- def remove_pad(x: np.ndarray) -> np.ndarray:
214
- return x[:h, :w, :]
 
 
 
215
 
216
- return padded, remove_pad
 
 
217
 
218
 
219
- def _comfy_to_u8(img: torch.Tensor) -> np.ndarray:
220
  """
221
- Comfy IMAGE is float [0..1], shape [H,W,C] or [B,H,W,C]
222
  """
223
- if img.ndim == 4:
224
- img = img[0]
225
- img = img.detach().cpu().float().clamp(0, 1)
226
- arr = (img.numpy() * 255.0).round().astype(np.uint8)
227
- return arr
228
 
 
 
 
229
 
230
- def _u8_to_comfy(img_u8: np.ndarray) -> torch.Tensor:
231
- img_u8 = _hwc3(img_u8)
232
- t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
233
- return t.unsqueeze(0) # [1,H,W,C]
234
 
 
 
235
 
236
- def _depth_to_uint8(pipe, input_u8: np.ndarray, detect_long_side: int) -> np.ndarray:
 
237
  """
238
- Run depth estimation:
239
- - resize (long side)
240
- - pad to 64
241
- - infer
242
- - normalize (percentiles like your zoe code)
243
- - remove pad
244
- - return 3-channel uint8
245
  """
246
- input_u8 = _hwc3(input_u8)
247
- resized = _resize_long_side(input_u8, detect_long_side)
248
- padded, remove_pad = _pad_to_64(resized, mode="edge")
249
-
250
- pil = Image.fromarray(padded)
251
 
252
- with torch.no_grad():
253
- result = pipe(pil)
254
- depth = result["depth"]
255
 
256
- if isinstance(depth, Image.Image):
257
- depth_arr = np.array(depth, dtype=np.float32)
258
- else:
259
- depth_arr = np.array(depth, dtype=np.float32)
260
 
261
- vmin = np.percentile(depth_arr, 2)
262
- vmax = np.percentile(depth_arr, 85)
263
- denom = (vmax - vmin) if (vmax - vmin) > 1e-6 else 1e-6
264
 
265
- depth_arr = (depth_arr - vmin) / denom
266
- depth_arr = 1.0 - depth_arr
267
 
268
- depth_u8 = (depth_arr * 255.0).clip(0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- depth_rgb = _hwc3(depth_u8)
271
- depth_rgb = remove_pad(depth_rgb)
272
- return depth_rgb
 
 
 
273
 
274
 
275
- # --------------------------------------------------------------------------------------
276
  # ComfyUI Node
277
- # --------------------------------------------------------------------------------------
278
 
279
  class Salia_Depth_Preprocessor:
280
  @classmethod
@@ -291,51 +270,74 @@ class Salia_Depth_Preprocessor:
291
  FUNCTION = "execute"
292
  CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
293
 
294
- def execute(self, image, resolution=-1):
295
  """
296
- If everything fails (local model + zoe fallback), return input image unchanged.
 
297
  """
 
298
  try:
299
- device = model_management.get_torch_device()
 
 
300
  except Exception:
301
- device = torch.device("cpu")
302
-
303
- pipe = get_depth_pipeline(device)
304
- if pipe is None:
305
- # Hard fail: return input image unchanged
306
- print("[SaliaDepth] No pipeline available. Returning input image unchanged.")
307
  return (image,)
308
 
309
- # Batch support: image is [B,H,W,C]
310
- if image.ndim == 3:
311
- image = image.unsqueeze(0)
 
 
312
 
313
- outs = []
314
- for i in range(image.shape[0]):
315
- # original size
316
- h0 = int(image[i].shape[0])
317
- w0 = int(image[i].shape[1])
318
- long_side = max(w0, h0)
319
 
320
- detect_long_side = long_side if int(resolution) == -1 else int(resolution)
321
 
 
322
  try:
323
- inp_u8 = _comfy_to_u8(image[i])
324
- depth_u8 = _depth_to_uint8(pipe, inp_u8, detect_long_side)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
- # resize depth back to original input size
327
- pil = Image.fromarray(depth_u8)
328
- pil = pil.resize((w0, h0), resample=Image.BILINEAR)
329
- depth_u8 = np.array(pil, dtype=np.uint8)
330
 
331
- outs.append(_u8_to_comfy(depth_u8))
332
  except Exception as e:
333
- # Per-image fail: return that image unchanged
334
- print(f"[SaliaDepth] Inference failed for batch index {i}: {repr(e)}. Passing through input.")
335
- outs.append(image[i].unsqueeze(0))
336
 
337
- out = torch.cat(outs, dim=0)
338
- return (out,)
339
 
340
 
341
  NODE_CLASS_MAPPINGS = {
@@ -343,5 +345,5 @@ NODE_CLASS_MAPPINGS = {
343
  }
344
 
345
  NODE_DISPLAY_NAME_MAPPINGS = {
346
- "SaliaDepthPreprocessor": "Salia Depth (assets/depth local-first)"
347
  }
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
  import shutil
5
  import urllib.request
6
  from pathlib import Path
7
+ from typing import Dict, Tuple, Optional, List
8
 
9
  import numpy as np
10
  import torch
 
12
 
13
  import comfy.model_management as model_management
14
 
 
15
  try:
16
+ import cv2
17
+ except Exception:
18
+ cv2 = None
19
+
20
 
21
+ # -----------------------------
22
+ # Paths / URLs (per your spec)
23
+ # -----------------------------
24
 
25
+ # nodes/Salia_Depth.py -> comfyui-salia_online/
26
+ PLUGIN_ROOT = Path(__file__).resolve().parents[1]
 
27
 
28
+ # MUST be assets/depth (not assets/assets, not assets/)
29
+ ASSETS_DEPTH_DIR = PLUGIN_ROOT / "assets" / "depth"
 
30
 
31
+ REQUIRED_FILES = ["config.json", "preprocessor_config.json", "model.safetensors"]
 
 
32
 
33
+ HF_BASE = "https://huggingface.co/saliacoel/depth/resolve/main"
34
+ FILE_URLS = {
35
+ "config.json": f"{HF_BASE}/config.json",
36
+ "preprocessor_config.json": f"{HF_BASE}/preprocessor_config.json",
37
+ "model.safetensors": f"{HF_BASE}/model.safetensors",
38
  }
39
 
40
+ # Fallback “zoe-path
41
+ FALLBACK_ZOE_REPO = "Intel/zoedepth-nyu-kitti"
42
+
43
+
44
+ # -----------------------------
45
+ # Global model cache
46
+ # -----------------------------
47
+ # key: (device_str, source_id) -> (processor, model)
48
+ _MODEL_CACHE: Dict[Tuple[str, str], Tuple[object, object]] = {}
49
+
50
+
51
+ # -----------------------------
52
+ # Utility
53
+ # -----------------------------
54
+
55
+ def _ensure_dir(p: Path) -> None:
56
+ p.mkdir(parents=True, exist_ok=True)
57
+
58
 
59
+ def _file_ok(p: Path) -> bool:
60
+ # existence + non-empty is a good baseline against partial downloads
61
+ return p.exists() and p.is_file() and p.stat().st_size > 0
62
 
 
 
 
63
 
64
+ def _have_local_files() -> bool:
65
+ return all(_file_ok(ASSETS_DEPTH_DIR / f) for f in REQUIRED_FILES)
66
 
67
 
68
+ def _download_file(url: str, dst: Path, timeout: int = 60) -> None:
69
  """
70
+ Download url -> dst atomically (tmp + replace).
71
+ Raises on failure.
72
  """
73
+ _ensure_dir(dst.parent)
74
  tmp = dst.with_suffix(dst.suffix + ".tmp")
75
 
76
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-Salia-Depth/1.0"})
 
 
 
 
 
 
77
  with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
78
  shutil.copyfileobj(r, f)
79
 
80
+ if not _file_ok(tmp):
81
+ raise RuntimeError(f"Downloaded file is empty/corrupt: {tmp}")
82
+
83
+ os.replace(tmp, dst)
84
 
85
 
86
+ def _ensure_local_model_files() -> bool:
87
  """
88
+ Ensure the 3 required files exist in assets/depth.
89
+ Returns True if available afterwards, False if download failed.
 
90
  """
91
+ _ensure_dir(ASSETS_DEPTH_DIR)
 
92
 
93
+ # already present
94
+ if _have_local_files():
95
+ return True
96
 
97
+ # try download missing ones
98
  try:
99
+ for fname in REQUIRED_FILES:
100
+ dst = ASSETS_DEPTH_DIR / fname
101
+ if not _file_ok(dst):
102
+ _download_file(FILE_URLS[fname], dst)
103
+ return _have_local_files()
 
 
 
 
104
  except Exception as e:
105
+ print(f"[SaliaDepth] Download from saliacoel/depth failed: {e}")
106
  return False
107
 
108
 
109
+ def _resize_max_side_uint8(img_u8: np.ndarray, max_side: int) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  """
111
+ Resize uint8 HWC so that max(H,W) == max_side, keep aspect ratio.
112
+ If max_side <= 0 or already matches, returns original.
 
113
  """
114
+ if max_side <= 0:
115
+ return img_u8
116
 
117
+ h, w = img_u8.shape[:2]
118
+ cur_max = max(h, w)
119
+ if cur_max == 0 or cur_max == max_side:
120
+ return img_u8
121
 
122
+ scale = float(max_side) / float(cur_max)
123
+ new_w = max(1, int(round(w * scale)))
124
+ new_h = max(1, int(round(h * scale)))
125
 
126
+ if cv2 is not None:
127
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_CUBIC
128
+ return cv2.resize(img_u8, (new_w, new_h), interpolation=interp)
129
 
130
+ # PIL fallback
131
+ pil = Image.fromarray(img_u8)
132
+ resample = Image.Resampling.LANCZOS if scale < 1 else Image.Resampling.BICUBIC
133
+ pil = pil.resize((new_w, new_h), resample=resample)
134
+ return np.array(pil, dtype=np.uint8)
 
 
 
135
 
136
 
137
+ def _depth_to_hint_rgb(depth_2d: np.ndarray) -> np.ndarray:
138
  """
139
+ Normalize depth to a ControlNet-style grayscale RGB hint.
140
+ Uses percentile normalization (2..85) and inverts.
 
141
  """
142
+ d = depth_2d.astype(np.float32)
143
+ if not np.isfinite(d).all():
144
+ d = np.nan_to_num(d, nan=0.0, posinf=0.0, neginf=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ vmin = np.percentile(d, 2)
147
+ vmax = np.percentile(d, 85)
148
+ denom = max(vmax - vmin, 1e-6)
149
 
150
+ dn = (d - vmin) / denom
151
+ dn = np.clip(dn, 0.0, 1.0)
152
+ dn = 1.0 - dn
153
 
154
+ u8 = (dn * 255.0).round().clip(0, 255).astype(np.uint8)
155
+ return np.stack([u8, u8, u8], axis=-1)
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
 
158
+ def _comfy_tensor_to_uint8_hwc(img: torch.Tensor) -> np.ndarray:
159
+ """
160
+ ComfyUI IMAGE: float [0..1], shape [H,W,3]
161
+ -> uint8 HWC
162
+ """
163
+ x = img.detach()
164
+ if x.is_cuda:
165
+ x = x.cpu()
166
+ x = x.float().clamp(0, 1).numpy()
167
+ return (x * 255.0).round().clip(0, 255).astype(np.uint8)
168
 
169
 
170
+ def _uint8_hwc_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
171
  """
172
+ uint8 HWC -> float32 tensor HWC [0..1]
173
  """
174
+ return torch.from_numpy(img_u8.astype(np.float32) / 255.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
+ def _post_process_depth(processor, outputs, target_h: int, target_w: int) -> np.ndarray:
178
+ """
179
+ Transformers API compatibility shim.
180
+ Some versions use target_sizes, some source_sizes.
181
+ Returns depth as float32 HxW.
182
+ """
183
+ # Try the most common signature first
184
+ try:
185
+ post = processor.post_process_depth_estimation(outputs, target_sizes=[(target_h, target_w)])
186
+ except TypeError:
187
+ post = processor.post_process_depth_estimation(outputs, source_sizes=[(target_h, target_w)])
188
 
189
+ # expected: list[{"predicted_depth": tensor[H,W]}]
190
+ depth_t = post[0]["predicted_depth"]
191
+ return depth_t.detach().float().cpu().numpy()
192
 
193
 
194
+ def _load_zoedepth_from_local(device: torch.device):
195
  """
196
+ Load ZoeDepth from ASSETS_DEPTH_DIR (offline).
197
  """
198
+ from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
 
 
 
 
199
 
200
+ key = (str(device), f"local::{ASSETS_DEPTH_DIR}")
201
+ if key in _MODEL_CACHE:
202
+ return _MODEL_CACHE[key]
203
 
204
+ processor = AutoImageProcessor.from_pretrained(str(ASSETS_DEPTH_DIR), local_files_only=True)
205
+ model = ZoeDepthForDepthEstimation.from_pretrained(str(ASSETS_DEPTH_DIR), local_files_only=True)
206
+ model.eval().to(device)
 
207
 
208
+ _MODEL_CACHE[key] = (processor, model)
209
+ return processor, model
210
 
211
+
212
+ def _load_zoedepth_fallback(device: torch.device):
213
  """
214
+ Load ZoeDepth from HF (zoe-path fallback).
 
 
 
 
 
 
215
  """
216
+ from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
 
 
 
 
217
 
218
+ key = (str(device), f"hf::{FALLBACK_ZOE_REPO}")
219
+ if key in _MODEL_CACHE:
220
+ return _MODEL_CACHE[key]
221
 
222
+ processor = AutoImageProcessor.from_pretrained(FALLBACK_ZOE_REPO)
223
+ model = ZoeDepthForDepthEstimation.from_pretrained(FALLBACK_ZOE_REPO)
224
+ model.eval().to(device)
 
225
 
226
+ _MODEL_CACHE[key] = (processor, model)
227
+ return processor, model
 
228
 
 
 
229
 
230
+ def _get_model(device: torch.device):
231
+ """
232
+ 1) Try local assets/depth (download if missing)
233
+ 2) If that fails -> zoe-path fallback
234
+ 3) If that fails -> return None
235
+ """
236
+ # Local-first
237
+ try:
238
+ if _ensure_local_model_files():
239
+ try:
240
+ return _load_zoedepth_from_local(device)
241
+ except Exception as e:
242
+ print(f"[SaliaDepth] Local load failed (assets/depth). Will fallback to zoe-path. Error: {e}")
243
+ except Exception as e:
244
+ print(f"[SaliaDepth] Local ensure/load unexpected error. Fallback to zoe-path. Error: {e}")
245
 
246
+ # Fallback: zoe-path
247
+ try:
248
+ return _load_zoedepth_fallback(device)
249
+ except Exception as e:
250
+ print(f"[SaliaDepth] Zoe fallback load failed. Will passthrough image. Error: {e}")
251
+ return None
252
 
253
 
254
+ # -----------------------------
255
  # ComfyUI Node
256
+ # -----------------------------
257
 
258
  class Salia_Depth_Preprocessor:
259
  @classmethod
 
270
  FUNCTION = "execute"
271
  CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
272
 
273
+ def execute(self, image: torch.Tensor, resolution: int = -1):
274
  """
275
+ If anything fails:
276
+ - return (image,) passthrough
277
  """
278
+ # Basic shape validation; if weird, passthrough
279
  try:
280
+ if image.dim() != 4 or image.shape[-1] != 3:
281
+ print(f"[SaliaDepth] Unexpected input IMAGE shape {tuple(image.shape)}; passthrough.")
282
+ return (image,)
283
  except Exception:
 
 
 
 
 
 
284
  return (image,)
285
 
286
+ device = model_management.get_torch_device()
287
+
288
+ model_pack = _get_model(device)
289
+ if model_pack is None:
290
+ return (image,)
291
 
292
+ processor, model = model_pack
 
 
 
 
 
293
 
294
+ outs: List[torch.Tensor] = []
295
 
296
+ for b in range(image.shape[0]):
297
  try:
298
+ # input in original size
299
+ img_u8 = _comfy_tensor_to_uint8_hwc(image[b])
300
+ h0, w0 = img_u8.shape[0], img_u8.shape[1]
301
+
302
+ # note 5: if -1, use bigger side (max(w,h))
303
+ max_side = max(w0, h0) if resolution == -1 else int(resolution)
304
+
305
+ # resize for inference (max side rule)
306
+ img_inf = _resize_max_side_uint8(img_u8, max_side=max_side)
307
+ pil = Image.fromarray(img_inf)
308
+
309
+ # preprocess
310
+ inputs = processor(images=pil, return_tensors="pt")
311
+ pixel_values = inputs["pixel_values"].to(device)
312
+
313
+ with torch.inference_mode():
314
+ outputs = model(pixel_values=pixel_values)
315
+
316
+ # postprocess back to inference image size
317
+ depth_np = _post_process_depth(processor, outputs, pil.height, pil.width)
318
+
319
+ # depth -> grayscale RGB hint
320
+ hint_rgb = _depth_to_hint_rgb(depth_np)
321
+
322
+ # resize hint back to original size
323
+ if hint_rgb.shape[0] != h0 or hint_rgb.shape[1] != w0:
324
+ if cv2 is not None:
325
+ hint_rgb = cv2.resize(hint_rgb, (w0, h0), interpolation=cv2.INTER_CUBIC)
326
+ else:
327
+ hint_rgb = np.array(
328
+ Image.fromarray(hint_rgb).resize((w0, h0), resample=Image.Resampling.BICUBIC),
329
+ dtype=np.uint8
330
+ )
331
 
332
+ outs.append(_uint8_hwc_to_comfy_tensor(hint_rgb))
 
 
 
333
 
 
334
  except Exception as e:
335
+ # Per-image failure -> passthrough that image (keeps batch size consistent)
336
+ print(f"[SaliaDepth] Inference failed on batch index {b}; passthrough that frame. Error: {e}")
337
+ outs.append(image[b].detach().cpu() if image[b].is_cuda else image[b])
338
 
339
+ out_batch = torch.stack(outs, dim=0)
340
+ return (out_batch,)
341
 
342
 
343
  NODE_CLASS_MAPPINGS = {
 
345
  }
346
 
347
  NODE_DISPLAY_NAME_MAPPINGS = {
348
+ "SaliaDepthPreprocessor": "Salia Depth"
349
  }