""" Salia Ultralytics Detector Provider (ComfyUI custom node) Goal: - Provide the same outputs as Impact-Subpack's `UltralyticsDetectorProvider`: - BBOX_DETECTOR - SEGM_DETECTOR - But packaged so you can drop it into your own custom node folder (your Salia_* environment) without requiring ComfyUI-Impact-Subpack. Notes: - This file intentionally keeps dependencies minimal and self-contained. - It uses `ultralytics.YOLO` to run `.pt` models directly (no TensorRT build step). - For PyTorch >= 2.6, `torch.load` defaults to `weights_only=True` which can break legacy `.pt` checkpoints. This file adds an OPTIONAL whitelist-based fallback to `weights_only=False` (unsafe) for specifically trusted model filenames. """ from __future__ import annotations import os import logging import pickle from datetime import datetime from contextlib import contextmanager from collections import namedtuple import folder_paths from PIL import Image import numpy as np import torch import torch.nn.functional as F try: import cv2 # opencv-python or opencv-python-headless except Exception: cv2 = None # --------------------------- # Model folders (same layout as Impact Subpack) # --------------------------- _SUPPORTED_PT_EXTS = getattr(folder_paths, "supported_pt_extensions", [".pt", ".pth", ".ckpt", ".safetensors"]) def _add_folder_path_and_extensions(folder_name: str, paths: list[str], extensions: list[str] | tuple[str, ...]): """Add/merge a folder_paths entry without depending on Impact-Pack helpers.""" if folder_name in folder_paths.folder_names_and_paths: existing_paths, existing_exts = folder_paths.folder_names_and_paths[folder_name] merged_paths = list(existing_paths) for p in paths: if p not in merged_paths: merged_paths.append(p) merged_exts = list(existing_exts) for ext in extensions: if ext not in merged_exts: merged_exts.append(ext) folder_paths.folder_names_and_paths[folder_name] = (merged_paths, tuple(merged_exts)) else: folder_paths.folder_names_and_paths[folder_name] = (list(paths), tuple(extensions)) def _update_model_paths(base_path: str): """Register standard Impact-Subpack ultralytics model locations.""" _add_folder_path_and_extensions( "ultralytics_bbox", [os.path.join(base_path, "ultralytics", "bbox")], _SUPPORTED_PT_EXTS, ) _add_folder_path_and_extensions( "ultralytics_segm", [os.path.join(base_path, "ultralytics", "segm")], _SUPPORTED_PT_EXTS, ) _add_folder_path_and_extensions( "ultralytics", [os.path.join(base_path, "ultralytics")], _SUPPORTED_PT_EXTS, ) # Register common folders (models_dir + ComfyUI-Manager download_model_base) _update_model_paths(folder_paths.models_dir) if "download_model_base" in folder_paths.folder_names_and_paths: try: _update_model_paths(folder_paths.get_folder_paths("download_model_base")[0]) except Exception: pass # Also register local folder(s) inside THIS custom-node extension, so you can keep # models next to your Salia_*.py files if you want. _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) for local_dir in [ os.path.join(_THIS_DIR, "nodes"), os.path.join(_THIS_DIR, "models"), _THIS_DIR, ]: if os.path.isdir(local_dir): _add_folder_path_and_extensions("ultralytics_bbox", [local_dir], _SUPPORTED_PT_EXTS) _add_folder_path_and_extensions("ultralytics_segm", [local_dir], _SUPPORTED_PT_EXTS) _add_folder_path_and_extensions("ultralytics", [local_dir], _SUPPORTED_PT_EXTS) # --------------------------- # Optional safe-load fallback (PyTorch >= 2.6) # --------------------------- _ORIG_TORCH_LOAD = torch.load def _get_whitelist_file() -> str | None: """Create/return the whitelist file path under ComfyUI's user directory.""" try: user_dir = folder_paths.get_user_directory() except Exception: user_dir = None if not user_dir or not os.path.isdir(user_dir): return None wl_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics") wl_file = os.path.join(wl_dir, "model-whitelist.txt") try: os.makedirs(wl_dir, exist_ok=True) if not os.path.exists(wl_file): with open(wl_file, "w", encoding="utf-8") as f: f.write("# Add base filenames of trusted legacy models here (one per line).\n") f.write("# Example: eyes.pt\n") f.write("# These will be allowed to load with weights_only=False if safe loading fails.\n") f.write("# WARNING: Only add models you trust.\n") except Exception: return None return wl_file _WHITELIST_PATH = _get_whitelist_file() # --------------------------- # Model path logging (requested) # --------------------------- def _get_model_load_log_file() -> str: """ Log file path used to record which ultralytics model file was actually loaded. Prefer the same ComfyUI user dir used for the whitelist (if available). """ # If whitelist exists, put log next to it (same directory). if _WHITELIST_PATH: base_dir = os.path.dirname(_WHITELIST_PATH) return os.path.join(base_dir, "model-load-log.txt") # Fallback: try ComfyUI user directory try: user_dir = folder_paths.get_user_directory() except Exception: user_dir = None if user_dir and os.path.isdir(user_dir): base_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics") try: os.makedirs(base_dir, exist_ok=True) except Exception: pass return os.path.join(base_dir, "model-load-log.txt") # Last resort: next to this python file return os.path.join(_THIS_DIR, "model-load-log.txt") _MODEL_LOAD_LOG_PATH = _get_model_load_log_file() def _find_all_model_paths(model_name: str) -> list[str]: """ Find all possible on-disk matches across the registered ultralytics folders. Useful if the same filename exists in multiple locations. """ matches: list[str] = [] try: ultra_roots = folder_paths.get_folder_paths("ultralytics") except Exception: ultra_roots = [] try: bbox_roots = folder_paths.get_folder_paths("ultralytics_bbox") except Exception: bbox_roots = [] try: segm_roots = folder_paths.get_folder_paths("ultralytics_segm") except Exception: segm_roots = [] def add_if_exists(root: str, rel: str): p = os.path.join(root, rel) if os.path.exists(p): matches.append(os.path.abspath(p)) # model_name might be "bbox/foo.pt" or "segm/foo.pt" (includes subfolder) for r in ultra_roots: add_if_exists(r, model_name) # Also search the specialized bbox/segm roots with the prefix stripped if model_name.startswith("bbox/"): rel = model_name[5:] for r in bbox_roots: add_if_exists(r, rel) elif model_name.startswith("segm/"): rel = model_name[5:] for r in segm_roots: add_if_exists(r, rel) # De-dupe preserving order out: list[str] = [] seen = set() for p in matches: if p not in seen: seen.add(p) out.append(p) return out def _log_selected_model(model_name: str, model_path: str, matches: list[str] | None = None): """ Prints the resolved model path to console AND appends it to a log file. """ # 1) Console output print(f"[Salia Ultralytics] Selected model_name: {model_name}") print(f"[Salia Ultralytics] Resolved model_path: {model_path}") if matches and len(matches) > 1: print("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):") for p in matches: print(f" - {p}") print(f"[Salia Ultralytics] Model load log file: {_MODEL_LOAD_LOG_PATH}") # Also emit to python logging (ComfyUI typically captures this) logging.info("[Salia Ultralytics] Selected model_name: %s", model_name) logging.info("[Salia Ultralytics] Resolved model_path: %s", model_path) if matches and len(matches) > 1: logging.warning("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):") for p in matches: logging.warning(" - %s", p) logging.info("[Salia Ultralytics] Model load log file: %s", _MODEL_LOAD_LOG_PATH) # 2) File append try: ts = datetime.now().isoformat(timespec="seconds") exists = os.path.isfile(model_path) size = os.path.getsize(model_path) if exists else -1 log_dir = os.path.dirname(_MODEL_LOAD_LOG_PATH) if log_dir: os.makedirs(log_dir, exist_ok=True) with open(_MODEL_LOAD_LOG_PATH, "a", encoding="utf-8") as f: f.write(f"{ts}\t{model_name}\t{model_path}\texists={exists}\tsize={size}\n") if matches and len(matches) > 1: for p in matches: f.write(f"{ts}\tmatch\t{p}\n") except Exception as e: logging.warning("[Salia Ultralytics] Failed to write model-load log to %s: %s", _MODEL_LOAD_LOG_PATH, e) def _load_whitelist(filepath: str | None) -> set[str]: if not filepath: return set() try: approved: set[str] = set() with open(filepath, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line and not line.startswith("#"): approved.add(os.path.basename(line)) return approved except Exception: return set() _MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH) def _torch_load_wrapper(*args, **kwargs): """Try safe load first; if it fails due to weights-only restrictions, allow fallback if whitelisted.""" filename = None if args and isinstance(args[0], str): filename = os.path.basename(args[0]) elif isinstance(kwargs.get("f"), str): filename = os.path.basename(kwargs["f"]) try: return _ORIG_TORCH_LOAD(*args, **kwargs) except pickle.UnpicklingError as e: msg = str(e) # Heuristic: this is the common PyTorch >=2.6 safe-load failure mode. maybe_weights_only_error = ( "Weights only load failed" in msg or "Unsupported global" in msg or "disallowed" in msg or "not allowed" in msg or "getattr" in msg ) if not maybe_weights_only_error: raise # Refresh whitelist from disk (so users can edit without restarting, sometimes) global _MODEL_WHITELIST _MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH) if filename and filename in _MODEL_WHITELIST: logging.warning( "[Salia Ultralytics] Safe torch.load failed for '%s'. Retrying with weights_only=False because it's whitelisted (%s).", filename, _WHITELIST_PATH, ) retry_kwargs = dict(kwargs) retry_kwargs["weights_only"] = False return _ORIG_TORCH_LOAD(*args, **retry_kwargs) logging.error( "[Salia Ultralytics] Blocked unsafe model load for '%s'.\n" "Safe loading failed and the file is not whitelisted.\n" "If you TRUST this model, add its base name to: %s", filename or "[unknown]", _WHITELIST_PATH or "[whitelist path unavailable]", ) raise @contextmanager def _patched_torch_load_for_ultralytics(): """Patch torch.load only while ultralytics loads a checkpoint.""" # If PyTorch doesn't even have the safe-loader feature, don't patch. if not hasattr(torch.serialization, "safe_globals"): yield return prev = torch.load torch.load = _torch_load_wrapper try: yield finally: torch.load = prev def _load_yolo(model_path: str): """Load an Ultralytics YOLO model (with optional safe-load fallback).""" try: from ultralytics import YOLO # lazy import except Exception as e: raise ImportError( "[Salia Ultralytics] ultralytics is not installed. Install it in your ComfyUI env, e.g.:\n" "pip install ultralytics" ) from e with _patched_torch_load_for_ultralytics(): return YOLO(model_path) # --------------------------- # Minimal Impact-compatible utilities (self-contained) # --------------------------- def _tensor2np_rgb(image: torch.Tensor) -> np.ndarray: """Convert a ComfyUI IMAGE tensor to a uint8 RGB numpy image.""" # ComfyUI image is usually: (B,H,W,C) float in [0,1] if not isinstance(image, torch.Tensor): raise TypeError(f"Expected torch.Tensor, got {type(image)}") if image.dim() == 4: img = image[0] else: img = image img = img.detach() if img.is_cuda: img = img.cpu() img = img.clamp(0, 1).numpy() if img.shape[-1] == 1: img = np.repeat(img, 3, axis=-1) img_u8 = (img * 255.0).round().astype(np.uint8) return img_u8 def tensor2pil(image: torch.Tensor) -> Image.Image: return Image.fromarray(_tensor2np_rgb(image)) def make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float, crop_min_size: int | None = None): x1, y1, x2, y2 = [float(v) for v in bbox_xyxy] bbox_w = max(1.0, x2 - x1) bbox_h = max(1.0, y2 - y1) crop_w = bbox_w * float(crop_factor) crop_h = bbox_h * float(crop_factor) if crop_min_size is not None: crop_w = max(crop_w, float(crop_min_size)) crop_h = max(crop_h, float(crop_min_size)) cx = (x1 + x2) / 2.0 cy = (y1 + y2) / 2.0 rx1 = int(round(cx - crop_w / 2.0)) ry1 = int(round(cy - crop_h / 2.0)) rx2 = int(round(cx + crop_w / 2.0)) ry2 = int(round(cy + crop_h / 2.0)) rx1 = max(0, min(w - 1, rx1)) ry1 = max(0, min(h - 1, ry1)) rx2 = max(rx1 + 1, min(w, rx2)) ry2 = max(ry1 + 1, min(h, ry2)) return (rx1, ry1, rx2, ry2) def crop_image(image: torch.Tensor, crop_region): x1, y1, x2, y2 = crop_region x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) if image.dim() == 4: return image[:, y1:y2, x1:x2, :] if image.dim() == 3: return image[y1:y2, x1:x2, :] raise ValueError(f"Unexpected image tensor shape: {tuple(image.shape)}") def crop_ndarray2(arr: np.ndarray, crop_region): x1, y1, x2, y2 = crop_region x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) return arr[y1:y2, x1:x2] def dilate_masks(segmasks, dilation: int): if dilation <= 0: return segmasks if cv2 is None: raise ImportError( "[Salia Ultralytics] opencv-python is required for mask dilation but cv2 could not be imported.\n" "Install: pip install opencv-python-headless" ) k = int(dilation) ksize = k * 2 + 1 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) out = [] for bbox, mask, conf in segmasks: m = (mask > 0.5).astype(np.uint8) * 255 m = cv2.dilate(m, kernel, iterations=1) out.append((bbox, (m > 0).astype(np.float32), conf)) return out def combine_masks(segmasks, out_shape_hw: tuple[int, int] | None = None) -> torch.Tensor: if not segmasks: if out_shape_hw is None: return torch.zeros((1, 1, 1), dtype=torch.float32) h, w = out_shape_hw return torch.zeros((1, h, w), dtype=torch.float32) base = segmasks[0][1] combined = np.zeros_like(base, dtype=np.float32) for _, m, _ in segmasks: combined = np.maximum(combined, m.astype(np.float32)) return torch.from_numpy(combined).unsqueeze(0) # --------------------------- # Impact-compatible detector wrapper objects # --------------------------- SEG = namedtuple( "SEG", [ "cropped_image", "cropped_mask", "confidence", "crop_region", "bbox", "label", "control_net_wrapper", ], defaults=[None], ) class NO_BBOX_DETECTOR: pass class NO_SEGM_DETECTOR: pass def _create_segmasks(results): # results = [labels, bboxes_xyxy, segms, confs] bboxes = results[1] segms = results[2] confs = results[3] out = [] for i in range(len(segms)): out.append((bboxes[i], segms[i].astype(np.float32), confs[i])) return out def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""): pred = model(image_pil, conf=confidence, device=device) bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy if bboxes.shape[0] == 0: return [[], [], [], []] # Make simple rectangle masks for each bbox np_img = np.array(image_pil) if np_img.ndim == 2: h, w = np_img.shape else: h, w = np_img.shape[0], np_img.shape[1] segms = [] for x0, y0, x1, y1 in bboxes: m = np.zeros((h, w), dtype=np.uint8) x0i, y0i, x1i, y1i = int(x0), int(y0), int(x1), int(y1) x0i = max(0, min(w - 1, x0i)) x1i = max(0, min(w, x1i)) y0i = max(0, min(h - 1, y0i)) y1i = max(0, min(h, y1i)) if cv2 is not None: cv2.rectangle(m, (x0i, y0i), (x1i, y1i), 255, -1) else: m[y0i:y1i, x0i:x1i] = 255 segms.append((m > 0)) labels = [] confs = [] for i in range(len(bboxes)): labels.append(pred[0].names[int(pred[0].boxes[i].cls.item())]) confs.append(pred[0].boxes[i].conf.detach().cpu().numpy()) return [labels, list(bboxes), segms, confs] def _inference_segm(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""): pred = model(image_pil, conf=confidence, device=device) bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy if bboxes.shape[0] == 0: return [[], [], [], []] if pred[0].masks is None or pred[0].masks.data is None: # fallback: no masks, treat like bbox return _inference_bbox(model, image_pil, confidence=confidence, device=device) segms = pred[0].masks.data.detach().cpu().numpy() # (n, h, w) in model-space # Resize masks back to original image size h_orig = image_pil.size[1] w_orig = image_pil.size[0] results = [[], [], [], []] for i in range(len(bboxes)): results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) results[1].append(bboxes[i]) mask = torch.from_numpy(segms[i]).float() mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(h_orig, w_orig), mode="bilinear", align_corners=False) mask = mask.squeeze(0).squeeze(0) results[2].append(mask.numpy()) results[3].append(pred[0].boxes[i].conf.detach().cpu().numpy()) return results class SaliaUltraBBoxDetector: def __init__(self, model): self.model = model def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): drop_size = max(int(drop_size), 1) detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold)) segmasks = _create_segmasks(detected) if int(dilation) > 0: segmasks = dilate_masks(segmasks, int(dilation)) items = [] h = image.shape[1] w = image.shape[2] for (bbox, mask, conf), label in zip(segmasks, detected[0]): x1, y1, x2, y2 = bbox if (x2 - x1) > drop_size and (y2 - y1) > drop_size: crop_region = make_crop_region(w, h, bbox, float(crop_factor)) if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region) cropped_image = crop_image(image, crop_region) cropped_mask = crop_ndarray2(mask, crop_region) items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None)) segs = (image.shape[1], image.shape[2]), items if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): segs = detailer_hook.post_detection(segs) return segs def detect_combined(self, image, threshold, dilation): detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold)) segmasks = _create_segmasks(detected) if int(dilation) > 0: segmasks = dilate_masks(segmasks, int(dilation)) return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2])) def setAux(self, x): pass class SaliaUltraSegmDetector: def __init__(self, model): self.model = model def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): drop_size = max(int(drop_size), 1) detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold)) segmasks = _create_segmasks(detected) if int(dilation) > 0: segmasks = dilate_masks(segmasks, int(dilation)) items = [] h = image.shape[1] w = image.shape[2] for (bbox, mask, conf), label in zip(segmasks, detected[0]): x1, y1, x2, y2 = bbox if (x2 - x1) > drop_size and (y2 - y1) > drop_size: crop_region = make_crop_region(w, h, bbox, float(crop_factor)) if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"): crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region) cropped_image = crop_image(image, crop_region) cropped_mask = crop_ndarray2(mask, crop_region) items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None)) segs = (image.shape[1], image.shape[2]), items if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): segs = detailer_hook.post_detection(segs) return segs def detect_combined(self, image, threshold, dilation): detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold)) segmasks = _create_segmasks(detected) if int(dilation) > 0: segmasks = dilate_masks(segmasks, int(dilation)) return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2])) def setAux(self, x): pass # --------------------------- # The actual ComfyUI Node # --------------------------- class SaliaUltralyticsDetectorProvider2: """Load an Ultralytics `.pt` model and provide Impact-compatible detectors.""" @classmethod def INPUT_TYPES(cls): bboxs = ["bbox/" + x for x in folder_paths.get_filename_list("ultralytics_bbox")] segms = ["segm/" + x for x in folder_paths.get_filename_list("ultralytics_segm")] return {"required": {"model_name": (bboxs + segms,)}} RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR") FUNCTION = "doit" CATEGORY = "Salia/Detectors" def doit(self, model_name: str): # First, allow selecting a file like "bbox/foo.pt" that lives under models/ultralytics/bbox model_path = folder_paths.get_full_path("ultralytics", model_name) if model_path is None: if model_name.startswith("bbox/"): model_path = folder_paths.get_full_path("ultralytics_bbox", model_name[5:]) elif model_name.startswith("segm/"): model_path = folder_paths.get_full_path("ultralytics_segm", model_name[5:]) if model_path is None: cands = [] try: cands.extend(folder_paths.get_folder_paths("ultralytics")) if model_name.startswith("bbox/"): cands.extend(folder_paths.get_folder_paths("ultralytics_bbox")) elif model_name.startswith("segm/"): cands.extend(folder_paths.get_folder_paths("ultralytics_segm")) except Exception: pass formatted = "\n\t".join(cands) raise ValueError( f"[Salia Ultralytics] model file '{model_name}' was not found.\n" f"Searched these folders:\n\t{formatted}\n" f"Tip: put bbox models in 'models/ultralytics/bbox' or segm models in 'models/ultralytics/segm'." ) # NEW: print + log the resolved on-disk path (and any duplicates) matches = _find_all_model_paths(model_name) _log_selected_model(model_name, os.path.abspath(model_path), matches) model = _load_yolo(model_path) if model_name.startswith("bbox/"): return SaliaUltraBBoxDetector(model), NO_SEGM_DETECTOR() else: return SaliaUltraBBoxDetector(model), SaliaUltraSegmDetector(model) NODE_CLASS_MAPPINGS = { "SaliaUltralyticsDetectorProvider2": SaliaUltralyticsDetectorProvider2, } NODE_DISPLAY_NAME_MAPPINGS = { "SaliaUltralyticsDetectorProvider2": "Salia Ultralytics Detector 2 (Salia)", }