# AILab_SAM3Segment.py # Integrated standalone nodes: # - SAM3Segment # - Salia_ezpz_gated_Duo2 # - apply_segment_4 # - SAM3Segment_Salia (fused) import os import sys import hashlib import shutil import threading import urllib.request import heapq from contextlib import nullcontext from pathlib import Path from typing import Any, Dict, Tuple, Optional, List import numpy as np import torch import torch.nn.functional as F from PIL import Image, ImageFilter, ImageOps from torch.hub import download_url_to_file import folder_paths import comfy.model_management import comfy.model_management as model_management from AILab_ImageMaskTools import pil2tensor, tensor2pil # ====================================================================================== # SAM3Segment (original, with syntax fix) # ====================================================================================== CURRENT_DIR = os.path.dirname(__file__) SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3") if SAM3_LOCAL_DIR not in sys.path: sys.path.insert(0, SAM3_LOCAL_DIR) SAM3_BPE_PATH = os.path.join(SAM3_LOCAL_DIR, "assets", "bpe_simple_vocab_16e6.txt.gz") if not os.path.isfile(SAM3_BPE_PATH): raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.") from sam3.model_builder import build_sam3_image_model # noqa: E402 from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402 _DEFAULT_PT_ENTRY = { "model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt", "filename": "sam3.pt", } SAM3_MODELS = { "sam3": _DEFAULT_PT_ENTRY.copy(), } def get_sam3_pt_models(): entry = SAM3_MODELS.get("sam3") if entry and entry.get("filename", "").endswith(".pt"): return {"sam3": entry} for key, value in SAM3_MODELS.items(): if value.get("filename", "").endswith(".pt"): return {"sam3": value} if "sam3" in key and value: candidate = value.copy() candidate["model_url"] = _DEFAULT_PT_ENTRY["model_url"] candidate["filename"] = _DEFAULT_PT_ENTRY["filename"] return {"sam3": candidate} return {"sam3": _DEFAULT_PT_ENTRY.copy()} def process_mask(mask_image, invert_output=False, mask_blur=0, mask_offset=0): if invert_output: mask_np = np.array(mask_image) mask_image = Image.fromarray(255 - mask_np) if mask_blur > 0: mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) if mask_offset != 0: filt = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter size = abs(mask_offset) * 2 + 1 for _ in range(abs(mask_offset)): mask_image = mask_image.filter(filt(size)) return mask_image def apply_background_color(image, mask_image, background="Alpha", background_color="#222222"): rgba_image = image.copy().convert("RGBA") rgba_image.putalpha(mask_image.convert("L")) if background == "Color": hex_color = background_color.lstrip("#") r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) bg_image = Image.new("RGBA", image.size, (r, g, b, 255)) composite = Image.alpha_composite(bg_image, rgba_image) return composite.convert("RGB") return rgba_image def get_or_download_model_file(filename, url): local_path = None if hasattr(folder_paths, "get_full_path"): local_path = folder_paths.get_full_path("sam3", filename) if local_path and os.path.isfile(local_path): return local_path base_models_dir = getattr(folder_paths, "models_dir", os.path.join(CURRENT_DIR, "models")) models_dir = os.path.join(base_models_dir, "sam3") os.makedirs(models_dir, exist_ok=True) local_path = os.path.join(models_dir, filename) if not os.path.exists(local_path): print(f"Downloading {filename} from {url} ...") download_url_to_file(url, local_path) return local_path def _resolve_device(user_choice): auto_device = comfy.model_management.get_torch_device() if user_choice == "CPU": return torch.device("cpu") if user_choice == "GPU": if auto_device.type != "cuda": raise RuntimeError("GPU unavailable") return torch.device("cuda") return auto_device class SAM3Segment: @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Describe the concept"}), "sam3_model": (list(SAM3_MODELS.keys()), {"default": "sam3"}), "device": (["Auto", "CPU", "GPU"], {"default": "Auto"}), "confidence_threshold": ("FLOAT", {"default": 0.5, "min": 0.05, "max": 0.95, "step": 0.01}), }, "optional": { "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}), "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}), "invert_output": ("BOOLEAN", {"default": False}), "unload_model": ("BOOLEAN", {"default": False}), "background": (["Alpha", "Color"], {"default": "Alpha"}), "background_color": ("COLORCODE", {"default": "#222222"}), }, } RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE") FUNCTION = "segment" CATEGORY = "🧪AILab/🧽RMBG" def __init__(self): self.processor_cache = {} def _load_processor(self, model_choice, device_choice): torch_device = _resolve_device(device_choice) device_str = "cuda" if torch_device.type == "cuda" else "cpu" cache_key = (model_choice, device_str) if cache_key not in self.processor_cache: model_info = SAM3_MODELS[model_choice] ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"]) model = build_sam3_image_model( bpe_path=SAM3_BPE_PATH, device=device_str, eval_mode=True, checkpoint_path=ckpt_path, load_from_HF=False, enable_segmentation=True, enable_inst_interactivity=False, ) processor = Sam3Processor(model, device=device_str) self.processor_cache[cache_key] = processor return self.processor_cache[cache_key], torch_device def _empty_result(self, img_pil, background, background_color): w, h = img_pil.size mask_image = Image.new("L", (w, h), 0) result_image = apply_background_color(img_pil, mask_image, background, background_color) if background == "Alpha": result_image = result_image.convert("RGBA") else: result_image = result_image.convert("RGB") empty_mask = torch.zeros((1, h, w), dtype=torch.float32) mask_rgb = empty_mask.reshape((-1, 1, h, w)).movedim(1, -1).expand(-1, -1, -1, 3) return result_image, empty_mask, mask_rgb def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask_offset, invert, background, background_color): img_pil = tensor2pil(img_tensor) text = prompt.strip() or "object" state = processor.set_image(img_pil) processor.reset_all_prompts(state) processor.set_confidence_threshold(confidence, state) state = processor.set_text_prompt(text, state) masks = state.get("masks") if masks is None or masks.numel() == 0: return self._empty_result(img_pil, background, background_color) masks = masks.float().to("cpu") if masks.ndim == 4: masks = masks.squeeze(1) combined = masks.amax(dim=0) mask_np = (combined.clamp(0, 1).numpy() * 255).astype(np.uint8) mask_image = Image.fromarray(mask_np, mode="L") mask_image = process_mask(mask_image, invert, mask_blur, mask_offset) result_image = apply_background_color(img_pil, mask_image, background, background_color) if background == "Alpha": result_image = result_image.convert("RGBA") else: result_image = result_image.convert("RGB") mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0) mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3) return result_image, mask_tensor, mask_rgb def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"): if image.ndim == 3: image = image.unsqueeze(0) processor, torch_device = self._load_processor(sam3_model, device) autocast_device = comfy.model_management.get_autocast_device(torch_device) autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device) ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext() result_images, result_masks, result_mask_images = [], [], [] with ctx: for tensor_img in image: img_pil, mask_tensor, mask_rgb = self._run_single( processor, tensor_img, prompt, confidence_threshold, mask_blur, mask_offset, invert_output, background, background_color, ) result_images.append(pil2tensor(img_pil)) result_masks.append(mask_tensor) result_mask_images.append(mask_rgb) if unload_model: device_str = "cuda" if torch_device.type == "cuda" else "cpu" cache_key = (sam3_model, device_str) if cache_key in self.processor_cache: del self.processor_cache[cache_key] if torch_device.type == "cuda": torch.cuda.empty_cache() return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0) # ====================================================================================== # Salia_ezpz_gated_Duo2 (standalone) # ====================================================================================== # transformers is required for depth-estimation pipeline try: from transformers import pipeline except Exception as e: pipeline = None _TRANSFORMERS_IMPORT_ERROR = e _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {} _CN_CACHE: Dict[str, Any] = {} _CKPT_LOCK = threading.Lock() _CN_LOCK = threading.Lock() def _find_plugin_root() -> Path: """ Walk upwards from this file until we find an 'assets' folder. If not found, fall back to this file's directory. """ here = Path(__file__).resolve() for parent in [here.parent] + list(here.parents)[:12]: if (parent / "assets").is_dir(): return parent return here.parent PLUGIN_ROOT = _find_plugin_root() def _pil_lanczos(): if hasattr(Image, "Resampling"): return Image.Resampling.LANCZOS return Image.LANCZOS def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image: if img.ndim == 4: img = img[0] img = img.detach().cpu().float().clamp(0, 1) arr = (img.numpy() * 255.0).round().astype(np.uint8) if arr.shape[-1] == 4: return Image.fromarray(arr, mode="RGBA") return Image.fromarray(arr, mode="RGB") def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor: if pil.mode not in ("RGB", "RGBA"): pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB") arr = np.array(pil).astype(np.float32) / 255.0 t = torch.from_numpy(arr) return t.unsqueeze(0) def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image: if mask.ndim == 3: mask = mask[0] mask = mask.detach().cpu().float().clamp(0, 1) arr = (mask.numpy() * 255.0).round().astype(np.uint8) return Image.fromarray(arr, mode="L") def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor: if pil_l.mode != "L": pil_l = pil_l.convert("L") arr = np.array(pil_l).astype(np.float32) / 255.0 t = torch.from_numpy(arr) return t.unsqueeze(0) def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor: if img.ndim != 4: raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].") outs = [] for i in range(img.shape[0]): pil = _image_tensor_to_pil(img[i].unsqueeze(0)) pil = pil.resize((int(w), int(h)), resample=_pil_lanczos()) outs.append(_pil_to_image_tensor(pil)) return torch.cat(outs, dim=0) def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor: if mask.ndim != 3: raise ValueError("Expected MASK tensor with shape [B,H,W].") outs = [] for i in range(mask.shape[0]): pil = _mask_tensor_to_pil(mask[i].unsqueeze(0)) pil = pil.resize((int(w), int(h)), resample=_pil_lanczos()) outs.append(_pil_to_mask_tensor(pil)) return torch.cat(outs, dim=0) def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: if rgb.ndim == 3: rgb = rgb.unsqueeze(0) if mask.ndim == 2: mask = mask.unsqueeze(0) if rgb.ndim != 4 or rgb.shape[-1] != 3: raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}") if mask.ndim != 3: raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}") if mask.shape[0] != rgb.shape[0]: if mask.shape[0] == 1 and rgb.shape[0] > 1: mask = mask.expand(rgb.shape[0], -1, -1) else: raise ValueError("Batch mismatch between rgb and mask.") if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]: raise ValueError( f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}" ) mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1) alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) return rgba def _load_checkpoint_cached(ckpt_name: str): with _CKPT_LOCK: if ckpt_name in _CKPT_CACHE: return _CKPT_CACHE[ckpt_name] import nodes loader = nodes.CheckpointLoaderSimple() fn = getattr(loader, loader.FUNCTION) model, clip, vae = fn(ckpt_name=ckpt_name) _CKPT_CACHE[ckpt_name] = (model, clip, vae) return model, clip, vae def _load_controlnet_cached(control_net_name: str): with _CN_LOCK: if control_net_name in _CN_CACHE: return _CN_CACHE[control_net_name] import nodes loader = nodes.ControlNetLoader() fn = getattr(loader, loader.FUNCTION) (cn,) = fn(control_net_name=control_net_name) _CN_CACHE[control_net_name] = cn return cn def _assets_images_dir() -> Path: return PLUGIN_ROOT / "assets" / "images" def _list_asset_pngs() -> list: img_dir = _assets_images_dir() if not img_dir.is_dir(): return [] files = [] for p in img_dir.rglob("*"): if p.is_file() and p.suffix.lower() == ".png": files.append(p.relative_to(img_dir).as_posix()) files.sort() return files def _safe_asset_path(asset_rel_path: str) -> Path: img_dir = _assets_images_dir() if not img_dir.is_dir(): raise FileNotFoundError(f"assets/images folder not found: {img_dir}") base = img_dir.resolve() rel = Path(asset_rel_path) if rel.is_absolute(): raise ValueError("Absolute paths are not allowed for asset_image.") full = (base / rel).resolve() if base != full and base not in full.parents: raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}") if not full.is_file(): raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}") if full.suffix.lower() != ".png": raise ValueError(f"Asset is not a PNG: {asset_rel_path}") return full def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]: p = _safe_asset_path(asset_rel_path) im = Image.open(p) im = ImageOps.exif_transpose(im) rgba = im.convert("RGBA") rgb = rgba.convert("RGB") rgb_arr = np.array(rgb).astype(np.float32) / 255.0 img_t = torch.from_numpy(rgb_arr)[None, ...] alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 mask = 1.0 - alpha mask_t = torch.from_numpy(mask)[None, ...] return img_t, mask_t MODEL_DIR = PLUGIN_ROOT / "assets" / "depth" MODEL_DIR.mkdir(parents=True, exist_ok=True) REQUIRED_FILES = { "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json", "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors", "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json", } ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti" _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} _PIPE_LOCK = threading.Lock() def _have_required_files() -> bool: return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys()) def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None: dst.parent.mkdir(parents=True, exist_ok=True) tmp = dst.with_suffix(dst.suffix + ".tmp") if tmp.exists(): try: tmp.unlink() except Exception: pass req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"}) with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f: shutil.copyfileobj(r, f) tmp.replace(dst) def ensure_local_model_files() -> bool: if _have_required_files(): return True try: for fname, url in REQUIRED_FILES.items(): fpath = MODEL_DIR / fname if fpath.exists(): continue _download_url_to_file(url, fpath) return _have_required_files() except Exception: return False def HWC3(x: np.ndarray) -> np.ndarray: assert x.dtype == np.uint8 if x.ndim == 2: x = x[:, :, None] assert x.ndim == 3 H, W, C = x.shape assert C == 1 or C == 3 or C == 4 if C == 3: return x if C == 1: return np.concatenate([x, x, x], axis=2) color = x[:, :, 0:3].astype(np.float32) alpha = x[:, :, 3:4].astype(np.float32) / 255.0 y = color * alpha + 255.0 * (1.0 - alpha) y = y.clip(0, 255).astype(np.uint8) return y def pad64(x: int) -> int: return int(np.ceil(float(x) / 64.0) * 64 - x) def safer_memory(x: np.ndarray) -> np.ndarray: return np.ascontiguousarray(x.copy()).copy() def resize_image_with_pad_min_side( input_image: np.ndarray, resolution: int, upscale_method: str = "INTER_CUBIC", skip_hwc3: bool = False, mode: str = "edge", ) -> Tuple[np.ndarray, Any]: cv2 = None try: import cv2 as _cv2 cv2 = _cv2 except Exception: cv2 = None img = input_image if skip_hwc3 else HWC3(input_image) H_raw, W_raw, _ = img.shape if resolution <= 0: return img, (lambda x: x) k = float(resolution) / float(min(H_raw, W_raw)) H_target = int(np.round(float(H_raw) * k)) W_target = int(np.round(float(W_raw) * k)) if cv2 is not None: upscale_methods = { "INTER_NEAREST": cv2.INTER_NEAREST, "INTER_LINEAR": cv2.INTER_LINEAR, "INTER_AREA": cv2.INTER_AREA, "INTER_CUBIC": cv2.INTER_CUBIC, "INTER_LANCZOS4": cv2.INTER_LANCZOS4, } method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC) img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA) else: pil = Image.fromarray(img) resample = Image.BICUBIC if k > 1 else Image.LANCZOS pil = pil.resize((W_target, H_target), resample=resample) img = np.array(pil, dtype=np.uint8) H_pad, W_pad = pad64(H_target), pad64(W_target) img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) def remove_pad(x: np.ndarray) -> np.ndarray: return safer_memory(x[:H_target, :W_target, ...]) return safer_memory(img_padded), remove_pad def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]: img = HWC3(img_u8) H_raw, W_raw, _ = img.shape H_pad, W_pad = pad64(H_raw), pad64(W_raw) img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) def remove_pad(x: np.ndarray) -> np.ndarray: return safer_memory(x[:H_raw, :W_raw, ...]) return safer_memory(img_padded), remove_pad def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]: if inp_u8.ndim == 3 and inp_u8.shape[2] == 4: rgba = inp_u8.astype(np.uint8) rgb = rgba[:, :, 0:3].astype(np.float32) a = (rgba[:, :, 3:4].astype(np.float32) / 255.0) rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8) alpha_u8 = rgba[:, :, 3].copy() return rgb_white, alpha_u8 return HWC3(inp_u8), None def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray: depth_rgb_u8 = HWC3(depth_rgb_u8) a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None] out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8) return out def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray: if img.ndim == 4: img = img[0] arr = img.detach().cpu().float().clamp(0, 1).numpy() u8 = (arr * 255.0).round().astype(np.uint8) return u8 def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor: img_u8 = HWC3(img_u8) t = torch.from_numpy(img_u8.astype(np.float32) / 255.0) return t.unsqueeze(0) def _try_load_pipeline(model_source: str, device: torch.device): if pipeline is None: raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}") key = (model_source, str(device)) with _PIPE_LOCK: if key in _PIPE_CACHE: return _PIPE_CACHE[key] p = pipeline(task="depth-estimation", model=model_source) try: p.model = p.model.to(device) p.device = device except Exception: pass _PIPE_CACHE[key] = p return p def get_depth_pipeline(device: torch.device): if ensure_local_model_files(): try: return _try_load_pipeline(str(MODEL_DIR), device) except Exception: pass try: return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device) except Exception: return None def depth_estimate_zoe_style( pipe, input_rgb_u8: np.ndarray, detect_resolution: int, upscale_method: str = "INTER_CUBIC", ) -> np.ndarray: if detect_resolution == -1: work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge") else: work_img, remove_pad = resize_image_with_pad_min_side( input_rgb_u8, int(detect_resolution), upscale_method=upscale_method, skip_hwc3=False, mode="edge", ) pil_image = Image.fromarray(work_img) with torch.no_grad(): result = pipe(pil_image) depth = result["depth"] if isinstance(depth, Image.Image): depth_array = np.array(depth, dtype=np.float32) else: depth_array = np.array(depth, dtype=np.float32) vmin = float(np.percentile(depth_array, 2)) vmax = float(np.percentile(depth_array, 85)) depth_array = depth_array - vmin denom = (vmax - vmin) if abs(denom) < 1e-12: denom = 1e-6 depth_array = depth_array / denom depth_array = 1.0 - depth_array depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8) detected_map = remove_pad(HWC3(depth_image)) return detected_map def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray: try: import cv2 out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR) return out.astype(np.uint8) except Exception: pil = Image.fromarray(depth_rgb_u8) pil = pil.resize((w0, h0), resample=Image.BILINEAR) return np.array(pil, dtype=np.uint8) def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor: try: device = model_management.get_torch_device() except Exception: device = torch.device("cpu") pipe_obj = None try: pipe_obj = get_depth_pipeline(device) except Exception: pipe_obj = None if pipe_obj is None: return image if image.ndim == 3: image = image.unsqueeze(0) outs = [] for i in range(image.shape[0]): try: h0 = int(image[i].shape[0]) w0 = int(image[i].shape[1]) inp_u8 = comfy_tensor_to_u8(image[i]) rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8) had_rgba = alpha_u8 is not None depth_rgb = depth_estimate_zoe_style( pipe=pipe_obj, input_rgb_u8=rgb_for_depth, detect_resolution=int(resolution), upscale_method="INTER_CUBIC", ) depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0) if had_rgba: if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0: try: import cv2 alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8) except Exception: pil_a = Image.fromarray(alpha_u8) pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR) alpha_u8 = np.array(pil_a, dtype=np.uint8) depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8) outs.append(u8_to_comfy_tensor(depth_rgb)) except Exception: outs.append(image[i].unsqueeze(0)) return torch.cat(outs, dim=0) def _salia_alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor: if base.ndim != 4 or overlay_rgba.ndim != 4: raise ValueError("base and overlay must be [B,H,W,C].") B, H, W, C = base.shape b2, sH, sW, c2 = overlay_rgba.shape if c2 != 4: raise ValueError("overlay_rgba must have 4 channels (RGBA).") if sH != sW: raise ValueError("overlay must be square.") s = sH if x < 0 or y < 0 or x + s > W or y + s > H: raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}") if b2 != B: if b2 == 1 and B > 1: overlay_rgba = overlay_rgba.expand(B, -1, -1, -1) else: raise ValueError("Batch mismatch between base and overlay.") out = base.clone() overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1) overlay_a = overlay_rgba[..., 3:4].clamp(0, 1) base_rgb = out[:, y:y + s, x:x + s, 0:3] comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a) out[:, y:y + s, x:x + s, 0:3] = comp_rgb if C == 4: base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1) comp_a = overlay_a + base_a * (1.0 - overlay_a) out[:, y:y + s, x:x + s, 3:4] = comp_a return out.clamp(0, 1) _HARDCODED_CKPT_NAME = "SaliaHighlady_Speedy.safetensors" _HARDCODED_CONTROLNET_NAME = "diffusion_pytorch_model_promax.safetensors" _HARDCODED_CN_START = 0.00 _HARDCODED_CN_END = 1.00 _PASS1_SAMPLER_NAME = "dpmpp_2m_sde_heun_gpu" _PASS1_SCHEDULER = "karras" _PASS1_STEPS = 29 _PASS1_CFG = 2.6 _PASS1_CONTROLNET_STRENGTH = 0.33 _PASS2_SAMPLER_NAME = "res_multistep_ancestral_cfg_pp" _PASS2_SCHEDULER = "karras" _PASS2_STEPS = 30 _PASS2_CFG = 1.7 _PASS2_CONTROLNET_STRENGTH = 0.5 class Salia_ezpz_gated_Duo2: CATEGORY = "image/salia" RETURN_TYPES = ("IMAGE", "IMAGE") RETURN_NAMES = ("image", "image_cropped") FUNCTION = "run" @classmethod def INPUT_TYPES(cls): assets = _list_asset_pngs() or [""] upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"] return { "required": { "image": ("IMAGE",), "trigger_string": ("STRING", {"default": ""}), "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), "positive_prompt": ("STRING", {"default": "", "multiline": True}), "negative_prompt": ("STRING", {"default": "", "multiline": True}), "asset_image": (assets, {}), "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), "upscale_factor_1": (upscale_choices, {"default": "4"}), "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), "upscale_factor_2": (upscale_choices, {"default": "4"}), "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), } } def run( self, image: torch.Tensor, trigger_string: str = "", X_coord: int = 0, Y_coord: int = 0, positive_prompt: str = "", negative_prompt: str = "", asset_image: str = "", square_size_1: int = 384, upscale_factor_1: str = "4", denoise_1: float = 0.35, square_size_2: int = 384, upscale_factor_2: str = "4", denoise_2: float = 0.35, ): if image.ndim == 3: image = image.unsqueeze(0) if image.ndim != 4: raise ValueError("Input image must be [B,H,W,C].") B, H, W, C = image.shape if C not in (3, 4): raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.") x = int(X_coord) y = int(Y_coord) s1 = int(square_size_1) s2 = int(square_size_2) def _validate_square_bounds(s: int, label: str): if s <= 0: raise ValueError(f"{label}: square_size must be > 0") if x < 0 or y < 0 or x + s > W or y + s > H: raise ValueError(f"{label}: out of bounds. image={W}x{H}, rect at ({x},{y}) size={s}") def _validate_upscale(up: int, s: int, label: str): if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16): raise ValueError(f"{label}: upscale_factor must be one of 1,2,4,6,8,10,12,14,16") if ((s * up) % 8) != 0: raise ValueError(f"{label}: square_size * upscale_factor must be divisible by 8 (VAE requirement).") def _crop_square(img: torch.Tensor, s: int) -> torch.Tensor: return img[:, y:y + s, x:x + s, :] _validate_square_bounds(s2, "final crop (square_size_2)") if trigger_string == "": out2 = image cropped = _crop_square(out2, s2) return (out2, cropped) _validate_square_bounds(s1, "pass1 (square_size_1)") _validate_square_bounds(s2, "pass2 (square_size_2)") up1 = int(upscale_factor_1) up2 = int(upscale_factor_2) _validate_upscale(up1, s1, "pass1") _validate_upscale(up2, s2, "pass2") d1 = float(max(0.0, min(1.0, denoise_1))) d2 = float(max(0.0, min(1.0, denoise_2))) if asset_image == "": raise FileNotFoundError("No PNGs found in assets/images for this plugin.") _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image) if asset_mask.ndim == 2: asset_mask = asset_mask.unsqueeze(0) if asset_mask.ndim != 3: raise ValueError("Asset mask must be [B,H,W].") if asset_mask.shape[0] != B: if asset_mask.shape[0] == 1 and B > 1: asset_mask = asset_mask.expand(B, -1, -1) else: raise ValueError("Batch mismatch for asset mask vs input image batch.") import nodes try: model, clip, vae = _load_checkpoint_cached(_HARDCODED_CKPT_NAME) except Exception as e: available = folder_paths.get_filename_list("checkpoints") or [] raise FileNotFoundError( f"Hardcoded ckpt not found: '{_HARDCODED_CKPT_NAME}'. " f"Put it in models/checkpoints. Available (first 50): {available[:50]}" ) from e try: controlnet = _load_controlnet_cached(_HARDCODED_CONTROLNET_NAME) except Exception as e: available = folder_paths.get_filename_list("controlnet") or [] raise FileNotFoundError( f"Hardcoded controlnet not found: '{_HARDCODED_CONTROLNET_NAME}'. " f"Put it in models/controlnet. Available (first 50): {available[:50]}" ) from e pos_enc = nodes.CLIPTextEncode() neg_enc = nodes.CLIPTextEncode() pos_fn = getattr(pos_enc, pos_enc.FUNCTION) neg_fn = getattr(neg_enc, neg_enc.FUNCTION) (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip) (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip) cn_apply = nodes.ControlNetApplyAdvanced() cn_fn = getattr(cn_apply, cn_apply.FUNCTION) vae_enc = nodes.VAEEncode() vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION) ksampler = nodes.KSampler() k_fn = getattr(ksampler, ksampler.FUNCTION) vae_dec = nodes.VAEDecode() vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION) def _run_pass( pass_index: int, in_image: torch.Tensor, s: int, up: int, denoise_v: float, steps_v: int, cfg_v: float, sampler_v: str, scheduler_v: str, controlnet_strength_v: float, ) -> torch.Tensor: up_w = s * up up_h = s * up crop = in_image[:, y:y + s, x:x + s, :] crop_rgb = crop[:, :, :, 0:3].contiguous() depth_small = _salia_depth_execute(crop_rgb, resolution=s) depth_up = _resize_image_lanczos(depth_small, up_w, up_h) crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h) asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h) pos_cn, neg_cn = cn_fn( strength=float(controlnet_strength_v), start_percent=float(_HARDCODED_CN_START), end_percent=float(_HARDCODED_CN_END), positive=pos_cond, negative=neg_cond, control_net=controlnet, image=depth_up, vae=vae, ) (latent,) = vae_enc_fn(pixels=crop_up, vae=vae) seed_material = ( f"{_HARDCODED_CKPT_NAME}|{_HARDCODED_CONTROLNET_NAME}|{asset_image}|" f"pass={pass_index}|x={x}|y={y}|s={s}|up={up}|" f"steps={steps_v}|cfg={cfg_v}|sampler={sampler_v}|scheduler={scheduler_v}|denoise={denoise_v}|" f"cn_strength={controlnet_strength_v}|" f"{positive_prompt}|{negative_prompt}" ).encode("utf-8", errors="ignore") seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16) (sampled_latent,) = k_fn( seed=seed64, steps=int(steps_v), cfg=float(cfg_v), sampler_name=str(sampler_v), scheduler=str(scheduler_v), denoise=float(denoise_v), model=model, positive=pos_cn, negative=neg_cn, latent_image=latent, ) (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae) rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up) rgba_square = _resize_image_lanczos(rgba_up, s, s) out = _salia_alpha_over_region(in_image, rgba_square, x=x, y=y) return out out1 = _run_pass( pass_index=1, in_image=image, s=s1, up=up1, denoise_v=d1, steps_v=_PASS1_STEPS, cfg_v=_PASS1_CFG, sampler_v=_PASS1_SAMPLER_NAME, scheduler_v=_PASS1_SCHEDULER, controlnet_strength_v=_PASS1_CONTROLNET_STRENGTH, ) out2 = _run_pass( pass_index=2, in_image=out1, s=s2, up=up2, denoise_v=d2, steps_v=_PASS2_STEPS, cfg_v=_PASS2_CFG, sampler_v=_PASS2_SAMPLER_NAME, scheduler_v=_PASS2_SCHEDULER, controlnet_strength_v=_PASS2_CONTROLNET_STRENGTH, ) cropped = out2[:, y:y + s2, x:x + s2, :] return (out2, cropped) # ====================================================================================== # apply_segment_4 (standalone, embedded) - rename internal alpha paste helper to avoid clash # ====================================================================================== # Expects: /assets/images/*.png _AP4_ASSETS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "images") def ap4_list_pngs() -> List[str]: if not os.path.isdir(_AP4_ASSETS_DIR): return [] files: List[str] = [] for root, _, fnames in os.walk(_AP4_ASSETS_DIR): for f in fnames: if f.lower().endswith(".png"): full = os.path.join(root, f) if os.path.isfile(full): rel = os.path.relpath(full, _AP4_ASSETS_DIR) files.append(rel.replace("\\", "/")) return sorted(files) def ap4_safe_path(filename: str) -> str: candidate = os.path.join(_AP4_ASSETS_DIR, filename) real_assets = os.path.realpath(_AP4_ASSETS_DIR) real_candidate = os.path.realpath(candidate) if not real_candidate.startswith(real_assets + os.sep) and real_candidate != real_assets: raise ValueError("Unsafe path (path traversal detected).") return real_candidate def ap4_file_hash(filename: str) -> str: path = ap4_safe_path(filename) h = hashlib.sha256() with open(path, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): h.update(chunk) return h.hexdigest() def ap4_load_image_from_assets(filename: str) -> Tuple[torch.Tensor, torch.Tensor]: path = ap4_safe_path(filename) i = Image.open(path) i = ImageOps.exif_transpose(i) if i.mode == "I": i = i.point(lambda px: px * (1 / 255)) rgb = i.convert("RGB") rgb_np = np.array(rgb).astype(np.float32) / 255.0 image = torch.from_numpy(rgb_np)[None, ...] bands = i.getbands() if "A" in bands: a = np.array(i.getchannel("A")).astype(np.float32) / 255.0 alpha = torch.from_numpy(a) else: l = np.array(i.convert("L")).astype(np.float32) / 255.0 alpha = torch.from_numpy(l) mask = 1.0 - alpha mask = mask.clamp(0.0, 1.0).unsqueeze(0) return image, mask def ap4_as_image(img: torch.Tensor) -> torch.Tensor: if not isinstance(img, torch.Tensor): raise TypeError("IMAGE must be a torch.Tensor") if img.dim() != 4: raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}") if img.shape[-1] not in (3, 4): raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}") return img def ap4_as_mask(mask: torch.Tensor) -> torch.Tensor: if not isinstance(mask, torch.Tensor): raise TypeError("MASK must be a torch.Tensor") if mask.dim() == 2: mask = mask.unsqueeze(0) if mask.dim() != 3: raise ValueError(f"Expected MASK shape [B,H,W] or [H,W], got {tuple(mask.shape)}") return mask def ap4_ensure_rgba(img: torch.Tensor) -> torch.Tensor: img = ap4_as_image(img) if img.shape[-1] == 4: return img B, H, W, _ = img.shape alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype) return torch.cat([img, alpha], dim=-1) def ap4_alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor: overlay = ap4_as_image(overlay) canvas = ap4_as_image(canvas) if overlay.shape[0] != canvas.shape[0]: if overlay.shape[0] == 1 and canvas.shape[0] > 1: overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:]) elif canvas.shape[0] == 1 and overlay.shape[0] > 1: canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:]) else: raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}") _, Hc, Wc, Cc = canvas.shape _, Ho, Wo, _ = overlay.shape x = int(x) y = int(y) out = canvas.clone() x0c = max(0, x) y0c = max(0, y) x1c = min(Wc, x + Wo) y1c = min(Hc, y + Ho) if x1c <= x0c or y1c <= y0c: return out x0o = x0c - x y0o = y0c - y x1o = x0o + (x1c - x0c) y1o = y0o + (y1c - y0c) canvas_region = out[:, y0c:y1c, x0c:x1c, :] overlay_region = overlay[:, y0o:y1o, x0o:x1o, :] canvas_rgba = ap4_ensure_rgba(canvas_region) overlay_rgba = ap4_ensure_rgba(overlay_region) over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0) over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0) under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0) under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0) over_pm = over_rgb * over_a under_pm = under_rgb * under_a out_a = over_a + under_a * (1.0 - over_a) out_pm = over_pm + under_pm * (1.0 - over_a) eps = 1e-6 out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm)) out_rgb = out_rgb.clamp(0.0, 1.0) out_a = out_a.clamp(0.0, 1.0) if Cc == 3: out[:, y0c:y1c, x0c:x1c, :] = out_rgb else: out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1) return out class AP4_AILab_MaskCombiner_Exact: def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None): masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None] if len(masks) <= 1: return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),) ref_shape = masks[0].shape masks = [self._resize_if_needed(m, ref_shape) for m in masks] if mode == "combine": result = torch.maximum(masks[0], masks[1]) for mask in masks[2:]: result = torch.maximum(result, mask) elif mode == "intersection": result = torch.minimum(masks[0], masks[1]) else: result = torch.abs(masks[0] - masks[1]) return (torch.clamp(result, 0, 1),) def _resize_if_needed(self, mask, target_shape): if mask.shape == target_shape: return mask if len(mask.shape) == 2: mask = mask.unsqueeze(0) elif len(mask.shape) == 4: mask = mask.squeeze(1) target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0] target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1] resized_masks = [] for i in range(mask.shape[0]): mask_np = mask[i].cpu().numpy() img = Image.fromarray((mask_np * 255).astype(np.uint8)) img_resized = img.resize((target_width, target_height), Image.LANCZOS) mask_resized = np.array(img_resized).astype(np.float32) / 255.0 resized_masks.append(torch.from_numpy(mask_resized)) return torch.stack(resized_masks) def ap4_resize_mask_comfy(alpha_mask: torch.Tensor, image_shape_hwc: Tuple[int, int, int]) -> torch.Tensor: H = int(image_shape_hwc[0]) W = int(image_shape_hwc[1]) return F.interpolate( alpha_mask.reshape((-1, 1, alpha_mask.shape[-2], alpha_mask.shape[-1])), size=(H, W), mode="bilinear", ).squeeze(1) def ap4_join_image_with_alpha_comfy(image: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: image = ap4_as_image(image) alpha = ap4_as_mask(alpha) alpha = alpha.to(device=image.device, dtype=image.dtype) batch_size = min(len(image), len(alpha)) out_images = [] alpha_resized = 1.0 - ap4_resize_mask_comfy(alpha, image.shape[1:]) for i in range(batch_size): out_images.append(torch.cat((image[i][:, :, :3], alpha_resized[i].unsqueeze(2)), dim=2)) return torch.stack(out_images) def ap4_try_get_comfy_model_management(): try: import comfy.model_management as mm # type: ignore return mm except Exception: return None def ap4_gaussian_kernel_1d(kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor: center = (kernel_size - 1) / 2.0 xs = torch.arange(kernel_size, device=device, dtype=dtype) - center kernel = torch.exp(-(xs * xs) / (2.0 * sigma * sigma)) kernel = kernel / kernel.sum() return kernel def ap4_mask_blur(mask: torch.Tensor, amount: int = 8, device: str = "gpu") -> torch.Tensor: mask = ap4_as_mask(mask).clamp(0.0, 1.0) if amount == 0: return mask k = int(amount) if k % 2 == 0: k += 1 sigma = 0.3 * (((k - 1) * 0.5) - 1.0) + 0.8 mm = ap4_try_get_comfy_model_management() if device == "gpu": if mm is not None: proc_device = mm.get_torch_device() else: proc_device = torch.device("cuda") if torch.cuda.is_available() else mask.device elif device == "cpu": proc_device = torch.device("cpu") else: proc_device = mask.device out_device = mask.device if device in ("gpu", "cpu") and mm is not None: out_device = mm.intermediate_device() orig_dtype = mask.dtype x = mask.to(device=proc_device, dtype=torch.float32) _, H, W = x.shape pad = k // 2 pad_mode = "reflect" if (H > pad and W > pad and H > 1 and W > 1) else "replicate" x4 = x.unsqueeze(1) x4 = F.pad(x4, (pad, pad, pad, pad), mode=pad_mode) kern1d = ap4_gaussian_kernel_1d(k, sigma, device=proc_device, dtype=torch.float32) w_h = kern1d.view(1, 1, 1, k) w_v = kern1d.view(1, 1, k, 1) x4 = F.conv2d(x4, w_h) x4 = F.conv2d(x4, w_v) out = x4.squeeze(1).clamp(0.0, 1.0) return out.to(device=out_device, dtype=orig_dtype) def ap4_dilate_mask(mask: torch.Tensor, dilation: int = 3) -> torch.Tensor: mask = ap4_as_mask(mask).clamp(0.0, 1.0) dilation = int(dilation) if dilation == 0: return mask k = abs(dilation) x = mask.unsqueeze(1) if dilation > 0: y = F.max_pool2d(x, kernel_size=k, stride=1, padding=k // 2) else: y = -F.max_pool2d(-x, kernel_size=k, stride=1, padding=k // 2) return y.squeeze(1).clamp(0.0, 1.0) def ap4_fill_holes_grayscale_numpy_heap(f: np.ndarray, connectivity: int = 8) -> np.ndarray: f = np.clip(f, 0.0, 1.0).astype(np.float32, copy=False) H, W = f.shape if H == 0 or W == 0: return f cost = np.full((H, W), np.inf, dtype=np.float32) finalized = np.zeros((H, W), dtype=np.bool_) heap: List[Tuple[float, int, int]] = [] def push(y: int, x: int): c = float(f[y, x]) if c < float(cost[y, x]): cost[y, x] = c heapq.heappush(heap, (c, y, x)) for x in range(W): push(0, x) if H > 1: push(H - 1, x) for y in range(H): push(y, 0) if W > 1: push(y, W - 1) if connectivity == 4: neigh = [(-1, 0), (1, 0), (0, -1), (0, 1)] else: neigh = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] eps = 1e-8 while heap: c, y, x = heapq.heappop(heap) if finalized[y, x]: continue if c > float(cost[y, x]) + eps: continue finalized[y, x] = True for dy, dx in neigh: ny = y + dy nx = x + dx if ny < 0 or ny >= H or nx < 0 or nx >= W: continue if finalized[ny, nx]: continue v = float(f[ny, nx]) nc = c if c >= v else v if nc < float(cost[ny, nx]) - eps: cost[ny, nx] = nc heapq.heappush(heap, (nc, ny, nx)) return cost def ap4_fill_holes_mask(mask: torch.Tensor) -> torch.Tensor: mask = ap4_as_mask(mask).clamp(0.0, 1.0) B, H, W = mask.shape device = mask.device dtype = mask.dtype mask_np = np.ascontiguousarray(mask.detach().cpu().numpy().astype(np.float32, copy=False)) filled_np = np.empty_like(mask_np) try: from skimage.morphology import reconstruction # type: ignore footprint = np.ones((3, 3), dtype=bool) for b in range(B): f = mask_np[b] seed = f.copy() if H > 2 and W > 2: seed[1:-1, 1:-1] = 1.0 else: seed[:, :] = 1.0 seed[0, :] = f[0, :] seed[-1, :] = f[-1, :] seed[:, 0] = f[:, 0] seed[:, -1] = f[:, -1] filled_np[b] = reconstruction(seed, f, method="erosion", footprint=footprint).astype(np.float32) except Exception: for b in range(B): filled_np[b] = ap4_fill_holes_grayscale_numpy_heap(mask_np[b], connectivity=8) out = torch.from_numpy(filled_np).to(device=device, dtype=dtype) return out.clamp(0.0, 1.0) class apply_segment_4: CATEGORY = "image/salia" @classmethod def INPUT_TYPES(cls): choices = ap4_list_pngs() or [""] return { "required": { "mask": ("MASK",), "image": (choices, {}), "img": ("IMAGE",), "canvas": ("IMAGE",), "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("Final_Image",) FUNCTION = "run" def run(self, mask, image, img, canvas, x, y): if image == "": raise FileNotFoundError("No PNGs found in assets/images next to this node") mask_in = ap4_as_mask(mask).clamp(0.0, 1.0) blurred = ap4_mask_blur(mask_in, amount=8, device="gpu") dilated = ap4_dilate_mask(blurred, dilation=3) filled = ap4_fill_holes_mask(dilated) inversed_mask = 1.0 - filled _asset_img, loaded_mask = ap4_load_image_from_assets(image) combiner = AP4_AILab_MaskCombiner_Exact() inv_cpu = inversed_mask.detach().cpu() loaded_cpu = ap4_as_mask(loaded_mask).detach().cpu() (alpha_mask,) = combiner.combine_masks(inv_cpu, mode="combine", mask_2=(1.0 - loaded_cpu)) alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0) alpha_image = ap4_join_image_with_alpha_comfy(img, alpha_mask) canvas = ap4_as_image(canvas) alpha_image = alpha_image.to(device=canvas.device, dtype=canvas.dtype) final = ap4_alpha_over_region(alpha_image, canvas, x, y) return (final,) @classmethod def IS_CHANGED(cls, mask, image, img, canvas, x, y): if image == "": return image return ap4_file_hash(image) @classmethod def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y): if image == "": return "No PNGs found in assets/images next to this node" try: path = ap4_safe_path(image) except Exception as e: return str(e) if not os.path.isfile(path): return f"File not found in assets/images: {image}" return True # ====================================================================================== # Fused node: Salia_ezpz_gated_Duo2 -> SAM3Segment (hardcoded) -> apply_segment_4 # ====================================================================================== class SAM3Segment_Salia: CATEGORY = "image/salia" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("Final_Image",) FUNCTION = "run" @classmethod def INPUT_TYPES(cls): # Use the exact dropdown sources of the embedded nodes salia_assets = _list_asset_pngs() or [""] ap4_assets = ap4_list_pngs() or [""] upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"] return { "required": { "image": ("IMAGE",), "trigger_string": ("STRING", {"default": ""}), "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), "positive_prompt": ("STRING", {"default": "", "multiline": True}), "negative_prompt": ("STRING", {"default": "", "multiline": True}), "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "SAM3 prompt"}), "asset_image": (salia_assets, {}), "apply_asset_image": (ap4_assets, {}), "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), "upscale_factor_1": (upscale_choices, {"default": "4"}), "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), "upscale_factor_2": (upscale_choices, {"default": "4"}), "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), } } def __init__(self): self._sam3 = SAM3Segment() self._salia = Salia_ezpz_gated_Duo2() self._ap4 = apply_segment_4() def run( self, image, trigger_string="", X_coord=0, Y_coord=0, positive_prompt="", negative_prompt="", prompt="", asset_image="", apply_asset_image="", square_size_1=384, upscale_factor_1="4", denoise_1=0.35, square_size_2=384, upscale_factor_2="4", denoise_2=0.35, ): # EXACT bypass: if trigger_string is empty, return input image as Final_Image if trigger_string == "": return (image,) # 1) Pre-node: Salia_ezpz_gated_Duo2 -> image_cropped _out_image, image_cropped = self._salia.run( image=image, trigger_string=trigger_string, X_coord=int(X_coord), Y_coord=int(Y_coord), positive_prompt=str(positive_prompt), negative_prompt=str(negative_prompt), asset_image=str(asset_image), square_size_1=int(square_size_1), upscale_factor_1=str(upscale_factor_1), denoise_1=float(denoise_1), square_size_2=int(square_size_2), upscale_factor_2=str(upscale_factor_2), denoise_2=float(denoise_2), ) # 2) Center: SAM3Segment with hardcoded settings on the CROPPED image seg_image, seg_mask, _mask_image = self._sam3.segment( image=image_cropped, prompt=str(prompt), sam3_model="sam3", device="GPU", confidence_threshold=0.50, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222", ) # 3) Post-node: apply_segment_4 onto ORIGINAL input canvas (not Duo2 output) (final_image,) = self._ap4.run( mask=seg_mask, image=str(apply_asset_image), img=seg_image, canvas=image, x=int(X_coord), y=int(Y_coord), ) return (final_image,) # ====================================================================================== # Node mappings (all nodes in this file) # ====================================================================================== NODE_CLASS_MAPPINGS = { "SAM3Segment": SAM3Segment, "Salia_ezpz_gated_Duo2": Salia_ezpz_gated_Duo2, "apply_segment_4": apply_segment_4, "SAM3Segment_Salia": SAM3Segment_Salia, } NODE_DISPLAY_NAME_MAPPINGS = { "SAM3Segment": "SAM3 Segmentation (RMBG)", "Salia_ezpz_gated_Duo2": "Salia_ezpz_gated_Duo2", "apply_segment_4": "apply_segment_4", "SAM3Segment_Salia": "SAM3Segment_Salia (Duo2 → SAM3 → apply_segment_4)", }