# ComfyUI Custom Nodes: Batch Index Tools # Put this file at: # ComfyUI/custom_nodes/batch_index_tools/__init__.py # Restart ComfyUI, then look under category: "Batch/Index" import torch def _clamp_index(index: int, batch_size: int) -> int: """Clamp index into [0, batch_size-1].""" if batch_size <= 0: raise ValueError("Input batch is empty (batch_size <= 0).") if index < 0 or index >= batch_size: print( f"[BatchIndexTools] index {index} out of range for batch_size {batch_size}; " f"clamping to valid range." ) index = max(0, min(index, batch_size - 1)) return index class BatchGetImageAtIndex: """ Node 1: - Takes an IMAGE batch and an integer index - Outputs the image at that index (as a batch of size 1) Notes: - Index is zero-based (0 is the first image). - If index is out of range, it is clamped to the nearest valid index. """ @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE",), "index": ("INT", {"default": 0, "min": 0, "max": 10**9}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "get" CATEGORY = "Batch/Index" def get(self, images, index): if not torch.is_tensor(images): raise TypeError("Expected 'images' to be a torch Tensor (ComfyUI IMAGE type).") if images.ndim != 4: raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.") b = images.shape[0] idx = _clamp_index(int(index), b) # Keep batch dimension (return a batch of 1) out = images[idx : idx + 1] return (out,) class BatchReplaceImageAtIndex: """ Node 2: - Takes an IMAGE batch, an integer index, and a single IMAGE - Replaces the batch item at that index with the provided image - Outputs the modified batch Notes: - Index is zero-based (0 is the first image). - If index is out of range, it is clamped to the nearest valid index. - The replacement image must have the same H/W/C as the batch images. - If 'image' is a batch, only the first image is used. """ @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE",), "index": ("INT", {"default": 0, "min": 0, "max": 10**9}), "image": ("IMAGE",), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("images",) FUNCTION = "replace" CATEGORY = "Batch/Index" def replace(self, images, index, image): if not torch.is_tensor(images) or not torch.is_tensor(image): raise TypeError("Expected 'images' and 'image' to be torch Tensors (ComfyUI IMAGE type).") if images.ndim != 4: raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.") if image.ndim != 4: raise ValueError(f"Expected 'image' with shape [B,H,W,C], got ndim={image.ndim}.") b = images.shape[0] idx = _clamp_index(int(index), b) # Use first image if a batch is provided replacement = image[:1] # Validate spatial/channel match if replacement.shape[1:] != images.shape[1:]: raise ValueError( "Replacement image must match batch image shape [H,W,C]. " f"Batch has [H,W,C]={tuple(images.shape[1:])}, " f"replacement has [H,W,C]={tuple(replacement.shape[1:])}." ) # Make output without mutating input out = images.clone() # Ensure dtype/device match rep0 = replacement[0].to(device=out.device, dtype=out.dtype) out[idx] = rep0 return (out,) NODE_CLASS_MAPPINGS = { "BatchGetImageAtIndex": BatchGetImageAtIndex, "BatchReplaceImageAtIndex": BatchReplaceImageAtIndex, } NODE_DISPLAY_NAME_MAPPINGS = { "BatchGetImageAtIndex": "Batch: Get Image @ Index", "BatchReplaceImageAtIndex": "Batch: Replace Image @ Index", }