import torch class Batch_Sprite_BBox_Cropper: """ ComfyUI custom node: - Takes a batch of RGBA images (or RGB+MASK). - Alpha clamp: alpha <= (alpha_cutoff / 255) -> 0 - Computes one global bounding box of visible pixels across the entire batch - Crops every image to the same bbox (spritesheet-safe) """ @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE",), # User requirement: cutoff default = 10 (out of 255) "alpha_cutoff": ("INT", {"default": 10, "min": 0, "max": 255, "step": 1}), "verbose": ("BOOLEAN", {"default": True}), }, # Optional: if you only have RGB image + separate alpha mask (common in ComfyUI) "optional": { "mask": ("MASK",), } } RETURN_TYPES = ("IMAGE", "INT", "INT", "INT", "INT", "INT", "INT") RETURN_NAMES = ("cropped_images", "left", "top", "right", "bottom", "crop_width", "crop_height") FUNCTION = "process" CATEGORY = "image/alpha" def process(self, images, alpha_cutoff=10, verbose=True, mask=None): """ images: torch tensor [B, H, W, C] typically float32 in [0,1] mask: torch tensor [B, H, W] or [H, W] in [0,1] (optional) """ if not isinstance(images, torch.Tensor): raise TypeError("images must be a torch.Tensor") if images.dim() != 4: raise ValueError(f"images must be [B,H,W,C], got shape {tuple(images.shape)}") B, H, W, C = images.shape # Build RGBA tensor if C == 4: rgba = images.clone() elif C == 3: if mask is None: raise ValueError( "Input images are RGB (C=3). Provide a MASK input or pass RGBA (C=4)." ) # Normalize mask to [B,H,W] if mask.dim() == 2: mask_b = mask.unsqueeze(0).expand(B, -1, -1) elif mask.dim() == 3: mask_b = mask else: raise ValueError(f"mask must be [H,W] or [B,H,W], got shape {tuple(mask.shape)}") if mask_b.shape[0] != B or mask_b.shape[1] != H or mask_b.shape[2] != W: raise ValueError( f"mask shape {tuple(mask_b.shape)} must match images batch/height/width {(B,H,W)}" ) rgba = torch.cat([images, mask_b.unsqueeze(-1)], dim=-1).clone() else: raise ValueError(f"Expected images with 3 (RGB) or 4 (RGBA) channels, got C={C}") # 1) Alpha clamp: alpha <= (alpha_cutoff/255) -> 0 threshold = float(alpha_cutoff) / 255.0 alpha = rgba[..., 3] rgba[..., 3] = torch.where(alpha <= threshold, torch.zeros_like(alpha), alpha) # 2) Global bbox of visible pixels across batch (alpha > 0 after clamp) visible = rgba[..., 3] > 0 # [B,H,W] boolean if not torch.any(visible): # Nothing visible after clamp; skip crop if verbose: print( f"[RGBABatchAlphaClampGlobalCrop] No visible pixels after clamp " f"(alpha_cutoff={alpha_cutoff}). Returning unchanged RGBA." ) left = 0 top = 0 right = W - 1 bottom = H - 1 crop_w = W crop_h = H return (rgba, left, top, right, bottom, crop_w, crop_h) # Union visibility across batch -> [H,W] union = torch.any(visible, dim=0) ys = torch.any(union, dim=1) # [H] xs = torch.any(union, dim=0) # [W] y_idx = torch.nonzero(ys, as_tuple=False).squeeze(1) x_idx = torch.nonzero(xs, as_tuple=False).squeeze(1) top = int(y_idx[0].item()) bottom = int(y_idx[-1].item()) left = int(x_idx[0].item()) right = int(x_idx[-1].item()) # 3) Crop all images to the same rect (inclusive right/bottom) cropped = rgba[:, top:bottom + 1, left:right + 1, :] crop_w = right - left + 1 crop_h = bottom - top + 1 if verbose: print( f"[RGBABatchAlphaClampGlobalCrop] alpha_cutoff={alpha_cutoff} " f"-> rect: left={left}, top={top}, right={right}, bottom={bottom} " f"(w={crop_w}, h={crop_h}), batch={B}" ) return (cropped, left, top, right, bottom, crop_w, crop_h) NODE_CLASS_MAPPINGS = { "Batch_Sprite_BBox_Cropper": Batch_Sprite_BBox_Cropper } NODE_DISPLAY_NAME_MAPPINGS = { "Batch_Sprite_BBox_Cropper": "Batch_Sprite_BBox_Cropper" }