File size: 22,453 Bytes
90fbd5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 |
# FaceDetailerStandalone_MIN_FIXED_FAST_EMBEDDED_SAM.py
# One-node Face Detailer (image-only) with fixed settings + embedded Ultralytics bbox detector + embedded SAM loader.
# - Output parity with Impact Pack Face Detailer at the same settings
# - No separate bbox-detector node; detector is cached/constructed internally
# - No separate SAM loader node; SAM is cached/constructed internally
# - Lightweight runtime overhead (cached imports, inference_mode, fused layers, TF32, FP16 on CUDA)
import os
from dataclasses import dataclass
from typing import List, Tuple, Optional
import warnings
warnings.filterwarnings("ignore")
# Silence OpenCV before importing it (env var) and after (setLogLevel)
os.environ["OPENCV_LOG_LEVEL"] = "ERROR"
import numpy as np
import torch
import comfy
from PIL import Image
import cv2
try:
if hasattr(cv2, "setLogLevel"):
try:
lvl = cv2.LOG_LEVEL_ERROR if hasattr(cv2, "LOG_LEVEL_ERROR") else 3 # 3 == error
cv2.setLogLevel(lvl)
except Exception:
pass
except Exception:
pass
# ---------------- Fixed FaceDetailer settings (do not expose in UI) ----------------
# GUIDE_SIZE = 512
# GUIDE_SIZE_FOR_BBOX = True
# MAX_SIZE = 1024
# STEPS = 30
# CFG = 7.0
# SCHEDULER = "simple"
# DENOISE = 0.5
# FEATHER = 5
# NOISE_MASK = True
# FORCE_INPAINT = True
# BBOX_THRESHOLD = 0.5
# BBOX_DILATION = 10
# BBOX_CROP_FACTOR = 3.0
# DROP_SIZE = 10
# SAM_DETECTION_HINT = "center-1"
# SAM_DILATION = 0
# SAM_THRESHOLD = 0.93
# SAM_BBOX_EXPANSION = 0
# SAM_MASK_HINT_THRESHOLD = 0.7
# SAM_MASK_HINT_USE_NEGATIVE = "False"
# WILDCARD = ""
# CYCLE = 1
# INPAINT_MODEL = False
# NOISE_MASK_FEATHER = 20
# TILED_ENCODE = False
# TILED_DECODE = False
# ---------------------------------------------------------------------
# ---------------- Ultralytics / YOLO detector integration (embedded) ----------------
# Torch runtime perf switches
torch.backends.cudnn.benchmark = True # autotune best conv algorithms
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high") # PyTorch 2.x
except Exception:
pass
# Optional Impact Pack interop (SEG type)
try:
# If Impact Pack is installed, use its SEG to be perfectly compatible.
from impact.core import SEG as _IMPACT_SEG # type: ignore
_USE_IMPACT_SEG = True
except Exception:
_USE_IMPACT_SEG = False
@dataclass
class _LocalSEG:
cropped_image: Optional[torch.Tensor]
cropped_mask: np.ndarray # 2D float32 [0..1]
confidence: float
crop_region: Tuple[int, int, int, int] # (x1,y1,x2,y2)
bbox: Tuple[int, int, int, int] # (x1,y1,x2,y2)
label: str
control_net_wrapper: Optional[object] = None
SEG = _IMPACT_SEG if _USE_IMPACT_SEG else _LocalSEG
# ---------------------------------------------------------------------
# LOCAL ASSET PATHS (no hardcoded absolute paths)
# ---------------------------------------------------------------------
# Base directory of this node file (cross-platform, works on RunPod/ComfyUI)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Local YOLO model path inside this custom node folder
YOLO_MODEL_PATH = os.path.join(BASE_DIR, "assets", "face_yolov8m_salia.pt")
YOLO_IMGSZ = 640
# Local SAM checkpoint path inside this custom node folder
SAM_CKPT_PATH = os.path.join(BASE_DIR, "assets", "sam_vit_b_01ec64_salia.pth")
# Cached instances (process-local)
_CACHED_YOLO_MODEL = None
_CACHED_ULTRA_DETECTOR = None
def _tensor_to_pil(image: torch.Tensor) -> Image.Image:
# image: [1, H, W, 3], float(0..1)
img = image[0].detach().cpu().clamp(0, 1).numpy()
img = (img * 255.0).round().astype(np.uint8) # (H, W, 3) RGB
return Image.fromarray(img, mode="RGB")
def _make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float) -> Tuple[int, int, int, int]:
x1, y1, x2, y2 = map(int, bbox_xyxy)
cx = (x1 + x2) * 0.5
cy = (y1 + y2) * 0.5
bw = (x2 - x1)
bh = (y2 - y1)
new_w = max(1, int(bw * crop_factor))
new_h = max(1, int(bh * crop_factor))
# center to image
nx1 = int(max(0, round(cx - new_w * 0.5)))
ny1 = int(max(0, round(cy - new_h * 0.5)))
nx2 = int(min(w, nx1 + new_w))
ny2 = int(min(h, ny1 + new_h))
# clamp again
nx1 = max(0, min(nx1, w - 1))
ny1 = max(0, min(ny1, h - 1))
nx2 = max(nx1 + 1, min(nx2, w))
ny2 = max(ny1 + 1, min(ny2, h))
return (nx1, ny1, nx2, ny2)
def _crop_tensor_image(image: torch.Tensor, crop: Tuple[int, int, int, int]) -> torch.Tensor:
# image: [1,H,W,3]; crop: (x1,y1,x2,y2)
x1, y1, xb, yb = crop
return image[:, y1:yb, x1:xb, :].contiguous()
def _crop_ndarray(mask: np.ndarray, crop: Tuple[int, int, int, int]) -> np.ndarray:
# mask: [H,W] float/bool/uint8; crop: (x1,y1,x2,y2)
x1, y1, xb, yb = crop
return mask[int(y1):int(yb), int(x1):int(xb)]
def _dilate_masks(segmasks: List[Tuple[np.ndarray, np.ndarray, float]], factor: int):
if factor == 0 or not segmasks:
return segmasks
k = abs(int(factor))
if k < 1:
return segmasks
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
do_dilate = factor > 0
out = []
for (bbox, m, conf) in segmasks:
u8 = (m * 255.0).astype(np.uint8) if m.dtype != np.uint8 else m
d = cv2.dilate(u8, kernel, iterations=1) if do_dilate else cv2.erode(u8, kernel, iterations=1)
out.append((bbox, d.astype(np.float32) / 255.0, conf))
return out
def _combine_masks(segmasks: List[Tuple[np.ndarray, np.ndarray, float]]) -> Optional[torch.Tensor]:
if not segmasks:
return None
h = segmasks[0][1].shape[0]
w = segmasks[0][1].shape[1]
acc = np.zeros((h, w), dtype=np.uint8)
for _, m, _ in segmasks:
u8 = (m * 255.0).astype(np.uint8) if m.dtype != np.uint8 else m
acc = cv2.bitwise_or(acc, u8)
return torch.from_numpy(acc.astype(np.float32) / 255.0) # [H,W], float32 0..1 CPU
def _pick_device_str(user_device: str = "") -> str:
if user_device:
return user_device
return "cuda" if torch.cuda.is_available() else "cpu"
@torch.inference_mode()
def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""):
"""
Returns results = [labels(str), bboxes(xyxy), segms(full-image bool masks), conf(float)]
For bbox models, segm "masks" are rectangles from the boxes (Subpack parity).
"""
pred = model(
image_pil,
conf=confidence,
device=_pick_device_str(device),
verbose=False,
imgsz=YOLO_IMGSZ, # fixed size can be faster
)
p0 = pred[0]
boxes = p0.boxes
bboxes = boxes.xyxy.detach().cpu().numpy() # (N,4) float, xyxy
W, H = image_pil.size
segms = []
for x0, y0, x1, y1 in bboxes:
m = np.zeros((H, W), np.uint8)
cv2.rectangle(m, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
segms.append(m.astype(bool))
if bboxes.shape[0] == 0:
return [[], [], [], []]
results = [[], [], [], []]
names = p0.names
for i, (bbox, segm) in enumerate(zip(bboxes, segms)):
cls_i = int(boxes.cls[i].item())
results[0].append(names[cls_i])
results[1].append(bbox)
results[2].append(segm)
results[3].append(float(boxes.conf[i].item()))
return results
def _create_segmasks(results):
bboxes = results[1]
segms = results[2]
confs = results[3]
out = []
for i in range(len(segms)):
out.append((bboxes[i], segms[i].astype(np.float32), confs[i]))
return out
class UltraBBoxDetector:
def __init__(self, yolo_model):
self.bbox_model = yolo_model
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
drop_size = max(int(drop_size), 1)
detected = _inference_bbox(self.bbox_model, _tensor_to_pil(image), threshold)
segmasks = _create_segmasks(detected)
if int(dilation) != 0:
segmasks = _dilate_masks(segmasks, int(dilation))
H = int(image.shape[1])
W = int(image.shape[2])
items = []
for (bbox_xyxy, full_mask, conf), label in zip(segmasks, detected[0]):
x1, y1, x2, y2 = map(int, bbox_xyxy)
if (x2 - x1) > drop_size and (y2 - y1) > drop_size:
crop_region = _make_crop_region(W, H, (x1, y1, x2, y2), float(crop_factor))
if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
crop_region = detailer_hook.post_crop_region(W, H, (x1, y1, x2, y2), crop_region)
cropped_image = _crop_tensor_image(image, crop_region)
cropped_mask = _crop_ndarray(full_mask, crop_region).astype(np.float32)
items.append(SEG(cropped_image, cropped_mask, float(conf), crop_region, (x1, y1, x2, y2), str(label), None))
segs = ((H, W), items)
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
segs = detailer_hook.post_detection(segs)
return segs
def detect_combined(self, image, threshold, dilation):
detected = _inference_bbox(self.bbox_model, _tensor_to_pil(image), threshold)
segmasks = _create_segmasks(detected)
if int(dilation) != 0:
segmasks = _dilate_masks(segmasks, int(dilation))
return _combine_masks(segmasks)
def setAux(self, x):
# kept for signature parity
pass
def _load_ultralytics_model(model_path: str):
# Import here so that module import doesn't hard-fail if ultralytics is missing
try:
from ultralytics import YOLO
except Exception as e:
raise RuntimeError(
"[FaceDetailerStandalone] The 'ultralytics' package is required for the embedded bbox detector.\n"
"Install in your ComfyUI python: python -m pip install --upgrade ultralytics"
) from e
if not os.path.isfile(model_path):
raise FileNotFoundError(
"[FaceDetailerStandalone] Embedded YOLO model file not found.\n"
f"Expected at: {model_path}\n"
"Please place 'face_yolov8m_salia.pt' in the 'assets' folder next to this node."
)
yolo = YOLO(model_path)
# One-time graph/model optimizations
try:
dev = _pick_device_str()
try:
yolo.to(dev) # newer Ultralytics
except Exception:
yolo.model.to(dev) # older versions
except Exception:
pass
# Fuse Conv+BN where possible (small speedup)
try:
yolo.fuse()
except Exception:
pass
# Use half precision weights on CUDA (big win; safe for inference)
try:
if torch.cuda.is_available():
yolo.model.half()
except Exception:
pass
return yolo
def _get_embedded_detector():
global _CACHED_YOLO_MODEL, _CACHED_ULTRA_DETECTOR
if _CACHED_ULTRA_DETECTOR is not None:
return _CACHED_ULTRA_DETECTOR
if _CACHED_YOLO_MODEL is None:
_CACHED_YOLO_MODEL = _load_ultralytics_model(YOLO_MODEL_PATH)
_CACHED_ULTRA_DETECTOR = UltraBBoxDetector(_CACHED_YOLO_MODEL)
return _CACHED_ULTRA_DETECTOR
# ---------------- Embedded SAM loader (GPU-only, hardcoded path, reuse one predictor) ----------------
# Matches your SAMLoaderStandalone design, but embedded + cached.
def _to_numpy_rgb(image_tensor):
"""
Comfy 'IMAGE' is NHWC in [0..1]. Convert to uint8 HxWx3 RGB numpy.
Accepts torch.Tensor (NHWC) or numpy already in HWC.
"""
if isinstance(image_tensor, torch.Tensor):
img = image_tensor
if img.dim() == 4 and img.shape[0] == 1:
img = img[0]
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).cpu().numpy() # HWC
return img
elif isinstance(image_tensor, np.ndarray):
if image_tensor.dtype != np.uint8:
img = np.clip(image_tensor, 0, 255).astype(np.uint8)
else:
img = image_tensor
return img
else:
raise TypeError(f"Unsupported image type for SAM: {type(image_tensor)}")
class _SAMWrapperGPUOnlyFast:
"""
FaceDetailer-compatible wrapper:
- Stays on CUDA
- Reuses a single SamPredictor
- predict(image, points, plabs, bbox, threshold) -> list[HxW float32 CPU masks]
"""
def __init__(self, model):
self.model = model
dev = comfy.model_management.get_torch_device()
if "cuda" not in str(dev).lower():
raise RuntimeError(
f"[FaceDetailerStandalone] GPU-only SAM: CUDA device not available (got '{dev}')."
)
self._device = dev
self.model.to(self._device).eval()
# Lazy import for segment_anything predictor
from segment_anything import SamPredictor # type: ignore
# Reuse one predictor instance (cheaper than re-creating every call)
self._predictor = SamPredictor(self.model)
def prepare_device(self):
if "cuda" not in str(self._device).lower():
raise RuntimeError("[FaceDetailerStandalone] CUDA device lost/unavailable for SAM.")
def release_device(self):
# GPU-only; keep on GPU (no-op)
pass
@torch.inference_mode()
def predict(self, image, points, plabs, bbox, threshold: float):
"""
image: Comfy IMAGE (NHWC, [0..1]) or numpy
points: list[[x,y], ...] or None
plabs: list[int] (1=fg, 0=bg) or None
bbox: [x1,y1,x2,y2] or None
threshold: float in [0..1]
returns: list of HxW float32 CPU masks (0/1)
"""
self.prepare_device()
np_img = _to_numpy_rgb(image)
# Some builds call set_image(img, "RGB"); accept both signatures.
try:
self._predictor.set_image(np_img, "RGB")
except TypeError:
self._predictor.set_image(np_img)
pc = np.array(points, dtype=np.float32) if points else None
pl = np.array(plabs, dtype=np.int32) if plabs else None
bx = np.array(bbox, dtype=np.float32) if bbox is not None else None
# Keep provided behavior: multimask_output=False
masks, scores, _ = self._predictor.predict(
point_coords=pc,
point_labels=pl,
box=bx,
multimask_output=False
)
out = []
if masks is not None and scores is not None:
for m, s in zip(masks, scores):
if float(s) >= float(threshold):
if isinstance(m, torch.Tensor):
t = m.to(torch.float32).cpu()
else:
t = torch.from_numpy(m.astype(np.float32)).cpu()
out.append(t)
return out
# Cache for SAM
_CACHED_SAM_MODEL = None
def _get_embedded_sam():
"""Load SAM vit_b from SAM_CKPT_PATH and attach GPU-only fast wrapper, cached."""
global _CACHED_SAM_MODEL
if _CACHED_SAM_MODEL is not None:
return _CACHED_SAM_MODEL
if not os.path.isfile(SAM_CKPT_PATH):
raise FileNotFoundError(
f"[FaceDetailerStandalone] SAM checkpoint not found:\n {SAM_CKPT_PATH}\n"
f"Place 'sam_vit_b_01ec64_salia.pth' in the 'assets' folder next to this node."
)
# Import here to avoid module import failure at file load time
try:
from segment_anything import sam_model_registry # type: ignore
except Exception as e:
raise RuntimeError(
"[FaceDetailerStandalone] 'segment_anything' is not installed for embedded SAM. "
"Install in your Comfy python, e.g.: python -m pip install "
"git+https://github.com/facebookresearch/segment-anything"
) from e
# Fixed to vit_b (matches 'sam_vit_b_01ec64' weights)
sam = sam_model_registry['vit_b'](checkpoint=SAM_CKPT_PATH)
sam.eval() # ensure eval mode
# Attach GPU-only, faster wrapper
wrapper = _SAMWrapperGPUOnlyFast(sam)
sam.sam_wrapper = wrapper
_CACHED_SAM_MODEL = sam
return _CACHED_SAM_MODEL
# ---------------- Impact Pack Face Detailer binding ----------------
_ENHANCE_FACE = None
_IMPORT_ERR = None
try:
from impact.impact_pack import FaceDetailer as _FD
_ENHANCE_FACE = _FD.enhance_face
except Exception as _e:
_IMPORT_ERR = _e
_ENHANCE_FACE = None
# ---------------- Single public node ----------------
class dn_04:
@classmethod
def INPUT_TYPES(cls):
# Only essential, connectable parts remain editable. (No bbox or SAM inputs.)
return {
"required": {
"image": ("IMAGE",),
"model": ("MODEL", {"tooltip": "If `ImpactDummyInput` is connected to model, inference is skipped."}),
"clip": ("CLIP",),
"vae": ("VAE",),
# Keep sampler selectable; all other knobs are fixed
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
# Conditioning stays connectable
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
# Keep seed editable but fixed after generate for reproducibility
"seed": ("INT", {
"default": 0,
"min": 0,
"max": 0xffffffffffffffff,
"step": 1,
"control_after_generate": "fixed",
}),
},
"optional": {
# No external SAM input; embedded
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "doit"
CATEGORY = "ImpactPack/Standalone"
DESCRIPTION = (
"Face Detailer with requested parameters hardcoded (non-editable), "
"and embedded Ultralytics face bbox detector + embedded SAM (no external input nodes). "
"Optimized call path (cached imports + inference_mode) for lower overhead; "
"results identical to Impact Pack Face Detailer at the same settings."
)
def doit(
self,
image, model, clip, vae,
sampler_name,
positive, negative,
seed,
):
if _ENHANCE_FACE is None:
raise RuntimeError(
"ComfyUI-Impact-Pack is required for Face Detailer logic. "
"Please install/enable ComfyUI-Impact-Pack."
) from _IMPORT_ERR
# Embedded detector & SAM (cached)
bbox_detector = _get_embedded_detector()
sam_model_opt = _get_embedded_sam()
enhance = _ENHANCE_FACE
# Determine batch size safely
B = image.shape[0] if (hasattr(image, "shape") and image.ndim == 4) else 1
# No autograd, faster kernel choices, identical math for inference
with torch.inference_mode():
if B == 1:
# Fast-path for single image (avoid list + cat)
single = image[0] if image.ndim == 4 else image # [H,W,C]
enhanced_img, _, _, _, _ = enhance(
single.unsqueeze(0), # -> [1,H,W,C]
model, clip, vae,
512, True, 1024, # guide_size, guide_for_bbox, max_size
seed, 38, 7.0, # steps, cfg
sampler_name, "simple", # scheduler name
positive, negative,
0.4, 5, True, True, # denoise, feather, noise_mask, force_inpaint
0.5, 10, 3.0, # bbox_threshold, bbox_dilation, bbox_crop_factor
"center-1", 0, 0.93, 0, # sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion
0.7, "False", # sam_mask_hint_threshold, sam_mask_hint_use_negative
10, bbox_detector, # drop_size, bbox_detector
# Internals not exposed (kept fixed/None)
segm_detector=None, sam_model_opt=sam_model_opt,
wildcard_opt="", detailer_hook=None,
refiner_ratio=None, refiner_model=None, refiner_clip=None,
refiner_positive=None, refiner_negative=None,
cycle=1, inpaint_model=False,
noise_mask_feather=20,
scheduler_func_opt=None,
tiled_encode=False, tiled_decode=False,
)
return (enhanced_img,)
# Batch of images; per-frame process with seed+i
out_imgs = []
for i, single in enumerate(image.unbind(0)):
enhanced_img, _, _, _, _ = enhance(
single.unsqueeze(0), # [1,H,W,C]
model, clip, vae,
512, True, 1024,
seed + i, 30, 7.0,
sampler_name, "simple",
positive, negative,
0.5, 5, True, True,
0.5, 10, 3.0,
"center-1", 0, 0.93, 0,
0.7, "False",
10, bbox_detector,
segm_detector=None, sam_model_opt=sam_model_opt,
wildcard_opt="", detailer_hook=None,
refiner_ratio=None, refiner_model=None, refiner_clip=None,
refiner_positive=None, refiner_negative=None,
cycle=1, inpaint_model=False,
noise_mask_feather=20,
scheduler_func_opt=None,
tiled_encode=False, tiled_decode=False,
)
out_imgs.append(enhanced_img)
return (torch.cat(out_imgs, dim=0),)
NODE_CLASS_MAPPINGS = {
"dn_04": dn_04,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"dn_04": "dn_04",
}
|