saliacoel commited on
Commit
acae2b0
·
verified ·
1 Parent(s): cf1376b

Upload batch_loop_clean_rife_fill.py

Browse files
Files changed (1) hide show
  1. batch_loop_clean_rife_fill.py +568 -0
batch_loop_clean_rife_fill.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import importlib
6
+ import threading
7
+ from typing import List, Tuple, Optional
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ # ============================================================
14
+ # Hardcoded HYBRID compare settings (exactly as requested)
15
+ # ============================================================
16
+ _DOWNSCALE_MAX = 256
17
+ _BLUR_SIGMA = 1.2
18
+ _HIST_BINS = 32
19
+ _SCALE = 1000.0
20
+
21
+ _W_PIXEL = 1.00
22
+ _W_SSIM = 1.00
23
+ _W_EDGE = 0.50
24
+ _W_HIST = 0.20
25
+
26
+
27
+ # ============================================================
28
+ # Lazy RIFE import (mirrors your wrapper behavior)
29
+ # ============================================================
30
+ _IMPORT_LOCK = threading.Lock()
31
+ _RIFE_CLASS = None
32
+
33
+ _HARDCODED_CKPT_NAME = "rife47.pth"
34
+ _HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES = 10
35
+ _HARDCODED_FAST_MODE = True
36
+ _HARDCODED_ENSEMBLE = True
37
+ _HARDCODED_SCALE_FACTOR = 1.0
38
+
39
+
40
+ def _lazy_get_rife_class():
41
+ """
42
+ Lazily import ComfyUI-Frame-Interpolation's RIFE_VFI class.
43
+ Expected folder:
44
+ ComfyUI/custom_nodes/ComfyUI-Frame-Interpolation
45
+ """
46
+ global _RIFE_CLASS
47
+ if _RIFE_CLASS is not None:
48
+ return _RIFE_CLASS
49
+
50
+ with _IMPORT_LOCK:
51
+ if _RIFE_CLASS is not None:
52
+ return _RIFE_CLASS
53
+
54
+ this_dir = os.path.dirname(os.path.abspath(__file__))
55
+ custom_nodes_dir = os.path.abspath(os.path.join(this_dir, ".."))
56
+ cfi_dir = os.path.join(custom_nodes_dir, "ComfyUI-Frame-Interpolation")
57
+
58
+ if not os.path.isdir(cfi_dir):
59
+ raise FileNotFoundError(
60
+ f"Could not find ComfyUI-Frame-Interpolation folder at:\n {cfi_dir}\n"
61
+ f"Expected it at:\n {os.path.join(custom_nodes_dir, 'ComfyUI-Frame-Interpolation')}"
62
+ )
63
+
64
+ if cfi_dir not in sys.path:
65
+ sys.path.insert(0, cfi_dir)
66
+
67
+ rife_mod = importlib.import_module("vfi_models.rife")
68
+ rife_cls = getattr(rife_mod, "RIFE_VFI", None)
69
+ if rife_cls is None:
70
+ raise ImportError("vfi_models.rife imported, but RIFE_VFI class was not found.")
71
+
72
+ _RIFE_CLASS = rife_cls
73
+ return _RIFE_CLASS
74
+
75
+
76
+ def _run_rife(frames_bhwc: torch.Tensor, multiplier: int) -> torch.Tensor:
77
+ """
78
+ frames_bhwc: [2,H,W,C]
79
+ returns: [multiplier+1, H, W, C] (usually includes originals at ends)
80
+ """
81
+ RIFE_VFI = _lazy_get_rife_class()
82
+ rife_node = RIFE_VFI()
83
+
84
+ out = rife_node.vfi(
85
+ ckpt_name=_HARDCODED_CKPT_NAME,
86
+ frames=frames_bhwc,
87
+ clear_cache_after_n_frames=_HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES,
88
+ multiplier=int(multiplier),
89
+ fast_mode=_HARDCODED_FAST_MODE,
90
+ ensemble=_HARDCODED_ENSEMBLE,
91
+ scale_factor=_HARDCODED_SCALE_FACTOR,
92
+ optional_interpolation_states=None,
93
+ )
94
+
95
+ # Some versions may return (IMAGE,) or (IMAGE, states). We only want the IMAGE.
96
+ if isinstance(out, (tuple, list)):
97
+ return out[0]
98
+ return out
99
+
100
+
101
+ # ============================================================
102
+ # Image helpers
103
+ # ============================================================
104
+
105
+ def _bhwc_to_nchw(img: torch.Tensor) -> torch.Tensor:
106
+ if img.dim() != 4:
107
+ raise ValueError(f"Expected IMAGE tensor [B,H,W,C], got {tuple(img.shape)}")
108
+ return img.permute(0, 3, 1, 2).contiguous()
109
+
110
+ def _nchw_to_bhwc(img: torch.Tensor) -> torch.Tensor:
111
+ if img.dim() != 4:
112
+ raise ValueError(f"Expected NCHW tensor [B,C,H,W], got {tuple(img.shape)}")
113
+ return img.permute(0, 2, 3, 1).contiguous()
114
+
115
+ def _drop_alpha_if_any(x: torch.Tensor) -> torch.Tensor:
116
+ if x.shape[1] > 3:
117
+ return x[:, :3, :, :].contiguous()
118
+ return x
119
+
120
+ def _ensure_3ch(x: torch.Tensor) -> torch.Tensor:
121
+ if x.shape[1] == 1:
122
+ return x.repeat(1, 3, 1, 1)
123
+ return x
124
+
125
+ def _to_luma(x: torch.Tensor) -> torch.Tensor:
126
+ if x.shape[1] == 1:
127
+ return x
128
+ r = x[:, 0:1, :, :]
129
+ g = x[:, 1:2, :, :]
130
+ b = x[:, 2:3, :, :]
131
+ return (0.2989 * r + 0.5870 * g + 0.1140 * b)
132
+
133
+ def _resize_max(x: torch.Tensor, max_size: int) -> torch.Tensor:
134
+ if max_size <= 0:
135
+ return x
136
+ b, c, h, w = x.shape
137
+ m = max(h, w)
138
+ if m <= max_size:
139
+ return x
140
+ scale = max_size / float(m)
141
+ nh = max(1, int(round(h * scale)))
142
+ nw = max(1, int(round(w * scale)))
143
+ return F.interpolate(x, size=(nh, nw), mode="bilinear", align_corners=False)
144
+
145
+ def _gaussian_blur(x: torch.Tensor, sigma: float) -> torch.Tensor:
146
+ if sigma <= 0:
147
+ return x
148
+
149
+ radius = int(max(1, round(3.0 * sigma)))
150
+ ksize = 2 * radius + 1
151
+ device = x.device
152
+ dtype = x.dtype
153
+
154
+ coords = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
155
+ kernel1d = torch.exp(-(coords * coords) / (2.0 * sigma * sigma))
156
+ kernel1d = kernel1d / (kernel1d.sum() + 1e-12)
157
+
158
+ c = x.shape[1]
159
+ kh = kernel1d.view(1, 1, 1, ksize).repeat(c, 1, 1, 1)
160
+ kv = kernel1d.view(1, 1, ksize, 1).repeat(c, 1, 1, 1)
161
+
162
+ out = F.conv2d(x, kh, padding=(0, radius), groups=c)
163
+ out = F.conv2d(out, kv, padding=(radius, 0), groups=c)
164
+ return out
165
+
166
+ def _sobel_edges(y: torch.Tensor) -> torch.Tensor:
167
+ device = y.device
168
+ dtype = y.dtype
169
+ c = y.shape[1]
170
+
171
+ kx = torch.tensor(
172
+ [[-1, 0, 1],
173
+ [-2, 0, 2],
174
+ [-1, 0, 1]],
175
+ device=device, dtype=dtype
176
+ ) / 8.0
177
+
178
+ ky = torch.tensor(
179
+ [[-1, -2, -1],
180
+ [ 0, 0, 0],
181
+ [ 1, 2, 1]],
182
+ device=device, dtype=dtype
183
+ ) / 8.0
184
+
185
+ kx = kx.view(1, 1, 3, 3).repeat(c, 1, 1, 1)
186
+ ky = ky.view(1, 1, 3, 3).repeat(c, 1, 1, 1)
187
+
188
+ gx = F.conv2d(y, kx, padding=1, groups=c)
189
+ gy = F.conv2d(y, ky, padding=1, groups=c)
190
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
191
+
192
+
193
+ # ============================================================
194
+ # SSIM (vectorized for batch of pairs)
195
+ # ============================================================
196
+
197
+ def _make_ssim_kernel(device, dtype, window_size: int = 11, sigma: float = 1.5):
198
+ radius = window_size // 2
199
+ coords = torch.arange(window_size, device=device, dtype=dtype) - radius
200
+ g = torch.exp(-(coords * coords) / (2.0 * sigma * sigma))
201
+ g = g / (g.sum() + 1e-12)
202
+ w2d = (g[:, None] * g[None, :]).view(1, 1, window_size, window_size)
203
+ return w2d, radius
204
+
205
+ def _ssim_batch_luma(x: torch.Tensor, y: torch.Tensor, w2d: torch.Tensor, radius: int) -> torch.Tensor:
206
+ """
207
+ x,y: [N,1,H,W]
208
+ returns: [N] ssim values
209
+ """
210
+ C1 = (0.01) ** 2
211
+ C2 = (0.03) ** 2
212
+
213
+ mu_x = F.conv2d(x, w2d, padding=radius, groups=1)
214
+ mu_y = F.conv2d(y, w2d, padding=radius, groups=1)
215
+
216
+ mu_x2 = mu_x * mu_x
217
+ mu_y2 = mu_y * mu_y
218
+ mu_xy = mu_x * mu_y
219
+
220
+ sigma_x2 = F.conv2d(x * x, w2d, padding=radius, groups=1) - mu_x2
221
+ sigma_y2 = F.conv2d(y * y, w2d, padding=radius, groups=1) - mu_y2
222
+ sigma_xy = F.conv2d(x * y, w2d, padding=radius, groups=1) - mu_xy
223
+
224
+ num = (2.0 * mu_xy + C1) * (2.0 * sigma_xy + C2)
225
+ den = (mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2)
226
+ ssim_map = num / (den + 1e-12)
227
+ return ssim_map.mean(dim=[1, 2, 3])
228
+
229
+
230
+ # ============================================================
231
+ # Histogram (per-frame) + chi2 between frames
232
+ # ============================================================
233
+
234
+ def _compute_histograms(rgb_resized: torch.Tensor, bins: int) -> torch.Tensor:
235
+ """
236
+ rgb_resized: [B,3,H,W] in [0,1]
237
+ returns hist: [B,3,bins] normalized
238
+ Uses torch.histc. If device histc fails, falls back to CPU.
239
+ """
240
+ eps = 1e-12
241
+ B = rgb_resized.shape[0]
242
+ device = rgb_resized.device
243
+
244
+ try:
245
+ h = torch.zeros((B, 3, bins), device=device, dtype=torch.float32)
246
+ for i in range(B):
247
+ for c in range(3):
248
+ hc = torch.histc(rgb_resized[i, c], bins=bins, min=0.0, max=1.0)
249
+ hc = hc / (hc.sum() + eps)
250
+ h[i, c] = hc
251
+ return h
252
+ except Exception:
253
+ rgb_cpu = rgb_resized.detach().float().cpu()
254
+ h_cpu = torch.zeros((B, 3, bins), device="cpu", dtype=torch.float32)
255
+ for i in range(B):
256
+ for c in range(3):
257
+ hc = torch.histc(rgb_cpu[i, c], bins=bins, min=0.0, max=1.0)
258
+ hc = hc / (hc.sum() + eps)
259
+ h_cpu[i, c] = hc
260
+ return h_cpu.to(device)
261
+
262
+
263
+ def _chi2_from_hist(h1: torch.Tensor, h2: torch.Tensor) -> torch.Tensor:
264
+ """
265
+ h1,h2: [...,3,bins]
266
+ returns: [...] chi2 distance averaged across channels
267
+ """
268
+ eps = 1e-12
269
+ diff2 = (h1 - h2) ** 2
270
+ denom = (h1 + h2 + eps)
271
+ chi = 0.5 * torch.sum(diff2 / denom, dim=-1) # sum over bins -> [...,3]
272
+ return torch.mean(chi, dim=-1) # avg over channels -> [...]
273
+
274
+
275
+ # ============================================================
276
+ # Preprocess + HYBRID scores
277
+ # ============================================================
278
+
279
+ class _Pre:
280
+ def __init__(self, rgb_resized, rgb_blur, luma_blur, edges, hist, w2d, radius):
281
+ self.rgb_resized = rgb_resized # [B,3,h,w]
282
+ self.rgb_blur = rgb_blur # [B,3,h,w]
283
+ self.luma_blur = luma_blur # [B,1,h,w]
284
+ self.edges = edges # [B,1,h,w]
285
+ self.hist = hist # [B,3,bins]
286
+ self.w2d = w2d
287
+ self.radius = radius
288
+
289
+ def _preprocess(images_bhwc: torch.Tensor) -> _Pre:
290
+ x = _bhwc_to_nchw(images_bhwc)
291
+ x = _drop_alpha_if_any(x).clamp(0.0, 1.0)
292
+ x = _ensure_3ch(x)
293
+
294
+ rgb_resized = _resize_max(x, _DOWNSCALE_MAX)
295
+ rgb_blur = _gaussian_blur(rgb_resized, _BLUR_SIGMA)
296
+ luma_blur = _to_luma(rgb_blur)
297
+ edges = _sobel_edges(luma_blur)
298
+
299
+ hist = _compute_histograms(rgb_resized, _HIST_BINS)
300
+ w2d, radius = _make_ssim_kernel(device=luma_blur.device, dtype=luma_blur.dtype)
301
+ return _Pre(rgb_resized, rgb_blur, luma_blur, edges, hist, w2d, radius)
302
+
303
+ def _hybrid_scores_adj(pre: _Pre) -> torch.Tensor:
304
+ """
305
+ returns scores for adjacent pairs: [B-1] (scaled by _SCALE)
306
+ """
307
+ B = pre.rgb_blur.shape[0]
308
+ if B <= 1:
309
+ return torch.zeros((0,), device=pre.rgb_blur.device, dtype=torch.float32)
310
+
311
+ # Pixel MAE on blurred RGB
312
+ pix = torch.mean(torch.abs(pre.rgb_blur[:-1] - pre.rgb_blur[1:]), dim=[1, 2, 3]) # [B-1]
313
+
314
+ # SSIM diff on blurred luma
315
+ ssim_vals = _ssim_batch_luma(pre.luma_blur[:-1], pre.luma_blur[1:], pre.w2d, pre.radius) # [B-1]
316
+ ssim_diff = torch.clamp(1.0 - ssim_vals, min=0.0)
317
+
318
+ # Edge MAE
319
+ ed = torch.mean(torch.abs(pre.edges[:-1] - pre.edges[1:]), dim=[1, 2, 3])
320
+
321
+ # Hist chi2
322
+ hist = _chi2_from_hist(pre.hist[:-1], pre.hist[1:]) # [B-1]
323
+
324
+ score = (_W_PIXEL * pix) + (_W_SSIM * ssim_diff) + (_W_EDGE * ed) + (_W_HIST * hist)
325
+ return score * _SCALE
326
+
327
+ def _hybrid_score_pair(pre: _Pre, i: int, j: int) -> float:
328
+ pix = torch.mean(torch.abs(pre.rgb_blur[i] - pre.rgb_blur[j]))
329
+ ssim_val = _ssim_batch_luma(pre.luma_blur[i:i+1], pre.luma_blur[j:j+1], pre.w2d, pre.radius)[0]
330
+ ssim_diff = torch.clamp(1.0 - ssim_val, min=0.0)
331
+ ed = torch.mean(torch.abs(pre.edges[i] - pre.edges[j]))
332
+ hist = _chi2_from_hist(pre.hist[i:i+1], pre.hist[j:j+1])[0]
333
+ score = (_W_PIXEL * pix) + (_W_SSIM * ssim_diff) + (_W_EDGE * ed) + (_W_HIST * hist)
334
+ return float(score.item() * _SCALE)
335
+
336
+ def _hybrid_scores_to_anchor(pre: _Pre, anchor_idx: int, cand_indices: List[int]) -> torch.Tensor:
337
+ """
338
+ returns [N] scores (scaled) between anchor and each candidate
339
+ """
340
+ device = pre.rgb_blur.device
341
+ if len(cand_indices) == 0:
342
+ return torch.zeros((0,), device=device, dtype=torch.float32)
343
+
344
+ idx = torch.tensor(cand_indices, device=device, dtype=torch.long)
345
+
346
+ # gather candidates
347
+ rgb_c = pre.rgb_blur.index_select(0, idx) # [N,3,h,w]
348
+ luma_c = pre.luma_blur.index_select(0, idx) # [N,1,h,w]
349
+ edge_c = pre.edges.index_select(0, idx) # [N,1,h,w]
350
+ hist_c = pre.hist.index_select(0, idx) # [N,3,bins]
351
+
352
+ rgb_a = pre.rgb_blur[anchor_idx].unsqueeze(0).expand_as(rgb_c)
353
+ luma_a = pre.luma_blur[anchor_idx].unsqueeze(0).expand_as(luma_c)
354
+ edge_a = pre.edges[anchor_idx].unsqueeze(0).expand_as(edge_c)
355
+ hist_a = pre.hist[anchor_idx].unsqueeze(0).expand_as(hist_c)
356
+
357
+ pix = torch.mean(torch.abs(rgb_c - rgb_a), dim=[1, 2, 3]) # [N]
358
+ ssim_vals = _ssim_batch_luma(luma_a, luma_c, pre.w2d, pre.radius)
359
+ ssim_diff = torch.clamp(1.0 - ssim_vals, min=0.0)
360
+ ed = torch.mean(torch.abs(edge_c - edge_a), dim=[1, 2, 3])
361
+ hist = _chi2_from_hist(hist_c, hist_a)
362
+
363
+ score = (_W_PIXEL * pix) + (_W_SSIM * ssim_diff) + (_W_EDGE * ed) + (_W_HIST * hist)
364
+ return score * _SCALE
365
+
366
+
367
+ # ============================================================
368
+ # The Node
369
+ # ============================================================
370
+
371
+ class LoopCleanRifeFill51:
372
+ """
373
+ 1) Remove frozen tail
374
+ 2) Remove frozen frames across whole batch (dedup pass)
375
+ 3) Crop to looping segment [anchor .. best_end]
376
+ 4) Repeatedly insert RIFE interpolated frames into highest-diff adjacent gap
377
+ 5) Stop at target_frames
378
+ """
379
+
380
+ @classmethod
381
+ def INPUT_TYPES(cls):
382
+ return {
383
+ "required": {
384
+ "images": ("IMAGE",),
385
+
386
+ # you explicitly wanted this configurable
387
+ "loop_anchor": ("INT", {"default": 9, "min": 0, "max": 4096, "step": 1}),
388
+
389
+ # tail search window for loop end matching
390
+ "loop_tail_search": ("INT", {"default": 15, "min": 1, "max": 512, "step": 1}),
391
+
392
+ # keep as input (default 3.0), since you might tune this per dataset
393
+ "freeze_threshold": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
394
+
395
+ # final length
396
+ "target_frames": ("INT", {"default": 51, "min": 1, "max": 1000, "step": 1}),
397
+
398
+ # multiplier behavior (you mainly asked 2 or 3)
399
+ "max_multiplier": ("INT", {"default": 3, "min": 2, "max": 8, "step": 1}),
400
+ "big_gap_threshold": ("FLOAT", {"default": 20.0, "min": 0.0, "max": 10000.0, "step": 0.5}),
401
+ }
402
+ }
403
+
404
+ RETURN_TYPES = ("IMAGE",)
405
+ RETURN_NAMES = ("images",)
406
+ FUNCTION = "process"
407
+ CATEGORY = "image/analysis"
408
+
409
+ def process(
410
+ self,
411
+ images: torch.Tensor,
412
+ loop_anchor: int,
413
+ loop_tail_search: int,
414
+ freeze_threshold: float,
415
+ target_frames: int,
416
+ max_multiplier: int,
417
+ big_gap_threshold: float,
418
+ ):
419
+ # ---------------------------
420
+ # Basic sanity
421
+ # ---------------------------
422
+ if images.dim() != 4:
423
+ raise ValueError(f"Expected IMAGE [B,H,W,C], got {tuple(images.shape)}")
424
+ B = images.shape[0]
425
+ if B <= 1:
426
+ # If it's a single frame, just repeat to target (rare / your "should never happen")
427
+ if target_frames > 1:
428
+ images = images.repeat(target_frames, 1, 1, 1)
429
+ return (images[:target_frames],)
430
+
431
+ # =====================================================
432
+ # 1) Remove frozen tail
433
+ # =====================================================
434
+ pre = _preprocess(images)
435
+ scores_adj = _hybrid_scores_adj(pre) # [B-1]
436
+ keep_last = images.shape[0] - 1
437
+ while keep_last > 0 and float(scores_adj[keep_last - 1].item()) < freeze_threshold:
438
+ keep_last -= 1
439
+ images = images[: keep_last + 1]
440
+
441
+ if images.shape[0] <= 1:
442
+ if target_frames > 1:
443
+ images = images.repeat(target_frames, 1, 1, 1)
444
+ return (images[:target_frames],)
445
+
446
+ # =====================================================
447
+ # 2) Remove frozen frames across entire batch (dedup)
448
+ # Remove only ONE of a frozen pair => drop the later one.
449
+ # =====================================================
450
+ pre = _preprocess(images)
451
+ keep: List[int] = [0]
452
+ last_kept = 0
453
+ for i in range(1, images.shape[0]):
454
+ sc = _hybrid_score_pair(pre, last_kept, i)
455
+ if sc >= freeze_threshold:
456
+ keep.append(i)
457
+ last_kept = i
458
+
459
+ keep_t = torch.tensor(keep, device=images.device, dtype=torch.long)
460
+ images = images.index_select(0, keep_t)
461
+
462
+ if images.shape[0] <= 1:
463
+ if target_frames > 1:
464
+ images = images.repeat(target_frames, 1, 1, 1)
465
+ return (images[:target_frames],)
466
+
467
+ # =====================================================
468
+ # 3) Crop to looping segment using anchor + closest end
469
+ # =====================================================
470
+ L = images.shape[0]
471
+ anchor = int(max(0, min(loop_anchor, L - 1)))
472
+
473
+ # Candidates are from the last N frames, but must be > anchor.
474
+ tail_start = max(anchor + 1, L - int(loop_tail_search))
475
+ if tail_start <= L - 1:
476
+ cand = list(range(L - 1, tail_start - 1, -1)) # reverse from end
477
+ pre = _preprocess(images)
478
+ scores = _hybrid_scores_to_anchor(pre, anchor, cand) # [N]
479
+ best_k = int(torch.argmin(scores).item())
480
+ end_idx = cand[best_k]
481
+ if end_idx <= anchor:
482
+ # fallback: keep from anchor to end
483
+ images = images[anchor:]
484
+ else:
485
+ images = images[anchor : end_idx + 1]
486
+ else:
487
+ # No candidates after anchor; fallback: keep from anchor to end
488
+ images = images[anchor:]
489
+
490
+ if images.shape[0] <= 1:
491
+ if target_frames > 1:
492
+ images = images.repeat(target_frames, 1, 1, 1)
493
+ return (images[:target_frames],)
494
+
495
+ # =====================================================
496
+ # 4+5) Insert RIFE interpolations into highest gap until target_frames
497
+ # =====================================================
498
+ # Clamp max_multiplier (at least 2)
499
+ max_multiplier = int(max(2, max_multiplier))
500
+
501
+ safety = 0
502
+ while images.shape[0] < target_frames:
503
+ safety += 1
504
+ if safety > 500:
505
+ # Prevent infinite loops in pathological cases
506
+ break
507
+
508
+ n = images.shape[0]
509
+ if n < 2:
510
+ break
511
+
512
+ pre = _preprocess(images)
513
+ scores_adj = _hybrid_scores_adj(pre) # [n-1]
514
+ if scores_adj.numel() == 0:
515
+ break
516
+
517
+ # Highest-diff adjacent pair
518
+ idx = int(torch.argmax(scores_adj).item())
519
+ max_score = float(scores_adj[idx].item())
520
+
521
+ remaining = target_frames - n
522
+
523
+ # Choose multiplier (mostly 2, sometimes 3+ if gap is large and we have room)
524
+ m = 2
525
+ if remaining >= 2 and max_multiplier >= 3 and max_score >= big_gap_threshold:
526
+ m = 3
527
+
528
+ # If we still have lots of room, allow higher multipliers up to max_multiplier
529
+ # (optional, but useful if the batch got really short)
530
+ # Inserts (m-1) frames.
531
+ if remaining >= 3 and max_multiplier > 3 and max_score >= big_gap_threshold:
532
+ # try to use as much as we can without overshooting
533
+ m = min(max_multiplier, remaining + 1)
534
+
535
+ # Never overshoot target
536
+ if (m - 1) > remaining:
537
+ m = remaining + 1
538
+ m = int(max(2, m))
539
+
540
+ # Run RIFE on the pair (batch of 2)
541
+ pair = images[idx : idx + 2] # [2,H,W,C]
542
+ rife_out = _run_rife(pair, multiplier=m) # [m+1,H,W,C] typically
543
+
544
+ # Take only the inserted frames (exclude first and last originals)
545
+ inserted = rife_out[1:-1] # [m-1,H,W,C]
546
+ if inserted.shape[0] == 0:
547
+ # fallback: if something weird happens, just stop
548
+ break
549
+
550
+ # If we would overshoot due to some mismatch, clamp inserted
551
+ if inserted.shape[0] > remaining:
552
+ inserted = inserted[:remaining]
553
+
554
+ # Insert into the batch between idx and idx+1
555
+ images = torch.cat([images[:idx+1], inserted, images[idx+1:]], dim=0)
556
+
557
+ # If we somehow overshot (shouldn't), clamp
558
+ images = images[:target_frames]
559
+ return (images,)
560
+
561
+
562
+ NODE_CLASS_MAPPINGS = {
563
+ "LoopCleanRifeFill51": LoopCleanRifeFill51,
564
+ }
565
+
566
+ NODE_DISPLAY_NAME_MAPPINGS = {
567
+ "LoopCleanRifeFill51": "Loop Clean + RIFE Fill to 51 (Hybrid hardcoded)",
568
+ }