saliacoel commited on
Commit
2856d56
·
verified ·
1 Parent(s): 166476b

Upload Batch_6.py

Browse files
Files changed (1) hide show
  1. Batch_6.py +124 -0
Batch_6.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save as: ComfyUI/custom_nodes/batch_merge_6_any.py
2
+ # Restart ComfyUI after saving.
3
+
4
+ import torch
5
+
6
+
7
+ class Batch_6:
8
+ """
9
+ Takes up to 6 IMAGE inputs (each can be a single image [H,W,C] or a batch [B,H,W,C]),
10
+ RGB or RGBA, and concatenates them into one batch.
11
+
12
+ - IMAGES_1 is required (so the node can always run).
13
+ - IMAGES_2..IMAGES_6 are optional (can be left unconnected).
14
+
15
+ Channel handling:
16
+ - If ANY input is RGBA (C=4), output will be RGBA.
17
+ - RGB inputs (C=3) will be upgraded to RGBA by adding alpha=1.
18
+ - If all inputs are RGB, output stays RGB.
19
+
20
+ Requirements:
21
+ - All images must share the same H and W (no resizing/cropping is done).
22
+ - Channels must be 3 or 4.
23
+ """
24
+
25
+ CATEGORY = "image/batch"
26
+ FUNCTION = "merge"
27
+ RETURN_TYPES = ("IMAGE",)
28
+ RETURN_NAMES = ("IMAGES_OUT",)
29
+
30
+ @classmethod
31
+ def INPUT_TYPES(cls):
32
+ return {
33
+ "required": {
34
+ "IMAGES_1": ("IMAGE",),
35
+ },
36
+ "optional": {
37
+ "IMAGES_2": ("IMAGE",),
38
+ "IMAGES_3": ("IMAGE",),
39
+ "IMAGES_4": ("IMAGE",),
40
+ "IMAGES_5": ("IMAGE",),
41
+ "IMAGES_6": ("IMAGE",),
42
+ },
43
+ }
44
+
45
+ @staticmethod
46
+ def _normalize_to_batch(t: torch.Tensor) -> torch.Tensor:
47
+ # Accept [H,W,C] as single image and convert to [1,H,W,C]
48
+ if t.dim() == 3:
49
+ return t.unsqueeze(0)
50
+ if t.dim() == 4:
51
+ return t
52
+ raise ValueError(f"Expected IMAGE tensor with 3 or 4 dims, got shape {tuple(t.shape)}")
53
+
54
+ @staticmethod
55
+ def _ensure_channels(t: torch.Tensor) -> int:
56
+ if t.dim() != 4:
57
+ raise ValueError(f"Expected [B,H,W,C], got shape {tuple(t.shape)}")
58
+ c = int(t.shape[-1])
59
+ if c not in (3, 4):
60
+ raise ValueError(f"Expected RGB/RGBA (C=3 or 4), got C={c}")
61
+ return c
62
+
63
+ def merge(self, IMAGES_1, IMAGES_2=None, IMAGES_3=None, IMAGES_4=None, IMAGES_5=None, IMAGES_6=None):
64
+ inputs = [IMAGES_1, IMAGES_2, IMAGES_3, IMAGES_4, IMAGES_5, IMAGES_6]
65
+
66
+ tensors = []
67
+ for idx, x in enumerate(inputs, start=1):
68
+ if x is None:
69
+ continue
70
+ if not isinstance(x, torch.Tensor):
71
+ raise TypeError(f"IMAGES_{idx} is not a torch.Tensor (got {type(x)})")
72
+
73
+ x = self._normalize_to_batch(x)
74
+ self._ensure_channels(x)
75
+ tensors.append(x)
76
+
77
+ if len(tensors) == 0:
78
+ # Shouldn't happen because IMAGES_1 is required, but keep it safe.
79
+ raise ValueError("No images provided.")
80
+
81
+ # Use first input as reference for device/dtype/size
82
+ ref = tensors[0]
83
+ device = ref.device
84
+ dtype = ref.dtype
85
+ H = int(ref.shape[1])
86
+ W = int(ref.shape[2])
87
+
88
+ # Decide output channels: RGBA if any input is RGBA
89
+ target_c = 4 if any(int(t.shape[-1]) == 4 for t in tensors) else 3
90
+
91
+ prepared = []
92
+ for i, t in enumerate(tensors):
93
+ # Align device/dtype
94
+ if t.device != device or t.dtype != dtype:
95
+ t = t.to(device=device, dtype=dtype)
96
+
97
+ # Validate size
98
+ if int(t.shape[1]) != H or int(t.shape[2]) != W:
99
+ raise ValueError(
100
+ f"Size mismatch: input #{i+1} has [H,W]=[{int(t.shape[1])},{int(t.shape[2])}] "
101
+ f"but expected [{H},{W}]."
102
+ )
103
+
104
+ c = int(t.shape[-1])
105
+
106
+ # Upgrade RGB -> RGBA if needed
107
+ if target_c == 4 and c == 3:
108
+ alpha = torch.ones((int(t.shape[0]), H, W, 1), device=device, dtype=dtype)
109
+ t = torch.cat([t, alpha], dim=-1)
110
+
111
+ # (No need to drop alpha because target_c is 3 only if all are 3)
112
+ prepared.append(t)
113
+
114
+ out = torch.cat(prepared, dim=0)
115
+ return (out,)
116
+
117
+
118
+ NODE_CLASS_MAPPINGS = {
119
+ "Batch_6": Batch_6,
120
+ }
121
+
122
+ NODE_DISPLAY_NAME_MAPPINGS = {
123
+ "Batch_6": "Batch 6",
124
+ }