MyCustomNodes / Get_Correct_Batch_Img.py
saliacoel's picture
Upload Get_Correct_Batch_Img.py
cfd41d9 verified
import torch
class Get_Correct_Batch_Img:
"""
Given a batch of RGBA images, scan a given Y row across a (sub)batch and
treat the visible span width on that row as a 1D curve over time (batch index).
This node:
- Measures the visible width for EVERY image in the selected sub-batch.
- Detects a "big wave" pattern and extracts 5 checkpoints:
cp0: first major high (start-side high)
cp1: first major low (first valley)
cp2: next major high (peak after first valley)
cp3: second major low (second valley)
cp4: final major high (peak after second valley, then shifted 5% back towards cp3)
- For each consecutive checkpoint pair, also finds an "in-between" frame:
mid_0_1: width closest to midpoint between cp0 and cp1
mid_1_2: width closest to midpoint between cp1 and cp2
mid_2_3: width closest to midpoint between cp2 and cp3
mid_3_4: width closest to midpoint between cp3 and cp4
Outputs (all RGBA, B=1):
cp0_start_high
cp1_low_1
cp2_high_2
cp3_low_2
cp4_high_3
mid_0_1
mid_1_2
mid_2_3
mid_3_4
Visibility is determined from the alpha channel (A > 0). Images with no
visible pixels on that row are treated as width = 0 (completely thin).
Only images within [start_index, end_index] (inclusive) are considered.
"""
CATEGORY = "image/batch"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
# RGBA image batch: torch.Tensor [B, H, W, 4]
"images": ("IMAGE",),
# Sub-batch start index (inclusive, 0-based)
"start_index": (
"INT",
{
"default": 0,
"min": 0,
"max": 2_147_483_647,
"step": 1,
},
),
# Sub-batch end index (inclusive, 0-based)
"end_index": (
"INT",
{
"default": 0,
"min": 0,
"max": 2_147_483_647,
"step": 1,
},
),
# Y coordinate (row) used for the horizontal scan
"y_coord": (
"INT",
{
"default": 0,
"min": 0,
"max": 2_147_483_647,
"step": 1,
},
),
}
}
# 5 checkpoints + 4 inbetweens = 9 outputs
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE",
"IMAGE", "IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = (
"cp0_start_high",
"cp1_low_1",
"cp2_high_2",
"cp3_low_2",
"cp4_high_3",
"mid_0_1",
"mid_1_2",
"mid_2_3",
"mid_3_4",
)
FUNCTION = "select"
def _compute_widths(self, images, start, end, y, alpha_threshold=0.0):
"""
For each image in [start, end], compute the visible width on row y.
Visibility is alpha > alpha_threshold. If no visible pixels, width = 0.
Returns a Python list of widths (len = end-start+1).
"""
widths = []
for i in range(start, end + 1):
row_alpha = images[i, y, :, 3]
visible = row_alpha > alpha_threshold
if torch.any(visible):
# Indices of visible pixels along X
visible_indices = torch.nonzero(visible, as_tuple=False).squeeze(1)
left_x = int(visible_indices[0])
right_x = int(visible_indices[-1])
width_px = right_x - left_x + 1 # inclusive distance
else:
# No visible pixels -> treat as width 0
width_px = 0
widths.append(float(width_px))
return widths
def _compute_checkpoints(self, widths):
"""
From a list of widths (one per frame in sub-batch), compute 5 checkpoints:
cp0, cp1, cp2, cp3, cp4 (indices into `widths` list).
Strategy (global-ish, not just tiny local wiggles):
- Split sequence into two halves.
- cp1 = minimum in first half (first big valley)
- cp3 = minimum in second half (second big valley)
- cp0 = maximum from start .. cp1
- cp2 = maximum from cp1 .. cp3
- cp4 = maximum from cp3 .. end
- Then nudge cp4 5% of the distance back towards cp3.
"""
n = len(widths)
if n == 0:
return [0, 0, 0, 0, 0]
# Very small sequences: just spread indices out linearly.
if n < 4:
cp0 = 0
cp4 = n - 1
cp1 = max(0, min(n - 1, n // 4))
cp3 = max(0, min(n - 1, (3 * n) // 4))
cp2 = max(cp1, min(cp3, (cp1 + cp3) // 2))
return [cp0, cp1, cp2, cp3, cp4]
# Normal case: n >= 4
mid = n // 2
# cp1: global min in the FIRST half [0 .. mid]
first_half_end = mid
cp1_rel = min(range(0, first_half_end + 1), key=lambda i: widths[i])
cp1 = cp1_rel
# cp3: global min in the SECOND half [mid .. n-1]
second_half_start = mid
cp3_rel = min(range(second_half_start, n), key=lambda i: widths[i])
cp3 = cp3_rel
# Ensure cp3 is strictly after cp1 where possible, so we genuinely get a second valley.
if cp3 <= cp1 and cp1 + 1 < n:
cp3 = min(range(cp1 + 1, n), key=lambda i: widths[i])
# cp0: highest point before (and including) cp1
cp0 = max(range(0, cp1 + 1), key=lambda i: widths[i])
# cp2: highest point between cp1 and cp3 (inclusive)
cp2 = cp1 + max(range(0, (cp3 - cp1) + 1), key=lambda k: widths[cp1 + k])
# cp4: highest point from cp3 to end
cp4 = cp3 + max(range(0, n - cp3), key=lambda k: widths[cp3 + k])
# Nudge cp4 5% towards cp3 along the index axis
if cp4 > cp3:
dist = cp4 - cp3
new_cp4_float = cp4 - 0.05 * dist
new_cp4 = int(round(new_cp4_float))
# Clamp to stay between cp3 and cp4
new_cp4 = max(cp3, min(cp4, new_cp4))
cp4 = new_cp4
return [cp0, cp1, cp2, cp3, cp4]
def _find_mid_index(self, idx_a, idx_b, widths):
"""
Given two checkpoint indices and the width list, find the index whose
width is closest to the midpoint (average) of those two widths.
Prefer a TRUE in-between frame if possible (strictly between the two
indices). If there's no index in-between (they're adjacent or equal),
fall back to one of the endpoints.
"""
if idx_a == idx_b:
return idx_a
if idx_a < idx_b:
lo, hi = idx_a, idx_b
else:
lo, hi = idx_b, idx_a
target = (widths[idx_a] + widths[idx_b]) / 2.0
# Strictly between indices, if any
candidates = list(range(lo + 1, hi))
if not candidates:
# No in-between frames; allow endpoints
candidates = [lo, hi]
best_idx = candidates[0]
best_diff = abs(widths[best_idx] - target)
for j in candidates[1:]:
diff = abs(widths[j] - target)
if diff < best_diff:
best_diff = diff
best_idx = j
return best_idx
def select(self, images, start_index, end_index, y_coord):
# --- Basic sanity checks on the input tensor ---
if not isinstance(images, torch.Tensor):
raise TypeError(f"Expected IMAGE tensor, got {type(images)}")
if images.ndim != 4:
raise ValueError(
f"Expected IMAGE of shape [B,H,W,C], got {tuple(images.shape)}"
)
batch_size, height, width, channels = images.shape
if channels != 4:
raise ValueError(
f"Expected RGBA image with 4 channels, got {channels}. "
"Make sure your input batch is RGBA (not RGB)."
)
if batch_size == 0:
raise ValueError("Empty image batch passed to Get_Correct_Batch_Img.")
# --- Clamp and normalize indices ---
start = max(0, min(int(start_index), batch_size - 1))
end = max(0, min(int(end_index), batch_size - 1))
if start > end:
start, end = end, start # swap so start <= end
# Clamp Y coordinate into image bounds
y = max(0, min(int(y_coord), height - 1))
# --- 1) Measure width for every image in the sub-batch ---
widths = self._compute_widths(images, start, end, y)
n = len(widths)
# Safety: if for some reason we got no widths (shouldn't happen), just
# use start as everything.
if n == 0:
base_img = images[start].unsqueeze(0)
return (
base_img, base_img, base_img, base_img, base_img,
base_img, base_img, base_img, base_img,
)
# --- 2) Find the 5 checkpoints on this "wave" ---
cp0, cp1, cp2, cp3, cp4 = self._compute_checkpoints(widths)
# Clamp checkpoints to valid local indices, just in case
cp0 = max(0, min(n - 1, int(cp0)))
cp1 = max(0, min(n - 1, int(cp1)))
cp2 = max(0, min(n - 1, int(cp2)))
cp3 = max(0, min(n - 1, int(cp3)))
cp4 = max(0, min(n - 1, int(cp4)))
# --- 3) Compute in-betweens between each consecutive pair ---
mid_0_1 = self._find_mid_index(cp0, cp1, widths)
mid_1_2 = self._find_mid_index(cp1, cp2, widths)
mid_2_3 = self._find_mid_index(cp2, cp3, widths)
mid_3_4 = self._find_mid_index(cp3, cp4, widths)
# Map local indices [0..n-1] back to global batch indices [0..batch_size-1]
def local_to_global(local_idx):
return start + local_idx
idx_cp0 = local_to_global(cp0)
idx_cp1 = local_to_global(cp1)
idx_cp2 = local_to_global(cp2)
idx_cp3 = local_to_global(cp3)
idx_cp4 = local_to_global(cp4)
idx_mid_0_1 = local_to_global(mid_0_1)
idx_mid_1_2 = local_to_global(mid_1_2)
idx_mid_2_3 = local_to_global(mid_2_3)
idx_mid_3_4 = local_to_global(mid_3_4)
# --- 4) Extract the corresponding images as individual 1-image batches ---
cp0_img = images[idx_cp0].unsqueeze(0)
cp1_img = images[idx_cp1].unsqueeze(0)
cp2_img = images[idx_cp2].unsqueeze(0)
cp3_img = images[idx_cp3].unsqueeze(0)
cp4_img = images[idx_cp4].unsqueeze(0)
mid_0_1_img = images[idx_mid_0_1].unsqueeze(0)
mid_1_2_img = images[idx_mid_1_2].unsqueeze(0)
mid_2_3_img = images[idx_mid_2_3].unsqueeze(0)
mid_3_4_img = images[idx_mid_3_4].unsqueeze(0)
return (
cp0_img,
cp1_img,
cp2_img,
cp3_img,
cp4_img,
mid_0_1_img,
mid_1_2_img,
mid_2_3_img,
mid_3_4_img,
)
# Register node with ComfyUI
NODE_CLASS_MAPPINGS = {
"Get_Correct_Batch_Img": Get_Correct_Batch_Img,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Get_Correct_Batch_Img": "Get_Correct_Batch_Img (Salia Wave)",
}