File size: 4,276 Bytes
6f46f5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | # ComfyUI Custom Nodes: Batch Index Tools
# Put this file at:
# ComfyUI/custom_nodes/batch_index_tools/__init__.py
# Restart ComfyUI, then look under category: "Batch/Index"
import torch
def _clamp_index(index: int, batch_size: int) -> int:
"""Clamp index into [0, batch_size-1]."""
if batch_size <= 0:
raise ValueError("Input batch is empty (batch_size <= 0).")
if index < 0 or index >= batch_size:
print(
f"[BatchIndexTools] index {index} out of range for batch_size {batch_size}; "
f"clamping to valid range."
)
index = max(0, min(index, batch_size - 1))
return index
class BatchGetImageAtIndex:
"""
Node 1:
- Takes an IMAGE batch and an integer index
- Outputs the image at that index (as a batch of size 1)
Notes:
- Index is zero-based (0 is the first image).
- If index is out of range, it is clamped to the nearest valid index.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "get"
CATEGORY = "Batch/Index"
def get(self, images, index):
if not torch.is_tensor(images):
raise TypeError("Expected 'images' to be a torch Tensor (ComfyUI IMAGE type).")
if images.ndim != 4:
raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
b = images.shape[0]
idx = _clamp_index(int(index), b)
# Keep batch dimension (return a batch of 1)
out = images[idx : idx + 1]
return (out,)
class BatchReplaceImageAtIndex:
"""
Node 2:
- Takes an IMAGE batch, an integer index, and a single IMAGE
- Replaces the batch item at that index with the provided image
- Outputs the modified batch
Notes:
- Index is zero-based (0 is the first image).
- If index is out of range, it is clamped to the nearest valid index.
- The replacement image must have the same H/W/C as the batch images.
- If 'image' is a batch, only the first image is used.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"index": ("INT", {"default": 0, "min": 0, "max": 10**9}),
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "replace"
CATEGORY = "Batch/Index"
def replace(self, images, index, image):
if not torch.is_tensor(images) or not torch.is_tensor(image):
raise TypeError("Expected 'images' and 'image' to be torch Tensors (ComfyUI IMAGE type).")
if images.ndim != 4:
raise ValueError(f"Expected 'images' with shape [B,H,W,C], got ndim={images.ndim}.")
if image.ndim != 4:
raise ValueError(f"Expected 'image' with shape [B,H,W,C], got ndim={image.ndim}.")
b = images.shape[0]
idx = _clamp_index(int(index), b)
# Use first image if a batch is provided
replacement = image[:1]
# Validate spatial/channel match
if replacement.shape[1:] != images.shape[1:]:
raise ValueError(
"Replacement image must match batch image shape [H,W,C]. "
f"Batch has [H,W,C]={tuple(images.shape[1:])}, "
f"replacement has [H,W,C]={tuple(replacement.shape[1:])}."
)
# Make output without mutating input
out = images.clone()
# Ensure dtype/device match
rep0 = replacement[0].to(device=out.device, dtype=out.dtype)
out[idx] = rep0
return (out,)
NODE_CLASS_MAPPINGS = {
"BatchGetImageAtIndex": BatchGetImageAtIndex,
"BatchReplaceImageAtIndex": BatchReplaceImageAtIndex,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"BatchGetImageAtIndex": "Batch: Get Image @ Index",
"BatchReplaceImageAtIndex": "Batch: Replace Image @ Index",
}
|