| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import sys |
| | import hashlib |
| | import shutil |
| | import threading |
| | import urllib.request |
| | import heapq |
| | from contextlib import nullcontext |
| | from pathlib import Path |
| | from typing import Any, Dict, Tuple, Optional, List |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image, ImageFilter, ImageOps |
| | from torch.hub import download_url_to_file |
| |
|
| | import folder_paths |
| | import comfy.model_management |
| | import comfy.model_management as model_management |
| |
|
| | from AILab_ImageMaskTools import pil2tensor, tensor2pil |
| |
|
| | |
| | |
| | |
| |
|
| | CURRENT_DIR = os.path.dirname(__file__) |
| | SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3") |
| | if SAM3_LOCAL_DIR not in sys.path: |
| | sys.path.insert(0, SAM3_LOCAL_DIR) |
| |
|
| | SAM3_BPE_PATH = os.path.join(SAM3_LOCAL_DIR, "assets", "bpe_simple_vocab_16e6.txt.gz") |
| | if not os.path.isfile(SAM3_BPE_PATH): |
| | raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.") |
| |
|
| | from sam3.model_builder import build_sam3_image_model |
| | from sam3.model.sam3_image_processor import Sam3Processor |
| |
|
| | _DEFAULT_PT_ENTRY = { |
| | "model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt", |
| | "filename": "sam3.pt", |
| | } |
| |
|
| | SAM3_MODELS = { |
| | "sam3": _DEFAULT_PT_ENTRY.copy(), |
| | } |
| |
|
| |
|
| | def get_sam3_pt_models(): |
| | entry = SAM3_MODELS.get("sam3") |
| | if entry and entry.get("filename", "").endswith(".pt"): |
| | return {"sam3": entry} |
| | for key, value in SAM3_MODELS.items(): |
| | if value.get("filename", "").endswith(".pt"): |
| | return {"sam3": value} |
| | if "sam3" in key and value: |
| | candidate = value.copy() |
| | candidate["model_url"] = _DEFAULT_PT_ENTRY["model_url"] |
| | candidate["filename"] = _DEFAULT_PT_ENTRY["filename"] |
| | return {"sam3": candidate} |
| | return {"sam3": _DEFAULT_PT_ENTRY.copy()} |
| |
|
| |
|
| | def process_mask(mask_image, invert_output=False, mask_blur=0, mask_offset=0): |
| | if invert_output: |
| | mask_np = np.array(mask_image) |
| | mask_image = Image.fromarray(255 - mask_np) |
| | if mask_blur > 0: |
| | mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) |
| | if mask_offset != 0: |
| | filt = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter |
| | size = abs(mask_offset) * 2 + 1 |
| | for _ in range(abs(mask_offset)): |
| | mask_image = mask_image.filter(filt(size)) |
| | return mask_image |
| |
|
| |
|
| | def apply_background_color(image, mask_image, background="Alpha", background_color="#222222"): |
| | rgba_image = image.copy().convert("RGBA") |
| | rgba_image.putalpha(mask_image.convert("L")) |
| | if background == "Color": |
| | hex_color = background_color.lstrip("#") |
| | r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) |
| | bg_image = Image.new("RGBA", image.size, (r, g, b, 255)) |
| | composite = Image.alpha_composite(bg_image, rgba_image) |
| | return composite.convert("RGB") |
| | return rgba_image |
| |
|
| |
|
| | def get_or_download_model_file(filename, url): |
| | local_path = None |
| | if hasattr(folder_paths, "get_full_path"): |
| | local_path = folder_paths.get_full_path("sam3", filename) |
| | if local_path and os.path.isfile(local_path): |
| | return local_path |
| | base_models_dir = getattr(folder_paths, "models_dir", os.path.join(CURRENT_DIR, "models")) |
| | models_dir = os.path.join(base_models_dir, "sam3") |
| | os.makedirs(models_dir, exist_ok=True) |
| | local_path = os.path.join(models_dir, filename) |
| | if not os.path.exists(local_path): |
| | print(f"Downloading {filename} from {url} ...") |
| | download_url_to_file(url, local_path) |
| | return local_path |
| |
|
| |
|
| | def _resolve_device(user_choice): |
| | auto_device = comfy.model_management.get_torch_device() |
| | if user_choice == "CPU": |
| | return torch.device("cpu") |
| | if user_choice == "GPU": |
| | if auto_device.type != "cuda": |
| | raise RuntimeError("GPU unavailable") |
| | return torch.device("cuda") |
| | return auto_device |
| |
|
| |
|
| | class SAM3Segment: |
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "image": ("IMAGE",), |
| | "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Describe the concept"}), |
| | "sam3_model": (list(SAM3_MODELS.keys()), {"default": "sam3"}), |
| | "device": (["Auto", "CPU", "GPU"], {"default": "Auto"}), |
| | "confidence_threshold": ("FLOAT", {"default": 0.5, "min": 0.05, "max": 0.95, "step": 0.01}), |
| | }, |
| | "optional": { |
| | "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}), |
| | "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}), |
| | "invert_output": ("BOOLEAN", {"default": False}), |
| | "unload_model": ("BOOLEAN", {"default": False}), |
| | "background": (["Alpha", "Color"], {"default": "Alpha"}), |
| | "background_color": ("COLORCODE", {"default": "#222222"}), |
| | }, |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") |
| | RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE") |
| | FUNCTION = "segment" |
| | CATEGORY = "🧪AILab/🧽RMBG" |
| |
|
| | def __init__(self): |
| | self.processor_cache = {} |
| |
|
| | def _load_processor(self, model_choice, device_choice): |
| | torch_device = _resolve_device(device_choice) |
| | device_str = "cuda" if torch_device.type == "cuda" else "cpu" |
| | cache_key = (model_choice, device_str) |
| | if cache_key not in self.processor_cache: |
| | model_info = SAM3_MODELS[model_choice] |
| | ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"]) |
| | model = build_sam3_image_model( |
| | bpe_path=SAM3_BPE_PATH, |
| | device=device_str, |
| | eval_mode=True, |
| | checkpoint_path=ckpt_path, |
| | load_from_HF=False, |
| | enable_segmentation=True, |
| | enable_inst_interactivity=False, |
| | ) |
| | processor = Sam3Processor(model, device=device_str) |
| | self.processor_cache[cache_key] = processor |
| | return self.processor_cache[cache_key], torch_device |
| |
|
| | def _empty_result(self, img_pil, background, background_color): |
| | w, h = img_pil.size |
| | mask_image = Image.new("L", (w, h), 0) |
| | result_image = apply_background_color(img_pil, mask_image, background, background_color) |
| | if background == "Alpha": |
| | result_image = result_image.convert("RGBA") |
| | else: |
| | result_image = result_image.convert("RGB") |
| | empty_mask = torch.zeros((1, h, w), dtype=torch.float32) |
| | mask_rgb = empty_mask.reshape((-1, 1, h, w)).movedim(1, -1).expand(-1, -1, -1, 3) |
| | return result_image, empty_mask, mask_rgb |
| |
|
| | def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask_offset, invert, background, background_color): |
| | img_pil = tensor2pil(img_tensor) |
| | text = prompt.strip() or "object" |
| | state = processor.set_image(img_pil) |
| | processor.reset_all_prompts(state) |
| | processor.set_confidence_threshold(confidence, state) |
| | state = processor.set_text_prompt(text, state) |
| | masks = state.get("masks") |
| | if masks is None or masks.numel() == 0: |
| | return self._empty_result(img_pil, background, background_color) |
| | masks = masks.float().to("cpu") |
| | if masks.ndim == 4: |
| | masks = masks.squeeze(1) |
| | combined = masks.amax(dim=0) |
| | mask_np = (combined.clamp(0, 1).numpy() * 255).astype(np.uint8) |
| | mask_image = Image.fromarray(mask_np, mode="L") |
| | mask_image = process_mask(mask_image, invert, mask_blur, mask_offset) |
| | result_image = apply_background_color(img_pil, mask_image, background, background_color) |
| | if background == "Alpha": |
| | result_image = result_image.convert("RGBA") |
| | else: |
| | result_image = result_image.convert("RGB") |
| | mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0) |
| | mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3) |
| | return result_image, mask_tensor, mask_rgb |
| |
|
| | def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"): |
| | if image.ndim == 3: |
| | image = image.unsqueeze(0) |
| |
|
| | processor, torch_device = self._load_processor(sam3_model, device) |
| | autocast_device = comfy.model_management.get_autocast_device(torch_device) |
| | autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device) |
| | ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext() |
| |
|
| | result_images, result_masks, result_mask_images = [], [], [] |
| |
|
| | with ctx: |
| | for tensor_img in image: |
| | img_pil, mask_tensor, mask_rgb = self._run_single( |
| | processor, |
| | tensor_img, |
| | prompt, |
| | confidence_threshold, |
| | mask_blur, |
| | mask_offset, |
| | invert_output, |
| | background, |
| | background_color, |
| | ) |
| | result_images.append(pil2tensor(img_pil)) |
| | result_masks.append(mask_tensor) |
| | result_mask_images.append(mask_rgb) |
| |
|
| | if unload_model: |
| | device_str = "cuda" if torch_device.type == "cuda" else "cpu" |
| | cache_key = (sam3_model, device_str) |
| | if cache_key in self.processor_cache: |
| | del self.processor_cache[cache_key] |
| | if torch_device.type == "cuda": |
| | torch.cuda.empty_cache() |
| |
|
| | return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | try: |
| | from transformers import pipeline |
| | except Exception as e: |
| | pipeline = None |
| | _TRANSFORMERS_IMPORT_ERROR = e |
| |
|
| | _CKPT_CACHE: Dict[str, Tuple[Any, Any, Any]] = {} |
| | _CN_CACHE: Dict[str, Any] = {} |
| | _CKPT_LOCK = threading.Lock() |
| | _CN_LOCK = threading.Lock() |
| |
|
| |
|
| | def _find_plugin_root() -> Path: |
| | """ |
| | Walk upwards from this file until we find an 'assets' folder. |
| | If not found, fall back to this file's directory. |
| | """ |
| | here = Path(__file__).resolve() |
| | for parent in [here.parent] + list(here.parents)[:12]: |
| | if (parent / "assets").is_dir(): |
| | return parent |
| | return here.parent |
| |
|
| |
|
| | PLUGIN_ROOT = _find_plugin_root() |
| |
|
| |
|
| | def _pil_lanczos(): |
| | if hasattr(Image, "Resampling"): |
| | return Image.Resampling.LANCZOS |
| | return Image.LANCZOS |
| |
|
| |
|
| | def _image_tensor_to_pil(img: torch.Tensor) -> Image.Image: |
| | if img.ndim == 4: |
| | img = img[0] |
| | img = img.detach().cpu().float().clamp(0, 1) |
| | arr = (img.numpy() * 255.0).round().astype(np.uint8) |
| | if arr.shape[-1] == 4: |
| | return Image.fromarray(arr, mode="RGBA") |
| | return Image.fromarray(arr, mode="RGB") |
| |
|
| |
|
| | def _pil_to_image_tensor(pil: Image.Image) -> torch.Tensor: |
| | if pil.mode not in ("RGB", "RGBA"): |
| | pil = pil.convert("RGBA") if "A" in pil.getbands() else pil.convert("RGB") |
| | arr = np.array(pil).astype(np.float32) / 255.0 |
| | t = torch.from_numpy(arr) |
| | return t.unsqueeze(0) |
| |
|
| |
|
| | def _mask_tensor_to_pil(mask: torch.Tensor) -> Image.Image: |
| | if mask.ndim == 3: |
| | mask = mask[0] |
| | mask = mask.detach().cpu().float().clamp(0, 1) |
| | arr = (mask.numpy() * 255.0).round().astype(np.uint8) |
| | return Image.fromarray(arr, mode="L") |
| |
|
| |
|
| | def _pil_to_mask_tensor(pil_l: Image.Image) -> torch.Tensor: |
| | if pil_l.mode != "L": |
| | pil_l = pil_l.convert("L") |
| | arr = np.array(pil_l).astype(np.float32) / 255.0 |
| | t = torch.from_numpy(arr) |
| | return t.unsqueeze(0) |
| |
|
| |
|
| | def _resize_image_lanczos(img: torch.Tensor, w: int, h: int) -> torch.Tensor: |
| | if img.ndim != 4: |
| | raise ValueError("Expected IMAGE tensor with shape [B,H,W,C].") |
| | outs = [] |
| | for i in range(img.shape[0]): |
| | pil = _image_tensor_to_pil(img[i].unsqueeze(0)) |
| | pil = pil.resize((int(w), int(h)), resample=_pil_lanczos()) |
| | outs.append(_pil_to_image_tensor(pil)) |
| | return torch.cat(outs, dim=0) |
| |
|
| |
|
| | def _resize_mask_lanczos(mask: torch.Tensor, w: int, h: int) -> torch.Tensor: |
| | if mask.ndim != 3: |
| | raise ValueError("Expected MASK tensor with shape [B,H,W].") |
| | outs = [] |
| | for i in range(mask.shape[0]): |
| | pil = _mask_tensor_to_pil(mask[i].unsqueeze(0)) |
| | pil = pil.resize((int(w), int(h)), resample=_pil_lanczos()) |
| | outs.append(_pil_to_mask_tensor(pil)) |
| | return torch.cat(outs, dim=0) |
| |
|
| |
|
| | def _rgb_to_rgba_with_comfy_mask(rgb: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| | if rgb.ndim == 3: |
| | rgb = rgb.unsqueeze(0) |
| | if mask.ndim == 2: |
| | mask = mask.unsqueeze(0) |
| |
|
| | if rgb.ndim != 4 or rgb.shape[-1] != 3: |
| | raise ValueError(f"rgb must be [B,H,W,3], got {tuple(rgb.shape)}") |
| | if mask.ndim != 3: |
| | raise ValueError(f"mask must be [B,H,W], got {tuple(mask.shape)}") |
| |
|
| | if mask.shape[0] != rgb.shape[0]: |
| | if mask.shape[0] == 1 and rgb.shape[0] > 1: |
| | mask = mask.expand(rgb.shape[0], -1, -1) |
| | else: |
| | raise ValueError("Batch mismatch between rgb and mask.") |
| |
|
| | if mask.shape[1] != rgb.shape[1] or mask.shape[2] != rgb.shape[2]: |
| | raise ValueError( |
| | f"Mask size mismatch. rgb={rgb.shape[2]}x{rgb.shape[1]} mask={mask.shape[2]}x{mask.shape[1]}" |
| | ) |
| |
|
| | mask = mask.to(device=rgb.device, dtype=rgb.dtype).clamp(0, 1) |
| | alpha = (1.0 - mask).unsqueeze(-1).clamp(0, 1) |
| | rgba = torch.cat([rgb.clamp(0, 1), alpha], dim=-1) |
| | return rgba |
| |
|
| |
|
| | def _load_checkpoint_cached(ckpt_name: str): |
| | with _CKPT_LOCK: |
| | if ckpt_name in _CKPT_CACHE: |
| | return _CKPT_CACHE[ckpt_name] |
| | import nodes |
| | loader = nodes.CheckpointLoaderSimple() |
| | fn = getattr(loader, loader.FUNCTION) |
| | model, clip, vae = fn(ckpt_name=ckpt_name) |
| | _CKPT_CACHE[ckpt_name] = (model, clip, vae) |
| | return model, clip, vae |
| |
|
| |
|
| | def _load_controlnet_cached(control_net_name: str): |
| | with _CN_LOCK: |
| | if control_net_name in _CN_CACHE: |
| | return _CN_CACHE[control_net_name] |
| | import nodes |
| | loader = nodes.ControlNetLoader() |
| | fn = getattr(loader, loader.FUNCTION) |
| | (cn,) = fn(control_net_name=control_net_name) |
| | _CN_CACHE[control_net_name] = cn |
| | return cn |
| |
|
| |
|
| | def _assets_images_dir() -> Path: |
| | return PLUGIN_ROOT / "assets" / "images" |
| |
|
| |
|
| | def _list_asset_pngs() -> list: |
| | img_dir = _assets_images_dir() |
| | if not img_dir.is_dir(): |
| | return [] |
| | files = [] |
| | for p in img_dir.rglob("*"): |
| | if p.is_file() and p.suffix.lower() == ".png": |
| | files.append(p.relative_to(img_dir).as_posix()) |
| | files.sort() |
| | return files |
| |
|
| |
|
| | def _safe_asset_path(asset_rel_path: str) -> Path: |
| | img_dir = _assets_images_dir() |
| | if not img_dir.is_dir(): |
| | raise FileNotFoundError(f"assets/images folder not found: {img_dir}") |
| |
|
| | base = img_dir.resolve() |
| | rel = Path(asset_rel_path) |
| |
|
| | if rel.is_absolute(): |
| | raise ValueError("Absolute paths are not allowed for asset_image.") |
| |
|
| | full = (base / rel).resolve() |
| |
|
| | if base != full and base not in full.parents: |
| | raise ValueError(f"Invalid asset path (path traversal blocked): {asset_rel_path}") |
| |
|
| | if not full.is_file(): |
| | raise FileNotFoundError(f"Asset PNG not found in assets/images: {asset_rel_path}") |
| | if full.suffix.lower() != ".png": |
| | raise ValueError(f"Asset is not a PNG: {asset_rel_path}") |
| |
|
| | return full |
| |
|
| |
|
| | def _load_asset_image_and_mask(asset_rel_path: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| | p = _safe_asset_path(asset_rel_path) |
| |
|
| | im = Image.open(p) |
| | im = ImageOps.exif_transpose(im) |
| |
|
| | rgba = im.convert("RGBA") |
| | rgb = rgba.convert("RGB") |
| |
|
| | rgb_arr = np.array(rgb).astype(np.float32) / 255.0 |
| | img_t = torch.from_numpy(rgb_arr)[None, ...] |
| |
|
| | alpha = np.array(rgba.getchannel("A")).astype(np.float32) / 255.0 |
| | mask = 1.0 - alpha |
| |
|
| | mask_t = torch.from_numpy(mask)[None, ...] |
| | return img_t, mask_t |
| |
|
| |
|
| | MODEL_DIR = PLUGIN_ROOT / "assets" / "depth" |
| | MODEL_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | REQUIRED_FILES = { |
| | "config.json": "https://huggingface.co/saliacoel/depth/resolve/main/config.json", |
| | "model.safetensors": "https://huggingface.co/saliacoel/depth/resolve/main/model.safetensors", |
| | "preprocessor_config.json": "https://huggingface.co/saliacoel/depth/resolve/main/preprocessor_config.json", |
| | } |
| |
|
| | ZOE_FALLBACK_REPO_ID = "Intel/zoedepth-nyu-kitti" |
| |
|
| | _PIPE_CACHE: Dict[Tuple[str, str], Any] = {} |
| | _PIPE_LOCK = threading.Lock() |
| |
|
| |
|
| | def _have_required_files() -> bool: |
| | return all((MODEL_DIR / name).exists() for name in REQUIRED_FILES.keys()) |
| |
|
| |
|
| | def _download_url_to_file(url: str, dst: Path, timeout: int = 180) -> None: |
| | dst.parent.mkdir(parents=True, exist_ok=True) |
| | tmp = dst.with_suffix(dst.suffix + ".tmp") |
| |
|
| | if tmp.exists(): |
| | try: |
| | tmp.unlink() |
| | except Exception: |
| | pass |
| |
|
| | req = urllib.request.Request(url, headers={"User-Agent": "ComfyUI-SaliaDepth/1.1"}) |
| | with urllib.request.urlopen(req, timeout=timeout) as r, open(tmp, "wb") as f: |
| | shutil.copyfileobj(r, f) |
| |
|
| | tmp.replace(dst) |
| |
|
| |
|
| | def ensure_local_model_files() -> bool: |
| | if _have_required_files(): |
| | return True |
| | try: |
| | for fname, url in REQUIRED_FILES.items(): |
| | fpath = MODEL_DIR / fname |
| | if fpath.exists(): |
| | continue |
| | _download_url_to_file(url, fpath) |
| | return _have_required_files() |
| | except Exception: |
| | return False |
| |
|
| |
|
| | def HWC3(x: np.ndarray) -> np.ndarray: |
| | assert x.dtype == np.uint8 |
| | if x.ndim == 2: |
| | x = x[:, :, None] |
| | assert x.ndim == 3 |
| | H, W, C = x.shape |
| | assert C == 1 or C == 3 or C == 4 |
| | if C == 3: |
| | return x |
| | if C == 1: |
| | return np.concatenate([x, x, x], axis=2) |
| | color = x[:, :, 0:3].astype(np.float32) |
| | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 |
| | y = color * alpha + 255.0 * (1.0 - alpha) |
| | y = y.clip(0, 255).astype(np.uint8) |
| | return y |
| |
|
| |
|
| | def pad64(x: int) -> int: |
| | return int(np.ceil(float(x) / 64.0) * 64 - x) |
| |
|
| |
|
| | def safer_memory(x: np.ndarray) -> np.ndarray: |
| | return np.ascontiguousarray(x.copy()).copy() |
| |
|
| |
|
| | def resize_image_with_pad_min_side( |
| | input_image: np.ndarray, |
| | resolution: int, |
| | upscale_method: str = "INTER_CUBIC", |
| | skip_hwc3: bool = False, |
| | mode: str = "edge", |
| | ) -> Tuple[np.ndarray, Any]: |
| | cv2 = None |
| | try: |
| | import cv2 as _cv2 |
| | cv2 = _cv2 |
| | except Exception: |
| | cv2 = None |
| |
|
| | img = input_image if skip_hwc3 else HWC3(input_image) |
| |
|
| | H_raw, W_raw, _ = img.shape |
| | if resolution <= 0: |
| | return img, (lambda x: x) |
| |
|
| | k = float(resolution) / float(min(H_raw, W_raw)) |
| | H_target = int(np.round(float(H_raw) * k)) |
| | W_target = int(np.round(float(W_raw) * k)) |
| |
|
| | if cv2 is not None: |
| | upscale_methods = { |
| | "INTER_NEAREST": cv2.INTER_NEAREST, |
| | "INTER_LINEAR": cv2.INTER_LINEAR, |
| | "INTER_AREA": cv2.INTER_AREA, |
| | "INTER_CUBIC": cv2.INTER_CUBIC, |
| | "INTER_LANCZOS4": cv2.INTER_LANCZOS4, |
| | } |
| | method = upscale_methods.get(upscale_method, cv2.INTER_CUBIC) |
| | img = cv2.resize(img, (W_target, H_target), interpolation=method if k > 1 else cv2.INTER_AREA) |
| | else: |
| | pil = Image.fromarray(img) |
| | resample = Image.BICUBIC if k > 1 else Image.LANCZOS |
| | pil = pil.resize((W_target, H_target), resample=resample) |
| | img = np.array(pil, dtype=np.uint8) |
| |
|
| | H_pad, W_pad = pad64(H_target), pad64(W_target) |
| | img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) |
| |
|
| | def remove_pad(x: np.ndarray) -> np.ndarray: |
| | return safer_memory(x[:H_target, :W_target, ...]) |
| |
|
| | return safer_memory(img_padded), remove_pad |
| |
|
| |
|
| | def pad_only_to_64(img_u8: np.ndarray, mode: str = "edge") -> Tuple[np.ndarray, Any]: |
| | img = HWC3(img_u8) |
| | H_raw, W_raw, _ = img.shape |
| | H_pad, W_pad = pad64(H_raw), pad64(W_raw) |
| | img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) |
| |
|
| | def remove_pad(x: np.ndarray) -> np.ndarray: |
| | return safer_memory(x[:H_raw, :W_raw, ...]) |
| |
|
| | return safer_memory(img_padded), remove_pad |
| |
|
| |
|
| | def composite_rgba_over_white_keep_alpha(inp_u8: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]: |
| | if inp_u8.ndim == 3 and inp_u8.shape[2] == 4: |
| | rgba = inp_u8.astype(np.uint8) |
| | rgb = rgba[:, :, 0:3].astype(np.float32) |
| | a = (rgba[:, :, 3:4].astype(np.float32) / 255.0) |
| | rgb_white = (rgb * a + 255.0 * (1.0 - a)).clip(0, 255).astype(np.uint8) |
| | alpha_u8 = rgba[:, :, 3].copy() |
| | return rgb_white, alpha_u8 |
| | return HWC3(inp_u8), None |
| |
|
| |
|
| | def apply_alpha_then_black_background(depth_rgb_u8: np.ndarray, alpha_u8: np.ndarray) -> np.ndarray: |
| | depth_rgb_u8 = HWC3(depth_rgb_u8) |
| | a = (alpha_u8.astype(np.float32) / 255.0)[:, :, None] |
| | out = (depth_rgb_u8.astype(np.float32) * a).clip(0, 255).astype(np.uint8) |
| | return out |
| |
|
| |
|
| | def comfy_tensor_to_u8(img: torch.Tensor) -> np.ndarray: |
| | if img.ndim == 4: |
| | img = img[0] |
| | arr = img.detach().cpu().float().clamp(0, 1).numpy() |
| | u8 = (arr * 255.0).round().astype(np.uint8) |
| | return u8 |
| |
|
| |
|
| | def u8_to_comfy_tensor(img_u8: np.ndarray) -> torch.Tensor: |
| | img_u8 = HWC3(img_u8) |
| | t = torch.from_numpy(img_u8.astype(np.float32) / 255.0) |
| | return t.unsqueeze(0) |
| |
|
| |
|
| | def _try_load_pipeline(model_source: str, device: torch.device): |
| | if pipeline is None: |
| | raise RuntimeError(f"transformers import failed: {_TRANSFORMERS_IMPORT_ERROR}") |
| |
|
| | key = (model_source, str(device)) |
| | with _PIPE_LOCK: |
| | if key in _PIPE_CACHE: |
| | return _PIPE_CACHE[key] |
| |
|
| | p = pipeline(task="depth-estimation", model=model_source) |
| | try: |
| | p.model = p.model.to(device) |
| | p.device = device |
| | except Exception: |
| | pass |
| |
|
| | _PIPE_CACHE[key] = p |
| | return p |
| |
|
| |
|
| | def get_depth_pipeline(device: torch.device): |
| | if ensure_local_model_files(): |
| | try: |
| | return _try_load_pipeline(str(MODEL_DIR), device) |
| | except Exception: |
| | pass |
| | try: |
| | return _try_load_pipeline(ZOE_FALLBACK_REPO_ID, device) |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def depth_estimate_zoe_style( |
| | pipe, |
| | input_rgb_u8: np.ndarray, |
| | detect_resolution: int, |
| | upscale_method: str = "INTER_CUBIC", |
| | ) -> np.ndarray: |
| | if detect_resolution == -1: |
| | work_img, remove_pad = pad_only_to_64(input_rgb_u8, mode="edge") |
| | else: |
| | work_img, remove_pad = resize_image_with_pad_min_side( |
| | input_rgb_u8, |
| | int(detect_resolution), |
| | upscale_method=upscale_method, |
| | skip_hwc3=False, |
| | mode="edge", |
| | ) |
| |
|
| | pil_image = Image.fromarray(work_img) |
| |
|
| | with torch.no_grad(): |
| | result = pipe(pil_image) |
| | depth = result["depth"] |
| |
|
| | if isinstance(depth, Image.Image): |
| | depth_array = np.array(depth, dtype=np.float32) |
| | else: |
| | depth_array = np.array(depth, dtype=np.float32) |
| |
|
| | vmin = float(np.percentile(depth_array, 2)) |
| | vmax = float(np.percentile(depth_array, 85)) |
| |
|
| | depth_array = depth_array - vmin |
| | denom = (vmax - vmin) |
| | if abs(denom) < 1e-12: |
| | denom = 1e-6 |
| | depth_array = depth_array / denom |
| |
|
| | depth_array = 1.0 - depth_array |
| | depth_image = (depth_array * 255.0).clip(0, 255).astype(np.uint8) |
| |
|
| | detected_map = remove_pad(HWC3(depth_image)) |
| | return detected_map |
| |
|
| |
|
| | def resize_to_original(depth_rgb_u8: np.ndarray, w0: int, h0: int) -> np.ndarray: |
| | try: |
| | import cv2 |
| | out = cv2.resize(depth_rgb_u8, (w0, h0), interpolation=cv2.INTER_LINEAR) |
| | return out.astype(np.uint8) |
| | except Exception: |
| | pil = Image.fromarray(depth_rgb_u8) |
| | pil = pil.resize((w0, h0), resample=Image.BILINEAR) |
| | return np.array(pil, dtype=np.uint8) |
| |
|
| |
|
| | def _salia_depth_execute(image: torch.Tensor, resolution: int = -1) -> torch.Tensor: |
| | try: |
| | device = model_management.get_torch_device() |
| | except Exception: |
| | device = torch.device("cpu") |
| |
|
| | pipe_obj = None |
| | try: |
| | pipe_obj = get_depth_pipeline(device) |
| | except Exception: |
| | pipe_obj = None |
| |
|
| | if pipe_obj is None: |
| | return image |
| |
|
| | if image.ndim == 3: |
| | image = image.unsqueeze(0) |
| |
|
| | outs = [] |
| | for i in range(image.shape[0]): |
| | try: |
| | h0 = int(image[i].shape[0]) |
| | w0 = int(image[i].shape[1]) |
| |
|
| | inp_u8 = comfy_tensor_to_u8(image[i]) |
| |
|
| | rgb_for_depth, alpha_u8 = composite_rgba_over_white_keep_alpha(inp_u8) |
| | had_rgba = alpha_u8 is not None |
| |
|
| | depth_rgb = depth_estimate_zoe_style( |
| | pipe=pipe_obj, |
| | input_rgb_u8=rgb_for_depth, |
| | detect_resolution=int(resolution), |
| | upscale_method="INTER_CUBIC", |
| | ) |
| |
|
| | depth_rgb = resize_to_original(depth_rgb, w0=w0, h0=h0) |
| |
|
| | if had_rgba: |
| | if alpha_u8.shape[0] != h0 or alpha_u8.shape[1] != w0: |
| | try: |
| | import cv2 |
| | alpha_u8 = cv2.resize(alpha_u8, (w0, h0), interpolation=cv2.INTER_LINEAR).astype(np.uint8) |
| | except Exception: |
| | pil_a = Image.fromarray(alpha_u8) |
| | pil_a = pil_a.resize((w0, h0), resample=Image.BILINEAR) |
| | alpha_u8 = np.array(pil_a, dtype=np.uint8) |
| |
|
| | depth_rgb = apply_alpha_then_black_background(depth_rgb, alpha_u8) |
| |
|
| | outs.append(u8_to_comfy_tensor(depth_rgb)) |
| | except Exception: |
| | outs.append(image[i].unsqueeze(0)) |
| |
|
| | return torch.cat(outs, dim=0) |
| |
|
| |
|
| | def _salia_alpha_over_region(base: torch.Tensor, overlay_rgba: torch.Tensor, x: int, y: int) -> torch.Tensor: |
| | if base.ndim != 4 or overlay_rgba.ndim != 4: |
| | raise ValueError("base and overlay must be [B,H,W,C].") |
| |
|
| | B, H, W, C = base.shape |
| | b2, sH, sW, c2 = overlay_rgba.shape |
| | if c2 != 4: |
| | raise ValueError("overlay_rgba must have 4 channels (RGBA).") |
| | if sH != sW: |
| | raise ValueError("overlay must be square.") |
| | s = sH |
| |
|
| | if x < 0 or y < 0 or x + s > W or y + s > H: |
| | raise ValueError(f"Square paste out of bounds. base={W}x{H}, paste at ({x},{y}) size={s}") |
| |
|
| | if b2 != B: |
| | if b2 == 1 and B > 1: |
| | overlay_rgba = overlay_rgba.expand(B, -1, -1, -1) |
| | else: |
| | raise ValueError("Batch mismatch between base and overlay.") |
| |
|
| | out = base.clone() |
| |
|
| | overlay_rgb = overlay_rgba[..., 0:3].clamp(0, 1) |
| | overlay_a = overlay_rgba[..., 3:4].clamp(0, 1) |
| |
|
| | base_rgb = out[:, y:y + s, x:x + s, 0:3] |
| | comp_rgb = overlay_rgb * overlay_a + base_rgb * (1.0 - overlay_a) |
| | out[:, y:y + s, x:x + s, 0:3] = comp_rgb |
| |
|
| | if C == 4: |
| | base_a = out[:, y:y + s, x:x + s, 3:4].clamp(0, 1) |
| | comp_a = overlay_a + base_a * (1.0 - overlay_a) |
| | out[:, y:y + s, x:x + s, 3:4] = comp_a |
| |
|
| | return out.clamp(0, 1) |
| |
|
| |
|
| | _HARDCODED_CKPT_NAME = "SaliaHighlady_Speedy.safetensors" |
| | _HARDCODED_CONTROLNET_NAME = "diffusion_pytorch_model_promax.safetensors" |
| | _HARDCODED_CN_START = 0.00 |
| | _HARDCODED_CN_END = 1.00 |
| |
|
| | _PASS1_SAMPLER_NAME = "dpmpp_2m_sde_heun_gpu" |
| | _PASS1_SCHEDULER = "karras" |
| | _PASS1_STEPS = 29 |
| | _PASS1_CFG = 2.6 |
| | _PASS1_CONTROLNET_STRENGTH = 0.33 |
| |
|
| | _PASS2_SAMPLER_NAME = "res_multistep_ancestral_cfg_pp" |
| | _PASS2_SCHEDULER = "karras" |
| | _PASS2_STEPS = 30 |
| | _PASS2_CFG = 1.7 |
| | _PASS2_CONTROLNET_STRENGTH = 0.5 |
| |
|
| |
|
| | class Salia_ezpz_gated_Duo2: |
| | CATEGORY = "image/salia" |
| | RETURN_TYPES = ("IMAGE", "IMAGE") |
| | RETURN_NAMES = ("image", "image_cropped") |
| | FUNCTION = "run" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | assets = _list_asset_pngs() or ["<no pngs found>"] |
| | upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"] |
| | return { |
| | "required": { |
| | "image": ("IMAGE",), |
| | "trigger_string": ("STRING", {"default": ""}), |
| | "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), |
| | "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), |
| | "positive_prompt": ("STRING", {"default": "", "multiline": True}), |
| | "negative_prompt": ("STRING", {"default": "", "multiline": True}), |
| | "asset_image": (assets, {}), |
| | "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), |
| | "upscale_factor_1": (upscale_choices, {"default": "4"}), |
| | "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), |
| | "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), |
| | "upscale_factor_2": (upscale_choices, {"default": "4"}), |
| | "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), |
| | } |
| | } |
| |
|
| | def run( |
| | self, |
| | image: torch.Tensor, |
| | trigger_string: str = "", |
| | X_coord: int = 0, |
| | Y_coord: int = 0, |
| | positive_prompt: str = "", |
| | negative_prompt: str = "", |
| | asset_image: str = "", |
| | square_size_1: int = 384, |
| | upscale_factor_1: str = "4", |
| | denoise_1: float = 0.35, |
| | square_size_2: int = 384, |
| | upscale_factor_2: str = "4", |
| | denoise_2: float = 0.35, |
| | ): |
| | if image.ndim == 3: |
| | image = image.unsqueeze(0) |
| | if image.ndim != 4: |
| | raise ValueError("Input image must be [B,H,W,C].") |
| |
|
| | B, H, W, C = image.shape |
| | if C not in (3, 4): |
| | raise ValueError("Input image must have 3 (RGB) or 4 (RGBA) channels.") |
| |
|
| | x = int(X_coord) |
| | y = int(Y_coord) |
| | s1 = int(square_size_1) |
| | s2 = int(square_size_2) |
| |
|
| | def _validate_square_bounds(s: int, label: str): |
| | if s <= 0: |
| | raise ValueError(f"{label}: square_size must be > 0") |
| | if x < 0 or y < 0 or x + s > W or y + s > H: |
| | raise ValueError(f"{label}: out of bounds. image={W}x{H}, rect at ({x},{y}) size={s}") |
| |
|
| | def _validate_upscale(up: int, s: int, label: str): |
| | if up not in (1, 2, 4, 6, 8, 10, 12, 14, 16): |
| | raise ValueError(f"{label}: upscale_factor must be one of 1,2,4,6,8,10,12,14,16") |
| | if ((s * up) % 8) != 0: |
| | raise ValueError(f"{label}: square_size * upscale_factor must be divisible by 8 (VAE requirement).") |
| |
|
| | def _crop_square(img: torch.Tensor, s: int) -> torch.Tensor: |
| | return img[:, y:y + s, x:x + s, :] |
| |
|
| | _validate_square_bounds(s2, "final crop (square_size_2)") |
| |
|
| | if trigger_string == "": |
| | out2 = image |
| | cropped = _crop_square(out2, s2) |
| | return (out2, cropped) |
| |
|
| | _validate_square_bounds(s1, "pass1 (square_size_1)") |
| | _validate_square_bounds(s2, "pass2 (square_size_2)") |
| |
|
| | up1 = int(upscale_factor_1) |
| | up2 = int(upscale_factor_2) |
| | _validate_upscale(up1, s1, "pass1") |
| | _validate_upscale(up2, s2, "pass2") |
| |
|
| | d1 = float(max(0.0, min(1.0, denoise_1))) |
| | d2 = float(max(0.0, min(1.0, denoise_2))) |
| |
|
| | if asset_image == "<no pngs found>": |
| | raise FileNotFoundError("No PNGs found in assets/images for this plugin.") |
| | _asset_img_unused, asset_mask = _load_asset_image_and_mask(asset_image) |
| |
|
| | if asset_mask.ndim == 2: |
| | asset_mask = asset_mask.unsqueeze(0) |
| | if asset_mask.ndim != 3: |
| | raise ValueError("Asset mask must be [B,H,W].") |
| |
|
| | if asset_mask.shape[0] != B: |
| | if asset_mask.shape[0] == 1 and B > 1: |
| | asset_mask = asset_mask.expand(B, -1, -1) |
| | else: |
| | raise ValueError("Batch mismatch for asset mask vs input image batch.") |
| |
|
| | import nodes |
| |
|
| | try: |
| | model, clip, vae = _load_checkpoint_cached(_HARDCODED_CKPT_NAME) |
| | except Exception as e: |
| | available = folder_paths.get_filename_list("checkpoints") or [] |
| | raise FileNotFoundError( |
| | f"Hardcoded ckpt not found: '{_HARDCODED_CKPT_NAME}'. " |
| | f"Put it in models/checkpoints. Available (first 50): {available[:50]}" |
| | ) from e |
| |
|
| | try: |
| | controlnet = _load_controlnet_cached(_HARDCODED_CONTROLNET_NAME) |
| | except Exception as e: |
| | available = folder_paths.get_filename_list("controlnet") or [] |
| | raise FileNotFoundError( |
| | f"Hardcoded controlnet not found: '{_HARDCODED_CONTROLNET_NAME}'. " |
| | f"Put it in models/controlnet. Available (first 50): {available[:50]}" |
| | ) from e |
| |
|
| | pos_enc = nodes.CLIPTextEncode() |
| | neg_enc = nodes.CLIPTextEncode() |
| | pos_fn = getattr(pos_enc, pos_enc.FUNCTION) |
| | neg_fn = getattr(neg_enc, neg_enc.FUNCTION) |
| | (pos_cond,) = pos_fn(text=str(positive_prompt), clip=clip) |
| | (neg_cond,) = neg_fn(text=str(negative_prompt), clip=clip) |
| |
|
| | cn_apply = nodes.ControlNetApplyAdvanced() |
| | cn_fn = getattr(cn_apply, cn_apply.FUNCTION) |
| | vae_enc = nodes.VAEEncode() |
| | vae_enc_fn = getattr(vae_enc, vae_enc.FUNCTION) |
| | ksampler = nodes.KSampler() |
| | k_fn = getattr(ksampler, ksampler.FUNCTION) |
| | vae_dec = nodes.VAEDecode() |
| | vae_dec_fn = getattr(vae_dec, vae_dec.FUNCTION) |
| |
|
| | def _run_pass( |
| | pass_index: int, |
| | in_image: torch.Tensor, |
| | s: int, |
| | up: int, |
| | denoise_v: float, |
| | steps_v: int, |
| | cfg_v: float, |
| | sampler_v: str, |
| | scheduler_v: str, |
| | controlnet_strength_v: float, |
| | ) -> torch.Tensor: |
| | up_w = s * up |
| | up_h = s * up |
| |
|
| | crop = in_image[:, y:y + s, x:x + s, :] |
| | crop_rgb = crop[:, :, :, 0:3].contiguous() |
| |
|
| | depth_small = _salia_depth_execute(crop_rgb, resolution=s) |
| | depth_up = _resize_image_lanczos(depth_small, up_w, up_h) |
| |
|
| | crop_up = _resize_image_lanczos(crop_rgb, up_w, up_h) |
| |
|
| | asset_mask_up = _resize_mask_lanczos(asset_mask, up_w, up_h) |
| |
|
| | pos_cn, neg_cn = cn_fn( |
| | strength=float(controlnet_strength_v), |
| | start_percent=float(_HARDCODED_CN_START), |
| | end_percent=float(_HARDCODED_CN_END), |
| | positive=pos_cond, |
| | negative=neg_cond, |
| | control_net=controlnet, |
| | image=depth_up, |
| | vae=vae, |
| | ) |
| |
|
| | (latent,) = vae_enc_fn(pixels=crop_up, vae=vae) |
| |
|
| | seed_material = ( |
| | f"{_HARDCODED_CKPT_NAME}|{_HARDCODED_CONTROLNET_NAME}|{asset_image}|" |
| | f"pass={pass_index}|x={x}|y={y}|s={s}|up={up}|" |
| | f"steps={steps_v}|cfg={cfg_v}|sampler={sampler_v}|scheduler={scheduler_v}|denoise={denoise_v}|" |
| | f"cn_strength={controlnet_strength_v}|" |
| | f"{positive_prompt}|{negative_prompt}" |
| | ).encode("utf-8", errors="ignore") |
| | seed64 = int(hashlib.sha256(seed_material).hexdigest()[:16], 16) |
| |
|
| | (sampled_latent,) = k_fn( |
| | seed=seed64, |
| | steps=int(steps_v), |
| | cfg=float(cfg_v), |
| | sampler_name=str(sampler_v), |
| | scheduler=str(scheduler_v), |
| | denoise=float(denoise_v), |
| | model=model, |
| | positive=pos_cn, |
| | negative=neg_cn, |
| | latent_image=latent, |
| | ) |
| |
|
| | (decoded_rgb,) = vae_dec_fn(samples=sampled_latent, vae=vae) |
| |
|
| | rgba_up = _rgb_to_rgba_with_comfy_mask(decoded_rgb, asset_mask_up) |
| | rgba_square = _resize_image_lanczos(rgba_up, s, s) |
| | out = _salia_alpha_over_region(in_image, rgba_square, x=x, y=y) |
| | return out |
| |
|
| | out1 = _run_pass( |
| | pass_index=1, |
| | in_image=image, |
| | s=s1, |
| | up=up1, |
| | denoise_v=d1, |
| | steps_v=_PASS1_STEPS, |
| | cfg_v=_PASS1_CFG, |
| | sampler_v=_PASS1_SAMPLER_NAME, |
| | scheduler_v=_PASS1_SCHEDULER, |
| | controlnet_strength_v=_PASS1_CONTROLNET_STRENGTH, |
| | ) |
| |
|
| | out2 = _run_pass( |
| | pass_index=2, |
| | in_image=out1, |
| | s=s2, |
| | up=up2, |
| | denoise_v=d2, |
| | steps_v=_PASS2_STEPS, |
| | cfg_v=_PASS2_CFG, |
| | sampler_v=_PASS2_SAMPLER_NAME, |
| | scheduler_v=_PASS2_SCHEDULER, |
| | controlnet_strength_v=_PASS2_CONTROLNET_STRENGTH, |
| | ) |
| |
|
| | cropped = out2[:, y:y + s2, x:x + s2, :] |
| | return (out2, cropped) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | _AP4_ASSETS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "images") |
| |
|
| |
|
| | def ap4_list_pngs() -> List[str]: |
| | if not os.path.isdir(_AP4_ASSETS_DIR): |
| | return [] |
| | files: List[str] = [] |
| | for root, _, fnames in os.walk(_AP4_ASSETS_DIR): |
| | for f in fnames: |
| | if f.lower().endswith(".png"): |
| | full = os.path.join(root, f) |
| | if os.path.isfile(full): |
| | rel = os.path.relpath(full, _AP4_ASSETS_DIR) |
| | files.append(rel.replace("\\", "/")) |
| | return sorted(files) |
| |
|
| |
|
| | def ap4_safe_path(filename: str) -> str: |
| | candidate = os.path.join(_AP4_ASSETS_DIR, filename) |
| | real_assets = os.path.realpath(_AP4_ASSETS_DIR) |
| | real_candidate = os.path.realpath(candidate) |
| | if not real_candidate.startswith(real_assets + os.sep) and real_candidate != real_assets: |
| | raise ValueError("Unsafe path (path traversal detected).") |
| | return real_candidate |
| |
|
| |
|
| | def ap4_file_hash(filename: str) -> str: |
| | path = ap4_safe_path(filename) |
| | h = hashlib.sha256() |
| | with open(path, "rb") as f: |
| | for chunk in iter(lambda: f.read(1024 * 1024), b""): |
| | h.update(chunk) |
| | return h.hexdigest() |
| |
|
| |
|
| | def ap4_load_image_from_assets(filename: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| | path = ap4_safe_path(filename) |
| | i = Image.open(path) |
| | i = ImageOps.exif_transpose(i) |
| |
|
| | if i.mode == "I": |
| | i = i.point(lambda px: px * (1 / 255)) |
| |
|
| | rgb = i.convert("RGB") |
| | rgb_np = np.array(rgb).astype(np.float32) / 255.0 |
| | image = torch.from_numpy(rgb_np)[None, ...] |
| |
|
| | bands = i.getbands() |
| | if "A" in bands: |
| | a = np.array(i.getchannel("A")).astype(np.float32) / 255.0 |
| | alpha = torch.from_numpy(a) |
| | else: |
| | l = np.array(i.convert("L")).astype(np.float32) / 255.0 |
| | alpha = torch.from_numpy(l) |
| |
|
| | mask = 1.0 - alpha |
| | mask = mask.clamp(0.0, 1.0).unsqueeze(0) |
| | return image, mask |
| |
|
| |
|
| | def ap4_as_image(img: torch.Tensor) -> torch.Tensor: |
| | if not isinstance(img, torch.Tensor): |
| | raise TypeError("IMAGE must be a torch.Tensor") |
| | if img.dim() != 4: |
| | raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}") |
| | if img.shape[-1] not in (3, 4): |
| | raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}") |
| | return img |
| |
|
| |
|
| | def ap4_as_mask(mask: torch.Tensor) -> torch.Tensor: |
| | if not isinstance(mask, torch.Tensor): |
| | raise TypeError("MASK must be a torch.Tensor") |
| | if mask.dim() == 2: |
| | mask = mask.unsqueeze(0) |
| | if mask.dim() != 3: |
| | raise ValueError(f"Expected MASK shape [B,H,W] or [H,W], got {tuple(mask.shape)}") |
| | return mask |
| |
|
| |
|
| | def ap4_ensure_rgba(img: torch.Tensor) -> torch.Tensor: |
| | img = ap4_as_image(img) |
| | if img.shape[-1] == 4: |
| | return img |
| | B, H, W, _ = img.shape |
| | alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype) |
| | return torch.cat([img, alpha], dim=-1) |
| |
|
| |
|
| | def ap4_alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor: |
| | overlay = ap4_as_image(overlay) |
| | canvas = ap4_as_image(canvas) |
| |
|
| | if overlay.shape[0] != canvas.shape[0]: |
| | if overlay.shape[0] == 1 and canvas.shape[0] > 1: |
| | overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:]) |
| | elif canvas.shape[0] == 1 and overlay.shape[0] > 1: |
| | canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:]) |
| | else: |
| | raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}") |
| |
|
| | _, Hc, Wc, Cc = canvas.shape |
| | _, Ho, Wo, _ = overlay.shape |
| |
|
| | x = int(x) |
| | y = int(y) |
| |
|
| | out = canvas.clone() |
| |
|
| | x0c = max(0, x) |
| | y0c = max(0, y) |
| | x1c = min(Wc, x + Wo) |
| | y1c = min(Hc, y + Ho) |
| |
|
| | if x1c <= x0c or y1c <= y0c: |
| | return out |
| |
|
| | x0o = x0c - x |
| | y0o = y0c - y |
| | x1o = x0o + (x1c - x0c) |
| | y1o = y0o + (y1c - y0c) |
| |
|
| | canvas_region = out[:, y0c:y1c, x0c:x1c, :] |
| | overlay_region = overlay[:, y0o:y1o, x0o:x1o, :] |
| |
|
| | canvas_rgba = ap4_ensure_rgba(canvas_region) |
| | overlay_rgba = ap4_ensure_rgba(overlay_region) |
| |
|
| | over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0) |
| | over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0) |
| |
|
| | under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0) |
| | under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0) |
| |
|
| | over_pm = over_rgb * over_a |
| | under_pm = under_rgb * under_a |
| |
|
| | out_a = over_a + under_a * (1.0 - over_a) |
| | out_pm = over_pm + under_pm * (1.0 - over_a) |
| |
|
| | eps = 1e-6 |
| | out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm)) |
| | out_rgb = out_rgb.clamp(0.0, 1.0) |
| | out_a = out_a.clamp(0.0, 1.0) |
| |
|
| | if Cc == 3: |
| | out[:, y0c:y1c, x0c:x1c, :] = out_rgb |
| | else: |
| | out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1) |
| |
|
| | return out |
| |
|
| |
|
| | class AP4_AILab_MaskCombiner_Exact: |
| | def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None): |
| | masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None] |
| | if len(masks) <= 1: |
| | return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),) |
| |
|
| | ref_shape = masks[0].shape |
| | masks = [self._resize_if_needed(m, ref_shape) for m in masks] |
| |
|
| | if mode == "combine": |
| | result = torch.maximum(masks[0], masks[1]) |
| | for mask in masks[2:]: |
| | result = torch.maximum(result, mask) |
| | elif mode == "intersection": |
| | result = torch.minimum(masks[0], masks[1]) |
| | else: |
| | result = torch.abs(masks[0] - masks[1]) |
| |
|
| | return (torch.clamp(result, 0, 1),) |
| |
|
| | def _resize_if_needed(self, mask, target_shape): |
| | if mask.shape == target_shape: |
| | return mask |
| |
|
| | if len(mask.shape) == 2: |
| | mask = mask.unsqueeze(0) |
| | elif len(mask.shape) == 4: |
| | mask = mask.squeeze(1) |
| |
|
| | target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0] |
| | target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1] |
| |
|
| | resized_masks = [] |
| | for i in range(mask.shape[0]): |
| | mask_np = mask[i].cpu().numpy() |
| | img = Image.fromarray((mask_np * 255).astype(np.uint8)) |
| | img_resized = img.resize((target_width, target_height), Image.LANCZOS) |
| | mask_resized = np.array(img_resized).astype(np.float32) / 255.0 |
| | resized_masks.append(torch.from_numpy(mask_resized)) |
| |
|
| | return torch.stack(resized_masks) |
| |
|
| |
|
| | def ap4_resize_mask_comfy(alpha_mask: torch.Tensor, image_shape_hwc: Tuple[int, int, int]) -> torch.Tensor: |
| | H = int(image_shape_hwc[0]) |
| | W = int(image_shape_hwc[1]) |
| | return F.interpolate( |
| | alpha_mask.reshape((-1, 1, alpha_mask.shape[-2], alpha_mask.shape[-1])), |
| | size=(H, W), |
| | mode="bilinear", |
| | ).squeeze(1) |
| |
|
| |
|
| | def ap4_join_image_with_alpha_comfy(image: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor: |
| | image = ap4_as_image(image) |
| | alpha = ap4_as_mask(alpha) |
| | alpha = alpha.to(device=image.device, dtype=image.dtype) |
| |
|
| | batch_size = min(len(image), len(alpha)) |
| | out_images = [] |
| |
|
| | alpha_resized = 1.0 - ap4_resize_mask_comfy(alpha, image.shape[1:]) |
| |
|
| | for i in range(batch_size): |
| | out_images.append(torch.cat((image[i][:, :, :3], alpha_resized[i].unsqueeze(2)), dim=2)) |
| |
|
| | return torch.stack(out_images) |
| |
|
| |
|
| | def ap4_try_get_comfy_model_management(): |
| | try: |
| | import comfy.model_management as mm |
| | return mm |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def ap4_gaussian_kernel_1d(kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
| | center = (kernel_size - 1) / 2.0 |
| | xs = torch.arange(kernel_size, device=device, dtype=dtype) - center |
| | kernel = torch.exp(-(xs * xs) / (2.0 * sigma * sigma)) |
| | kernel = kernel / kernel.sum() |
| | return kernel |
| |
|
| |
|
| | def ap4_mask_blur(mask: torch.Tensor, amount: int = 8, device: str = "gpu") -> torch.Tensor: |
| | mask = ap4_as_mask(mask).clamp(0.0, 1.0) |
| |
|
| | if amount == 0: |
| | return mask |
| |
|
| | k = int(amount) |
| | if k % 2 == 0: |
| | k += 1 |
| |
|
| | sigma = 0.3 * (((k - 1) * 0.5) - 1.0) + 0.8 |
| |
|
| | mm = ap4_try_get_comfy_model_management() |
| |
|
| | if device == "gpu": |
| | if mm is not None: |
| | proc_device = mm.get_torch_device() |
| | else: |
| | proc_device = torch.device("cuda") if torch.cuda.is_available() else mask.device |
| | elif device == "cpu": |
| | proc_device = torch.device("cpu") |
| | else: |
| | proc_device = mask.device |
| |
|
| | out_device = mask.device |
| | if device in ("gpu", "cpu") and mm is not None: |
| | out_device = mm.intermediate_device() |
| |
|
| | orig_dtype = mask.dtype |
| | x = mask.to(device=proc_device, dtype=torch.float32) |
| |
|
| | _, H, W = x.shape |
| | pad = k // 2 |
| |
|
| | pad_mode = "reflect" if (H > pad and W > pad and H > 1 and W > 1) else "replicate" |
| |
|
| | x4 = x.unsqueeze(1) |
| | x4 = F.pad(x4, (pad, pad, pad, pad), mode=pad_mode) |
| |
|
| | kern1d = ap4_gaussian_kernel_1d(k, sigma, device=proc_device, dtype=torch.float32) |
| | w_h = kern1d.view(1, 1, 1, k) |
| | w_v = kern1d.view(1, 1, k, 1) |
| |
|
| | x4 = F.conv2d(x4, w_h) |
| | x4 = F.conv2d(x4, w_v) |
| |
|
| | out = x4.squeeze(1).clamp(0.0, 1.0) |
| | return out.to(device=out_device, dtype=orig_dtype) |
| |
|
| |
|
| | def ap4_dilate_mask(mask: torch.Tensor, dilation: int = 3) -> torch.Tensor: |
| | mask = ap4_as_mask(mask).clamp(0.0, 1.0) |
| | dilation = int(dilation) |
| | if dilation == 0: |
| | return mask |
| |
|
| | k = abs(dilation) |
| | x = mask.unsqueeze(1) |
| |
|
| | if dilation > 0: |
| | y = F.max_pool2d(x, kernel_size=k, stride=1, padding=k // 2) |
| | else: |
| | y = -F.max_pool2d(-x, kernel_size=k, stride=1, padding=k // 2) |
| |
|
| | return y.squeeze(1).clamp(0.0, 1.0) |
| |
|
| |
|
| | def ap4_fill_holes_grayscale_numpy_heap(f: np.ndarray, connectivity: int = 8) -> np.ndarray: |
| | f = np.clip(f, 0.0, 1.0).astype(np.float32, copy=False) |
| | H, W = f.shape |
| | if H == 0 or W == 0: |
| | return f |
| |
|
| | cost = np.full((H, W), np.inf, dtype=np.float32) |
| | finalized = np.zeros((H, W), dtype=np.bool_) |
| | heap: List[Tuple[float, int, int]] = [] |
| |
|
| | def push(y: int, x: int): |
| | c = float(f[y, x]) |
| | if c < float(cost[y, x]): |
| | cost[y, x] = c |
| | heapq.heappush(heap, (c, y, x)) |
| |
|
| | for x in range(W): |
| | push(0, x) |
| | if H > 1: |
| | push(H - 1, x) |
| | for y in range(H): |
| | push(y, 0) |
| | if W > 1: |
| | push(y, W - 1) |
| |
|
| | if connectivity == 4: |
| | neigh = [(-1, 0), (1, 0), (0, -1), (0, 1)] |
| | else: |
| | neigh = [(-1, -1), (-1, 0), (-1, 1), |
| | (0, -1), (0, 1), |
| | (1, -1), (1, 0), (1, 1)] |
| |
|
| | eps = 1e-8 |
| | while heap: |
| | c, y, x = heapq.heappop(heap) |
| |
|
| | if finalized[y, x]: |
| | continue |
| | if c > float(cost[y, x]) + eps: |
| | continue |
| |
|
| | finalized[y, x] = True |
| |
|
| | for dy, dx in neigh: |
| | ny = y + dy |
| | nx = x + dx |
| | if ny < 0 or ny >= H or nx < 0 or nx >= W: |
| | continue |
| | if finalized[ny, nx]: |
| | continue |
| |
|
| | v = float(f[ny, nx]) |
| | nc = c if c >= v else v |
| | if nc < float(cost[ny, nx]) - eps: |
| | cost[ny, nx] = nc |
| | heapq.heappush(heap, (nc, ny, nx)) |
| |
|
| | return cost |
| |
|
| |
|
| | def ap4_fill_holes_mask(mask: torch.Tensor) -> torch.Tensor: |
| | mask = ap4_as_mask(mask).clamp(0.0, 1.0) |
| |
|
| | B, H, W = mask.shape |
| | device = mask.device |
| | dtype = mask.dtype |
| |
|
| | mask_np = np.ascontiguousarray(mask.detach().cpu().numpy().astype(np.float32, copy=False)) |
| | filled_np = np.empty_like(mask_np) |
| |
|
| | try: |
| | from skimage.morphology import reconstruction |
| | footprint = np.ones((3, 3), dtype=bool) |
| |
|
| | for b in range(B): |
| | f = mask_np[b] |
| | seed = f.copy() |
| |
|
| | if H > 2 and W > 2: |
| | seed[1:-1, 1:-1] = 1.0 |
| | else: |
| | seed[:, :] = 1.0 |
| | seed[0, :] = f[0, :] |
| | seed[-1, :] = f[-1, :] |
| | seed[:, 0] = f[:, 0] |
| | seed[:, -1] = f[:, -1] |
| |
|
| | filled_np[b] = reconstruction(seed, f, method="erosion", footprint=footprint).astype(np.float32) |
| |
|
| | except Exception: |
| | for b in range(B): |
| | filled_np[b] = ap4_fill_holes_grayscale_numpy_heap(mask_np[b], connectivity=8) |
| |
|
| | out = torch.from_numpy(filled_np).to(device=device, dtype=dtype) |
| | return out.clamp(0.0, 1.0) |
| |
|
| |
|
| | class apply_segment_4: |
| | CATEGORY = "image/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | choices = ap4_list_pngs() or ["<no pngs found>"] |
| | return { |
| | "required": { |
| | "mask": ("MASK",), |
| | "image": (choices, {}), |
| | "img": ("IMAGE",), |
| | "canvas": ("IMAGE",), |
| | "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("Final_Image",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | raise FileNotFoundError("No PNGs found in assets/images next to this node") |
| |
|
| | mask_in = ap4_as_mask(mask).clamp(0.0, 1.0) |
| |
|
| | blurred = ap4_mask_blur(mask_in, amount=8, device="gpu") |
| | dilated = ap4_dilate_mask(blurred, dilation=3) |
| | filled = ap4_fill_holes_mask(dilated) |
| |
|
| | inversed_mask = 1.0 - filled |
| |
|
| | _asset_img, loaded_mask = ap4_load_image_from_assets(image) |
| |
|
| | combiner = AP4_AILab_MaskCombiner_Exact() |
| |
|
| | inv_cpu = inversed_mask.detach().cpu() |
| | loaded_cpu = ap4_as_mask(loaded_mask).detach().cpu() |
| |
|
| | (alpha_mask,) = combiner.combine_masks(inv_cpu, mode="combine", mask_2=(1.0 - loaded_cpu)) |
| | alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0) |
| |
|
| | alpha_image = ap4_join_image_with_alpha_comfy(img, alpha_mask) |
| |
|
| | canvas = ap4_as_image(canvas) |
| | alpha_image = alpha_image.to(device=canvas.device, dtype=canvas.dtype) |
| | final = ap4_alpha_over_region(alpha_image, canvas, x, y) |
| |
|
| | return (final,) |
| |
|
| | @classmethod |
| | def IS_CHANGED(cls, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | return image |
| | return ap4_file_hash(image) |
| |
|
| | @classmethod |
| | def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | return "No PNGs found in assets/images next to this node" |
| | try: |
| | path = ap4_safe_path(image) |
| | except Exception as e: |
| | return str(e) |
| | if not os.path.isfile(path): |
| | return f"File not found in assets/images: {image}" |
| | return True |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class SAM3Segment_Salia: |
| | CATEGORY = "image/salia" |
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("Final_Image",) |
| | FUNCTION = "run" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | |
| | salia_assets = _list_asset_pngs() or ["<no pngs found>"] |
| | ap4_assets = ap4_list_pngs() or ["<no pngs found>"] |
| | upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"] |
| | return { |
| | "required": { |
| | "image": ("IMAGE",), |
| | "trigger_string": ("STRING", {"default": ""}), |
| |
|
| | "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), |
| | "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}), |
| |
|
| | "positive_prompt": ("STRING", {"default": "", "multiline": True}), |
| | "negative_prompt": ("STRING", {"default": "", "multiline": True}), |
| | "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "SAM3 prompt"}), |
| |
|
| | "asset_image": (salia_assets, {}), |
| | "apply_asset_image": (ap4_assets, {}), |
| |
|
| | "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), |
| | "upscale_factor_1": (upscale_choices, {"default": "4"}), |
| | "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), |
| |
|
| | "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}), |
| | "upscale_factor_2": (upscale_choices, {"default": "4"}), |
| | "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}), |
| | } |
| | } |
| |
|
| | def __init__(self): |
| | self._sam3 = SAM3Segment() |
| | self._salia = Salia_ezpz_gated_Duo2() |
| | self._ap4 = apply_segment_4() |
| |
|
| | def run( |
| | self, |
| | image, |
| | trigger_string="", |
| | X_coord=0, |
| | Y_coord=0, |
| | positive_prompt="", |
| | negative_prompt="", |
| | prompt="", |
| | asset_image="", |
| | apply_asset_image="", |
| | square_size_1=384, |
| | upscale_factor_1="4", |
| | denoise_1=0.35, |
| | square_size_2=384, |
| | upscale_factor_2="4", |
| | denoise_2=0.35, |
| | ): |
| | |
| | if trigger_string == "": |
| | return (image,) |
| |
|
| | |
| | _out_image, image_cropped = self._salia.run( |
| | image=image, |
| | trigger_string=trigger_string, |
| | X_coord=int(X_coord), |
| | Y_coord=int(Y_coord), |
| | positive_prompt=str(positive_prompt), |
| | negative_prompt=str(negative_prompt), |
| | asset_image=str(asset_image), |
| | square_size_1=int(square_size_1), |
| | upscale_factor_1=str(upscale_factor_1), |
| | denoise_1=float(denoise_1), |
| | square_size_2=int(square_size_2), |
| | upscale_factor_2=str(upscale_factor_2), |
| | denoise_2=float(denoise_2), |
| | ) |
| |
|
| | |
| | seg_image, seg_mask, _mask_image = self._sam3.segment( |
| | image=image_cropped, |
| | prompt=str(prompt), |
| | sam3_model="sam3", |
| | device="GPU", |
| | confidence_threshold=0.50, |
| | mask_blur=0, |
| | mask_offset=0, |
| | invert_output=False, |
| | unload_model=False, |
| | background="Alpha", |
| | background_color="#222222", |
| | ) |
| |
|
| | |
| | (final_image,) = self._ap4.run( |
| | mask=seg_mask, |
| | image=str(apply_asset_image), |
| | img=seg_image, |
| | canvas=image, |
| | x=int(X_coord), |
| | y=int(Y_coord), |
| | ) |
| |
|
| | return (final_image,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | NODE_CLASS_MAPPINGS = { |
| | "SAM3Segment": SAM3Segment, |
| | "Salia_ezpz_gated_Duo2": Salia_ezpz_gated_Duo2, |
| | "apply_segment_4": apply_segment_4, |
| | "SAM3Segment_Salia": SAM3Segment_Salia, |
| | } |
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = { |
| | "SAM3Segment": "SAM3 Segmentation (RMBG)", |
| | "Salia_ezpz_gated_Duo2": "Salia_ezpz_gated_Duo2", |
| | "apply_segment_4": "apply_segment_4", |
| | "SAM3Segment_Salia": "SAM3Segment_Salia (Duo2 → SAM3 → apply_segment_4)", |
| | } |
| |
|