| import argparse |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import einops |
| from torch.utils.data import DataLoader |
| import pathlib |
| from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, InterpolationStateList |
| import typing |
| from comfy.model_management import get_torch_device |
|
|
| CKPT_CONFIGS = { |
| "XVFInet_X4K1000FPS_exp1_latest.pt": { |
| "module_scale_factor": 4, |
| "S_trn": 3, |
| "S_tst": 5 |
| }, |
| "XVFInet_Vimeo_exp1_latest.pt": { |
| "module_scale_factor": 2, |
| "S_trn": 1, |
| "S_tst": 1 |
| } |
| } |
|
|
| class XVFI_Inference(nn.Module): |
| def __init__(self, model_path, model_config) -> None: |
| super(XVFI_Inference, self).__init__() |
| from .xvfi_arch import XVFInet, weights_init |
| model_config = model_config |
| args = argparse.Namespace( |
| gpu=get_torch_device(), |
| nf=64, |
| **model_config, |
| img_ch=3, |
| ) |
| self.model = XVFInet(args).apply(weights_init).to(get_torch_device()) |
| self.model.load_state_dict(torch.load(model_path, map_location=get_torch_device())["state_dict_Model"]) |
|
|
| def forward(self, I0, I1, timestep): |
| |
| |
| |
|
|
| x = torch.stack([I0, I1], dim=0) |
| x = einops.rearrange(x, "t b c h w -> b c t h w") |
| return self.model(x, timestep, is_training=False) |
|
|
| MODEL_TYPE = pathlib.Path(__file__).parent.name |
|
|
| class XVFI: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "ckpt_name": (list(CKPT_CONFIGS.keys()), ), |
| "frames": ("IMAGE", ), |
| "batch_size": ("INT", {"default": 1, "min": 1, "max": 100}), |
| "multipler": ("INT", {"default": 2, "min": 2, "max": 1000}), |
| }, |
| "optional": { |
| "optional_interpolation_states": ("INTERPOLATION_STATES", ), |
| } |
| } |
| |
| RETURN_TYPES = ("IMAGE", ) |
| FUNCTION = "vfi" |
| CATEGORY = "ComfyUI-Frame-Interpolation/VFI" |
|
|
| def vfi( |
| self, |
| ckpt_name: typing.AnyStr, |
| frames: torch.Tensor, |
| batch_size: typing.SupportsInt = 1, |
| multipler: typing.SupportsInt = 2, |
| optional_interpolation_states: InterpolationStateList = None |
| ): |
| model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) |
| ckpt_config = CKPT_CONFIGS[ckpt_name] |
| global model |
| model = XVFI_Inference(model_path, ckpt_config) |
|
|
| frames = preprocess_frames(frames) |
| |
| divide = 2 ** (ckpt_config["S_tst"]) * ckpt_config["module_scale_factor"] * 4 |
| B, C, H, W = frames.size() |
| H_padding = (divide - H % divide) % divide |
| W_padding = (divide - W % divide) % divide |
| if H_padding != 0 or W_padding != 0: |
| frames = F.pad(frames, (0, W_padding, 0, H_padding), "constant") |
| |
| frame_dict = { |
| str(i): frames[i].unsqueeze(0) for i in range(frames.shape[0]) |
| } |
|
|
| if optional_interpolation_states is None: |
| interpolation_states = [True] * (frames.shape[0] - 1) |
| else: |
| interpolation_states = optional_interpolation_states |
|
|
| enabled_former_idxs = [i for i, state in enumerate(interpolation_states) if state] |
| former_idxs_loader = DataLoader(enabled_former_idxs, batch_size=batch_size) |
| |
| for former_idxs_batch in former_idxs_loader: |
| for middle_i in range(1, multipler): |
| _middle_frames = model( |
| frames[former_idxs_batch], |
| frames[former_idxs_batch + 1], |
| timestep=torch.tensor([middle_i/multipler]).repeat(len(former_idxs_batch)).unsqueeze(1).to(get_torch_device()) |
| ) |
| for i, former_idx in enumerate(former_idxs_batch): |
| frame_dict[f'{former_idx}.{middle_i}'] = _middle_frames[i].unsqueeze(0) |
| |
| out_frames = torch.cat([frame_dict[key] for key in sorted(frame_dict.keys())], dim=0)[:, :, :H, :W] |
| return (postprocess_frames(out_frames), ) |
| |
|
|