saliacoel commited on
Commit
56031a1
·
verified ·
1 Parent(s): 0d007be

Upload salia_extract_loop.py

Browse files
Files changed (1) hide show
  1. salia_extract_loop.py +679 -0
salia_extract_loop.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ # ============================================================
5
+ # Basic helpers (standalone)
6
+ # ============================================================
7
+
8
+ def _bhwc_to_nchw(img: torch.Tensor) -> torch.Tensor:
9
+ if img.dim() != 4:
10
+ raise ValueError(f"Expected [B,H,W,C], got {tuple(img.shape)}")
11
+ return img.permute(0, 3, 1, 2).contiguous()
12
+
13
+ def _ensure_rgba_bhwc(images: torch.Tensor) -> torch.Tensor:
14
+ if images.dim() != 4:
15
+ raise ValueError(f"Expected [B,H,W,C], got {tuple(images.shape)}")
16
+ b, h, w, c = images.shape
17
+ if c == 4:
18
+ return images
19
+ if c == 3:
20
+ alpha = torch.ones((b, h, w, 1), device=images.device, dtype=images.dtype)
21
+ return torch.cat([images, alpha], dim=3)
22
+ raise ValueError(f"Expected 3 or 4 channels, got {c}")
23
+
24
+ def _to_luma(x: torch.Tensor) -> torch.Tensor:
25
+ # x: [B,3,H,W]
26
+ r = x[:, 0:1, :, :]
27
+ g = x[:, 1:2, :, :]
28
+ b = x[:, 2:3, :, :]
29
+ return (0.2989 * r + 0.5870 * g + 0.1140 * b)
30
+
31
+ def _resize_max(x: torch.Tensor, max_size: int) -> torch.Tensor:
32
+ if max_size <= 0:
33
+ return x
34
+ b, c, h, w = x.shape
35
+ m = max(h, w)
36
+ if m <= max_size:
37
+ return x
38
+ scale = max_size / float(m)
39
+ nh = max(1, int(round(h * scale)))
40
+ nw = max(1, int(round(w * scale)))
41
+ return F.interpolate(x, size=(nh, nw), mode="bilinear", align_corners=False)
42
+
43
+ def _gaussian_blur(x: torch.Tensor, sigma: float) -> torch.Tensor:
44
+ if sigma <= 0:
45
+ return x
46
+ radius = int(max(1, round(3.0 * sigma)))
47
+ ksize = 2 * radius + 1
48
+ device = x.device
49
+ dtype = x.dtype
50
+
51
+ coords = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
52
+ kernel1d = torch.exp(-(coords * coords) / (2.0 * sigma * sigma))
53
+ kernel1d = kernel1d / (kernel1d.sum() + 1e-12)
54
+
55
+ c = x.shape[1]
56
+ kh = kernel1d.view(1, 1, 1, ksize).repeat(c, 1, 1, 1)
57
+ kv = kernel1d.view(1, 1, ksize, 1).repeat(c, 1, 1, 1)
58
+
59
+ out = F.conv2d(x, kh, padding=(0, radius), groups=c)
60
+ out = F.conv2d(out, kv, padding=(radius, 0), groups=c)
61
+ return out
62
+
63
+ def _sobel_edges(y: torch.Tensor) -> torch.Tensor:
64
+ # y: [B,1,H,W]
65
+ device = y.device
66
+ dtype = y.dtype
67
+
68
+ kx = torch.tensor(
69
+ [[-1, 0, 1],
70
+ [-2, 0, 2],
71
+ [-1, 0, 1]],
72
+ device=device, dtype=dtype
73
+ ) / 8.0
74
+
75
+ ky = torch.tensor(
76
+ [[-1, -2, -1],
77
+ [ 0, 0, 0],
78
+ [ 1, 2, 1]],
79
+ device=device, dtype=dtype
80
+ ) / 8.0
81
+
82
+ kx = kx.view(1, 1, 3, 3)
83
+ ky = ky.view(1, 1, 3, 3)
84
+
85
+ gx = F.conv2d(y, kx, padding=1)
86
+ gy = F.conv2d(y, ky, padding=1)
87
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
88
+
89
+ def _make_gaussian_window(window_size: int, sigma: float, device, dtype):
90
+ radius = window_size // 2
91
+ coords = torch.arange(window_size, device=device, dtype=dtype) - radius
92
+ g = torch.exp(-(coords * coords) / (2.0 * sigma * sigma))
93
+ g = g / (g.sum() + 1e-12)
94
+ w2d = (g[:, None] * g[None, :]).view(1, 1, window_size, window_size)
95
+ return w2d, radius
96
+
97
+ def _ssim_fast(x: torch.Tensor, y: torch.Tensor, w2d: torch.Tensor, radius: int) -> torch.Tensor:
98
+ """
99
+ SSIM for luma only.
100
+ x,y: [B,1,H,W]
101
+ returns [B]
102
+ """
103
+ mu_x = F.conv2d(x, w2d, padding=radius)
104
+ mu_y = F.conv2d(y, w2d, padding=radius)
105
+
106
+ mu_x2 = mu_x * mu_x
107
+ mu_y2 = mu_y * mu_y
108
+ mu_xy = mu_x * mu_y
109
+
110
+ sigma_x2 = F.conv2d(x * x, w2d, padding=radius) - mu_x2
111
+ sigma_y2 = F.conv2d(y * y, w2d, padding=radius) - mu_y2
112
+ sigma_xy = F.conv2d(x * y, w2d, padding=radius) - mu_xy
113
+
114
+ C1 = (0.01) ** 2
115
+ C2 = (0.03) ** 2
116
+
117
+ num = (2.0 * mu_xy + C1) * (2.0 * sigma_xy + C2)
118
+ den = (mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2)
119
+
120
+ ssim_map = num / (den + 1e-12)
121
+ return ssim_map.mean(dim=[1, 2, 3])
122
+
123
+ def _hist_chi2_from_hists(hx: torch.Tensor, hy: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ hx,hy: [B,3,bins] normalized
126
+ returns [B]
127
+ """
128
+ eps = 1e-12
129
+ return 0.5 * (((hx - hy) ** 2) / (hx + hy + eps)).sum(dim=2).mean(dim=1)
130
+
131
+ # ============================================================
132
+ # Fast frozen-tail diff (cheap)
133
+ # ============================================================
134
+
135
+ def _fast_tail_diff_bhwc(a_bhwc: torch.Tensor, b_bhwc: torch.Tensor) -> float:
136
+ DOWNSCALE_MAX = 128
137
+ BLUR_SIGMA = 1.2
138
+ SCALE = 1000.0
139
+
140
+ a = _bhwc_to_nchw(a_bhwc).clamp(0.0, 1.0)
141
+ b = _bhwc_to_nchw(b_bhwc).clamp(0.0, 1.0)
142
+
143
+ if a.shape[1] >= 4 and b.shape[1] >= 4:
144
+ aa = a[:, 3:4]
145
+ ba = b[:, 3:4]
146
+ ar = a[:, 0:3] * aa
147
+ br = b[:, 0:3] * ba
148
+ else:
149
+ ar = a[:, 0:3]
150
+ br = b[:, 0:3]
151
+
152
+ if ar.shape[2:] != br.shape[2:]:
153
+ br = F.interpolate(br, size=ar.shape[2:], mode="bilinear", align_corners=False)
154
+
155
+ ar = _resize_max(ar, DOWNSCALE_MAX)
156
+ br = _resize_max(br, DOWNSCALE_MAX)
157
+ ar = _gaussian_blur(ar, BLUR_SIGMA)
158
+ br = _gaussian_blur(br, BLUR_SIGMA)
159
+
160
+ mae = torch.mean(torch.abs(ar - br), dim=[1, 2, 3])
161
+ return float(mae.mean().item() * SCALE)
162
+
163
+ # ============================================================
164
+ # Waveform helpers (span widths in a y-band)
165
+ # ============================================================
166
+
167
+ def _smooth_1d(values: list) -> list:
168
+ n = len(values)
169
+ if n < 3:
170
+ return list(values)
171
+ t = torch.tensor(values, dtype=torch.float32).view(1, 1, n)
172
+ k = torch.tensor([0.25, 0.50, 0.25], dtype=torch.float32).view(1, 1, 3)
173
+ tpad = F.pad(t, (1, 1), mode="replicate")
174
+ out = F.conv1d(tpad, k)
175
+ return out.view(n).tolist()
176
+
177
+ def _is_local_min(values: list, i: int) -> bool:
178
+ n = len(values)
179
+ if n < 3:
180
+ return True
181
+ if i <= 0 or i >= n - 1:
182
+ return False
183
+ return (values[i] <= values[i - 1]) and (values[i] <= values[i + 1])
184
+
185
+ def _percentile(values: list, q: float) -> float:
186
+ if not values:
187
+ return 0.0
188
+ s = sorted(values)
189
+ q = max(0.0, min(1.0, float(q)))
190
+ pos = q * (len(s) - 1)
191
+ idx = int(round(pos))
192
+ idx = max(0, min(len(s) - 1, idx))
193
+ return float(s[idx])
194
+
195
+ def _compute_visible_y_bounds(images_bhwc: torch.Tensor, alpha_thr: float = 0.01):
196
+ b, h, w, c = images_bhwc.shape
197
+ if c < 4:
198
+ return (0, h - 1)
199
+ alpha = images_bhwc[:, :, :, 3]
200
+ vis_y = (alpha > alpha_thr).any(dim=2).any(dim=0) # [H]
201
+ idx = torch.where(vis_y)[0]
202
+ if idx.numel() == 0:
203
+ return (0, h - 1)
204
+ y_min = int(idx[0].item())
205
+ y_max = int(idx[-1].item())
206
+ return (max(0, y_min), min(h - 1, y_max))
207
+
208
+ def _compute_band_span_widths(images_bhwc: torch.Tensor,
209
+ y0: int,
210
+ y1: int,
211
+ alpha_thr: float = 0.01,
212
+ sample_rows: int = 32) -> list:
213
+ """
214
+ Robust span width per frame in a band:
215
+ - sample rows between y0..y1
216
+ - for each row, get left/right visible
217
+ - aggregate via 10% / 90% quantile using sorting (small row count)
218
+ """
219
+ b, h, w, c = images_bhwc.shape
220
+ if c < 4:
221
+ return [float(w)] * b
222
+
223
+ y0 = max(0, min(h - 1, int(y0)))
224
+ y1 = max(0, min(h - 1, int(y1)))
225
+ if y0 > y1:
226
+ y0, y1 = y1, y0
227
+
228
+ if sample_rows <= 1:
229
+ ys = [y0]
230
+ else:
231
+ ys_t = torch.linspace(y0, y1, steps=sample_rows)
232
+ ys = torch.unique(torch.round(ys_t).long()).tolist()
233
+ ys = [int(v) for v in ys]
234
+
235
+ widths = []
236
+ for i in range(b):
237
+ lefts = []
238
+ rights = []
239
+ for y in ys:
240
+ row_alpha = images_bhwc[i, y, :, 3]
241
+ vis = row_alpha > alpha_thr
242
+ if torch.any(vis):
243
+ idx = torch.where(vis)[0]
244
+ lefts.append(int(idx[0].item()))
245
+ rights.append(int(idx[-1].item()))
246
+ if not lefts:
247
+ widths.append(0.0)
248
+ continue
249
+ lefts.sort()
250
+ rights.sort()
251
+ # 10% and 90% quantiles (row count is small)
252
+ lq = lefts[int(round(0.10 * (len(lefts) - 1)))]
253
+ rq = rights[int(round(0.90 * (len(rights) - 1)))]
254
+ widths.append(float(max(0, rq - lq + 1)))
255
+ return widths
256
+
257
+ def _valley_candidates(signal: list, max_k: int, min_sep: int) -> list:
258
+ """
259
+ Strong valley candidates:
260
+ - local minima
261
+ - in low quantile band
262
+ - greedy separation
263
+ """
264
+ n = len(signal)
265
+ if n < 6:
266
+ return list(range(n))
267
+
268
+ low_thr = _percentile(signal, 0.50) # valley-like should be in lower half
269
+ mins = [i for i in range(1, n - 1) if _is_local_min(signal, i) and signal[i] <= low_thr]
270
+ if not mins:
271
+ # fallback: just take globally small points
272
+ mins = list(range(n))
273
+
274
+ mins.sort(key=lambda i: signal[i]) # deepest valleys first
275
+
276
+ chosen = []
277
+ for i in mins:
278
+ if all(abs(i - j) >= min_sep for j in chosen):
279
+ chosen.append(i)
280
+ if len(chosen) >= max_k:
281
+ break
282
+
283
+ return sorted(chosen)
284
+
285
+ # ============================================================
286
+ # Hybrid fixed precompute + batched scoring (THIS is the speed win)
287
+ # ============================================================
288
+
289
+ class _HybridFixedBatchScorer:
290
+ """
291
+ Hardcoded hybrid:
292
+ downscale_max=256
293
+ blur_sigma=1.2
294
+ hist_bins=32
295
+ scale=1000
296
+ w_pixel=1.00
297
+ w_ssim=1.00
298
+ w_edge=0.5
299
+ w_hist=0.2
300
+
301
+ For sprites: premultiply alpha ON.
302
+ Precomputes per-frame features once, then scores many pairs at once.
303
+ """
304
+
305
+ DOWNSCALE_MAX = 256
306
+ BLUR_SIGMA = 1.2
307
+ HIST_BINS = 32
308
+ SCALE = 1000.0
309
+ W_PIXEL = 1.0
310
+ W_SSIM = 1.0
311
+ W_EDGE = 0.5
312
+ W_HIST = 0.2
313
+
314
+ SSIM_W = 11
315
+ SSIM_SIGMA = 1.5
316
+
317
+ def __init__(self, images_bhwc: torch.Tensor, premultiply_alpha: bool = True):
318
+ # images_bhwc: [B,H,W,4]
319
+ self.device = images_bhwc.device
320
+ # keep float32 for stable ops
321
+ x = images_bhwc.clamp(0.0, 1.0).to(torch.float32)
322
+ x_nchw = _bhwc_to_nchw(x)
323
+
324
+ if x_nchw.shape[1] >= 4 and premultiply_alpha:
325
+ a = x_nchw[:, 3:4]
326
+ rgb = x_nchw[:, 0:3] * a
327
+ else:
328
+ rgb = x_nchw[:, 0:3]
329
+
330
+ # Downscale once
331
+ rgb_small = _resize_max(rgb, self.DOWNSCALE_MAX)
332
+
333
+ # Blur once
334
+ rgb_blur = _gaussian_blur(rgb_small, self.BLUR_SIGMA)
335
+
336
+ # Luma + edges once
337
+ luma_blur = _to_luma(rgb_blur)
338
+ edge = _sobel_edges(luma_blur)
339
+
340
+ # Histograms ONCE (no GPU->CPU ping-pong)
341
+ # Use histc (matches your original binning behavior closely)
342
+ b, c, h, w = rgb_small.shape
343
+ bins = self.HIST_BINS
344
+ eps = 1e-12
345
+ hists = torch.empty((b, 3, bins), device=rgb_small.device, dtype=torch.float32)
346
+ rgb_small = rgb_small.clamp(0.0, 1.0)
347
+ for i in range(b):
348
+ for ch in range(3):
349
+ hist = torch.histc(rgb_small[i, ch], bins=bins, min=0.0, max=1.0)
350
+ hist = hist / (hist.sum() + eps)
351
+ hists[i, ch] = hist
352
+
353
+ self.rgb_blur = rgb_blur
354
+ self.luma_blur = luma_blur
355
+ self.edge = edge
356
+ self.hists = hists
357
+
358
+ self.w2d, self.radius = _make_gaussian_window(self.SSIM_W, self.SSIM_SIGMA, self.device, torch.float32)
359
+
360
+ def scores_for_pairs(self, idx_i: list, idx_j: list) -> torch.Tensor:
361
+ """
362
+ idx_i, idx_j: python lists of same length M
363
+ returns: tensor [M] float32 (scaled by 1000)
364
+ """
365
+ if len(idx_i) != len(idx_j):
366
+ raise ValueError("idx_i and idx_j must have same length")
367
+ m = len(idx_i)
368
+ if m == 0:
369
+ return torch.zeros((0,), device=self.device, dtype=torch.float32)
370
+
371
+ ti = torch.tensor(idx_i, device=self.device, dtype=torch.long)
372
+ tj = torch.tensor(idx_j, device=self.device, dtype=torch.long)
373
+
374
+ a_rgb = self.rgb_blur.index_select(0, ti)
375
+ b_rgb = self.rgb_blur.index_select(0, tj)
376
+ pix = torch.mean(torch.abs(a_rgb - b_rgb), dim=[1, 2, 3])
377
+
378
+ a_y = self.luma_blur.index_select(0, ti)
379
+ b_y = self.luma_blur.index_select(0, tj)
380
+ ssim = _ssim_fast(a_y, b_y, self.w2d, self.radius)
381
+ ssim_diff = (1.0 - ssim).clamp(min=0.0)
382
+
383
+ a_e = self.edge.index_select(0, ti)
384
+ b_e = self.edge.index_select(0, tj)
385
+ ed = torch.mean(torch.abs(a_e - b_e), dim=[1, 2, 3])
386
+
387
+ ha = self.hists.index_select(0, ti)
388
+ hb = self.hists.index_select(0, tj)
389
+ hist = _hist_chi2_from_hists(ha, hb)
390
+
391
+ per = (self.W_PIXEL * pix) + (self.W_SSIM * ssim_diff) + (self.W_EDGE * ed) + (self.W_HIST * hist)
392
+ return per * self.SCALE
393
+
394
+ def score_one(self, i: int, j: int) -> float:
395
+ s = self.scores_for_pairs([i], [j])
396
+ return float(s[0].item())
397
+
398
+ # ============================================================
399
+ # Node 1: Hardcoded hybrid compare (2 images -> float)
400
+ # ============================================================
401
+
402
+ class ImageCompareHybrid:
403
+ """
404
+ Same hybrid as before, hardcoded.
405
+ Note: for general images, alpha is ignored (matches your original).
406
+ """
407
+ CATEGORY = "image/analysis"
408
+ RETURN_TYPES = ("FLOAT",)
409
+ RETURN_NAMES = ("difference",)
410
+ FUNCTION = "compare"
411
+
412
+ @classmethod
413
+ def INPUT_TYPES(cls):
414
+ return {"required": {"image_a": ("IMAGE",), "image_b": ("IMAGE",)}}
415
+
416
+ def compare(self, image_a, image_b):
417
+ # For single compare, keep behavior: drop alpha (no premultiply)
418
+ # We do it via scorer on a 2-frame batch.
419
+ a = _ensure_rgba_bhwc(image_a).to(torch.float32).clamp(0.0, 1.0)
420
+ b = _ensure_rgba_bhwc(image_b).to(torch.float32).clamp(0.0, 1.0)
421
+ x = torch.cat([a[0:1], b[0:1]], dim=0)
422
+ scorer = _HybridFixedBatchScorer(x, premultiply_alpha=False)
423
+ score = scorer.score_one(0, 1)
424
+ return (float(score),)
425
+
426
+ # ============================================================
427
+ # Node 2: FAST + VALLEY-TO-VALLEY auto loop
428
+ # ============================================================
429
+
430
+ class Salia_Extract_Loop:
431
+ """
432
+ FAST + always valley-to-valley.
433
+
434
+ - trims frozen tail (cheap)
435
+ - computes feet/hands span waveforms
436
+ - chooses valley candidates (minima)
437
+ - evaluates ALL candidate valley pairs in one batched hybrid scoring call
438
+ - refines by snapping to nearby local minima (still valley-to-valley)
439
+ """
440
+
441
+ CATEGORY = "image/batch"
442
+ RETURN_TYPES = ("IMAGE", "INT", "INT", "FLOAT", "STRING")
443
+ RETURN_NAMES = ("loop_batch", "start_index", "end_index", "match_score", "debug")
444
+ FUNCTION = "autoloop"
445
+
446
+ @classmethod
447
+ def INPUT_TYPES(cls):
448
+ return {"required": {"images": ("IMAGE",)}}
449
+
450
+ def _trim_frozen_tail(self, images_bhwc: torch.Tensor):
451
+ FREEZE_THR = 3.0
452
+ MIN_CONSEC = 2
453
+ b = images_bhwc.shape[0]
454
+ if b < 3:
455
+ return b, FREEZE_THR
456
+
457
+ tail = 0
458
+ for t in range(b - 1, 0, -1):
459
+ d = _fast_tail_diff_bhwc(images_bhwc[t - 1:t], images_bhwc[t:t + 1])
460
+ if d < FREEZE_THR:
461
+ tail += 1
462
+ else:
463
+ break
464
+
465
+ if tail >= MIN_CONSEC:
466
+ eff = max(2, b - tail)
467
+ return eff, FREEZE_THR
468
+
469
+ return b, FREEZE_THR
470
+
471
+ def _snap_valley_to_valley(self, scorer, feet_s, start, end, min_len):
472
+ """
473
+ Force both ends to be local minima, by searching nearby.
474
+ Evaluate candidates in one batch.
475
+ """
476
+ n = len(feet_s)
477
+ radius = 6
478
+
479
+ s_cands = []
480
+ for i in range(max(1, start - radius), min(n - 1, start + radius + 1)):
481
+ if _is_local_min(feet_s, i):
482
+ s_cands.append(i)
483
+ e_cands = []
484
+ for j in range(max(1, end - radius), min(n - 1, end + radius + 1)):
485
+ if _is_local_min(feet_s, j):
486
+ e_cands.append(j)
487
+
488
+ if not s_cands:
489
+ s_cands = [start]
490
+ if not e_cands:
491
+ e_cands = [end]
492
+
493
+ pairs_i = []
494
+ pairs_j = []
495
+ for i in s_cands:
496
+ for j in e_cands:
497
+ if j - i >= min_len:
498
+ pairs_i.append(i)
499
+ pairs_j.append(j)
500
+
501
+ if not pairs_i:
502
+ return start, end, float(scorer.score_one(start, end))
503
+
504
+ scores = scorer.scores_for_pairs(pairs_i, pairs_j) # [M]
505
+ k = int(torch.argmin(scores).item())
506
+ best_s = pairs_i[k]
507
+ best_e = pairs_j[k]
508
+ best_score = float(scores[k].item())
509
+ return best_s, best_e, best_score
510
+
511
+ def autoloop(self, images):
512
+ if not isinstance(images, torch.Tensor):
513
+ raise TypeError(f"Expected IMAGE tensor, got {type(images)}")
514
+ if images.ndim != 4:
515
+ raise ValueError(f"Expected IMAGE [B,H,W,C], got {tuple(images.shape)}")
516
+
517
+ images = _ensure_rgba_bhwc(images).to(torch.float32).clamp(0.0, 1.0)
518
+ b, h, w, c = images.shape
519
+
520
+ if b < 6:
521
+ return (images, 0, max(0, b - 1), 0.0, f"Too few frames (B={b})")
522
+
523
+ # 1) Trim frozen tail
524
+ eff_len, freeze_thr = self._trim_frozen_tail(images)
525
+ imgs = images[:eff_len]
526
+ n = imgs.shape[0]
527
+
528
+ if n < 6:
529
+ return (imgs, 0, max(0, n - 1), 0.0, f"After trim too few frames (B={n})")
530
+
531
+ # 2) Visible bounds + adaptive bands
532
+ alpha_thr = 0.01
533
+ y_min, y_max = _compute_visible_y_bounds(imgs, alpha_thr=alpha_thr)
534
+ vis_h = max(1, (y_max - y_min + 1))
535
+
536
+ # relative bands (walkcycle-ish defaults)
537
+ hands_y0 = y_min + int(round(0.45 * (vis_h - 1)))
538
+ hands_y1 = y_min + int(round(0.63 * (vis_h - 1)))
539
+ feet_y0 = y_min + int(round(0.70 * (vis_h - 1)))
540
+ feet_y1 = y_min + int(round(0.93 * (vis_h - 1)))
541
+
542
+ hands_y0 = max(0, min(h - 1, hands_y0))
543
+ hands_y1 = max(0, min(h - 1, hands_y1))
544
+ feet_y0 = max(0, min(h - 1, feet_y0))
545
+ feet_y1 = max(0, min(h - 1, feet_y1))
546
+
547
+ # 3) Waveforms
548
+ feet = _compute_band_span_widths(imgs, feet_y0, feet_y1, alpha_thr=alpha_thr, sample_rows=32)
549
+ hands = _compute_band_span_widths(imgs, hands_y0, hands_y1, alpha_thr=alpha_thr, sample_rows=24)
550
+
551
+ feet_s = _smooth_1d(_smooth_1d(feet))
552
+ hands_s = _smooth_1d(_smooth_1d(hands))
553
+
554
+ feet_range = max(feet_s) - min(feet_s)
555
+ if feet_range < 4.0:
556
+ # In this case valley detection is unreliable; return original trimmed batch
557
+ dbg = (
558
+ f"Feet waveform too flat (range={feet_range:.2f}). "
559
+ f"Returning trimmed batch.\norig_B={b}, eff_B={n}, freeze_thr={freeze_thr}"
560
+ )
561
+ return (imgs, 0, n - 1, 0.0, dbg)
562
+
563
+ # 4) Valley candidates (minima) + choose end valleys near end
564
+ min_sep = max(2, int(round(0.08 * n)))
565
+ valleys = _valley_candidates(feet_s, max_k=14, min_sep=min_sep)
566
+
567
+ end_region = int(round(0.50 * (n - 1)))
568
+ end_valleys = [v for v in valleys if v >= end_region]
569
+ if not end_valleys:
570
+ end_valleys = sorted(valleys)[-4:]
571
+ end_valleys = sorted(end_valleys, reverse=True)[:4]
572
+
573
+ # start valleys are earlier
574
+ min_loop_len = max(8, int(round(0.18 * n))) # prevents half-cycle accidental loops
575
+ start_valleys = [v for v in valleys if v <= (n - 1) - min_loop_len]
576
+
577
+ if not start_valleys or not end_valleys:
578
+ dbg = (
579
+ "No sufficient valley candidates. Returning trimmed batch.\n"
580
+ f"orig_B={b}, eff_B={n}, valleys={valleys}"
581
+ )
582
+ return (imgs, 0, n - 1, 0.0, dbg)
583
+
584
+ # 5) Precompute hybrid features ONCE (premultiply alpha ON for sprites)
585
+ scorer = _HybridFixedBatchScorer(imgs, premultiply_alpha=True)
586
+
587
+ # 6) Build candidate valley pairs and score in ONE batched call
588
+ pairs_i = []
589
+ pairs_j = []
590
+ feat_tie = []
591
+
592
+ foot_rng = (max(feet_s) - min(feet_s)) + 1e-6
593
+ hand_rng = (max(hands_s) - min(hands_s)) + 1e-6
594
+
595
+ for e in end_valleys:
596
+ for s in start_valleys:
597
+ if e - s < min_loop_len:
598
+ continue
599
+ # enforce valley-to-valley: both should be local minima (or at least in candidate list)
600
+ if not _is_local_min(feet_s, s):
601
+ continue
602
+ if not _is_local_min(feet_s, e):
603
+ continue
604
+
605
+ pairs_i.append(s)
606
+ pairs_j.append(e)
607
+ feat = abs(feet_s[s] - feet_s[e]) / foot_rng + abs(hands_s[s] - hands_s[e]) / hand_rng
608
+ feat_tie.append(float(feat))
609
+
610
+ if not pairs_i:
611
+ # fallback: allow candidate valleys even if not strict local minima
612
+ for e in end_valleys:
613
+ for s in start_valleys:
614
+ if e - s >= min_loop_len:
615
+ pairs_i.append(s)
616
+ pairs_j.append(e)
617
+ feat = abs(feet_s[s] - feet_s[e]) / foot_rng + abs(hands_s[s] - hands_s[e]) / hand_rng
618
+ feat_tie.append(float(feat))
619
+
620
+ scores = scorer.scores_for_pairs(pairs_i, pairs_j) # [M]
621
+
622
+ # Combine score with tiny tie-breaker (keeps correct pose if multiple are close)
623
+ tie_w = 10.0
624
+ total = scores + tie_w * torch.tensor(feat_tie, device=scores.device, dtype=scores.dtype)
625
+
626
+ # prefer late end valley if many are similarly good:
627
+ # we do: among scores <= GOOD, pick highest end; else pick min total
628
+ GOOD = 8.0
629
+ good_mask = (scores <= GOOD)
630
+ if torch.any(good_mask):
631
+ good_idx = torch.where(good_mask)[0].tolist()
632
+ # pick max end, then min total
633
+ max_end = max(pairs_j[k] for k in good_idx)
634
+ best_pool = [k for k in good_idx if pairs_j[k] == max_end]
635
+ best_k = min(best_pool, key=lambda k: float(total[k].item()))
636
+ else:
637
+ best_k = int(torch.argmin(total).item())
638
+
639
+ start = int(pairs_i[best_k])
640
+ end = int(pairs_j[best_k])
641
+ match_score = float(scores[best_k].item())
642
+
643
+ # 7) Snap/refine to nearby minima -> GUARANTEED valley-to-valley
644
+ start, end, match_score = self._snap_valley_to_valley(scorer, feet_s, start, end, min_loop_len)
645
+
646
+ dropped_end = end
647
+ end_out = end - 1
648
+
649
+ # Slice end-exclusive (so last returned frame is end-1)
650
+ loop = imgs[start:end] # [start .. end-1]
651
+
652
+ # Optional debug: closure score is now last_kept -> first
653
+ closure_score = float(scorer.score_one(start, end_out)) if end_out > start else float(match_score)
654
+
655
+ dbg = (
656
+ "AutoLoopSpriteBatch FAST (valley-to-valley)\n"
657
+ f"orig_B={b}, eff_B={n} (freeze_thr={freeze_thr})\n"
658
+ f"start={start}, end={end} (dropped), output_end={end_out}, len={end_out-start+1}, match_score(dup)={match_score:.4f}\n"
659
+ f"closure_score(last_kept->first)={closure_score:.4f}\n"
660
+ f"visible_y=[{y_min}..{y_max}] hands_y=[{hands_y0}..{hands_y1}] feet_y=[{feet_y0}..{feet_y1}]\n"
661
+ f"feet_range={feet_range:.2f}, min_sep={min_sep}, min_loop_len={min_loop_len}\n"
662
+ f"valleys={valleys}, end_valleys={end_valleys}"
663
+ )
664
+
665
+ return (loop, int(start), int(end_out), float(match_score), dbg)
666
+
667
+ # ============================================================
668
+ # Register
669
+ # ============================================================
670
+
671
+ NODE_CLASS_MAPPINGS = {
672
+ "ImageCompareHybrid": ImageCompareHybrid,
673
+ "Salia_Extract_Loop": Salia_Extract_Loop,
674
+ }
675
+
676
+ NODE_DISPLAY_NAME_MAPPINGS = {
677
+ "ImageCompareHybrid": "ImageCompareHybrid",
678
+ "Salia_Extract_Loop": "Salia_Extract_Loop",
679
+ }