import pathlib from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList import typing import torch import torch.nn as nn from comfy.model_management import soft_empty_cache, get_torch_device MODEL_TYPE = pathlib.Path(__file__).parent.name MODEL_FILE_NAMES = { "ssl": "eisai_ssl.pt", "dtm": "eisai_dtm.pt", "raft": "eisai_anime_interp_full.ckpt" } class EISAI(nn.Module): def __init__(self, model_file_names) -> None: from .eisai_arch import SoftsplatLite, DTM, RAFT super(EISAI, self).__init__() self.raft = RAFT(load_file_from_github_release(MODEL_TYPE, model_file_names["raft"])) self.raft.to(get_torch_device()).eval() self.ssl = SoftsplatLite() self.ssl.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["ssl"]))) self.ssl.to(get_torch_device()).eval() self.dtm = DTM() self.dtm.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["dtm"]))) self.dtm.to(get_torch_device()).eval() def forward(self, img0, img1, t): with torch.no_grad(): flow0, _ = self.raft(img0, img1) flow1, _ = self.raft(img1, img0) x = { "images": torch.stack([img0, img1], dim=1), "flows": torch.stack([flow0, flow1], dim=1), } out_ssl, _ = self.ssl(x, t=t, return_more=True) out_dtm, _ = self.dtm(x, out_ssl, _, return_more=False) return out_dtm[:, :3] class EISAI_VFI: @classmethod def INPUT_TYPES(s): return { "required": { "ckpt_name": (["eisai"], ), "frames": ("IMAGE", ), "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), }, "optional": { "optional_interpolation_states": ("INTERPOLATION_STATES", ) } } RETURN_TYPES = ("IMAGE", ) FUNCTION = "vfi" CATEGORY = "ComfyUI-Frame-Interpolation/VFI" def vfi( self, ckpt_name: typing.AnyStr, frames: torch.Tensor, clear_cache_after_n_frames = 10, multiplier: typing.SupportsInt = 2, optional_interpolation_states: InterpolationStateList = None, **kwargs ): interpolation_model = EISAI(MODEL_FILE_NAMES) interpolation_model.eval().to(get_torch_device()) frames = preprocess_frames(frames) def return_middle_frame(frame_0, frame_1, timestep, model): return model(frame_0, frame_1, t=timestep) scale = 1 args = [interpolation_model, scale] out = postprocess_frames( generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, interpolation_states=optional_interpolation_states, dtype=torch.float32) ) return (out,)