# Save as: ComfyUI/custom_nodes/batch_merge_6_any.py # Restart ComfyUI after saving. import torch class Batch_6: """ Takes up to 6 IMAGE inputs (each can be a single image [H,W,C] or a batch [B,H,W,C]), RGB or RGBA, and concatenates them into one batch. - IMAGES_1 is required (so the node can always run). - IMAGES_2..IMAGES_6 are optional (can be left unconnected). Channel handling: - If ANY input is RGBA (C=4), output will be RGBA. - RGB inputs (C=3) will be upgraded to RGBA by adding alpha=1. - If all inputs are RGB, output stays RGB. Requirements: - All images must share the same H and W (no resizing/cropping is done). - Channels must be 3 or 4. """ CATEGORY = "image/batch" FUNCTION = "merge" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("IMAGES_OUT",) @classmethod def INPUT_TYPES(cls): return { "required": { "IMAGES_1": ("IMAGE",), }, "optional": { "IMAGES_2": ("IMAGE",), "IMAGES_3": ("IMAGE",), "IMAGES_4": ("IMAGE",), "IMAGES_5": ("IMAGE",), "IMAGES_6": ("IMAGE",), }, } @staticmethod def _normalize_to_batch(t: torch.Tensor) -> torch.Tensor: # Accept [H,W,C] as single image and convert to [1,H,W,C] if t.dim() == 3: return t.unsqueeze(0) if t.dim() == 4: return t raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(t.shape)}") @staticmethod def _ensure_channels(t: torch.Tensor) -> int: if t.dim() != 4: raise ValueError(f"Expected [B,H,W,C], got shape {tuple(t.shape)}") c = int(t.shape[-1]) if c not in (3, 4): raise ValueError(f"Expected RGB/RGBA (C=3 or 4), got C={c}") return c def merge(self, IMAGES_1, IMAGES_2=None, IMAGES_3=None, IMAGES_4=None, IMAGES_5=None, IMAGES_6=None): inputs = [IMAGES_1, IMAGES_2, IMAGES_3, IMAGES_4, IMAGES_5, IMAGES_6] tensors = [] for idx, x in enumerate(inputs, start=1): if x is None: continue if not isinstance(x, torch.Tensor): raise TypeError(f"IMAGES_{idx} is not a torch.Tensor (got {type(x)})") x = self._normalize_to_batch(x) self._ensure_channels(x) tensors.append(x) if len(tensors) == 0: # Shouldn't happen because IMAGES_1 is required, but keep it safe. raise ValueError("No images provided.") # Use first input as reference for device/dtype/size ref = tensors[0] device = ref.device dtype = ref.dtype H = int(ref.shape[1]) W = int(ref.shape[2]) # Decide output channels: RGBA if any input is RGBA target_c = 4 if any(int(t.shape[-1]) == 4 for t in tensors) else 3 prepared = [] for i, t in enumerate(tensors): # Align device/dtype if t.device != device or t.dtype != dtype: t = t.to(device=device, dtype=dtype) # Validate size if int(t.shape[1]) != H or int(t.shape[2]) != W: raise ValueError( f"Size mismatch: input #{i+1} has [H,W]=[{int(t.shape[1])},{int(t.shape[2])}] " f"but expected [{H},{W}]." ) c = int(t.shape[-1]) # Upgrade RGB -> RGBA if needed if target_c == 4 and c == 3: alpha = torch.ones((int(t.shape[0]), H, W, 1), device=device, dtype=dtype) t = torch.cat([t, alpha], dim=-1) # (No need to drop alpha because target_c is 3 only if all are 3) prepared.append(t) out = torch.cat(prepared, dim=0) return (out,) NODE_CLASS_MAPPINGS = { "Batch_6": Batch_6, } NODE_DISPLAY_NAME_MAPPINGS = { "Batch_6": "Batch 6", }