MyCustomNodes / AILab_SAM3Segment.py
saliacoel's picture
Update AILab_SAM3Segment.py
f559035 verified
# AILab_SAM3Segment.py
# Integrated standalone nodes:
# - SAM3Segment
# - Salia_ezpz_gated_Duo2
# - apply_segment_4
# - SAM3Segment_Salia (fused)
import os
import sys
import hashlib
import shutil
import threading
import urllib.request
import heapq
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict, Tuple, Optional, List
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageFilter, ImageOps
from torch.hub import download_url_to_file
import folder_paths
import comfy.model_management
import comfy.model_management as model_management
from AILab_ImageMaskTools import pil2tensor, tensor2pil
# ======================================================================================
# SAM3Segment (original, with syntax fix)
# ======================================================================================
CURRENT_DIR = os.path.dirname(__file__)
SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3")
if SAM3_LOCAL_DIR not in sys.path:
sys.path.insert(0, SAM3_LOCAL_DIR)
SAM3_BPE_PATH = os.path.join(SAM3_LOCAL_DIR, "assets", "bpe_simple_vocab_16e6.txt.gz")
if not os.path.isfile(SAM3_BPE_PATH):
raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.")
from sam3.model_builder import build_sam3_image_model # noqa: E402
from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402
_DEFAULT_PT_ENTRY = {
"model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt",
"filename": "sam3.pt",
}
SAM3_MODELS = {
"sam3": _DEFAULT_PT_ENTRY.copy(),
}
def get_sam3_pt_models():
entry = SAM3_MODELS.get("sam3")
if entry and entry.get("filename", "").endswith(".pt"):
return {"sam3": entry}
for key, value in SAM3_MODELS.items():
if value.get("filename", "").endswith(".pt"):
return {"sam3": value}
if "sam3" in key and value:
candidate = value.copy()
candidate["model_url"] = _DEFAULT_PT_ENTRY["model_url"]
candidate["filename"] = _DEFAULT_PT_ENTRY["filename"]
return {"sam3": candidate}
return {"sam3": _DEFAULT_PT_ENTRY.copy()}
def process_mask(mask_image, invert_output=False, mask_blur=0, mask_offset=0):
if invert_output:
mask_np = np.array(mask_image)
mask_image = Image.fromarray(255 - mask_np)
if mask_blur > 0:
mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
if mask_offset != 0:
filt = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter
size = abs(mask_offset) * 2 + 1
for _ in range(abs(mask_offset)):
mask_image = mask_image.filter(filt(size))
return mask_image
def apply_background_color(image, mask_image, background="Alpha", background_color="#222222"):
rgba_image = image.copy().convert("RGBA")
rgba_image.putalpha(mask_image.convert("L"))
if background == "Color":
hex_color = background_color.lstrip("#")
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
bg_image = Image.new("RGBA", image.size, (r, g, b, 255))
composite = Image.alpha_composite(bg_image, rgba_image)
return composite.convert("RGB")
return rgba_image
def get_or_download_model_file(filename, url):
local_path = None
if hasattr(folder_paths, "get_full_path"):
local_path = folder_paths.get_full_path("sam3", filename)
if local_path and os.path.isfile(local_path):
return local_path
base_models_dir = getattr(folder_paths, "models_dir", os.path.join(CURRENT_DIR, "models"))
models_dir = os.path.join(base_models_dir, "sam3")
os.makedirs(models_dir, exist_ok=True)
local_path = os.path.join(models_dir, filename)
if not os.path.exists(local_path):
print(f"Downloading {filename} from {url} ...")
download_url_to_file(url, local_path)
return local_path
def _resolve_device(user_choice):
auto_device = comfy.model_management.get_torch_device()
if user_choice == "CPU":
return torch.device("cpu")
if user_choice == "GPU":
if auto_device.type != "cuda":
raise RuntimeError("GPU unavailable")
return torch.device("cuda")
return auto_device
class SAM3Segment:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Describe the concept"}),
"sam3_model": (list(SAM3_MODELS.keys()), {"default": "sam3"}),
"device": (["Auto", "CPU", "GPU"], {"default": "Auto"}),
"confidence_threshold": ("FLOAT", {"default": 0.5, "min": 0.05, "max": 0.95, "step": 0.01}),
},
"optional": {
"mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}),
"mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}),
"invert_output": ("BOOLEAN", {"default": False}),
"unload_model": ("BOOLEAN", {"default": False}),
"background": (["Alpha", "Color"], {"default": "Alpha"}),
"background_color": ("COLORCODE", {"default": "#222222"}),
},
}
RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
FUNCTION = "segment"
CATEGORY = "🧪AILab/🧽RMBG"
def __init__(self):
self.processor_cache = {}
def _load_processor(self, model_choice, device_choice):
torch_device = _resolve_device(device_choice)
device_str = "cuda" if torch_device.type == "cuda" else "cpu"
cache_key = (model_choice, device_str)
if cache_key not in self.processor_cache:
model_info = SAM3_MODELS[model_choice]
ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"])
model = build_sam3_image_model(
bpe_path=SAM3_BPE_PATH,
device=device_str,
eval_mode=True,
checkpoint_path=ckpt_path,
load_from_HF=False,
enable_segmentation=True,
enable_inst_interactivity=False,
)
processor = Sam3Processor(model, device=device_str)
self.processor_cache[cache_key] = processor
return self.processor_cache[cache_key], torch_device
def _empty_result(self, img_pil, background, background_color):
w, h = img_pil.size
mask_image = Image.new("L", (w, h), 0)
result_image = apply_background_color(img_pil, mask_image, background, background_color)
if background == "Alpha":
result_image = result_image.convert("RGBA")
else:
result_image = result_image.convert("RGB")
empty_mask = torch.zeros((1, h, w), dtype=torch.float32)
mask_rgb = empty_mask.reshape((-1, 1, h, w)).movedim(1, -1).expand(-1, -1, -1, 3)
return result_image, empty_mask, mask_rgb
def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask_offset, invert, background, background_color):
img_pil = tensor2pil(img_tensor)
text = prompt.strip() or "object"
state = processor.set_image(img_pil)
processor.reset_all_prompts(state)
processor.set_confidence_threshold(confidence, state)
state = processor.set_text_prompt(text, state)
masks = state.get("masks")
if masks is None or masks.numel() == 0:
return self._empty_result(img_pil, background, background_color)
masks = masks.float().to("cpu")
if masks.ndim == 4:
masks = masks.squeeze(1)
combined = masks.amax(dim=0)
mask_np = (combined.clamp(0, 1).numpy() * 255).astype(np.uint8)
mask_image = Image.fromarray(mask_np, mode="L")
mask_image = process_mask(mask_image, invert, mask_blur, mask_offset)
result_image = apply_background_color(img_pil, mask_image, background, background_color)
if background == "Alpha":
result_image = result_image.convert("RGBA")
else:
result_image = result_image.convert("RGB")
mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
return result_image, mask_tensor, mask_rgb
def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"):
if image.ndim == 3:
image = image.unsqueeze(0)
processor, torch_device = self._load_processor(sam3_model, device)
autocast_device = comfy.model_management.get_autocast_device(torch_device)
autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device)
ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext()
result_images, result_masks, result_mask_images = [], [], []
with ctx:
for tensor_img in image:
img_pil, mask_tensor, mask_rgb = self._run_single(
processor,
tensor_img,
prompt,
confidence_threshold,
mask_blur,
mask_offset,
invert_output,
background,
background_color,
)
result_images.append(pil2tensor(img_pil))
result_masks.append(mask_tensor)
result_mask_images.append(mask_rgb)
if unload_model:
device_str = "cuda" if torch_device.type == "cuda" else "cpu"
cache_key = (sam3_model, device_str)
if cache_key in self.processor_cache:
del self.processor_cache[cache_key]
if torch_device.type == "cuda":
torch.cuda.empty_cache()
return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0)
# ======================================================================================
# Salia_ezpz_gated_Duo2 (standalone)
# ======================================================================================
# transformers is required for depth-estimation pipeline
try:
from transformers import pipeline
except Exception as e:
pipeline = None
_TRANSFORMERS_IMPORT_ERROR = e
_CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {}
_CN_CACHE: Dict[str, Any] = {}
_CKPT_LOCK = threading.Lock()
_CN_LOCK = threading.Lock()
def _find_plugin_root() -> Path:
"""
Walk upwards from this file until we find an 'assets' folder.
If not found, fall back to this file's directory.
"""
here = Path(__file__).resolve()
for parent in [here.parent] + list(here.parents)[:12]:
if (parent / "assets").is_dir():
return parent
return here.parent
PLUGIN_ROOT = _find_plugin_root()
def _pil_lanczos():
if hasattr(Image, "Resampling"):
return Image.Resampling.LANCZOS
return Image.LANCZOS
def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image:
if img.ndim == 4:
img = img[0]
img = img.detach().cpu().float().clamp(0, 1)
arr = (img.numpy() * 255.0).round().astype(np.uint8)
if arr.shape[-1] == 4:
return Image.fromarray(arr, mode="RGBA")
return Image.fromarray(arr, mode="RGB")
def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor:
if pil.mode not in ("RGB", "RGBA"):
pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB")
arr = np.array(pil).astype(np.float32) / 255.0
t = torch.from_numpy(arr)
return t.unsqueeze(0)
def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image:
if mask.ndim == 3:
mask = mask[0]
mask = mask.detach().cpu().float().clamp(0, 1)
arr = (mask.numpy() * 255.0).round().astype(np.uint8)
return Image.fromarray(arr, mode="L")
def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor:
if pil_l.mode != "L":
pil_l = pil_l.convert("L")
arr = np.array(pil_l).astype(np.float32) / 255.0
t = torch.from_numpy(arr)
return t.unsqueeze(0)
def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor:
if img.ndim != 4:
raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].")
outs = []
for i in range(img.shape[0]):
pil = _image_tensor_to_pil(img[i].unsqueeze(0))
pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
outs.append(_pil_to_image_tensor(pil))
return torch.cat(outs, dim=0)
def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor:
if mask.ndim != 3:
raise ValueError("Expected MASK tensor with shape [B,H,W].")
outs = []
for i in range(mask.shape[0]):
pil = _mask_tensor_to_pil(mask[i].unsqueeze(0))
pil = pil.resize((int(w), int(h)), resample=_pil_lanczos())
outs.append(_pil_to_mask_tensor(pil))
return torch.cat(outs, dim=0)
def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
if rgb.ndim == 3:
rgb = rgb.unsqueeze(0)
if mask.ndim == 2:
mask = mask.unsqueeze(0)
if rgb.ndim != 4 or rgb.shape[-1] != 3:
raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}")
if mask.ndim != 3:
raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}")
if mask.shape[0] != rgb.shape[0]:
if mask.shape[0] == 1 and rgb.shape[0] > 1:
mask = mask.expand(rgb.shape[0], -1, -1)
else:
raise ValueError("Batch mismatch between rgb and mask.")
if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]:
raise ValueError(
f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}"
)
mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1)
alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1)
rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1)
return rgba
def _load_checkpoint_cached(ckpt_name: str):
with _CKPT_LOCK:
if ckpt_name in _CKPT_CACHE:
return _CKPT_CACHE[ckpt_name]
import nodes
loader = nodes.CheckpointLoaderSimple()
fn = getattr(loader, loader.FUNCTION)
model, clip, vae = fn(ckpt_name=ckpt_name)
_CKPT_CACHE[ckpt_name] = (model, clip, vae)
return model, clip, vae
def _load_controlnet_cached(control_net_name: str):
with _CN_LOCK:
if control_net_name in _CN_CACHE:
return _CN_CACHE[control_net_name]
import nodes
loader = nodes.ControlNetLoader()
fn = getattr(loader, loader.FUNCTION)
(cn,) = fn(control_net_name=control_net_name)
_CN_CACHE[control_net_name] = cn
return cn
def _assets_images_dir() -> Path:
return PLUGIN_ROOT / "assets" / "images"
def _list_asset_pngs() -> list:
img_dir = _assets_images_dir()
if not img_dir.is_dir():
return []
files = []
for p in img_dir.rglob("*"):
if p.is_file() and p.suffix.lower() == ".png":
files.append(p.relative_to(img_dir).as_posix())
files.sort()
return files
def _safe_asset_path(asset_rel_path: str) -> Path:
img_dir = _assets_images_dir()
if not img_dir.is_dir():
raise FileNotFoundError(f"assets/images folder not found: {img_dir}")
base = img_dir.resolve()
rel = Path(asset_rel_path)
if rel.is_absolute():
raise ValueError("Absolute paths are not allowed for asset_image.")
full = (base / rel).resolve()
if base != full and base not in full.parents:
raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}")
if not full.is_file():
raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}")
if full.suffix.lower() != ".png":
raise ValueError(f"Asset is not a PNG: {asset_rel_path}")
return full
def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
p = _safe_asset_path(asset_rel_path)
im = Image.open(p)
im = ImageOps.exif_transpose(im)
rgba = im.convert("RGBA")
rgb = rgba.convert("RGB")
rgb_arr = np.array(rgb).astype(np.float32) / 255.0
img_t = torch.from_numpy(rgb_arr)[None, ...]
alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0
mask = 1.0 - alpha
mask_t = torch.from_numpy(mask)[None, ...]
return img_t, mask_t
MODEL_DIR = PLUGIN_ROOT / "assets" / "depth"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
REQUIRED_FILES = {
"config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json",
"model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors",
"preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json",
}
ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti"
_PIPE_CACHE: Dict[Tuple[str, str], Any] = {}
_PIPE_LOCK = threading.Lock()
def _have_required_files() -> bool:
return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys())
def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None:
dst.parent.mkdir(parents=True, exist_ok=True)
tmp = dst.with_suffix(dst.suffix + ".tmp")
if tmp.exists():
try:
tmp.unlink()
except Exception:
pass
req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"})
with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f:
shutil.copyfileobj(r, f)
tmp.replace(dst)
def ensure_local_model_files() -> bool:
if _have_required_files():
return True
try:
for fname, url in REQUIRED_FILES.items():
fpath = MODEL_DIR / fname
if fpath.exists():
continue
_download_url_to_file(url, fpath)
return _have_required_files()
except Exception:
return False
def HWC3(x: np.ndarray) -> np.ndarray:
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def pad64(x: int) -> int:
return int(np.ceil(float(x) / 64.0) * 64 - x)
def safer_memory(x: np.ndarray) -> np.ndarray:
return np.ascontiguousarray(x.copy()).copy()
def resize_image_with_pad_min_side(
input_image: np.ndarray,
resolution: int,
upscale_method: str = "INTER_CUBIC",
skip_hwc3: bool = False,
mode: str = "edge",
) -> Tuple[np.ndarray, Any]:
cv2 = None
try:
import cv2 as _cv2
cv2 = _cv2
except Exception:
cv2 = None
img = input_image if skip_hwc3 else HWC3(input_image)
H_raw, W_raw, _ = img.shape
if resolution <= 0:
return img, (lambda x: x)
k = float(resolution) / float(min(H_raw, W_raw))
H_target = int(np.round(float(H_raw) * k))
W_target = int(np.round(float(W_raw) * k))
if cv2 is not None:
upscale_methods = {
"INTER_NEAREST": cv2.INTER_NEAREST,
"INTER_LINEAR": cv2.INTER_LINEAR,
"INTER_AREA": cv2.INTER_AREA,
"INTER_CUBIC": cv2.INTER_CUBIC,
"INTER_LANCZOS4": cv2.INTER_LANCZOS4,
}
method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC)
img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA)
else:
pil = Image.fromarray(img)
resample = Image.BICUBIC if k > 1 else Image.LANCZOS
pil = pil.resize((W_target, H_target), resample=resample)
img = np.array(pil, dtype=np.uint8)
H_pad, W_pad = pad64(H_target), pad64(W_target)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
def remove_pad(x: np.ndarray) -> np.ndarray:
return safer_memory(x[:H_target, :W_target, ...])
return safer_memory(img_padded), remove_pad
def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]:
img = HWC3(img_u8)
H_raw, W_raw, _ = img.shape
H_pad, W_pad = pad64(H_raw), pad64(W_raw)
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
def remove_pad(x: np.ndarray) -> np.ndarray:
return safer_memory(x[:H_raw, :W_raw, ...])
return safer_memory(img_padded), remove_pad
def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
if inp_u8.ndim == 3 and inp_u8.shape[2] == 4:
rgba = inp_u8.astype(np.uint8)
rgb = rgba[:, :, 0:3].astype(np.float32)
a = (rgba[:, :, 3:4].astype(np.float32) / 255.0)
rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8)
alpha_u8 = rgba[:, :, 3].copy()
return rgb_white, alpha_u8
return HWC3(inp_u8), None
def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray:
depth_rgb_u8 = HWC3(depth_rgb_u8)
a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None]
out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8)
return out
def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray:
if img.ndim == 4:
img = img[0]
arr = img.detach().cpu().float().clamp(0, 1).numpy()
u8 = (arr * 255.0).round().astype(np.uint8)
return u8
def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor:
img_u8 = HWC3(img_u8)
t = torch.from_numpy(img_u8.astype(np.float32) / 255.0)
return t.unsqueeze(0)
def _try_load_pipeline(model_source: str, device: torch.device):
if pipeline is None:
raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}")
key = (model_source, str(device))
with _PIPE_LOCK:
if key in _PIPE_CACHE:
return _PIPE_CACHE[key]
p = pipeline(task="depth-estimation", model=model_source)
try:
p.model = p.model.to(device)
p.device = device
except Exception:
pass
_PIPE_CACHE[key] = p
return p
def get_depth_pipeline(device: torch.device):
if ensure_local_model_files():
try:
return _try_load_pipeline(str(MODEL_DIR), device)
except Exception:
pass
try:
return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device)
except Exception:
return None
def depth_estimate_zoe_style(
pipe,
input_rgb_u8: np.ndarray,
detect_resolution: int,
upscale_method: str = "INTER_CUBIC",
) -> np.ndarray:
if detect_resolution == -1:
work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge")
else:
work_img, remove_pad = resize_image_with_pad_min_side(
input_rgb_u8,
int(detect_resolution),
upscale_method=upscale_method,
skip_hwc3=False,
mode="edge",
)
pil_image = Image.fromarray(work_img)
with torch.no_grad():
result = pipe(pil_image)
depth = result["depth"]
if isinstance(depth, Image.Image):
depth_array = np.array(depth, dtype=np.float32)
else:
depth_array = np.array(depth, dtype=np.float32)
vmin = float(np.percentile(depth_array, 2))
vmax = float(np.percentile(depth_array, 85))
depth_array = depth_array - vmin
denom = (vmax - vmin)
if abs(denom) < 1e-12:
denom = 1e-6
depth_array = depth_array / denom
depth_array = 1.0 - depth_array
depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8)
detected_map = remove_pad(HWC3(depth_image))
return detected_map
def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray:
try:
import cv2
out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR)
return out.astype(np.uint8)
except Exception:
pil = Image.fromarray(depth_rgb_u8)
pil = pil.resize((w0, h0), resample=Image.BILINEAR)
return np.array(pil, dtype=np.uint8)
def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor:
try:
device = model_management.get_torch_device()
except Exception:
device = torch.device("cpu")
pipe_obj = None
try:
pipe_obj = get_depth_pipeline(device)
except Exception:
pipe_obj = None
if pipe_obj is None:
return image
if image.ndim == 3:
image = image.unsqueeze(0)
outs = []
for i in range(image.shape[0]):
try:
h0 = int(image[i].shape[0])
w0 = int(image[i].shape[1])
inp_u8 = comfy_tensor_to_u8(image[i])
rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8)
had_rgba = alpha_u8 is not None
depth_rgb = depth_estimate_zoe_style(
pipe=pipe_obj,
input_rgb_u8=rgb_for_depth,
detect_resolution=int(resolution),
upscale_method="INTER_CUBIC",
)
depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0)
if had_rgba:
if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0:
try:
import cv2
alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8)
except Exception:
pil_a = Image.fromarray(alpha_u8)
pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR)
alpha_u8 = np.array(pil_a, dtype=np.uint8)
depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8)
outs.append(u8_to_comfy_tensor(depth_rgb))
except Exception:
outs.append(image[i].unsqueeze(0))
return torch.cat(outs, dim=0)
def _salia_alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor:
if base.ndim != 4 or overlay_rgba.ndim != 4:
raise ValueError("base and overlay must be [B,H,W,C].")
B, H, W, C = base.shape
b2, sH, sW, c2 = overlay_rgba.shape
if c2 != 4:
raise ValueError("overlay_rgba must have 4 channels (RGBA).")
if sH != sW:
raise ValueError("overlay must be square.")
s = sH
if x < 0 or y < 0 or x + s > W or y + s > H:
raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}")
if b2 != B:
if b2 == 1 and B > 1:
overlay_rgba = overlay_rgba.expand(B, -1, -1, -1)
else:
raise ValueError("Batch mismatch between base and overlay.")
out = base.clone()
overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1)
overlay_a = overlay_rgba[..., 3:4].clamp(0, 1)
base_rgb = out[:, y:y + s, x:x + s, 0:3]
comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a)
out[:, y:y + s, x:x + s, 0:3] = comp_rgb
if C == 4:
base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1)
comp_a = overlay_a + base_a * (1.0 - overlay_a)
out[:, y:y + s, x:x + s, 3:4] = comp_a
return out.clamp(0, 1)
_HARDCODED_CKPT_NAME = "SaliaHighlady_Speedy.safetensors"
_HARDCODED_CONTROLNET_NAME = "diffusion_pytorch_model_promax.safetensors"
_HARDCODED_CN_START = 0.00
_HARDCODED_CN_END = 1.00
_PASS1_SAMPLER_NAME = "dpmpp_2m_sde_heun_gpu"
_PASS1_SCHEDULER = "karras"
_PASS1_STEPS = 29
_PASS1_CFG = 2.6
_PASS1_CONTROLNET_STRENGTH = 0.33
_PASS2_SAMPLER_NAME = "res_multistep_ancestral_cfg_pp"
_PASS2_SCHEDULER = "karras"
_PASS2_STEPS = 30
_PASS2_CFG = 1.7
_PASS2_CONTROLNET_STRENGTH = 0.5
class Salia_ezpz_gated_Duo2:
CATEGORY = "image/salia"
RETURN_TYPES = ("IMAGE", "IMAGE")
RETURN_NAMES = ("image", "image_cropped")
FUNCTION = "run"
@classmethod
def INPUT_TYPES(cls):
assets = _list_asset_pngs() or ["<no pngs found>"]
upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
return {
"required": {
"image": ("IMAGE",),
"trigger_string": ("STRING", {"default": ""}),
"X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
"Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
"positive_prompt": ("STRING", {"default": "", "multiline": True}),
"negative_prompt": ("STRING", {"default": "", "multiline": True}),
"asset_image": (assets, {}),
"square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
"upscale_factor_1": (upscale_choices, {"default": "4"}),
"denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
"square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
"upscale_factor_2": (upscale_choices, {"default": "4"}),
"denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
}
}
def run(
self,
image: torch.Tensor,
trigger_string: str = "",
X_coord: int = 0,
Y_coord: int = 0,
positive_prompt: str = "",
negative_prompt: str = "",
asset_image: str = "",
square_size_1: int = 384,
upscale_factor_1: str = "4",
denoise_1: float = 0.35,
square_size_2: int = 384,
upscale_factor_2: str = "4",
denoise_2: float = 0.35,
):
if image.ndim == 3:
image = image.unsqueeze(0)
if image.ndim != 4:
raise ValueError("Input image must be [B,H,W,C].")
B, H, W, C = image.shape
if C not in (3, 4):
raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.")
x = int(X_coord)
y = int(Y_coord)
s1 = int(square_size_1)
s2 = int(square_size_2)
def _validate_square_bounds(s: int, label: str):
if s <= 0:
raise ValueError(f"{label}: square_size must be > 0")
if x < 0 or y < 0 or x + s > W or y + s > H:
raise ValueError(f"{label}: out of bounds. image={W}x{H}, rect at ({x},{y}) size={s}")
def _validate_upscale(up: int, s: int, label: str):
if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16):
raise ValueError(f"{label}: upscale_factor must be one of 1,2,4,6,8,10,12,14,16")
if ((s * up) % 8) != 0:
raise ValueError(f"{label}: square_size * upscale_factor must be divisible by 8 (VAE requirement).")
def _crop_square(img: torch.Tensor, s: int) -> torch.Tensor:
return img[:, y:y + s, x:x + s, :]
_validate_square_bounds(s2, "final crop (square_size_2)")
if trigger_string == "":
out2 = image
cropped = _crop_square(out2, s2)
return (out2, cropped)
_validate_square_bounds(s1, "pass1 (square_size_1)")
_validate_square_bounds(s2, "pass2 (square_size_2)")
up1 = int(upscale_factor_1)
up2 = int(upscale_factor_2)
_validate_upscale(up1, s1, "pass1")
_validate_upscale(up2, s2, "pass2")
d1 = float(max(0.0, min(1.0, denoise_1)))
d2 = float(max(0.0, min(1.0, denoise_2)))
if asset_image == "<no pngs found>":
raise FileNotFoundError("No PNGs found in assets/images for this plugin.")
_asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image)
if asset_mask.ndim == 2:
asset_mask = asset_mask.unsqueeze(0)
if asset_mask.ndim != 3:
raise ValueError("Asset mask must be [B,H,W].")
if asset_mask.shape[0] != B:
if asset_mask.shape[0] == 1 and B > 1:
asset_mask = asset_mask.expand(B, -1, -1)
else:
raise ValueError("Batch mismatch for asset mask vs input image batch.")
import nodes
try:
model, clip, vae = _load_checkpoint_cached(_HARDCODED_CKPT_NAME)
except Exception as e:
available = folder_paths.get_filename_list("checkpoints") or []
raise FileNotFoundError(
f"Hardcoded ckpt not found: '{_HARDCODED_CKPT_NAME}'. "
f"Put it in models/checkpoints. Available (first 50): {available[:50]}"
) from e
try:
controlnet = _load_controlnet_cached(_HARDCODED_CONTROLNET_NAME)
except Exception as e:
available = folder_paths.get_filename_list("controlnet") or []
raise FileNotFoundError(
f"Hardcoded controlnet not found: '{_HARDCODED_CONTROLNET_NAME}'. "
f"Put it in models/controlnet. Available (first 50): {available[:50]}"
) from e
pos_enc = nodes.CLIPTextEncode()
neg_enc = nodes.CLIPTextEncode()
pos_fn = getattr(pos_enc, pos_enc.FUNCTION)
neg_fn = getattr(neg_enc, neg_enc.FUNCTION)
(pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip)
(neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip)
cn_apply = nodes.ControlNetApplyAdvanced()
cn_fn = getattr(cn_apply, cn_apply.FUNCTION)
vae_enc = nodes.VAEEncode()
vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION)
ksampler = nodes.KSampler()
k_fn = getattr(ksampler, ksampler.FUNCTION)
vae_dec = nodes.VAEDecode()
vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION)
def _run_pass(
pass_index: int,
in_image: torch.Tensor,
s: int,
up: int,
denoise_v: float,
steps_v: int,
cfg_v: float,
sampler_v: str,
scheduler_v: str,
controlnet_strength_v: float,
) -> torch.Tensor:
up_w = s * up
up_h = s * up
crop = in_image[:, y:y + s, x:x + s, :]
crop_rgb = crop[:, :, :, 0:3].contiguous()
depth_small = _salia_depth_execute(crop_rgb, resolution=s)
depth_up = _resize_image_lanczos(depth_small, up_w, up_h)
crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h)
asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h)
pos_cn, neg_cn = cn_fn(
strength=float(controlnet_strength_v),
start_percent=float(_HARDCODED_CN_START),
end_percent=float(_HARDCODED_CN_END),
positive=pos_cond,
negative=neg_cond,
control_net=controlnet,
image=depth_up,
vae=vae,
)
(latent,) = vae_enc_fn(pixels=crop_up, vae=vae)
seed_material = (
f"{_HARDCODED_CKPT_NAME}|{_HARDCODED_CONTROLNET_NAME}|{asset_image}|"
f"pass={pass_index}|x={x}|y={y}|s={s}|up={up}|"
f"steps={steps_v}|cfg={cfg_v}|sampler={sampler_v}|scheduler={scheduler_v}|denoise={denoise_v}|"
f"cn_strength={controlnet_strength_v}|"
f"{positive_prompt}|{negative_prompt}"
).encode("utf-8", errors="ignore")
seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16)
(sampled_latent,) = k_fn(
seed=seed64,
steps=int(steps_v),
cfg=float(cfg_v),
sampler_name=str(sampler_v),
scheduler=str(scheduler_v),
denoise=float(denoise_v),
model=model,
positive=pos_cn,
negative=neg_cn,
latent_image=latent,
)
(decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae)
rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up)
rgba_square = _resize_image_lanczos(rgba_up, s, s)
out = _salia_alpha_over_region(in_image, rgba_square, x=x, y=y)
return out
out1 = _run_pass(
pass_index=1,
in_image=image,
s=s1,
up=up1,
denoise_v=d1,
steps_v=_PASS1_STEPS,
cfg_v=_PASS1_CFG,
sampler_v=_PASS1_SAMPLER_NAME,
scheduler_v=_PASS1_SCHEDULER,
controlnet_strength_v=_PASS1_CONTROLNET_STRENGTH,
)
out2 = _run_pass(
pass_index=2,
in_image=out1,
s=s2,
up=up2,
denoise_v=d2,
steps_v=_PASS2_STEPS,
cfg_v=_PASS2_CFG,
sampler_v=_PASS2_SAMPLER_NAME,
scheduler_v=_PASS2_SCHEDULER,
controlnet_strength_v=_PASS2_CONTROLNET_STRENGTH,
)
cropped = out2[:, y:y + s2, x:x + s2, :]
return (out2, cropped)
# ======================================================================================
# apply_segment_4 (standalone, embedded) - rename internal alpha paste helper to avoid clash
# ======================================================================================
# Expects: <this_file_dir>/assets/images/*.png
_AP4_ASSETS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "images")
def ap4_list_pngs() -> List[str]:
if not os.path.isdir(_AP4_ASSETS_DIR):
return []
files: List[str] = []
for root, _, fnames in os.walk(_AP4_ASSETS_DIR):
for f in fnames:
if f.lower().endswith(".png"):
full = os.path.join(root, f)
if os.path.isfile(full):
rel = os.path.relpath(full, _AP4_ASSETS_DIR)
files.append(rel.replace("\\", "/"))
return sorted(files)
def ap4_safe_path(filename: str) -> str:
candidate = os.path.join(_AP4_ASSETS_DIR, filename)
real_assets = os.path.realpath(_AP4_ASSETS_DIR)
real_candidate = os.path.realpath(candidate)
if not real_candidate.startswith(real_assets + os.sep) and real_candidate != real_assets:
raise ValueError("Unsafe path (path traversal detected).")
return real_candidate
def ap4_file_hash(filename: str) -> str:
path = ap4_safe_path(filename)
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def ap4_load_image_from_assets(filename: str) -> Tuple[torch.Tensor, torch.Tensor]:
path = ap4_safe_path(filename)
i = Image.open(path)
i = ImageOps.exif_transpose(i)
if i.mode == "I":
i = i.point(lambda px: px * (1 / 255))
rgb = i.convert("RGB")
rgb_np = np.array(rgb).astype(np.float32) / 255.0
image = torch.from_numpy(rgb_np)[None, ...]
bands = i.getbands()
if "A" in bands:
a = np.array(i.getchannel("A")).astype(np.float32) / 255.0
alpha = torch.from_numpy(a)
else:
l = np.array(i.convert("L")).astype(np.float32) / 255.0
alpha = torch.from_numpy(l)
mask = 1.0 - alpha
mask = mask.clamp(0.0, 1.0).unsqueeze(0)
return image, mask
def ap4_as_image(img: torch.Tensor) -> torch.Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("IMAGE must be a torch.Tensor")
if img.dim() != 4:
raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
if img.shape[-1] not in (3, 4):
raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
return img
def ap4_as_mask(mask: torch.Tensor) -> torch.Tensor:
if not isinstance(mask, torch.Tensor):
raise TypeError("MASK must be a torch.Tensor")
if mask.dim() == 2:
mask = mask.unsqueeze(0)
if mask.dim() != 3:
raise ValueError(f"Expected MASK shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
return mask
def ap4_ensure_rgba(img: torch.Tensor) -> torch.Tensor:
img = ap4_as_image(img)
if img.shape[-1] == 4:
return img
B, H, W, _ = img.shape
alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
return torch.cat([img, alpha], dim=-1)
def ap4_alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
overlay = ap4_as_image(overlay)
canvas = ap4_as_image(canvas)
if overlay.shape[0] != canvas.shape[0]:
if overlay.shape[0] == 1 and canvas.shape[0] > 1:
overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
else:
raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
_, Hc, Wc, Cc = canvas.shape
_, Ho, Wo, _ = overlay.shape
x = int(x)
y = int(y)
out = canvas.clone()
x0c = max(0, x)
y0c = max(0, y)
x1c = min(Wc, x + Wo)
y1c = min(Hc, y + Ho)
if x1c <= x0c or y1c <= y0c:
return out
x0o = x0c - x
y0o = y0c - y
x1o = x0o + (x1c - x0c)
y1o = y0o + (y1c - y0c)
canvas_region = out[:, y0c:y1c, x0c:x1c, :]
overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
canvas_rgba = ap4_ensure_rgba(canvas_region)
overlay_rgba = ap4_ensure_rgba(overlay_region)
over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
over_pm = over_rgb * over_a
under_pm = under_rgb * under_a
out_a = over_a + under_a * (1.0 - over_a)
out_pm = over_pm + under_pm * (1.0 - over_a)
eps = 1e-6
out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
out_rgb = out_rgb.clamp(0.0, 1.0)
out_a = out_a.clamp(0.0, 1.0)
if Cc == 3:
out[:, y0c:y1c, x0c:x1c, :] = out_rgb
else:
out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
return out
class AP4_AILab_MaskCombiner_Exact:
def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
if len(masks) <= 1:
return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
ref_shape = masks[0].shape
masks = [self._resize_if_needed(m, ref_shape) for m in masks]
if mode == "combine":
result = torch.maximum(masks[0], masks[1])
for mask in masks[2:]:
result = torch.maximum(result, mask)
elif mode == "intersection":
result = torch.minimum(masks[0], masks[1])
else:
result = torch.abs(masks[0] - masks[1])
return (torch.clamp(result, 0, 1),)
def _resize_if_needed(self, mask, target_shape):
if mask.shape == target_shape:
return mask
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
elif len(mask.shape) == 4:
mask = mask.squeeze(1)
target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
resized_masks = []
for i in range(mask.shape[0]):
mask_np = mask[i].cpu().numpy()
img = Image.fromarray((mask_np * 255).astype(np.uint8))
img_resized = img.resize((target_width, target_height), Image.LANCZOS)
mask_resized = np.array(img_resized).astype(np.float32) / 255.0
resized_masks.append(torch.from_numpy(mask_resized))
return torch.stack(resized_masks)
def ap4_resize_mask_comfy(alpha_mask: torch.Tensor, image_shape_hwc: Tuple[int, int, int]) -> torch.Tensor:
H = int(image_shape_hwc[0])
W = int(image_shape_hwc[1])
return F.interpolate(
alpha_mask.reshape((-1, 1, alpha_mask.shape[-2], alpha_mask.shape[-1])),
size=(H, W),
mode="bilinear",
).squeeze(1)
def ap4_join_image_with_alpha_comfy(image: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
image = ap4_as_image(image)
alpha = ap4_as_mask(alpha)
alpha = alpha.to(device=image.device, dtype=image.dtype)
batch_size = min(len(image), len(alpha))
out_images = []
alpha_resized = 1.0 - ap4_resize_mask_comfy(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:, :, :3], alpha_resized[i].unsqueeze(2)), dim=2))
return torch.stack(out_images)
def ap4_try_get_comfy_model_management():
try:
import comfy.model_management as mm # type: ignore
return mm
except Exception:
return None
def ap4_gaussian_kernel_1d(kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
center = (kernel_size - 1) / 2.0
xs = torch.arange(kernel_size, device=device, dtype=dtype) - center
kernel = torch.exp(-(xs * xs) / (2.0 * sigma * sigma))
kernel = kernel / kernel.sum()
return kernel
def ap4_mask_blur(mask: torch.Tensor, amount: int = 8, device: str = "gpu") -> torch.Tensor:
mask = ap4_as_mask(mask).clamp(0.0, 1.0)
if amount == 0:
return mask
k = int(amount)
if k % 2 == 0:
k += 1
sigma = 0.3 * (((k - 1) * 0.5) - 1.0) + 0.8
mm = ap4_try_get_comfy_model_management()
if device == "gpu":
if mm is not None:
proc_device = mm.get_torch_device()
else:
proc_device = torch.device("cuda") if torch.cuda.is_available() else mask.device
elif device == "cpu":
proc_device = torch.device("cpu")
else:
proc_device = mask.device
out_device = mask.device
if device in ("gpu", "cpu") and mm is not None:
out_device = mm.intermediate_device()
orig_dtype = mask.dtype
x = mask.to(device=proc_device, dtype=torch.float32)
_, H, W = x.shape
pad = k // 2
pad_mode = "reflect" if (H > pad and W > pad and H > 1 and W > 1) else "replicate"
x4 = x.unsqueeze(1)
x4 = F.pad(x4, (pad, pad, pad, pad), mode=pad_mode)
kern1d = ap4_gaussian_kernel_1d(k, sigma, device=proc_device, dtype=torch.float32)
w_h = kern1d.view(1, 1, 1, k)
w_v = kern1d.view(1, 1, k, 1)
x4 = F.conv2d(x4, w_h)
x4 = F.conv2d(x4, w_v)
out = x4.squeeze(1).clamp(0.0, 1.0)
return out.to(device=out_device, dtype=orig_dtype)
def ap4_dilate_mask(mask: torch.Tensor, dilation: int = 3) -> torch.Tensor:
mask = ap4_as_mask(mask).clamp(0.0, 1.0)
dilation = int(dilation)
if dilation == 0:
return mask
k = abs(dilation)
x = mask.unsqueeze(1)
if dilation > 0:
y = F.max_pool2d(x, kernel_size=k, stride=1, padding=k // 2)
else:
y = -F.max_pool2d(-x, kernel_size=k, stride=1, padding=k // 2)
return y.squeeze(1).clamp(0.0, 1.0)
def ap4_fill_holes_grayscale_numpy_heap(f: np.ndarray, connectivity: int = 8) -> np.ndarray:
f = np.clip(f, 0.0, 1.0).astype(np.float32, copy=False)
H, W = f.shape
if H == 0 or W == 0:
return f
cost = np.full((H, W), np.inf, dtype=np.float32)
finalized = np.zeros((H, W), dtype=np.bool_)
heap: List[Tuple[float, int, int]] = []
def push(y: int, x: int):
c = float(f[y, x])
if c < float(cost[y, x]):
cost[y, x] = c
heapq.heappush(heap, (c, y, x))
for x in range(W):
push(0, x)
if H > 1:
push(H - 1, x)
for y in range(H):
push(y, 0)
if W > 1:
push(y, W - 1)
if connectivity == 4:
neigh = [(-1, 0), (1, 0), (0, -1), (0, 1)]
else:
neigh = [(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1)]
eps = 1e-8
while heap:
c, y, x = heapq.heappop(heap)
if finalized[y, x]:
continue
if c > float(cost[y, x]) + eps:
continue
finalized[y, x] = True
for dy, dx in neigh:
ny = y + dy
nx = x + dx
if ny < 0 or ny >= H or nx < 0 or nx >= W:
continue
if finalized[ny, nx]:
continue
v = float(f[ny, nx])
nc = c if c >= v else v
if nc < float(cost[ny, nx]) - eps:
cost[ny, nx] = nc
heapq.heappush(heap, (nc, ny, nx))
return cost
def ap4_fill_holes_mask(mask: torch.Tensor) -> torch.Tensor:
mask = ap4_as_mask(mask).clamp(0.0, 1.0)
B, H, W = mask.shape
device = mask.device
dtype = mask.dtype
mask_np = np.ascontiguousarray(mask.detach().cpu().numpy().astype(np.float32, copy=False))
filled_np = np.empty_like(mask_np)
try:
from skimage.morphology import reconstruction # type: ignore
footprint = np.ones((3, 3), dtype=bool)
for b in range(B):
f = mask_np[b]
seed = f.copy()
if H > 2 and W > 2:
seed[1:-1, 1:-1] = 1.0
else:
seed[:, :] = 1.0
seed[0, :] = f[0, :]
seed[-1, :] = f[-1, :]
seed[:, 0] = f[:, 0]
seed[:, -1] = f[:, -1]
filled_np[b] = reconstruction(seed, f, method="erosion", footprint=footprint).astype(np.float32)
except Exception:
for b in range(B):
filled_np[b] = ap4_fill_holes_grayscale_numpy_heap(mask_np[b], connectivity=8)
out = torch.from_numpy(filled_np).to(device=device, dtype=dtype)
return out.clamp(0.0, 1.0)
class apply_segment_4:
CATEGORY = "image/salia"
@classmethod
def INPUT_TYPES(cls):
choices = ap4_list_pngs() or ["<no pngs found>"]
return {
"required": {
"mask": ("MASK",),
"image": (choices, {}),
"img": ("IMAGE",),
"canvas": ("IMAGE",),
"x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
"y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("Final_Image",)
FUNCTION = "run"
def run(self, mask, image, img, canvas, x, y):
if image == "<no pngs found>":
raise FileNotFoundError("No PNGs found in assets/images next to this node")
mask_in = ap4_as_mask(mask).clamp(0.0, 1.0)
blurred = ap4_mask_blur(mask_in, amount=8, device="gpu")
dilated = ap4_dilate_mask(blurred, dilation=3)
filled = ap4_fill_holes_mask(dilated)
inversed_mask = 1.0 - filled
_asset_img, loaded_mask = ap4_load_image_from_assets(image)
combiner = AP4_AILab_MaskCombiner_Exact()
inv_cpu = inversed_mask.detach().cpu()
loaded_cpu = ap4_as_mask(loaded_mask).detach().cpu()
(alpha_mask,) = combiner.combine_masks(inv_cpu, mode="combine", mask_2=(1.0 - loaded_cpu))
alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0)
alpha_image = ap4_join_image_with_alpha_comfy(img, alpha_mask)
canvas = ap4_as_image(canvas)
alpha_image = alpha_image.to(device=canvas.device, dtype=canvas.dtype)
final = ap4_alpha_over_region(alpha_image, canvas, x, y)
return (final,)
@classmethod
def IS_CHANGED(cls, mask, image, img, canvas, x, y):
if image == "<no pngs found>":
return image
return ap4_file_hash(image)
@classmethod
def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y):
if image == "<no pngs found>":
return "No PNGs found in assets/images next to this node"
try:
path = ap4_safe_path(image)
except Exception as e:
return str(e)
if not os.path.isfile(path):
return f"File not found in assets/images: {image}"
return True
# ======================================================================================
# Fused node: Salia_ezpz_gated_Duo2 -> SAM3Segment (hardcoded) -> apply_segment_4
# ======================================================================================
class SAM3Segment_Salia:
CATEGORY = "image/salia"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("Final_Image",)
FUNCTION = "run"
@classmethod
def INPUT_TYPES(cls):
# Use the exact dropdown sources of the embedded nodes
salia_assets = _list_asset_pngs() or ["<no pngs found>"]
ap4_assets = ap4_list_pngs() or ["<no pngs found>"]
upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
return {
"required": {
"image": ("IMAGE",),
"trigger_string": ("STRING", {"default": ""}),
"X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
"Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
"positive_prompt": ("STRING", {"default": "", "multiline": True}),
"negative_prompt": ("STRING", {"default": "", "multiline": True}),
"prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "SAM3 prompt"}),
"asset_image": (salia_assets, {}),
"apply_asset_image": (ap4_assets, {}),
"square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
"upscale_factor_1": (upscale_choices, {"default": "4"}),
"denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
"square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
"upscale_factor_2": (upscale_choices, {"default": "4"}),
"denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
}
}
def __init__(self):
self._sam3 = SAM3Segment()
self._salia = Salia_ezpz_gated_Duo2()
self._ap4 = apply_segment_4()
def run(
self,
image,
trigger_string="",
X_coord=0,
Y_coord=0,
positive_prompt="",
negative_prompt="",
prompt="",
asset_image="",
apply_asset_image="",
square_size_1=384,
upscale_factor_1="4",
denoise_1=0.35,
square_size_2=384,
upscale_factor_2="4",
denoise_2=0.35,
):
# EXACT bypass: if trigger_string is empty, return input image as Final_Image
if trigger_string == "":
return (image,)
# 1) Pre-node: Salia_ezpz_gated_Duo2 -> image_cropped
_out_image, image_cropped = self._salia.run(
image=image,
trigger_string=trigger_string,
X_coord=int(X_coord),
Y_coord=int(Y_coord),
positive_prompt=str(positive_prompt),
negative_prompt=str(negative_prompt),
asset_image=str(asset_image),
square_size_1=int(square_size_1),
upscale_factor_1=str(upscale_factor_1),
denoise_1=float(denoise_1),
square_size_2=int(square_size_2),
upscale_factor_2=str(upscale_factor_2),
denoise_2=float(denoise_2),
)
# 2) Center: SAM3Segment with hardcoded settings on the CROPPED image
seg_image, seg_mask, _mask_image = self._sam3.segment(
image=image_cropped,
prompt=str(prompt),
sam3_model="sam3",
device="GPU",
confidence_threshold=0.50,
mask_blur=0,
mask_offset=0,
invert_output=False,
unload_model=False,
background="Alpha",
background_color="#222222",
)
# 3) Post-node: apply_segment_4 onto ORIGINAL input canvas (not Duo2 output)
(final_image,) = self._ap4.run(
mask=seg_mask,
image=str(apply_asset_image),
img=seg_image,
canvas=image,
x=int(X_coord),
y=int(Y_coord),
)
return (final_image,)
# ======================================================================================
# Node mappings (all nodes in this file)
# ======================================================================================
NODE_CLASS_MAPPINGS = {
"SAM3Segment": SAM3Segment,
"Salia_ezpz_gated_Duo2": Salia_ezpz_gated_Duo2,
"apply_segment_4": apply_segment_4,
"SAM3Segment_Salia": SAM3Segment_Salia,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SAM3Segment": "SAM3 Segmentation (RMBG)",
"Salia_ezpz_gated_Duo2": "Salia_ezpz_gated_Duo2",
"apply_segment_4": "apply_segment_4",
"SAM3Segment_Salia": "SAM3Segment_Salia (Duo2 → SAM3 → apply_segment_4)",
}