File size: 4,127 Bytes
2856d56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | # Save as: ComfyUI/custom_nodes/batch_merge_6_any.py
# Restart ComfyUI after saving.
import torch
class Batch_6:
"""
Takes up to 6 IMAGE inputs (each can be a single image [H,W,C] or a batch [B,H,W,C]),
RGB or RGBA, and concatenates them into one batch.
- IMAGES_1 is required (so the node can always run).
- IMAGES_2..IMAGES_6 are optional (can be left unconnected).
Channel handling:
- If ANY input is RGBA (C=4), output will be RGBA.
- RGB inputs (C=3) will be upgraded to RGBA by adding alpha=1.
- If all inputs are RGB, output stays RGB.
Requirements:
- All images must share the same H and W (no resizing/cropping is done).
- Channels must be 3 or 4.
"""
CATEGORY = "image/batch"
FUNCTION = "merge"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("IMAGES_OUT",)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"IMAGES_1": ("IMAGE",),
},
"optional": {
"IMAGES_2": ("IMAGE",),
"IMAGES_3": ("IMAGE",),
"IMAGES_4": ("IMAGE",),
"IMAGES_5": ("IMAGE",),
"IMAGES_6": ("IMAGE",),
},
}
@staticmethod
def _normalize_to_batch(t: torch.Tensor) -> torch.Tensor:
# Accept [H,W,C] as single image and convert to [1,H,W,C]
if t.dim() == 3:
return t.unsqueeze(0)
if t.dim() == 4:
return t
raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(t.shape)}")
@staticmethod
def _ensure_channels(t: torch.Tensor) -> int:
if t.dim() != 4:
raise ValueError(f"Expected [B,H,W,C], got shape {tuple(t.shape)}")
c = int(t.shape[-1])
if c not in (3, 4):
raise ValueError(f"Expected RGB/RGBA (C=3 or 4), got C={c}")
return c
def merge(self, IMAGES_1, IMAGES_2=None, IMAGES_3=None, IMAGES_4=None, IMAGES_5=None, IMAGES_6=None):
inputs = [IMAGES_1, IMAGES_2, IMAGES_3, IMAGES_4, IMAGES_5, IMAGES_6]
tensors = []
for idx, x in enumerate(inputs, start=1):
if x is None:
continue
if not isinstance(x, torch.Tensor):
raise TypeError(f"IMAGES_{idx} is not a torch.Tensor (got {type(x)})")
x = self._normalize_to_batch(x)
self._ensure_channels(x)
tensors.append(x)
if len(tensors) == 0:
# Shouldn't happen because IMAGES_1 is required, but keep it safe.
raise ValueError("No images provided.")
# Use first input as reference for device/dtype/size
ref = tensors[0]
device = ref.device
dtype = ref.dtype
H = int(ref.shape[1])
W = int(ref.shape[2])
# Decide output channels: RGBA if any input is RGBA
target_c = 4 if any(int(t.shape[-1]) == 4 for t in tensors) else 3
prepared = []
for i, t in enumerate(tensors):
# Align device/dtype
if t.device != device or t.dtype != dtype:
t = t.to(device=device, dtype=dtype)
# Validate size
if int(t.shape[1]) != H or int(t.shape[2]) != W:
raise ValueError(
f"Size mismatch: input #{i+1} has [H,W]=[{int(t.shape[1])},{int(t.shape[2])}] "
f"but expected [{H},{W}]."
)
c = int(t.shape[-1])
# Upgrade RGB -> RGBA if needed
if target_c == 4 and c == 3:
alpha = torch.ones((int(t.shape[0]), H, W, 1), device=device, dtype=dtype)
t = torch.cat([t, alpha], dim=-1)
# (No need to drop alpha because target_c is 3 only if all are 3)
prepared.append(t)
out = torch.cat(prepared, dim=0)
return (out,)
NODE_CLASS_MAPPINGS = {
"Batch_6": Batch_6,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Batch_6": "Batch 6",
}
|