saliacoel commited on
Commit
11855b6
·
verified ·
1 Parent(s): 6f46f5e

Upload vfi_utils.py

Browse files
Files changed (1) hide show
  1. vfi_utils.py +397 -0
vfi_utils.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import os
3
+ from torch.hub import download_url_to_file, get_dir
4
+ from urllib.parse import urlparse
5
+ import torch
6
+ import typing
7
+ import traceback
8
+ import einops
9
+ import gc
10
+ import torchvision.transforms.functional as transform
11
+ from comfy.model_management import soft_empty_cache, get_torch_device
12
+ import numpy as np
13
+
14
+
15
+ # -----------------------------
16
+ # Config
17
+ # -----------------------------
18
+ config_path = os.path.join(os.path.dirname(__file__), "./config.yaml")
19
+ if os.path.exists(config_path):
20
+ config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) or {}
21
+ else:
22
+ raise Exception(
23
+ "config.yaml file is neccessary, plz recreate the config file by downloading it from "
24
+ "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation"
25
+ )
26
+
27
+ DEVICE = get_torch_device()
28
+
29
+
30
+ # -----------------------------
31
+ # Model download sources
32
+ # -----------------------------
33
+ # Original GitHub release bases (some RIFE checkpoints are no longer hosted there -> 404)
34
+ DEFAULT_BASE_MODEL_DOWNLOAD_URLS = [
35
+ "https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/",
36
+ "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/",
37
+ "https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/",
38
+ ]
39
+
40
+ # Optional: override via config.yaml:
41
+ # base_model_download_urls:
42
+ # - "https://..."
43
+ BASE_MODEL_DOWNLOAD_URLS = config.get("base_model_download_urls", DEFAULT_BASE_MODEL_DOWNLOAD_URLS)
44
+
45
+ # Optional: add extra base URLs via env var (comma-separated):
46
+ # COMFY_VFI_EXTRA_MODEL_DOWNLOAD_BASE_URLS="https://mirror1/.../,https://mirror2/.../"
47
+ _env_extra = os.getenv("COMFY_VFI_EXTRA_MODEL_DOWNLOAD_BASE_URLS", "")
48
+ if _env_extra.strip():
49
+ for u in [x.strip() for x in _env_extra.split(",") if x.strip()]:
50
+ if not u.endswith("/"):
51
+ u += "/"
52
+ if u not in BASE_MODEL_DOWNLOAD_URLS:
53
+ BASE_MODEL_DOWNLOAD_URLS.append(u)
54
+
55
+ # Optional: last-resort direct URL overrides per checkpoint.
56
+ # You asked for this HuggingFace mirror:
57
+ # https://huggingface.co/saliacoel/x/resolve/main/rife47.pth
58
+ #
59
+ # You can override it without editing code by setting:
60
+ # COMFY_VFI_RIFE47_URL="https://huggingface.co/.../rife47.pth"
61
+ #
62
+ # Or in config.yaml:
63
+ # ckpt_url_mirrors:
64
+ # rife47.pth: "https://huggingface.co/.../rife47.pth"
65
+ DEFAULT_CKPT_URL_MIRRORS = {
66
+ "rife47.pth": os.getenv(
67
+ "COMFY_VFI_RIFE47_URL",
68
+ "https://huggingface.co/saliacoel/x/resolve/main/rife47.pth",
69
+ ),
70
+ }
71
+ _ckpt_overrides = config.get("ckpt_url_mirrors", {}) or {}
72
+ CKPT_URL_MIRRORS = {**DEFAULT_CKPT_URL_MIRRORS, **_ckpt_overrides}
73
+
74
+
75
+ def _is_http_url(value: str) -> bool:
76
+ try:
77
+ parts = urlparse(value)
78
+ return parts.scheme in ("http", "https") and bool(parts.netloc)
79
+ except Exception:
80
+ return False
81
+
82
+
83
+ class InterpolationStateList:
84
+ def __init__(self, frame_indices: typing.List[int], is_skip_list: bool):
85
+ self.frame_indices = frame_indices
86
+ self.is_skip_list = is_skip_list
87
+
88
+ def is_frame_skipped(self, frame_index):
89
+ is_frame_in_list = frame_index in self.frame_indices
90
+ return self.is_skip_list and is_frame_in_list or (not self.is_skip_list and not is_frame_in_list)
91
+
92
+
93
+ class MakeInterpolationStateList:
94
+ @classmethod
95
+ def INPUT_TYPES(s):
96
+ return {
97
+ "required": {
98
+ "frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}),
99
+ "is_skip_list": ("BOOLEAN", {"default": True}),
100
+ },
101
+ }
102
+
103
+ RETURN_TYPES = ("INTERPOLATION_STATES",)
104
+ FUNCTION = "create_options"
105
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
106
+
107
+ def create_options(self, frame_indices: str, is_skip_list: bool):
108
+ frame_indices_list = [int(item) for item in frame_indices.split(",")]
109
+
110
+ interpolation_state_list = InterpolationStateList(
111
+ frame_indices=frame_indices_list,
112
+ is_skip_list=is_skip_list,
113
+ )
114
+ return (interpolation_state_list,)
115
+
116
+
117
+ def get_ckpt_container_path(model_type):
118
+ return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type))
119
+
120
+
121
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
122
+ """Load file from http url, will download models if necessary.
123
+
124
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
125
+
126
+ Args:
127
+ url (str): URL to be downloaded.
128
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
129
+ progress (bool): Whether to show the download progress.
130
+ file_name (str): The downloaded file name. If None, use the file name in the url.
131
+
132
+ Returns:
133
+ str: The path to the downloaded file.
134
+ """
135
+ if model_dir is None: # use the pytorch hub_dir
136
+ hub_dir = get_dir()
137
+ model_dir = os.path.join(hub_dir, "checkpoints")
138
+
139
+ os.makedirs(model_dir, exist_ok=True)
140
+
141
+ if file_name is None:
142
+ parts = urlparse(url)
143
+ file_name = os.path.basename(parts.path)
144
+
145
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
146
+ if not os.path.exists(cached_file):
147
+ print(f'Downloading: "{url}" to {cached_file}\n')
148
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
149
+ return cached_file
150
+
151
+
152
+ def load_file_from_github_release(model_type, ckpt_name):
153
+ """
154
+ Backwards-compatible name, but now supports:
155
+ - direct URL passed as ckpt_name
156
+ - a per-ckpt mirror URL fallback (e.g. HuggingFace)
157
+ """
158
+ # Allow passing a direct URL (future-proofing / manual overrides)
159
+ if isinstance(ckpt_name, str) and _is_http_url(ckpt_name):
160
+ return load_file_from_url(ckpt_name, get_ckpt_container_path(model_type))
161
+
162
+ error_strs = []
163
+
164
+ urls_to_try = [base + ckpt_name for base in BASE_MODEL_DOWNLOAD_URLS]
165
+
166
+ # Add last-resort mirror(s) if configured for this ckpt
167
+ mirror = CKPT_URL_MIRRORS.get(ckpt_name)
168
+ if mirror:
169
+ if isinstance(mirror, (list, tuple)):
170
+ urls_to_try.extend(list(mirror))
171
+ else:
172
+ urls_to_try.append(str(mirror))
173
+
174
+ # De-duplicate while preserving order
175
+ seen = set()
176
+ deduped = []
177
+ for u in urls_to_try:
178
+ if u not in seen:
179
+ seen.add(u)
180
+ deduped.append(u)
181
+ urls_to_try = deduped
182
+
183
+ for i, url in enumerate(urls_to_try):
184
+ try:
185
+ return load_file_from_url(url, get_ckpt_container_path(model_type))
186
+ except Exception:
187
+ traceback_str = traceback.format_exc()
188
+ if i < len(urls_to_try) - 1:
189
+ print("Failed! Trying another endpoint.")
190
+ error_strs.append(f"Error when downloading from: {url}\n\n{traceback_str}")
191
+
192
+ error_str = "\n\n".join(error_strs)
193
+ raise Exception(
194
+ f"Tried all endpoints to download {ckpt_name} but no success. Below is the error log:\n\n{error_str}"
195
+ )
196
+
197
+
198
+ def load_file_from_direct_url(model_type, url):
199
+ return load_file_from_url(url, get_ckpt_container_path(model_type))
200
+
201
+
202
+ def preprocess_frames(frames):
203
+ return einops.rearrange(frames[..., :3], "n h w c -> n c h w")
204
+
205
+
206
+ def postprocess_frames(frames):
207
+ return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu()
208
+
209
+
210
+ def assert_batch_size(frames, batch_size=2, vfi_name=None):
211
+ subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires"
212
+ assert len(frames) >= batch_size, (
213
+ f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. "
214
+ f"Please check the frame input using PreviewImage."
215
+ )
216
+
217
+
218
+ def _generic_frame_loop(
219
+ frames,
220
+ clear_cache_after_n_frames,
221
+ multiplier: typing.Union[typing.SupportsInt, typing.List],
222
+ return_middle_frame_function,
223
+ *return_middle_frame_function_args,
224
+ interpolation_states: InterpolationStateList = None,
225
+ use_timestep=True,
226
+ dtype=torch.float16,
227
+ final_logging=True,
228
+ ):
229
+ # https://github.com/hzwer/Practical-RIFE/blob/main/inference_video.py#L169
230
+ def non_timestep_inference(frame0, frame1, n):
231
+ middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args)
232
+ if n == 1:
233
+ return [middle]
234
+ first_half = non_timestep_inference(frame0, middle, n=n // 2)
235
+ second_half = non_timestep_inference(middle, frame1, n=n // 2)
236
+ if n % 2:
237
+ return [*first_half, middle, *second_half]
238
+ else:
239
+ return [*first_half, *second_half]
240
+
241
+ output_frames = torch.zeros(multiplier * frames.shape[0], *frames.shape[1:], dtype=dtype, device="cpu")
242
+ out_len = 0
243
+
244
+ number_of_frames_processed_since_last_cleared_cuda_cache = 0
245
+
246
+ for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it
247
+ frame0 = frames[frame_itr : frame_itr + 1]
248
+ output_frames[out_len] = frame0 # Start with first frame
249
+ out_len += 1
250
+
251
+ # Ensure that input frames are in fp32 - the same dtype as model
252
+ frame0 = frame0.to(dtype=torch.float32)
253
+ frame1 = frames[frame_itr + 1 : frame_itr + 2].to(dtype=torch.float32)
254
+
255
+ if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr):
256
+ continue
257
+
258
+ # Generate and append a batch of middle frames
259
+ middle_frame_batches = []
260
+
261
+ if use_timestep:
262
+ for middle_i in range(1, multiplier):
263
+ timestep = middle_i / multiplier
264
+
265
+ middle_frame = (
266
+ return_middle_frame_function(
267
+ frame0.to(DEVICE),
268
+ frame1.to(DEVICE),
269
+ timestep,
270
+ *return_middle_frame_function_args,
271
+ )
272
+ .detach()
273
+ .cpu()
274
+ )
275
+ middle_frame_batches.append(middle_frame.to(dtype=dtype))
276
+ else:
277
+ middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1)
278
+ middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype))
279
+
280
+ # Copy middle frames to output
281
+ for middle_frame in middle_frame_batches:
282
+ output_frames[out_len] = middle_frame
283
+ out_len += 1
284
+
285
+ number_of_frames_processed_since_last_cleared_cuda_cache += 1
286
+ # Try to avoid a memory overflow by clearing cuda cache regularly
287
+ if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames:
288
+ print("Comfy-VFI: Clearing cache...", end=" ")
289
+ soft_empty_cache()
290
+ number_of_frames_processed_since_last_cleared_cuda_cache = 0
291
+ print("Done cache clearing")
292
+
293
+ gc.collect()
294
+
295
+ if final_logging:
296
+ print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}")
297
+
298
+ # Append final frame
299
+ output_frames[out_len] = frames[-1:]
300
+ out_len += 1
301
+
302
+ # clear cache for courtesy
303
+ if final_logging:
304
+ print("Comfy-VFI: Final clearing cache...", end=" ")
305
+ soft_empty_cache()
306
+ if final_logging:
307
+ print("Done cache clearing")
308
+
309
+ return output_frames[:out_len]
310
+
311
+
312
+ def generic_frame_loop(
313
+ model_name,
314
+ frames,
315
+ clear_cache_after_n_frames,
316
+ multiplier: typing.Union[typing.SupportsInt, typing.List],
317
+ return_middle_frame_function,
318
+ *return_middle_frame_function_args,
319
+ interpolation_states: InterpolationStateList = None,
320
+ use_timestep=True,
321
+ dtype=torch.float32,
322
+ ):
323
+ assert_batch_size(frames, vfi_name=model_name.replace("_", " ").replace("VFI", ""))
324
+ if type(multiplier) == int:
325
+ return _generic_frame_loop(
326
+ frames,
327
+ clear_cache_after_n_frames,
328
+ multiplier,
329
+ return_middle_frame_function,
330
+ *return_middle_frame_function_args,
331
+ interpolation_states=interpolation_states,
332
+ use_timestep=use_timestep,
333
+ dtype=dtype,
334
+ )
335
+ if type(multiplier) == list:
336
+ multipliers = list(map(int, multiplier))
337
+ multipliers += [2] * (len(frames) - len(multipliers) - 1)
338
+ frame_batches = []
339
+ for frame_itr in range(len(frames) - 1):
340
+ multiplier = multipliers[frame_itr]
341
+ if multiplier == 0:
342
+ continue
343
+ frame_batch = _generic_frame_loop(
344
+ frames[frame_itr : frame_itr + 2],
345
+ clear_cache_after_n_frames,
346
+ multiplier,
347
+ return_middle_frame_function,
348
+ *return_middle_frame_function_args,
349
+ interpolation_states=interpolation_states,
350
+ use_timestep=use_timestep,
351
+ dtype=dtype,
352
+ final_logging=False,
353
+ )
354
+ if frame_itr != len(frames) - 2: # Not append last frame unless this batch is the last one
355
+ frame_batch = frame_batch[:-1]
356
+ frame_batches.append(frame_batch)
357
+ output_frames = torch.cat(frame_batches)
358
+ print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}")
359
+ return output_frames
360
+ raise NotImplementedError(f"multipiler of {type(multiplier)}")
361
+
362
+
363
+ class FloatToInt:
364
+ @classmethod
365
+ def INPUT_TYPES(s):
366
+ return {"required": {"float": ("FLOAT", {"default": 0, "min": 0, "step": 0.01})}}
367
+
368
+ RETURN_TYPES = ("INT",)
369
+ FUNCTION = "convert"
370
+ CATEGORY = "ComfyUI-Frame-Interpolation"
371
+
372
+ def convert(self, float):
373
+ if hasattr(float, "__iter__"):
374
+ return (list(map(int, float)),)
375
+ return (int(float),)
376
+
377
+
378
+ """ def generic_4frame_loop(
379
+ frames,
380
+ clear_cache_after_n_frames,
381
+ multiplier: typing.SupportsInt,
382
+ return_middle_frame_function,
383
+ *return_middle_frame_function_args,
384
+ interpolation_states: InterpolationStateList = None,
385
+ use_timestep=False):
386
+
387
+ if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model")
388
+ def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n):
389
+ middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args)
390
+ if n == 1:
391
+ return [middle]
392
+ first_half = non_timestep_inference(frame_0, middle, n=n//2)
393
+ second_half = non_timestep_inference(middle, frame_1, n=n//2)
394
+ if n%2:
395
+ return [*first_half, middle, *second_half]
396
+ else:
397
+ return [*first_half, *second_half] """