File size: 13,086 Bytes
2a5d451 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 | import os
import hashlib
from typing import List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
# ============================================================
# Standalone assets helpers (no external utils required)
# Expects: <this_file_dir>/assets/images/*.png
# ============================================================
_ASSETS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets", "images")
def list_pngs() -> List[str]:
if not os.path.isdir(_ASSETS_DIR):
return []
files = []
for f in os.listdir(_ASSETS_DIR):
if f.lower().endswith(".png") and os.path.isfile(os.path.join(_ASSETS_DIR, f)):
files.append(f)
return sorted(files)
def safe_path(filename: str) -> str:
# Prevent path traversal, force within _ASSETS_DIR
candidate = os.path.join(_ASSETS_DIR, filename)
real_assets = os.path.realpath(_ASSETS_DIR)
real_candidate = os.path.realpath(candidate)
if not real_candidate.startswith(real_assets + os.sep) and real_candidate != real_assets:
raise ValueError("Unsafe path (path traversal detected).")
return real_candidate
def file_hash(filename: str) -> str:
path = safe_path(filename)
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def load_image_from_assets(filename: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Loads a PNG from assets/images and returns:
- image: IMAGE tensor [1,H,W,3] float32 in [0,1]
- mask: MASK tensor [1,H,W] float32 in [0,1]
IMPORTANT: mask follows ComfyUI LoadImage convention:
if alpha exists: mask = 1 - alpha
else: mask = 1 - luminance
"""
path = safe_path(filename)
i = Image.open(path)
i = ImageOps.exif_transpose(i)
# Match Comfy style handling of mode 'I'
if i.mode == "I":
i = i.point(lambda px: px * (1 / 255))
# IMAGE output (RGB)
rgb = i.convert("RGB")
rgb_np = np.array(rgb).astype(np.float32) / 255.0
image = torch.from_numpy(rgb_np)[None, ...] # [1,H,W,3]
# MASK output
bands = i.getbands()
if "A" in bands:
a = np.array(i.getchannel("A")).astype(np.float32) / 255.0
alpha = torch.from_numpy(a) # [H,W]
else:
# fallback: use luminance as alpha-like signal
l = np.array(i.convert("L")).astype(np.float32) / 255.0
alpha = torch.from_numpy(l)
mask = 1.0 - alpha # ComfyUI mask convention
mask = mask.clamp(0.0, 1.0).unsqueeze(0) # [1,H,W]
return image, mask
# ============================================================
# Helpers (IMAGE / MASK validation + alpha paste)
# ============================================================
def _as_image(img: torch.Tensor) -> torch.Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError("IMAGE must be a torch.Tensor")
if img.dim() != 4:
raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
if img.shape[-1] not in (3, 4):
raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
return img
def _as_mask(mask: torch.Tensor) -> torch.Tensor:
if not isinstance(mask, torch.Tensor):
raise TypeError("MASK must be a torch.Tensor")
if mask.dim() == 2:
mask = mask.unsqueeze(0) # [1,H,W]
if mask.dim() != 3:
raise ValueError(f"Expected MASK shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
return mask
def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
img = _as_image(img)
if img.shape[-1] == 4:
return img
B, H, W, _ = img.shape
alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
return torch.cat([img, alpha], dim=-1)
def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
"""
Alpha-over paste overlay on top of canvas at (x,y) using overlay alpha.
Supports RGB/RGBA for both. Returns same channel count as canvas.
"""
overlay = _as_image(overlay)
canvas = _as_image(canvas)
# Batch handling: allow 1->N expansion
if overlay.shape[0] != canvas.shape[0]:
if overlay.shape[0] == 1 and canvas.shape[0] > 1:
overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
else:
raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
B, Hc, Wc, Cc = canvas.shape
_, Ho, Wo, _ = overlay.shape
x = int(x)
y = int(y)
out = canvas.clone()
# intersection on canvas
x0c = max(0, x)
y0c = max(0, y)
x1c = min(Wc, x + Wo)
y1c = min(Hc, y + Ho)
if x1c <= x0c or y1c <= y0c:
return out
# corresponding region on overlay
x0o = x0c - x
y0o = y0c - y
x1o = x0o + (x1c - x0c)
y1o = y0o + (y1c - y0c)
canvas_region = out[:, y0c:y1c, x0c:x1c, :]
overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
canvas_rgba = _ensure_rgba(canvas_region)
overlay_rgba = _ensure_rgba(overlay_region)
over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
# premultiplied alpha composite
over_pm = over_rgb * over_a
under_pm = under_rgb * under_a
out_a = over_a + under_a * (1.0 - over_a)
out_pm = over_pm + under_pm * (1.0 - over_a)
eps = 1e-6
out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
out_rgb = out_rgb.clamp(0.0, 1.0)
out_a = out_a.clamp(0.0, 1.0)
if Cc == 3:
out[:, y0c:y1c, x0c:x1c, :] = out_rgb
else:
out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
return out
# ============================================================
# RMBG EXACT MASK COMBINE LOGIC (same as your prior node)
# torch.maximum + PIL resize (LANCZOS)
# ============================================================
class _AILab_MaskCombiner_Exact:
def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
if len(masks) <= 1:
return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
ref_shape = masks[0].shape
masks = [self._resize_if_needed(m, ref_shape) for m in masks]
if mode == "combine":
result = torch.maximum(masks[0], masks[1])
for mask in masks[2:]:
result = torch.maximum(result, mask)
elif mode == "intersection":
result = torch.minimum(masks[0], masks[1])
else:
result = torch.abs(masks[0] - masks[1])
return (torch.clamp(result, 0, 1),)
def _resize_if_needed(self, mask, target_shape):
if mask.shape == target_shape:
return mask
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
elif len(mask.shape) == 4:
mask = mask.squeeze(1)
target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
resized_masks = []
for i in range(mask.shape[0]):
mask_np = mask[i].cpu().numpy()
img = Image.fromarray((mask_np * 255).astype(np.uint8))
img_resized = img.resize((target_width, target_height), Image.LANCZOS)
mask_resized = np.array(img_resized).astype(np.float32) / 255.0
resized_masks.append(torch.from_numpy(mask_resized))
return torch.stack(resized_masks)
# ============================================================
# ComfyUI core "Join Image with Alpha" logic (EXACT)
# (from JoinImageWithAlpha implementation)
# ============================================================
def _resize_mask_comfy(alpha_mask: torch.Tensor, image_shape_hwc: Tuple[int, int, int]) -> torch.Tensor:
# image_shape_hwc is image.shape[1:] => (H,W,C)
H = int(image_shape_hwc[0])
W = int(image_shape_hwc[1])
return F.interpolate(
alpha_mask.reshape((-1, 1, alpha_mask.shape[-2], alpha_mask.shape[-1])),
size=(H, W),
mode="bilinear",
).squeeze(1)
def _join_image_with_alpha_comfy(image: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
"""
EXACT logic:
batch_size = min(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
out = cat(image[i][:,:,:3], alpha[i].unsqueeze(2))
"""
image = _as_image(image)
alpha = _as_mask(alpha)
# Ensure same device/dtype for cat (core node assumes they already match)
alpha = alpha.to(device=image.device, dtype=image.dtype)
batch_size = min(len(image), len(alpha))
out_images = []
alpha_resized = 1.0 - _resize_mask_comfy(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:, :, :3], alpha_resized[i].unsqueeze(2)), dim=2))
return torch.stack(out_images)
# ============================================================
# NODE: apply_segment_3
# ============================================================
class apply_segment_3:
CATEGORY = "image/salia"
@classmethod
def INPUT_TYPES(cls):
choices = list_pngs() or ["<no pngs found>"]
return {
"required": {
"mask": ("MASK",),
"image": (choices, {}), # dropdown asset (used for loaded mask)
"img": ("IMAGE",), # input image for Join Image with Alpha
"canvas": ("IMAGE",), # destination canvas
"x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
"y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
}
}
RETURN_TYPES = ("MASK", "MASK", "IMAGE", "IMAGE")
RETURN_NAMES = ("Inversed_Mask", "Alpha_Mask", "Alpha_Image", "Final_Image")
FUNCTION = "run"
def run(self, mask, image, img, canvas, x, y):
if image == "<no pngs found>":
raise FileNotFoundError("No PNGs found in assets/images next to apply_segment_3.py")
# --- Step A: invert input mask (exactly your workflow)
mask_in = _as_mask(mask)
inversed_mask = 1.0 - mask_in # [B,H,W]
# --- Step B: combine_masks_with_loaded(inversed_mask) -> alpha_mask
# combine_masks_with_loaded does: max(mask, 1 - loaded_mask)
# loaded_mask comes from loader (Comfy LoadImage-style mask = 1 - alpha)
# so (1 - loaded_mask) is alpha channel (or "mask" stored as alpha)
_asset_img, loaded_mask = load_image_from_assets(image)
combiner = _AILab_MaskCombiner_Exact()
inv_cpu = inversed_mask.detach().cpu()
loaded_cpu = _as_mask(loaded_mask).detach().cpu()
alpha_mask, = combiner.combine_masks(inv_cpu, mode="combine", mask_2=(1.0 - loaded_cpu))
alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0) # [B,H,W] on CPU
# --- Step C: Join Image with Alpha (EXACT comfy core logic)
alpha_image = _join_image_with_alpha_comfy(img, alpha_mask)
# --- Step D: Paste_rect_to_img equivalent (alpha-over)
canvas = _as_image(canvas)
alpha_image = alpha_image.to(device=canvas.device, dtype=canvas.dtype)
final = _alpha_over_region(alpha_image, canvas, x, y)
return (inversed_mask, alpha_mask, alpha_image, final)
@classmethod
def IS_CHANGED(cls, mask, image, img, canvas, x, y):
if image == "<no pngs found>":
return image
return file_hash(image)
@classmethod
def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y):
if image == "<no pngs found>":
return "No PNGs found in assets/images next to apply_segment_3.py"
try:
path = safe_path(image)
except Exception as e:
return str(e)
if not os.path.isfile(path):
return f"File not found in assets/images: {image}"
return True
# ============================================================
# Node mappings (ONLY this node)
# ============================================================
NODE_CLASS_MAPPINGS = {
"apply_segment_3": apply_segment_3,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"apply_segment_3": "apply_segment_3",
} |