depth / salia_depth.py
saliacoel's picture
Update salia_depth.py
79a1f7d verified
import os
import shutil
import urllib.request
from pathlib import Path
from typing import Dict, Tuple, Any, Optional, List
import numpy as np
import torch
from PIL import Image
import comfy.model_management as model_management
# transformers is required for depth-estimation pipeline
try:
from transformers import pipeline
except Exception as e:
pipeline = None
_TRANSFORMERS_IMPORT_ERROR = e
# --------------------------------------------------------------------------------------
# Paths / sources
# --------------------------------------------------------------------------------------
# This file: comfyui-salia_online/nodes/Salia_Depth.py
# Plugin root: comfyui-salia_online/
PLUGIN_ROOT = Path(__file__).resolve().parent.parent
# Requested local path: assets/depth
MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
REQUIRED_FILES = {
"config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
"model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
"preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
}
# "zoe-path" fallback
ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
# --------------------------------------------------------------------------------------
# Logging helpers
# --------------------------------------------------------------------------------------
def _make_logger() -> Tuple[List[str], Any]:
lines: List[str] = []
def log(msg: str):
# console
try:
print(msg)
except Exception:
pass
# UI string
lines.append(str(msg))
return lines, log
def _fmt_bytes(n: Optional[int]) -> str:
if n is None:
return "?"
# simple readable
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024:
return f"{n:.0f}{unit}"
n /= 1024.0
return f"{n:.1f}PB"
def _file_size(path: Path) -> Optional[int]:
try:
return path.stat().st_size
except Exception:
return None
def _hf_cache_info() -> Dict[str, str]:
info: Dict[str, str] = {}
info["env.HF_HOME"] = os.environ.get("HF_HOME", "")
info["env.HF_HUB_CACHE"] = os.environ.get("HF_HUB_CACHE", "")
info["env.TRANSFORMERS_CACHE"] = os.environ.get("TRANSFORMERS_CACHE", "")
info["env.HUGGINGFACE_HUB_CACHE"] = os.environ.get("HUGGINGFACE_HUB_CACHE", "")
try:
from huggingface_hub import constants as hf_constants
# These exist in most hub versions:
info["huggingface_hub.constants.HF_HOME"] = str(getattr(hf_constants, "HF_HOME", ""))
info["huggingface_hub.constants.HF_HUB_CACHE"] = str(getattr(hf_constants, "HF_HUB_CACHE", ""))
except Exception:
pass
return info
# --------------------------------------------------------------------------------------
# Download helpers
# --------------------------------------------------------------------------------------
def _have_required_files() -> bool:
return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
"""
Download with atomic temp rename.
"""
dst.parent.mkdir(parents=True, exist_ok=True)
tmp = dst.with_suffix(dst.suffix + ".tmp")
if tmp.exists():
try:
tmp.unlink()
except Exception:
pass
req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
shutil.copyfileobj(r, f)
tmp.replace(dst)
def ensure_local_model_files(log) -> bool:
"""
Ensure assets/depth contains the 3 files.
Returns True if present or downloaded successfully, else False.
"""
# Always log expected locations + URLs, even if we don't download.
log("[SaliaDepth] ===== Local model file check =====")
log(f"[SaliaDepth] Plugin root: {PLUGIN_ROOT}")
log(f"[SaliaDepth] Local model dir (on drive): {MODEL_DIR}")
for fname, url in REQUIRED_FILES.items():
fpath = MODEL_DIR / fname
exists = fpath.exists()
size = _file_size(fpath) if exists else None
log(f"[SaliaDepth] - {fname}")
log(f"[SaliaDepth] local path: {fpath} exists={exists} size={_fmt_bytes(size)}")
log(f"[SaliaDepth] remote url : {url}")
if _have_required_files():
log("[SaliaDepth] All required local files already exist. No download needed.")
return True
log("[SaliaDepth] One or more local files missing. Attempting download...")
try:
for fname, url in REQUIRED_FILES.items():
fpath = MODEL_DIR / fname
if fpath.exists():
continue
log(f"[SaliaDepth] Downloading '{fname}' -> '{fpath}'")
_download_url_to_file(url, fpath)
log(f"[SaliaDepth] Downloaded '{fname}' size={_fmt_bytes(_file_size(fpath))}")
ok = _have_required_files()
log(f"[SaliaDepth] Download finished. ok={ok}")
return ok
except Exception as e:
log(f"[SaliaDepth] Download failed with error: {repr(e)}")
return False
# --------------------------------------------------------------------------------------
# Exact Zoe-style preprocessing helpers (copied/adapted from your snippet)
# --------------------------------------------------------------------------------------
def HWC3(x: np.ndarray) -> np.ndarray:
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
# C == 4
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha) # white background
y = y.clip(0, 255).astype(np.uint8)
return y
def pad64(x: int) -> int:
return int(np.ceil(float(x) / 64.0) * 64 - x)
def safer_memory(x: np.ndarray) -> np.ndarray:
return np.ascontiguousarray(x.copy()).copy()
def resize_image_with_pad_min_side(
input_image: np.ndarray,
resolution: int,
upscale_method: str = "INTER_CUBIC",
skip_hwc3: bool = False,
mode: str = "edge",
log=None
) -> Tuple[np.ndarray, Any]:
"""
EXACT behavior like your zoe.transformers.py:
k = resolution / min(H,W)
resize to (W_target, H_target)
pad to multiple of 64
return padded image and remove_pad() closure
"""
# prefer cv2 like original for matching results
cv2 = None
try:
import cv2 as _cv2
cv2 = _cv2
except Exception:
cv2 = None
if log:
log("[SaliaDepth] WARN: cv2 not available; resizing will use PIL fallback (may change results).")
if skip_hwc3:
img = input_image
else:
img = HWC3(input_image)
H_raw, W_raw, _ = img.shape
if resolution <= 0:
# keep original, but still pad to 64 (we will handle padding separately for -1 path)
return img, (lambda x: x)
k = float(resolution) / float(min(H_raw, W_raw))
H_target = int(np.round(float(H_raw) * k))
W_target = int(np.round(float(W_raw) * k))
if cv2 is not None:
upscale_methods = {
"INTER_NEAREST": cv2.INTER_NEAREST,
"INTER_LINEAR": cv2.INTER_LINEAR,
"INTER_AREA": cv2.INTER_AREA,
"INTER_CUBIC": cv2.INTER_CUBIC,
"INTER_LANCZOS4": cv2.INTER_LANCZOS4,
}
method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
else:
# PIL fallback
pil = Image.fromarray(img)
resample = Image.BICUBIC if k > 1 else Image.LANCZOS
pil = pil.resize((W_target, H_target), resample=resample)
img = np.array(pil, dtype=np.uint8)
H_pad, W_pad = pad64(H_target), pad64(W_target)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
def remove_pad(x: np.ndarray) -> np.ndarray:
return safer_memory(x[:H_target, :W_target, ...])
return safer_memory(img_padded), remove_pad
def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
"""
For resolution == -1: keep original resolution but still pad to multiples of 64,
then provide remove_pad that returns original size.
"""
img = HWC3(img_u8)
H_raw, W_raw, _ = img.shape
H_pad, W_pad = pad64(H_raw), pad64(W_raw)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
def remove_pad(x: np.ndarray) -> np.ndarray:
return safer_memory(x[:H_raw, :W_raw, ...])
return safer_memory(img_padded), remove_pad
# --------------------------------------------------------------------------------------
# RGBA rules (as you requested)
# --------------------------------------------------------------------------------------
def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
If RGBA: return RGB composited over WHITE + alpha_u8 kept separately.
If RGB: return input RGB + None alpha.
"""
if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
rgba = inp_u8.astype(np.uint8)
rgb = rgba[:, :, 0:3].astype(np.float32)
a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
alpha_u8 = rgba[:, :, 3].copy()
return rgb_white, alpha_u8
# force to RGB
return HWC3(inp_u8), None
def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
"""
Requested output rule:
- attach alpha to depth (conceptually RGBA)
- composite over BLACK
- output RGB
That is equivalent to depth_rgb * alpha.
"""
depth_rgb_u8 = HWC3(depth_rgb_u8)
a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
return out
# --------------------------------------------------------------------------------------
# ComfyUI conversion helpers
# --------------------------------------------------------------------------------------
def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
"""
Comfy IMAGE: float [0..1], shape [H,W,C] or [B,H,W,C]
Convert to uint8 HWC.
"""
if img.ndim == 4:
img = img[0]
arr = img.detach().cpu().float().clamp(0, 1).numpy()
u8 = (arr * 255.0).round().astype(np.uint8)
return u8
def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
img_u8 = HWC3(img_u8)
t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
return t.unsqueeze(0) # [1,H,W,C]
# --------------------------------------------------------------------------------------
# Pipeline loading (local-first, then zoe fallback)
# --------------------------------------------------------------------------------------
_PIPE_CACHE: Dict[Tuple[str, str], Any] = {} # (model_source, device_str) -> pipeline
def _try_load_pipeline(model_source: str, device: torch.device, log):
"""
Use transformers.pipeline like Zoe code does.
We intentionally do NOT pass device=... here, and instead move model like Zoe node.
"""
if pipeline is None:
raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
key = (model_source, str(device))
if key in _PIPE_CACHE:
log(f"[SaliaDepth] Using cached pipeline for source='{model_source}' device='{device}'")
return _PIPE_CACHE[key]
log(f"[SaliaDepth] Creating pipeline(task='depth-estimation', model='{model_source}')")
p = pipeline(task="depth-estimation", model=model_source)
# Try to move model to torch device, like ZoeDetector.to()
try:
p.model = p.model.to(device)
p.device = device # Zoe code sets this; newer transformers uses torch.device internally
log(f"[SaliaDepth] Moved pipeline model to device: {device}")
except Exception as e:
log(f"[SaliaDepth] WARN: Could not move pipeline model to device {device}: {repr(e)}")
# Log config info for debugging
try:
cfg = p.model.config
log(f"[SaliaDepth] Model class: {p.model.__class__.__name__}")
log(f"[SaliaDepth] Config class: {cfg.__class__.__name__}")
log(f"[SaliaDepth] Config model_type: {getattr(cfg, 'model_type', '')}")
log(f"[SaliaDepth] Config _name_or_path: {getattr(cfg, '_name_or_path', '')}")
except Exception as e:
log(f"[SaliaDepth] WARN: Could not log model config: {repr(e)}")
_PIPE_CACHE[key] = p
return p
def get_depth_pipeline(device: torch.device, log):
"""
1) Ensure assets/depth files exist (download if missing)
2) Try load local dir
3) Fallback to Intel/zoedepth-nyu-kitti
4) If both fail -> None
"""
# Always log HF cache info (helps locate where fallback downloads go)
log("[SaliaDepth] ===== Hugging Face cache info (fallback path) =====")
for k, v in _hf_cache_info().items():
if v:
log(f"[SaliaDepth] {k} = {v}")
log(f"[SaliaDepth] Zoe fallback repo id: {ZOE_FALLBACK_REPO_ID}")
# Local-first
local_ok = ensure_local_model_files(log)
if local_ok:
try:
log(f"[SaliaDepth] Trying LOCAL model from directory: {MODEL_DIR}")
return _try_load_pipeline(str(MODEL_DIR), device, log)
except Exception as e:
log(f"[SaliaDepth] Local model load FAILED: {repr(e)}")
# Fallback
try:
log(f"[SaliaDepth] Trying ZOE fallback model: {ZOE_FALLBACK_REPO_ID}")
return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device, log)
except Exception as e:
log(f"[SaliaDepth] Zoe fallback load FAILED: {repr(e)}")
return None
# --------------------------------------------------------------------------------------
# Depth inference (Zoe-style)
# --------------------------------------------------------------------------------------
def depth_estimate_zoe_style(
pipe,
input_rgb_u8: np.ndarray,
detect_resolution: int,
log,
upscale_method: str = "INTER_CUBIC"
) -> np.ndarray:
"""
Matches your ZoeDetector.__call__ logic very closely.
Returns uint8 RGB depth map.
"""
# detect_resolution:
# - if -1: keep original but pad-to-64
# - else: min-side resize to detect_resolution, then pad-to-64
if detect_resolution == -1:
work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
log(f"[SaliaDepth] Preprocess: resolution=-1 (no resize), padded to 64. work={work_img.shape}")
else:
work_img, remove_pad = resize_image_with_pad_min_side(
input_rgb_u8,
int(detect_resolution),
upscale_method=upscale_method,
skip_hwc3=False,
mode="edge",
log=log
)
log(f"[SaliaDepth] Preprocess: min-side resized to {detect_resolution}, padded to 64. work={work_img.shape}")
pil_image = Image.fromarray(work_img)
with torch.no_grad():
result = pipe(pil_image)
depth = result["depth"]
if isinstance(depth, Image.Image):
depth_array = np.array(depth, dtype=np.float32)
else:
depth_array = np.array(depth, dtype=np.float32)
# EXACT normalization like your Zoe code
vmin = float(np.percentile(depth_array, 2))
vmax = float(np.percentile(depth_array, 85))
log(f"[SaliaDepth] Depth raw stats: shape={depth_array.shape} vmin(p2)={vmin:.6f} vmax(p85)={vmax:.6f} mean={float(depth_array.mean()):.6f}")
depth_array = depth_array - vmin
denom = (vmax - vmin)
if abs(denom) < 1e-12:
# avoid division by zero; log it
log("[SaliaDepth] WARN: vmax==vmin; forcing denom epsilon to avoid NaNs.")
denom = 1e-6
depth_array = depth_array / denom
# EXACT invert like your Zoe code
depth_array = 1.0 - depth_array
depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
detected_map = remove_pad(HWC3(depth_image))
log(f"[SaliaDepth] Output (post-remove_pad): {detected_map.shape} dtype={detected_map.dtype}")
return detected_map
def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int, log) -> np.ndarray:
"""
Resize depth output back to original input size.
Use cv2 if available, else PIL.
"""
try:
import cv2
out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
return out.astype(np.uint8)
except Exception as e:
log(f"[SaliaDepth] WARN: cv2 resize failed ({repr(e)}); using PIL.")
pil = Image.fromarray(depth_rgb_u8)
pil = pil.resize((w0, h0), resample=Image.BILINEAR)
return np.array(pil, dtype=np.uint8)
# --------------------------------------------------------------------------------------
# ComfyUI Node
# --------------------------------------------------------------------------------------
class Salia_Depth_Preprocessor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
# note: default -1, min -1
"resolution": ("INT", {"default": -1, "min": -1, "max": 8192, "step": 1}),
}
}
# 2 outputs: image + log string
RETURN_TYPES = ("IMAGE", "STRING")
FUNCTION = "execute"
CATEGORY = "ControlNet Preprocessors/Normal and Depth Estimators"
def execute(self, image, resolution=-1):
lines, log = _make_logger()
log("[SaliaDepth] ==================================================")
log("[SaliaDepth] SaliaDepthPreprocessor starting")
log(f"[SaliaDepth] resolution input = {resolution}")
# Get torch device
try:
device = model_management.get_torch_device()
except Exception as e:
device = torch.device("cpu")
log(f"[SaliaDepth] WARN: model_management.get_torch_device failed: {repr(e)} -> using CPU")
log(f"[SaliaDepth] torch device = {device}")
# Load pipeline
pipe = None
try:
pipe = get_depth_pipeline(device, log)
except Exception as e:
log(f"[SaliaDepth] ERROR: get_depth_pipeline crashed: {repr(e)}")
pipe = None
if pipe is None:
log("[SaliaDepth] FATAL: No pipeline available. Returning input image unchanged.")
return (image, "\n".join(lines))
# Batch support
if image.ndim == 3:
image = image.unsqueeze(0)
outs = []
for i in range(image.shape[0]):
try:
# Original dimensions
h0 = int(image[i].shape[0])
w0 = int(image[i].shape[1])
c0 = int(image[i].shape[2])
log(f"[SaliaDepth] ---- Batch index {i} input shape = ({h0},{w0},{c0}) ----")
inp_u8 = comfy_tensor_to_u8(image[i])
# RGBA rule (pre)
rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
had_rgba = alpha_u8 is not None
log(f"[SaliaDepth] had_rgba={had_rgba}")
# Run depth (Zoe-style)
depth_rgb = depth_estimate_zoe_style(
pipe=pipe,
input_rgb_u8=rgb_for_depth,
detect_resolution=int(resolution),
log=log,
upscale_method="INTER_CUBIC"
)
# Resize back to original input size
depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0, log=log)
# RGBA rule (post)
if had_rgba:
# Use original alpha at original size.
# If alpha size differs, resize alpha to match.
if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
log("[SaliaDepth] Alpha size mismatch; resizing alpha to original size.")
try:
import cv2
alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
except Exception:
pil_a = Image.fromarray(alpha_u8)
pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
alpha_u8 = np.array(pil_a, dtype=np.uint8)
# "Put alpha on RGB turning it into RGBA, then put BLACK background behind it, then back to RGB"
depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
log("[SaliaDepth] Applied RGBA post-step (alpha + black background).")
outs.append(u8_to_comfy_tensor(depth_rgb))
except Exception as e:
log(f"[SaliaDepth] ERROR: Inference failed at batch index {i}: {repr(e)}")
log("[SaliaDepth] Passing through original input image for this batch item.")
outs.append(image[i].unsqueeze(0))
out = torch.cat(outs, dim=0)
log("[SaliaDepth] Done.")
return (out, "\n".join(lines))
NODE_CLASS_MAPPINGS = {
"SaliaDepthPreprocessor": Salia_Depth_Preprocessor
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaliaDepthPreprocessor": "Salia Depth (local assets/depth + logs)"
}