import yaml import os from torch.hub import download_url_to_file, get_dir from urllib.parse import urlparse import torch import typing import traceback import einops import gc import torchvision.transforms.functional as transform from comfy.model_management import soft_empty_cache, get_torch_device import numpy as np BASE_MODEL_DOWNLOAD_URLS = [ "https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/", "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/", "https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/" ] config_path = os.path.join(os.path.dirname(__file__), "./config.yaml") if os.path.exists(config_path): config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) else: raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation") DEVICE = get_torch_device() class InterpolationStateList(): def __init__(self, frame_indices: typing.List[int], is_skip_list: bool): self.frame_indices = frame_indices self.is_skip_list = is_skip_list def is_frame_skipped(self, frame_index): is_frame_in_list = frame_index in self.frame_indices return self.is_skip_list and is_frame_in_list or not self.is_skip_list and not is_frame_in_list class MakeInterpolationStateList: @classmethod def INPUT_TYPES(s): return { "required": { "frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}), "is_skip_list": ("BOOLEAN", {"default": True},), }, } RETURN_TYPES = ("INTERPOLATION_STATES",) FUNCTION = "create_options" CATEGORY = "ComfyUI-Frame-Interpolation/VFI" def create_options(self, frame_indices: str, is_skip_list: bool): frame_indices_list = [int(item) for item in frame_indices.split(',')] interpolation_state_list = InterpolationStateList( frame_indices=frame_indices_list, is_skip_list=is_skip_list, ) return (interpolation_state_list,) def get_ckpt_container_path(model_type): return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type)) def load_file_from_url(url, model_dir=None, progress=True, file_name=None): """Load file form http url, will download models if necessary. Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py Args: url (str): URL to be downloaded. model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. Default: None. progress (bool): Whether to show the download progress. Default: True. file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. Returns: str: The path to the downloaded file. """ if model_dir is None: # use the pytorch hub_dir hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) file_name = os.path.basename(parts.path) if file_name is not None: file_name = file_name cached_file = os.path.abspath(os.path.join(model_dir, file_name)) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file def load_file_from_github_release(model_type, ckpt_name): error_strs = [] for i, base_model_download_url in enumerate(BASE_MODEL_DOWNLOAD_URLS): try: return load_file_from_url(base_model_download_url + ckpt_name, get_ckpt_container_path(model_type)) except Exception: traceback_str = traceback.format_exc() if i < len(BASE_MODEL_DOWNLOAD_URLS) - 1: print("Failed! Trying another endpoint.") error_strs.append(f"Error when downloading from: {base_model_download_url + ckpt_name}\n\n{traceback_str}") error_str = '\n\n'.join(error_strs) raise Exception(f"Tried all GitHub base urls to download {ckpt_name} but no suceess. Below is the error log:\n\n{error_str}") def load_file_from_direct_url(model_type, url): return load_file_from_url(url, get_ckpt_container_path(model_type)) def preprocess_frames(frames): return einops.rearrange(frames, "n h w c -> n c h w") def postprocess_frames(frames): return einops.rearrange(frames, "n c h w -> n h w c").cpu() def assert_batch_size(frames, batch_size=2, vfi_name=None): subject_verb = "Most VFI models require" if vfi_name is None else f"{vfi_name} VFI model requires" assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage." def generic_frame_loop( frames, clear_cache_after_n_frames, multiplier: typing.SupportsInt, return_middle_frame_function, *return_middle_frame_function_args, interpolation_states: InterpolationStateList = None, use_timestep=True, dtype=torch.float16): #https://github.com/hzwer/Practical-RIFE/blob/main/inference_video.py#L169 def non_timestep_inference(frame0, frame1, n): middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args) if n == 1: return [middle] first_half = non_timestep_inference(frame0, middle, n=n//2) second_half = non_timestep_inference(middle, frame1, n=n//2) if n%2: return [*first_half, middle, *second_half] else: return [*first_half, *second_half] assert_batch_size(frames) # Too lazy to include model name lol output_frames = [] # List to store processed frames in the correct order number_of_frames_processed_since_last_cleared_cuda_cache = 0 for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it #Ensure that input frames are in fp32 - the same dtype as model frame0 = frames[frame_itr:frame_itr+1].float() frame1 = frames[frame_itr+1:frame_itr+2].float() output_frames.append(frame0.to(dtype=dtype)) # Start with first frame if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr): continue # Generate and append a batch of middle frames middle_frame_batches = [] if use_timestep: for middle_i in range(1, multiplier): timestep = middle_i/multiplier middle_frame = return_middle_frame_function( frame0.to(DEVICE), frame1.to(DEVICE), timestep, *return_middle_frame_function_args ).detach().cpu() middle_frame_batches.append(middle_frame.to(dtype=dtype)) else: middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1) middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype)) # Extend output array by batch output_frames.extend(middle_frame_batches) number_of_frames_processed_since_last_cleared_cuda_cache += 1 # Try to avoid a memory overflow by clearing cuda cache regularly if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: print("Comfy-VFI: Clearing cache...", end=' ') soft_empty_cache() number_of_frames_processed_since_last_cleared_cuda_cache = 0 print("Done cache clearing") gc.collect() print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") output_frames.append(frames[-1:].to(dtype=dtype)) # Append final frame output_frames = [frame.cpu() for frame in output_frames] #Ensure all frames are in cpu out = torch.cat(output_frames, dim=0) # clear cache for courtesy print("Comfy-VFI: Final clearing cache...", end = ' ') soft_empty_cache() print("Done cache clearing") return out """ def generic_4frame_loop( frames, clear_cache_after_n_frames, multiplier: typing.SupportsInt, return_middle_frame_function, *return_middle_frame_function_args, interpolation_states: InterpolationStateList = None, use_timestep=False): if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model") def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n): middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args) if n == 1: return [middle] first_half = non_timestep_inference(frame_0, middle, n=n//2) second_half = non_timestep_inference(middle, frame_1, n=n//2) if n%2: return [*first_half, middle, *second_half] else: return [*first_half, *second_half] """