MyCustomNodes / dn04.py
saliacoel's picture
Upload dn04.py
90fbd5f verified
# FaceDetailerStandalone_MIN_FIXED_FAST_EMBEDDED_SAM.py
# One-node Face Detailer (image-only) with fixed settings + embedded Ultralytics bbox detector + embedded SAM loader.
# - Output parity with Impact Pack Face Detailer at the same settings
# - No separate bbox-detector node; detector is cached/constructed internally
# - No separate SAM loader node; SAM is cached/constructed internally
# - Lightweight runtime overhead (cached imports, inference_mode, fused layers, TF32, FP16 on CUDA)
import os
from dataclasses import dataclass
from typing import List, Tuple, Optional
import warnings
warnings.filterwarnings("ignore")
# Silence OpenCV before importing it (env var) and after (setLogLevel)
os.environ["OPENCV_LOG_LEVEL"] = "ERROR"
import numpy as np
import torch
import comfy
from PIL import Image
import cv2
try:
if hasattr(cv2, "setLogLevel"):
try:
lvl = cv2.LOG_LEVEL_ERROR if hasattr(cv2, "LOG_LEVEL_ERROR") else 3 # 3 == error
cv2.setLogLevel(lvl)
except Exception:
pass
except Exception:
pass
# ---------------- Fixed FaceDetailer settings (do not expose in UI) ----------------
# GUIDE_SIZE = 512
# GUIDE_SIZE_FOR_BBOX = True
# MAX_SIZE = 1024
# STEPS = 30
# CFG = 7.0
# SCHEDULER = "simple"
# DENOISE = 0.5
# FEATHER = 5
# NOISE_MASK = True
# FORCE_INPAINT = True
# BBOX_THRESHOLD = 0.5
# BBOX_DILATION = 10
# BBOX_CROP_FACTOR = 3.0
# DROP_SIZE = 10
# SAM_DETECTION_HINT = "center-1"
# SAM_DILATION = 0
# SAM_THRESHOLD = 0.93
# SAM_BBOX_EXPANSION = 0
# SAM_MASK_HINT_THRESHOLD = 0.7
# SAM_MASK_HINT_USE_NEGATIVE = "False"
# WILDCARD = ""
# CYCLE = 1
# INPAINT_MODEL = False
# NOISE_MASK_FEATHER = 20
# TILED_ENCODE = False
# TILED_DECODE = False
# ---------------------------------------------------------------------
# ---------------- Ultralytics / YOLO detector integration (embedded) ----------------
# Torch runtime perf switches
torch.backends.cudnn.benchmark = True # autotune best conv algorithms
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high") # PyTorch 2.x
except Exception:
pass
# Optional Impact Pack interop (SEG type)
try:
# If Impact Pack is installed, use its SEG to be perfectly compatible.
from impact.core import SEG as _IMPACT_SEG # type: ignore
_USE_IMPACT_SEG = True
except Exception:
_USE_IMPACT_SEG = False
@dataclass
class _LocalSEG:
cropped_image: Optional[torch.Tensor]
cropped_mask: np.ndarray # 2D float32 [0..1]
confidence: float
crop_region: Tuple[int, int, int, int] # (x1,y1,x2,y2)
bbox: Tuple[int, int, int, int] # (x1,y1,x2,y2)
label: str
control_net_wrapper: Optional[object] = None
SEG = _IMPACT_SEG if _USE_IMPACT_SEG else _LocalSEG
# ---------------------------------------------------------------------
# LOCAL ASSET PATHS (no hardcoded absolute paths)
# ---------------------------------------------------------------------
# Base directory of this node file (cross-platform, works on RunPod/ComfyUI)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Local YOLO model path inside this custom node folder
YOLO_MODEL_PATH = os.path.join(BASE_DIR, "assets", "face_yolov8m_salia.pt")
YOLO_IMGSZ = 640
# Local SAM checkpoint path inside this custom node folder
SAM_CKPT_PATH = os.path.join(BASE_DIR, "assets", "sam_vit_b_01ec64_salia.pth")
# Cached instances (process-local)
_CACHED_YOLO_MODEL = None
_CACHED_ULTRA_DETECTOR = None
def _tensor_to_pil(image: torch.Tensor) -> Image.Image:
# image: [1, H, W, 3], float(0..1)
img = image[0].detach().cpu().clamp(0, 1).numpy()
img = (img * 255.0).round().astype(np.uint8) # (H, W, 3) RGB
return Image.fromarray(img, mode="RGB")
def _make_crop_region(w: int, h: int, bbox_xyxy, crop_factor: float) -> Tuple[int, int, int, int]:
x1, y1, x2, y2 = map(int, bbox_xyxy)
cx = (x1 + x2) * 0.5
cy = (y1 + y2) * 0.5
bw = (x2 - x1)
bh = (y2 - y1)
new_w = max(1, int(bw * crop_factor))
new_h = max(1, int(bh * crop_factor))
# center to image
nx1 = int(max(0, round(cx - new_w * 0.5)))
ny1 = int(max(0, round(cy - new_h * 0.5)))
nx2 = int(min(w, nx1 + new_w))
ny2 = int(min(h, ny1 + new_h))
# clamp again
nx1 = max(0, min(nx1, w - 1))
ny1 = max(0, min(ny1, h - 1))
nx2 = max(nx1 + 1, min(nx2, w))
ny2 = max(ny1 + 1, min(ny2, h))
return (nx1, ny1, nx2, ny2)
def _crop_tensor_image(image: torch.Tensor, crop: Tuple[int, int, int, int]) -> torch.Tensor:
# image: [1,H,W,3]; crop: (x1,y1,x2,y2)
x1, y1, xb, yb = crop
return image[:, y1:yb, x1:xb, :].contiguous()
def _crop_ndarray(mask: np.ndarray, crop: Tuple[int, int, int, int]) -> np.ndarray:
# mask: [H,W] float/bool/uint8; crop: (x1,y1,x2,y2)
x1, y1, xb, yb = crop
return mask[int(y1):int(yb), int(x1):int(xb)]
def _dilate_masks(segmasks: List[Tuple[np.ndarray, np.ndarray, float]], factor: int):
if factor == 0 or not segmasks:
return segmasks
k = abs(int(factor))
if k < 1:
return segmasks
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
do_dilate = factor > 0
out = []
for (bbox, m, conf) in segmasks:
u8 = (m * 255.0).astype(np.uint8) if m.dtype != np.uint8 else m
d = cv2.dilate(u8, kernel, iterations=1) if do_dilate else cv2.erode(u8, kernel, iterations=1)
out.append((bbox, d.astype(np.float32) / 255.0, conf))
return out
def _combine_masks(segmasks: List[Tuple[np.ndarray, np.ndarray, float]]) -> Optional[torch.Tensor]:
if not segmasks:
return None
h = segmasks[0][1].shape[0]
w = segmasks[0][1].shape[1]
acc = np.zeros((h, w), dtype=np.uint8)
for _, m, _ in segmasks:
u8 = (m * 255.0).astype(np.uint8) if m.dtype != np.uint8 else m
acc = cv2.bitwise_or(acc, u8)
return torch.from_numpy(acc.astype(np.float32) / 255.0) # [H,W], float32 0..1 CPU
def _pick_device_str(user_device: str = "") -> str:
if user_device:
return user_device
return "cuda" if torch.cuda.is_available() else "cpu"
@torch.inference_mode()
def _inference_bbox(model, image_pil: Image.Image, confidence: float = 0.3, device: str = ""):
"""
Returns results = [labels(str), bboxes(xyxy), segms(full-image bool masks), conf(float)]
For bbox models, segm "masks" are rectangles from the boxes (Subpack parity).
"""
pred = model(
image_pil,
conf=confidence,
device=_pick_device_str(device),
verbose=False,
imgsz=YOLO_IMGSZ, # fixed size can be faster
)
p0 = pred[0]
boxes = p0.boxes
bboxes = boxes.xyxy.detach().cpu().numpy() # (N,4) float, xyxy
W, H = image_pil.size
segms = []
for x0, y0, x1, y1 in bboxes:
m = np.zeros((H, W), np.uint8)
cv2.rectangle(m, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
segms.append(m.astype(bool))
if bboxes.shape[0] == 0:
return [[], [], [], []]
results = [[], [], [], []]
names = p0.names
for i, (bbox, segm) in enumerate(zip(bboxes, segms)):
cls_i = int(boxes.cls[i].item())
results[0].append(names[cls_i])
results[1].append(bbox)
results[2].append(segm)
results[3].append(float(boxes.conf[i].item()))
return results
def _create_segmasks(results):
bboxes = results[1]
segms = results[2]
confs = results[3]
out = []
for i in range(len(segms)):
out.append((bboxes[i], segms[i].astype(np.float32), confs[i]))
return out
class UltraBBoxDetector:
def __init__(self, yolo_model):
self.bbox_model = yolo_model
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
drop_size = max(int(drop_size), 1)
detected = _inference_bbox(self.bbox_model, _tensor_to_pil(image), threshold)
segmasks = _create_segmasks(detected)
if int(dilation) != 0:
segmasks = _dilate_masks(segmasks, int(dilation))
H = int(image.shape[1])
W = int(image.shape[2])
items = []
for (bbox_xyxy, full_mask, conf), label in zip(segmasks, detected[0]):
x1, y1, x2, y2 = map(int, bbox_xyxy)
if (x2 - x1) > drop_size and (y2 - y1) > drop_size:
crop_region = _make_crop_region(W, H, (x1, y1, x2, y2), float(crop_factor))
if detailer_hook is not None and hasattr(detailer_hook, "post_crop_region"):
crop_region = detailer_hook.post_crop_region(W, H, (x1, y1, x2, y2), crop_region)
cropped_image = _crop_tensor_image(image, crop_region)
cropped_mask = _crop_ndarray(full_mask, crop_region).astype(np.float32)
items.append(SEG(cropped_image, cropped_mask, float(conf), crop_region, (x1, y1, x2, y2), str(label), None))
segs = ((H, W), items)
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
segs = detailer_hook.post_detection(segs)
return segs
def detect_combined(self, image, threshold, dilation):
detected = _inference_bbox(self.bbox_model, _tensor_to_pil(image), threshold)
segmasks = _create_segmasks(detected)
if int(dilation) != 0:
segmasks = _dilate_masks(segmasks, int(dilation))
return _combine_masks(segmasks)
def setAux(self, x):
# kept for signature parity
pass
def _load_ultralytics_model(model_path: str):
# Import here so that module import doesn't hard-fail if ultralytics is missing
try:
from ultralytics import YOLO
except Exception as e:
raise RuntimeError(
"[FaceDetailerStandalone] The 'ultralytics' package is required for the embedded bbox detector.\n"
"Install in your ComfyUI python: python -m pip install --upgrade ultralytics"
) from e
if not os.path.isfile(model_path):
raise FileNotFoundError(
"[FaceDetailerStandalone] Embedded YOLO model file not found.\n"
f"Expected at: {model_path}\n"
"Please place 'face_yolov8m_salia.pt' in the 'assets' folder next to this node."
)
yolo = YOLO(model_path)
# One-time graph/model optimizations
try:
dev = _pick_device_str()
try:
yolo.to(dev) # newer Ultralytics
except Exception:
yolo.model.to(dev) # older versions
except Exception:
pass
# Fuse Conv+BN where possible (small speedup)
try:
yolo.fuse()
except Exception:
pass
# Use half precision weights on CUDA (big win; safe for inference)
try:
if torch.cuda.is_available():
yolo.model.half()
except Exception:
pass
return yolo
def _get_embedded_detector():
global _CACHED_YOLO_MODEL, _CACHED_ULTRA_DETECTOR
if _CACHED_ULTRA_DETECTOR is not None:
return _CACHED_ULTRA_DETECTOR
if _CACHED_YOLO_MODEL is None:
_CACHED_YOLO_MODEL = _load_ultralytics_model(YOLO_MODEL_PATH)
_CACHED_ULTRA_DETECTOR = UltraBBoxDetector(_CACHED_YOLO_MODEL)
return _CACHED_ULTRA_DETECTOR
# ---------------- Embedded SAM loader (GPU-only, hardcoded path, reuse one predictor) ----------------
# Matches your SAMLoaderStandalone design, but embedded + cached.
def _to_numpy_rgb(image_tensor):
"""
Comfy 'IMAGE' is NHWC in [0..1]. Convert to uint8 HxWx3 RGB numpy.
Accepts torch.Tensor (NHWC) or numpy already in HWC.
"""
if isinstance(image_tensor, torch.Tensor):
img = image_tensor
if img.dim() == 4 and img.shape[0] == 1:
img = img[0]
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).cpu().numpy() # HWC
return img
elif isinstance(image_tensor, np.ndarray):
if image_tensor.dtype != np.uint8:
img = np.clip(image_tensor, 0, 255).astype(np.uint8)
else:
img = image_tensor
return img
else:
raise TypeError(f"Unsupported image type for SAM: {type(image_tensor)}")
class _SAMWrapperGPUOnlyFast:
"""
FaceDetailer-compatible wrapper:
- Stays on CUDA
- Reuses a single SamPredictor
- predict(image, points, plabs, bbox, threshold) -> list[HxW float32 CPU masks]
"""
def __init__(self, model):
self.model = model
dev = comfy.model_management.get_torch_device()
if "cuda" not in str(dev).lower():
raise RuntimeError(
f"[FaceDetailerStandalone] GPU-only SAM: CUDA device not available (got '{dev}')."
)
self._device = dev
self.model.to(self._device).eval()
# Lazy import for segment_anything predictor
from segment_anything import SamPredictor # type: ignore
# Reuse one predictor instance (cheaper than re-creating every call)
self._predictor = SamPredictor(self.model)
def prepare_device(self):
if "cuda" not in str(self._device).lower():
raise RuntimeError("[FaceDetailerStandalone] CUDA device lost/unavailable for SAM.")
def release_device(self):
# GPU-only; keep on GPU (no-op)
pass
@torch.inference_mode()
def predict(self, image, points, plabs, bbox, threshold: float):
"""
image: Comfy IMAGE (NHWC, [0..1]) or numpy
points: list[[x,y], ...] or None
plabs: list[int] (1=fg, 0=bg) or None
bbox: [x1,y1,x2,y2] or None
threshold: float in [0..1]
returns: list of HxW float32 CPU masks (0/1)
"""
self.prepare_device()
np_img = _to_numpy_rgb(image)
# Some builds call set_image(img, "RGB"); accept both signatures.
try:
self._predictor.set_image(np_img, "RGB")
except TypeError:
self._predictor.set_image(np_img)
pc = np.array(points, dtype=np.float32) if points else None
pl = np.array(plabs, dtype=np.int32) if plabs else None
bx = np.array(bbox, dtype=np.float32) if bbox is not None else None
# Keep provided behavior: multimask_output=False
masks, scores, _ = self._predictor.predict(
point_coords=pc,
point_labels=pl,
box=bx,
multimask_output=False
)
out = []
if masks is not None and scores is not None:
for m, s in zip(masks, scores):
if float(s) >= float(threshold):
if isinstance(m, torch.Tensor):
t = m.to(torch.float32).cpu()
else:
t = torch.from_numpy(m.astype(np.float32)).cpu()
out.append(t)
return out
# Cache for SAM
_CACHED_SAM_MODEL = None
def _get_embedded_sam():
"""Load SAM vit_b from SAM_CKPT_PATH and attach GPU-only fast wrapper, cached."""
global _CACHED_SAM_MODEL
if _CACHED_SAM_MODEL is not None:
return _CACHED_SAM_MODEL
if not os.path.isfile(SAM_CKPT_PATH):
raise FileNotFoundError(
f"[FaceDetailerStandalone] SAM checkpoint not found:\n {SAM_CKPT_PATH}\n"
f"Place 'sam_vit_b_01ec64_salia.pth' in the 'assets' folder next to this node."
)
# Import here to avoid module import failure at file load time
try:
from segment_anything import sam_model_registry # type: ignore
except Exception as e:
raise RuntimeError(
"[FaceDetailerStandalone] 'segment_anything' is not installed for embedded SAM. "
"Install in your Comfy python, e.g.: python -m pip install "
"git+https://github.com/facebookresearch/segment-anything"
) from e
# Fixed to vit_b (matches 'sam_vit_b_01ec64' weights)
sam = sam_model_registry['vit_b'](checkpoint=SAM_CKPT_PATH)
sam.eval() # ensure eval mode
# Attach GPU-only, faster wrapper
wrapper = _SAMWrapperGPUOnlyFast(sam)
sam.sam_wrapper = wrapper
_CACHED_SAM_MODEL = sam
return _CACHED_SAM_MODEL
# ---------------- Impact Pack Face Detailer binding ----------------
_ENHANCE_FACE = None
_IMPORT_ERR = None
try:
from impact.impact_pack import FaceDetailer as _FD
_ENHANCE_FACE = _FD.enhance_face
except Exception as _e:
_IMPORT_ERR = _e
_ENHANCE_FACE = None
# ---------------- Single public node ----------------
class dn_04:
@classmethod
def INPUT_TYPES(cls):
# Only essential, connectable parts remain editable. (No bbox or SAM inputs.)
return {
"required": {
"image": ("IMAGE",),
"model": ("MODEL", {"tooltip": "If `ImpactDummyInput` is connected to model, inference is skipped."}),
"clip": ("CLIP",),
"vae": ("VAE",),
# Keep sampler selectable; all other knobs are fixed
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
# Conditioning stays connectable
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
# Keep seed editable but fixed after generate for reproducibility
"seed": ("INT", {
"default": 0,
"min": 0,
"max": 0xffffffffffffffff,
"step": 1,
"control_after_generate": "fixed",
}),
},
"optional": {
# No external SAM input; embedded
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "doit"
CATEGORY = "ImpactPack/Standalone"
DESCRIPTION = (
"Face Detailer with requested parameters hardcoded (non-editable), "
"and embedded Ultralytics face bbox detector + embedded SAM (no external input nodes). "
"Optimized call path (cached imports + inference_mode) for lower overhead; "
"results identical to Impact Pack Face Detailer at the same settings."
)
def doit(
self,
image, model, clip, vae,
sampler_name,
positive, negative,
seed,
):
if _ENHANCE_FACE is None:
raise RuntimeError(
"ComfyUI-Impact-Pack is required for Face Detailer logic. "
"Please install/enable ComfyUI-Impact-Pack."
) from _IMPORT_ERR
# Embedded detector & SAM (cached)
bbox_detector = _get_embedded_detector()
sam_model_opt = _get_embedded_sam()
enhance = _ENHANCE_FACE
# Determine batch size safely
B = image.shape[0] if (hasattr(image, "shape") and image.ndim == 4) else 1
# No autograd, faster kernel choices, identical math for inference
with torch.inference_mode():
if B == 1:
# Fast-path for single image (avoid list + cat)
single = image[0] if image.ndim == 4 else image # [H,W,C]
enhanced_img, _, _, _, _ = enhance(
single.unsqueeze(0), # -> [1,H,W,C]
model, clip, vae,
512, True, 1024, # guide_size, guide_for_bbox, max_size
seed, 38, 7.0, # steps, cfg
sampler_name, "simple", # scheduler name
positive, negative,
0.4, 5, True, True, # denoise, feather, noise_mask, force_inpaint
0.5, 10, 3.0, # bbox_threshold, bbox_dilation, bbox_crop_factor
"center-1", 0, 0.93, 0, # sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion
0.7, "False", # sam_mask_hint_threshold, sam_mask_hint_use_negative
10, bbox_detector, # drop_size, bbox_detector
# Internals not exposed (kept fixed/None)
segm_detector=None, sam_model_opt=sam_model_opt,
wildcard_opt="", detailer_hook=None,
refiner_ratio=None, refiner_model=None, refiner_clip=None,
refiner_positive=None, refiner_negative=None,
cycle=1, inpaint_model=False,
noise_mask_feather=20,
scheduler_func_opt=None,
tiled_encode=False, tiled_decode=False,
)
return (enhanced_img,)
# Batch of images; per-frame process with seed+i
out_imgs = []
for i, single in enumerate(image.unbind(0)):
enhanced_img, _, _, _, _ = enhance(
single.unsqueeze(0), # [1,H,W,C]
model, clip, vae,
512, True, 1024,
seed + i, 30, 7.0,
sampler_name, "simple",
positive, negative,
0.5, 5, True, True,
0.5, 10, 3.0,
"center-1", 0, 0.93, 0,
0.7, "False",
10, bbox_detector,
segm_detector=None, sam_model_opt=sam_model_opt,
wildcard_opt="", detailer_hook=None,
refiner_ratio=None, refiner_model=None, refiner_clip=None,
refiner_positive=None, refiner_negative=None,
cycle=1, inpaint_model=False,
noise_mask_feather=20,
scheduler_func_opt=None,
tiled_encode=False, tiled_decode=False,
)
out_imgs.append(enhanced_img)
return (torch.cat(out_imgs, dim=0),)
NODE_CLASS_MAPPINGS = {
"dn_04": dn_04,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"dn_04": "dn_04",
}