File size: 4,809 Bytes
c7a0808 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import torch
class Batch_Sprite_BBox_Cropper:
"""
ComfyUI custom node:
- Takes a batch of RGBA images (or RGB+MASK).
- Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0
- Computes one global bounding box of visible pixels across the entire batch
- Crops every image to the same bbox (spritesheet-safe)
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
# User requirement: cutoff default = 10 (out of 255)
"alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}),
"verbose": ("BOOLEAN", {"default": True}),
},
# Optional: if you only have RGB image + separate alpha mask (common in ComfyUI)
"optional": {
"mask": ("MASK",),
}
}
RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT")
RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height")
FUNCTION = "process"
CATEGORY = "image/alpha"
def process(self, images, alpha_cutoff=10, verbose=True, mask=None):
"""
images: torch tensor [B, H, W, C] typically float32 in [0,1]
mask: torch tensor [B, H, W] or [H, W] in [0,1] (optional)
"""
if not isinstance(images, torch.Tensor):
raise TypeError("images must be a torch.Tensor")
if images.dim() != 4:
raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}")
B, H, W, C = images.shape
# Build RGBA tensor
if C == 4:
rgba = images.clone()
elif C == 3:
if mask is None:
raise ValueError(
"Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)."
)
# Normalize mask to [B,H,W]
if mask.dim() == 2:
mask_b = mask.unsqueeze(0).expand(B, -1, -1)
elif mask.dim() == 3:
mask_b = mask
else:
raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}")
if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W:
raise ValueError(
f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}"
)
rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone()
else:
raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}")
# 1) Alpha clamp: alpha <= (alpha_cutoff/255) -> 0
threshold = float(alpha_cutoff) / 255.0
alpha = rgba[..., 3]
rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha)
# 2) Global bbox of visible pixels across batch (alpha > 0 after clamp)
visible = rgba[..., 3] > 0 # [B,H,W] boolean
if not torch.any(visible):
# Nothing visible after clamp; skip crop
if verbose:
print(
f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp "
f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA."
)
left = 0
top = 0
right = W - 1
bottom = H - 1
crop_w = W
crop_h = H
return (rgba, left, top, right, bottom, crop_w, crop_h)
# Union visibility across batch -> [H,W]
union = torch.any(visible, dim=0)
ys = torch.any(union, dim=1) # [H]
xs = torch.any(union, dim=0) # [W]
y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1)
x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1)
top = int(y_idx[0].item())
bottom = int(y_idx[-1].item())
left = int(x_idx[0].item())
right = int(x_idx[-1].item())
# 3) Crop all images to the same rect (inclusive right/bottom)
cropped = rgba[:, top:bottom + 1, left:right + 1, :]
crop_w = right - left + 1
crop_h = bottom - top + 1
if verbose:
print(
f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} "
f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} "
f"(w={crop_w}, h={crop_h}), batch={B}"
)
return (cropped, left, top, right, bottom, crop_w, crop_h)
NODE_CLASS_MAPPINGS = {
"Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper"
}
|