File size: 11,845 Bytes
cfd41d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
import torch
class Get_Correct_Batch_Img:
"""
Given a batch of RGBA images, scan a given Y row across a (sub)batch and
treat the visible span width on that row as a 1D curve over time (batch index).
This node:
- Measures the visible width for EVERY image in the selected sub-batch.
- Detects a "big wave" pattern and extracts 5 checkpoints:
cp0: first major high (start-side high)
cp1: first major low (first valley)
cp2: next major high (peak after first valley)
cp3: second major low (second valley)
cp4: final major high (peak after second valley, then shifted 5% back towards cp3)
- For each consecutive checkpoint pair, also finds an "in-between" frame:
mid_0_1: width closest to midpoint between cp0 and cp1
mid_1_2: width closest to midpoint between cp1 and cp2
mid_2_3: width closest to midpoint between cp2 and cp3
mid_3_4: width closest to midpoint between cp3 and cp4
Outputs (all RGBA, B=1):
cp0_start_high
cp1_low_1
cp2_high_2
cp3_low_2
cp4_high_3
mid_0_1
mid_1_2
mid_2_3
mid_3_4
Visibility is determined from the alpha channel (A > 0). Images with no
visible pixels on that row are treated as width = 0 (completely thin).
Only images within [start_index, end_index] (inclusive) are considered.
"""
CATEGORY = "image/batch"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
# RGBA image batch: torch.Tensor [B, H, W, 4]
"images": ("IMAGE",),
# Sub-batch start index (inclusive, 0-based)
"start_index": (
"INT",
{
"default": 0,
"min": 0,
"max": 2_147_483_647,
"step": 1,
},
),
# Sub-batch end index (inclusive, 0-based)
"end_index": (
"INT",
{
"default": 0,
"min": 0,
"max": 2_147_483_647,
"step": 1,
},
),
# Y coordinate (row) used for the horizontal scan
"y_coord": (
"INT",
{
"default": 0,
"min": 0,
"max": 2_147_483_647,
"step": 1,
},
),
}
}
# 5 checkpoints + 4 inbetweens = 9 outputs
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE", "IMAGE", "IMAGE",
"IMAGE", "IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = (
"cp0_start_high",
"cp1_low_1",
"cp2_high_2",
"cp3_low_2",
"cp4_high_3",
"mid_0_1",
"mid_1_2",
"mid_2_3",
"mid_3_4",
)
FUNCTION = "select"
def _compute_widths(self, images, start, end, y, alpha_threshold=0.0):
"""
For each image in [start, end], compute the visible width on row y.
Visibility is alpha > alpha_threshold. If no visible pixels, width = 0.
Returns a Python list of widths (len = end-start+1).
"""
widths = []
for i in range(start, end + 1):
row_alpha = images[i, y, :, 3]
visible = row_alpha > alpha_threshold
if torch.any(visible):
# Indices of visible pixels along X
visible_indices = torch.nonzero(visible, as_tuple=False).squeeze(1)
left_x = int(visible_indices[0])
right_x = int(visible_indices[-1])
width_px = right_x - left_x + 1 # inclusive distance
else:
# No visible pixels -> treat as width 0
width_px = 0
widths.append(float(width_px))
return widths
def _compute_checkpoints(self, widths):
"""
From a list of widths (one per frame in sub-batch), compute 5 checkpoints:
cp0, cp1, cp2, cp3, cp4 (indices into `widths` list).
Strategy (global-ish, not just tiny local wiggles):
- Split sequence into two halves.
- cp1 = minimum in first half (first big valley)
- cp3 = minimum in second half (second big valley)
- cp0 = maximum from start .. cp1
- cp2 = maximum from cp1 .. cp3
- cp4 = maximum from cp3 .. end
- Then nudge cp4 5% of the distance back towards cp3.
"""
n = len(widths)
if n == 0:
return [0, 0, 0, 0, 0]
# Very small sequences: just spread indices out linearly.
if n < 4:
cp0 = 0
cp4 = n - 1
cp1 = max(0, min(n - 1, n // 4))
cp3 = max(0, min(n - 1, (3 * n) // 4))
cp2 = max(cp1, min(cp3, (cp1 + cp3) // 2))
return [cp0, cp1, cp2, cp3, cp4]
# Normal case: n >= 4
mid = n // 2
# cp1: global min in the FIRST half [0 .. mid]
first_half_end = mid
cp1_rel = min(range(0, first_half_end + 1), key=lambda i: widths[i])
cp1 = cp1_rel
# cp3: global min in the SECOND half [mid .. n-1]
second_half_start = mid
cp3_rel = min(range(second_half_start, n), key=lambda i: widths[i])
cp3 = cp3_rel
# Ensure cp3 is strictly after cp1 where possible, so we genuinely get a second valley.
if cp3 <= cp1 and cp1 + 1 < n:
cp3 = min(range(cp1 + 1, n), key=lambda i: widths[i])
# cp0: highest point before (and including) cp1
cp0 = max(range(0, cp1 + 1), key=lambda i: widths[i])
# cp2: highest point between cp1 and cp3 (inclusive)
cp2 = cp1 + max(range(0, (cp3 - cp1) + 1), key=lambda k: widths[cp1 + k])
# cp4: highest point from cp3 to end
cp4 = cp3 + max(range(0, n - cp3), key=lambda k: widths[cp3 + k])
# Nudge cp4 5% towards cp3 along the index axis
if cp4 > cp3:
dist = cp4 - cp3
new_cp4_float = cp4 - 0.05 * dist
new_cp4 = int(round(new_cp4_float))
# Clamp to stay between cp3 and cp4
new_cp4 = max(cp3, min(cp4, new_cp4))
cp4 = new_cp4
return [cp0, cp1, cp2, cp3, cp4]
def _find_mid_index(self, idx_a, idx_b, widths):
"""
Given two checkpoint indices and the width list, find the index whose
width is closest to the midpoint (average) of those two widths.
Prefer a TRUE in-between frame if possible (strictly between the two
indices). If there's no index in-between (they're adjacent or equal),
fall back to one of the endpoints.
"""
if idx_a == idx_b:
return idx_a
if idx_a < idx_b:
lo, hi = idx_a, idx_b
else:
lo, hi = idx_b, idx_a
target = (widths[idx_a] + widths[idx_b]) / 2.0
# Strictly between indices, if any
candidates = list(range(lo + 1, hi))
if not candidates:
# No in-between frames; allow endpoints
candidates = [lo, hi]
best_idx = candidates[0]
best_diff = abs(widths[best_idx] - target)
for j in candidates[1:]:
diff = abs(widths[j] - target)
if diff < best_diff:
best_diff = diff
best_idx = j
return best_idx
def select(self, images, start_index, end_index, y_coord):
# --- Basic sanity checks on the input tensor ---
if not isinstance(images, torch.Tensor):
raise TypeError(f"Expected IMAGE tensor, got {type(images)}")
if images.ndim != 4:
raise ValueError(
f"Expected IMAGE of shape [B,H,W,C], got {tuple(images.shape)}"
)
batch_size, height, width, channels = images.shape
if channels != 4:
raise ValueError(
f"Expected RGBA image with 4 channels, got {channels}. "
"Make sure your input batch is RGBA (not RGB)."
)
if batch_size == 0:
raise ValueError("Empty image batch passed to Get_Correct_Batch_Img.")
# --- Clamp and normalize indices ---
start = max(0, min(int(start_index), batch_size - 1))
end = max(0, min(int(end_index), batch_size - 1))
if start > end:
start, end = end, start # swap so start <= end
# Clamp Y coordinate into image bounds
y = max(0, min(int(y_coord), height - 1))
# --- 1) Measure width for every image in the sub-batch ---
widths = self._compute_widths(images, start, end, y)
n = len(widths)
# Safety: if for some reason we got no widths (shouldn't happen), just
# use start as everything.
if n == 0:
base_img = images[start].unsqueeze(0)
return (
base_img, base_img, base_img, base_img, base_img,
base_img, base_img, base_img, base_img,
)
# --- 2) Find the 5 checkpoints on this "wave" ---
cp0, cp1, cp2, cp3, cp4 = self._compute_checkpoints(widths)
# Clamp checkpoints to valid local indices, just in case
cp0 = max(0, min(n - 1, int(cp0)))
cp1 = max(0, min(n - 1, int(cp1)))
cp2 = max(0, min(n - 1, int(cp2)))
cp3 = max(0, min(n - 1, int(cp3)))
cp4 = max(0, min(n - 1, int(cp4)))
# --- 3) Compute in-betweens between each consecutive pair ---
mid_0_1 = self._find_mid_index(cp0, cp1, widths)
mid_1_2 = self._find_mid_index(cp1, cp2, widths)
mid_2_3 = self._find_mid_index(cp2, cp3, widths)
mid_3_4 = self._find_mid_index(cp3, cp4, widths)
# Map local indices [0..n-1] back to global batch indices [0..batch_size-1]
def local_to_global(local_idx):
return start + local_idx
idx_cp0 = local_to_global(cp0)
idx_cp1 = local_to_global(cp1)
idx_cp2 = local_to_global(cp2)
idx_cp3 = local_to_global(cp3)
idx_cp4 = local_to_global(cp4)
idx_mid_0_1 = local_to_global(mid_0_1)
idx_mid_1_2 = local_to_global(mid_1_2)
idx_mid_2_3 = local_to_global(mid_2_3)
idx_mid_3_4 = local_to_global(mid_3_4)
# --- 4) Extract the corresponding images as individual 1-image batches ---
cp0_img = images[idx_cp0].unsqueeze(0)
cp1_img = images[idx_cp1].unsqueeze(0)
cp2_img = images[idx_cp2].unsqueeze(0)
cp3_img = images[idx_cp3].unsqueeze(0)
cp4_img = images[idx_cp4].unsqueeze(0)
mid_0_1_img = images[idx_mid_0_1].unsqueeze(0)
mid_1_2_img = images[idx_mid_1_2].unsqueeze(0)
mid_2_3_img = images[idx_mid_2_3].unsqueeze(0)
mid_3_4_img = images[idx_mid_3_4].unsqueeze(0)
return (
cp0_img,
cp1_img,
cp2_img,
cp3_img,
cp4_img,
mid_0_1_img,
mid_1_2_img,
mid_2_3_img,
mid_3_4_img,
)
# Register node with ComfyUI
NODE_CLASS_MAPPINGS = {
"Get_Correct_Batch_Img": Get_Correct_Batch_Img,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Get_Correct_Batch_Img": "Get_Correct_Batch_Img (Salia Wave)",
}
|