| |
|
| |
|
| |
|
| | import torch
|
| |
|
| |
|
| | class Batch_6:
|
| | """
|
| | Takes up to 6 IMAGE inputs (each can be a single image [H,W,C] or a batch [B,H,W,C]),
|
| | RGB or RGBA, and concatenates them into one batch.
|
| |
|
| | - IMAGES_1 is required (so the node can always run).
|
| | - IMAGES_2..IMAGES_6 are optional (can be left unconnected).
|
| |
|
| | Channel handling:
|
| | - If ANY input is RGBA (C=4), output will be RGBA.
|
| | - RGB inputs (C=3) will be upgraded to RGBA by adding alpha=1.
|
| | - If all inputs are RGB, output stays RGB.
|
| |
|
| | Requirements:
|
| | - All images must share the same H and W (no resizing/cropping is done).
|
| | - Channels must be 3 or 4.
|
| | """
|
| |
|
| | CATEGORY = "image/batch"
|
| | FUNCTION = "merge"
|
| | RETURN_TYPES = ("IMAGE",)
|
| | RETURN_NAMES = ("IMAGES_OUT",)
|
| |
|
| | @classmethod
|
| | def INPUT_TYPES(cls):
|
| | return {
|
| | "required": {
|
| | "IMAGES_1": ("IMAGE",),
|
| | },
|
| | "optional": {
|
| | "IMAGES_2": ("IMAGE",),
|
| | "IMAGES_3": ("IMAGE",),
|
| | "IMAGES_4": ("IMAGE",),
|
| | "IMAGES_5": ("IMAGE",),
|
| | "IMAGES_6": ("IMAGE",),
|
| | },
|
| | }
|
| |
|
| | @staticmethod
|
| | def _normalize_to_batch(t: torch.Tensor) -> torch.Tensor:
|
| |
|
| | if t.dim() == 3:
|
| | return t.unsqueeze(0)
|
| | if t.dim() == 4:
|
| | return t
|
| | raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(t.shape)}")
|
| |
|
| | @staticmethod
|
| | def _ensure_channels(t: torch.Tensor) -> int:
|
| | if t.dim() != 4:
|
| | raise ValueError(f"Expected [B,H,W,C], got shape {tuple(t.shape)}")
|
| | c = int(t.shape[-1])
|
| | if c not in (3, 4):
|
| | raise ValueError(f"Expected RGB/RGBA (C=3 or 4), got C={c}")
|
| | return c
|
| |
|
| | def merge(self, IMAGES_1, IMAGES_2=None, IMAGES_3=None, IMAGES_4=None, IMAGES_5=None, IMAGES_6=None):
|
| | inputs = [IMAGES_1, IMAGES_2, IMAGES_3, IMAGES_4, IMAGES_5, IMAGES_6]
|
| |
|
| | tensors = []
|
| | for idx, x in enumerate(inputs, start=1):
|
| | if x is None:
|
| | continue
|
| | if not isinstance(x, torch.Tensor):
|
| | raise TypeError(f"IMAGES_{idx} is not a torch.Tensor (got {type(x)})")
|
| |
|
| | x = self._normalize_to_batch(x)
|
| | self._ensure_channels(x)
|
| | tensors.append(x)
|
| |
|
| | if len(tensors) == 0:
|
| |
|
| | raise ValueError("No images provided.")
|
| |
|
| |
|
| | ref = tensors[0]
|
| | device = ref.device
|
| | dtype = ref.dtype
|
| | H = int(ref.shape[1])
|
| | W = int(ref.shape[2])
|
| |
|
| |
|
| | target_c = 4 if any(int(t.shape[-1]) == 4 for t in tensors) else 3
|
| |
|
| | prepared = []
|
| | for i, t in enumerate(tensors):
|
| |
|
| | if t.device != device or t.dtype != dtype:
|
| | t = t.to(device=device, dtype=dtype)
|
| |
|
| |
|
| | if int(t.shape[1]) != H or int(t.shape[2]) != W:
|
| | raise ValueError(
|
| | f"Size mismatch: input #{i+1} has [H,W]=[{int(t.shape[1])},{int(t.shape[2])}] "
|
| | f"but expected [{H},{W}]."
|
| | )
|
| |
|
| | c = int(t.shape[-1])
|
| |
|
| |
|
| | if target_c == 4 and c == 3:
|
| | alpha = torch.ones((int(t.shape[0]), H, W, 1), device=device, dtype=dtype)
|
| | t = torch.cat([t, alpha], dim=-1)
|
| |
|
| |
|
| | prepared.append(t)
|
| |
|
| | out = torch.cat(prepared, dim=0)
|
| | return (out,)
|
| |
|
| |
|
| | NODE_CLASS_MAPPINGS = {
|
| | "Batch_6": Batch_6,
|
| | }
|
| |
|
| | NODE_DISPLAY_NAME_MAPPINGS = {
|
| | "Batch_6": "Batch 6",
|
| | }
|
| |
|