# Save as: ComfyUI/custom_nodes/batch_slice_start_end.py # Restart ComfyUI after saving. import torch class Get_Batch_Range_Start_To_End: """ Inputs: - start_id (INT) - end_id (INT) - images (IMAGE batch, typically torch.Tensor [B, H, W, C]) Outputs: - sliced_images (IMAGE batch) - status (STRING): "ok" or an error message - count (INT): number of images in the *input* batch Behavior: - Returns images from start_id to end_id (inclusive). - If invalid / impossible (out of range, start>end, empty batch, etc.), returns the original input batch unchanged, plus an error message. """ CATEGORY = "image/batch" FUNCTION = "slice_batch" RETURN_TYPES = ("IMAGE", "STRING", "INT") RETURN_NAMES = ("images", "status", "count") @classmethod def INPUT_TYPES(cls): return { "required": { "start_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}), "end_id": ("INT", {"default": 0, "min": 0, "max": 1_000_000, "step": 1}), "images": ("IMAGE",), } } def slice_batch(self, start_id, end_id, images): # Validate tensor if not isinstance(images, torch.Tensor): # Can't reliably "return original" if it's not a tensor, but try anyway. return (images, "error: images is not a torch.Tensor", 0) # Normalize to batched shape for safety original = images if images.dim() == 3: images = images.unsqueeze(0) # [1, H, W, C] elif images.dim() != 4: # Return original unchanged count = int(images.shape[0]) if images.dim() > 0 else 0 return (original, f"error: expected IMAGE with 3 or 4 dims, got {tuple(images.shape)}", count) b = int(images.shape[0]) # input batch count if b <= 0: return (images, "error: empty batch (B=0)", 0) # Validate indices (inclusive slicing) if start_id > end_id: return (images, f"error: start_id > end_id ({start_id} > {end_id})", b) if start_id < 0 or end_id < 0: return (images, f"error: negative index not allowed (start_id={start_id}, end_id={end_id})", b) if start_id >= b or end_id >= b: return ( images, f"error: out of range (start_id={start_id}, end_id={end_id}, batch_size={b})", b, ) # Slice inclusive: [start_id, end_id] sliced = images[start_id : end_id + 1].clone() return (sliced, "ok", b) NODE_CLASS_MAPPINGS = { "Get_Batch_Range_Start_To_End": Get_Batch_Range_Start_To_End, } NODE_DISPLAY_NAME_MAPPINGS = { "Get_Batch_Range_Start_To_End": "Get Batch from Batch (From Start ID to End ID)", }