| | import os |
| | from typing import Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from PIL import Image |
| |
|
| |
|
| | |
| | try: |
| | from ..utils.io import list_pngs, load_image_from_assets, file_hash, safe_path |
| | except Exception: |
| | |
| | try: |
| | from .utils.io import list_pngs, load_image_from_assets, file_hash, safe_path |
| | except Exception as e: |
| | _UTILS_IMPORT_ERR = e |
| |
|
| | def _missing(*args, **kwargs): |
| | raise ImportError( |
| | "Could not import Salia utils (list_pngs/load_image_from_assets/file_hash/safe_path). " |
| | "Place this node file in the same package layout as your other Salia nodes.\n" |
| | f"Original import error: {_UTILS_IMPORT_ERR}" |
| | ) |
| |
|
| | list_pngs = _missing |
| | load_image_from_assets = _missing |
| | file_hash = _missing |
| | safe_path = _missing |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _as_image(img: torch.Tensor) -> torch.Tensor: |
| | |
| | if not isinstance(img, torch.Tensor): |
| | raise TypeError("IMAGE must be a torch.Tensor") |
| | if img.dim() != 4: |
| | raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}") |
| | if img.shape[-1] not in (3, 4): |
| | raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}") |
| | return img |
| |
|
| |
|
| | def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor: |
| | """ |
| | Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros. |
| | image: [B,H,W,C] |
| | returns: [B,h,w,C] |
| | """ |
| | image = _as_image(image) |
| | B, H, W, C = image.shape |
| | w = max(1, int(w)) |
| | h = max(1, int(h)) |
| | x = int(x) |
| | y = int(y) |
| |
|
| | out = torch.zeros((B, h, w, C), device=image.device, dtype=image.dtype) |
| |
|
| | |
| | x0s = max(0, x) |
| | y0s = max(0, y) |
| | x1s = min(W, x + w) |
| | y1s = min(H, y + h) |
| |
|
| | if x1s <= x0s or y1s <= y0s: |
| | return out |
| |
|
| | |
| | x0d = x0s - x |
| | y0d = y0s - y |
| | x1d = x0d + (x1s - x0s) |
| | y1d = y0d + (y1s - y0s) |
| |
|
| | out[:, y0d:y1d, x0d:x1d, :] = image[:, y0s:y1s, x0s:x1s, :] |
| | return out |
| |
|
| |
|
| | def _ensure_rgba(img: torch.Tensor) -> torch.Tensor: |
| | """ |
| | img: [B,H,W,C] where C is 3 or 4 |
| | returns RGBA [B,H,W,4] |
| | """ |
| | img = _as_image(img) |
| | if img.shape[-1] == 4: |
| | return img |
| | |
| | B, H, W, _ = img.shape |
| | alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype) |
| | return torch.cat([img, alpha], dim=-1) |
| |
|
| |
|
| | def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor: |
| | """ |
| | Places overlay at canvas pixel position (x,y) top-left corner. |
| | Supports RGB/RGBA for both. Uses alpha-over if overlay has alpha or canvas has alpha. |
| | Returns same channel count as canvas (3->3, 4->4). |
| | """ |
| | overlay = _as_image(overlay) |
| | canvas = _as_image(canvas) |
| |
|
| | |
| | if overlay.shape[0] != canvas.shape[0]: |
| | if overlay.shape[0] == 1 and canvas.shape[0] > 1: |
| | overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:]) |
| | elif canvas.shape[0] == 1 and overlay.shape[0] > 1: |
| | canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:]) |
| | else: |
| | raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}") |
| |
|
| | B, Hc, Wc, Cc = canvas.shape |
| | _, Ho, Wo, _ = overlay.shape |
| |
|
| | x = int(x) |
| | y = int(y) |
| |
|
| | out = canvas.clone() |
| |
|
| | |
| | x0c = max(0, x) |
| | y0c = max(0, y) |
| | x1c = min(Wc, x + Wo) |
| | y1c = min(Hc, y + Ho) |
| |
|
| | if x1c <= x0c or y1c <= y0c: |
| | return out |
| |
|
| | |
| | x0o = x0c - x |
| | y0o = y0c - y |
| | x1o = x0o + (x1c - x0c) |
| | y1o = y0o + (y1c - y0c) |
| |
|
| | canvas_region = out[:, y0c:y1c, x0c:x1c, :] |
| | overlay_region = overlay[:, y0o:y1o, x0o:x1o, :] |
| |
|
| | |
| | canvas_rgba = _ensure_rgba(canvas_region) |
| | overlay_rgba = _ensure_rgba(overlay_region) |
| |
|
| | over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0) |
| | over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0) |
| |
|
| | under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0) |
| | under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0) |
| |
|
| | |
| | over_pm = over_rgb * over_a |
| | under_pm = under_rgb * under_a |
| |
|
| | out_a = over_a + under_a * (1.0 - over_a) |
| | out_pm = over_pm + under_pm * (1.0 - over_a) |
| |
|
| | eps = 1e-6 |
| | out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm)) |
| | out_rgb = out_rgb.clamp(0.0, 1.0) |
| | out_a = out_a.clamp(0.0, 1.0) |
| |
|
| | if Cc == 3: |
| | out[:, y0c:y1c, x0c:x1c, :] = out_rgb |
| | else: |
| | out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class _AILab_MaskCombiner_Exact: |
| | def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None): |
| | try: |
| | masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None] |
| |
|
| | if len(masks) <= 1: |
| | return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),) |
| |
|
| | ref_shape = masks[0].shape |
| | masks = [self._resize_if_needed(m, ref_shape) for m in masks] |
| |
|
| | if mode == "combine": |
| | result = torch.maximum(masks[0], masks[1]) |
| | for mask in masks[2:]: |
| | result = torch.maximum(result, mask) |
| | elif mode == "intersection": |
| | result = torch.minimum(masks[0], masks[1]) |
| | else: |
| | result = torch.abs(masks[0] - masks[1]) |
| |
|
| | return (torch.clamp(result, 0, 1),) |
| | except Exception as e: |
| | print(f"Error in combine_masks: {str(e)}") |
| | print(f"Mask shapes: {[m.shape for m in masks]}") |
| | raise e |
| |
|
| | def _resize_if_needed(self, mask, target_shape): |
| | try: |
| | if mask.shape == target_shape: |
| | return mask |
| |
|
| | if len(mask.shape) == 2: |
| | mask = mask.unsqueeze(0) |
| | elif len(mask.shape) == 4: |
| | mask = mask.squeeze(1) |
| |
|
| | target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0] |
| | target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1] |
| |
|
| | resized_masks = [] |
| | for i in range(mask.shape[0]): |
| | mask_np = mask[i].cpu().numpy() |
| | img = Image.fromarray((mask_np * 255).astype(np.uint8)) |
| | img_resized = img.resize((target_width, target_height), Image.LANCZOS) |
| | mask_resized = np.array(img_resized).astype(np.float32) / 255.0 |
| | resized_masks.append(torch.from_numpy(mask_resized)) |
| |
|
| | return torch.stack(resized_masks) |
| |
|
| | except Exception as e: |
| | print(f"Error in _resize_if_needed: {str(e)}") |
| | print(f"Input mask shape: {mask.shape}, Target shape: {target_shape}") |
| | raise e |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Cropout_Square_From_IMG: |
| | CATEGORY = "image/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "img": ("IMAGE",), |
| | "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "square_size": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("image",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, img, x, y, square_size): |
| | cropped = _crop_with_padding(img, x, y, square_size, square_size) |
| | return (cropped,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Cropout_Rect_From_IMG: |
| | CATEGORY = "image/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "img": ("IMAGE",), |
| | "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "width": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}), |
| | "height": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("image",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, img, x, y, width, height): |
| | cropped = _crop_with_padding(img, x, y, width, height) |
| | return (cropped,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Paste_rect_to_img: |
| | CATEGORY = "image/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return { |
| | "required": { |
| | "overlay": ("IMAGE",), |
| | "canvas": ("IMAGE",), |
| | "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("image",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, overlay, canvas, x, y): |
| | out = _alpha_over_region(overlay, canvas, x, y) |
| | return (out,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Combine_2_masks: |
| | CATEGORY = "mask/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}} |
| |
|
| | RETURN_TYPES = ("MASK",) |
| | RETURN_NAMES = ("mask",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, maskA, maskB): |
| | combiner = _AILab_MaskCombiner_Exact() |
| | out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB) |
| | return (out,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Combine_2_masks_invert_1: |
| | CATEGORY = "mask/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}} |
| |
|
| | RETURN_TYPES = ("MASK",) |
| | RETURN_NAMES = ("mask",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, maskA, maskB): |
| | combiner = _AILab_MaskCombiner_Exact() |
| | maskA = 1.0 - maskA |
| | out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB) |
| | return (out,) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class Combine_2_masks_inverse: |
| | CATEGORY = "mask/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}} |
| |
|
| | RETURN_TYPES = ("MASK",) |
| | RETURN_NAMES = ("mask",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, maskA, maskB): |
| | combiner = _AILab_MaskCombiner_Exact() |
| | maskA = 1.0 - maskA |
| | maskB = 1.0 - maskB |
| | combined, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB) |
| | out = 1.0 - combined |
| | out = torch.clamp(out, 0, 1) |
| | return (out,) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class combine_masks_with_loaded: |
| | CATEGORY = "mask/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | choices = list_pngs() or ["<no pngs found>"] |
| | return { |
| | "required": { |
| | "mask": ("MASK",), |
| | "image": (choices, {}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("MASK",) |
| | RETURN_NAMES = ("mask",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, mask, image): |
| | if image == "<no pngs found>": |
| | raise FileNotFoundError("No PNGs in assets/images") |
| |
|
| | _img, loaded_mask = load_image_from_assets(image) |
| |
|
| | combiner = _AILab_MaskCombiner_Exact() |
| | out, = combiner.combine_masks(mask, mode="combine", mask_2=1.0-loaded_mask) |
| | return (out,) |
| |
|
| | @classmethod |
| | def IS_CHANGED(cls, mask, image): |
| | if image == "<no pngs found>": |
| | return image |
| | return file_hash(image) |
| |
|
| | @classmethod |
| | def VALIDATE_INPUTS(cls, mask, image): |
| | if image == "<no pngs found>": |
| | return "No PNGs in assets/images" |
| | try: |
| | path = safe_path(image) |
| | except Exception as e: |
| | return str(e) |
| | if not os.path.isfile(path): |
| | return f"File not found in assets/images: {image}" |
| | return True |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class apply_segment: |
| | CATEGORY = "image/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | choices = list_pngs() or ["<no pngs found>"] |
| | return { |
| | "required": { |
| | "mask": ("MASK",), |
| | "image": (choices, {}), |
| | "img": ("IMAGE",), |
| | "canvas": ("IMAGE",), |
| | "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("image",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | raise FileNotFoundError("No PNGs in assets/images") |
| |
|
| | combiner = _AILab_MaskCombiner_Exact() |
| |
|
| | |
| | _img_asset, loaded_mask = load_image_from_assets(image) |
| |
|
| | |
| | inv_mask = 1.0 - mask |
| | final_mask, = combiner.combine_masks(inv_mask, mode="combine", mask_2=loaded_mask) |
| |
|
| | |
| | img = _as_image(img) |
| | B, H, W, C = img.shape |
| |
|
| | |
| | |
| | final_mask_resized = combiner._resize_if_needed(final_mask, (final_mask.shape[0], H, W)) |
| |
|
| | |
| | if final_mask_resized.shape[0] != B: |
| | if final_mask_resized.shape[0] == 1 and B > 1: |
| | final_mask_resized = final_mask_resized.expand(B, H, W) |
| | elif B == 1 and final_mask_resized.shape[0] > 1: |
| | img = img.expand(final_mask_resized.shape[0], *img.shape[1:]) |
| | B = img.shape[0] |
| | else: |
| | raise ValueError(f"Batch mismatch: img batch={B}, final_mask batch={final_mask_resized.shape[0]}") |
| |
|
| | if C == 3: |
| | |
| | alpha = final_mask_resized.to(device=img.device, dtype=img.dtype) |
| | final_overlay = torch.cat([img, alpha.unsqueeze(-1)], dim=-1) |
| | else: |
| | |
| | rgb = img[..., :3] |
| | alpha_img = img[..., 3] |
| |
|
| | |
| | a1 = alpha_img.detach().cpu() |
| | a2 = final_mask_resized.detach().cpu() |
| | combined_alpha, = combiner.combine_masks(a1, mode="combine", mask_2=a2) |
| |
|
| | combined_alpha = combined_alpha.to(device=img.device, dtype=img.dtype) |
| | final_overlay = torch.cat([rgb, combined_alpha.unsqueeze(-1)], dim=-1) |
| |
|
| | |
| | canvas = _as_image(canvas) |
| | final_overlay = final_overlay.to(device=canvas.device, dtype=canvas.dtype) |
| |
|
| | out = _alpha_over_region(final_overlay, canvas, x, y) |
| | return (out,) |
| |
|
| | @classmethod |
| | def IS_CHANGED(cls, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | return image |
| | return file_hash(image) |
| |
|
| | @classmethod |
| | def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | return "No PNGs in assets/images" |
| | try: |
| | path = safe_path(image) |
| | except Exception as e: |
| | return str(e) |
| | if not os.path.isfile(path): |
| | return f"File not found in assets/images: {image}" |
| | return True |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class apply_segment_2: |
| | CATEGORY = "image/salia" |
| |
|
| | @classmethod |
| | def INPUT_TYPES(cls): |
| | choices = list_pngs() or ["<no pngs found>"] |
| | return { |
| | "required": { |
| | "mask": ("MASK",), |
| | "image": (choices, {}), |
| | "img": ("IMAGE",), |
| | "canvas": ("IMAGE",), |
| | "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}), |
| | } |
| | } |
| |
|
| | RETURN_TYPES = ("IMAGE",) |
| | RETURN_NAMES = ("image",) |
| | FUNCTION = "run" |
| |
|
| | def run(self, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | raise FileNotFoundError("No PNGs in assets/images") |
| |
|
| | combiner = _AILab_MaskCombiner_Exact() |
| |
|
| | |
| | inverse_mask = (1.0 - mask) |
| |
|
| | |
| | |
| | _img_asset, loaded_mask = load_image_from_assets(image) |
| |
|
| | |
| | inverse_mask_cpu = inverse_mask.detach().cpu() |
| | loaded_mask_cpu = loaded_mask.detach().cpu() |
| |
|
| | alpha_mask, = combiner.combine_masks( |
| | inverse_mask_cpu, |
| | mode="combine", |
| | mask_2=(1.0 - loaded_mask_cpu), |
| | ) |
| | alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0) |
| |
|
| | |
| | img = _as_image(img) |
| | B, H, W, C = img.shape |
| |
|
| | |
| | alpha_mask_resized = combiner._resize_if_needed(alpha_mask, (alpha_mask.shape[0], H, W)) |
| |
|
| | |
| | if alpha_mask_resized.shape[0] != B: |
| | if alpha_mask_resized.shape[0] == 1 and B > 1: |
| | alpha_mask_resized = alpha_mask_resized.expand(B, H, W) |
| | elif B == 1 and alpha_mask_resized.shape[0] > 1: |
| | img = img.expand(alpha_mask_resized.shape[0], *img.shape[1:]) |
| | B = img.shape[0] |
| | else: |
| | raise ValueError( |
| | f"Batch mismatch: img batch={B}, alpha_mask batch={alpha_mask_resized.shape[0]}" |
| | ) |
| |
|
| | alpha_mask_resized = alpha_mask_resized.to(device=img.device, dtype=img.dtype).clamp(0.0, 1.0) |
| |
|
| | if C == 3: |
| | |
| | overlay = torch.cat([img, alpha_mask_resized.unsqueeze(-1)], dim=-1) |
| | else: |
| | |
| | |
| | rgb = img[..., :3] |
| | alpha_img = img[..., 3].clamp(0.0, 1.0) |
| |
|
| | alpha_out = (alpha_img * alpha_mask_resized).clamp(0.0, 1.0) |
| | overlay = torch.cat([rgb, alpha_out.unsqueeze(-1)], dim=-1) |
| |
|
| | |
| | canvas = _as_image(canvas) |
| | overlay = overlay.to(device=canvas.device, dtype=canvas.dtype) |
| |
|
| | out = _alpha_over_region(overlay, canvas, x, y) |
| | return (out,) |
| |
|
| | @classmethod |
| | def IS_CHANGED(cls, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | return image |
| | return file_hash(image) |
| |
|
| | @classmethod |
| | def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y): |
| | if image == "<no pngs found>": |
| | return "No PNGs in assets/images" |
| | try: |
| | path = safe_path(image) |
| | except Exception as e: |
| | return str(e) |
| | if not os.path.isfile(path): |
| | return f"File not found in assets/images: {image}" |
| | return True |
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = { |
| | "Cropout_Square_From_IMG": Cropout_Square_From_IMG, |
| | "Cropout_Rect_From_IMG": Cropout_Rect_From_IMG, |
| | "Paste_rect_to_img": Paste_rect_to_img, |
| | "Combine_2_masks": Combine_2_masks, |
| | "Combine_2_masks_invert_1": Combine_2_masks_invert_1, |
| | "Combine_2_masks_inverse": Combine_2_masks_inverse, |
| | "combine_masks_with_loaded": combine_masks_with_loaded, |
| | "apply_segment": apply_segment, |
| | "apply_segment_2": apply_segment_2, |
| | } |
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = { |
| | "Cropout_Square_From_IMG": "Cropout_Square_From_IMG", |
| | "Cropout_Rect_From_IMG": "Cropout_Rect_From_IMG", |
| | "Paste_rect_to_img": "Paste_rect_to_img", |
| | "Combine_2_masks": "Combine_2_masks", |
| | "Combine_2_masks_invert_1": "Combine_2_masks_invert_1", |
| | "Combine_2_masks_inverse": "Combine_2_masks_inverse", |
| | "combine_masks_with_loaded": "combine_masks_with_loaded", |
| | "apply_segment": "apply_segment", |
| | "apply_segment_2": "apply_segment_2", |
| | } |