File size: 4,809 Bytes
c7a0808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch

class Batch_Sprite_BBox_Cropper:
    """

    ComfyUI custom node:

    - Takes a batch of RGBA images (or RGB+MASK).

    - Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0

    - Computes one global bounding box of visible pixels across the entire batch

    - Crops every image to the same bbox (spritesheet-safe)

    """

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE",),
                # User requirement: cutoff default = 10 (out of 255)
                "alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
                "verbose": ("BOOLEAN", {"default": True}),
            },
            # Optional: if you only have RGB image + separate alpha mask (common in ComfyUI)
            "optional": {
                "mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT")
    RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height")
    FUNCTION = "process"
    CATEGORY = "image/alpha"

    def process(self, images, alpha_cutoff=10, verbose=True, mask=None):
        """

        images: torch tensor [B, H, W, C] typically float32 in [0,1]

        mask:   torch tensor [B, H, W] or [H, W] in [0,1] (optional)

        """

        if not isinstance(images, torch.Tensor):
            raise TypeError("images must be a torch.Tensor")

        if images.dim() != 4:
            raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}")

        B, H, W, C = images.shape

        # Build RGBA tensor
        if C == 4:
            rgba = images.clone()
        elif C == 3:
            if mask is None:
                raise ValueError(
                    "Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)."
                )
            # Normalize mask to [B,H,W]
            if mask.dim() == 2:
                mask_b = mask.unsqueeze(0).expand(B, -1, -1)
            elif mask.dim() == 3:
                mask_b = mask
            else:
                raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}")

            if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W:
                raise ValueError(
                    f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}"
                )

            rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone()
        else:
            raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}")

        # 1) Alpha clamp: alpha <= (alpha_cutoff/255) -> 0
        threshold = float(alpha_cutoff) / 255.0
        alpha = rgba[..., 3]
        rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha)

        # 2) Global bbox of visible pixels across batch (alpha > 0 after clamp)
        visible = rgba[..., 3] > 0  # [B,H,W] boolean
        if not torch.any(visible):
            # Nothing visible after clamp; skip crop
            if verbose:
                print(
                    f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp "
                    f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA."
                )
            left = 0
            top = 0
            right = W - 1
            bottom = H - 1
            crop_w = W
            crop_h = H
            return (rgba, left, top, right, bottom, crop_w, crop_h)

        # Union visibility across batch -> [H,W]
        union = torch.any(visible, dim=0)

        ys = torch.any(union, dim=1)  # [H]
        xs = torch.any(union, dim=0)  # [W]

        y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1)
        x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1)

        top = int(y_idx[0].item())
        bottom = int(y_idx[-1].item())
        left = int(x_idx[0].item())
        right = int(x_idx[-1].item())

        # 3) Crop all images to the same rect (inclusive right/bottom)
        cropped = rgba[:, top:bottom + 1, left:right + 1, :]

        crop_w = right - left + 1
        crop_h = bottom - top + 1

        if verbose:
            print(
                f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} "
                f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} "
                f"(w={crop_w}, h={crop_h}), batch={B}"
            )

        return (cropped, left, top, right, bottom, crop_w, crop_h)


NODE_CLASS_MAPPINGS = {
    "Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper"
}