saliacoel commited on
Commit
3fb7a98
·
verified ·
1 Parent(s): a336c71

Upload salia_special_rife_batch.py

Browse files
Files changed (1) hide show
  1. salia_special_rife_batch.py +274 -0
salia_special_rife_batch.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ import importlib
6
+ import threading
7
+ from typing import Dict, Optional, Tuple
8
+
9
+ import torch
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Lazy import wrapper around rife_lazy.py (which itself lazy-loads Frame-Interpolation)
13
+ # -----------------------------------------------------------------------------
14
+ _RIFE_IMPORT_LOCK = threading.Lock()
15
+ _RIFE_LAZY_MOD = None
16
+ _RIFE_VFI_CLASS = None
17
+
18
+
19
+ def _lazy_import_rife_lazy_module():
20
+ """
21
+ Import rife_lazy.py only when this node is executed.
22
+ rife_lazy.py is expected to be in the same folder as this file.
23
+ """
24
+ global _RIFE_LAZY_MOD
25
+ if _RIFE_LAZY_MOD is not None:
26
+ return _RIFE_LAZY_MOD
27
+
28
+ with _RIFE_IMPORT_LOCK:
29
+ if _RIFE_LAZY_MOD is not None:
30
+ return _RIFE_LAZY_MOD
31
+
32
+ this_dir = os.path.dirname(os.path.abspath(__file__))
33
+ if this_dir not in sys.path:
34
+ sys.path.insert(0, this_dir)
35
+
36
+ _RIFE_LAZY_MOD = importlib.import_module("rife_lazy")
37
+ return _RIFE_LAZY_MOD
38
+
39
+
40
+ def _lazy_get_rife_vfi_class():
41
+ """
42
+ Uses rife_lazy._lazy_get_rife_class() to obtain the real RIFE_VFI class.
43
+ """
44
+ global _RIFE_VFI_CLASS
45
+ if _RIFE_VFI_CLASS is not None:
46
+ return _RIFE_VFI_CLASS
47
+
48
+ mod = _lazy_import_rife_lazy_module()
49
+
50
+ with _RIFE_IMPORT_LOCK:
51
+ if _RIFE_VFI_CLASS is not None:
52
+ return _RIFE_VFI_CLASS
53
+ _RIFE_VFI_CLASS = mod._lazy_get_rife_class()
54
+ return _RIFE_VFI_CLASS
55
+
56
+
57
+ def _unwrap_image_output(result):
58
+ """
59
+ Many ComfyUI nodes return tuples. We only want the IMAGE output.
60
+ """
61
+ if isinstance(result, (tuple, list)):
62
+ return result[0]
63
+ return result
64
+
65
+
66
+ def _normalize_to_batch(images: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Accept [H,W,C] and convert to [1,H,W,C].
69
+ """
70
+ if isinstance(images, torch.Tensor) and images.dim() == 3:
71
+ return images.unsqueeze(0)
72
+ return images
73
+
74
+
75
+ def _torch_inference_context():
76
+ """
77
+ Use torch.inference_mode if available, otherwise fall back to torch.no_grad.
78
+ """
79
+ return torch.inference_mode() if hasattr(torch, "inference_mode") else torch.no_grad()
80
+
81
+
82
+ # -----------------------------------------------------------------------------
83
+ # Node
84
+ # -----------------------------------------------------------------------------
85
+ class SALIA_SPECIAL_BATCH_RIFE47:
86
+ """
87
+ Input: IMAGE batch
88
+
89
+ Pipeline:
90
+ 1) Custom_Batch_Output split:
91
+ - Batch_UP = [7] + [9..25] + [27..31] + [33..36]
92
+ - Rife_x3 = [37, 4]
93
+ 2) Rife_x3 -> RIFE(rife47, mult=3) => RifeOutput
94
+ 3) almost_final_batch = concat(Batch_UP, RifeOutput)
95
+
96
+ 4) Extra inserts (indices refer to almost_final_batch *before* inserts):
97
+ - DUO_14_15: interpolate (14,15) with mult=3 => keep middle 2 frames => insert after 14
98
+ - SINGLE_25_26: interpolate (25,26) with mult=2 => keep middle 1 frame => insert after 25
99
+ - SINGLE_30_1: interpolate (30, 1) with mult=2 => keep middle 1 frame => insert after 30
100
+
101
+ Output: IMAGE batch
102
+ """
103
+
104
+ CATEGORY = "salia_online/VFI"
105
+ FUNCTION = "process"
106
+ RETURN_TYPES = ("IMAGE",)
107
+ RETURN_NAMES = ("IMAGE",)
108
+
109
+ @classmethod
110
+ def INPUT_TYPES(cls):
111
+ return {"required": {"images": ("IMAGE",)}}
112
+
113
+ # ---- Custom_Batch_Output indices (0-based) ----
114
+ _BATCH_UP_INDICES = (
115
+ [7]
116
+ + list(range(9, 26)) # 9..25
117
+ + list(range(27, 32)) # 27..31
118
+ + list(range(33, 37)) # 33..36
119
+ )
120
+ _RIFE_X3_INDICES = [37, 4]
121
+
122
+ # ---- Insert specs (based on almost_final_batch BEFORE inserts) ----
123
+ _DUO_PAIR: Tuple[int, int] = (14, 15)
124
+ _DUO_MULT: int = 3
125
+
126
+ _SINGLE_25_26_PAIR: Tuple[int, int] = (25, 26)
127
+ _SINGLE_25_26_MULT: int = 2
128
+
129
+ _SINGLE_30_1_PAIR: Tuple[int, int] = (30, 1)
130
+ _SINGLE_30_1_MULT: int = 2
131
+
132
+ def _call_rife(self, rife_node, rife_lazy_mod, frames: torch.Tensor, multiplier: int) -> torch.Tensor:
133
+ """
134
+ Call the underlying RIFE_VFI node with the hardcoded params from rife_lazy.py.
135
+ """
136
+ result = rife_node.vfi(
137
+ ckpt_name=getattr(rife_lazy_mod, "_HARDCODED_CKPT_NAME", "rife47.pth"),
138
+ frames=frames,
139
+ clear_cache_after_n_frames=getattr(rife_lazy_mod, "_HARDCODED_CLEAR_CACHE_AFTER_N_FRAMES", 10),
140
+ multiplier=int(multiplier),
141
+ fast_mode=getattr(rife_lazy_mod, "_HARDCODED_FAST_MODE", True),
142
+ ensemble=getattr(rife_lazy_mod, "_HARDCODED_ENSEMBLE", True),
143
+ scale_factor=getattr(rife_lazy_mod, "_HARDCODED_SCALE_FACTOR", 1.0),
144
+ optional_interpolation_states=None,
145
+ )
146
+ out = _unwrap_image_output(result)
147
+ if not isinstance(out, torch.Tensor):
148
+ raise TypeError("RIFE output was not a torch.Tensor")
149
+ out = _normalize_to_batch(out)
150
+ if out.dim() != 4:
151
+ raise ValueError(f"RIFE output must be [B,H,W,C], got shape: {tuple(out.shape)}")
152
+ return out
153
+
154
+ def _rife_middle_frames(
155
+ self,
156
+ rife_node,
157
+ rife_lazy_mod,
158
+ base: torch.Tensor,
159
+ i: int,
160
+ j: int,
161
+ multiplier: int,
162
+ ) -> Optional[torch.Tensor]:
163
+ """
164
+ Interpolate between base[i] and base[j], then discard endpoints and return only middle frames:
165
+ - mult=3 => output len 4 => middle 2 frames
166
+ - mult=2 => output len 3 => middle 1 frame
167
+ """
168
+ b = int(base.shape[0])
169
+ if i < 0 or j < 0 or i >= b or j >= b:
170
+ return None
171
+
172
+ device = base.device
173
+ idx = torch.tensor([i, j], dtype=torch.long, device=device)
174
+ pair = torch.index_select(base, 0, idx)
175
+
176
+ out = self._call_rife(rife_node, rife_lazy_mod, pair, multiplier)
177
+
178
+ # keep only middle frames
179
+ if out.shape[0] < 3:
180
+ return None
181
+ mid = out[1:-1]
182
+ if mid.shape[0] == 0:
183
+ return None
184
+ return mid
185
+
186
+ def _insert_after_indices(self, base: torch.Tensor, inserts: Dict[int, torch.Tensor]) -> torch.Tensor:
187
+ """
188
+ inserts maps "base_index -> batch_of_frames_to_insert_after_that_index"
189
+ This avoids index-shift confusion because inserts are applied relative to the original base.
190
+ """
191
+ parts = []
192
+ for i in range(int(base.shape[0])):
193
+ parts.append(base[i : i + 1])
194
+ extra = inserts.get(i, None)
195
+ if extra is not None:
196
+ extra = _normalize_to_batch(extra)
197
+ parts.append(extra)
198
+ return torch.cat(parts, dim=0)
199
+
200
+ def process(self, images):
201
+ # Safety / validation
202
+ if not isinstance(images, torch.Tensor):
203
+ return (images,)
204
+
205
+ images = _normalize_to_batch(images)
206
+ if images.dim() != 4:
207
+ return (images,)
208
+
209
+ b = int(images.shape[0])
210
+ # Must be able to address index 37 (needs B >= 38)
211
+ if b < 38:
212
+ return (images,)
213
+
214
+ device = images.device
215
+
216
+ # Custom_Batch_Output split
217
+ idx_up = torch.tensor(self._BATCH_UP_INDICES, dtype=torch.long, device=device)
218
+ idx_rife = torch.tensor(self._RIFE_X3_INDICES, dtype=torch.long, device=device)
219
+
220
+ batch_up = torch.index_select(images, 0, idx_up)
221
+ rife_x3 = torch.index_select(images, 0, idx_rife)
222
+
223
+ # Lazy-load RIFE
224
+ rife_lazy_mod = _lazy_import_rife_lazy_module()
225
+ RIFE_VFI = _lazy_get_rife_vfi_class()
226
+ rife_node = RIFE_VFI()
227
+
228
+ with _torch_inference_context():
229
+ # Rife_x3 -> RIFE mult=3
230
+ rife_output = self._call_rife(rife_node, rife_lazy_mod, rife_x3, multiplier=3)
231
+
232
+ # almost_final_batch = Batch_UP + RifeOutput
233
+ almost_final = torch.cat([batch_up, rife_output], dim=0)
234
+
235
+ # Need indices up to 30 and also index 1 for the last interpolation
236
+ if int(almost_final.shape[0]) <= 30:
237
+ return (almost_final,)
238
+
239
+ # Build inserts from the original almost_final (pre-insert)
240
+ inserts: Dict[int, torch.Tensor] = {}
241
+
242
+ duo_mid = self._rife_middle_frames(
243
+ rife_node, rife_lazy_mod, almost_final,
244
+ i=self._DUO_PAIR[0], j=self._DUO_PAIR[1], multiplier=self._DUO_MULT
245
+ )
246
+ if duo_mid is not None:
247
+ inserts[self._DUO_PAIR[0]] = duo_mid # insert after first index of the pair
248
+
249
+ single_25_26_mid = self._rife_middle_frames(
250
+ rife_node, rife_lazy_mod, almost_final,
251
+ i=self._SINGLE_25_26_PAIR[0], j=self._SINGLE_25_26_PAIR[1], multiplier=self._SINGLE_25_26_MULT
252
+ )
253
+ if single_25_26_mid is not None:
254
+ inserts[self._SINGLE_25_26_PAIR[0]] = single_25_26_mid
255
+
256
+ single_30_1_mid = self._rife_middle_frames(
257
+ rife_node, rife_lazy_mod, almost_final,
258
+ i=self._SINGLE_30_1_PAIR[0], j=self._SINGLE_30_1_PAIR[1], multiplier=self._SINGLE_30_1_MULT
259
+ )
260
+ if single_30_1_mid is not None:
261
+ inserts[self._SINGLE_30_1_PAIR[0]] = single_30_1_mid
262
+
263
+ # Apply inserts without index-shift mistakes
264
+ final_batch = self._insert_after_indices(almost_final, inserts)
265
+ return (final_batch,)
266
+
267
+
268
+ NODE_CLASS_MAPPINGS = {
269
+ "SALIA_SPECIAL_BATCH_RIFE47": SALIA_SPECIAL_BATCH_RIFE47,
270
+ }
271
+
272
+ NODE_DISPLAY_NAME_MAPPINGS = {
273
+ "SALIA_SPECIAL_BATCH_RIFE47": "Special Batch + RIFE Inserts (rife47)",
274
+ }