saliacoel commited on
Commit
71aa491
·
verified ·
1 Parent(s): f8944c1

Upload Salia_RifeVFI_Insert.py

Browse files
Files changed (1) hide show
  1. Salia_RifeVFI_Insert.py +190 -0
Salia_RifeVFI_Insert.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom_nodes\comfyui-salia_online\nodes\rife_insert_between.py
2
+ #
3
+ # Node: Insert RIFE-generated in-between frames between two indices of a batch.
4
+ #
5
+ # Inputs:
6
+ # - batch (IMAGE): input batch of frames
7
+ # - start (INT): start index in batch
8
+ # - end (INT): end index in batch
9
+ # - multiplier (INT): number of in-between frames to INSERT (1 => insert 1 frame)
10
+ #
11
+ # Internals:
12
+ # - Extract batch[start] and batch[end]
13
+ # - Make a 2-frame batch and run ComfyUI-Frame-Interpolation's RIFE_VFI lazily
14
+ # - Call RIFE with (multiplier + 1) because upstream multiplier is a factor
15
+ # (2 frames, factor=2 => 1 middle frame; factor=3 => 2 middle frames; etc.)
16
+ # - Remove FIRST and LAST from RIFE output (keep only the in-betweens)
17
+ # - Insert in-betweens between start and end in the original batch
18
+ #
19
+ # Output:
20
+ # - IMAGE: new batch with inserted frames
21
+ #
22
+ # Notes:
23
+ # - If end > start+1, frames between (start+1 .. end-1) are REPLACED.
24
+ # (This matches "place inside between start and end" as immediate neighbors.)
25
+ #
26
+
27
+ from __future__ import annotations
28
+
29
+ import os
30
+ import sys
31
+ import importlib
32
+ import threading
33
+ from typing import Tuple
34
+
35
+ import torch
36
+
37
+ _IMPORT_LOCK = threading.Lock()
38
+ _RIFE_CLASS = None # cached class object (import-only cache)
39
+
40
+ # -----------------------------
41
+ # Hardcoded settings (match your lazy node)
42
+ # -----------------------------
43
+ _HARDCODED_CKPT_NAME = "rife47.pth"
44
+ _HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES = 10
45
+ _HARDCODED_FAST_MODE = True
46
+ _HARDCODED_ENSEMBLE = True
47
+ _HARDCODED_SCALE_FACTOR = 1.0
48
+
49
+
50
+ def _lazy_get_rife_class():
51
+ """
52
+ Lazily import ComfyUI-Frame-Interpolation's RIFE_VFI class without importing
53
+ the whole package at ComfyUI startup.
54
+ """
55
+ global _RIFE_CLASS
56
+ if _RIFE_CLASS is not None:
57
+ return _RIFE_CLASS
58
+
59
+ with _IMPORT_LOCK:
60
+ if _RIFE_CLASS is not None:
61
+ return _RIFE_CLASS
62
+
63
+ # This file lives at:
64
+ # ...\custom_nodes\comfyui-salia_online\nodes\rife_insert_between.py
65
+ # We want:
66
+ # ...\custom_nodes\ComfyUI-Frame-Interpolation
67
+ this_dir = os.path.dirname(os.path.abspath(__file__))
68
+ custom_nodes_dir = os.path.abspath(os.path.join(this_dir, "..", ".."))
69
+ cfi_dir = os.path.join(custom_nodes_dir, "ComfyUI-Frame-Interpolation")
70
+
71
+ if not os.path.isdir(cfi_dir):
72
+ raise FileNotFoundError(
73
+ f"Could not find ComfyUI-Frame-Interpolation folder at:\n {cfi_dir}\n"
74
+ f"Expected it at:\n {os.path.join(custom_nodes_dir, 'ComfyUI-Frame-Interpolation')}"
75
+ )
76
+
77
+ # Add the extension folder so:
78
+ # import vfi_models.rife
79
+ # and:
80
+ # import vfi_utils
81
+ # resolve correctly.
82
+ if cfi_dir not in sys.path:
83
+ sys.path.insert(0, cfi_dir)
84
+
85
+ rife_mod = importlib.import_module("vfi_models.rife")
86
+ rife_cls = getattr(rife_mod, "RIFE_VFI", None)
87
+ if rife_cls is None:
88
+ raise ImportError("vfi_models.rife imported, but RIFE_VFI class was not found.")
89
+
90
+ _RIFE_CLASS = rife_cls
91
+ return _RIFE_CLASS
92
+
93
+
94
+ class SALIA_RIFE_INSERT_BETWEEN:
95
+ @classmethod
96
+ def INPUT_TYPES(cls):
97
+ return {
98
+ "required": {
99
+ "batch": ("IMAGE",),
100
+ "start": ("INT", {"default": 0, "min": 0, "step": 1}),
101
+ "end": ("INT", {"default": 1, "min": 0, "step": 1}),
102
+ # user multiplier = number of inserted frames
103
+ "multiplier": ("INT", {"default": 1, "min": 1, "step": 1}),
104
+ }
105
+ }
106
+
107
+ RETURN_TYPES = ("IMAGE",)
108
+ RETURN_NAMES = ("IMAGE",)
109
+ FUNCTION = "insert"
110
+ CATEGORY = "salia_online/VFI"
111
+
112
+ def insert(self, batch: torch.Tensor, start: int, end: int, multiplier: int) -> Tuple[torch.Tensor]:
113
+ if batch is None or not hasattr(batch, "shape"):
114
+ raise ValueError("Input 'batch' must be an IMAGE tensor.")
115
+
116
+ if batch.shape[0] < 2:
117
+ raise ValueError(f"Input batch must have at least 2 frames, got {batch.shape[0]}.")
118
+
119
+ start = int(start)
120
+ end = int(end)
121
+ multiplier = int(multiplier)
122
+
123
+ n = int(batch.shape[0])
124
+ if not (0 <= start < n) or not (0 <= end < n):
125
+ raise ValueError(f"start/end out of range. batch has {n} frames, got start={start}, end={end}.")
126
+
127
+ if start == end:
128
+ raise ValueError("start and end must be different indices.")
129
+
130
+ if start > end:
131
+ raise ValueError(f"start must be < end. Got start={start}, end={end}.")
132
+
133
+ # Extract the two boundary frames
134
+ frame_start = batch[start:start + 1]
135
+ frame_end = batch[end:end + 1]
136
+
137
+ # Make a 2-frame batch for RIFE
138
+ frames = torch.cat([frame_start, frame_end], dim=0)
139
+
140
+ # Upstream RIFE multiplier is a *factor*:
141
+ # - 2 frames, factor=2 => output 3 frames => 1 in-between
142
+ # We want user multiplier = number of in-betweens,
143
+ # so factor = user_multiplier + 1
144
+ rife_multiplier = multiplier + 1
145
+
146
+ # Run RIFE lazily
147
+ RIFE_VFI = _lazy_get_rife_class()
148
+ rife_node = RIFE_VFI()
149
+
150
+ (rife_out,) = rife_node.vfi(
151
+ ckpt_name=_HARDCODED_CKPT_NAME,
152
+ frames=frames,
153
+ clear_cache_after_n_frames=_HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES,
154
+ multiplier=int(rife_multiplier),
155
+ fast_mode=_HARDCODED_FAST_MODE,
156
+ ensemble=_HARDCODED_ENSEMBLE,
157
+ scale_factor=_HARDCODED_SCALE_FACTOR,
158
+ optional_interpolation_states=None,
159
+ )
160
+
161
+ # Keep only the in-between frames: drop first and last
162
+ # (If something unexpected happens, this safely yields empty middle.)
163
+ middle = rife_out[1:-1] if rife_out.shape[0] >= 2 else rife_out[0:0]
164
+
165
+ # Optional sanity: ensure we got the expected number of inserted frames
166
+ # If it doesn't match, we still proceed with whatever RIFE returned.
167
+ # expected = multiplier
168
+ # if middle.shape[0] != expected:
169
+ # print(f"SALIA_RIFE_INSERT_BETWEEN: expected {expected} middle frames, got {middle.shape[0]}")
170
+
171
+ # Insert: keep frames up to start, then middle, then from end onward.
172
+ # This effectively REPLACES any existing frames between start and end.
173
+ before = batch[: start + 1]
174
+ after = batch[end:] # includes the end frame
175
+
176
+ # Match device if needed (usually everything is CPU)
177
+ if middle.device != before.device:
178
+ middle = middle.to(before.device)
179
+
180
+ out = torch.cat([before, middle, after], dim=0)
181
+ return (out,)
182
+
183
+
184
+ NODE_CLASS_MAPPINGS = {
185
+ "SALIA_RIFE_INSERT_BETWEEN": SALIA_RIFE_INSERT_BETWEEN,
186
+ }
187
+
188
+ NODE_DISPLAY_NAME_MAPPINGS = {
189
+ "SALIA_RIFE_INSERT_BETWEEN": "RIFE Insert Between (Lazy, hardcoded rife47)",
190
+ }