saliacoel commited on
Commit
b0b1cf0
·
verified ·
1 Parent(s): f260c3a

Upload Salia_Croppytools.py

Browse files
Files changed (1) hide show
  1. Salia_Croppytools.py +460 -0
Salia_Croppytools.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ # Salia utils (same style as your loader node)
8
+ try:
9
+ from ..utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
10
+ except Exception:
11
+ # Fallback if you placed this file in a different package depth
12
+ try:
13
+ from .utils.io import list_pngs, load_image_from_assets, file_hash, safe_path
14
+ except Exception as e:
15
+ _UTILS_IMPORT_ERR = e
16
+
17
+ def _missing(*args, **kwargs):
18
+ raise ImportError(
19
+ "Could not import Salia utils (list_pngs/load_image_from_assets/file_hash/safe_path). "
20
+ "Place this node file in the same package layout as your other Salia nodes.\n"
21
+ f"Original import error: {_UTILS_IMPORT_ERR}"
22
+ )
23
+
24
+ list_pngs = _missing
25
+ load_image_from_assets = _missing
26
+ file_hash = _missing
27
+ safe_path = _missing
28
+
29
+
30
+ # -----------------------------
31
+ # Helpers
32
+ # -----------------------------
33
+
34
+ def _as_image(img: torch.Tensor) -> torch.Tensor:
35
+ # ComfyUI IMAGE is usually [B,H,W,C]
36
+ if not isinstance(img, torch.Tensor):
37
+ raise TypeError("IMAGE must be a torch.Tensor")
38
+ if img.dim() != 4:
39
+ raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(img.shape)}")
40
+ if img.shape[-1] not in (3, 4):
41
+ raise ValueError(f"Expected IMAGE channels 3 (RGB) or 4 (RGBA), got C={img.shape[-1]}")
42
+ return img
43
+
44
+
45
+ def _as_mask(msk: torch.Tensor) -> torch.Tensor:
46
+ # ComfyUI MASK is usually [B,H,W] float 0..1
47
+ if not isinstance(msk, torch.Tensor):
48
+ raise TypeError("MASK must be a torch.Tensor")
49
+ if msk.dim() == 2:
50
+ msk = msk.unsqueeze(0)
51
+ if msk.dim() != 3:
52
+ raise ValueError(f"Expected MASK shape [B,H,W] (or [H,W]), got {tuple(msk.shape)}")
53
+ return msk
54
+
55
+
56
+ def _match_batch(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ ba = a.shape[0]
58
+ bb = b.shape[0]
59
+ if ba == bb:
60
+ return a, b
61
+ if ba == 1 and bb > 1:
62
+ return a.expand(bb, *a.shape[1:]), b
63
+ if bb == 1 and ba > 1:
64
+ return a, b.expand(ba, *b.shape[1:])
65
+ raise ValueError(f"Batch mismatch: A has batch {ba}, B has batch {bb} (and neither is 1).")
66
+
67
+
68
+ def _resize_mask_to(msk: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor:
69
+ # msk: [B,H,W] -> resize to [B,target_h,target_w]
70
+ if msk.shape[1] == target_h and msk.shape[2] == target_w:
71
+ return msk
72
+ x = msk.unsqueeze(1) # [B,1,H,W]
73
+ x = F.interpolate(x, size=(target_h, target_w), mode="bilinear", align_corners=False)
74
+ return x.squeeze(1)
75
+
76
+
77
+ def _combine_alpha_union(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ "Alpha combine" (union) like standard alpha coverage:
80
+ out = 1 - (1-a)*(1-b)
81
+ """
82
+ a = a.clamp(0.0, 1.0)
83
+ b = b.clamp(0.0, 1.0)
84
+ return (1.0 - (1.0 - a) * (1.0 - b)).clamp(0.0, 1.0)
85
+
86
+
87
+ def _crop_with_padding(image: torch.Tensor, x: int, y: int, w: int, h: int) -> torch.Tensor:
88
+ """
89
+ Crops [x,y] top-left, size w*h. If out of bounds, pads with zeros.
90
+ image: [B,H,W,C]
91
+ returns: [B,h,w,C]
92
+ """
93
+ image = _as_image(image)
94
+ B, H, W, C = image.shape
95
+ w = max(1, int(w))
96
+ h = max(1, int(h))
97
+ x = int(x)
98
+ y = int(y)
99
+
100
+ out = torch.zeros((B, h, w, C), device=image.device, dtype=image.dtype)
101
+
102
+ # intersection in source
103
+ x0s = max(0, x)
104
+ y0s = max(0, y)
105
+ x1s = min(W, x + w)
106
+ y1s = min(H, y + h)
107
+
108
+ if x1s <= x0s or y1s <= y0s:
109
+ return out
110
+
111
+ # destination offsets
112
+ x0d = x0s - x
113
+ y0d = y0s - y
114
+ x1d = x0d + (x1s - x0s)
115
+ y1d = y0d + (y1s - y0s)
116
+
117
+ out[:, y0d:y1d, x0d:x1d, :] = image[:, y0s:y1s, x0s:x1s, :]
118
+ return out
119
+
120
+
121
+ def _ensure_rgba(img: torch.Tensor) -> torch.Tensor:
122
+ """
123
+ img: [B,H,W,C] where C is 3 or 4
124
+ returns RGBA [B,H,W,4]
125
+ """
126
+ img = _as_image(img)
127
+ if img.shape[-1] == 4:
128
+ return img
129
+ # RGB -> RGBA with alpha=1
130
+ B, H, W, _ = img.shape
131
+ alpha = torch.ones((B, H, W, 1), device=img.device, dtype=img.dtype)
132
+ return torch.cat([img, alpha], dim=-1)
133
+
134
+
135
+ def _alpha_over_region(overlay: torch.Tensor, canvas: torch.Tensor, x: int, y: int) -> torch.Tensor:
136
+ """
137
+ Places overlay at canvas pixel position (x,y) top-left corner.
138
+ Supports RGB/RGBA for both. Uses alpha-over if overlay has alpha or canvas has alpha.
139
+ Returns same channel count as canvas (3->3, 4->4).
140
+ """
141
+ overlay = _as_image(overlay)
142
+ canvas = _as_image(canvas)
143
+
144
+ overlay, canvas = _match_batch(overlay, canvas)
145
+
146
+ B, Hc, Wc, Cc = canvas.shape
147
+ Bo, Ho, Wo, Co = overlay.shape
148
+
149
+ x = int(x)
150
+ y = int(y)
151
+
152
+ out = canvas.clone()
153
+
154
+ # intersection on canvas
155
+ x0c = max(0, x)
156
+ y0c = max(0, y)
157
+ x1c = min(Wc, x + Wo)
158
+ y1c = min(Hc, y + Ho)
159
+
160
+ if x1c <= x0c or y1c <= y0c:
161
+ return out
162
+
163
+ # corresponding region on overlay
164
+ x0o = x0c - x
165
+ y0o = y0c - y
166
+ x1o = x0o + (x1c - x0c)
167
+ y1o = y0o + (y1c - y0c)
168
+
169
+ canvas_region = out[:, y0c:y1c, x0c:x1c, :]
170
+ overlay_region = overlay[:, y0o:y1o, x0o:x1o, :]
171
+
172
+ # Convert both regions to RGBA for compositing
173
+ canvas_rgba = _ensure_rgba(canvas_region)
174
+ overlay_rgba = _ensure_rgba(overlay_region)
175
+
176
+ over_rgb = overlay_rgba[..., :3].clamp(0.0, 1.0)
177
+ over_a = overlay_rgba[..., 3:4].clamp(0.0, 1.0)
178
+
179
+ under_rgb = canvas_rgba[..., :3].clamp(0.0, 1.0)
180
+ under_a = canvas_rgba[..., 3:4].clamp(0.0, 1.0)
181
+
182
+ # Premultiplied alpha composite: out = over + under*(1-over_a)
183
+ over_pm = over_rgb * over_a
184
+ under_pm = under_rgb * under_a
185
+
186
+ out_a = over_a + under_a * (1.0 - over_a)
187
+ out_pm = over_pm + under_pm * (1.0 - over_a)
188
+
189
+ eps = 1e-6
190
+ out_rgb = torch.where(out_a > eps, out_pm / (out_a + eps), torch.zeros_like(out_pm))
191
+ out_rgb = out_rgb.clamp(0.0, 1.0)
192
+ out_a = out_a.clamp(0.0, 1.0)
193
+
194
+ if Cc == 3:
195
+ out[:, y0c:y1c, x0c:x1c, :] = out_rgb
196
+ else:
197
+ out[:, y0c:y1c, x0c:x1c, :] = torch.cat([out_rgb, out_a], dim=-1)
198
+
199
+ return out
200
+
201
+
202
+ # -----------------------------
203
+ # 1) Cropout_Square_From_IMG
204
+ # -----------------------------
205
+
206
+ class Cropout_Square_From_IMG:
207
+ CATEGORY = "image/salia"
208
+
209
+ @classmethod
210
+ def INPUT_TYPES(cls):
211
+ return {
212
+ "required": {
213
+ "img": ("IMAGE",),
214
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
215
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
216
+ "square_size": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
217
+ }
218
+ }
219
+
220
+ RETURN_TYPES = ("IMAGE",)
221
+ RETURN_NAMES = ("image",)
222
+ FUNCTION = "run"
223
+
224
+ def run(self, img, x, y, square_size):
225
+ cropped = _crop_with_padding(img, x, y, square_size, square_size)
226
+ return (cropped,)
227
+
228
+
229
+ # -----------------------------
230
+ # 2) Cropout_Rect_From_IMG
231
+ # -----------------------------
232
+
233
+ class Cropout_Rect_From_IMG:
234
+ CATEGORY = "image/salia"
235
+
236
+ @classmethod
237
+ def INPUT_TYPES(cls):
238
+ return {
239
+ "required": {
240
+ "img": ("IMAGE",),
241
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
242
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
243
+ "width": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
244
+ "height": ("INT", {"default": 512, "min": 1, "max": 16384, "step": 1}),
245
+ }
246
+ }
247
+
248
+ RETURN_TYPES = ("IMAGE",)
249
+ RETURN_NAMES = ("image",)
250
+ FUNCTION = "run"
251
+
252
+ def run(self, img, x, y, width, height):
253
+ cropped = _crop_with_padding(img, x, y, width, height)
254
+ return (cropped,)
255
+
256
+
257
+ # -----------------------------
258
+ # 3) Paste_rect_to_img
259
+ # -----------------------------
260
+
261
+ class Paste_rect_to_img:
262
+ CATEGORY = "image/salia"
263
+
264
+ @classmethod
265
+ def INPUT_TYPES(cls):
266
+ return {
267
+ "required": {
268
+ "overlay": ("IMAGE",),
269
+ "canvas": ("IMAGE",),
270
+ "x": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
271
+ "y": ("INT", {"default": 0, "min": -100000, "max": 100000, "step": 1}),
272
+ }
273
+ }
274
+
275
+ RETURN_TYPES = ("IMAGE",)
276
+ RETURN_NAMES = ("image",)
277
+ FUNCTION = "run"
278
+
279
+ def run(self, overlay, canvas, x, y):
280
+ out = _alpha_over_region(overlay, canvas, x, y)
281
+ return (out,)
282
+
283
+
284
+ # -----------------------------
285
+ # 4) Combine_2_masks
286
+ # -----------------------------
287
+
288
+ class Combine_2_masks:
289
+ CATEGORY = "mask/salia"
290
+
291
+ @classmethod
292
+ def INPUT_TYPES(cls):
293
+ return {
294
+ "required": {
295
+ "maskA": ("MASK",),
296
+ "maskB": ("MASK",),
297
+ }
298
+ }
299
+
300
+ RETURN_TYPES = ("MASK",)
301
+ RETURN_NAMES = ("mask",)
302
+ FUNCTION = "run"
303
+
304
+ def run(self, maskA, maskB):
305
+ a = _as_mask(maskA)
306
+ b = _as_mask(maskB)
307
+
308
+ a, b = _match_batch(a, b)
309
+ b = _resize_mask_to(b, a.shape[1], a.shape[2])
310
+
311
+ out = _combine_alpha_union(a, b)
312
+ return (out,)
313
+
314
+
315
+ # -----------------------------
316
+ # 5) Combine_2_masks_invert_1
317
+ # -----------------------------
318
+
319
+ class Combine_2_masks_invert_1:
320
+ CATEGORY = "mask/salia"
321
+
322
+ @classmethod
323
+ def INPUT_TYPES(cls):
324
+ return {
325
+ "required": {
326
+ "maskA": ("MASK",),
327
+ "maskB": ("MASK",),
328
+ }
329
+ }
330
+
331
+ RETURN_TYPES = ("MASK",)
332
+ RETURN_NAMES = ("mask",)
333
+ FUNCTION = "run"
334
+
335
+ def run(self, maskA, maskB):
336
+ a = _as_mask(maskA)
337
+ b = _as_mask(maskB)
338
+
339
+ a, b = _match_batch(a, b)
340
+ b = _resize_mask_to(b, a.shape[1], a.shape[2])
341
+
342
+ a_inv = (1.0 - a).clamp(0.0, 1.0)
343
+ out = _combine_alpha_union(a_inv, b)
344
+ return (out,)
345
+
346
+
347
+ # -----------------------------
348
+ # 6) Combine_2_masks_inverse
349
+ # -----------------------------
350
+
351
+ class Combine_2_masks_inverse:
352
+ CATEGORY = "mask/salia"
353
+
354
+ @classmethod
355
+ def INPUT_TYPES(cls):
356
+ return {
357
+ "required": {
358
+ "maskA": ("MASK",),
359
+ "maskB": ("MASK",),
360
+ }
361
+ }
362
+
363
+ RETURN_TYPES = ("MASK",)
364
+ RETURN_NAMES = ("mask",)
365
+ FUNCTION = "run"
366
+
367
+ def run(self, maskA, maskB):
368
+ a = _as_mask(maskA)
369
+ b = _as_mask(maskB)
370
+
371
+ a, b = _match_batch(a, b)
372
+ b = _resize_mask_to(b, a.shape[1], a.shape[2])
373
+
374
+ a_inv = (1.0 - a).clamp(0.0, 1.0)
375
+ b_inv = (1.0 - b).clamp(0.0, 1.0)
376
+
377
+ combined_inv = _combine_alpha_union(a_inv, b_inv)
378
+ out = (1.0 - combined_inv).clamp(0.0, 1.0) # == a*b (intersection)
379
+ return (out,)
380
+
381
+
382
+ # -----------------------------
383
+ # 7) combine_masks_with_loaded
384
+ # -----------------------------
385
+
386
+ class combine_masks_with_loaded:
387
+ CATEGORY = "mask/salia"
388
+
389
+ @classmethod
390
+ def INPUT_TYPES(cls):
391
+ choices = list_pngs() or ["<no pngs found>"]
392
+ return {
393
+ "required": {
394
+ "mask": ("MASK",),
395
+ "image": (choices, {}),
396
+ }
397
+ }
398
+
399
+ RETURN_TYPES = ("MASK",)
400
+ RETURN_NAMES = ("mask",)
401
+ FUNCTION = "run"
402
+
403
+ def run(self, mask, image):
404
+ if image == "<no pngs found>":
405
+ raise FileNotFoundError("No PNGs in assets/images")
406
+
407
+ base = _as_mask(mask)
408
+
409
+ # Load image+mask from assets (Salia util)
410
+ _img, loaded_mask = load_image_from_assets(image)
411
+ loaded = _as_mask(loaded_mask)
412
+
413
+ base, loaded = _match_batch(base, loaded)
414
+ loaded = _resize_mask_to(loaded, base.shape[1], base.shape[2])
415
+
416
+ out = _combine_alpha_union(base, loaded)
417
+ return (out,)
418
+
419
+ @classmethod
420
+ def IS_CHANGED(cls, mask, image):
421
+ if image == "<no pngs found>":
422
+ return image
423
+ return file_hash(image)
424
+
425
+ @classmethod
426
+ def VALIDATE_INPUTS(cls, mask, image):
427
+ if image == "<no pngs found>":
428
+ return "No PNGs in assets/images"
429
+ try:
430
+ path = safe_path(image)
431
+ except Exception as e:
432
+ return str(e)
433
+ if not os.path.isfile(path):
434
+ return f"File not found in assets/images: {image}"
435
+ return True
436
+
437
+
438
+ # -----------------------------
439
+ # Node mappings
440
+ # -----------------------------
441
+
442
+ NODE_CLASS_MAPPINGS = {
443
+ "Cropout_Square_From_IMG": Cropout_Square_From_IMG,
444
+ "Cropout_Rect_From_IMG": Cropout_Rect_From_IMG,
445
+ "Paste_rect_to_img": Paste_rect_to_img,
446
+ "Combine_2_masks": Combine_2_masks,
447
+ "Combine_2_masks_invert_1": Combine_2_masks_invert_1,
448
+ "Combine_2_masks_inverse": Combine_2_masks_inverse,
449
+ "combine_masks_with_loaded": combine_masks_with_loaded,
450
+ }
451
+
452
+ NODE_DISPLAY_NAME_MAPPINGS = {
453
+ "Cropout_Square_From_IMG": "Cropout_Square_From_IMG",
454
+ "Cropout_Rect_From_IMG": "Cropout_Rect_From_IMG",
455
+ "Paste_rect_to_img": "Paste_rect_to_img",
456
+ "Combine_2_masks": "Combine_2_masks",
457
+ "Combine_2_masks_invert_1": "Combine_2_masks_invert_1",
458
+ "Combine_2_masks_inverse": "Combine_2_masks_inverse",
459
+ "combine_masks_with_loaded": "combine_masks_with_loaded",
460
+ }