File size: 4,127 Bytes
2856d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Save as: ComfyUI/custom_nodes/batch_merge_6_any.py
# Restart ComfyUI after saving.

import torch


class Batch_6:
    """

    Takes up to 6 IMAGE inputs (each can be a single image [H,W,C] or a batch [B,H,W,C]),

    RGB or RGBA, and concatenates them into one batch.



    - IMAGES_1 is required (so the node can always run).

    - IMAGES_2..IMAGES_6 are optional (can be left unconnected).



    Channel handling:

      - If ANY input is RGBA (C=4), output will be RGBA.

      - RGB inputs (C=3) will be upgraded to RGBA by adding alpha=1.

      - If all inputs are RGB, output stays RGB.



    Requirements:

      - All images must share the same H and W (no resizing/cropping is done).

      - Channels must be 3 or 4.

    """

    CATEGORY = "image/batch"
    FUNCTION = "merge"
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("IMAGES_OUT",)

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "IMAGES_1": ("IMAGE",),
            },
            "optional": {
                "IMAGES_2": ("IMAGE",),
                "IMAGES_3": ("IMAGE",),
                "IMAGES_4": ("IMAGE",),
                "IMAGES_5": ("IMAGE",),
                "IMAGES_6": ("IMAGE",),
            },
        }

    @staticmethod
    def _normalize_to_batch(t: torch.Tensor) -> torch.Tensor:
        # Accept [H,W,C] as single image and convert to [1,H,W,C]
        if t.dim() == 3:
            return t.unsqueeze(0)
        if t.dim() == 4:
            return t
        raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(t.shape)}")

    @staticmethod
    def _ensure_channels(t: torch.Tensor) -> int:
        if t.dim() != 4:
            raise ValueError(f"Expected [B,H,W,C], got shape {tuple(t.shape)}")
        c = int(t.shape[-1])
        if c not in (3, 4):
            raise ValueError(f"Expected RGB/RGBA (C=3 or 4), got C={c}")
        return c

    def merge(self, IMAGES_1, IMAGES_2=None, IMAGES_3=None, IMAGES_4=None, IMAGES_5=None, IMAGES_6=None):
        inputs = [IMAGES_1, IMAGES_2, IMAGES_3, IMAGES_4, IMAGES_5, IMAGES_6]

        tensors = []
        for idx, x in enumerate(inputs, start=1):
            if x is None:
                continue
            if not isinstance(x, torch.Tensor):
                raise TypeError(f"IMAGES_{idx} is not a torch.Tensor (got {type(x)})")

            x = self._normalize_to_batch(x)
            self._ensure_channels(x)
            tensors.append(x)

        if len(tensors) == 0:
            # Shouldn't happen because IMAGES_1 is required, but keep it safe.
            raise ValueError("No images provided.")

        # Use first input as reference for device/dtype/size
        ref = tensors[0]
        device = ref.device
        dtype = ref.dtype
        H = int(ref.shape[1])
        W = int(ref.shape[2])

        # Decide output channels: RGBA if any input is RGBA
        target_c = 4 if any(int(t.shape[-1]) == 4 for t in tensors) else 3

        prepared = []
        for i, t in enumerate(tensors):
            # Align device/dtype
            if t.device != device or t.dtype != dtype:
                t = t.to(device=device, dtype=dtype)

            # Validate size
            if int(t.shape[1]) != H or int(t.shape[2]) != W:
                raise ValueError(
                    f"Size mismatch: input #{i+1} has [H,W]=[{int(t.shape[1])},{int(t.shape[2])}] "
                    f"but expected [{H},{W}]."
                )

            c = int(t.shape[-1])

            # Upgrade RGB -> RGBA if needed
            if target_c == 4 and c == 3:
                alpha = torch.ones((int(t.shape[0]), H, W, 1), device=device, dtype=dtype)
                t = torch.cat([t, alpha], dim=-1)

            # (No need to drop alpha because target_c is 3 only if all are 3)
            prepared.append(t)

        out = torch.cat(prepared, dim=0)
        return (out,)


NODE_CLASS_MAPPINGS = {
    "Batch_6": Batch_6,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "Batch_6": "Batch 6",
}