import os from typing import Tuple import torch import torch.nn.functional as F import numpy as np from PIL import Image # Salia utils (same style as your loader node) try: from ..utils.io import list_pngs, load_image_from_assets, file_hash, safe_path except Exception: # Fallback if you placed this file in a different package depth try: from .utils.io import list_pngs, load_image_from_assets, file_hash, safe_path except Exception as e: _UTILS_IMPORT_ERR = e def _missing(*args, **kwargs): raise ImportError( "Could not import Salia utils (list_pngs/load_image_from_assets/file_hash/safe_path). " "Place this node file in the same package layout as your other Salia nodes.\n" f"Original import error: {_UTILS_IMPORT_ERR}" ) list_pngs = _missing load_image_from_assets = _missing file_hash = _missing safe_path = _missing # ----------------------------- # Helpers (IMAGE) # ----------------------------- def _as_image(img: torch.Tensor) -> torch.Tensor: # ComfyUI IMAGE is usually [B,H,W,C] if not isinstance(img, torch.Tensor): raise TypeError("IMAGE must be a torch.Tensor") if img.dim() != 4: raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}") if img.shape[-1] not in (3, 4): raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}") return img def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor: """ Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros. image: [B,H,W,C] returns: [B,h,w,C] """ image = _as_image(image) B, H, W, C = image.shape w = max(1, int(w)) h = max(1, int(h)) x = int(x) y = int(y) out = torch.zeros((B, h, w, C), device=image.device, dtype=image.dtype) # intersection in source x0s = max(0, x) y0s = max(0, y) x1s = min(W, x + w) y1s = min(H, y + h) if x1s <= x0s or y1s <= y0s: return out # destination offsets x0d = x0s - x y0d = y0s - y x1d = x0d + (x1s - x0s) y1d = y0d + (y1s - y0s) out[:, y0d:y1d, x0d:x1d, :] = image[:, y0s:y1s, x0s:x1s, :] return out def _ensure_rgba(img: torch.Tensor) -> torch.Tensor: """ img: [B,H,W,C] where C is 3 or 4 returns RGBA [B,H,W,4] """ img = _as_image(img) if img.shape[-1] == 4: return img # RGB -> RGBA with alpha=1 B, H, W, _ = img.shape alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype) return torch.cat([img, alpha], dim=-1) def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor: """ Places overlay at canvas pixel position (x,y) top-left corner. Supports RGB/RGBA for both. Uses alpha-over if overlay has alpha or canvas has alpha. Returns same channel count as canvas (3->3, 4->4). """ overlay = _as_image(overlay) canvas = _as_image(canvas) # Simple batch handling (Comfy usually matches batches, but allow 1->N) if overlay.shape[0] != canvas.shape[0]: if overlay.shape[0] == 1 and canvas.shape[0] > 1: overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:]) elif canvas.shape[0] == 1 and overlay.shape[0] > 1: canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:]) else: raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}") B, Hc, Wc, Cc = canvas.shape _, Ho, Wo, _ = overlay.shape x = int(x) y = int(y) out = canvas.clone() # intersection on canvas x0c = max(0, x) y0c = max(0, y) x1c = min(Wc, x + Wo) y1c = min(Hc, y + Ho) if x1c <= x0c or y1c <= y0c: return out # corresponding region on overlay x0o = x0c - x y0o = y0c - y x1o = x0o + (x1c - x0c) y1o = y0o + (y1c - y0c) canvas_region = out[:, y0c:y1c, x0c:x1c, :] overlay_region = overlay[:, y0o:y1o, x0o:x1o, :] # Convert both regions to RGBA for compositing canvas_rgba = _ensure_rgba(canvas_region) overlay_rgba = _ensure_rgba(overlay_region) over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0) over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0) under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0) under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0) # Premultiplied alpha composite: out = over + under*(1-over_a) over_pm = over_rgb * over_a under_pm = under_rgb * under_a out_a = over_a + under_a * (1.0 - over_a) out_pm = over_pm + under_pm * (1.0 - over_a) eps = 1e-6 out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm)) out_rgb = out_rgb.clamp(0.0, 1.0) out_a = out_a.clamp(0.0, 1.0) if Cc == 3: out[:, y0c:y1c, x0c:x1c, :] = out_rgb else: out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1) return out # ----------------------------- # RMBG EXACT MASK COMBINE LOGIC (copied solution) # ----------------------------- class _AILab_MaskCombiner_Exact: def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None): try: masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None] if len(masks) <= 1: return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),) ref_shape = masks[0].shape masks = [self._resize_if_needed(m, ref_shape) for m in masks] if mode == "combine": result = torch.maximum(masks[0], masks[1]) for mask in masks[2:]: result = torch.maximum(result, mask) elif mode == "intersection": result = torch.minimum(masks[0], masks[1]) else: result = torch.abs(masks[0] - masks[1]) return (torch.clamp(result, 0, 1),) except Exception as e: print(f"Error in combine_masks: {str(e)}") print(f"Mask shapes: {[m.shape for m in masks]}") raise e def _resize_if_needed(self, mask, target_shape): try: if mask.shape == target_shape: return mask if len(mask.shape) == 2: mask = mask.unsqueeze(0) elif len(mask.shape) == 4: mask = mask.squeeze(1) target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0] target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1] resized_masks = [] for i in range(mask.shape[0]): mask_np = mask[i].cpu().numpy() img = Image.fromarray((mask_np * 255).astype(np.uint8)) img_resized = img.resize((target_width, target_height), Image.LANCZOS) mask_resized = np.array(img_resized).astype(np.float32) / 255.0 resized_masks.append(torch.from_numpy(mask_resized)) return torch.stack(resized_masks) except Exception as e: print(f"Error in _resize_if_needed: {str(e)}") print(f"Input mask shape: {mask.shape}, Target shape: {target_shape}") raise e # ----------------------------- # 1) Cropout_Square_From_IMG # ----------------------------- class Cropout_Square_From_IMG: CATEGORY = "image/salia" @classmethod def INPUT_TYPES(cls): return { "required": { "img": ("IMAGE",), "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "square_size": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "run" def run(self, img, x, y, square_size): cropped = _crop_with_padding(img, x, y, square_size, square_size) return (cropped,) # ----------------------------- # 2) Cropout_Rect_From_IMG # ----------------------------- class Cropout_Rect_From_IMG: CATEGORY = "image/salia" @classmethod def INPUT_TYPES(cls): return { "required": { "img": ("IMAGE",), "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "width": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}), "height": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "run" def run(self, img, x, y, width, height): cropped = _crop_with_padding(img, x, y, width, height) return (cropped,) # ----------------------------- # 3) Paste_rect_to_img # ----------------------------- class Paste_rect_to_img: CATEGORY = "image/salia" @classmethod def INPUT_TYPES(cls): return { "required": { "overlay": ("IMAGE",), "canvas": ("IMAGE",), "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "run" def run(self, overlay, canvas, x, y): out = _alpha_over_region(overlay, canvas, x, y) return (out,) # ----------------------------- # 4) Combine_2_masks (RMBG exact: torch.maximum + PIL resize) # ----------------------------- class Combine_2_masks: CATEGORY = "mask/salia" @classmethod def INPUT_TYPES(cls): return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}} RETURN_TYPES = ("MASK",) RETURN_NAMES = ("mask",) FUNCTION = "run" def run(self, maskA, maskB): combiner = _AILab_MaskCombiner_Exact() out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB) return (out,) # ----------------------------- # 5) Combine_2_masks_invert_1 (invert A then RMBG combine) # ----------------------------- class Combine_2_masks_invert_1: CATEGORY = "mask/salia" @classmethod def INPUT_TYPES(cls): return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}} RETURN_TYPES = ("MASK",) RETURN_NAMES = ("mask",) FUNCTION = "run" def run(self, maskA, maskB): combiner = _AILab_MaskCombiner_Exact() maskA = 1.0 - maskA out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB) return (out,) # ----------------------------- # 6) Combine_2_masks_inverse # invert both, combine, invert result (RMBG max logic) # ----------------------------- class Combine_2_masks_inverse: CATEGORY = "mask/salia" @classmethod def INPUT_TYPES(cls): return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}} RETURN_TYPES = ("MASK",) RETURN_NAMES = ("mask",) FUNCTION = "run" def run(self, maskA, maskB): combiner = _AILab_MaskCombiner_Exact() maskA = 1.0 - maskA maskB = 1.0 - maskB combined, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB) out = 1.0 - combined out = torch.clamp(out, 0, 1) return (out,) # ----------------------------- # 7) combine_masks_with_loaded (RMBG exact combine) # ----------------------------- class combine_masks_with_loaded: CATEGORY = "mask/salia" @classmethod def INPUT_TYPES(cls): choices = list_pngs() or [""] return { "required": { "mask": ("MASK",), "image": (choices, {}), } } RETURN_TYPES = ("MASK",) RETURN_NAMES = ("mask",) FUNCTION = "run" def run(self, mask, image): if image == "": raise FileNotFoundError("No PNGs in assets/images") _img, loaded_mask = load_image_from_assets(image) combiner = _AILab_MaskCombiner_Exact() out, = combiner.combine_masks(mask, mode="combine", mask_2=1.0-loaded_mask) return (out,) @classmethod def IS_CHANGED(cls, mask, image): if image == "": return image return file_hash(image) @classmethod def VALIDATE_INPUTS(cls, mask, image): if image == "": return "No PNGs in assets/images" try: path = safe_path(image) except Exception as e: return str(e) if not os.path.isfile(path): return f"File not found in assets/images: {image}" return True # ----------------------------- # 8) NEW: invert input mask, combine with loaded mask, apply to image alpha, paste on canvas # ----------------------------- class apply_segment: CATEGORY = "image/salia" @classmethod def INPUT_TYPES(cls): choices = list_pngs() or [""] return { "required": { "mask": ("MASK",), "image": (choices, {}), # dropdown asset (used ONLY for loaded mask) "img": ("IMAGE",), # the image to receive final_mask as alpha (overlay source) "canvas": ("IMAGE",), # destination "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "run" def run(self, mask, image, img, canvas, x, y): if image == "": raise FileNotFoundError("No PNGs in assets/images") combiner = _AILab_MaskCombiner_Exact() # Load asset mask (do NOT invert) _img_asset, loaded_mask = load_image_from_assets(image) # Invert input mask, then combine with loaded mask (RMBG exact combine => maximum) inv_mask = 1.0 - mask final_mask, = combiner.combine_masks(inv_mask, mode="combine", mask_2=loaded_mask) # Apply final_mask as alpha to input image -> final_overlay (RGBA) img = _as_image(img) B, H, W, C = img.shape # Resize final_mask to match img H/W if needed (uses RMBG exact resize helper) # (target_shape must look like a mask shape [B,H,W], but resize keeps its own batch count) final_mask_resized = combiner._resize_if_needed(final_mask, (final_mask.shape[0], H, W)) # Batch match (simple 1->N expansion only) if final_mask_resized.shape[0] != B: if final_mask_resized.shape[0] == 1 and B > 1: final_mask_resized = final_mask_resized.expand(B, H, W) elif B == 1 and final_mask_resized.shape[0] > 1: img = img.expand(final_mask_resized.shape[0], *img.shape[1:]) B = img.shape[0] else: raise ValueError(f"Batch mismatch: img batch={B}, final_mask batch={final_mask_resized.shape[0]}") if C == 3: # RGB -> RGBA with alpha = final_mask alpha = final_mask_resized.to(device=img.device, dtype=img.dtype) final_overlay = torch.cat([img, alpha.unsqueeze(-1)], dim=-1) else: # RGBA: combine existing alpha with final_mask using RMBG combine (maximum) rgb = img[..., :3] alpha_img = img[..., 3] # [B,H,W] # RMBG combine uses PIL-resize sometimes, so keep combine inputs on CPU a1 = alpha_img.detach().cpu() a2 = final_mask_resized.detach().cpu() combined_alpha, = combiner.combine_masks(a1, mode="combine", mask_2=a2) combined_alpha = combined_alpha.to(device=img.device, dtype=img.dtype) final_overlay = torch.cat([rgb, combined_alpha.unsqueeze(-1)], dim=-1) # Paste final_overlay onto canvas at (x,y) canvas = _as_image(canvas) final_overlay = final_overlay.to(device=canvas.device, dtype=canvas.dtype) out = _alpha_over_region(final_overlay, canvas, x, y) return (out,) @classmethod def IS_CHANGED(cls, mask, image, img, canvas, x, y): if image == "": return image return file_hash(image) @classmethod def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y): if image == "": return "No PNGs in assets/images" try: path = safe_path(image) except Exception as e: return str(e) if not os.path.isfile(path): return f"File not found in assets/images: {image}" return True # ----------------------------- # 9) NEW: apply_segment_2 # Steps: # 1) inverse_mask = 1 - mask # 2) alpha_mask = combine_masks_with_loaded(inverse_mask, selected_asset) # (i.e. max(inverse_mask, 1 - loaded_mask)) # 3) overlay = join img with alpha using alpha_mask # - RGB: create RGBA with alpha = alpha_mask # - RGBA: alpha_out = alpha_img * alpha_mask (more transparent, never more opaque) # 4) paste overlay onto canvas at (x,y) using alpha-over # ----------------------------- class apply_segment_2: CATEGORY = "image/salia" @classmethod def INPUT_TYPES(cls): choices = list_pngs() or [""] return { "required": { "mask": ("MASK",), "image": (choices, {}), # dropdown asset (used ONLY for loaded mask) "img": ("IMAGE",), # the image to receive alpha_mask as alpha (overlay source) "canvas": ("IMAGE",), # destination "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "run" def run(self, mask, image, img, canvas, x, y): if image == "": raise FileNotFoundError("No PNGs in assets/images") combiner = _AILab_MaskCombiner_Exact() # --- Step 1: invert input mask -> inverse_mask inverse_mask = (1.0 - mask) # --- Step 2: alpha_mask = combine_masks_with_loaded(inverse_mask, image) # combine_masks_with_loaded does: max(mask, 1-loaded_mask) _img_asset, loaded_mask = load_image_from_assets(image) # Make sure both are on CPU so combiner doesn't hit device mismatch inverse_mask_cpu = inverse_mask.detach().cpu() loaded_mask_cpu = loaded_mask.detach().cpu() alpha_mask, = combiner.combine_masks( inverse_mask_cpu, mode="combine", mask_2=(1.0 - loaded_mask_cpu), ) alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0) # --- Step 3: join img with alpha using alpha_mask -> overlay img = _as_image(img) B, H, W, C = img.shape # Resize alpha_mask to match img H/W if needed (RMBG exact resize helper) alpha_mask_resized = combiner._resize_if_needed(alpha_mask, (alpha_mask.shape[0], H, W)) # Batch match (simple 1->N expansion only) if alpha_mask_resized.shape[0] != B: if alpha_mask_resized.shape[0] == 1 and B > 1: alpha_mask_resized = alpha_mask_resized.expand(B, H, W) elif B == 1 and alpha_mask_resized.shape[0] > 1: img = img.expand(alpha_mask_resized.shape[0], *img.shape[1:]) B = img.shape[0] else: raise ValueError( f"Batch mismatch: img batch={B}, alpha_mask batch={alpha_mask_resized.shape[0]}" ) alpha_mask_resized = alpha_mask_resized.to(device=img.device, dtype=img.dtype).clamp(0.0, 1.0) if C == 3: # RGB -> RGBA with alpha = alpha_mask overlay = torch.cat([img, alpha_mask_resized.unsqueeze(-1)], dim=-1) else: # RGBA: DO NOT replace alpha. # Combine to become MORE transparent: multiply existing alpha by alpha_mask. rgb = img[..., :3] alpha_img = img[..., 3].clamp(0.0, 1.0) alpha_out = (alpha_img * alpha_mask_resized).clamp(0.0, 1.0) overlay = torch.cat([rgb, alpha_out.unsqueeze(-1)], dim=-1) # --- Step 4: paste overlay onto canvas at (x,y) canvas = _as_image(canvas) overlay = overlay.to(device=canvas.device, dtype=canvas.dtype) out = _alpha_over_region(overlay, canvas, x, y) return (out,) @classmethod def IS_CHANGED(cls, mask, image, img, canvas, x, y): if image == "": return image return file_hash(image) @classmethod def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y): if image == "": return "No PNGs in assets/images" try: path = safe_path(image) except Exception as e: return str(e) if not os.path.isfile(path): return f"File not found in assets/images: {image}" return True NODE_CLASS_MAPPINGS = { "Cropout_Square_From_IMG": Cropout_Square_From_IMG, "Cropout_Rect_From_IMG": Cropout_Rect_From_IMG, "Paste_rect_to_img": Paste_rect_to_img, "Combine_2_masks": Combine_2_masks, "Combine_2_masks_invert_1": Combine_2_masks_invert_1, "Combine_2_masks_inverse": Combine_2_masks_inverse, "combine_masks_with_loaded": combine_masks_with_loaded, "apply_segment": apply_segment, "apply_segment_2": apply_segment_2, # <-- add this } NODE_DISPLAY_NAME_MAPPINGS = { "Cropout_Square_From_IMG": "Cropout_Square_From_IMG", "Cropout_Rect_From_IMG": "Cropout_Rect_From_IMG", "Paste_rect_to_img": "Paste_rect_to_img", "Combine_2_masks": "Combine_2_masks", "Combine_2_masks_invert_1": "Combine_2_masks_invert_1", "Combine_2_masks_inverse": "Combine_2_masks_inverse", "combine_masks_with_loaded": "combine_masks_with_loaded", "apply_segment": "apply_segment", "apply_segment_2": "apply_segment_2", # <-- add this }