saliacoel commited on
Commit
01cf9f1
·
verified ·
1 Parent(s): 673a4b7

Upload Salia_UltralyticsDetectorProvider2.py

Browse files
Files changed (1) hide show
  1. Salia_UltralyticsDetectorProvider2.py +741 -0
Salia_UltralyticsDetectorProvider2.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Salia Ultralytics Detector Provider (ComfyUI custom node)
3
+
4
+ Goal:
5
+ - Provide the same outputs as Impact-Subpack's `UltralyticsDetectorProvider`:
6
+ - BBOX_DETECTOR
7
+ - SEGM_DETECTOR
8
+ - But packaged so you can drop it into your own custom node folder (your Salia_* environment)
9
+ without requiring ComfyUI-Impact-Subpack.
10
+
11
+ Notes:
12
+ - This file intentionally keeps dependencies minimal and self-contained.
13
+ - It uses `ultralytics.YOLO` to run `.pt` models directly (no TensorRT build step).
14
+ - For PyTorch >= 2.6, `torch.load` defaults to `weights_only=True` which can break
15
+ legacy `.pt` checkpoints. This file adds an OPTIONAL whitelist-based fallback
16
+ to `weights_only=False` (unsafe) for specifically trusted model filenames.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import os
22
+ import logging
23
+ import pickle
24
+ from datetime import datetime
25
+ from contextlib import contextmanager
26
+ from collections import namedtuple
27
+
28
+ import folder_paths
29
+
30
+ from PIL import Image
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn.functional as F
34
+
35
+ try:
36
+ import cv2 # opencv-python or opencv-python-headless
37
+ except Exception:
38
+ cv2 = None
39
+
40
+
41
+ # ---------------------------
42
+ # Model folders (same layout as Impact Subpack)
43
+ # ---------------------------
44
+
45
+ _SUPPORTED_PT_EXTS = getattr(folder_paths, "supported_pt_extensions", [".pt", ".pth", ".ckpt", ".safetensors"])
46
+
47
+
48
+ def _add_folder_path_and_extensions(folder_name: str, paths: list[str], extensions: list[str] | tuple[str, ...]):
49
+ """Add/merge a folder_paths entry without depending on Impact-Pack helpers."""
50
+ if folder_name in folder_paths.folder_names_and_paths:
51
+ existing_paths, existing_exts = folder_paths.folder_names_and_paths[folder_name]
52
+ merged_paths = list(existing_paths)
53
+ for p in paths:
54
+ if p not in merged_paths:
55
+ merged_paths.append(p)
56
+ merged_exts = list(existing_exts)
57
+ for ext in extensions:
58
+ if ext not in merged_exts:
59
+ merged_exts.append(ext)
60
+ folder_paths.folder_names_and_paths[folder_name] = (merged_paths, tuple(merged_exts))
61
+ else:
62
+ folder_paths.folder_names_and_paths[folder_name] = (list(paths), tuple(extensions))
63
+
64
+
65
+ def _update_model_paths(base_path: str):
66
+ """Register standard Impact-Subpack ultralytics model locations."""
67
+ _add_folder_path_and_extensions(
68
+ "ultralytics_bbox",
69
+ [os.path.join(base_path, "ultralytics", "bbox")],
70
+ _SUPPORTED_PT_EXTS,
71
+ )
72
+ _add_folder_path_and_extensions(
73
+ "ultralytics_segm",
74
+ [os.path.join(base_path, "ultralytics", "segm")],
75
+ _SUPPORTED_PT_EXTS,
76
+ )
77
+ _add_folder_path_and_extensions(
78
+ "ultralytics",
79
+ [os.path.join(base_path, "ultralytics")],
80
+ _SUPPORTED_PT_EXTS,
81
+ )
82
+
83
+
84
+ # Register common folders (models_dir + ComfyUI-Manager download_model_base)
85
+ _update_model_paths(folder_paths.models_dir)
86
+ if "download_model_base" in folder_paths.folder_names_and_paths:
87
+ try:
88
+ _update_model_paths(folder_paths.get_folder_paths("download_model_base")[0])
89
+ except Exception:
90
+ pass
91
+
92
+ # Also register local folder(s) inside THIS custom-node extension, so you can keep
93
+ # models next to your Salia_*.py files if you want.
94
+ _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
95
+ for local_dir in [
96
+ os.path.join(_THIS_DIR, "nodes"),
97
+ os.path.join(_THIS_DIR, "models"),
98
+ _THIS_DIR,
99
+ ]:
100
+ if os.path.isdir(local_dir):
101
+ _add_folder_path_and_extensions("ultralytics_bbox", [local_dir], _SUPPORTED_PT_EXTS)
102
+ _add_folder_path_and_extensions("ultralytics_segm", [local_dir], _SUPPORTED_PT_EXTS)
103
+ _add_folder_path_and_extensions("ultralytics", [local_dir], _SUPPORTED_PT_EXTS)
104
+
105
+
106
+ # ---------------------------
107
+ # Optional safe-load fallback (PyTorch >= 2.6)
108
+ # ---------------------------
109
+
110
+ _ORIG_TORCH_LOAD = torch.load
111
+
112
+
113
+ def _get_whitelist_file() -> str | None:
114
+ """Create/return the whitelist file path under ComfyUI's user directory."""
115
+ try:
116
+ user_dir = folder_paths.get_user_directory()
117
+ except Exception:
118
+ user_dir = None
119
+
120
+ if not user_dir or not os.path.isdir(user_dir):
121
+ return None
122
+
123
+ wl_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics")
124
+ wl_file = os.path.join(wl_dir, "model-whitelist.txt")
125
+ try:
126
+ os.makedirs(wl_dir, exist_ok=True)
127
+ if not os.path.exists(wl_file):
128
+ with open(wl_file, "w", encoding="utf-8") as f:
129
+ f.write("# Add base filenames of trusted legacy models here (one per line).\n")
130
+ f.write("# Example: eyes.pt\n")
131
+ f.write("# These will be allowed to load with weights_only=False if safe loading fails.\n")
132
+ f.write("# WARNING: Only add models you trust.\n")
133
+ except Exception:
134
+ return None
135
+
136
+ return wl_file
137
+
138
+
139
+ _WHITELIST_PATH = _get_whitelist_file()
140
+
141
+
142
+ # ---------------------------
143
+ # Model path logging (requested)
144
+ # ---------------------------
145
+
146
+ def _get_model_load_log_file() -> str:
147
+ """
148
+ Log file path used to record which ultralytics model file was actually loaded.
149
+ Prefer the same ComfyUI user dir used for the whitelist (if available).
150
+ """
151
+ # If whitelist exists, put log next to it (same directory).
152
+ if _WHITELIST_PATH:
153
+ base_dir = os.path.dirname(_WHITELIST_PATH)
154
+ return os.path.join(base_dir, "model-load-log.txt")
155
+
156
+ # Fallback: try ComfyUI user directory
157
+ try:
158
+ user_dir = folder_paths.get_user_directory()
159
+ except Exception:
160
+ user_dir = None
161
+
162
+ if user_dir and os.path.isdir(user_dir):
163
+ base_dir = os.path.join(user_dir, "default", "ComfyUI-Salia-Ultralytics")
164
+ try:
165
+ os.makedirs(base_dir, exist_ok=True)
166
+ except Exception:
167
+ pass
168
+ return os.path.join(base_dir, "model-load-log.txt")
169
+
170
+ # Last resort: next to this python file
171
+ return os.path.join(_THIS_DIR, "model-load-log.txt")
172
+
173
+
174
+ _MODEL_LOAD_LOG_PATH = _get_model_load_log_file()
175
+
176
+
177
+ def _find_all_model_paths(model_name: str) -> list[str]:
178
+ """
179
+ Find all possible on-disk matches across the registered ultralytics folders.
180
+ Useful if the same filename exists in multiple locations.
181
+ """
182
+ matches: list[str] = []
183
+
184
+ try:
185
+ ultra_roots = folder_paths.get_folder_paths("ultralytics")
186
+ except Exception:
187
+ ultra_roots = []
188
+
189
+ try:
190
+ bbox_roots = folder_paths.get_folder_paths("ultralytics_bbox")
191
+ except Exception:
192
+ bbox_roots = []
193
+
194
+ try:
195
+ segm_roots = folder_paths.get_folder_paths("ultralytics_segm")
196
+ except Exception:
197
+ segm_roots = []
198
+
199
+ def add_if_exists(root: str, rel: str):
200
+ p = os.path.join(root, rel)
201
+ if os.path.exists(p):
202
+ matches.append(os.path.abspath(p))
203
+
204
+ # model_name might be "bbox/foo.pt" or "segm/foo.pt" (includes subfolder)
205
+ for r in ultra_roots:
206
+ add_if_exists(r, model_name)
207
+
208
+ # Also search the specialized bbox/segm roots with the prefix stripped
209
+ if model_name.startswith("bbox/"):
210
+ rel = model_name[5:]
211
+ for r in bbox_roots:
212
+ add_if_exists(r, rel)
213
+ elif model_name.startswith("segm/"):
214
+ rel = model_name[5:]
215
+ for r in segm_roots:
216
+ add_if_exists(r, rel)
217
+
218
+ # De-dupe preserving order
219
+ out: list[str] = []
220
+ seen = set()
221
+ for p in matches:
222
+ if p not in seen:
223
+ seen.add(p)
224
+ out.append(p)
225
+ return out
226
+
227
+
228
+ def _log_selected_model(model_name: str, model_path: str, matches: list[str] | None = None):
229
+ """
230
+ Prints the resolved model path to console AND appends it to a log file.
231
+ """
232
+ # 1) Console output
233
+ print(f"[Salia Ultralytics] Selected model_name: {model_name}")
234
+ print(f"[Salia Ultralytics] Resolved model_path: {model_path}")
235
+ if matches and len(matches) > 1:
236
+ print("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):")
237
+ for p in matches:
238
+ print(f" - {p}")
239
+ print(f"[Salia Ultralytics] Model load log file: {_MODEL_LOAD_LOG_PATH}")
240
+
241
+ # Also emit to python logging (ComfyUI typically captures this)
242
+ logging.info("[Salia Ultralytics] Selected model_name: %s", model_name)
243
+ logging.info("[Salia Ultralytics] Resolved model_path: %s", model_path)
244
+ if matches and len(matches) > 1:
245
+ logging.warning("[Salia Ultralytics] Multiple matches found (first one is used by get_full_path):")
246
+ for p in matches:
247
+ logging.warning(" - %s", p)
248
+ logging.info("[Salia Ultralytics] Model load log file: %s", _MODEL_LOAD_LOG_PATH)
249
+
250
+ # 2) File append
251
+ try:
252
+ ts = datetime.now().isoformat(timespec="seconds")
253
+ exists = os.path.isfile(model_path)
254
+ size = os.path.getsize(model_path) if exists else -1
255
+
256
+ log_dir = os.path.dirname(_MODEL_LOAD_LOG_PATH)
257
+ if log_dir:
258
+ os.makedirs(log_dir, exist_ok=True)
259
+
260
+ with open(_MODEL_LOAD_LOG_PATH, "a", encoding="utf-8") as f:
261
+ f.write(f"{ts}\t{model_name}\t{model_path}\texists={exists}\tsize={size}\n")
262
+ if matches and len(matches) > 1:
263
+ for p in matches:
264
+ f.write(f"{ts}\tmatch\t{p}\n")
265
+ except Exception as e:
266
+ logging.warning("[Salia Ultralytics] Failed to write model-load log to %s: %s", _MODEL_LOAD_LOG_PATH, e)
267
+
268
+
269
+ def _load_whitelist(filepath: str | None) -> set[str]:
270
+ if not filepath:
271
+ return set()
272
+ try:
273
+ approved: set[str] = set()
274
+ with open(filepath, "r", encoding="utf-8") as f:
275
+ for line in f:
276
+ line = line.strip()
277
+ if line and not line.startswith("#"):
278
+ approved.add(os.path.basename(line))
279
+ return approved
280
+ except Exception:
281
+ return set()
282
+
283
+
284
+ _MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH)
285
+
286
+
287
+ def _torch_load_wrapper(*args, **kwargs):
288
+ """Try safe load first; if it fails due to weights-only restrictions, allow fallback if whitelisted."""
289
+ filename = None
290
+ if args and isinstance(args[0], str):
291
+ filename = os.path.basename(args[0])
292
+ elif isinstance(kwargs.get("f"), str):
293
+ filename = os.path.basename(kwargs["f"])
294
+
295
+ try:
296
+ return _ORIG_TORCH_LOAD(*args, **kwargs)
297
+ except pickle.UnpicklingError as e:
298
+ msg = str(e)
299
+ # Heuristic: this is the common PyTorch >=2.6 safe-load failure mode.
300
+ maybe_weights_only_error = (
301
+ "Weights only load failed" in msg
302
+ or "Unsupported global" in msg
303
+ or "disallowed" in msg
304
+ or "not allowed" in msg
305
+ or "getattr" in msg
306
+ )
307
+
308
+ if not maybe_weights_only_error:
309
+ raise
310
+
311
+ # Refresh whitelist from disk (so users can edit without restarting, sometimes)
312
+ global _MODEL_WHITELIST
313
+ _MODEL_WHITELIST = _load_whitelist(_WHITELIST_PATH)
314
+
315
+ if filename and filename in _MODEL_WHITELIST:
316
+ logging.warning(
317
+ "[Salia Ultralytics] Safe torch.load failed for '%s'. Retrying with weights_only=False because it's whitelisted (%s).",
318
+ filename,
319
+ _WHITELIST_PATH,
320
+ )
321
+ retry_kwargs = dict(kwargs)
322
+ retry_kwargs["weights_only"] = False
323
+ return _ORIG_TORCH_LOAD(*args, **retry_kwargs)
324
+
325
+ logging.error(
326
+ "[Salia Ultralytics] Blocked unsafe model load for '%s'.\n"
327
+ "Safe loading failed and the file is not whitelisted.\n"
328
+ "If you TRUST this model, add its base name to: %s",
329
+ filename or "[unknown]",
330
+ _WHITELIST_PATH or "[whitelist path unavailable]",
331
+ )
332
+ raise
333
+
334
+
335
+ @contextmanager
336
+ def _patched_torch_load_for_ultralytics():
337
+ """Patch torch.load only while ultralytics loads a checkpoint."""
338
+ # If PyTorch doesn't even have the safe-loader feature, don't patch.
339
+ if not hasattr(torch.serialization, "safe_globals"):
340
+ yield
341
+ return
342
+
343
+ prev = torch.load
344
+ torch.load = _torch_load_wrapper
345
+ try:
346
+ yield
347
+ finally:
348
+ torch.load = prev
349
+
350
+
351
+ def _load_yolo(model_path: str):
352
+ """Load an Ultralytics YOLO model (with optional safe-load fallback)."""
353
+ try:
354
+ from ultralytics import YOLO # lazy import
355
+ except Exception as e:
356
+ raise ImportError(
357
+ "[Salia Ultralytics] ultralytics is not installed. Install it in your ComfyUI env, e.g.:\n"
358
+ "pip install ultralytics"
359
+ ) from e
360
+
361
+ with _patched_torch_load_for_ultralytics():
362
+ return YOLO(model_path)
363
+
364
+
365
+ # ---------------------------
366
+ # Minimal Impact-compatible utilities (self-contained)
367
+ # ---------------------------
368
+
369
+ def _tensor2np_rgb(image: torch.Tensor) -> np.ndarray:
370
+ """Convert a ComfyUI IMAGE tensor to a uint8 RGB numpy image."""
371
+ # ComfyUI image is usually: (B,H,W,C) float in [0,1]
372
+ if not isinstance(image, torch.Tensor):
373
+ raise TypeError(f"Expected torch.Tensor, got {type(image)}")
374
+
375
+ if image.dim() == 4:
376
+ img = image[0]
377
+ else:
378
+ img = image
379
+
380
+ img = img.detach()
381
+ if img.is_cuda:
382
+ img = img.cpu()
383
+
384
+ img = img.clamp(0, 1).numpy()
385
+ if img.shape[-1] == 1:
386
+ img = np.repeat(img, 3, axis=-1)
387
+
388
+ img_u8 = (img * 255.0).round().astype(np.uint8)
389
+ return img_u8
390
+
391
+
392
+ def tensor2pil(image: torch.Tensor) -> Image.Image:
393
+ return Image.fromarray(_tensor2np_rgb(image))
394
+
395
+
396
+ def make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float, crop_min_size: int | None = None):
397
+ x1, y1, x2, y2 = [float(v) for v in bbox_xyxy]
398
+ bbox_w = max(1.0, x2 - x1)
399
+ bbox_h = max(1.0, y2 - y1)
400
+
401
+ crop_w = bbox_w * float(crop_factor)
402
+ crop_h = bbox_h * float(crop_factor)
403
+
404
+ if crop_min_size is not None:
405
+ crop_w = max(crop_w, float(crop_min_size))
406
+ crop_h = max(crop_h, float(crop_min_size))
407
+
408
+ cx = (x1 + x2) / 2.0
409
+ cy = (y1 + y2) / 2.0
410
+
411
+ rx1 = int(round(cx - crop_w / 2.0))
412
+ ry1 = int(round(cy - crop_h / 2.0))
413
+ rx2 = int(round(cx + crop_w / 2.0))
414
+ ry2 = int(round(cy + crop_h / 2.0))
415
+
416
+ rx1 = max(0, min(w - 1, rx1))
417
+ ry1 = max(0, min(h - 1, ry1))
418
+ rx2 = max(rx1 + 1, min(w, rx2))
419
+ ry2 = max(ry1 + 1, min(h, ry2))
420
+
421
+ return (rx1, ry1, rx2, ry2)
422
+
423
+
424
+ def crop_image(image: torch.Tensor, crop_region):
425
+ x1, y1, x2, y2 = crop_region
426
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
427
+ if image.dim() == 4:
428
+ return image[:, y1:y2, x1:x2, :]
429
+ if image.dim() == 3:
430
+ return image[y1:y2, x1:x2, :]
431
+ raise ValueError(f"Unexpected image tensor shape: {tuple(image.shape)}")
432
+
433
+
434
+ def crop_ndarray2(arr: np.ndarray, crop_region):
435
+ x1, y1, x2, y2 = crop_region
436
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
437
+ return arr[y1:y2, x1:x2]
438
+
439
+
440
+ def dilate_masks(segmasks, dilation: int):
441
+ if dilation <= 0:
442
+ return segmasks
443
+ if cv2 is None:
444
+ raise ImportError(
445
+ "[Salia Ultralytics] opencv-python is required for mask dilation but cv2 could not be imported.\n"
446
+ "Install: pip install opencv-python-headless"
447
+ )
448
+
449
+ k = int(dilation)
450
+ ksize = k * 2 + 1
451
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
452
+
453
+ out = []
454
+ for bbox, mask, conf in segmasks:
455
+ m = (mask > 0.5).astype(np.uint8) * 255
456
+ m = cv2.dilate(m, kernel, iterations=1)
457
+ out.append((bbox, (m > 0).astype(np.float32), conf))
458
+ return out
459
+
460
+
461
+ def combine_masks(segmasks, out_shape_hw: tuple[int, int] | None = None) -> torch.Tensor:
462
+ if not segmasks:
463
+ if out_shape_hw is None:
464
+ return torch.zeros((1, 1, 1), dtype=torch.float32)
465
+ h, w = out_shape_hw
466
+ return torch.zeros((1, h, w), dtype=torch.float32)
467
+
468
+ base = segmasks[0][1]
469
+ combined = np.zeros_like(base, dtype=np.float32)
470
+ for _, m, _ in segmasks:
471
+ combined = np.maximum(combined, m.astype(np.float32))
472
+ return torch.from_numpy(combined).unsqueeze(0)
473
+
474
+
475
+ # ---------------------------
476
+ # Impact-compatible detector wrapper objects
477
+ # ---------------------------
478
+
479
+ SEG = namedtuple(
480
+ "SEG",
481
+ [
482
+ "cropped_image",
483
+ "cropped_mask",
484
+ "confidence",
485
+ "crop_region",
486
+ "bbox",
487
+ "label",
488
+ "control_net_wrapper",
489
+ ],
490
+ defaults=[None],
491
+ )
492
+
493
+
494
+ class NO_BBOX_DETECTOR:
495
+ pass
496
+
497
+
498
+ class NO_SEGM_DETECTOR:
499
+ pass
500
+
501
+
502
+ def _create_segmasks(results):
503
+ # results = [labels, bboxes_xyxy, segms, confs]
504
+ bboxes = results[1]
505
+ segms = results[2]
506
+ confs = results[3]
507
+
508
+ out = []
509
+ for i in range(len(segms)):
510
+ out.append((bboxes[i], segms[i].astype(np.float32), confs[i]))
511
+ return out
512
+
513
+
514
+ def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""):
515
+ pred = model(image_pil, conf=confidence, device=device)
516
+
517
+ bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy
518
+ if bboxes.shape[0] == 0:
519
+ return [[], [], [], []]
520
+
521
+ # Make simple rectangle masks for each bbox
522
+ np_img = np.array(image_pil)
523
+ if np_img.ndim == 2:
524
+ h, w = np_img.shape
525
+ else:
526
+ h, w = np_img.shape[0], np_img.shape[1]
527
+
528
+ segms = []
529
+ for x0, y0, x1, y1 in bboxes:
530
+ m = np.zeros((h, w), dtype=np.uint8)
531
+ x0i, y0i, x1i, y1i = int(x0), int(y0), int(x1), int(y1)
532
+ x0i = max(0, min(w - 1, x0i))
533
+ x1i = max(0, min(w, x1i))
534
+ y0i = max(0, min(h - 1, y0i))
535
+ y1i = max(0, min(h, y1i))
536
+ if cv2 is not None:
537
+ cv2.rectangle(m, (x0i, y0i), (x1i, y1i), 255, -1)
538
+ else:
539
+ m[y0i:y1i, x0i:x1i] = 255
540
+ segms.append((m > 0))
541
+
542
+ labels = []
543
+ confs = []
544
+ for i in range(len(bboxes)):
545
+ labels.append(pred[0].names[int(pred[0].boxes[i].cls.item())])
546
+ confs.append(pred[0].boxes[i].conf.detach().cpu().numpy())
547
+
548
+ return [labels, list(bboxes), segms, confs]
549
+
550
+
551
+ def _inference_segm(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""):
552
+ pred = model(image_pil, conf=confidence, device=device)
553
+
554
+ bboxes = pred[0].boxes.xyxy.cpu().numpy() # xyxy
555
+ if bboxes.shape[0] == 0:
556
+ return [[], [], [], []]
557
+
558
+ if pred[0].masks is None or pred[0].masks.data is None:
559
+ # fallback: no masks, treat like bbox
560
+ return _inference_bbox(model, image_pil, confidence=confidence, device=device)
561
+
562
+ segms = pred[0].masks.data.detach().cpu().numpy() # (n, h, w) in model-space
563
+
564
+ # Resize masks back to original image size
565
+ h_orig = image_pil.size[1]
566
+ w_orig = image_pil.size[0]
567
+
568
+ results = [[], [], [], []]
569
+
570
+ for i in range(len(bboxes)):
571
+ results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
572
+ results[1].append(bboxes[i])
573
+
574
+ mask = torch.from_numpy(segms[i]).float()
575
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(h_orig, w_orig), mode="bilinear", align_corners=False)
576
+ mask = mask.squeeze(0).squeeze(0)
577
+
578
+ results[2].append(mask.numpy())
579
+ results[3].append(pred[0].boxes[i].conf.detach().cpu().numpy())
580
+
581
+ return results
582
+
583
+
584
+ class SaliaUltraBBoxDetector:
585
+ def __init__(self, model):
586
+ self.model = model
587
+
588
+ def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
589
+ drop_size = max(int(drop_size), 1)
590
+ detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold))
591
+ segmasks = _create_segmasks(detected)
592
+
593
+ if int(dilation) > 0:
594
+ segmasks = dilate_masks(segmasks, int(dilation))
595
+
596
+ items = []
597
+ h = image.shape[1]
598
+ w = image.shape[2]
599
+
600
+ for (bbox, mask, conf), label in zip(segmasks, detected[0]):
601
+ x1, y1, x2, y2 = bbox
602
+ if (x2 - x1) > drop_size and (y2 - y1) > drop_size:
603
+ crop_region = make_crop_region(w, h, bbox, float(crop_factor))
604
+
605
+ if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
606
+ crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region)
607
+
608
+ cropped_image = crop_image(image, crop_region)
609
+ cropped_mask = crop_ndarray2(mask, crop_region)
610
+
611
+ items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None))
612
+
613
+ segs = (image.shape[1], image.shape[2]), items
614
+
615
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
616
+ segs = detailer_hook.post_detection(segs)
617
+
618
+ return segs
619
+
620
+ def detect_combined(self, image, threshold, dilation):
621
+ detected = _inference_bbox(self.model, tensor2pil(image), confidence=float(threshold))
622
+ segmasks = _create_segmasks(detected)
623
+ if int(dilation) > 0:
624
+ segmasks = dilate_masks(segmasks, int(dilation))
625
+ return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2]))
626
+
627
+ def setAux(self, x):
628
+ pass
629
+
630
+
631
+ class SaliaUltraSegmDetector:
632
+ def __init__(self, model):
633
+ self.model = model
634
+
635
+ def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
636
+ drop_size = max(int(drop_size), 1)
637
+ detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold))
638
+ segmasks = _create_segmasks(detected)
639
+
640
+ if int(dilation) > 0:
641
+ segmasks = dilate_masks(segmasks, int(dilation))
642
+
643
+ items = []
644
+ h = image.shape[1]
645
+ w = image.shape[2]
646
+
647
+ for (bbox, mask, conf), label in zip(segmasks, detected[0]):
648
+ x1, y1, x2, y2 = bbox
649
+ if (x2 - x1) > drop_size and (y2 - y1) > drop_size:
650
+ crop_region = make_crop_region(w, h, bbox, float(crop_factor))
651
+
652
+ if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
653
+ crop_region = detailer_hook.post_crop_region(w, h, bbox, crop_region)
654
+
655
+ cropped_image = crop_image(image, crop_region)
656
+ cropped_mask = crop_ndarray2(mask, crop_region)
657
+
658
+ items.append(SEG(cropped_image, cropped_mask, conf, crop_region, bbox, label, None))
659
+
660
+ segs = (image.shape[1], image.shape[2]), items
661
+
662
+ if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
663
+ segs = detailer_hook.post_detection(segs)
664
+
665
+ return segs
666
+
667
+ def detect_combined(self, image, threshold, dilation):
668
+ detected = _inference_segm(self.model, tensor2pil(image), confidence=float(threshold))
669
+ segmasks = _create_segmasks(detected)
670
+ if int(dilation) > 0:
671
+ segmasks = dilate_masks(segmasks, int(dilation))
672
+ return combine_masks(segmasks, out_shape_hw=(image.shape[1], image.shape[2]))
673
+
674
+ def setAux(self, x):
675
+ pass
676
+
677
+
678
+ # ---------------------------
679
+ # The actual ComfyUI Node
680
+ # ---------------------------
681
+
682
+ class SaliaUltralyticsDetectorProvider2:
683
+ """Load an Ultralytics `.pt` model and provide Impact-compatible detectors."""
684
+
685
+ @classmethod
686
+ def INPUT_TYPES(cls):
687
+ bboxs = ["bbox/" + x for x in folder_paths.get_filename_list("ultralytics_bbox")]
688
+ segms = ["segm/" + x for x in folder_paths.get_filename_list("ultralytics_segm")]
689
+ return {"required": {"model_name": (bboxs + segms,)}}
690
+
691
+ RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR")
692
+ FUNCTION = "doit"
693
+ CATEGORY = "Salia/Detectors"
694
+
695
+ def doit(self, model_name: str):
696
+ # First, allow selecting a file like "bbox/foo.pt" that lives under models/ultralytics/bbox
697
+ model_path = folder_paths.get_full_path("ultralytics", model_name)
698
+
699
+ if model_path is None:
700
+ if model_name.startswith("bbox/"):
701
+ model_path = folder_paths.get_full_path("ultralytics_bbox", model_name[5:])
702
+ elif model_name.startswith("segm/"):
703
+ model_path = folder_paths.get_full_path("ultralytics_segm", model_name[5:])
704
+
705
+ if model_path is None:
706
+ cands = []
707
+ try:
708
+ cands.extend(folder_paths.get_folder_paths("ultralytics"))
709
+ if model_name.startswith("bbox/"):
710
+ cands.extend(folder_paths.get_folder_paths("ultralytics_bbox"))
711
+ elif model_name.startswith("segm/"):
712
+ cands.extend(folder_paths.get_folder_paths("ultralytics_segm"))
713
+ except Exception:
714
+ pass
715
+
716
+ formatted = "\n\t".join(cands)
717
+ raise ValueError(
718
+ f"[Salia Ultralytics] model file '{model_name}' was not found.\n"
719
+ f"Searched these folders:\n\t{formatted}\n"
720
+ f"Tip: put bbox models in 'models/ultralytics/bbox' or segm models in 'models/ultralytics/segm'."
721
+ )
722
+
723
+ # NEW: print + log the resolved on-disk path (and any duplicates)
724
+ matches = _find_all_model_paths(model_name)
725
+ _log_selected_model(model_name, os.path.abspath(model_path), matches)
726
+
727
+ model = _load_yolo(model_path)
728
+
729
+ if model_name.startswith("bbox/"):
730
+ return SaliaUltraBBoxDetector(model), NO_SEGM_DETECTOR()
731
+ else:
732
+ return SaliaUltraBBoxDetector(model), SaliaUltraSegmDetector(model)
733
+
734
+
735
+ NODE_CLASS_MAPPINGS = {
736
+ "SaliaUltralyticsDetectorProvider2": SaliaUltralyticsDetectorProvider2,
737
+ }
738
+
739
+ NODE_DISPLAY_NAME_MAPPINGS = {
740
+ "SaliaUltralyticsDetectorProvider2": "Salia Ultralytics Detector 2 (Salia)",
741
+ }