import torch class Get_Correct_Batch_Img: """ Given a batch of RGBA images, scan a given Y row across a (sub)batch and treat the visible span width on that row as a 1D curve over time (batch index). This node: - Measures the visible width for EVERY image in the selected sub-batch. - Detects a "big wave" pattern and extracts 5 checkpoints: cp0: first major high (start-side high) cp1: first major low (first valley) cp2: next major high (peak after first valley) cp3: second major low (second valley) cp4: final major high (peak after second valley, then shifted 5% back towards cp3) - For each consecutive checkpoint pair, also finds an "in-between" frame: mid_0_1: width closest to midpoint between cp0 and cp1 mid_1_2: width closest to midpoint between cp1 and cp2 mid_2_3: width closest to midpoint between cp2 and cp3 mid_3_4: width closest to midpoint between cp3 and cp4 Outputs (all RGBA, B=1): cp0_start_high cp1_low_1 cp2_high_2 cp3_low_2 cp4_high_3 mid_0_1 mid_1_2 mid_2_3 mid_3_4 Visibility is determined from the alpha channel (A > 0). Images with no visible pixels on that row are treated as width = 0 (completely thin). Only images within [start_index, end_index] (inclusive) are considered. """ CATEGORY = "image/batch" @classmethod def INPUT_TYPES(cls): return { "required": { # RGBA image batch: torch.Tensor [B, H, W, 4] "images": ("IMAGE",), # Sub-batch start index (inclusive, 0-based) "start_index": ( "INT", { "default": 0, "min": 0, "max": 2_147_483_647, "step": 1, }, ), # Sub-batch end index (inclusive, 0-based) "end_index": ( "INT", { "default": 0, "min": 0, "max": 2_147_483_647, "step": 1, }, ), # Y coordinate (row) used for the horizontal scan "y_coord": ( "INT", { "default": 0, "min": 0, "max": 2_147_483_647, "step": 1, }, ), } } # 5 checkpoints + 4 inbetweens = 9 outputs RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE") RETURN_NAMES = ( "cp0_start_high", "cp1_low_1", "cp2_high_2", "cp3_low_2", "cp4_high_3", "mid_0_1", "mid_1_2", "mid_2_3", "mid_3_4", ) FUNCTION = "select" def _compute_widths(self, images, start, end, y, alpha_threshold=0.0): """ For each image in [start, end], compute the visible width on row y. Visibility is alpha > alpha_threshold. If no visible pixels, width = 0. Returns a Python list of widths (len = end-start+1). """ widths = [] for i in range(start, end + 1): row_alpha = images[i, y, :, 3] visible = row_alpha > alpha_threshold if torch.any(visible): # Indices of visible pixels along X visible_indices = torch.nonzero(visible, as_tuple=False).squeeze(1) left_x = int(visible_indices[0]) right_x = int(visible_indices[-1]) width_px = right_x - left_x + 1 # inclusive distance else: # No visible pixels -> treat as width 0 width_px = 0 widths.append(float(width_px)) return widths def _compute_checkpoints(self, widths): """ From a list of widths (one per frame in sub-batch), compute 5 checkpoints: cp0, cp1, cp2, cp3, cp4 (indices into `widths` list). Strategy (global-ish, not just tiny local wiggles): - Split sequence into two halves. - cp1 = minimum in first half (first big valley) - cp3 = minimum in second half (second big valley) - cp0 = maximum from start .. cp1 - cp2 = maximum from cp1 .. cp3 - cp4 = maximum from cp3 .. end - Then nudge cp4 5% of the distance back towards cp3. """ n = len(widths) if n == 0: return [0, 0, 0, 0, 0] # Very small sequences: just spread indices out linearly. if n < 4: cp0 = 0 cp4 = n - 1 cp1 = max(0, min(n - 1, n // 4)) cp3 = max(0, min(n - 1, (3 * n) // 4)) cp2 = max(cp1, min(cp3, (cp1 + cp3) // 2)) return [cp0, cp1, cp2, cp3, cp4] # Normal case: n >= 4 mid = n // 2 # cp1: global min in the FIRST half [0 .. mid] first_half_end = mid cp1_rel = min(range(0, first_half_end + 1), key=lambda i: widths[i]) cp1 = cp1_rel # cp3: global min in the SECOND half [mid .. n-1] second_half_start = mid cp3_rel = min(range(second_half_start, n), key=lambda i: widths[i]) cp3 = cp3_rel # Ensure cp3 is strictly after cp1 where possible, so we genuinely get a second valley. if cp3 <= cp1 and cp1 + 1 < n: cp3 = min(range(cp1 + 1, n), key=lambda i: widths[i]) # cp0: highest point before (and including) cp1 cp0 = max(range(0, cp1 + 1), key=lambda i: widths[i]) # cp2: highest point between cp1 and cp3 (inclusive) cp2 = cp1 + max(range(0, (cp3 - cp1) + 1), key=lambda k: widths[cp1 + k]) # cp4: highest point from cp3 to end cp4 = cp3 + max(range(0, n - cp3), key=lambda k: widths[cp3 + k]) # Nudge cp4 5% towards cp3 along the index axis if cp4 > cp3: dist = cp4 - cp3 new_cp4_float = cp4 - 0.05 * dist new_cp4 = int(round(new_cp4_float)) # Clamp to stay between cp3 and cp4 new_cp4 = max(cp3, min(cp4, new_cp4)) cp4 = new_cp4 return [cp0, cp1, cp2, cp3, cp4] def _find_mid_index(self, idx_a, idx_b, widths): """ Given two checkpoint indices and the width list, find the index whose width is closest to the midpoint (average) of those two widths. Prefer a TRUE in-between frame if possible (strictly between the two indices). If there's no index in-between (they're adjacent or equal), fall back to one of the endpoints. """ if idx_a == idx_b: return idx_a if idx_a < idx_b: lo, hi = idx_a, idx_b else: lo, hi = idx_b, idx_a target = (widths[idx_a] + widths[idx_b]) / 2.0 # Strictly between indices, if any candidates = list(range(lo + 1, hi)) if not candidates: # No in-between frames; allow endpoints candidates = [lo, hi] best_idx = candidates[0] best_diff = abs(widths[best_idx] - target) for j in candidates[1:]: diff = abs(widths[j] - target) if diff < best_diff: best_diff = diff best_idx = j return best_idx def select(self, images, start_index, end_index, y_coord): # --- Basic sanity checks on the input tensor --- if not isinstance(images, torch.Tensor): raise TypeError(f"Expected IMAGE tensor, got {type(images)}") if images.ndim != 4: raise ValueError( f"Expected IMAGE of shape [B,H,W,C], got {tuple(images.shape)}" ) batch_size, height, width, channels = images.shape if channels != 4: raise ValueError( f"Expected RGBA image with 4 channels, got {channels}. " "Make sure your input batch is RGBA (not RGB)." ) if batch_size == 0: raise ValueError("Empty image batch passed to Get_Correct_Batch_Img.") # --- Clamp and normalize indices --- start = max(0, min(int(start_index), batch_size - 1)) end = max(0, min(int(end_index), batch_size - 1)) if start > end: start, end = end, start # swap so start <= end # Clamp Y coordinate into image bounds y = max(0, min(int(y_coord), height - 1)) # --- 1) Measure width for every image in the sub-batch --- widths = self._compute_widths(images, start, end, y) n = len(widths) # Safety: if for some reason we got no widths (shouldn't happen), just # use start as everything. if n == 0: base_img = images[start].unsqueeze(0) return ( base_img, base_img, base_img, base_img, base_img, base_img, base_img, base_img, base_img, ) # --- 2) Find the 5 checkpoints on this "wave" --- cp0, cp1, cp2, cp3, cp4 = self._compute_checkpoints(widths) # Clamp checkpoints to valid local indices, just in case cp0 = max(0, min(n - 1, int(cp0))) cp1 = max(0, min(n - 1, int(cp1))) cp2 = max(0, min(n - 1, int(cp2))) cp3 = max(0, min(n - 1, int(cp3))) cp4 = max(0, min(n - 1, int(cp4))) # --- 3) Compute in-betweens between each consecutive pair --- mid_0_1 = self._find_mid_index(cp0, cp1, widths) mid_1_2 = self._find_mid_index(cp1, cp2, widths) mid_2_3 = self._find_mid_index(cp2, cp3, widths) mid_3_4 = self._find_mid_index(cp3, cp4, widths) # Map local indices [0..n-1] back to global batch indices [0..batch_size-1] def local_to_global(local_idx): return start + local_idx idx_cp0 = local_to_global(cp0) idx_cp1 = local_to_global(cp1) idx_cp2 = local_to_global(cp2) idx_cp3 = local_to_global(cp3) idx_cp4 = local_to_global(cp4) idx_mid_0_1 = local_to_global(mid_0_1) idx_mid_1_2 = local_to_global(mid_1_2) idx_mid_2_3 = local_to_global(mid_2_3) idx_mid_3_4 = local_to_global(mid_3_4) # --- 4) Extract the corresponding images as individual 1-image batches --- cp0_img = images[idx_cp0].unsqueeze(0) cp1_img = images[idx_cp1].unsqueeze(0) cp2_img = images[idx_cp2].unsqueeze(0) cp3_img = images[idx_cp3].unsqueeze(0) cp4_img = images[idx_cp4].unsqueeze(0) mid_0_1_img = images[idx_mid_0_1].unsqueeze(0) mid_1_2_img = images[idx_mid_1_2].unsqueeze(0) mid_2_3_img = images[idx_mid_2_3].unsqueeze(0) mid_3_4_img = images[idx_mid_3_4].unsqueeze(0) return ( cp0_img, cp1_img, cp2_img, cp3_img, cp4_img, mid_0_1_img, mid_1_2_img, mid_2_3_img, mid_3_4_img, ) # Register node with ComfyUI NODE_CLASS_MAPPINGS = { "Get_Correct_Batch_Img": Get_Correct_Batch_Img, } NODE_DISPLAY_NAME_MAPPINGS = { "Get_Correct_Batch_Img": "Get_Correct_Batch_Img (Salia Wave)", }