saliacoel commited on
Commit
00dd500
·
verified ·
1 Parent(s): 070007a

Update Salia_Croppytools.py

Browse files
Files changed (1) hide show
  1. Salia_Croppytools.py +36 -484
Salia_Croppytools.py CHANGED
@@ -1,437 +1,16 @@
1
- import os
2
- from typing import Tuple
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- import numpy as np
7
- from PIL import Image
8
-
9
-
10
- # Salia utils (same style as your loader node)
11
- try:
12
- from ..utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
13
- except Exception:
14
- # Fallback if you placed this file in a different package depth
15
- try:
16
- from .utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
17
- except Exception as e:
18
- _UTILS_IMPORT_ERR = e
19
-
20
- def _missing(*args, **kwargs):
21
- raise ImportError(
22
- "Could not import Salia utils (list_pngs/load_image_from_assets/file_hash/safe_path). "
23
- "Place this node file in the same package layout as your other Salia nodes.\n"
24
- f"Original import error: {_UTILS_IMPORT_ERR}"
25
- )
26
-
27
- list_pngs = _missing
28
- load_image_from_assets = _missing
29
- file_hash = _missing
30
- safe_path = _missing
31
-
32
-
33
- # -----------------------------
34
- # Helpers (IMAGE)
35
- # -----------------------------
36
-
37
- def _as_image(img: torch.Tensor) -> torch.Tensor:
38
- # ComfyUI IMAGE is usually [B,H,W,C]
39
- if not isinstance(img, torch.Tensor):
40
- raise TypeError("IMAGE must be a torch.Tensor")
41
- if img.dim() != 4:
42
- raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
43
- if img.shape[-1] not in (3, 4):
44
- raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
45
- return img
46
-
47
-
48
- def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor:
49
- """
50
- Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros.
51
- image: [B,H,W,C]
52
- returns: [B,h,w,C]
53
- """
54
- image = _as_image(image)
55
- B, H, W, C = image.shape
56
- w = max(1, int(w))
57
- h = max(1, int(h))
58
- x = int(x)
59
- y = int(y)
60
-
61
- out = torch.zeros((B, h, w, C), device=image.device, dtype=image.dtype)
62
-
63
- # intersection in source
64
- x0s = max(0, x)
65
- y0s = max(0, y)
66
- x1s = min(W, x + w)
67
- y1s = min(H, y + h)
68
-
69
- if x1s <= x0s or y1s <= y0s:
70
- return out
71
-
72
- # destination offsets
73
- x0d = x0s - x
74
- y0d = y0s - y
75
- x1d = x0d + (x1s - x0s)
76
- y1d = y0d + (y1s - y0s)
77
-
78
- out[:, y0d:y1d, x0d:x1d, :] = image[:, y0s:y1s, x0s:x1s, :]
79
- return out
80
-
81
-
82
- def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
83
- """
84
- img: [B,H,W,C] where C is 3 or 4
85
- returns RGBA [B,H,W,4]
86
- """
87
- img = _as_image(img)
88
- if img.shape[-1] == 4:
89
- return img
90
- # RGB -> RGBA with alpha=1
91
- B, H, W, _ = img.shape
92
- alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
93
- return torch.cat([img, alpha], dim=-1)
94
-
95
-
96
- def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
97
- """
98
- Places overlay at canvas pixel position (x,y) top-left corner.
99
- Supports RGB/RGBA for both. Uses alpha-over if overlay has alpha or canvas has alpha.
100
- Returns same channel count as canvas (3->3, 4->4).
101
- """
102
- overlay = _as_image(overlay)
103
- canvas = _as_image(canvas)
104
-
105
- # Simple batch handling (Comfy usually matches batches, but allow 1->N)
106
- if overlay.shape[0] != canvas.shape[0]:
107
- if overlay.shape[0] == 1 and canvas.shape[0] > 1:
108
- overlay = overlay.expand(canvas.shape[0], *overlay.shape[1:])
109
- elif canvas.shape[0] == 1 and overlay.shape[0] > 1:
110
- canvas = canvas.expand(overlay.shape[0], *canvas.shape[1:])
111
- else:
112
- raise ValueError(f"Batch mismatch: overlay {overlay.shape[0]} vs canvas {canvas.shape[0]}")
113
-
114
- B, Hc, Wc, Cc = canvas.shape
115
- _, Ho, Wo, _ = overlay.shape
116
-
117
- x = int(x)
118
- y = int(y)
119
-
120
- out = canvas.clone()
121
-
122
- # intersection on canvas
123
- x0c = max(0, x)
124
- y0c = max(0, y)
125
- x1c = min(Wc, x + Wo)
126
- y1c = min(Hc, y + Ho)
127
-
128
- if x1c <= x0c or y1c <= y0c:
129
- return out
130
-
131
- # corresponding region on overlay
132
- x0o = x0c - x
133
- y0o = y0c - y
134
- x1o = x0o + (x1c - x0c)
135
- y1o = y0o + (y1c - y0c)
136
-
137
- canvas_region = out[:, y0c:y1c, x0c:x1c, :]
138
- overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
139
-
140
- # Convert both regions to RGBA for compositing
141
- canvas_rgba = _ensure_rgba(canvas_region)
142
- overlay_rgba = _ensure_rgba(overlay_region)
143
-
144
- over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
145
- over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
146
-
147
- under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
148
- under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
149
-
150
- # Premultiplied alpha composite: out = over + under*(1-over_a)
151
- over_pm = over_rgb * over_a
152
- under_pm = under_rgb * under_a
153
-
154
- out_a = over_a + under_a * (1.0 - over_a)
155
- out_pm = over_pm + under_pm * (1.0 - over_a)
156
-
157
- eps = 1e-6
158
- out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
159
- out_rgb = out_rgb.clamp(0.0, 1.0)
160
- out_a = out_a.clamp(0.0, 1.0)
161
-
162
- if Cc == 3:
163
- out[:, y0c:y1c, x0c:x1c, :] = out_rgb
164
- else:
165
- out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
166
-
167
- return out
168
-
169
-
170
- # -----------------------------
171
- # RMBG EXACT MASK COMBINE LOGIC (copied solution)
172
- # -----------------------------
173
-
174
- class _AILab_MaskCombiner_Exact:
175
- def combine_masks(self, mask_1, mode="combine", mask_2=None, mask_3=None, mask_4=None):
176
- try:
177
- masks = [m for m in [mask_1, mask_2, mask_3, mask_4] if m is not None]
178
-
179
- if len(masks) <= 1:
180
- return (masks[0] if masks else torch.zeros((1, 64, 64), dtype=torch.float32),)
181
-
182
- ref_shape = masks[0].shape
183
- masks = [self._resize_if_needed(m, ref_shape) for m in masks]
184
-
185
- if mode == "combine":
186
- result = torch.maximum(masks[0], masks[1])
187
- for mask in masks[2:]:
188
- result = torch.maximum(result, mask)
189
- elif mode == "intersection":
190
- result = torch.minimum(masks[0], masks[1])
191
- else:
192
- result = torch.abs(masks[0] - masks[1])
193
-
194
- return (torch.clamp(result, 0, 1),)
195
- except Exception as e:
196
- print(f"Error in combine_masks: {str(e)}")
197
- print(f"Mask shapes: {[m.shape for m in masks]}")
198
- raise e
199
-
200
- def _resize_if_needed(self, mask, target_shape):
201
- try:
202
- if mask.shape == target_shape:
203
- return mask
204
-
205
- if len(mask.shape) == 2:
206
- mask = mask.unsqueeze(0)
207
- elif len(mask.shape) == 4:
208
- mask = mask.squeeze(1)
209
-
210
- target_height = target_shape[-2] if len(target_shape) >= 2 else target_shape[0]
211
- target_width = target_shape[-1] if len(target_shape) >= 2 else target_shape[1]
212
-
213
- resized_masks = []
214
- for i in range(mask.shape[0]):
215
- mask_np = mask[i].cpu().numpy()
216
- img = Image.fromarray((mask_np * 255).astype(np.uint8))
217
- img_resized = img.resize((target_width, target_height), Image.LANCZOS)
218
- mask_resized = np.array(img_resized).astype(np.float32) / 255.0
219
- resized_masks.append(torch.from_numpy(mask_resized))
220
-
221
- return torch.stack(resized_masks)
222
-
223
- except Exception as e:
224
- print(f"Error in _resize_if_needed: {str(e)}")
225
- print(f"Input mask shape: {mask.shape}, Target shape: {target_shape}")
226
- raise e
227
-
228
-
229
- # -----------------------------
230
- # 1) Cropout_Square_From_IMG
231
- # -----------------------------
232
-
233
- class Cropout_Square_From_IMG:
234
- CATEGORY = "image/salia"
235
-
236
- @classmethod
237
- def INPUT_TYPES(cls):
238
- return {
239
- "required": {
240
- "img": ("IMAGE",),
241
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
242
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
243
- "square_size": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
244
- }
245
- }
246
-
247
- RETURN_TYPES = ("IMAGE",)
248
- RETURN_NAMES = ("image",)
249
- FUNCTION = "run"
250
-
251
- def run(self, img, x, y, square_size):
252
- cropped = _crop_with_padding(img, x, y, square_size, square_size)
253
- return (cropped,)
254
-
255
-
256
- # -----------------------------
257
- # 2) Cropout_Rect_From_IMG
258
- # -----------------------------
259
-
260
- class Cropout_Rect_From_IMG:
261
- CATEGORY = "image/salia"
262
-
263
- @classmethod
264
- def INPUT_TYPES(cls):
265
- return {
266
- "required": {
267
- "img": ("IMAGE",),
268
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
269
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
270
- "width": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
271
- "height": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
272
- }
273
- }
274
-
275
- RETURN_TYPES = ("IMAGE",)
276
- RETURN_NAMES = ("image",)
277
- FUNCTION = "run"
278
-
279
- def run(self, img, x, y, width, height):
280
- cropped = _crop_with_padding(img, x, y, width, height)
281
- return (cropped,)
282
-
283
-
284
- # -----------------------------
285
- # 3) Paste_rect_to_img
286
- # -----------------------------
287
-
288
- class Paste_rect_to_img:
289
- CATEGORY = "image/salia"
290
-
291
- @classmethod
292
- def INPUT_TYPES(cls):
293
- return {
294
- "required": {
295
- "overlay": ("IMAGE",),
296
- "canvas": ("IMAGE",),
297
- "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
298
- "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
299
- }
300
- }
301
-
302
- RETURN_TYPES = ("IMAGE",)
303
- RETURN_NAMES = ("image",)
304
- FUNCTION = "run"
305
-
306
- def run(self, overlay, canvas, x, y):
307
- out = _alpha_over_region(overlay, canvas, x, y)
308
- return (out,)
309
-
310
-
311
- # -----------------------------
312
- # 4) Combine_2_masks (RMBG exact: torch.maximum + PIL resize)
313
- # -----------------------------
314
-
315
- class Combine_2_masks:
316
- CATEGORY = "mask/salia"
317
-
318
- @classmethod
319
- def INPUT_TYPES(cls):
320
- return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
321
-
322
- RETURN_TYPES = ("MASK",)
323
- RETURN_NAMES = ("mask",)
324
- FUNCTION = "run"
325
-
326
- def run(self, maskA, maskB):
327
- combiner = _AILab_MaskCombiner_Exact()
328
- out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
329
- return (out,)
330
-
331
-
332
- # -----------------------------
333
- # 5) Combine_2_masks_invert_1 (invert A then RMBG combine)
334
- # -----------------------------
335
-
336
- class Combine_2_masks_invert_1:
337
- CATEGORY = "mask/salia"
338
-
339
- @classmethod
340
- def INPUT_TYPES(cls):
341
- return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
342
-
343
- RETURN_TYPES = ("MASK",)
344
- RETURN_NAMES = ("mask",)
345
- FUNCTION = "run"
346
-
347
- def run(self, maskA, maskB):
348
- combiner = _AILab_MaskCombiner_Exact()
349
- maskA = 1.0 - maskA
350
- out, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
351
- return (out,)
352
-
353
-
354
- # -----------------------------
355
- # 6) Combine_2_masks_inverse
356
- # invert both, combine, invert result (RMBG max logic)
357
- # -----------------------------
358
-
359
- class Combine_2_masks_inverse:
360
- CATEGORY = "mask/salia"
361
-
362
- @classmethod
363
- def INPUT_TYPES(cls):
364
- return {"required": {"maskA": ("MASK",), "maskB": ("MASK",)}}
365
-
366
- RETURN_TYPES = ("MASK",)
367
- RETURN_NAMES = ("mask",)
368
- FUNCTION = "run"
369
-
370
- def run(self, maskA, maskB):
371
- combiner = _AILab_MaskCombiner_Exact()
372
- maskA = 1.0 - maskA
373
- maskB = 1.0 - maskB
374
- combined, = combiner.combine_masks(maskA, mode="combine", mask_2=maskB)
375
- out = 1.0 - combined
376
- out = torch.clamp(out, 0, 1)
377
- return (out,)
378
-
379
-
380
- # -----------------------------
381
- # 7) combine_masks_with_loaded (RMBG exact combine)
382
- # -----------------------------
383
-
384
- class combine_masks_with_loaded:
385
- CATEGORY = "mask/salia"
386
-
387
- @classmethod
388
- def INPUT_TYPES(cls):
389
- choices = list_pngs() or ["<no pngs found>"]
390
- return {
391
- "required": {
392
- "mask": ("MASK",),
393
- "image": (choices, {}),
394
- }
395
- }
396
-
397
- RETURN_TYPES = ("MASK",)
398
- RETURN_NAMES = ("mask",)
399
- FUNCTION = "run"
400
-
401
- def run(self, mask, image):
402
- if image == "<no pngs found>":
403
- raise FileNotFoundError("No PNGs in assets/images")
404
-
405
- _img, loaded_mask = load_image_from_assets(image)
406
-
407
- combiner = _AILab_MaskCombiner_Exact()
408
- out, = combiner.combine_masks(mask, mode="combine", mask_2=1.0-loaded_mask)
409
- return (out,)
410
-
411
- @classmethod
412
- def IS_CHANGED(cls, mask, image):
413
- if image == "<no pngs found>":
414
- return image
415
- return file_hash(image)
416
-
417
- @classmethod
418
- def VALIDATE_INPUTS(cls, mask, image):
419
- if image == "<no pngs found>":
420
- return "No PNGs in assets/images"
421
- try:
422
- path = safe_path(image)
423
- except Exception as e:
424
- return str(e)
425
- if not os.path.isfile(path):
426
- return f"File not found in assets/images: {image}"
427
- return True
428
-
429
-
430
  # -----------------------------
431
- # 8) NEW: invert input mask, combine with loaded mask, apply to image alpha, paste on canvas
 
 
 
 
 
 
 
 
432
  # -----------------------------
433
 
434
- class apply_segment:
435
  CATEGORY = "image/salia"
436
 
437
  @classmethod
@@ -458,40 +37,32 @@ class apply_segment:
458
 
459
  combiner = _AILab_MaskCombiner_Exact()
460
 
461
- # -------------------
462
- # 1) invert input mask -> inverse_mask
463
- # -------------------
464
- inverse_mask = 1.0 - mask
465
- inverse_mask = torch.clamp(inverse_mask, 0.0, 1.0)
466
 
467
- # Ensure CPU so we don't ever hit CPU/GPU device mismatches with loaded masks/PIL resizing
468
- inverse_mask = inverse_mask.detach().cpu()
469
-
470
- # -------------------
471
- # 2) alpha_mask = combine_masks_with_loaded(inverse_mask)
472
- # i.e. max(inverse_mask, 1 - loaded_mask) using RMBG exact combiner
473
- # -------------------
474
  _img_asset, loaded_mask = load_image_from_assets(image)
475
- loaded_mask = loaded_mask.detach().cpu()
 
 
 
476
 
477
  alpha_mask, = combiner.combine_masks(
478
- inverse_mask,
479
  mode="combine",
480
- mask_2=(1.0 - loaded_mask)
481
  )
482
  alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0)
483
 
484
- # -------------------
485
- # 3) join image with alpha using alpha_mask -> overlay (RGBA)
486
- # IMPORTANT: replace alpha, don't max() with existing alpha
487
- # -------------------
488
  img = _as_image(img)
489
  B, H, W, C = img.shape
490
 
491
- # Resize alpha_mask to match img H/W if needed (same resize logic style as RMBG)
492
  alpha_mask_resized = combiner._resize_if_needed(alpha_mask, (alpha_mask.shape[0], H, W))
493
 
494
- # Batch match (allow 1 -> N expansion)
495
  if alpha_mask_resized.shape[0] != B:
496
  if alpha_mask_resized.shape[0] == 1 and B > 1:
497
  alpha_mask_resized = alpha_mask_resized.expand(B, H, W)
@@ -505,12 +76,19 @@ class apply_segment:
505
 
506
  alpha_mask_resized = alpha_mask_resized.to(device=img.device, dtype=img.dtype).clamp(0.0, 1.0)
507
 
508
- rgb = img[..., :3]
509
- overlay = torch.cat([rgb, alpha_mask_resized.unsqueeze(-1)], dim=-1) # RGBA
 
 
 
 
 
 
510
 
511
- # -------------------
512
- # 4) paste overlay onto canvas at (x,y) -> final output
513
- # -------------------
 
514
  canvas = _as_image(canvas)
515
  overlay = overlay.to(device=canvas.device, dtype=canvas.dtype)
516
 
@@ -533,30 +111,4 @@ class apply_segment:
533
  return str(e)
534
  if not os.path.isfile(path):
535
  return f"File not found in assets/images: {image}"
536
- return True
537
-
538
- # -----------------------------
539
- # Node mappings
540
- # -----------------------------
541
-
542
- NODE_CLASS_MAPPINGS = {
543
- "Cropout_Square_From_IMG": Cropout_Square_From_IMG,
544
- "Cropout_Rect_From_IMG": Cropout_Rect_From_IMG,
545
- "Paste_rect_to_img": Paste_rect_to_img,
546
- "Combine_2_masks": Combine_2_masks,
547
- "Combine_2_masks_invert_1": Combine_2_masks_invert_1,
548
- "Combine_2_masks_inverse": Combine_2_masks_inverse,
549
- "combine_masks_with_loaded": combine_masks_with_loaded,
550
- "apply_segment": apply_segment,
551
- }
552
-
553
- NODE_DISPLAY_NAME_MAPPINGS = {
554
- "Cropout_Square_From_IMG": "Cropout_Square_From_IMG",
555
- "Cropout_Rect_From_IMG": "Cropout_Rect_From_IMG",
556
- "Paste_rect_to_img": "Paste_rect_to_img",
557
- "Combine_2_masks": "Combine_2_masks",
558
- "Combine_2_masks_invert_1": "Combine_2_masks_invert_1",
559
- "Combine_2_masks_inverse": "Combine_2_masks_inverse",
560
- "combine_masks_with_loaded": "combine_masks_with_loaded",
561
- "apply_segment": "apply_segment",
562
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -----------------------------
2
+ # 9) NEW: apply_segment_2
3
+ # Steps:
4
+ # 1) inverse_mask = 1 - mask
5
+ # 2) alpha_mask = combine_masks_with_loaded(inverse_mask, selected_asset)
6
+ # (i.e. max(inverse_mask, 1 - loaded_mask))
7
+ # 3) overlay = join img with alpha using alpha_mask
8
+ # - RGB: create RGBA with alpha = alpha_mask
9
+ # - RGBA: alpha_out = alpha_img * alpha_mask (more transparent, never more opaque)
10
+ # 4) paste overlay onto canvas at (x,y) using alpha-over
11
  # -----------------------------
12
 
13
+ class apply_segment_2:
14
  CATEGORY = "image/salia"
15
 
16
  @classmethod
 
37
 
38
  combiner = _AILab_MaskCombiner_Exact()
39
 
40
+ # --- Step 1: invert input mask -> inverse_mask
41
+ inverse_mask = (1.0 - mask)
 
 
 
42
 
43
+ # --- Step 2: alpha_mask = combine_masks_with_loaded(inverse_mask, image)
44
+ # combine_masks_with_loaded does: max(mask, 1-loaded_mask)
 
 
 
 
 
45
  _img_asset, loaded_mask = load_image_from_assets(image)
46
+
47
+ # Make sure both are on CPU so combiner doesn't hit device mismatch
48
+ inverse_mask_cpu = inverse_mask.detach().cpu()
49
+ loaded_mask_cpu = loaded_mask.detach().cpu()
50
 
51
  alpha_mask, = combiner.combine_masks(
52
+ inverse_mask_cpu,
53
  mode="combine",
54
+ mask_2=(1.0 - loaded_mask_cpu),
55
  )
56
  alpha_mask = torch.clamp(alpha_mask, 0.0, 1.0)
57
 
58
+ # --- Step 3: join img with alpha using alpha_mask -> overlay
 
 
 
59
  img = _as_image(img)
60
  B, H, W, C = img.shape
61
 
62
+ # Resize alpha_mask to match img H/W if needed (RMBG exact resize helper)
63
  alpha_mask_resized = combiner._resize_if_needed(alpha_mask, (alpha_mask.shape[0], H, W))
64
 
65
+ # Batch match (simple 1->N expansion only)
66
  if alpha_mask_resized.shape[0] != B:
67
  if alpha_mask_resized.shape[0] == 1 and B > 1:
68
  alpha_mask_resized = alpha_mask_resized.expand(B, H, W)
 
76
 
77
  alpha_mask_resized = alpha_mask_resized.to(device=img.device, dtype=img.dtype).clamp(0.0, 1.0)
78
 
79
+ if C == 3:
80
+ # RGB -> RGBA with alpha = alpha_mask
81
+ overlay = torch.cat([img, alpha_mask_resized.unsqueeze(-1)], dim=-1)
82
+ else:
83
+ # RGBA: DO NOT replace alpha.
84
+ # Combine to become MORE transparent: multiply existing alpha by alpha_mask.
85
+ rgb = img[..., :3]
86
+ alpha_img = img[..., 3].clamp(0.0, 1.0)
87
 
88
+ alpha_out = (alpha_img * alpha_mask_resized).clamp(0.0, 1.0)
89
+ overlay = torch.cat([rgb, alpha_out.unsqueeze(-1)], dim=-1)
90
+
91
+ # --- Step 4: paste overlay onto canvas at (x,y)
92
  canvas = _as_image(canvas)
93
  overlay = overlay.to(device=canvas.device, dtype=canvas.dtype)
94
 
 
111
  return str(e)
112
  if not os.path.isfile(path):
113
  return f"File not found in assets/images: {image}"
114
+ return True