| | import yaml |
| | import os |
| | from torch.hub import download_url_to_file, get_dir |
| | from urllib.parse import urlparse |
| | import torch |
| | import typing |
| | import traceback |
| | import einops |
| | import gc |
| | import torchvision.transforms.functional as transform |
| | from comfy.model_management import soft_empty_cache, get_torch_device |
| | import numpy as np |
| |
|
| | BASE_MODEL_DOWNLOAD_URLS = [ |
| | "https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/", |
| | "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/", |
| | "https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/" |
| | ] |
| |
|
| | config_path = os.path.join(os.path.dirname(__file__), "./config.yaml") |
| | if os.path.exists(config_path): |
| | config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) |
| | else: |
| | raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation") |
| | DEVICE = get_torch_device() |
| |
|
| | class InterpolationStateList(): |
| |
|
| | def __init__(self, frame_indices: typing.List[int], is_skip_list: bool): |
| | self.frame_indices = frame_indices |
| | self.is_skip_list = is_skip_list |
| | |
| | def is_frame_skipped(self, frame_index): |
| | is_frame_in_list = frame_index in self.frame_indices |
| | return self.is_skip_list and is_frame_in_list or not self.is_skip_list and not is_frame_in_list |
| | |
| |
|
| | class MakeInterpolationStateList: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}), |
| | "is_skip_list": ("BOOLEAN", {"default": True},), |
| | }, |
| | } |
| | |
| | RETURN_TYPES = ("INTERPOLATION_STATES",) |
| | FUNCTION = "create_options" |
| | CATEGORY = "ComfyUI-Frame-Interpolation/VFI" |
| |
|
| | def create_options(self, frame_indices: str, is_skip_list: bool): |
| | frame_indices_list = [int(item) for item in frame_indices.split(',')] |
| | |
| | interpolation_state_list = InterpolationStateList( |
| | frame_indices=frame_indices_list, |
| | is_skip_list=is_skip_list, |
| | ) |
| | return (interpolation_state_list,) |
| | |
| | |
| | def get_ckpt_container_path(model_type): |
| | return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type)) |
| |
|
| | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): |
| | """Load file form http url, will download models if necessary. |
| | |
| | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py |
| | |
| | Args: |
| | url (str): URL to be downloaded. |
| | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. |
| | Default: None. |
| | progress (bool): Whether to show the download progress. Default: True. |
| | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. |
| | |
| | Returns: |
| | str: The path to the downloaded file. |
| | """ |
| | if model_dir is None: |
| | hub_dir = get_dir() |
| | model_dir = os.path.join(hub_dir, 'checkpoints') |
| |
|
| | os.makedirs(model_dir, exist_ok=True) |
| |
|
| | parts = urlparse(url) |
| | file_name = os.path.basename(parts.path) |
| | if file_name is not None: |
| | file_name = file_name |
| | cached_file = os.path.abspath(os.path.join(model_dir, file_name)) |
| | if not os.path.exists(cached_file): |
| | print(f'Downloading: "{url}" to {cached_file}\n') |
| | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) |
| | return cached_file |
| |
|
| | def load_file_from_github_release(model_type, ckpt_name): |
| | error_strs = [] |
| | for i, base_model_download_url in enumerate(BASE_MODEL_DOWNLOAD_URLS): |
| | try: |
| | return load_file_from_url(base_model_download_url + ckpt_name, get_ckpt_container_path(model_type)) |
| | except Exception: |
| | traceback_str = traceback.format_exc() |
| | if i < len(BASE_MODEL_DOWNLOAD_URLS) - 1: |
| | print("Failed! Trying another endpoint.") |
| | error_strs.append(f"Error when downloading from: {base_model_download_url + ckpt_name}\n\n{traceback_str}") |
| |
|
| | error_str = '\n\n'.join(error_strs) |
| | raise Exception(f"Tried all GitHub base urls to download {ckpt_name} but no suceess. Below is the error log:\n\n{error_str}") |
| | |
| |
|
| | def load_file_from_direct_url(model_type, url): |
| | return load_file_from_url(url, get_ckpt_container_path(model_type)) |
| |
|
| | def preprocess_frames(frames): |
| | return einops.rearrange(frames[..., :3], "n h w c -> n c h w") |
| |
|
| | def postprocess_frames(frames): |
| | return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu() |
| |
|
| | def assert_batch_size(frames, batch_size=2, vfi_name=None): |
| | subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires" |
| | assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage." |
| |
|
| | def _generic_frame_loop( |
| | frames, |
| | clear_cache_after_n_frames, |
| | multiplier: typing.Union[typing.SupportsInt, typing.List], |
| | return_middle_frame_function, |
| | *return_middle_frame_function_args, |
| | interpolation_states: InterpolationStateList = None, |
| | use_timestep=True, |
| | dtype=torch.float16, |
| | final_logging=True): |
| | |
| | |
| | def non_timestep_inference(frame0, frame1, n): |
| | middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args) |
| | if n == 1: |
| | return [middle] |
| | first_half = non_timestep_inference(frame0, middle, n=n//2) |
| | second_half = non_timestep_inference(middle, frame1, n=n//2) |
| | if n%2: |
| | return [*first_half, middle, *second_half] |
| | else: |
| | return [*first_half, *second_half] |
| |
|
| | output_frames = torch.zeros(multiplier*frames.shape[0], *frames.shape[1:], dtype=dtype, device="cpu") |
| | out_len = 0 |
| |
|
| | number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
| | |
| | for frame_itr in range(len(frames) - 1): |
| | frame0 = frames[frame_itr:frame_itr+1] |
| | output_frames[out_len] = frame0 |
| | out_len += 1 |
| | |
| | frame0 = frame0.to(dtype=torch.float32) |
| | frame1 = frames[frame_itr+1:frame_itr+2].to(dtype=torch.float32) |
| | |
| | if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr): |
| | continue |
| | |
| | |
| | middle_frame_batches = [] |
| |
|
| | if use_timestep: |
| | for middle_i in range(1, multiplier): |
| | timestep = middle_i/multiplier |
| | |
| | middle_frame = return_middle_frame_function( |
| | frame0.to(DEVICE), |
| | frame1.to(DEVICE), |
| | timestep, |
| | *return_middle_frame_function_args |
| | ).detach().cpu() |
| | middle_frame_batches.append(middle_frame.to(dtype=dtype)) |
| | else: |
| | middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1) |
| | middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype)) |
| | |
| | |
| | for middle_frame in middle_frame_batches: |
| | output_frames[out_len] = middle_frame |
| | out_len += 1 |
| |
|
| | number_of_frames_processed_since_last_cleared_cuda_cache += 1 |
| | |
| | if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: |
| | print("Comfy-VFI: Clearing cache...", end=' ') |
| | soft_empty_cache() |
| | number_of_frames_processed_since_last_cleared_cuda_cache = 0 |
| | print("Done cache clearing") |
| | |
| | gc.collect() |
| | |
| | if final_logging: |
| | print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") |
| | |
| | output_frames[out_len] = frames[-1:] |
| | out_len += 1 |
| | |
| | if final_logging: |
| | print("Comfy-VFI: Final clearing cache...", end = ' ') |
| | soft_empty_cache() |
| | if final_logging: |
| | print("Done cache clearing") |
| | return output_frames[:out_len] |
| |
|
| | def generic_frame_loop( |
| | model_name, |
| | frames, |
| | clear_cache_after_n_frames, |
| | multiplier: typing.Union[typing.SupportsInt, typing.List], |
| | return_middle_frame_function, |
| | *return_middle_frame_function_args, |
| | interpolation_states: InterpolationStateList = None, |
| | use_timestep=True, |
| | dtype=torch.float32): |
| |
|
| | assert_batch_size(frames, vfi_name=model_name.replace('_', ' ').replace('VFI', '')) |
| | if type(multiplier) == int: |
| | return _generic_frame_loop( |
| | frames, |
| | clear_cache_after_n_frames, |
| | multiplier, |
| | return_middle_frame_function, |
| | *return_middle_frame_function_args, |
| | interpolation_states=interpolation_states, |
| | use_timestep=use_timestep, |
| | dtype=dtype |
| | ) |
| | if type(multiplier) == list: |
| | multipliers = list(map(int, multiplier)) |
| | multipliers += [2] * (len(frames) - len(multipliers) - 1) |
| | frame_batches = [] |
| | for frame_itr in range(len(frames) - 1): |
| | multiplier = multipliers[frame_itr] |
| | if multiplier == 0: continue |
| | frame_batch = _generic_frame_loop( |
| | frames[frame_itr:frame_itr+2], |
| | clear_cache_after_n_frames, |
| | multiplier, |
| | return_middle_frame_function, |
| | *return_middle_frame_function_args, |
| | interpolation_states=interpolation_states, |
| | use_timestep=use_timestep, |
| | dtype=dtype, |
| | final_logging=False |
| | ) |
| | if frame_itr != len(frames) - 2: |
| | frame_batch = frame_batch[:-1] |
| | frame_batches.append(frame_batch) |
| | output_frames = torch.cat(frame_batches) |
| | print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") |
| | return output_frames |
| | raise NotImplementedError(f"multipiler of {type(multiplier)}") |
| |
|
| | class FloatToInt: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return { |
| | "required": { |
| | "float": ("FLOAT", {"default": 0, 'min': 0, 'step': 0.01}) |
| | } |
| | } |
| | |
| | RETURN_TYPES = ("INT",) |
| | FUNCTION = "convert" |
| | CATEGORY = "ComfyUI-Frame-Interpolation" |
| |
|
| | def convert(self, float): |
| | if hasattr(float, "__iter__"): |
| | return (list(map(int, float)),) |
| | return (int(float),) |
| |
|
| | """ def generic_4frame_loop( |
| | frames, |
| | clear_cache_after_n_frames, |
| | multiplier: typing.SupportsInt, |
| | return_middle_frame_function, |
| | *return_middle_frame_function_args, |
| | interpolation_states: InterpolationStateList = None, |
| | use_timestep=False): |
| | |
| | if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model") |
| | def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n): |
| | middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args) |
| | if n == 1: |
| | return [middle] |
| | first_half = non_timestep_inference(frame_0, middle, n=n//2) |
| | second_half = non_timestep_inference(middle, frame_1, n=n//2) |
| | if n%2: |
| | return [*first_half, middle, *second_half] |
| | else: |
| | return [*first_half, *second_half] """ |