saliacoel commited on
Commit
4b3f0d8
·
verified ·
1 Parent(s): 067dbfa

Update Salia_Croppytools.py

Browse files
Files changed (1) hide show
  1. Salia_Croppytools.py +552 -552
Salia_Croppytools.py CHANGED
@@ -1,553 +1,553 @@
1
- import os
2
- from typing import Tuple
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- import numpy as np
7
- from PIL import Image
8
-
9
-
10
- # Salia utils (same style as your loader node)
11
- try:
12
- from ..utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
13
- except Exception:
14
- # Fallback if you placed this file in a different package depth
15
- try:
16
- from .utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
17
- except Exception as e:
18
- _UTILS_IMPORT_ERR = e
19
-
20
- def _missing(*args, **kwargs):
21
- raise ImportError(
22
- "Could not import Salia utils (list_pngs/load_image_from_assets/file_hash/safe_path). "
23
- "Place this node file in the same package layout as your other Salia nodes.\n"
24
- f"Original import error: {_UTILS_IMPORT_ERR}"
25
- )
26
-
27
- list_pngs = _missing
28
- load_image_from_assets = _missing
29
- file_hash = _missing
30
- safe_path = _missing
31
-
32
-
33
- # -----------------------------
34
- # Helpers (IMAGE)
35
- # -----------------------------
36
-
37
- def _as_image(img: torch.Tensor) -> torch.Tensor:
38
- # ComfyUI IMAGE is usually [B,H,W,C]
39
- if not isinstance(img, torch.Tensor):
40
- raise TypeError("IMAGE must be a torch.Tensor")
41
- if img.dim() != 4:
42
- raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
43
- if img.shape[-1] not in (3, 4):
44
- raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
45
- return img
46
-
47
-
48
- def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor:
49
- """
50
- Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros.
51
- image: [B,H,W,C]
52
- returns: [B,h,w,C]
53
- """
54
- image = _as_image(image)
55
- B, H, W, C = image.shape
56
- w = max(1, int(w))
57
- h = max(1, int(h))
58
- x = int(x)
59
- y = int(y)
60
-
61
- out = torch.zeros((B, h, w, C), device=image.device, dtype=image.dtype)
62
-
63
- # intersection in source
64
- x0s = max(0, x)
65
- y0s = max(0, y)
66
- x1s = min(W, x + w)
67
- y1s = min(H, y + h)
68
-
69
- if x1s <= x0s or y1s <= y0s:
70
- return out
71
-
72
- # destination offsets
73
- x0d = x0s - x
74
- y0d = y0s - y
75
- x1d = x0d + (x1s - x0s)
76
- y1d = y0d + (y1s - y0s)
77
-
78
- out[:, y0d:y1d, x0d:x1d, :] = image[:, y0s:y1s, x0s:x1s, :]
79
- return out
80
-
81
-
82
- def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
83
- """
84
- img: [B,H,W,C] where C is 3 or 4
85
- returns RGBA [B,H,W,4]
86
- """
87
- img = _as_image(img)
88
- if img.shape[-1] == 4:
89
- return img
90
- # RGB -> RGBA with alpha=1
91
- B, H, W, _ = img.shape
92
- alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
93
- return torch.cat([img, alpha], dim=-1)
94
-
95
-
96
- def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
97
- """
98
- Places overlay at canvas pixel position (x,y) top-left corner.
99
- Supports RGB/RGBA for both. Uses alpha-over if overlay has alpha or canvas has alpha.
100
- Returns same channel count as canvas (3->3, 4->4).
101
- """
102
- overlay = _as_image(overlay)
103
- canvas = _as_image(canvas)
104
-
105
- # Simple batch handling (Comfy usually matches batches, but allow 1->N)
106
- if overlay.shape[0] != canvas.shape[0]:
107
- if overlay.shape[0] == 1 and canvas.shape[0] > 1:
108
- overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
109
- elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
110
- canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
111
- else:
112
- raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
113
-
114
- B, Hc, Wc, Cc = canvas.shape
115
- _, Ho, Wo, _ = overlay.shape
116
-
117
- x = int(x)
118
- y = int(y)
119
-
120
- out = canvas.clone()
121
-
122
- # intersection on canvas
123
- x0c = max(0, x)
124
- y0c = max(0, y)
125
- x1c = min(Wc, x + Wo)
126
- y1c = min(Hc, y + Ho)
127
-
128
- if x1c <= x0c or y1c <= y0c:
129
- return out
130
-
131
- # corresponding region on overlay
132
- x0o = x0c - x
133
- y0o = y0c - y
134
- x1o = x0o + (x1c - x0c)
135
- y1o = y0o + (y1c - y0c)
136
-
137
- canvas_region = out[:, y0c:y1c, x0c:x1c, :]
138
- overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
139
-
140
- # Convert both regions to RGBA for compositing
141
- canvas_rgba = _ensure_rgba(canvas_region)
142
- overlay_rgba = _ensure_rgba(overlay_region)
143
-
144
- over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
145
- over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
146
-
147
- under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
148
- under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
149
-
150
- # Premultiplied alpha composite: out = over + under*(1-over_a)
151
- over_pm = over_rgb * over_a
152
- under_pm = under_rgb * under_a
153
-
154
- out_a = over_a + under_a * (1.0 - over_a)
155
- out_pm = over_pm + under_pm * (1.0 - over_a)
156
-
157
- eps = 1e-6
158
- out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
159
- out_rgb = out_rgb.clamp(0.0, 1.0)
160
- out_a = out_a.clamp(0.0, 1.0)
161
-
162
- if Cc == 3:
163
- out[:, y0c:y1c, x0c:x1c, :] = out_rgb
164
- else:
165
- out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
166
-
167
- return out
168
-
169
-
170
- # -----------------------------
171
- # RMBG EXACT MASK COMBINE LOGIC (copied solution)
172
- # -----------------------------
173
-
174
- class _AILab_MaskCombiner_Exact:
175
- def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
176
- try:
177
- masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
178
-
179
- if len(masks) <= 1:
180
- return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
181
-
182
- ref_shape = masks[0].shape
183
- masks = [self._resize_if_needed(m, ref_shape) for m in masks]
184
-
185
- if mode == "combine":
186
- result = torch.maximum(masks[0], masks[1])
187
- for mask in masks[2:]:
188
- result = torch.maximum(result, mask)
189
- elif mode == "intersection":
190
- result = torch.minimum(masks[0], masks[1])
191
- else:
192
- result = torch.abs(masks[0] - masks[1])
193
-
194
- return (torch.clamp(result, 0, 1),)
195
- except Exception as e:
196
- print(f"Error in combine_masks: {str(e)}")
197
- print(f"Mask shapes: {[m.shape for m in masks]}")
198
- raise e
199
-
200
- def _resize_if_needed(self, mask, target_shape):
201
- try:
202
- if mask.shape == target_shape:
203
- return mask
204
-
205
- if len(mask.shape) == 2:
206
- mask = mask.unsqueeze(0)
207
- elif len(mask.shape) == 4:
208
- mask = mask.squeeze(1)
209
-
210
- target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
211
- target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
212
-
213
- resized_masks = []
214
- for i in range(mask.shape[0]):
215
- mask_np = mask[i].cpu().numpy()
216
- img = Image.fromarray((mask_np * 255).astype(np.uint8))
217
- img_resized = img.resize((target_width, target_height), Image.LANCZOS)
218
- mask_resized = np.array(img_resized).astype(np.float32) / 255.0
219
- resized_masks.append(torch.from_numpy(mask_resized))
220
-
221
- return torch.stack(resized_masks)
222
-
223
- except Exception as e:
224
- print(f"Error in _resize_if_needed: {str(e)}")
225
- print(f"Input mask shape: {mask.shape}, Target shape: {target_shape}")
226
- raise e
227
-
228
-
229
- # -----------------------------
230
- # 1) Cropout_Square_From_IMG
231
- # -----------------------------
232
-
233
- class Cropout_Square_From_IMG:
234
- CATEGORY = "image/salia"
235
-
236
- @classmethod
237
- def INPUT_TYPES(cls):
238
- return {
239
- "required": {
240
- "img": ("IMAGE",),
241
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
242
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
243
- "square_size": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
244
- }
245
- }
246
-
247
- RETURN_TYPES = ("IMAGE",)
248
- RETURN_NAMES = ("image",)
249
- FUNCTION = "run"
250
-
251
- def run(self, img, x, y, square_size):
252
- cropped = _crop_with_padding(img, x, y, square_size, square_size)
253
- return (cropped,)
254
-
255
-
256
- # -----------------------------
257
- # 2) Cropout_Rect_From_IMG
258
- # -----------------------------
259
-
260
- class Cropout_Rect_From_IMG:
261
- CATEGORY = "image/salia"
262
-
263
- @classmethod
264
- def INPUT_TYPES(cls):
265
- return {
266
- "required": {
267
- "img": ("IMAGE",),
268
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
269
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
270
- "width": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
271
- "height": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
272
- }
273
- }
274
-
275
- RETURN_TYPES = ("IMAGE",)
276
- RETURN_NAMES = ("image",)
277
- FUNCTION = "run"
278
-
279
- def run(self, img, x, y, width, height):
280
- cropped = _crop_with_padding(img, x, y, width, height)
281
- return (cropped,)
282
-
283
-
284
- # -----------------------------
285
- # 3) Paste_rect_to_img
286
- # -----------------------------
287
-
288
- class Paste_rect_to_img:
289
- CATEGORY = "image/salia"
290
-
291
- @classmethod
292
- def INPUT_TYPES(cls):
293
- return {
294
- "required": {
295
- "overlay": ("IMAGE",),
296
- "canvas": ("IMAGE",),
297
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
298
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
299
- }
300
- }
301
-
302
- RETURN_TYPES = ("IMAGE",)
303
- RETURN_NAMES = ("image",)
304
- FUNCTION = "run"
305
-
306
- def run(self, overlay, canvas, x, y):
307
- out = _alpha_over_region(overlay, canvas, x, y)
308
- return (out,)
309
-
310
-
311
- # -----------------------------
312
- # 4) Combine_2_masks (RMBG exact: torch.maximum + PIL resize)
313
- # -----------------------------
314
-
315
- class Combine_2_masks:
316
- CATEGORY = "mask/salia"
317
-
318
- @classmethod
319
- def INPUT_TYPES(cls):
320
- return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
321
-
322
- RETURN_TYPES = ("MASK",)
323
- RETURN_NAMES = ("mask",)
324
- FUNCTION = "run"
325
-
326
- def run(self, maskA, maskB):
327
- combiner = _AILab_MaskCombiner_Exact()
328
- out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
329
- return (out,)
330
-
331
-
332
- # -----------------------------
333
- # 5) Combine_2_masks_invert_1 (invert A then RMBG combine)
334
- # -----------------------------
335
-
336
- class Combine_2_masks_invert_1:
337
- CATEGORY = "mask/salia"
338
-
339
- @classmethod
340
- def INPUT_TYPES(cls):
341
- return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
342
-
343
- RETURN_TYPES = ("MASK",)
344
- RETURN_NAMES = ("mask",)
345
- FUNCTION = "run"
346
-
347
- def run(self, maskA, maskB):
348
- combiner = _AILab_MaskCombiner_Exact()
349
- maskA = 1.0 - maskA
350
- out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
351
- return (out,)
352
-
353
-
354
- # -----------------------------
355
- # 6) Combine_2_masks_inverse
356
- # invert both, combine, invert result (RMBG max logic)
357
- # -----------------------------
358
-
359
- class Combine_2_masks_inverse:
360
- CATEGORY = "mask/salia"
361
-
362
- @classmethod
363
- def INPUT_TYPES(cls):
364
- return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
365
-
366
- RETURN_TYPES = ("MASK",)
367
- RETURN_NAMES = ("mask",)
368
- FUNCTION = "run"
369
-
370
- def run(self, maskA, maskB):
371
- combiner = _AILab_MaskCombiner_Exact()
372
- maskA = 1.0 - maskA
373
- maskB = 1.0 - maskB
374
- combined, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
375
- out = 1.0 - combined
376
- out = torch.clamp(out, 0, 1)
377
- return (out,)
378
-
379
-
380
- # -----------------------------
381
- # 7) combine_masks_with_loaded (RMBG exact combine)
382
- # -----------------------------
383
-
384
- class combine_masks_with_loaded:
385
- CATEGORY = "mask/salia"
386
-
387
- @classmethod
388
- def INPUT_TYPES(cls):
389
- choices = list_pngs() or ["<no pngs found>"]
390
- return {
391
- "required": {
392
- "mask": ("MASK",),
393
- "image": (choices, {}),
394
- }
395
- }
396
-
397
- RETURN_TYPES = ("MASK",)
398
- RETURN_NAMES = ("mask",)
399
- FUNCTION = "run"
400
-
401
- def run(self, mask, image):
402
- if image == "<no pngs found>":
403
- raise FileNotFoundError("No PNGs in assets/images")
404
-
405
- _img, loaded_mask = load_image_from_assets(image)
406
-
407
- combiner = _AILab_MaskCombiner_Exact()
408
- out, = combiner.combine_masks(mask, mode="combine", mask_2=loaded_mask)
409
- return (out,)
410
-
411
- @classmethod
412
- def IS_CHANGED(cls, mask, image):
413
- if image == "<no pngs found>":
414
- return image
415
- return file_hash(image)
416
-
417
- @classmethod
418
- def VALIDATE_INPUTS(cls, mask, image):
419
- if image == "<no pngs found>":
420
- return "No PNGs in assets/images"
421
- try:
422
- path = safe_path(image)
423
- except Exception as e:
424
- return str(e)
425
- if not os.path.isfile(path):
426
- return f"File not found in assets/images: {image}"
427
- return True
428
-
429
-
430
- # -----------------------------
431
- # 8) NEW: invert input mask, combine with loaded mask, apply to image alpha, paste on canvas
432
- # -----------------------------
433
-
434
- class apply_segment:
435
- CATEGORY = "image/salia"
436
-
437
- @classmethod
438
- def INPUT_TYPES(cls):
439
- choices = list_pngs() or ["<no pngs found>"]
440
- return {
441
- "required": {
442
- "mask": ("MASK",),
443
- "image": (choices, {}), # dropdown asset (used ONLY for loaded mask)
444
- "img": ("IMAGE",), # the image to receive final_mask as alpha (overlay source)
445
- "canvas": ("IMAGE",), # destination
446
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
447
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
448
- }
449
- }
450
-
451
- RETURN_TYPES = ("IMAGE",)
452
- RETURN_NAMES = ("image",)
453
- FUNCTION = "run"
454
-
455
- def run(self, mask, image, img, canvas, x, y):
456
- if image == "<no pngs found>":
457
- raise FileNotFoundError("No PNGs in assets/images")
458
-
459
- combiner = _AILab_MaskCombiner_Exact()
460
-
461
- # Load asset mask (do NOT invert)
462
- _img_asset, loaded_mask = load_image_from_assets(image)
463
-
464
- # Invert input mask, then combine with loaded mask (RMBG exact combine => maximum)
465
- inv_mask = 1.0 - mask
466
- final_mask, = combiner.combine_masks(inv_mask, mode="combine", mask_2=loaded_mask)
467
-
468
- # Apply final_mask as alpha to input image -> final_overlay (RGBA)
469
- img = _as_image(img)
470
- B, H, W, C = img.shape
471
-
472
- # Resize final_mask to match img H/W if needed (uses RMBG exact resize helper)
473
- # (target_shape must look like a mask shape [B,H,W], but resize keeps its own batch count)
474
- final_mask_resized = combiner._resize_if_needed(final_mask, (final_mask.shape[0], H, W))
475
-
476
- # Batch match (simple 1->N expansion only)
477
- if final_mask_resized.shape[0] != B:
478
- if final_mask_resized.shape[0] == 1 and B > 1:
479
- final_mask_resized = final_mask_resized.expand(B, H, W)
480
- elif B == 1 and final_mask_resized.shape[0] > 1:
481
- img = img.expand(final_mask_resized.shape[0], *img.shape[1:])
482
- B = img.shape[0]
483
- else:
484
- raise ValueError(f"Batch mismatch: img batch={B}, final_mask batch={final_mask_resized.shape[0]}")
485
-
486
- if C == 3:
487
- # RGB -> RGBA with alpha = final_mask
488
- alpha = final_mask_resized.to(device=img.device, dtype=img.dtype)
489
- final_overlay = torch.cat([img, alpha.unsqueeze(-1)], dim=-1)
490
- else:
491
- # RGBA: combine existing alpha with final_mask using RMBG combine (maximum)
492
- rgb = img[..., :3]
493
- alpha_img = img[..., 3] # [B,H,W]
494
-
495
- # RMBG combine uses PIL-resize sometimes, so keep combine inputs on CPU
496
- a1 = alpha_img.detach().cpu()
497
- a2 = final_mask_resized.detach().cpu()
498
- combined_alpha, = combiner.combine_masks(a1, mode="combine", mask_2=a2)
499
-
500
- combined_alpha = combined_alpha.to(device=img.device, dtype=img.dtype)
501
- final_overlay = torch.cat([rgb, combined_alpha.unsqueeze(-1)], dim=-1)
502
-
503
- # Paste final_overlay onto canvas at (x,y)
504
- canvas = _as_image(canvas)
505
- final_overlay = final_overlay.to(device=canvas.device, dtype=canvas.dtype)
506
-
507
- out = _alpha_over_region(final_overlay, canvas, x, y)
508
- return (out,)
509
-
510
- @classmethod
511
- def IS_CHANGED(cls, mask, image, img, canvas, x, y):
512
- if image == "<no pngs found>":
513
- return image
514
- return file_hash(image)
515
-
516
- @classmethod
517
- def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y):
518
- if image == "<no pngs found>":
519
- return "No PNGs in assets/images"
520
- try:
521
- path = safe_path(image)
522
- except Exception as e:
523
- return str(e)
524
- if not os.path.isfile(path):
525
- return f"File not found in assets/images: {image}"
526
- return True
527
-
528
-
529
- # -----------------------------
530
- # Node mappings
531
- # -----------------------------
532
-
533
- NODE_CLASS_MAPPINGS = {
534
- "Cropout_Square_From_IMG": Cropout_Square_From_IMG,
535
- "Cropout_Rect_From_IMG": Cropout_Rect_From_IMG,
536
- "Paste_rect_to_img": Paste_rect_to_img,
537
- "Combine_2_masks": Combine_2_masks,
538
- "Combine_2_masks_invert_1": Combine_2_masks_invert_1,
539
- "Combine_2_masks_inverse": Combine_2_masks_inverse,
540
- "combine_masks_with_loaded": combine_masks_with_loaded,
541
- "apply_segment": apply_segment,
542
- }
543
-
544
- NODE_DISPLAY_NAME_MAPPINGS = {
545
- "Cropout_Square_From_IMG": "Cropout_Square_From_IMG",
546
- "Cropout_Rect_From_IMG": "Cropout_Rect_From_IMG",
547
- "Paste_rect_to_img": "Paste_rect_to_img",
548
- "Combine_2_masks": "Combine_2_masks",
549
- "Combine_2_masks_invert_1": "Combine_2_masks_invert_1",
550
- "Combine_2_masks_inverse": "Combine_2_masks_inverse",
551
- "combine_masks_with_loaded": "combine_masks_with_loaded",
552
- "apply_segment": "apply_segment",
553
  }
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+
10
+ # Salia utils (same style as your loader node)
11
+ try:
12
+ from ..utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
13
+ except Exception:
14
+ # Fallback if you placed this file in a different package depth
15
+ try:
16
+ from .utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
17
+ except Exception as e:
18
+ _UTILS_IMPORT_ERR = e
19
+
20
+ def _missing(*args, **kwargs):
21
+ raise ImportError(
22
+ "Could not import Salia utils (list_pngs/load_image_from_assets/file_hash/safe_path). "
23
+ "Place this node file in the same package layout as your other Salia nodes.\n"
24
+ f"Original import error: {_UTILS_IMPORT_ERR}"
25
+ )
26
+
27
+ list_pngs = _missing
28
+ load_image_from_assets = _missing
29
+ file_hash = _missing
30
+ safe_path = _missing
31
+
32
+
33
+ # -----------------------------
34
+ # Helpers (IMAGE)
35
+ # -----------------------------
36
+
37
+ def _as_image(img: torch.Tensor) -> torch.Tensor:
38
+ # ComfyUI IMAGE is usually [B,H,W,C]
39
+ if not isinstance(img, torch.Tensor):
40
+ raise TypeError("IMAGE must be a torch.Tensor")
41
+ if img.dim() != 4:
42
+ raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
43
+ if img.shape[-1] not in (3, 4):
44
+ raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
45
+ return img
46
+
47
+
48
+ def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor:
49
+ """
50
+ Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros.
51
+ image: [B,H,W,C]
52
+ returns: [B,h,w,C]
53
+ """
54
+ image = _as_image(image)
55
+ B, H, W, C = image.shape
56
+ w = max(1, int(w))
57
+ h = max(1, int(h))
58
+ x = int(x)
59
+ y = int(y)
60
+
61
+ out = torch.zeros((B, h, w, C), device=image.device, dtype=image.dtype)
62
+
63
+ # intersection in source
64
+ x0s = max(0, x)
65
+ y0s = max(0, y)
66
+ x1s = min(W, x + w)
67
+ y1s = min(H, y + h)
68
+
69
+ if x1s <= x0s or y1s <= y0s:
70
+ return out
71
+
72
+ # destination offsets
73
+ x0d = x0s - x
74
+ y0d = y0s - y
75
+ x1d = x0d + (x1s - x0s)
76
+ y1d = y0d + (y1s - y0s)
77
+
78
+ out[:, y0d:y1d, x0d:x1d, :] = image[:, y0s:y1s, x0s:x1s, :]
79
+ return out
80
+
81
+
82
+ def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ img: [B,H,W,C] where C is 3 or 4
85
+ returns RGBA [B,H,W,4]
86
+ """
87
+ img = _as_image(img)
88
+ if img.shape[-1] == 4:
89
+ return img
90
+ # RGB -> RGBA with alpha=1
91
+ B, H, W, _ = img.shape
92
+ alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
93
+ return torch.cat([img, alpha], dim=-1)
94
+
95
+
96
+ def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
97
+ """
98
+ Places overlay at canvas pixel position (x,y) top-left corner.
99
+ Supports RGB/RGBA for both. Uses alpha-over if overlay has alpha or canvas has alpha.
100
+ Returns same channel count as canvas (3->3, 4->4).
101
+ """
102
+ overlay = _as_image(overlay)
103
+ canvas = _as_image(canvas)
104
+
105
+ # Simple batch handling (Comfy usually matches batches, but allow 1->N)
106
+ if overlay.shape[0] != canvas.shape[0]:
107
+ if overlay.shape[0] == 1 and canvas.shape[0] > 1:
108
+ overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
109
+ elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
110
+ canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
111
+ else:
112
+ raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
113
+
114
+ B, Hc, Wc, Cc = canvas.shape
115
+ _, Ho, Wo, _ = overlay.shape
116
+
117
+ x = int(x)
118
+ y = int(y)
119
+
120
+ out = canvas.clone()
121
+
122
+ # intersection on canvas
123
+ x0c = max(0, x)
124
+ y0c = max(0, y)
125
+ x1c = min(Wc, x + Wo)
126
+ y1c = min(Hc, y + Ho)
127
+
128
+ if x1c <= x0c or y1c <= y0c:
129
+ return out
130
+
131
+ # corresponding region on overlay
132
+ x0o = x0c - x
133
+ y0o = y0c - y
134
+ x1o = x0o + (x1c - x0c)
135
+ y1o = y0o + (y1c - y0c)
136
+
137
+ canvas_region = out[:, y0c:y1c, x0c:x1c, :]
138
+ overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
139
+
140
+ # Convert both regions to RGBA for compositing
141
+ canvas_rgba = _ensure_rgba(canvas_region)
142
+ overlay_rgba = _ensure_rgba(overlay_region)
143
+
144
+ over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
145
+ over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
146
+
147
+ under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
148
+ under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
149
+
150
+ # Premultiplied alpha composite: out = over + under*(1-over_a)
151
+ over_pm = over_rgb * over_a
152
+ under_pm = under_rgb * under_a
153
+
154
+ out_a = over_a + under_a * (1.0 - over_a)
155
+ out_pm = over_pm + under_pm * (1.0 - over_a)
156
+
157
+ eps = 1e-6
158
+ out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
159
+ out_rgb = out_rgb.clamp(0.0, 1.0)
160
+ out_a = out_a.clamp(0.0, 1.0)
161
+
162
+ if Cc == 3:
163
+ out[:, y0c:y1c, x0c:x1c, :] = out_rgb
164
+ else:
165
+ out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
166
+
167
+ return out
168
+
169
+
170
+ # -----------------------------
171
+ # RMBG EXACT MASK COMBINE LOGIC (copied solution)
172
+ # -----------------------------
173
+
174
+ class _AILab_MaskCombiner_Exact:
175
+ def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
176
+ try:
177
+ masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
178
+
179
+ if len(masks) <= 1:
180
+ return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
181
+
182
+ ref_shape = masks[0].shape
183
+ masks = [self._resize_if_needed(m, ref_shape) for m in masks]
184
+
185
+ if mode == "combine":
186
+ result = torch.maximum(masks[0], masks[1])
187
+ for mask in masks[2:]:
188
+ result = torch.maximum(result, mask)
189
+ elif mode == "intersection":
190
+ result = torch.minimum(masks[0], masks[1])
191
+ else:
192
+ result = torch.abs(masks[0] - masks[1])
193
+
194
+ return (torch.clamp(result, 0, 1),)
195
+ except Exception as e:
196
+ print(f"Error in combine_masks: {str(e)}")
197
+ print(f"Mask shapes: {[m.shape for m in masks]}")
198
+ raise e
199
+
200
+ def _resize_if_needed(self, mask, target_shape):
201
+ try:
202
+ if mask.shape == target_shape:
203
+ return mask
204
+
205
+ if len(mask.shape) == 2:
206
+ mask = mask.unsqueeze(0)
207
+ elif len(mask.shape) == 4:
208
+ mask = mask.squeeze(1)
209
+
210
+ target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
211
+ target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
212
+
213
+ resized_masks = []
214
+ for i in range(mask.shape[0]):
215
+ mask_np = mask[i].cpu().numpy()
216
+ img = Image.fromarray((mask_np * 255).astype(np.uint8))
217
+ img_resized = img.resize((target_width, target_height), Image.LANCZOS)
218
+ mask_resized = np.array(img_resized).astype(np.float32) / 255.0
219
+ resized_masks.append(torch.from_numpy(mask_resized))
220
+
221
+ return torch.stack(resized_masks)
222
+
223
+ except Exception as e:
224
+ print(f"Error in _resize_if_needed: {str(e)}")
225
+ print(f"Input mask shape: {mask.shape}, Target shape: {target_shape}")
226
+ raise e
227
+
228
+
229
+ # -----------------------------
230
+ # 1) Cropout_Square_From_IMG
231
+ # -----------------------------
232
+
233
+ class Cropout_Square_From_IMG:
234
+ CATEGORY = "image/salia"
235
+
236
+ @classmethod
237
+ def INPUT_TYPES(cls):
238
+ return {
239
+ "required": {
240
+ "img": ("IMAGE",),
241
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
242
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
243
+ "square_size": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
244
+ }
245
+ }
246
+
247
+ RETURN_TYPES = ("IMAGE",)
248
+ RETURN_NAMES = ("image",)
249
+ FUNCTION = "run"
250
+
251
+ def run(self, img, x, y, square_size):
252
+ cropped = _crop_with_padding(img, x, y, square_size, square_size)
253
+ return (cropped,)
254
+
255
+
256
+ # -----------------------------
257
+ # 2) Cropout_Rect_From_IMG
258
+ # -----------------------------
259
+
260
+ class Cropout_Rect_From_IMG:
261
+ CATEGORY = "image/salia"
262
+
263
+ @classmethod
264
+ def INPUT_TYPES(cls):
265
+ return {
266
+ "required": {
267
+ "img": ("IMAGE",),
268
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
269
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
270
+ "width": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
271
+ "height": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
272
+ }
273
+ }
274
+
275
+ RETURN_TYPES = ("IMAGE",)
276
+ RETURN_NAMES = ("image",)
277
+ FUNCTION = "run"
278
+
279
+ def run(self, img, x, y, width, height):
280
+ cropped = _crop_with_padding(img, x, y, width, height)
281
+ return (cropped,)
282
+
283
+
284
+ # -----------------------------
285
+ # 3) Paste_rect_to_img
286
+ # -----------------------------
287
+
288
+ class Paste_rect_to_img:
289
+ CATEGORY = "image/salia"
290
+
291
+ @classmethod
292
+ def INPUT_TYPES(cls):
293
+ return {
294
+ "required": {
295
+ "overlay": ("IMAGE",),
296
+ "canvas": ("IMAGE",),
297
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
298
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
299
+ }
300
+ }
301
+
302
+ RETURN_TYPES = ("IMAGE",)
303
+ RETURN_NAMES = ("image",)
304
+ FUNCTION = "run"
305
+
306
+ def run(self, overlay, canvas, x, y):
307
+ out = _alpha_over_region(overlay, canvas, x, y)
308
+ return (out,)
309
+
310
+
311
+ # -----------------------------
312
+ # 4) Combine_2_masks (RMBG exact: torch.maximum + PIL resize)
313
+ # -----------------------------
314
+
315
+ class Combine_2_masks:
316
+ CATEGORY = "mask/salia"
317
+
318
+ @classmethod
319
+ def INPUT_TYPES(cls):
320
+ return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
321
+
322
+ RETURN_TYPES = ("MASK",)
323
+ RETURN_NAMES = ("mask",)
324
+ FUNCTION = "run"
325
+
326
+ def run(self, maskA, maskB):
327
+ combiner = _AILab_MaskCombiner_Exact()
328
+ out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
329
+ return (out,)
330
+
331
+
332
+ # -----------------------------
333
+ # 5) Combine_2_masks_invert_1 (invert A then RMBG combine)
334
+ # -----------------------------
335
+
336
+ class Combine_2_masks_invert_1:
337
+ CATEGORY = "mask/salia"
338
+
339
+ @classmethod
340
+ def INPUT_TYPES(cls):
341
+ return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
342
+
343
+ RETURN_TYPES = ("MASK",)
344
+ RETURN_NAMES = ("mask",)
345
+ FUNCTION = "run"
346
+
347
+ def run(self, maskA, maskB):
348
+ combiner = _AILab_MaskCombiner_Exact()
349
+ maskA = 1.0 - maskA
350
+ out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
351
+ return (out,)
352
+
353
+
354
+ # -----------------------------
355
+ # 6) Combine_2_masks_inverse
356
+ # invert both, combine, invert result (RMBG max logic)
357
+ # -----------------------------
358
+
359
+ class Combine_2_masks_inverse:
360
+ CATEGORY = "mask/salia"
361
+
362
+ @classmethod
363
+ def INPUT_TYPES(cls):
364
+ return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
365
+
366
+ RETURN_TYPES = ("MASK",)
367
+ RETURN_NAMES = ("mask",)
368
+ FUNCTION = "run"
369
+
370
+ def run(self, maskA, maskB):
371
+ combiner = _AILab_MaskCombiner_Exact()
372
+ maskA = 1.0 - maskA
373
+ maskB = 1.0 - maskB
374
+ combined, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
375
+ out = 1.0 - combined
376
+ out = torch.clamp(out, 0, 1)
377
+ return (out,)
378
+
379
+
380
+ # -----------------------------
381
+ # 7) combine_masks_with_loaded (RMBG exact combine)
382
+ # -----------------------------
383
+
384
+ class combine_masks_with_loaded:
385
+ CATEGORY = "mask/salia"
386
+
387
+ @classmethod
388
+ def INPUT_TYPES(cls):
389
+ choices = list_pngs() or ["<no pngs found>"]
390
+ return {
391
+ "required": {
392
+ "mask": ("MASK",),
393
+ "image": (choices, {}),
394
+ }
395
+ }
396
+
397
+ RETURN_TYPES = ("MASK",)
398
+ RETURN_NAMES = ("mask",)
399
+ FUNCTION = "run"
400
+
401
+ def run(self, mask, image):
402
+ if image == "<no pngs found>":
403
+ raise FileNotFoundError("No PNGs in assets/images")
404
+
405
+ _img, loaded_mask = load_image_from_assets(image)
406
+
407
+ combiner = _AILab_MaskCombiner_Exact()
408
+ out, = combiner.combine_masks(mask, mode="combine", mask_2=1.0-loaded_mask)
409
+ return (out,)
410
+
411
+ @classmethod
412
+ def IS_CHANGED(cls, mask, image):
413
+ if image == "<no pngs found>":
414
+ return image
415
+ return file_hash(image)
416
+
417
+ @classmethod
418
+ def VALIDATE_INPUTS(cls, mask, image):
419
+ if image == "<no pngs found>":
420
+ return "No PNGs in assets/images"
421
+ try:
422
+ path = safe_path(image)
423
+ except Exception as e:
424
+ return str(e)
425
+ if not os.path.isfile(path):
426
+ return f"File not found in assets/images: {image}"
427
+ return True
428
+
429
+
430
+ # -----------------------------
431
+ # 8) NEW: invert input mask, combine with loaded mask, apply to image alpha, paste on canvas
432
+ # -----------------------------
433
+
434
+ class apply_segment:
435
+ CATEGORY = "image/salia"
436
+
437
+ @classmethod
438
+ def INPUT_TYPES(cls):
439
+ choices = list_pngs() or ["<no pngs found>"]
440
+ return {
441
+ "required": {
442
+ "mask": ("MASK",),
443
+ "image": (choices, {}), # dropdown asset (used ONLY for loaded mask)
444
+ "img": ("IMAGE",), # the image to receive final_mask as alpha (overlay source)
445
+ "canvas": ("IMAGE",), # destination
446
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
447
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
448
+ }
449
+ }
450
+
451
+ RETURN_TYPES = ("IMAGE",)
452
+ RETURN_NAMES = ("image",)
453
+ FUNCTION = "run"
454
+
455
+ def run(self, mask, image, img, canvas, x, y):
456
+ if image == "<no pngs found>":
457
+ raise FileNotFoundError("No PNGs in assets/images")
458
+
459
+ combiner = _AILab_MaskCombiner_Exact()
460
+
461
+ # Load asset mask (do NOT invert)
462
+ _img_asset, loaded_mask = load_image_from_assets(image)
463
+
464
+ # Invert input mask, then combine with loaded mask (RMBG exact combine => maximum)
465
+ inv_mask = 1.0 - mask
466
+ final_mask, = combiner.combine_masks(inv_mask, mode="combine", mask_2=loaded_mask)
467
+
468
+ # Apply final_mask as alpha to input image -> final_overlay (RGBA)
469
+ img = _as_image(img)
470
+ B, H, W, C = img.shape
471
+
472
+ # Resize final_mask to match img H/W if needed (uses RMBG exact resize helper)
473
+ # (target_shape must look like a mask shape [B,H,W], but resize keeps its own batch count)
474
+ final_mask_resized = combiner._resize_if_needed(final_mask, (final_mask.shape[0], H, W))
475
+
476
+ # Batch match (simple 1->N expansion only)
477
+ if final_mask_resized.shape[0] != B:
478
+ if final_mask_resized.shape[0] == 1 and B > 1:
479
+ final_mask_resized = final_mask_resized.expand(B, H, W)
480
+ elif B == 1 and final_mask_resized.shape[0] > 1:
481
+ img = img.expand(final_mask_resized.shape[0], *img.shape[1:])
482
+ B = img.shape[0]
483
+ else:
484
+ raise ValueError(f"Batch mismatch: img batch={B}, final_mask batch={final_mask_resized.shape[0]}")
485
+
486
+ if C == 3:
487
+ # RGB -> RGBA with alpha = final_mask
488
+ alpha = final_mask_resized.to(device=img.device, dtype=img.dtype)
489
+ final_overlay = torch.cat([img, alpha.unsqueeze(-1)], dim=-1)
490
+ else:
491
+ # RGBA: combine existing alpha with final_mask using RMBG combine (maximum)
492
+ rgb = img[..., :3]
493
+ alpha_img = img[..., 3] # [B,H,W]
494
+
495
+ # RMBG combine uses PIL-resize sometimes, so keep combine inputs on CPU
496
+ a1 = alpha_img.detach().cpu()
497
+ a2 = final_mask_resized.detach().cpu()
498
+ combined_alpha, = combiner.combine_masks(a1, mode="combine", mask_2=a2)
499
+
500
+ combined_alpha = combined_alpha.to(device=img.device, dtype=img.dtype)
501
+ final_overlay = torch.cat([rgb, combined_alpha.unsqueeze(-1)], dim=-1)
502
+
503
+ # Paste final_overlay onto canvas at (x,y)
504
+ canvas = _as_image(canvas)
505
+ final_overlay = final_overlay.to(device=canvas.device, dtype=canvas.dtype)
506
+
507
+ out = _alpha_over_region(final_overlay, canvas, x, y)
508
+ return (out,)
509
+
510
+ @classmethod
511
+ def IS_CHANGED(cls, mask, image, img, canvas, x, y):
512
+ if image == "<no pngs found>":
513
+ return image
514
+ return file_hash(image)
515
+
516
+ @classmethod
517
+ def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y):
518
+ if image == "<no pngs found>":
519
+ return "No PNGs in assets/images"
520
+ try:
521
+ path = safe_path(image)
522
+ except Exception as e:
523
+ return str(e)
524
+ if not os.path.isfile(path):
525
+ return f"File not found in assets/images: {image}"
526
+ return True
527
+
528
+
529
+ # -----------------------------
530
+ # Node mappings
531
+ # -----------------------------
532
+
533
+ NODE_CLASS_MAPPINGS = {
534
+ "Cropout_Square_From_IMG": Cropout_Square_From_IMG,
535
+ "Cropout_Rect_From_IMG": Cropout_Rect_From_IMG,
536
+ "Paste_rect_to_img": Paste_rect_to_img,
537
+ "Combine_2_masks": Combine_2_masks,
538
+ "Combine_2_masks_invert_1": Combine_2_masks_invert_1,
539
+ "Combine_2_masks_inverse": Combine_2_masks_inverse,
540
+ "combine_masks_with_loaded": combine_masks_with_loaded,
541
+ "apply_segment": apply_segment,
542
+ }
543
+
544
+ NODE_DISPLAY_NAME_MAPPINGS = {
545
+ "Cropout_Square_From_IMG": "Cropout_Square_From_IMG",
546
+ "Cropout_Rect_From_IMG": "Cropout_Rect_From_IMG",
547
+ "Paste_rect_to_img": "Paste_rect_to_img",
548
+ "Combine_2_masks": "Combine_2_masks",
549
+ "Combine_2_masks_invert_1": "Combine_2_masks_invert_1",
550
+ "Combine_2_masks_inverse": "Combine_2_masks_inverse",
551
+ "combine_masks_with_loaded": "combine_masks_with_loaded",
552
+ "apply_segment": "apply_segment",
553
  }