aliensmn's picture
Mirror from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation
61029c7 verified
import pathlib
from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
import typing
import torch
import torch.nn as nn
from comfy.model_management import soft_empty_cache, get_torch_device
MODEL_TYPE = pathlib.Path(__file__).parent.name
MODEL_FILE_NAMES = {
"ssl": "eisai_ssl.pt",
"dtm": "eisai_dtm.pt",
"raft": "eisai_anime_interp_full.ckpt"
}
class EISAI(nn.Module):
def __init__(self, model_file_names) -> None:
from .eisai_arch import SoftsplatLite, DTM, RAFT
super(EISAI, self).__init__()
self.raft = RAFT(load_file_from_github_release(MODEL_TYPE, model_file_names["raft"]))
self.raft.to(get_torch_device()).eval()
self.ssl = SoftsplatLite()
self.ssl.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["ssl"])))
self.ssl.to(get_torch_device()).eval()
self.dtm = DTM()
self.dtm.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["dtm"])))
self.dtm.to(get_torch_device()).eval()
def forward(self, img0, img1, t):
with torch.no_grad():
flow0, _ = self.raft(img0, img1)
flow1, _ = self.raft(img1, img0)
x = {
"images": torch.stack([img0, img1], dim=1),
"flows": torch.stack([flow0, flow1], dim=1),
}
out_ssl, _ = self.ssl(x, t=t, return_more=True)
out_dtm, _ = self.dtm(x, out_ssl, _, return_more=False)
return out_dtm[:, :3]
class EISAI_VFI:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (["eisai"], ),
"frames": ("IMAGE", ),
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
},
"optional": {
"optional_interpolation_states": ("INTERPOLATION_STATES", )
}
}
RETURN_TYPES = ("IMAGE", )
FUNCTION = "vfi"
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
def vfi(
self,
ckpt_name: typing.AnyStr,
frames: torch.Tensor,
clear_cache_after_n_frames = 10,
multiplier: typing.SupportsInt = 2,
optional_interpolation_states: InterpolationStateList = None,
**kwargs
):
interpolation_model = EISAI(MODEL_FILE_NAMES)
interpolation_model.eval().to(get_torch_device())
frames = preprocess_frames(frames)
def return_middle_frame(frame_0, frame_1, timestep, model):
return model(frame_0, frame_1, t=timestep)
scale = 1
args = [interpolation_model, scale]
out = postprocess_frames(
generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
interpolation_states=optional_interpolation_states, dtype=torch.float32)
)
return (out,)