saliacoel commited on
Commit
c7a0808
·
verified ·
1 Parent(s): 7049a6e

Upload 2 files

Browse files
Files changed (2) hide show
  1. Batch_Sprites_BBox_Cropper.py +129 -0
  2. Batchfilter.py +113 -0
Batch_Sprites_BBox_Cropper.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Batch_Sprite_BBox_Cropper:
4
+ """
5
+ ComfyUI custom node:
6
+ - Takes a batch of RGBA images (or RGB+MASK).
7
+ - Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0
8
+ - Computes one global bounding box of visible pixels across the entire batch
9
+ - Crops every image to the same bbox (spritesheet-safe)
10
+ """
11
+
12
+ @classmethod
13
+ def INPUT_TYPES(cls):
14
+ return {
15
+ "required": {
16
+ "images": ("IMAGE",),
17
+ # User requirement: cutoff default = 10 (out of 255)
18
+ "alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
19
+ "verbose": ("BOOLEAN", {"default": True}),
20
+ },
21
+ # Optional: if you only have RGB image + separate alpha mask (common in ComfyUI)
22
+ "optional": {
23
+ "mask": ("MASK",),
24
+ }
25
+ }
26
+
27
+ RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT")
28
+ RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height")
29
+ FUNCTION = "process"
30
+ CATEGORY = "image/alpha"
31
+
32
+ def process(self, images, alpha_cutoff=10, verbose=True, mask=None):
33
+ """
34
+ images: torch tensor [B, H, W, C] typically float32 in [0,1]
35
+ mask: torch tensor [B, H, W] or [H, W] in [0,1] (optional)
36
+ """
37
+
38
+ if not isinstance(images, torch.Tensor):
39
+ raise TypeError("images must be a torch.Tensor")
40
+
41
+ if images.dim() != 4:
42
+ raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}")
43
+
44
+ B, H, W, C = images.shape
45
+
46
+ # Build RGBA tensor
47
+ if C == 4:
48
+ rgba = images.clone()
49
+ elif C == 3:
50
+ if mask is None:
51
+ raise ValueError(
52
+ "Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)."
53
+ )
54
+ # Normalize mask to [B,H,W]
55
+ if mask.dim() == 2:
56
+ mask_b = mask.unsqueeze(0).expand(B, -1, -1)
57
+ elif mask.dim() == 3:
58
+ mask_b = mask
59
+ else:
60
+ raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}")
61
+
62
+ if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W:
63
+ raise ValueError(
64
+ f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}"
65
+ )
66
+
67
+ rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone()
68
+ else:
69
+ raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}")
70
+
71
+ # 1) Alpha clamp: alpha <= (alpha_cutoff/255) -> 0
72
+ threshold = float(alpha_cutoff) / 255.0
73
+ alpha = rgba[..., 3]
74
+ rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha)
75
+
76
+ # 2) Global bbox of visible pixels across batch (alpha > 0 after clamp)
77
+ visible = rgba[..., 3] > 0 # [B,H,W] boolean
78
+ if not torch.any(visible):
79
+ # Nothing visible after clamp; skip crop
80
+ if verbose:
81
+ print(
82
+ f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp "
83
+ f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA."
84
+ )
85
+ left = 0
86
+ top = 0
87
+ right = W - 1
88
+ bottom = H - 1
89
+ crop_w = W
90
+ crop_h = H
91
+ return (rgba, left, top, right, bottom, crop_w, crop_h)
92
+
93
+ # Union visibility across batch -> [H,W]
94
+ union = torch.any(visible, dim=0)
95
+
96
+ ys = torch.any(union, dim=1) # [H]
97
+ xs = torch.any(union, dim=0) # [W]
98
+
99
+ y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1)
100
+ x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1)
101
+
102
+ top = int(y_idx[0].item())
103
+ bottom = int(y_idx[-1].item())
104
+ left = int(x_idx[0].item())
105
+ right = int(x_idx[-1].item())
106
+
107
+ # 3) Crop all images to the same rect (inclusive right/bottom)
108
+ cropped = rgba[:, top:bottom + 1, left:right + 1, :]
109
+
110
+ crop_w = right - left + 1
111
+ crop_h = bottom - top + 1
112
+
113
+ if verbose:
114
+ print(
115
+ f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} "
116
+ f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} "
117
+ f"(w={crop_w}, h={crop_h}), batch={B}"
118
+ )
119
+
120
+ return (cropped, left, top, right, bottom, crop_w, crop_h)
121
+
122
+
123
+ NODE_CLASS_MAPPINGS = {
124
+ "Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper
125
+ }
126
+
127
+ NODE_DISPLAY_NAME_MAPPINGS = {
128
+ "Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper"
129
+ }
Batchfilter.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+
4
+
5
+ class BatchFilterKeepFirstLast:
6
+ """
7
+ Batch filter node (IMAGE -> IMAGE) that always keeps the first and last image.
8
+
9
+ Modes (int):
10
+ - 0 : passthrough (no changes)
11
+ - 10 : keep 1st, 3rd, 5th, ... (drop every 2nd), but always keep last
12
+ - <10: keep slightly MORE than mode 10 (adds back frames, evenly distributed)
13
+ - >10: keep slightly FEWER than mode 10 (removes extra frames, evenly distributed)
14
+
15
+ Notes:
16
+ - In ComfyUI, IMAGE is a batch (torch.Tensor) of shape [B, H, W, C]. We only filter B. :contentReference[oaicite:2]{index=2}
17
+ - Works with RGBA (C=4) or RGB (C=3) since we do not modify channels.
18
+ """
19
+
20
+ CATEGORY = "image/batch"
21
+ RETURN_TYPES = ("IMAGE",)
22
+ RETURN_NAMES = ("images",)
23
+ FUNCTION = "filter_batch"
24
+
25
+ # How much each +/-1 step away from mode 10 adjusts the batch, as a fraction of B.
26
+ # With your reference batch size B=40:
27
+ # round(40 * 0.05) = 2 images per step (i.e., mode 9 adds ~2; mode 11 removes ~2).
28
+ ADJUST_PER_STEP_FRACTION = 0.05
29
+
30
+ @classmethod
31
+ def INPUT_TYPES(cls):
32
+ return {
33
+ "required": {
34
+ "images": ("IMAGE",),
35
+ "mode": ("INT", {"default": 10, "min": 0, "max": 20, "step": 1}),
36
+ }
37
+ }
38
+
39
+ def filter_batch(self, images: torch.Tensor, mode: int):
40
+ if not isinstance(images, torch.Tensor):
41
+ raise TypeError("images must be a torch.Tensor")
42
+
43
+ if images.ndim != 4:
44
+ raise ValueError(f"Expected images with shape [B,H,W,C], got {tuple(images.shape)}")
45
+
46
+ b = int(images.shape[0])
47
+ if b <= 1 or mode == 0:
48
+ return (images,)
49
+
50
+ # Base pattern for mode 10: keep 0,2,4,... plus always keep last.
51
+ keep = list(range(0, b, 2))
52
+ if (b - 1) not in keep:
53
+ keep.append(b - 1)
54
+ keep = sorted(set(keep))
55
+
56
+ if mode != 10:
57
+ delta = mode - 10 # <0 => keep more, >0 => keep fewer
58
+ step = max(1, int(round(b * self.ADJUST_PER_STEP_FRACTION)))
59
+
60
+ min_keep = 1 if b == 1 else 2 # first+last (or just one if batch=1)
61
+
62
+ if delta < 0:
63
+ # Add frames back from those we dropped, evenly spread.
64
+ add_count = min((-delta) * step, b - len(keep))
65
+ if add_count > 0:
66
+ candidates = [i for i in range(b) if i not in keep and i not in (0, b - 1)]
67
+ add_idxs = self._evenly_pick(candidates, add_count)
68
+ keep = sorted(set(keep + add_idxs))
69
+
70
+ elif delta > 0:
71
+ # Remove extra frames from the kept set (but never first/last), evenly spread.
72
+ max_removable = max(0, len(keep) - min_keep)
73
+ remove_count = min(delta * step, max_removable)
74
+ if remove_count > 0:
75
+ removable = [i for i in keep if i not in (0, b - 1)]
76
+ remove_idxs = set(self._evenly_pick(removable, remove_count))
77
+ keep = [i for i in keep if i not in remove_idxs]
78
+ keep = sorted(set(keep))
79
+
80
+ # Enforce rule: always keep first and last.
81
+ if 0 not in keep:
82
+ keep.insert(0, 0)
83
+ if (b - 1) not in keep:
84
+ keep.append(b - 1)
85
+ keep = sorted(set(keep))
86
+
87
+ out = images[keep, ...]
88
+ return (out,)
89
+
90
+ @staticmethod
91
+ def _evenly_pick(items: List[int], k: int) -> List[int]:
92
+ """
93
+ Pick k unique elements from items, evenly distributed across the list.
94
+ Deterministic, preserves ordering of selected indices in 'items'.
95
+ """
96
+ m = len(items)
97
+ if k <= 0 or m == 0:
98
+ return []
99
+ k = min(k, m)
100
+
101
+ # Choose k positions in (0..m-1) spread out, avoiding endpoints bias.
102
+ # This yields strictly increasing positions for k<=m.
103
+ positions = [int((i + 1) * (m + 1) / (k + 1)) - 1 for i in range(k)]
104
+ return [items[p] for p in positions]
105
+
106
+
107
+ NODE_CLASS_MAPPINGS = {
108
+ "BatchFilterKeepFirstLast": BatchFilterKeepFirstLast,
109
+ }
110
+
111
+ NODE_DISPLAY_NAME_MAPPINGS = {
112
+ "BatchFilterKeepFirstLast": "Batch Filter (Keep First/Last)",
113
+ }