saliacoel commited on
Commit
067dbfa
·
verified ·
1 Parent(s): eb0d4e7

Upload Salia_Croppytools.py

Browse files
Files changed (1) hide show
  1. Salia_Croppytools.py +193 -100
Salia_Croppytools.py CHANGED
@@ -3,6 +3,9 @@ from typing import Tuple
3
 
4
  import torch
5
  import torch.nn.functional as F
 
 
 
6
 
7
  # Salia utils (same style as your loader node)
8
  try:
@@ -28,7 +31,7 @@ except Exception:
28
 
29
 
30
  # -----------------------------
31
- # Helpers
32
  # -----------------------------
33
 
34
  def _as_image(img: torch.Tensor) -> torch.Tensor:
@@ -42,48 +45,6 @@ def _as_image(img: torch.Tensor) -> torch.Tensor:
42
  return img
43
 
44
 
45
- def _as_mask(msk: torch.Tensor) -> torch.Tensor:
46
- # ComfyUI MASK is usually [B,H,W] float 0..1
47
- if not isinstance(msk, torch.Tensor):
48
- raise TypeError("MASK must be a torch.Tensor")
49
- if msk.dim() == 2:
50
- msk = msk.unsqueeze(0)
51
- if msk.dim() != 3:
52
- raise ValueError(f"Expected MASK shape [B,H,W] (or [H,W]), got {tuple(msk.shape)}")
53
- return msk
54
-
55
-
56
- def _match_batch(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
57
- ba = a.shape[0]
58
- bb = b.shape[0]
59
- if ba == bb:
60
- return a, b
61
- if ba == 1 and bb > 1:
62
- return a.expand(bb, *a.shape[1:]), b
63
- if bb == 1 and ba > 1:
64
- return a, b.expand(ba, *b.shape[1:])
65
- raise ValueError(f"Batch mismatch: A has batch {ba}, B has batch {bb} (and neither is 1).")
66
-
67
-
68
- def _resize_mask_to(msk: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor:
69
- # msk: [B,H,W] -> resize to [B,target_h,target_w]
70
- if msk.shape[1] == target_h and msk.shape[2] == target_w:
71
- return msk
72
- x = msk.unsqueeze(1) # [B,1,H,W]
73
- x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False)
74
- return x.squeeze(1)
75
-
76
-
77
- def _combine_alpha_union(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
78
- """
79
- "Alpha combine" (union) like standard alpha coverage:
80
- out = 1 - (1-a)*(1-b)
81
- """
82
- a = a.clamp(0.0, 1.0)
83
- b = b.clamp(0.0, 1.0)
84
- return (1.0 - (1.0 - a) * (1.0 - b)).clamp(0.0, 1.0)
85
-
86
-
87
  def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor:
88
  """
89
  Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros.
@@ -141,10 +102,17 @@ def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: i
141
  overlay = _as_image(overlay)
142
  canvas = _as_image(canvas)
143
 
144
- overlay, canvas = _match_batch(overlay, canvas)
 
 
 
 
 
 
 
145
 
146
  B, Hc, Wc, Cc = canvas.shape
147
- Bo, Ho, Wo, Co = overlay.shape
148
 
149
  x = int(x)
150
  y = int(y)
@@ -199,6 +167,65 @@ def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: i
199
  return out
200
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # -----------------------------
203
  # 1) Cropout_Square_From_IMG
204
  # -----------------------------
@@ -282,7 +309,7 @@ class Paste_rect_to_img:
282
 
283
 
284
  # -----------------------------
285
- # 4) Combine_2_masks
286
  # -----------------------------
287
 
288
  class Combine_2_masks:
@@ -290,30 +317,20 @@ class Combine_2_masks:
290
 
291
  @classmethod
292
  def INPUT_TYPES(cls):
293
- return {
294
- "required": {
295
- "maskA": ("MASK",),
296
- "maskB": ("MASK",),
297
- }
298
- }
299
 
300
  RETURN_TYPES = ("MASK",)
301
  RETURN_NAMES = ("mask",)
302
  FUNCTION = "run"
303
 
304
  def run(self, maskA, maskB):
305
- a = _as_mask(maskA)
306
- b = _as_mask(maskB)
307
-
308
- a, b = _match_batch(a, b)
309
- b = _resize_mask_to(b, a.shape[1], a.shape[2])
310
-
311
- out = _combine_alpha_union(a, b)
312
  return (out,)
313
 
314
 
315
  # -----------------------------
316
- # 5) Combine_2_masks_invert_1
317
  # -----------------------------
318
 
319
  class Combine_2_masks_invert_1:
@@ -321,31 +338,22 @@ class Combine_2_masks_invert_1:
321
 
322
  @classmethod
323
  def INPUT_TYPES(cls):
324
- return {
325
- "required": {
326
- "maskA": ("MASK",),
327
- "maskB": ("MASK",),
328
- }
329
- }
330
 
331
  RETURN_TYPES = ("MASK",)
332
  RETURN_NAMES = ("mask",)
333
  FUNCTION = "run"
334
 
335
  def run(self, maskA, maskB):
336
- a = _as_mask(maskA)
337
- b = _as_mask(maskB)
338
-
339
- a, b = _match_batch(a, b)
340
- b = _resize_mask_to(b, a.shape[1], a.shape[2])
341
-
342
- a_inv = (1.0 - a).clamp(0.0, 1.0)
343
- out = _combine_alpha_union(a_inv, b)
344
  return (out,)
345
 
346
 
347
  # -----------------------------
348
  # 6) Combine_2_masks_inverse
 
349
  # -----------------------------
350
 
351
  class Combine_2_masks_inverse:
@@ -353,34 +361,24 @@ class Combine_2_masks_inverse:
353
 
354
  @classmethod
355
  def INPUT_TYPES(cls):
356
- return {
357
- "required": {
358
- "maskA": ("MASK",),
359
- "maskB": ("MASK",),
360
- }
361
- }
362
 
363
  RETURN_TYPES = ("MASK",)
364
  RETURN_NAMES = ("mask",)
365
  FUNCTION = "run"
366
 
367
  def run(self, maskA, maskB):
368
- a = _as_mask(maskA)
369
- b = _as_mask(maskB)
370
-
371
- a, b = _match_batch(a, b)
372
- b = _resize_mask_to(b, a.shape[1], a.shape[2])
373
-
374
- a_inv = (1.0 - a).clamp(0.0, 1.0)
375
- b_inv = (1.0 - b).clamp(0.0, 1.0)
376
-
377
- combined_inv = _combine_alpha_union(a_inv, b_inv)
378
- out = (1.0 - combined_inv).clamp(0.0, 1.0) # == a*b (intersection)
379
  return (out,)
380
 
381
 
382
  # -----------------------------
383
- # 7) combine_masks_with_loaded
384
  # -----------------------------
385
 
386
  class combine_masks_with_loaded:
@@ -404,16 +402,10 @@ class combine_masks_with_loaded:
404
  if image == "<no pngs found>":
405
  raise FileNotFoundError("No PNGs in assets/images")
406
 
407
- base = _as_mask(mask)
408
-
409
- # Load image+mask from assets (Salia util)
410
  _img, loaded_mask = load_image_from_assets(image)
411
- loaded = _as_mask(loaded_mask)
412
 
413
- base, loaded = _match_batch(base, loaded)
414
- loaded = _resize_mask_to(loaded, base.shape[1], base.shape[2])
415
-
416
- out = _combine_alpha_union(base, loaded)
417
  return (out,)
418
 
419
  @classmethod
@@ -435,6 +427,105 @@ class combine_masks_with_loaded:
435
  return True
436
 
437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  # -----------------------------
439
  # Node mappings
440
  # -----------------------------
@@ -447,6 +538,7 @@ NODE_CLASS_MAPPINGS = {
447
  "Combine_2_masks_invert_1": Combine_2_masks_invert_1,
448
  "Combine_2_masks_inverse": Combine_2_masks_inverse,
449
  "combine_masks_with_loaded": combine_masks_with_loaded,
 
450
  }
451
 
452
  NODE_DISPLAY_NAME_MAPPINGS = {
@@ -457,4 +549,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
457
  "Combine_2_masks_invert_1": "Combine_2_masks_invert_1",
458
  "Combine_2_masks_inverse": "Combine_2_masks_inverse",
459
  "combine_masks_with_loaded": "combine_masks_with_loaded",
 
460
  }
 
3
 
4
  import torch
5
  import torch.nn.functional as F
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
 
10
  # Salia utils (same style as your loader node)
11
  try:
 
31
 
32
 
33
  # -----------------------------
34
+ # Helpers (IMAGE)
35
  # -----------------------------
36
 
37
  def _as_image(img: torch.Tensor) -> torch.Tensor:
 
45
  return img
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor:
49
  """
50
  Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros.
 
102
  overlay = _as_image(overlay)
103
  canvas = _as_image(canvas)
104
 
105
+ # Simple batch handling (Comfy usually matches batches, but allow 1->N)
106
+ if overlay.shape[0] != canvas.shape[0]:
107
+ if overlay.shape[0] == 1 and canvas.shape[0] > 1:
108
+ overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
109
+ elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
110
+ canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
111
+ else:
112
+ raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
113
 
114
  B, Hc, Wc, Cc = canvas.shape
115
+ _, Ho, Wo, _ = overlay.shape
116
 
117
  x = int(x)
118
  y = int(y)
 
167
  return out
168
 
169
 
170
+ # -----------------------------
171
+ # RMBG EXACT MASK COMBINE LOGIC (copied solution)
172
+ # -----------------------------
173
+
174
+ class _AILab_MaskCombiner_Exact:
175
+ def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
176
+ try:
177
+ masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
178
+
179
+ if len(masks) <= 1:
180
+ return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
181
+
182
+ ref_shape = masks[0].shape
183
+ masks = [self._resize_if_needed(m, ref_shape) for m in masks]
184
+
185
+ if mode == "combine":
186
+ result = torch.maximum(masks[0], masks[1])
187
+ for mask in masks[2:]:
188
+ result = torch.maximum(result, mask)
189
+ elif mode == "intersection":
190
+ result = torch.minimum(masks[0], masks[1])
191
+ else:
192
+ result = torch.abs(masks[0] - masks[1])
193
+
194
+ return (torch.clamp(result, 0, 1),)
195
+ except Exception as e:
196
+ print(f"Error in combine_masks: {str(e)}")
197
+ print(f"Mask shapes: {[m.shape for m in masks]}")
198
+ raise e
199
+
200
+ def _resize_if_needed(self, mask, target_shape):
201
+ try:
202
+ if mask.shape == target_shape:
203
+ return mask
204
+
205
+ if len(mask.shape) == 2:
206
+ mask = mask.unsqueeze(0)
207
+ elif len(mask.shape) == 4:
208
+ mask = mask.squeeze(1)
209
+
210
+ target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
211
+ target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
212
+
213
+ resized_masks = []
214
+ for i in range(mask.shape[0]):
215
+ mask_np = mask[i].cpu().numpy()
216
+ img = Image.fromarray((mask_np * 255).astype(np.uint8))
217
+ img_resized = img.resize((target_width, target_height), Image.LANCZOS)
218
+ mask_resized = np.array(img_resized).astype(np.float32) / 255.0
219
+ resized_masks.append(torch.from_numpy(mask_resized))
220
+
221
+ return torch.stack(resized_masks)
222
+
223
+ except Exception as e:
224
+ print(f"Error in _resize_if_needed: {str(e)}")
225
+ print(f"Input mask shape: {mask.shape}, Target shape: {target_shape}")
226
+ raise e
227
+
228
+
229
  # -----------------------------
230
  # 1) Cropout_Square_From_IMG
231
  # -----------------------------
 
309
 
310
 
311
  # -----------------------------
312
+ # 4) Combine_2_masks (RMBG exact: torch.maximum + PIL resize)
313
  # -----------------------------
314
 
315
  class Combine_2_masks:
 
317
 
318
  @classmethod
319
  def INPUT_TYPES(cls):
320
+ return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
 
 
 
 
 
321
 
322
  RETURN_TYPES = ("MASK",)
323
  RETURN_NAMES = ("mask",)
324
  FUNCTION = "run"
325
 
326
  def run(self, maskA, maskB):
327
+ combiner = _AILab_MaskCombiner_Exact()
328
+ out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
 
 
 
 
 
329
  return (out,)
330
 
331
 
332
  # -----------------------------
333
+ # 5) Combine_2_masks_invert_1 (invert A then RMBG combine)
334
  # -----------------------------
335
 
336
  class Combine_2_masks_invert_1:
 
338
 
339
  @classmethod
340
  def INPUT_TYPES(cls):
341
+ return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
 
 
 
 
 
342
 
343
  RETURN_TYPES = ("MASK",)
344
  RETURN_NAMES = ("mask",)
345
  FUNCTION = "run"
346
 
347
  def run(self, maskA, maskB):
348
+ combiner = _AILab_MaskCombiner_Exact()
349
+ maskA = 1.0 - maskA
350
+ out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
 
 
 
 
 
351
  return (out,)
352
 
353
 
354
  # -----------------------------
355
  # 6) Combine_2_masks_inverse
356
+ # invert both, combine, invert result (RMBG max logic)
357
  # -----------------------------
358
 
359
  class Combine_2_masks_inverse:
 
361
 
362
  @classmethod
363
  def INPUT_TYPES(cls):
364
+ return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
 
 
 
 
 
365
 
366
  RETURN_TYPES = ("MASK",)
367
  RETURN_NAMES = ("mask",)
368
  FUNCTION = "run"
369
 
370
  def run(self, maskA, maskB):
371
+ combiner = _AILab_MaskCombiner_Exact()
372
+ maskA = 1.0 - maskA
373
+ maskB = 1.0 - maskB
374
+ combined, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
375
+ out = 1.0 - combined
376
+ out = torch.clamp(out, 0, 1)
 
 
 
 
 
377
  return (out,)
378
 
379
 
380
  # -----------------------------
381
+ # 7) combine_masks_with_loaded (RMBG exact combine)
382
  # -----------------------------
383
 
384
  class combine_masks_with_loaded:
 
402
  if image == "<no pngs found>":
403
  raise FileNotFoundError("No PNGs in assets/images")
404
 
 
 
 
405
  _img, loaded_mask = load_image_from_assets(image)
 
406
 
407
+ combiner = _AILab_MaskCombiner_Exact()
408
+ out, = combiner.combine_masks(mask, mode="combine", mask_2=loaded_mask)
 
 
409
  return (out,)
410
 
411
  @classmethod
 
427
  return True
428
 
429
 
430
+ # -----------------------------
431
+ # 8) NEW: invert input mask, combine with loaded mask, apply to image alpha, paste on canvas
432
+ # -----------------------------
433
+
434
+ class apply_segment:
435
+ CATEGORY = "image/salia"
436
+
437
+ @classmethod
438
+ def INPUT_TYPES(cls):
439
+ choices = list_pngs() or ["<no pngs found>"]
440
+ return {
441
+ "required": {
442
+ "mask": ("MASK",),
443
+ "image": (choices, {}), # dropdown asset (used ONLY for loaded mask)
444
+ "img": ("IMAGE",), # the image to receive final_mask as alpha (overlay source)
445
+ "canvas": ("IMAGE",), # destination
446
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
447
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
448
+ }
449
+ }
450
+
451
+ RETURN_TYPES = ("IMAGE",)
452
+ RETURN_NAMES = ("image",)
453
+ FUNCTION = "run"
454
+
455
+ def run(self, mask, image, img, canvas, x, y):
456
+ if image == "<no pngs found>":
457
+ raise FileNotFoundError("No PNGs in assets/images")
458
+
459
+ combiner = _AILab_MaskCombiner_Exact()
460
+
461
+ # Load asset mask (do NOT invert)
462
+ _img_asset, loaded_mask = load_image_from_assets(image)
463
+
464
+ # Invert input mask, then combine with loaded mask (RMBG exact combine => maximum)
465
+ inv_mask = 1.0 - mask
466
+ final_mask, = combiner.combine_masks(inv_mask, mode="combine", mask_2=loaded_mask)
467
+
468
+ # Apply final_mask as alpha to input image -> final_overlay (RGBA)
469
+ img = _as_image(img)
470
+ B, H, W, C = img.shape
471
+
472
+ # Resize final_mask to match img H/W if needed (uses RMBG exact resize helper)
473
+ # (target_shape must look like a mask shape [B,H,W], but resize keeps its own batch count)
474
+ final_mask_resized = combiner._resize_if_needed(final_mask, (final_mask.shape[0], H, W))
475
+
476
+ # Batch match (simple 1->N expansion only)
477
+ if final_mask_resized.shape[0] != B:
478
+ if final_mask_resized.shape[0] == 1 and B > 1:
479
+ final_mask_resized = final_mask_resized.expand(B, H, W)
480
+ elif B == 1 and final_mask_resized.shape[0] > 1:
481
+ img = img.expand(final_mask_resized.shape[0], *img.shape[1:])
482
+ B = img.shape[0]
483
+ else:
484
+ raise ValueError(f"Batch mismatch: img batch={B}, final_mask batch={final_mask_resized.shape[0]}")
485
+
486
+ if C == 3:
487
+ # RGB -> RGBA with alpha = final_mask
488
+ alpha = final_mask_resized.to(device=img.device, dtype=img.dtype)
489
+ final_overlay = torch.cat([img, alpha.unsqueeze(-1)], dim=-1)
490
+ else:
491
+ # RGBA: combine existing alpha with final_mask using RMBG combine (maximum)
492
+ rgb = img[..., :3]
493
+ alpha_img = img[..., 3] # [B,H,W]
494
+
495
+ # RMBG combine uses PIL-resize sometimes, so keep combine inputs on CPU
496
+ a1 = alpha_img.detach().cpu()
497
+ a2 = final_mask_resized.detach().cpu()
498
+ combined_alpha, = combiner.combine_masks(a1, mode="combine", mask_2=a2)
499
+
500
+ combined_alpha = combined_alpha.to(device=img.device, dtype=img.dtype)
501
+ final_overlay = torch.cat([rgb, combined_alpha.unsqueeze(-1)], dim=-1)
502
+
503
+ # Paste final_overlay onto canvas at (x,y)
504
+ canvas = _as_image(canvas)
505
+ final_overlay = final_overlay.to(device=canvas.device, dtype=canvas.dtype)
506
+
507
+ out = _alpha_over_region(final_overlay, canvas, x, y)
508
+ return (out,)
509
+
510
+ @classmethod
511
+ def IS_CHANGED(cls, mask, image, img, canvas, x, y):
512
+ if image == "<no pngs found>":
513
+ return image
514
+ return file_hash(image)
515
+
516
+ @classmethod
517
+ def VALIDATE_INPUTS(cls, mask, image, img, canvas, x, y):
518
+ if image == "<no pngs found>":
519
+ return "No PNGs in assets/images"
520
+ try:
521
+ path = safe_path(image)
522
+ except Exception as e:
523
+ return str(e)
524
+ if not os.path.isfile(path):
525
+ return f"File not found in assets/images: {image}"
526
+ return True
527
+
528
+
529
  # -----------------------------
530
  # Node mappings
531
  # -----------------------------
 
538
  "Combine_2_masks_invert_1": Combine_2_masks_invert_1,
539
  "Combine_2_masks_inverse": Combine_2_masks_inverse,
540
  "combine_masks_with_loaded": combine_masks_with_loaded,
541
+ "apply_segment": apply_segment,
542
  }
543
 
544
  NODE_DISPLAY_NAME_MAPPINGS = {
 
549
  "Combine_2_masks_invert_1": "Combine_2_masks_invert_1",
550
  "Combine_2_masks_inverse": "Combine_2_masks_inverse",
551
  "combine_masks_with_loaded": "combine_masks_with_loaded",
552
+ "apply_segment": "apply_segment",
553
  }