Stkzzzz222 commited on
Commit
9cac472
·
verified ·
1 Parent(s): 7026e37

Upload image_batcher_by_indexz.py

Browse files
Files changed (1) hide show
  1. image_batcher_by_indexz.py +149 -0
image_batcher_by_indexz.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ # --- NEW CLASS NAME ---
5
+ class ImageBatcherByIndexProV2:
6
+ """
7
+ (V2) A ComfyUI node that creates a batch of images with advanced features.
8
+ - This version supports both single images and input batches.
9
+ - User specifies max_frames for the output batch.
10
+ - For each input image (up to 6), user can specify its start position (frame_index)
11
+ and mask behavior (mask_as_image_area_is_black or white).
12
+ - The 'repeat_count' parameter has dual functionality:
13
+ - If the input is a single image (batch size 1), 'repeat_count' dictates
14
+ how many times that single image is repeated.
15
+ - If the input is a batch of images (batch size > 1), 'repeat_count'
16
+ specifies how many images to take sequentially from that input batch.
17
+ - Output resolution is determined by the first connected input image.
18
+ - Frames not filled by an input image will be RGB(127,127,127).
19
+ - Outputs 'output_batch' and 'batch_masks'.
20
+ """
21
+
22
+ MASK_BEHAVIOR_OPTIONS = ["IMAGE_AREA_IS_BLACK", "IMAGE_AREA_IS_WHITE"]
23
+
24
+ @classmethod
25
+ def INPUT_TYPES(s):
26
+ inputs = {
27
+ "required": {
28
+ "max_frames": ("INT", {"default": 50, "min": 1, "max": 8192, "step": 1, "display": "number"}),
29
+ },
30
+ "optional": {}
31
+ }
32
+ for i in range(1, 7):
33
+ inputs["optional"][f"image_{i}"] = ("IMAGE",)
34
+ inputs["optional"][f"frame_index_{i}"] = ("INT", {"default": i, "min": 1, "max": 8192, "step": 1, "display": "number"})
35
+ inputs["optional"][f"repeat_count_{i}"] = ("INT", {"default": 1, "min": 1, "max": 8192, "step": 1, "display": "number"})
36
+ inputs["optional"][f"mask_behavior_{i}"] = (s.MASK_BEHAVIOR_OPTIONS, {"default": s.MASK_BEHAVIOR_OPTIONS[0]})
37
+ return inputs
38
+
39
+ RETURN_TYPES = ("IMAGE", "IMAGE",)
40
+ RETURN_NAMES = ("output_batch", "batch_masks",)
41
+ FUNCTION = "create_batch_pro"
42
+ CATEGORY = "utils/batching"
43
+
44
+ def _prepare_color_frame(self, color_tuple, target_h, target_w, target_c, dtype, device):
45
+ color_tensor = torch.tensor(color_tuple, dtype=dtype, device=device)
46
+ return color_tensor.reshape(1, 1, target_c).expand(target_h, target_w, target_c)
47
+
48
+ def _process_single_image(self, image_b1hwc, target_h, target_w, target_c, dtype, device):
49
+ current_image_orig = image_b1hwc
50
+ if current_image_orig.shape[3] != target_c:
51
+ current_image_adapted = torch.zeros((1, target_h, target_w, target_c), dtype=dtype, device=device)
52
+ common_channels = min(current_image_orig.shape[3], target_c)
53
+ temp_resized = F.interpolate(current_image_orig.permute(0, 3, 1, 2), size=(target_h, target_w), mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
54
+ current_image_adapted[..., :common_channels] = temp_resized[..., :common_channels]
55
+
56
+ if target_c == 4 and current_image_orig.shape[3] < 4:
57
+ current_image_adapted[..., 3] = 1.0
58
+ elif target_c == 1 and current_image_orig.shape[3] > 1:
59
+ current_image_adapted[..., 0] = temp_resized[..., :3].mean(dim=3)
60
+ current_image_orig = current_image_adapted
61
+
62
+ if current_image_orig.shape[1] != target_h or current_image_orig.shape[2] != target_w:
63
+ img_to_resize_permuted = current_image_orig.permute(0, 3, 1, 2)
64
+ resized_permuted = F.interpolate(img_to_resize_permuted, size=(target_h, target_w), mode='bilinear', align_corners=False)
65
+ processed_image = resized_permuted.permute(0, 2, 3, 1)[0]
66
+ else:
67
+ processed_image = current_image_orig[0]
68
+ return processed_image
69
+
70
+ def create_batch_pro(self, max_frames, **kwargs):
71
+ target_h, target_w, target_c = -1, -1, -1
72
+ first_valid_image_tensor = None
73
+ base_dtype = torch.float32
74
+ base_device = 'cpu'
75
+
76
+ for i in range(1, 7):
77
+ img_tensor = kwargs.get(f"image_{i}")
78
+ if img_tensor is not None:
79
+ first_valid_image_tensor = img_tensor
80
+ target_h, target_w, target_c = img_tensor.shape[1], img_tensor.shape[2], img_tensor.shape[3]
81
+ base_dtype = img_tensor.dtype
82
+ base_device = img_tensor.device
83
+ break
84
+
85
+ if first_valid_image_tensor is None:
86
+ empty_img = torch.empty(0, 1, 1, 3, dtype=base_dtype, device=base_device)
87
+ return (empty_img, empty_img,)
88
+
89
+ fill_value_rgb_norm = 127.0 / 255.0
90
+ fill_color_tuple = (fill_value_rgb_norm,) * min(target_c, 3)
91
+ white_color_tuple = (1.0,) * min(target_c, 3)
92
+ black_color_tuple = (0.0,) * min(target_c, 3)
93
+ if target_c > 3:
94
+ fill_color_tuple += (1.0,)
95
+ white_color_tuple += (1.0,)
96
+ black_color_tuple += (1.0,)
97
+
98
+ fill_frame = self._prepare_color_frame(fill_color_tuple, target_h, target_w, target_c, base_dtype, base_device)
99
+ white_frame_mask = self._prepare_color_frame(white_color_tuple, target_h, target_w, target_c, base_dtype, base_device)
100
+ black_frame_mask = self._prepare_color_frame(black_color_tuple, target_h, target_w, target_c, base_dtype, base_device)
101
+
102
+ output_batch = torch.empty((max_frames, target_h, target_w, target_c), dtype=base_dtype, device=base_device)
103
+ output_batch[:] = fill_frame
104
+ batch_masks = torch.empty((max_frames, target_h, target_w, target_c), dtype=base_dtype, device=base_device)
105
+ batch_masks[:] = white_frame_mask
106
+
107
+ for i in range(1, 7):
108
+ img_tensor = kwargs.get(f"image_{i}")
109
+ if img_tensor is None: continue
110
+
111
+ frame_index_user = kwargs.get(f"frame_index_{i}", i)
112
+ repeat_count = kwargs.get(f"repeat_count_{i}", 1)
113
+ mask_behavior = kwargs.get(f"mask_behavior_{i}", self.MASK_BEHAVIOR_OPTIONS[0])
114
+ start_idx = frame_index_user - 1
115
+ chosen_mask_frame = black_frame_mask if mask_behavior == self.MASK_BEHAVIOR_OPTIONS[0] else white_frame_mask
116
+
117
+ input_batch_size = img_tensor.shape[0]
118
+
119
+ if input_batch_size > 1:
120
+ num_frames_to_take = min(repeat_count, input_batch_size)
121
+ print(f"V2 Node: Input image_{i} is a batch of {input_batch_size}. Taking {num_frames_to_take} frames starting at index {frame_index_user}.")
122
+
123
+ for j in range(num_frames_to_take):
124
+ current_actual_idx = start_idx + j
125
+ if not (0 <= current_actual_idx < max_frames): break
126
+ image_to_process = img_tensor[j].unsqueeze(0)
127
+ processed_image = self._process_single_image(image_to_process, target_h, target_w, target_c, base_dtype, base_device)
128
+ output_batch[current_actual_idx] = processed_image
129
+ batch_masks[current_actual_idx] = chosen_mask_frame
130
+ else:
131
+ print(f"V2 Node: Input image_{i} is a single image. Repeating {repeat_count} times starting at index {frame_index_user}.")
132
+ image_to_process = img_tensor[0].unsqueeze(0)
133
+ processed_image = self._process_single_image(image_to_process, target_h, target_w, target_c, base_dtype, base_device)
134
+
135
+ for j in range(repeat_count):
136
+ current_actual_idx = start_idx + j
137
+ if not (0 <= current_actual_idx < max_frames): break
138
+ output_batch[current_actual_idx] = processed_image
139
+ batch_masks[current_actual_idx] = chosen_mask_frame
140
+
141
+ return (output_batch, batch_masks,)
142
+
143
+ # --- ComfyUI Boilerplate with NEW NAMES ---
144
+ NODE_CLASS_MAPPINGS = {
145
+ "ImageBatcherByIndexProV2": ImageBatcherByIndexProV2
146
+ }
147
+ NODE_DISPLAY_NAME_MAPPINGS = {
148
+ "ImageBatcherByIndexProV2": "Image Batcher by Index Pro V2"
149
+ }