saliacoel commited on
Commit
3d625c7
·
verified ·
1 Parent(s): d02b8fc

Upload salia_depth.py

Browse files
Files changed (1) hide show
  1. salia_depth.py +347 -0
salia_depth.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import urllib.request
3
+ from pathlib import Path
4
+ from typing import Dict, Tuple, Any, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+ import comfy.model_management as model_management
11
+
12
+ # transformers is required
13
+ try:
14
+ from transformers import pipeline
15
+ except Exception as e:
16
+ pipeline = None
17
+ _TRANSFORMERS_IMPORT_ERROR = e
18
+
19
+
20
+ # --------------------------------------------------------------------------------------
21
+ # Paths / sources
22
+ # --------------------------------------------------------------------------------------
23
+
24
+ # This file: comfyui-salia_online/nodes/Salia_Depth.py
25
+ # Plugin root: comfyui-salia_online/
26
+ PLUGIN_ROOT = Path(__file__).resolve().parent.parent
27
+
28
+ # Requested local path: assets/depth
29
+ MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
30
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
31
+
32
+ REQUIRED_FILES = {
33
+ "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
34
+ "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
35
+ "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
36
+ }
37
+
38
+ # "zoe-path" fallback (matches what your current ZoeDetector code pulls)
39
+ ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
40
+
41
+
42
+ # --------------------------------------------------------------------------------------
43
+ # Download + validation helpers
44
+ # --------------------------------------------------------------------------------------
45
+
46
+ def _have_required_files() -> bool:
47
+ return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
48
+
49
+
50
+ def _download_url_to_file(url: str, dst: Path, timeout: int = 120) -> None:
51
+ """
52
+ Download with an atomic temp file -> rename.
53
+ """
54
+ dst.parent.mkdir(parents=True, exist_ok=True)
55
+ tmp = dst.with_suffix(dst.suffix + ".tmp")
56
+
57
+ if tmp.exists():
58
+ try:
59
+ tmp.unlink()
60
+ except Exception:
61
+ pass
62
+
63
+ req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.0"})
64
+ with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
65
+ shutil.copyfileobj(r, f)
66
+
67
+ tmp.replace(dst)
68
+
69
+
70
+ def ensure_local_model_files() -> bool:
71
+ """
72
+ Ensure assets/depth contains config.json, model.safetensors, preprocessor_config.json.
73
+ Returns True if files are present (either already or downloaded).
74
+ Returns False if download failed.
75
+ """
76
+ if _have_required_files():
77
+ return True
78
+
79
+ print("[SaliaDepth] Local model files missing in:", str(MODEL_DIR))
80
+ print("[SaliaDepth] Attempting to download required files from saliacoel/depth ...")
81
+
82
+ try:
83
+ for fname, url in REQUIRED_FILES.items():
84
+ fpath = MODEL_DIR / fname
85
+ if fpath.exists():
86
+ continue
87
+ print(f"[SaliaDepth] Downloading {fname} ...")
88
+ _download_url_to_file(url, fpath)
89
+ ok = _have_required_files()
90
+ print(f"[SaliaDepth] Download complete. ok={ok}")
91
+ return ok
92
+ except Exception as e:
93
+ print("[SaliaDepth] Download failed:", repr(e))
94
+ return False
95
+
96
+
97
+ # --------------------------------------------------------------------------------------
98
+ # Pipeline cache / load
99
+ # --------------------------------------------------------------------------------------
100
+
101
+ _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
102
+
103
+
104
+ def _pipeline_device_arg(device: torch.device) -> int:
105
+ # transformers.pipeline: device=-1 for CPU, 0..N for CUDA index
106
+ if device.type == "cuda":
107
+ return int(device.index) if device.index is not None else 0
108
+ return -1
109
+
110
+
111
+ def _try_load_pipeline(model_source: str, device: torch.device):
112
+ """
113
+ model_source can be:
114
+ - local directory path (string)
115
+ - HF repo id
116
+ """
117
+ if pipeline is None:
118
+ raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
119
+
120
+ key = (model_source, str(device))
121
+ if key in _PIPE_CACHE:
122
+ return _PIPE_CACHE[key]
123
+
124
+ dev_arg = _pipeline_device_arg(device)
125
+ print(f"[SaliaDepth] Loading depth-estimation pipeline from '{model_source}' (device={dev_arg})")
126
+
127
+ p = pipeline(task="depth-estimation", model=model_source, device=dev_arg)
128
+
129
+ # If Comfy gives MPS (mac), pipeline device arg is -1; try moving model anyway.
130
+ try:
131
+ p.model = p.model.to(device)
132
+ except Exception:
133
+ pass
134
+
135
+ _PIPE_CACHE[key] = p
136
+ return p
137
+
138
+
139
+ def get_depth_pipeline(device: torch.device):
140
+ """
141
+ 1) Try local assets/depth (download if missing)
142
+ 2) Fallback to zoe-path Intel/zoedepth-nyu-kitti
143
+ 3) If both fail -> return None
144
+ """
145
+ # 1) local-first
146
+ if ensure_local_model_files():
147
+ try:
148
+ return _try_load_pipeline(str(MODEL_DIR), device)
149
+ except Exception as e:
150
+ print("[SaliaDepth] Local model load failed:", repr(e))
151
+
152
+ # 2) zoe fallback
153
+ try:
154
+ print("[SaliaDepth] Falling back to Zoe path:", ZOE_FALLBACK_REPO_ID)
155
+ return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
156
+ except Exception as e:
157
+ print("[SaliaDepth] Zoe fallback load failed:", repr(e))
158
+
159
+ # 3) total failure
160
+ return None
161
+
162
+
163
+ # --------------------------------------------------------------------------------------
164
+ # Image utilities
165
+ # --------------------------------------------------------------------------------------
166
+
167
+ def _hwc3(x: np.ndarray) -> np.ndarray:
168
+ assert x.dtype == np.uint8
169
+ if x.ndim == 2:
170
+ x = x[:, :, None]
171
+ if x.shape[2] == 1:
172
+ return np.concatenate([x, x, x], axis=2)
173
+ if x.shape[2] == 3:
174
+ return x
175
+ if x.shape[2] == 4:
176
+ color = x[:, :, 0:3].astype(np.float32)
177
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
178
+ y = color * alpha + 255.0 * (1.0 - alpha)
179
+ return y.clip(0, 255).astype(np.uint8)
180
+ raise ValueError("Unexpected channel count")
181
+
182
+
183
+ def _pad64(n: int) -> int:
184
+ return int(np.ceil(float(n) / 64.0) * 64 - n)
185
+
186
+
187
+ def _resize_long_side(image_u8: np.ndarray, long_side: int) -> np.ndarray:
188
+ """
189
+ Resize so that max(H,W) == long_side. If long_side equals current long side -> no change.
190
+ """
191
+ h, w = image_u8.shape[:2]
192
+ cur_long = max(h, w)
193
+ if long_side <= 0 or long_side == cur_long:
194
+ return image_u8
195
+
196
+ scale = float(long_side) / float(cur_long)
197
+ new_w = int(round(w * scale))
198
+ new_h = int(round(h * scale))
199
+
200
+ pil = Image.fromarray(image_u8)
201
+ # Downscale with LANCZOS, upscale with BICUBIC
202
+ resample = Image.BICUBIC if scale > 1.0 else Image.LANCZOS
203
+ pil = pil.resize((new_w, new_h), resample=resample)
204
+ return np.array(pil, dtype=np.uint8)
205
+
206
+
207
+ def _pad_to_64(image_u8: np.ndarray, mode: str = "edge"):
208
+ h, w = image_u8.shape[:2]
209
+ hp = _pad64(h)
210
+ wp = _pad64(w)
211
+ padded = np.pad(image_u8, ((0, hp), (0, wp), (0, 0)), mode=mode)
212
+
213
+ def remove_pad(x: np.ndarray) -> np.ndarray:
214
+ return x[:h, :w, :]
215
+
216
+ return padded, remove_pad
217
+
218
+
219
+ def _comfy_to_u8(img: torch.Tensor) -> np.ndarray:
220
+ """
221
+ Comfy IMAGE is float [0..1], shape [H,W,C] or [B,H,W,C]
222
+ """
223
+ if img.ndim == 4:
224
+ img = img[0]
225
+ img = img.detach().cpu().float().clamp(0, 1)
226
+ arr = (img.numpy() * 255.0).round().astype(np.uint8)
227
+ return arr
228
+
229
+
230
+ def _u8_to_comfy(img_u8: np.ndarray) -> torch.Tensor:
231
+ img_u8 = _hwc3(img_u8)
232
+ t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
233
+ return t.unsqueeze(0) # [1,H,W,C]
234
+
235
+
236
+ def _depth_to_uint8(pipe, input_u8: np.ndarray, detect_long_side: int) -> np.ndarray:
237
+ """
238
+ Run depth estimation:
239
+ - resize (long side)
240
+ - pad to 64
241
+ - infer
242
+ - normalize (percentiles like your zoe code)
243
+ - remove pad
244
+ - return 3-channel uint8
245
+ """
246
+ input_u8 = _hwc3(input_u8)
247
+ resized = _resize_long_side(input_u8, detect_long_side)
248
+ padded, remove_pad = _pad_to_64(resized, mode="edge")
249
+
250
+ pil = Image.fromarray(padded)
251
+
252
+ with torch.no_grad():
253
+ result = pipe(pil)
254
+ depth = result["depth"]
255
+
256
+ if isinstance(depth, Image.Image):
257
+ depth_arr = np.array(depth, dtype=np.float32)
258
+ else:
259
+ depth_arr = np.array(depth, dtype=np.float32)
260
+
261
+ vmin = np.percentile(depth_arr, 2)
262
+ vmax = np.percentile(depth_arr, 85)
263
+ denom = (vmax - vmin) if (vmax - vmin) > 1e-6 else 1e-6
264
+
265
+ depth_arr = (depth_arr - vmin) / denom
266
+ depth_arr = 1.0 - depth_arr
267
+
268
+ depth_u8 = (depth_arr * 255.0).clip(0, 255).astype(np.uint8)
269
+
270
+ depth_rgb = _hwc3(depth_u8)
271
+ depth_rgb = remove_pad(depth_rgb)
272
+ return depth_rgb
273
+
274
+
275
+ # --------------------------------------------------------------------------------------
276
+ # ComfyUI Node
277
+ # --------------------------------------------------------------------------------------
278
+
279
+ class Salia_Depth_Preprocessor:
280
+ @classmethod
281
+ def INPUT_TYPES(cls):
282
+ return {
283
+ "required": {
284
+ "image": ("IMAGE",),
285
+ # note 5: default -1, min -1
286
+ "resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
287
+ }
288
+ }
289
+
290
+ RETURN_TYPES = ("IMAGE",)
291
+ FUNCTION = "execute"
292
+ CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
293
+
294
+ def execute(self, image, resolution=-1):
295
+ """
296
+ If everything fails (local model + zoe fallback), return input image unchanged.
297
+ """
298
+ try:
299
+ device = model_management.get_torch_device()
300
+ except Exception:
301
+ device = torch.device("cpu")
302
+
303
+ pipe = get_depth_pipeline(device)
304
+ if pipe is None:
305
+ # Hard fail: return input image unchanged
306
+ print("[SaliaDepth] No pipeline available. Returning input image unchanged.")
307
+ return (image,)
308
+
309
+ # Batch support: image is [B,H,W,C]
310
+ if image.ndim == 3:
311
+ image = image.unsqueeze(0)
312
+
313
+ outs = []
314
+ for i in range(image.shape[0]):
315
+ # original size
316
+ h0 = int(image[i].shape[0])
317
+ w0 = int(image[i].shape[1])
318
+ long_side = max(w0, h0)
319
+
320
+ detect_long_side = long_side if int(resolution) == -1 else int(resolution)
321
+
322
+ try:
323
+ inp_u8 = _comfy_to_u8(image[i])
324
+ depth_u8 = _depth_to_uint8(pipe, inp_u8, detect_long_side)
325
+
326
+ # resize depth back to original input size
327
+ pil = Image.fromarray(depth_u8)
328
+ pil = pil.resize((w0, h0), resample=Image.BILINEAR)
329
+ depth_u8 = np.array(pil, dtype=np.uint8)
330
+
331
+ outs.append(_u8_to_comfy(depth_u8))
332
+ except Exception as e:
333
+ # Per-image fail: return that image unchanged
334
+ print(f"[SaliaDepth] Inference failed for batch index {i}: {repr(e)}. Passing through input.")
335
+ outs.append(image[i].unsqueeze(0))
336
+
337
+ out = torch.cat(outs, dim=0)
338
+ return (out,)
339
+
340
+
341
+ NODE_CLASS_MAPPINGS = {
342
+ "SaliaDepthPreprocessor": Salia_Depth_Preprocessor
343
+ }
344
+
345
+ NODE_DISPLAY_NAME_MAPPINGS = {
346
+ "SaliaDepthPreprocessor": "Salia Depth (assets/depth local-first)"
347
+ }