Upload image_batcher_by_indexz.py
Browse files- image_batcher_by_indexz.py +149 -0
image_batcher_by_indexz.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
# --- NEW CLASS NAME ---
|
| 5 |
+
class ImageBatcherByIndexProV2:
|
| 6 |
+
"""
|
| 7 |
+
(V2) A ComfyUI node that creates a batch of images with advanced features.
|
| 8 |
+
- This version supports both single images and input batches.
|
| 9 |
+
- User specifies max_frames for the output batch.
|
| 10 |
+
- For each input image (up to 6), user can specify its start position (frame_index)
|
| 11 |
+
and mask behavior (mask_as_image_area_is_black or white).
|
| 12 |
+
- The 'repeat_count' parameter has dual functionality:
|
| 13 |
+
- If the input is a single image (batch size 1), 'repeat_count' dictates
|
| 14 |
+
how many times that single image is repeated.
|
| 15 |
+
- If the input is a batch of images (batch size > 1), 'repeat_count'
|
| 16 |
+
specifies how many images to take sequentially from that input batch.
|
| 17 |
+
- Output resolution is determined by the first connected input image.
|
| 18 |
+
- Frames not filled by an input image will be RGB(127,127,127).
|
| 19 |
+
- Outputs 'output_batch' and 'batch_masks'.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
MASK_BEHAVIOR_OPTIONS = ["IMAGE_AREA_IS_BLACK", "IMAGE_AREA_IS_WHITE"]
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def INPUT_TYPES(s):
|
| 26 |
+
inputs = {
|
| 27 |
+
"required": {
|
| 28 |
+
"max_frames": ("INT", {"default": 50, "min": 1, "max": 8192, "step": 1, "display": "number"}),
|
| 29 |
+
},
|
| 30 |
+
"optional": {}
|
| 31 |
+
}
|
| 32 |
+
for i in range(1, 7):
|
| 33 |
+
inputs["optional"][f"image_{i}"] = ("IMAGE",)
|
| 34 |
+
inputs["optional"][f"frame_index_{i}"] = ("INT", {"default": i, "min": 1, "max": 8192, "step": 1, "display": "number"})
|
| 35 |
+
inputs["optional"][f"repeat_count_{i}"] = ("INT", {"default": 1, "min": 1, "max": 8192, "step": 1, "display": "number"})
|
| 36 |
+
inputs["optional"][f"mask_behavior_{i}"] = (s.MASK_BEHAVIOR_OPTIONS, {"default": s.MASK_BEHAVIOR_OPTIONS[0]})
|
| 37 |
+
return inputs
|
| 38 |
+
|
| 39 |
+
RETURN_TYPES = ("IMAGE", "IMAGE",)
|
| 40 |
+
RETURN_NAMES = ("output_batch", "batch_masks",)
|
| 41 |
+
FUNCTION = "create_batch_pro"
|
| 42 |
+
CATEGORY = "utils/batching"
|
| 43 |
+
|
| 44 |
+
def _prepare_color_frame(self, color_tuple, target_h, target_w, target_c, dtype, device):
|
| 45 |
+
color_tensor = torch.tensor(color_tuple, dtype=dtype, device=device)
|
| 46 |
+
return color_tensor.reshape(1, 1, target_c).expand(target_h, target_w, target_c)
|
| 47 |
+
|
| 48 |
+
def _process_single_image(self, image_b1hwc, target_h, target_w, target_c, dtype, device):
|
| 49 |
+
current_image_orig = image_b1hwc
|
| 50 |
+
if current_image_orig.shape[3] != target_c:
|
| 51 |
+
current_image_adapted = torch.zeros((1, target_h, target_w, target_c), dtype=dtype, device=device)
|
| 52 |
+
common_channels = min(current_image_orig.shape[3], target_c)
|
| 53 |
+
temp_resized = F.interpolate(current_image_orig.permute(0, 3, 1, 2), size=(target_h, target_w), mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
|
| 54 |
+
current_image_adapted[..., :common_channels] = temp_resized[..., :common_channels]
|
| 55 |
+
|
| 56 |
+
if target_c == 4 and current_image_orig.shape[3] < 4:
|
| 57 |
+
current_image_adapted[..., 3] = 1.0
|
| 58 |
+
elif target_c == 1 and current_image_orig.shape[3] > 1:
|
| 59 |
+
current_image_adapted[..., 0] = temp_resized[..., :3].mean(dim=3)
|
| 60 |
+
current_image_orig = current_image_adapted
|
| 61 |
+
|
| 62 |
+
if current_image_orig.shape[1] != target_h or current_image_orig.shape[2] != target_w:
|
| 63 |
+
img_to_resize_permuted = current_image_orig.permute(0, 3, 1, 2)
|
| 64 |
+
resized_permuted = F.interpolate(img_to_resize_permuted, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
| 65 |
+
processed_image = resized_permuted.permute(0, 2, 3, 1)[0]
|
| 66 |
+
else:
|
| 67 |
+
processed_image = current_image_orig[0]
|
| 68 |
+
return processed_image
|
| 69 |
+
|
| 70 |
+
def create_batch_pro(self, max_frames, **kwargs):
|
| 71 |
+
target_h, target_w, target_c = -1, -1, -1
|
| 72 |
+
first_valid_image_tensor = None
|
| 73 |
+
base_dtype = torch.float32
|
| 74 |
+
base_device = 'cpu'
|
| 75 |
+
|
| 76 |
+
for i in range(1, 7):
|
| 77 |
+
img_tensor = kwargs.get(f"image_{i}")
|
| 78 |
+
if img_tensor is not None:
|
| 79 |
+
first_valid_image_tensor = img_tensor
|
| 80 |
+
target_h, target_w, target_c = img_tensor.shape[1], img_tensor.shape[2], img_tensor.shape[3]
|
| 81 |
+
base_dtype = img_tensor.dtype
|
| 82 |
+
base_device = img_tensor.device
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
if first_valid_image_tensor is None:
|
| 86 |
+
empty_img = torch.empty(0, 1, 1, 3, dtype=base_dtype, device=base_device)
|
| 87 |
+
return (empty_img, empty_img,)
|
| 88 |
+
|
| 89 |
+
fill_value_rgb_norm = 127.0 / 255.0
|
| 90 |
+
fill_color_tuple = (fill_value_rgb_norm,) * min(target_c, 3)
|
| 91 |
+
white_color_tuple = (1.0,) * min(target_c, 3)
|
| 92 |
+
black_color_tuple = (0.0,) * min(target_c, 3)
|
| 93 |
+
if target_c > 3:
|
| 94 |
+
fill_color_tuple += (1.0,)
|
| 95 |
+
white_color_tuple += (1.0,)
|
| 96 |
+
black_color_tuple += (1.0,)
|
| 97 |
+
|
| 98 |
+
fill_frame = self._prepare_color_frame(fill_color_tuple, target_h, target_w, target_c, base_dtype, base_device)
|
| 99 |
+
white_frame_mask = self._prepare_color_frame(white_color_tuple, target_h, target_w, target_c, base_dtype, base_device)
|
| 100 |
+
black_frame_mask = self._prepare_color_frame(black_color_tuple, target_h, target_w, target_c, base_dtype, base_device)
|
| 101 |
+
|
| 102 |
+
output_batch = torch.empty((max_frames, target_h, target_w, target_c), dtype=base_dtype, device=base_device)
|
| 103 |
+
output_batch[:] = fill_frame
|
| 104 |
+
batch_masks = torch.empty((max_frames, target_h, target_w, target_c), dtype=base_dtype, device=base_device)
|
| 105 |
+
batch_masks[:] = white_frame_mask
|
| 106 |
+
|
| 107 |
+
for i in range(1, 7):
|
| 108 |
+
img_tensor = kwargs.get(f"image_{i}")
|
| 109 |
+
if img_tensor is None: continue
|
| 110 |
+
|
| 111 |
+
frame_index_user = kwargs.get(f"frame_index_{i}", i)
|
| 112 |
+
repeat_count = kwargs.get(f"repeat_count_{i}", 1)
|
| 113 |
+
mask_behavior = kwargs.get(f"mask_behavior_{i}", self.MASK_BEHAVIOR_OPTIONS[0])
|
| 114 |
+
start_idx = frame_index_user - 1
|
| 115 |
+
chosen_mask_frame = black_frame_mask if mask_behavior == self.MASK_BEHAVIOR_OPTIONS[0] else white_frame_mask
|
| 116 |
+
|
| 117 |
+
input_batch_size = img_tensor.shape[0]
|
| 118 |
+
|
| 119 |
+
if input_batch_size > 1:
|
| 120 |
+
num_frames_to_take = min(repeat_count, input_batch_size)
|
| 121 |
+
print(f"V2 Node: Input image_{i} is a batch of {input_batch_size}. Taking {num_frames_to_take} frames starting at index {frame_index_user}.")
|
| 122 |
+
|
| 123 |
+
for j in range(num_frames_to_take):
|
| 124 |
+
current_actual_idx = start_idx + j
|
| 125 |
+
if not (0 <= current_actual_idx < max_frames): break
|
| 126 |
+
image_to_process = img_tensor[j].unsqueeze(0)
|
| 127 |
+
processed_image = self._process_single_image(image_to_process, target_h, target_w, target_c, base_dtype, base_device)
|
| 128 |
+
output_batch[current_actual_idx] = processed_image
|
| 129 |
+
batch_masks[current_actual_idx] = chosen_mask_frame
|
| 130 |
+
else:
|
| 131 |
+
print(f"V2 Node: Input image_{i} is a single image. Repeating {repeat_count} times starting at index {frame_index_user}.")
|
| 132 |
+
image_to_process = img_tensor[0].unsqueeze(0)
|
| 133 |
+
processed_image = self._process_single_image(image_to_process, target_h, target_w, target_c, base_dtype, base_device)
|
| 134 |
+
|
| 135 |
+
for j in range(repeat_count):
|
| 136 |
+
current_actual_idx = start_idx + j
|
| 137 |
+
if not (0 <= current_actual_idx < max_frames): break
|
| 138 |
+
output_batch[current_actual_idx] = processed_image
|
| 139 |
+
batch_masks[current_actual_idx] = chosen_mask_frame
|
| 140 |
+
|
| 141 |
+
return (output_batch, batch_masks,)
|
| 142 |
+
|
| 143 |
+
# --- ComfyUI Boilerplate with NEW NAMES ---
|
| 144 |
+
NODE_CLASS_MAPPINGS = {
|
| 145 |
+
"ImageBatcherByIndexProV2": ImageBatcherByIndexProV2
|
| 146 |
+
}
|
| 147 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 148 |
+
"ImageBatcherByIndexProV2": "Image Batcher by Index Pro V2"
|
| 149 |
+
}
|