diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..8a703c561132057791f1855c8247871214638267 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +All_in_one_v1_3.png filter=lfs diff=lfs merge=lfs -text +demo_frames/anime0.png filter=lfs diff=lfs merge=lfs -text +demo_frames/anime1.png filter=lfs diff=lfs merge=lfs -text +demo_frames/bocchi0.jpg filter=lfs diff=lfs merge=lfs -text +demo_frames/bocchi1.jpg filter=lfs diff=lfs merge=lfs -text +demo_frames/real0.png filter=lfs diff=lfs merge=lfs -text +demo_frames/real1.png filter=lfs diff=lfs merge=lfs -text +demo_frames/rick/00003.png filter=lfs diff=lfs merge=lfs -text +demo_frames/rick/00004.png filter=lfs diff=lfs merge=lfs -text +demo_frames/rick/00005.png filter=lfs diff=lfs merge=lfs -text +demo_frames/violet0.png filter=lfs diff=lfs merge=lfs -text +demo_frames/violet1.png filter=lfs diff=lfs merge=lfs -text +example.png filter=lfs diff=lfs merge=lfs -text +interpolation_schedule.png filter=lfs diff=lfs merge=lfs -text +test_vfi_schedule.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..eb5b7d5f5484abc954a89dde2b97d327de44b43b --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,25 @@ +name: Publish to Comfy registry +on: + workflow_dispatch: + push: + branches: + - main + paths: + - "pyproject.toml" + +permissions: + issues: write + +jobs: + publish-node: + name: Publish Custom Node to registry + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'Fannovel16' }} + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@v1 + with: + ## Add your own personal access token to your Github Repository secrets and reference it here. + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c16001cc49079aaeba0e41bb39dcbf6c4ac3733d --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +ckpts +__pycache__ +test_result \ No newline at end of file diff --git a/All_in_one_v1_3.png b/All_in_one_v1_3.png new file mode 100644 index 0000000000000000000000000000000000000000..364c54b9001c4a450dc434f9fd310c59a2a98af2 --- /dev/null +++ b/All_in_one_v1_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90735b644e0c35634642b65f2a8041a9a4da380d27b9bcc4d3bbef47869bd92a +size 1462273 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2a8000ad9540222a0f8f50ac7fb8b04fa8dd0cd3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Fannovel16 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f5d0f3c06b520f1f7a9724e67cc780d8709e2cbb --- /dev/null +++ b/README.md @@ -0,0 +1,194 @@ +# ComfyUI Frame Interpolation (ComfyUI VFI) (WIP) + +A custom node set for Video Frame Interpolation in ComfyUI. +**UPDATE** Memory management is improved. Now this extension takes less RAM and VRAM than before. + +**UPDATE 2** VFI nodes now accept scheduling multipiler values + +![](./interpolation_schedule.png) +![](./test_vfi_schedule.gif) + +## Nodes +* KSampler Gradually Adding More Denoise (efficient) +* GMFSS Fortuna VFI +* IFRNet VFI +* IFUnet VFI +* M2M VFI +* RIFE VFI (4.0 - 4.9) (Note that option `fast_mode` won't do anything from v4.5+ as `contextnet` is removed) +* FILM VFI +* Sepconv VFI +* AMT VFI +* Make Interpolation State List +* STMFNet VFI (requires at least 4 frames, can only do 2x interpolation for now) +* FLAVR VFI (same conditions as STMFNet) + +## Install +### ComfyUI Manager +Incompatibile issue with it is now fixed + +Following this guide to install this extension + +https://github.com/ltdrdata/ComfyUI-Manager#how-to-use +### Command-line +#### Windows +Run install.bat + +For Window users, if you are having trouble with cupy, please run `install.bat` instead of `install-cupy.py` or `python install.py`. +#### Linux +Open your shell app and start venv if it is used for ComfyUI. Run: +``` +python install.py +``` +## Support for non-CUDA device (experimental) +If you don't have a NVidia card, you can try `taichi` ops backend powered by [Taichi Lang](https://www.taichi-lang.org/) + +On Windows, you can install it by running `install.bat` or `pip install taichi` on Linux + +Then change value of `ops_backend` from `cupy` to `taichi` in `config.yaml` + +If `NotImplementedError` appears, a VFI node in the workflow isn't supported by taichi + +## Usage +All VFI nodes can be accessed in **category** `ComfyUI-Frame-Interpolation/VFI` if the installation is successful and require a `IMAGE` containing frames (at least 2, or at least 4 for STMF-Net/FLAVR). + +Regarding STMFNet and FLAVR, if you only have two or three frames, you should use: Load Images -> Other VFI node (FILM is recommended in this case) with `multiplier=4` -> STMFNet VFI/FLAVR VFI + +`clear_cache_after_n_frames` is used to avoid out-of-memory. Decreasing it makes the chance lower but also increases processing time. + +It is recommended to use LoadImages (LoadImagesFromDirectory) from [ComfyUI-Advanced-ControlNet](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/) and [ComfyUI-VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) along side with this extension. + +## Example +### Simple workflow +Workflow metadata isn't embeded +Download these two images [anime0.png](./demo_frames/anime0.png) and [anime1.png](./demo_frames/anime0.png) and put them into a folder like `E:\test` in this image. +![](./example.png) + +### Complex workflow +It's used in AnimationDiff (can load workflow metadata) +![](All_in_one_v1_3.png) + +## Credit +Big thanks for styler00dollar for making [VSGAN-tensorrt-docker](https://github.com/styler00dollar/VSGAN-tensorrt-docker). About 99% the code of this repo comes from it. + +Citation for each VFI node: +### GMFSS Fortuna +The All-In-One GMFSS: Dedicated for Anime Video Frame Interpolation + +https://github.com/98mxr/GMFSS_Fortuna + +### IFRNet +```bibtex +@InProceedings{Kong_2022_CVPR, + author = {Kong, Lingtong and Jiang, Boyuan and Luo, Donghao and Chu, Wenqing and Huang, Xiaoming and Tai, Ying and Wang, Chengjie and Yang, Jie}, + title = {IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2022} +} +``` + +### IFUnet +RIFE with IFUNet, FusionNet and RefineNet + +https://github.com/98mxr/IFUNet +### M2M +```bibtex +@InProceedings{hu2022m2m, + title={Many-to-many Splatting for Efficient Video Frame Interpolation}, + author={Hu, Ping and Niklaus, Simon and Sclaroff, Stan and Saenko, Kate}, + journal={CVPR}, + year={2022} + } +``` + +### RIFE +```bibtex +@inproceedings{huang2022rife, + title={Real-Time Intermediate Flow Estimation for Video Frame Interpolation}, + author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang}, + booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, + year={2022} +} +``` + +### FILM +[Frame interpolation in PyTorch](https://github.com/dajes/frame-interpolation-pytorch) + +```bibtex +@inproceedings{reda2022film, + title = {FILM: Frame Interpolation for Large Motion}, + author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless}, + booktitle = {European Conference on Computer Vision (ECCV)}, + year = {2022} +} +``` + +```bibtex +@misc{film-tf, + title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"}, + author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/google-research/frame-interpolation}} +} +``` + +### Sepconv +```bibtex +[1] @inproceedings{Niklaus_WACV_2021, + author = {Simon Niklaus and Long Mai and Oliver Wang}, + title = {Revisiting Adaptive Convolutions for Video Frame Interpolation}, + booktitle = {IEEE Winter Conference on Applications of Computer Vision}, + year = {2021} + } +``` + +```bibtex +[2] @inproceedings{Niklaus_ICCV_2017, + author = {Simon Niklaus and Long Mai and Feng Liu}, + title = {Video Frame Interpolation via Adaptive Separable Convolution}, + booktitle = {IEEE International Conference on Computer Vision}, + year = {2017} + } +``` + +```bibtex +[3] @inproceedings{Niklaus_CVPR_2017, + author = {Simon Niklaus and Long Mai and Feng Liu}, + title = {Video Frame Interpolation via Adaptive Convolution}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, + year = {2017} + } +``` + +### AMT + ```bibtex + @inproceedings{licvpr23amt, + title={AMT: All-Pairs Multi-Field Transforms for Efficient Frame Interpolation}, + author={Li, Zhen and Zhu, Zuo-Liang and Han, Ling-Hao and Hou, Qibin and Guo, Chun-Le and Cheng, Ming-Ming}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year={2023} + } + ``` + +### ST-MFNet +```bibtex +@InProceedings{Danier_2022_CVPR, + author = {Danier, Duolikun and Zhang, Fan and Bull, David}, + title = {ST-MFNet: A Spatio-Temporal Multi-Flow Network for Frame Interpolation}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {3521-3531} +} +``` + +### FLAVR +```bibtex +@article{kalluri2021flavr, + title={FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation}, + author={Kalluri, Tarun and Pathak, Deepak and Chandraker, Manmohan and Tran, Du}, + booktitle={arxiv}, + year={2021} +} +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b362e2c1dbec391c923bda205c46da3a235d63ee --- /dev/null +++ b/__init__.py @@ -0,0 +1,42 @@ +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +from .other_nodes import Gradually_More_Denoise_KSampler + +#Some models are commented out because the code is not completed +#from vfi_models.eisai import EISAI_VFI +from vfi_models.gmfss_fortuna import GMFSS_Fortuna_VFI +from vfi_models.ifrnet import IFRNet_VFI +from vfi_models.ifunet import IFUnet_VFI +from vfi_models.m2m import M2M_VFI +from vfi_models.rife import RIFE_VFI +from vfi_models.sepconv import SepconvVFI +from vfi_models.amt import AMT_VFI +from vfi_models.film import FILM_VFI +from vfi_models.stmfnet import STMFNet_VFI +from vfi_models.flavr import FLAVR_VFI +from vfi_models.cain import CAIN_VFI +from vfi_utils import MakeInterpolationStateList, FloatToInt + +NODE_CLASS_MAPPINGS = { + "KSampler Gradually Adding More Denoise (efficient)": Gradually_More_Denoise_KSampler, +# "EISAI VFI": EISAI_VFI, + "GMFSS Fortuna VFI": GMFSS_Fortuna_VFI, + "IFRNet VFI": IFRNet_VFI, + "IFUnet VFI": IFUnet_VFI, + "M2M VFI": M2M_VFI, + "RIFE VFI": RIFE_VFI, + "Sepconv VFI": SepconvVFI, + "AMT VFI": AMT_VFI, + "FILM VFI": FILM_VFI, + "Make Interpolation State List": MakeInterpolationStateList, + "STMFNet VFI": STMFNet_VFI, + "FLAVR VFI": FLAVR_VFI, + "CAIN VFI": CAIN_VFI, + "VFI FloatToInt": FloatToInt +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RIFE VFI": "RIFE VFI (recommend rife47 and rife49)" +} \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b99d4a76ab08cc1157e349745770dd1f37c9ffca --- /dev/null +++ b/config.yaml @@ -0,0 +1,3 @@ +#Plz don't delete this file, just edit it when neccessary. +ckpts_path: "./ckpts" +ops_backend: "cupy" #Either "taichi" or "cupy" \ No newline at end of file diff --git a/demo_frames/anime0.png b/demo_frames/anime0.png new file mode 100644 index 0000000000000000000000000000000000000000..d42b064bac5f4973f8e0119462fcfe11a53dedea --- /dev/null +++ b/demo_frames/anime0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:734039ac77a89cf8d52fed8989bd4335392a1d246b099979d1c58a145c629ace +size 340936 diff --git a/demo_frames/anime1.png b/demo_frames/anime1.png new file mode 100644 index 0000000000000000000000000000000000000000..69ffde9dced457a5781aa5ce82e3127bcfc15952 --- /dev/null +++ b/demo_frames/anime1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd24bdafe9a0cfc82eada33c40962e9977ed5b6711ae6d89bf28b07cbded712a +size 329357 diff --git a/demo_frames/bocchi0.jpg b/demo_frames/bocchi0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b8c02eb11b949797b3785e10ac5a1c51e63b8ef --- /dev/null +++ b/demo_frames/bocchi0.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c607fae213b83d4c15fa10d6939b612f7f2242afd0b8716b203ace51774f6718 +size 129774 diff --git a/demo_frames/bocchi1.jpg b/demo_frames/bocchi1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..432a8d455f504cbfa82ae632d4ab5e8c9605b04d --- /dev/null +++ b/demo_frames/bocchi1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f03f067142490d4353d3f5af8bd51b0f9f4bdd3d2094dde6a28f4fec062fbe16 +size 139543 diff --git a/demo_frames/real0.png b/demo_frames/real0.png new file mode 100644 index 0000000000000000000000000000000000000000..863cfb642782237734eb87916e1f24b25d2bc28b --- /dev/null +++ b/demo_frames/real0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4792023ccf17c8231c6eb5ee40de528d515e2f8c419b3949985411a122a4de4f +size 1230238 diff --git a/demo_frames/real1.png b/demo_frames/real1.png new file mode 100644 index 0000000000000000000000000000000000000000..beabe73d95c2114baa3b2d4a4d88ef0dbfd6adb3 --- /dev/null +++ b/demo_frames/real1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37c8e6ec527c81895e5a66ea49cdd18b85045f9fed6fdfb75b45f438649235bf +size 1213845 diff --git a/demo_frames/rick/00003.png b/demo_frames/rick/00003.png new file mode 100644 index 0000000000000000000000000000000000000000..dfa698b75547a7da2e5f4a457b60a02f922fa729 --- /dev/null +++ b/demo_frames/rick/00003.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98f5dba7557ba55d13f494425d340ca84af8b56e35f929fab5df39e54015e265 +size 456245 diff --git a/demo_frames/rick/00004.png b/demo_frames/rick/00004.png new file mode 100644 index 0000000000000000000000000000000000000000..bf50362fecfdb40e67625d7123f823b01c1b2937 --- /dev/null +++ b/demo_frames/rick/00004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61bcf7933b192d84870b80910f7f983371c642d5c7100b34e8cc6dbd01cba7e6 +size 354894 diff --git a/demo_frames/rick/00005.png b/demo_frames/rick/00005.png new file mode 100644 index 0000000000000000000000000000000000000000..f886e35f4c3705b321336f8d2edad8bbd3be0243 --- /dev/null +++ b/demo_frames/rick/00005.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f795d06e93ad4f9c19db578e9378a48b6008cc3df81fb2cd9fbbd5ed91bd8cf7 +size 357224 diff --git a/demo_frames/violet0.png b/demo_frames/violet0.png new file mode 100644 index 0000000000000000000000000000000000000000..b1b60f7c60ae594fc5087ef77172f42e414eeaec --- /dev/null +++ b/demo_frames/violet0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6844899b551801ee22d4f57993ab66fd4b6fbe00eab916d6b987bdf083eadfe +size 888543 diff --git a/demo_frames/violet1.png b/demo_frames/violet1.png new file mode 100644 index 0000000000000000000000000000000000000000..9b8f96923b006f878ea8f6ab7aa03abed402ce57 --- /dev/null +++ b/demo_frames/violet1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66ee9a9a486f57eb80ba5d41140eaca4ca46f0d946a3cff93eabb0ee3b1e29d0 +size 951287 diff --git a/example.png b/example.png new file mode 100644 index 0000000000000000000000000000000000000000..2c949e9406acc3c59d07c95ab96ae8b0b3d4834d --- /dev/null +++ b/example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9a5e9310bfba63b109990b326402d42477688682858bc64f146ef546e6662ead +size 182103 diff --git a/install-taichi.bat b/install-taichi.bat new file mode 100644 index 0000000000000000000000000000000000000000..d601f71c20e8ea2d768ee710a277666f2bd68643 --- /dev/null +++ b/install-taichi.bat @@ -0,0 +1,11 @@ +@echo off +echo Installing Taichi lang backend... + +if exist "%python_exec%" ( + %python_exec% -s -m pip install taichi +) else ( + echo Installing with system Python + pip install taichi +) + +pause \ No newline at end of file diff --git a/install.bat b/install.bat new file mode 100644 index 0000000000000000000000000000000000000000..84e0f7eb536a2c9bfc3313a377086b9cf4aa508f --- /dev/null +++ b/install.bat @@ -0,0 +1,16 @@ +@echo off + +set "requirements_txt=%~dp0\requirements-no-cupy.txt" +set "python_exec=..\..\..\python_embeded\python.exe" + +echo Installing ComfyUI Frame Interpolation.. + +if exist "%python_exec%" ( + echo Installing with ComfyUI Portable + %python_exec% -s install.py +) else ( + echo Installing with system Python + python install.py +) + +pause \ No newline at end of file diff --git a/install.py b/install.py new file mode 100644 index 0000000000000000000000000000000000000000..ecbe35da8bd39829134a6ee7759251e0a203d25b --- /dev/null +++ b/install.py @@ -0,0 +1,59 @@ +import os +from pathlib import Path +import sys +import platform + +def get_cuda_ver_from_dir(cuda_home): + nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(cuda_home)) + nvrtc = list(nvrtc) + if len(nvrtc) == 0: + return + nvrtc = nvrtc[0] + if ('102' in nvrtc) or ('10.2' in nvrtc): + return '102' + if '110' in nvrtc or ('11.0' in nvrtc): + return '110' + if '111' in nvrtc or ('11.1' in nvrtc): + return '111' + if '11' in nvrtc: + return '11x' + if '12' in nvrtc: + return '12x' + +s_param = '-s' if "python_embeded" in sys.executable else '' + +def get_cuda_home_path(): + if "CUDA_HOME" in os.environ: + return os.environ["CUDA_HOME"] + import torch + torch_lib_path = Path(torch.__file__).parent / "lib" + torch_lib_path = str(torch_lib_path.resolve()) + if os.path.exists(torch_lib_path): + nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(torch_lib_path)) + nvrtc = list(nvrtc) + return torch_lib_path if len(nvrtc) > 0 else None + +def install_cupy(): + cuda_home = get_cuda_home_path() + try: + if cuda_home is not None: + os.environ["CUDA_HOME"] = cuda_home + os.environ["CUDA_PATH"] = cuda_home + import cupy + print("CuPy is already installed.") + except: + print("Uninstall cupy if existed...") + os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x') + print("Installing cupy...") + cuda_ver = get_cuda_ver_from_dir(cuda_home) + cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel" + os.system(f'"{sys.executable}" {s_param} -m pip install {cupy_package}') + +with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f: + for package in f.readlines(): + package = package.strip() + print(f"Installing {package}...") + os.system(f'"{sys.executable}" {s_param} -m pip install {package}') + +print("Checking cupy...") +install_cupy() \ No newline at end of file diff --git a/interpolation_schedule.png b/interpolation_schedule.png new file mode 100644 index 0000000000000000000000000000000000000000..1c308d6c811936b5f3383ebc8f9243707ed95be6 --- /dev/null +++ b/interpolation_schedule.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6999ee4a5fd6222b7b05adb8afa4994053bfe8e0f9c6b5cccf25992638b586c +size 378153 diff --git a/other_nodes.py b/other_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..75b76fe109424a389eb4802dc4aa8c59c53329ec --- /dev/null +++ b/other_nodes.py @@ -0,0 +1,88 @@ +import latent_preview +import comfy +import einops +import torch + +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + device = comfy.model_management.get_torch_device() + latent_image = latent["samples"] + + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + + previewer = latent_preview.get_previewer(device, model.model.latent_format) + + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x, total_steps): + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) + + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) + out = latent.copy() + out["samples"] = samples + return (out, ) + +class Gradually_More_Denoise_KSampler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent_image": ("LATENT", ), + + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + + "start_denoise": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "denoise_increment": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.1}), + "denoise_increment_steps": ("INT", {"default": 20, "min": 1, "max": 10000}) + }, + "optional": { "optional_vae": ("VAE",) } + } + + RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "LATENT", "VAE", ) + RETURN_NAMES = ("MODEL", "CONDITIONING+", "CONDITIONING-", "LATENT", "VAE", ) + OUTPUT_NODE = True + FUNCTION = "sample" + CATEGORY = "ComfyUI-Frame-Interpolation/others" + + def sample(self, model, positive, negative, latent_image, optional_vae, + seed, steps, cfg, sampler_name, scheduler,start_denoise, denoise_increment, denoise_increment_steps): + if start_denoise + denoise_increment * denoise_increment_steps > 1.0: + raise Exception(f"Max denoise strength can't over 1.0 (start_denoise={start_denoise}, denoise_increment={denoise_increment}, denoise_increment_steps={denoise_increment_steps}") + + copied_latent = latent_image.copy() + out_samples = [] + + for latent_sample in copied_latent["samples"]: + latent = {"samples": einops.rearrange(latent_sample, "c h w -> 1 c h w")} + #Latent's shape is NCHW + gradually_denoising_samples = [ + common_ksampler( + model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=start_denoise + denoise_increment * i + )[0]["samples"] + for i in range(denoise_increment_steps) + ] + out_samples.extend(gradually_denoising_samples) + + copied_latent["samples"] = torch.cat(out_samples, dim=0) + return (model, positive, negative, copied_latent, optional_vae) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..9878e796c78d0231e7ecb21c7c361a6038bd067f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "comfyui-frame-interpolation" +description = "A custom node suite for Video Frame Interpolation in ComfyUI" +version = "1.0.7" +license = { file = "LICENSE" } + +[project.urls] +Repository = "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation" + +[tool.comfy] +PublisherId = "fannovel16" +DisplayName = "ComfyUI-Frame-Interpolation" +Icon = "" diff --git a/requirements-no-cupy.txt b/requirements-no-cupy.txt new file mode 100644 index 0000000000000000000000000000000000000000..c490ca515d47b6623861b051fd822d24bc2ecd7e --- /dev/null +++ b/requirements-no-cupy.txt @@ -0,0 +1,9 @@ +torch +numpy +einops +opencv-contrib-python +kornia +scipy +Pillow +torchvision +tqdm \ No newline at end of file diff --git a/requirements-with-cupy.txt b/requirements-with-cupy.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdfeb47253006b2897a2dd3ff1e7c91ccb41e1b2 --- /dev/null +++ b/requirements-with-cupy.txt @@ -0,0 +1,10 @@ +torch +numpy +einops +opencv-contrib-python +kornia +scipy +Pillow +torchvision +tqdm +cupy-wheel \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..3c06ae20dc974344ed091dce2e410666a9fdd56f --- /dev/null +++ b/test.py @@ -0,0 +1,38 @@ +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) + +import shutil +import torch +import torch.nn.functional as F +import PIL +import torchvision.transforms.functional as transform +from vfi_utils import load_file_from_github_release +from vfi_models import gmfss_fortuna, ifrnet, ifunet, m2m, rife, sepconv, amt, xvfi, cain, flavr +import numpy as np + +frame_0 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime0.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0) +frame_1 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime1.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0) + + +if os.path.exists("test_result"): + shutil.rmtree("test_result") + +vfi_node_class = gmfss_fortuna.GMFSS_Fortuna_VFI() +for i, ckpt_name in enumerate(vfi_node_class.INPUT_TYPES()["required"]["ckpt_name"][0][:2]): + result = vfi_node_class.vfi(ckpt_name, torch.cat([ + frame_0, + frame_1, + frame_0, + frame_1 + ], dim=0).cuda(), multipler=4, batch_size=2)[0] + print(result.shape) + print(f"Generated {result.size(0)} frames") + frames = [PIL.Image.fromarray(np.clip((frame * 255).numpy(), 0, 255).astype(np.uint8)) for frame in result] + print(result[0].shape) + os.makedirs(f"test_result/video{i}", exist_ok=True) + for j, frame in enumerate(frames): + frame.save(f"test_result/video{i}/{j}.jpg") + frames[0].save(f"test_result/video{i}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1/3, loop=0) + os.startfile(f"test_result{os.path.sep}video{i}.gif") +#torchvision.io.video.write_video("test.mp4", einops.rearrange(result, "n c h w -> n h w c").cpu(), fps=1) \ No newline at end of file diff --git a/test_vfi_schedule.gif b/test_vfi_schedule.gif new file mode 100644 index 0000000000000000000000000000000000000000..52fabecc38ae3050e723e8b4eb213fd41ad36195 --- /dev/null +++ b/test_vfi_schedule.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:931fcd4c2cc84b457cbc1b1c3b8745a2bf292ff7dc43d4f733a2c510ad90353d +size 8409697 diff --git a/vfi_models/amt/__init__.py b/vfi_models/amt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20740c2a4d9d669a625db6f729ef11b18c449e9a --- /dev/null +++ b/vfi_models/amt/__init__.py @@ -0,0 +1,87 @@ +import pathlib +import torch +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_direct_url, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList +import typing +from comfy.model_management import get_torch_device +from .amt_arch import AMT_S, AMT_L, AMT_G, InputPadder + +#https://github.com/MCG-NKU/AMT/tree/main/cfgs +CKPT_CONFIGS = { + "amt-s.pth": { + "network": AMT_S, + "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 3 } + }, + "amt-l.pth": { + "network": AMT_L, + "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 5 } + }, + "amt-g.pth": { + "network": AMT_G, + "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 5 } + }, + "gopro_amt-s.pth": { + "network": AMT_S, + "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 3 } + } +} + + +MODEL_TYPE = pathlib.Path(__file__).parent.name + +class AMT_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (list(CKPT_CONFIGS.keys()), ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 1, "min": 1, "max": 100}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames: typing.SupportsInt = 1, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + model_path = load_file_from_direct_url(MODEL_TYPE, f"https://huggingface.co/lalala125/AMT/resolve/main/{ckpt_name}") + ckpt_config = CKPT_CONFIGS[ckpt_name] + + interpolation_model = ckpt_config["network"](**ckpt_config["params"]) + interpolation_model.load_state_dict(torch.load(model_path)["state_dict"]) + interpolation_model.eval().to(get_torch_device()) + + frames = preprocess_frames(frames) + padder = InputPadder(frames.shape, 16) + frames = padder.pad(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model): + return model( + frame_0, + frame_1, + embt=torch.FloatTensor([timestep] * frame_0.shape[0]).view(frame_0.shape[0], 1, 1, 1).to(get_torch_device()), + scale_factor=1.0, + eval=True + )["imgt_pred"] + + args = [interpolation_model] + out = generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + out = padder.unpad(out) + out = postprocess_frames(out) + return (out,) + diff --git a/vfi_models/amt/amt_arch.py b/vfi_models/amt/amt_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..448c4d56193a833d3a1c89ea3e8475258f11958a --- /dev/null +++ b/vfi_models/amt/amt_arch.py @@ -0,0 +1,1590 @@ +""" +https://github.com/MCG-NKU/AMT/blob/main/utils/dist_utils.py +https://github.com/MCG-NKU/AMT/blob/main/utils/flow_utils.py +https://github.com/MCG-NKU/AMT/blob/main/utils/utils.py +https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/feat_enc.py +https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/ifrnet.py +https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/multi_flow.py +https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/raft.py +https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-S.py +https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-L.py +https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-G.py +""" +#Removed imageio by removing readImage, writeImage +#The model will receive image tensors from other ComfyUI's nodes so they are unneccessary + +import torch +import torch.nn as nn +import numpy as np +from PIL import ImageFile +import torch.nn.functional as F +ImageFile.LOAD_TRUNCATED_IMAGES = True +import re +import sys +import random + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output + + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) + + + + + + + + + + + +class AverageMeter(): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0. + self.avg = 0. + self.sum = 0. + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class AverageMeterGroups: + def __init__(self) -> None: + self.meter_dict = dict() + + def update(self, dict, n=1): + for name, val in dict.items(): + if self.meter_dict.get(name) is None: + self.meter_dict[name] = AverageMeter() + self.meter_dict[name].update(val, n) + + def reset(self, name=None): + if name is None: + for v in self.meter_dict.values(): + v.reset() + else: + meter = self.meter_dict.get(name) + if meter is not None: + meter.reset() + + def avg(self, name): + meter = self.meter_dict.get(name) + if meter is not None: + return meter.avg + + +class InputPadder: + """ Pads images such that dimensions are divisible by divisor """ + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + + def pad(self, input_tensor): + return F.pad(input_tensor, self._pad, mode='replicate') + + def unpad(self, input_tensor): + return self._unpad(input_tensor) + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def img2tensor(img): + if img.shape[-1] > 3: + img = img[:,:,:3] + return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0 + + +def tensor2img(img_t): + return (img_t * 255.).detach( + ).squeeze(0).permute(1, 2, 0).cpu().numpy( + ).clip(0, 255).astype(np.uint8) + +def seed_all(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + endian = '<' + scale = -scale + else: + endian = '>' + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + + +def readFlow(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + return readPFM(name)[0][:,:,0:2] + + f = open(name, 'rb') + + header = f.read(4) + if header.decode("utf-8") != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + + width = np.fromfile(f, np.int32, 1).squeeze() + height = np.fromfile(f, np.int32, 1).squeeze() + + flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) + + return flow.astype(np.float32) + +def writeFlow(name, flow): + f = open(name, 'wb') + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + + +def readFloat(name): + f = open(name, 'rb') + + if(f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + + data = np.fromfile(f, np.float32, count).reshape(dims) + if dim > 2: + data = np.transpose(data, (2, 1, 0)) + data = np.transpose(data, (1, 0, 2)) + + return data + + +def writeFloat(name, data): + f = open(name, 'wb') + + dim=len(data.shape) + if dim>3: + raise Exception('bad float file dimension: %d' % dim) + + f.write(('float\n').encode('ascii')) + f.write(('%d\n' % dim).encode('ascii')) + + if dim == 1: + f.write(('%d\n' % data.shape[0]).encode('ascii')) + else: + f.write(('%d\n' % data.shape[1]).encode('ascii')) + f.write(('%d\n' % data.shape[0]).encode('ascii')) + for i in range(2, dim): + f.write(('%d\n' % data.shape[i]).encode('ascii')) + + data = data.astype(np.float32) + if dim==2: + data.tofile(f) + + else: + np.transpose(data, (2, 0, 1)).tofile(f) + + +def check_dim_and_resize(tensor_list): + shape_list = [] + for t in tensor_list: + shape_list.append(t.shape[2:]) + + if len(set(shape_list)) > 1: + desired_shape = shape_list[0] + print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}') + + resize_tensor_list = [] + for t in tensor_list: + resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear')) + + tensor_list = resize_tensor_list + + return tensor_list + + + + + + + + + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + + + + + + + + + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k ==7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx+1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*2+1, in_ch*2), + ResBlock(in_ch*2, skip_ch), + nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ + + + + + + + + + + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, + mask=None, img_res=None, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w + ).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w + ).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 + ) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res + + + + + + + + + + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + + centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + + + + + + + + + + +class AMT_S(nn.Module): + def __init__(self, + corr_radius=3, + corr_lvls=4, + num_flows=3, + channels=[20, 32, 44, 56], + skip_channels=20): + super(AMT_S, self).__init__() + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + self.channels = channels + self.skip_channels = skip_channels + + self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels) + + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(44) + self.update3 = self._get_updateblock(32, 2) + self.update2 = self._get_updateblock(20, 4) + + self.comb_block = nn.Sequential( + nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1), + nn.PReLU(6*num_flows), + nn.Conv2d(6*num_flows, 3, 3, 1, 1), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64, + fc_dim=68, scale_factor=scale_factor, + corr_levels=self.corr_levels, radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + embt, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + embt, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + embt, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } + + + + + + + + + + + +class AMT_L(nn.Module): + def __init__(self, + corr_radius=3, + corr_lvls=4, + num_flows=5, + channels=[48, 64, 72, 128], + skip_channels=48 + ): + super(AMT_L, self).__init__() + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + + self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder([48, 64, 72, 128], large=True) + + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(72, None) + self.update3 = self._get_updateblock(64, 2.0) + self.update2 = self._get_updateblock(48, 4.0) + + self.comb_block = nn.Sequential( + nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), + nn.PReLU(6*self.num_flows), + nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48, + corr_dim=256, corr_dim2=160, fc_dim=124, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + embt, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + embt, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + embt, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } + + + + + + + + + + + +class AMT_G(nn.Module): + def __init__(self, + corr_radius=3, + corr_lvls=4, + num_flows=5, + channels=[84, 96, 112, 128], + skip_channels=84): + super(AMT_G, self).__init__() + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(112, None) + self.update3_low = self._get_updateblock(96, 2.0) + self.update2_low = self._get_updateblock(84, 4.0) + + self.update3_high = self._get_updateblock(96, None) + self.update2_high = self._get_updateblock(84, None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3), + nn.PReLU(6*self.num_flows), + nn.Conv2d(6*self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + embt, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + embt, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + embt, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } \ No newline at end of file diff --git a/vfi_models/cain/__init__.py b/vfi_models/cain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1166d1fad2639bb9ce49a2e2ff3471ef6a6415 --- /dev/null +++ b/vfi_models/cain/__init__.py @@ -0,0 +1,64 @@ +import torch +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList +import typing +from comfy.model_management import get_torch_device + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAMES = ["pretrained_cain.pth"] + + +class CAIN_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (CKPT_NAMES, ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames: typing.SupportsInt = 1, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + from .cain_arch import CAIN + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + sd = torch.load(model_path)["state_dict"] + sd = {key.replace('module.', ''): value for key, value in sd.items()} + + + global interpolation_model + interpolation_model = CAIN(depth=3) + interpolation_model.load_state_dict(sd) + interpolation_model.eval().to(get_torch_device()) + del sd + + frames = preprocess_frames(frames) + + + def return_middle_frame(frame_0, frame_1, timestep, model): + #CAIN does some direct modifications to input frame tensors so we need to clone them + return model(frame_0.detach().clone(), frame_1.detach().clone())[0] + + args = [interpolation_model] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, use_timestep=False, dtype=torch.float32) + ) + return (out,) \ No newline at end of file diff --git a/vfi_models/cain/cain_arch.py b/vfi_models/cain/cain_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..abd8db30b7f4260ea847da249d4d4114f8091d95 --- /dev/null +++ b/vfi_models/cain/cain_arch.py @@ -0,0 +1,74 @@ +import math +import numpy as np + +import torch +import torch.nn as nn + +from .common import * + + +class Encoder(nn.Module): + def __init__(self, in_channels=3, depth=3): + super(Encoder, self).__init__() + + # Shuffle pixels to expand in channel dimension + # shuffler_list = [PixelShuffle(0.5) for i in range(depth)] + # self.shuffler = nn.Sequential(*shuffler_list) + self.shuffler = PixelShuffle(1 / 2**depth) + + relu = nn.LeakyReLU(0.2, True) + + # FF_RCAN or FF_Resblocks + self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu) + + def forward(self, x1, x2): + """ + Encoder: Shuffle-spread --> Feature Fusion --> Return fused features + """ + feats1 = self.shuffler(x1) + feats2 = self.shuffler(x2) + + feats = self.interpolate(feats1, feats2) + + return feats + + +class Decoder(nn.Module): + def __init__(self, depth=3): + super(Decoder, self).__init__() + + # shuffler_list = [PixelShuffle(2) for i in range(depth)] + # self.shuffler = nn.Sequential(*shuffler_list) + self.shuffler = PixelShuffle(2**depth) + + def forward(self, feats): + out = self.shuffler(feats) + return out + + +class CAIN(nn.Module): + def __init__(self, depth=3): + super(CAIN, self).__init__() + + self.encoder = Encoder(in_channels=3, depth=depth) + self.decoder = Decoder(depth=depth) + + def forward(self, x1, x2): + x1, m1 = sub_mean(x1) + x2, m2 = sub_mean(x2) + + if not self.training: + paddingInput, paddingOutput = InOutPaddings(x1) + x1 = paddingInput(x1) + x2 = paddingInput(x2) + + feats = self.encoder(x1, x2) + out = self.decoder(feats) + + if not self.training: + out = paddingOutput(out) + + mi = (m1 + m2) / 2 + out += mi + + return out, feats \ No newline at end of file diff --git a/vfi_models/cain/cain_encdec_arch.py b/vfi_models/cain/cain_encdec_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..ea10ed87e38bee995f435bb4305e221e4c13a4a6 --- /dev/null +++ b/vfi_models/cain/cain_encdec_arch.py @@ -0,0 +1,95 @@ +import math +import numpy as np + +import torch +import torch.nn as nn + +from .common import * +from comfy.model_management import get_torch_device + +class Encoder(nn.Module): + def __init__(self, in_channels=3, depth=3, nf_start=32, norm=False): + super(Encoder, self).__init__() + self.device = get_torch_device() + + nf = nf_start + relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.body = nn.Sequential( + ConvNorm(in_channels, nf * 1, 7, stride=1, norm=norm), + relu, + ConvNorm(nf * 1, nf * 2, 5, stride=2, norm=norm), + relu, + ConvNorm(nf * 2, nf * 4, 5, stride=2, norm=norm), + relu, + ConvNorm(nf * 4, nf * 6, 5, stride=2, norm=norm) + ) + + self.interpolate = Interpolation(5, 12, nf * 6, reduction=16, act=relu) + + def forward(self, x1, x2): + """ + Encoder: Feature Extraction --> Feature Fusion --> Return + """ + feats1 = self.body(x1) + feats2 = self.body(x2) + + feats = self.interpolate(feats1, feats2) + + return feats + + +class Decoder(nn.Module): + def __init__(self, in_channels=192, out_channels=3, depth=3, norm=False, up_mode='shuffle'): + super(Decoder, self).__init__() + self.device = get_torch_device() + + relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + nf = [in_channels, (in_channels*2)//3, in_channels//3, in_channels//6] + #nf = [192, 128, 64, 32] + #nf = [186, 124, 62, 31] + self.body = nn.Sequential( + UpConvNorm(nf[0], nf[1], mode=up_mode, norm=norm), + ResBlock(nf[1], nf[1], norm=norm, act=relu), + UpConvNorm(nf[1], nf[2], mode=up_mode, norm=norm), + ResBlock(nf[2], nf[2], norm=norm, act=relu), + UpConvNorm(nf[2], nf[3], mode=up_mode, norm=norm), + ResBlock(nf[3], nf[3], norm=norm, act=relu), + conv7x7(nf[3], out_channels) + ) + + def forward(self, feats): + out = self.body(feats) + #out = self.conv_final(out) + + return out + + +class CAIN_EncDec(nn.Module): + def __init__(self, depth=3, n_resblocks=3, start_filts=32, up_mode='shuffle'): + super(CAIN_EncDec, self).__init__() + self.depth = depth + + self.encoder = Encoder(in_channels=3, depth=depth, norm=False) + self.decoder = Decoder(in_channels=start_filts*6, depth=depth, norm=False, up_mode=up_mode) + + def forward(self, x1, x2): + x1, m1 = sub_mean(x1) + x2, m2 = sub_mean(x2) + + if not self.training: + paddingInput, paddingOutput = InOutPaddings(x1) + x1 = paddingInput(x1) + x2 = paddingInput(x2) + + feats = self.encoder(x1, x2) + out = self.decoder(feats) + + if not self.training: + out = paddingOutput(out) + + mi = (m1 + m2)/2 + out += mi + + return out, feats \ No newline at end of file diff --git a/vfi_models/cain/cain_noca_arch.py b/vfi_models/cain/cain_noca_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..08fbb117e9acef30d49627d5c906622b13197623 --- /dev/null +++ b/vfi_models/cain/cain_noca_arch.py @@ -0,0 +1,73 @@ +import math +import numpy as np + +import torch +import torch.nn as nn + +from .common import * +from comfy.model_management import get_torch_device + +class Encoder(nn.Module): + def __init__(self, in_channels=3, depth=3): + super(Encoder, self).__init__() + self.device = get_torch_device() + + self.shuffler = PixelShuffle(1/2**depth) + # self.shuffler = nn.Sequential( + # PixelShuffle(1/2), + # PixelShuffle(1/2), + # PixelShuffle(1/2)) + self.interpolate = Interpolation_res(5, 12, in_channels * (4**depth)) + + def forward(self, x1, x2): + feats1 = self.shuffler(x1) + feats2 = self.shuffler(x2) + + feats = self.interpolate(feats1, feats2) + + return feats + + +class Decoder(nn.Module): + def __init__(self, depth=3): + super(Decoder, self).__init__() + self.device = get_torch_device() + + self.shuffler = PixelShuffle(2**depth) + # self.shuffler = nn.Sequential( + # PixelShuffle(2), + # PixelShuffle(2), + # PixelShuffle(2)) + + def forward(self, feats): + out = self.shuffler(feats) + return out + + +class CAIN_NoCA(nn.Module): + def __init__(self, depth=3): + super(CAIN_NoCA, self).__init__() + self.depth = depth + + self.encoder = Encoder(in_channels=3, depth=depth) + self.decoder = Decoder(depth=depth) + + def forward(self, x1, x2): + x1, m1 = sub_mean(x1) + x2, m2 = sub_mean(x2) + + if not self.training: + paddingInput, paddingOutput = InOutPaddings(x1) + x1 = paddingInput(x1) + x2 = paddingInput(x2) + + feats = self.encoder(x1, x2) + out = self.decoder(feats) + + if not self.training: + out = paddingOutput(out) + + mi = (m1 + m2) / 2 + out += mi + + return out, feats \ No newline at end of file diff --git a/vfi_models/cain/common.py b/vfi_models/cain/common.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1600a231d32eec39785883cb26b93fd9c64e62 --- /dev/null +++ b/vfi_models/cain/common.py @@ -0,0 +1,361 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def sub_mean(x): + mean = x.mean(2, keepdim=True).mean(3, keepdim=True) + x -= mean + return x, mean + +def InOutPaddings(x): + w, h = x.size(3), x.size(2) + padding_width, padding_height = 0, 0 + if w != ((w >> 7) << 7): + padding_width = (((w >> 7) + 1) << 7) - w + if h != ((h >> 7) << 7): + padding_height = (((h >> 7) + 1) << 7) - h + paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2, + padding_height // 2, padding_height - padding_height // 2]) + paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width, + 0 - padding_height // 2, padding_height // 2 - padding_height]) + return paddingInput, paddingOutput + + +class ConvNorm(nn.Module): + def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False): + super(ConvNorm, self).__init__() + + reflection_padding = kernel_size // 2 + self.reflection_pad = nn.ReflectionPad2d(reflection_padding) + self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True) + + self.norm = norm + if norm == 'IN': + self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True) + elif norm == 'BN': + self.norm = nn.BatchNorm2d(out_feat) + + def forward(self, x): + out = self.reflection_pad(x) + out = self.conv(out) + if self.norm: + out = self.norm(out) + return out + + +class UpConvNorm(nn.Module): + def __init__(self, in_channels, out_channels, mode='transpose', norm=False): + super(UpConvNorm, self).__init__() + + if mode == 'transpose': + self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) + elif mode == 'shuffle': + self.upconv = nn.Sequential( + ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm), + PixelShuffle(2)) + else: + # out_channels is always going to be the same as in_channels + self.upconv = nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), + ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm)) + + def forward(self, x): + out = self.upconv(x) + return out + + + +class meanShift(nn.Module): + def __init__(self, rgbRange, rgbMean, sign, nChannel=3): + super(meanShift, self).__init__() + if nChannel == 1: + l = rgbMean[0] * rgbRange * float(sign) + + self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0) + self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1) + self.shifter.bias.data = torch.Tensor([l]) + elif nChannel == 3: + r = rgbMean[0] * rgbRange * float(sign) + g = rgbMean[1] * rgbRange * float(sign) + b = rgbMean[2] * rgbRange * float(sign) + + self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0) + self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) + self.shifter.bias.data = torch.Tensor([r, g, b]) + else: + r = rgbMean[0] * rgbRange * float(sign) + g = rgbMean[1] * rgbRange * float(sign) + b = rgbMean[2] * rgbRange * float(sign) + self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0) + self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1) + self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b]) + + # Freeze the meanShift layer + for params in self.shifter.parameters(): + params.requires_grad = False + + def forward(self, x): + x = self.shifter(x) + + return x + + +""" CONV - (BN) - RELU - CONV - (BN) """ +class ResBlock(nn.Module): + def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder + norm=False, act=nn.ReLU(True), downscale=False): + super(ResBlock, self).__init__() + + self.body = nn.Sequential( + ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1), + act, + ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1) + ) + + self.downscale = None + if downscale: + self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2) + + def forward(self, x): + res = x + out = self.body(x) + if self.downscale is not None: + res = self.downscale(res) + out += res + + return out + + +## Channel Attention (CA) Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y, y + + +## Residual Channel Attention Block (RCAB) +class RCAB(nn.Module): + def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True, + norm=False, act=nn.ReLU(True), downscale=False, return_ca=False): + super(RCAB, self).__init__() + + self.body = nn.Sequential( + ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm), + act, + ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm), + CALayer(out_feat, reduction) + ) + self.downscale = downscale + if downscale: + self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1) + self.return_ca = return_ca + + def forward(self, x): + res = x + out, ca = self.body(x) + if self.downscale: + res = self.downConv(res) + out += res + + if self.return_ca: + return out, ca + else: + return out + + +## Residual Group (RG) +class ResidualGroup(nn.Module): + def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False): + super(ResidualGroup, self).__init__() + + modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act) + for _ in range(n_resblocks)] + modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +def pixel_shuffle(input, scale_factor): + batch_size, channels, in_height, in_width = input.size() + + out_channels = int(int(channels / scale_factor) / scale_factor) + out_height = int(in_height * scale_factor) + out_width = int(in_width * scale_factor) + + if scale_factor >= 1: + input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width) + shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() + else: + block_size = int(1 / scale_factor) + input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size) + shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() + + return shuffle_out.view(batch_size, out_channels, out_height, out_width) + + +class PixelShuffle(nn.Module): + def __init__(self, scale_factor): + super(PixelShuffle, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + return pixel_shuffle(x, self.scale_factor) + def extra_repr(self): + return 'scale_factor={}'.format(self.scale_factor) + + +def conv(in_channels, out_channels, kernel_size, + stride=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size//2, + stride=1, + bias=bias, + groups=groups) + + +def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=bias, + groups=groups) + +def conv3x3(in_channels, out_channels, stride=1, + padding=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def conv5x5(in_channels, out_channels, stride=1, + padding=2, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=5, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def conv7x7(in_channels, out_channels, stride=1, + padding=3, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def upconv2x2(in_channels, out_channels, mode='shuffle'): + if mode == 'transpose': + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=4, + stride=2, + padding=1) + elif mode == 'shuffle': + return nn.Sequential( + conv3x3(in_channels, 4*out_channels), + PixelShuffle(2)) + else: + # out_channels is always going to be the same as in_channels + return nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), + conv1x1(in_channels, out_channels)) + + + +class Interpolation(nn.Module): + def __init__(self, n_resgroups, n_resblocks, n_feats, + reduction=16, act=nn.LeakyReLU(0.2, True), norm=False): + super(Interpolation, self).__init__() + + # define modules: head, body, tail + self.headConv = conv3x3(n_feats * 2, n_feats) + + modules_body = [ + ResidualGroup( + RCAB, + n_resblocks=n_resblocks, + n_feat=n_feats, + kernel_size=3, + reduction=reduction, + act=act, + norm=norm) + for _ in range(n_resgroups)] + self.body = nn.Sequential(*modules_body) + + self.tailConv = conv3x3(n_feats, n_feats) + + def forward(self, x0, x1): + # Build input tensor + x = torch.cat([x0, x1], dim=1) + x = self.headConv(x) + + res = self.body(x) + res += x + + out = self.tailConv(res) + return out + + +class Interpolation_res(nn.Module): + def __init__(self, n_resgroups, n_resblocks, n_feats, + act=nn.LeakyReLU(0.2, True), norm=False): + super(Interpolation_res, self).__init__() + + # define modules: head, body, tail (reduces concatenated inputs to n_feat) + self.headConv = conv3x3(n_feats * 2, n_feats) + + modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3, + reduction=0, act=act, norm=norm) + for _ in range(n_resgroups)] + self.body = nn.Sequential(*modules_body) + + self.tailConv = conv3x3(n_feats, n_feats) + + def forward(self, x0, x1): + # Build input tensor + x = torch.cat([x0, x1], dim=1) + x = self.headConv(x) + + res = x + for m in self.body: + res = m(res) + res += x + + x = self.tailConv(res) + + return x \ No newline at end of file diff --git a/vfi_models/eisai/__init__.py b/vfi_models/eisai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3602505192621731e00acde67489c6a51360adb8 --- /dev/null +++ b/vfi_models/eisai/__init__.py @@ -0,0 +1,84 @@ +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList +import typing +import torch +import torch.nn as nn +from comfy.model_management import soft_empty_cache, get_torch_device + +MODEL_TYPE = pathlib.Path(__file__).parent.name +MODEL_FILE_NAMES = { + "ssl": "eisai_ssl.pt", + "dtm": "eisai_dtm.pt", + "raft": "eisai_anime_interp_full.ckpt" +} + +class EISAI(nn.Module): + def __init__(self, model_file_names) -> None: + from .eisai_arch import SoftsplatLite, DTM, RAFT + super(EISAI, self).__init__() + self.raft = RAFT(load_file_from_github_release(MODEL_TYPE, model_file_names["raft"])) + self.raft.to(get_torch_device()).eval() + + self.ssl = SoftsplatLite() + self.ssl.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["ssl"]))) + self.ssl.to(get_torch_device()).eval() + + self.dtm = DTM() + self.dtm.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["dtm"]))) + self.dtm.to(get_torch_device()).eval() + + def forward(self, img0, img1, t): + with torch.no_grad(): + flow0, _ = self.raft(img0, img1) + flow1, _ = self.raft(img1, img0) + x = { + "images": torch.stack([img0, img1], dim=1), + "flows": torch.stack([flow0, flow1], dim=1), + } + out_ssl, _ = self.ssl(x, t=t, return_more=True) + out_dtm, _ = self.dtm(x, out_ssl, _, return_more=False) + return out_dtm[:, :3] + +class EISAI_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (["eisai"], ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + interpolation_model = EISAI(MODEL_FILE_NAMES) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model): + return model(frame_0, frame_1, t=timestep) + + scale = 1 + + args = [interpolation_model, scale] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + ) + return (out,) diff --git a/vfi_models/eisai/eisai_arch.py b/vfi_models/eisai/eisai_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3abae7187b26fedd86219058a7e8e024e59e30 --- /dev/null +++ b/vfi_models/eisai/eisai_arch.py @@ -0,0 +1,2586 @@ +""" +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_scripts/interpolate.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/models/ssldtm.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/util_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/twodee_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/pytorch_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/distance_transform_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/sketchers_v1.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/interpolator_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/gridnet_v1.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/flow_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/softsplat_v0.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/rfr_new.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/extractor.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/update.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/corr.py +https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/utils.py +""" + +import copy +import cv2 +import torch.nn.functional as F +import torchvision.transforms.functional as F +import gc +from PIL import Image, ImageFile, ImageFont, ImageDraw +import inspect +from scipy import interpolate +import kornia +import math +from argparse import Namespace +import torch.nn as nn +import numpy as np +import os +from functools import partial +import pathlib +import PIL +import re +import requests +from scipy.spatial.transform import Rotation +import scipy +import shutil +import torchvision.transforms as T +import time +import torch +import torchvision as tv +import zlib +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm as std_tqdm +from tqdm.auto import trange as std_trange +from vfi_models.ops import FunctionSoftsplat, batch_edt +from comfy.model_management import get_torch_device + +device = get_torch_device() +autocast = torch.autocast +tqdm = partial(std_tqdm, dynamic_ncols=True) +trange = partial(std_trange, dynamic_ncols=True) +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def pixel_ij(x, rounding=True): + if isinstance(x, np.ndarray): + x = x.tolist() + return tuple( + pixel_rounder(i, rounding) + for i in (x if isinstance(x, tuple) or isinstance(x, list) else (x, x)) + ) + + +def rescale_dry(x, factor): + h, w = x[-2:] if isinstance(x, tuple) or isinstance(x, list) else I(x).size + return (h * factor, w * factor) + + +def pixel_rounder(n, mode): + if mode == True or mode == "round": + return round(n) + elif mode == "ceil": + return math.ceil(n) + elif mode == "floor": + return math.floor(n) + else: + return n + + +def diam(x): + if isinstance(x, tuple) or isinstance(x, list): + h, w = x[-2:] + elif isinstance(x, I): + h, w = x.size + else: + h, w = x.shape[-2:] + return np.sqrt(h**2 + w**2) + + +def pixel_logit(x, pixel_margin=1): + x = (x * (255 - 2 * pixel_margin) + pixel_margin) / 255 + return torch.log(x / (1 - x)) + + +class InputPadder: + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode="replicate") for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method="cubic", fill_value=0) + + flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method="cubic", fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + # print(img.size()) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode="bilinear"): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class BasicEncoder1(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(BasicEncoder1, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +################################################## +# RFR is implemented based on RAFT optical flow # +################################################## + + +def backwarp(img, flow): + _, _, H, W = img.size() + + u = flow[:, 0, :, :] + v = flow[:, 1, :, :] + + gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) + + gridX = torch.tensor( + gridX, + requires_grad=False, + ).cuda() + gridY = torch.tensor( + gridY, + requires_grad=False, + ).cuda() + x = gridX.unsqueeze(0).expand_as(u).float() + u + y = gridY.unsqueeze(0).expand_as(v).float() + v + # range -1 to 1 + x = 2 * (x / (W - 1) - 0.5) + y = 2 * (y / (H - 1) - 0.5) + # stacking X and Y + grid = torch.stack((x, y), dim=3) + # Sample pixels using bilinear interpolation. + imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True) + + return imgOut + + +class ErrorAttention(nn.Module): + """A three-layer network for predicting mask""" + + def __init__(self, input, output): + super(ErrorAttention, self).__init__() + self.conv1 = nn.Conv2d(input, 32, 5, padding=2) + self.conv2 = nn.Conv2d(32, 32, 3, padding=1) + self.conv3 = nn.Conv2d(38, output, 3, padding=1) + self.prelu1 = nn.PReLU() + self.prelu2 = nn.PReLU() + + def forward(self, x1): + x = self.prelu1(self.conv1(x1)) + x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1)) + x = self.conv3(x) + return x + + +class RFR(nn.Module): + def __init__(self, args): + super(RFR, self).__init__() + self.attention2 = ErrorAttention(6, 1) + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + args.dropout = 0 + self.args = args + + # feature network, context network, and update block + self.fnet = BasicEncoder(output_dim=256, norm_fn="none", dropout=args.dropout) + # self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8).to(img.device) + coords1 = coords_grid(N, H // 8, W // 8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward( + self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False + ): + H, W = image1.size()[2:4] + H8 = H // 8 * 8 + W8 = W // 8 * 8 + + if flow_init is not None: + flow_init_resize = F.interpolate( + flow_init, size=(H8 // 8, W8 // 8), mode="nearest" + ) + + flow_init_resize[:, :1] = ( + flow_init_resize[:, :1].clone() * (W8 // 8 * 1.0) / flow_init.size()[3] + ) + flow_init_resize[:, 1:] = ( + flow_init_resize[:, 1:].clone() * (H8 // 8 * 1.0) / flow_init.size()[2] + ) + + if not hasattr(self.args, "not_use_rfr_mask") or ( + hasattr(self.args, "not_use_rfr_mask") + and (not self.args.not_use_rfr_mask) + ): + im18 = F.interpolate(image1, size=(H8 // 8, W8 // 8), mode="bilinear") + im28 = F.interpolate(image2, size=(H8 // 8, W8 // 8), mode="bilinear") + + warp21 = backwarp(im28, flow_init_resize) + error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True) + # print('errormin', error21.min(), error21.max()) + f12init = ( + torch.exp( + -self.attention2( + torch.cat([im18, error21, flow_init_resize], dim=1) + ) + ** 2 + ) + * flow_init_resize + ) + else: + flow_init_resize = None + flow_init = torch.zeros( + image1.size()[0], 2, image1.size()[2] // 8, image1.size()[3] // 8 + ).cuda() + error21 = torch.zeros( + image1.size()[0], 1, image1.size()[2] // 8, image1.size()[3] // 8 + ).cuda() + + f12_init = flow_init + # print('None inital flow!') + + image1 = F.interpolate(image1, size=(H8, W8), mode="bilinear") + image2 = F.interpolate(image2, size=(H8, W8), mode="bilinear") + + f12s, f12, f12_init = self.forward_pred( + image1, image2, iters, flow_init_resize, upsample, test_mode + ) + + if hasattr(self.args, "requires_sq_flow") and self.args.requires_sq_flow: + for ii in range(len(f12s)): + f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode="bilinear") + f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0 * W8) * W + f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0 * H8) * H + if self.training: + return f12s + else: + return [f12s[-1]], f12_init + else: + f12[:, :1] = f12[:, :1].clone() / (1.0 * W8) * W + f12[:, 1:] = f12[:, 1:].clone() / (1.0 * H8) * H + + f12 = F.interpolate(f12, size=(H, W), mode="bilinear") + # print('wo!!') + return ( + f12, + f12_init, + error21, + ) + + def forward_pred( + self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False + ): + """Estimate optical flow between pair of frames""" + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(device.type, enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + fmap1 = fmap1.float() + fmap2 = fmap2.float() + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(device.type, enabled=self.args.mixed_precision): + cnet = self.fnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + if itr == 0: + if flow_init is not None: + coords1 = coords1 + flow_init + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(device.type, enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + return flow_predictions, flow_up, flow_init + +####################### WARPING ####################### + + +# expects batched tensors, considered low-level operation +# img: bs, ch, h, w +# flow: bs, xy (pix displace), h, w +def flow_backwarp( + img, flow, resample="bilinear", padding_mode="border", align_corners=False +): + if len(img.shape) != 4: + img = img[None,] + if len(flow.shape) != 4: + flow = flow[None,] + q = ( + 2 + * flow + / torch.tensor( + [ + flow.shape[-2], + flow.shape[-1], + ], + device=flow.device, + dtype=torch.float, + )[None, :, None, None] + ) + q = q + torch.stack( + torch.meshgrid( + torch.linspace(-1, 1, flow.shape[-2]), + torch.linspace(-1, 1, flow.shape[-1]), + ) + )[ + None, + ].to( + flow.device + ) + if img.dtype != q.dtype: + img = img.type(q.dtype) + + return nn.functional.grid_sample( + img, + q.flip(dims=(1,)).permute(0, 2, 3, 1), + mode=resample, # nearest, bicubic, bilinear + padding_mode=padding_mode, # border, zeros, reflection + align_corners=align_corners, + ) + + +backwarp = flow_warp = flow_backwarp + + +# mode: sum, avg, lin, softmax +# lin/softmax w/out metric defaults to avg +# must use gpu, move back to cpu if retain_device +# typical metric: -20 * | img0 - backwarp(img1,flow) | +# From Fannovel16: Changed mode params for common ops. +def flow_forewarp( + img, flow, mode="average", metric=None, mask=False, retain_device=True +): + # setup + #if mode == "sum": + # mode = "summation" + #elif mode == "avg": + # mode = "average" + if mode in ["lin", "linear"]: + #mode = "linear" if metric is not None else "average" + mode = "linear" if metric is not None else "avg" + elif mode in ["sm", "softmax"]: + #mode = "softmax" if metric is not None else "average" + mode = "soft" if metric is not None else "avg" + if len(img.shape) != 4: + img = img[None,] + if len(flow.shape) != 4: + flow = flow[None,] + if metric is not None and len(metric.shape) != 4: + metric = metric[None,] + flow = flow.flip(dims=(1,)) + if img.dtype != torch.float32: + img = img.type(torch.float32) + if flow.dtype != torch.float32: + flow = flow.type(torch.float32) + if metric is not None and metric.dtype != torch.float32: + metric = metric.type(torch.float32) + + # move to gpu if necessary + assert img.device == flow.device + if metric is not None: + assert img.device == metric.device + was_cpu = img.device.type == "cpu" + if was_cpu: + img = img.to("cuda") + flow = flow.to("cuda") + if metric is not None: + metric = metric.to("cuda") + + # add mask + if mask: + bs, ch, h, w = img.shape + img = torch.cat( + [img, torch.ones(bs, 1, h, w, dtype=img.dtype, device=img.device)], dim=1 + ) + + # forward, move back to cpu if desired + ans = FunctionSoftsplat(img, flow, metric, mode) + if was_cpu and retain_device: + ans = ans.cpu() + return ans + + +forewarp = flow_forewarp + + +# resizing utility +def flow_resize(flow, size, mode="nearest", align_corners=False): + # flow: bs,xy,h,w + size = pixel_ij(size, rounding=True) + if flow.dtype != torch.float: + flow = flow.float() + if len(flow.shape) == 3: + flow = flow[None,] + if flow.shape[-2:] == size: + return flow + return ( + nn.functional.interpolate( + flow, + size=size, + mode=mode, + align_corners=align_corners if mode != "nearest" else None, + ) + * torch.tensor( + [b / a for a, b in zip(flow.shape[-2:], size)], + device=flow.device, + )[None, :, None, None] + ) + + +####################### TRADITIONAL ####################### + +# dense +_lucaskanade = lambda a, b: np.moveaxis( + cv2.optflow.calcOpticalFlowSparseToDense( + a, + b, # grid_step=5, sigma=0.5, + ), + 2, + 0, +)[ + None, +] +_farneback = lambda a, b: np.moveaxis( + cv2.calcOpticalFlowFarneback( + a, + b, + None, + 0.6, + 3, + 25, + 7, + 5, + 1.2, + cv2.OPTFLOW_FARNEBACK_GAUSSIAN, + ), + 2, + 0, +)[ + None, +] +_dtvl1_ = cv2.optflow.createOptFlow_DualTVL1() +_dtvl1 = lambda a, b: np.moveaxis( + _dtvl1_.calc( + a, + b, + None, + ), + 2, + 0, +)[ + None, +] +_simple = lambda a, b: np.moveaxis( + cv2.optflow.calcOpticalFlowSF( + a, + b, + 3, + 5, + 5, + ), + 2, + 0, +)[ + None, +] +_pca_ = cv2.optflow.createOptFlow_PCAFlow() +_pca = lambda a, b: np.moveaxis( + _pca_.calc( + a, + b, + None, + ), + 2, + 0, +)[ + None, +] +_drlof = lambda a, b: np.moveaxis( + cv2.optflow.calcOpticalFlowDenseRLOF( + a, + b, + None, + ), + 2, + 0, +)[ + None, +] +_deepflow_ = cv2.optflow.createOptFlow_DeepFlow() +_deepflow = lambda a, b: np.moveaxis( + _deepflow_.calc( + a, + b, + None, + ), + 2, + 0, +)[ + None, +] + + +def cv2flow(a, b, method="lucaskanade", back=False): + if method == "lucaskanade": + f = _lucaskanade + a = a.convert("L").cv2() + b = b.convert("L").cv2() + elif method == "farneback": + f = _farneback + a = a.convert("L").cv2() + b = b.convert("L").cv2() + elif method == "dtvl1": + f = _dtvl1 + a = a.convert("L").cv2() + b = b.convert("L").cv2() + elif method == "simple": + f = _simple + a = a.convert("RGB").cv2() + b = b.convert("RGB").cv2() + elif method == "pca": + f = _pca + a = a.convert("L").cv2() + b = b.convert("L").cv2() + elif method == "drlof": + f = _drlof + a = a.convert("RGB").cv2() + b = b.convert("RGB").cv2() + elif method == "deepflow": + f = _deepflow + a = a.convert("L").cv2() + b = b.convert("L").cv2() + else: + assert 0 + ans = f(b, a) + if back: + ans = np.concatenate( + [ + ans, + f(a, b), + ] + ) + return torch.tensor(ans).flip(dims=(1,)) + + +####################### FLOWNET2 ####################### + + +def flownet2(img_a, img_b, mode="shm", back=False): + # package + url = f"http://localhost:8109/get-flow" + if mode == "shm": + t = time.time() + fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png")) + fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png")) + elif mode == "net": + assert False, "not impl" + q = u2d.img2uri(img.pil("RGB")) + q.decode() + resp = requests.get( + url, + params={ + "img_a": fn_a, + "img_b": fn_b, + "mode": mode, + "back": back, + # 'vis': vis, + }, + ) + + # return + ans = {"response": resp} + if resp.status_code == 200: + j = resp.json() + ans["time"] = j["time"] + ans["output"] = { + "flow": torch.tensor(load(j["fn_flow"])), + } + # if vis: + # ans['output']['vis'] = I(j['fn_vis']) + if mode == "shm": + shutil.rmtree(f"/dev/shm/_flownet2/{t}") + return ans + + +####################### VISUALIZATION ####################### + + +class Gridnet(nn.Module): + def __init__(self, channels_0, channels_1, channels_2, total_dropout_p, depth): + super().__init__() + self.channels_0 = ch0 = channels_0 + self.channels_1 = ch1 = channels_1 + self.channels_2 = ch2 = channels_2 + self.total_dropout_p = p = total_dropout_p + self.depth = depth + self.encoders = nn.ModuleList( + [GridnetEncoder(ch0, ch1, ch2) for i in range(self.depth)] + ) + self.decoders = nn.ModuleList( + [GridnetDecoder(ch0, ch1, ch2) for i in range(self.depth)] + ) + self.total_dropout = GridnetTotalDropout(p) + return + + def forward(self, x): + for e, enc in enumerate(self.encoders): + t = [self.total_dropout(i) for i in t] if e != 0 else x + t = enc(t) + for d, dec in enumerate(self.decoders): + t = [self.total_dropout(i) for i in t] + t = dec(t) + return t + + +class GridnetEncoder(nn.Module): + def __init__(self, channels_0, channels_1, channels_2): + super().__init__() + self.channels_0 = ch0 = channels_0 + self.channels_1 = ch1 = channels_1 + self.channels_2 = ch2 = channels_2 + self.resnet_0 = GridnetResnet(ch0) + self.resnet_1 = GridnetResnet(ch1) + self.resnet_2 = GridnetResnet(ch2) + self.downsample_01 = GridnetDownsample(ch0, ch1) + self.downsample_12 = GridnetDownsample(ch1, ch2) + return + + def forward(self, x): + out = [ + None, + ] * 3 + out[0] = self.resnet_0(x[0]) + out[1] = self.resnet_1(x[1]) + self.downsample_01(out[0]) + out[2] = self.resnet_2(x[2]) + self.downsample_12(out[1]) + return out + + +class GridnetDecoder(nn.Module): + def __init__(self, channels_0, channels_1, channels_2): + super().__init__() + self.channels_0 = ch0 = channels_0 + self.channels_1 = ch1 = channels_1 + self.channels_2 = ch2 = channels_2 + self.resnet_0 = GridnetResnet(ch0) + self.resnet_1 = GridnetResnet(ch1) + self.resnet_2 = GridnetResnet(ch2) + self.upsample_10 = GridnetUpsample(ch1, ch0) + self.upsample_21 = GridnetUpsample(ch2, ch1) + return + + def forward(self, x): + out = [ + None, + ] * 3 + out[2] = self.resnet_2(x[2]) + out[1] = self.resnet_1(x[1]) + self.upsample_21(out[2]) + out[0] = self.resnet_0(x[0]) + self.upsample_10(out[1]) + return out + + +class GridnetConverter(nn.Module): + def __init__(self, channels_in, channels_out): + super().__init__() + self.channels_in = cin = channels_in + self.channels_out = cout = channels_out + self.nets = nn.ModuleList( + [ + nn.Sequential( + nn.PReLU(a), + nn.Conv2d(a, b, kernel_size=1, padding=0), + nn.BatchNorm2d(b), + ) + for a, b in zip(cin, cout) + ] + ) + return + + def forward(self, x): + return [m(q) for m, q in zip(self.nets, x)] + + +class GridnetResnet(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = ch = channels + self.net = nn.Sequential( + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + ) + return + + def forward(self, x): + return x + self.net(x) + + +class GridnetDownsample(nn.Module): + def __init__(self, channels_in, channels_out): + super().__init__() + self.channels_in = chin = channels_in + self.channels_out = chout = channels_out + self.net = nn.Sequential( + nn.PReLU(chin), + nn.Conv2d(chin, chin, kernel_size=3, padding=1, stride=2), + nn.BatchNorm2d(chin), + nn.PReLU(chin), + nn.Conv2d(chin, chout, kernel_size=3, padding=1), + nn.BatchNorm2d(chout), + ) + return + + def forward(self, x): + return self.net(x) + + +class GridnetUpsample(nn.Module): + def __init__(self, channels_in, channels_out): + super().__init__() + self.channels_in = chin = channels_in + self.channels_out = chout = channels_out + self.net = nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.PReLU(chin), + nn.Conv2d(chin, chout, kernel_size=3, padding=1), + nn.BatchNorm2d(chout), + nn.PReLU(chout), + nn.Conv2d(chout, chout, kernel_size=3, padding=1), + nn.BatchNorm2d(chout), + ) + return + + def forward(self, x): + return self.net(x) + + +class GridnetTotalDropout(nn.Module): + def __init__(self, p): + super().__init__() + assert 0 <= p < 1 + self.p = p + self.weight = 1 / (1 - p) + return + + def get_drop(self, x): + d = torch.rand(len(x))[:, None, None, None] < self.p + d = (1 - d.float()).to(x.device) * self.weight + return d + + def forward(self, x, force_drop=None): + if force_drop is True: + ans = x * self.get_drop(x) + elif force_drop is False: + ans = x + else: + if self.training: + ans = x * self.get_drop(x) + else: + ans = x + return ans + + +class Interpolator(nn.Module): + def __init__(self, size, mode="bilinear"): + super().__init__() + self.size = size + self.mode = mode + return + + def forward(self, x, is_flow=False): + if x.shape[-2] == self.size: + return x + if len(x.shape) == 4: + # bs,ch,h,w + bs, ch, h, w = x.shape + ans = nn.functional.interpolate( + x, + size=self.size, + mode=self.mode, + align_corners=(False, None)[self.mode == "nearest"], + ) + if is_flow: + ans = ( + ans + * torch.tensor( + [b / a for a, b in zip((h, w), self.size)], + device=ans.device, + )[None, :, None, None] + ) + return ans + elif len(x.shape) == 5: + # bs,k,ch,h,w (merge bs and k) + bs, k, ch, h, w = x.shape + return self.forward( + x.view(bs * k, ch, h, w), + is_flow=is_flow, + ).view(bs, k, ch, *self.size) + else: + assert 0 + + +###################### CANNY ###################### + + +def canny(img, a=100, b=200): + img = I(img).convert("L") + return I(cv2.Canny(img.cv2(), a, b)) + + +# https://www.pyimagesearch.com/2015/04/06/zero-parameter-automatic-canny-edge-detection-with-python-and-opencv/ +def canny_pis(img, sigma=0.33): + # compute the median of the single channel pixel intensities + img = I(img).convert("L").uint8(ch_last=False) + v = np.median(img) + # apply automatic Canny edge detection using the computed median + lower = int(max(0, (1.0 - sigma) * v)) + upper = int(min(255, (1.0 + sigma) * v)) + edged = cv2.Canny(img[0], lower, upper) + # return the edged image + return I(edged) + + +# https://en.wikipedia.org/wiki/Otsu%27s_method +def canny_otsu(img): + img = I(img).convert("L").uint8(ch_last=False) + high, _ = cv2.threshold(img[0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + low = 0.5 * high + return I(cv2.Canny(img[0], low, high)) + + +def xdog(img, t=1.0, epsilon=0.04, phi=100, sigma=3, k=1.6): + img = I(img).convert("L").uint8(ch_last=False) + grey = np.asarray(img, dtype=np.float32) + g0 = scipy.ndimage.gaussian_filter(grey, sigma) + g1 = scipy.ndimage.gaussian_filter(grey, sigma * k) + + # ans = ((1+p) * g0 - p * g1) / 255 + ans = (g0 - t * g1) / 255 + ans = 1 + np.tanh(phi * (ans - epsilon)) * (ans < epsilon) + return ans + + +def dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True): + img = I(img).convert("L").tensor()[None] + kern0 = max(2 * int(sigma * kernel_factor) + 1, 3) + kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3) + g0 = kornia.filters.gaussian_blur2d( + img, + (kern0, kern0), + (sigma, sigma), + border_type="replicate", + ) + g1 = kornia.filters.gaussian_blur2d( + img, + (kern1, kern1), + (sigma * k, sigma * k), + border_type="replicate", + ) + ans = 0.5 + t * (g1 - g0) - epsilon + ans = ans.clip(0, 1) if clip else ans + return ans[0].numpy() + + +# input: (bs,rgb(a),h,w) or (bs,1,h,w) +# returns: (bs,1,h,w) +def batch_dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True): + # to grayscale if needed + bs, ch, h, w = img.shape + if ch in [3, 4]: + img = kornia.color.rgb_to_grayscale(img[:, :3]) + else: + assert ch == 1 + + # calculate dog + kern0 = max(2 * int(sigma * kernel_factor) + 1, 3) + kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3) + g0 = kornia.filters.gaussian_blur2d( + img, + (kern0, kern0), + (sigma, sigma), + border_type="replicate", + ) + g1 = kornia.filters.gaussian_blur2d( + img, + (kern1, kern1), + (sigma * k, sigma * k), + border_type="replicate", + ) + ans = 0.5 + t * (g1 - g0) - epsilon + ans = ans.clip(0, 1) if clip else ans + return ans + + +############### DERIVED DISTANCES ############### + +# input: (bs,h,w) or (bs,1,h,w) +# returns: (bs,) +# normalized s.t. metric is same across proportional image scales + + +# average of two asymmetric distances +# normalized by diameter and area +def batch_chamfer_distance(gt, pred, block=1024, return_more=False): + t = batch_chamfer_distance_t(gt, pred, block=block) + p = batch_chamfer_distance_p(gt, pred, block=block) + cd = (t + p) / 2 + return cd + + +def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False): + assert gt.device == pred.device and gt.shape == pred.shape + bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] + dpred = batch_edt(pred, block=block) + cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) + if len(cd.shape) == 2: + assert cd.shape[1] == 1 + cd = cd.squeeze(1) + return cd + + +def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False): + assert gt.device == pred.device and gt.shape == pred.shape + bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] + dgt = batch_edt(gt, block=block) + cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2) + if len(cd.shape) == 2: + assert cd.shape[1] == 1 + cd = cd.squeeze(1) + return cd + + +# normalized by diameter +# always between [0,1] +def batch_hausdorff_distance(gt, pred, block=1024, return_more=False): + assert gt.device == pred.device and gt.shape == pred.shape + bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1] + dgt = batch_edt(gt, block=block) + dpred = batch_edt(pred, block=block) + hd = torch.stack( + [ + (dgt * pred).amax(dim=(-2, -1)), + (dpred * gt).amax(dim=(-2, -1)), + ] + ).amax(dim=0).float() / np.sqrt(h**2 + w**2) + if len(hd.shape) == 2: + assert hd.shape[1] == 1 + hd = hd.squeeze(1) + return hd + + +#################### UTILITIES #################### + + +def reset_parameters(model): + for layer in model.children(): + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + return model + + +def channel_squeeze(x, dim=1): + a = x.shape[:dim] + b = x.shape[dim + 2 :] + return x.reshape(*a, -1, *b) + + +def channel_unsqueeze(x, shape, dim=1): + a = x.shape[:dim] + b = x.shape[dim + 1 :] + return x.reshape(*a, *shape, *b) + + +def default_collate(items, device=None): + return to(dict(torch.utils.data.dataloader.default_collate(items)), device) + + +def to(x, device): + if device is None: + return x + if issubclass(x.__class__, dict): + return dict( + { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in x.items() + } + ) + if isinstance(x, torch.Tensor): + return x.to(device) + if isinstance(x, np.ndarray): + return torch.tensor(x).to(device) + assert 0, "data not understood" + + +################ PARSING ################ + +from argparse import Namespace + +# args: all args +# bargs: base args +# pargs: data processing args +# largs: data loading args +# margs: model args +# targs: training args + + +# typically used to read dataset filters +def read_filter(fn, cast=None, sort=True, sort_key=None): + if cast is None: + cast = lambda x: x + ans = [cast(line) for line in read(fn).split("\n") if line != ""] + if sort: + return sorted(ans, key=sort_key) + else: + return ans + + +################ FILE MANAGEMENT ################ + + +def mkfile(fn, parents=True, exist_ok=True): + dn = "/".join(fn.split("/")[:-1]) + mkdir(dn, parents=parents, exist_ok=exist_ok) + return fn + + +def mkdir(dn, parents=True, exist_ok=True): + pathlib.Path(dn).mkdir(parents=parents, exist_ok=exist_ok) + return dn if (not dn[-1] == "/" or dn == "/") else dn[:-1] + + +def fstrip(fn, return_more=False): + dspl = fn.split("/") + dn = "/".join(dspl[:-1]) if len(dspl) > 1 else "." + fn = dspl[-1] + fspl = fn.split(".") + if len(fspl) == 1: + bn = fspl[0] + ext = "" + else: + bn = ".".join(fspl[:-1]) + ext = fspl[-1] + if return_more: + return Namespace( + dn=dn, + fn=fn, + path=f"{dn}/{fn}", + bn_path=f"{dn}/{bn}", + bn=bn, + ext=ext, + ) + else: + return bn + + +def read(fn, mode="r"): + with open(fn, mode) as handle: + return handle.read() + + +def write(text, fn, mode="w"): + mkfile(fn, parents=True, exist_ok=True) + with open(fn, mode) as handle: + return handle.write(text) + + +import pickle + + +def dump(obj, fn, mode="wb"): + mkfile(fn, parents=True, exist_ok=True) + with open(fn, mode) as handle: + return pickle.dump(obj, handle) + + +def load(fn, mode="rb"): + with open(fn, mode) as handle: + return pickle.load(handle) + + +import json + + +def jwrite(x, fn, mode="w", indent="\t", sort_keys=False): + mkfile(fn, parents=True, exist_ok=True) + with open(fn, mode) as handle: + return json.dump(x, handle, indent=indent, sort_keys=sort_keys) + + +def jread(fn, mode="r"): + with open(fn, mode) as handle: + return json.load(handle) + + +try: + import yaml + + def ywrite(x, fn, mode="w", default_flow_style=False): + mkfile(fn, parents=True, exist_ok=True) + with open(fn, mode) as handle: + return yaml.dump(x, handle, default_flow_style=default_flow_style) + + def yread(fn, mode="r"): + with open(fn, mode) as handle: + return yaml.safe_load(handle) + +except: + pass + +try: + import pyunpack +except: + pass + +try: + import mysql + import mysql.connector +except: + pass + + +################ MISC ################ + +hakase = "./env/__hakase__.jpg" +if not os.path.isfile(hakase): + hakase = "./__env__/__hakase__.jpg" + + +def mem(units="m"): + return ( + psProcess(os.getpid()).memory_info().rss + / { + "b": 1, + "k": 1e3, + "m": 1e6, + "g": 1e9, + "t": 1e12, + }[units[0].lower()] + ) + + +def chunk(array, length, colwise=True): + if colwise: + return [array[i : i + length] for i in range(0, len(array), length)] + else: + return chunk(array, int(math.ceil(len(array) / length)), colwise=True) + + +def classtree(x): + return inspect.getclasstree(inspect.getmro(x)) + + +################ AESTHETIC ################ + + +class Table: + def __init__( + self, + table, + delimiter=" ", + orientation="br", + double_colon=True, + ): + self.delimiter = delimiter + self.orientation = orientation + self.t = Table.parse(table, delimiter, orientation, double_colon) + return + + # rendering + def __str__(self): + return self.render() + + def __repr__(self): + return self.render() + + def render(self): + # set up empty entry + empty = ("", Table._spec(self.orientation, transpose=False)) + + # calculate table size + t = copy.deepcopy(self.t) + totalrows = len(t) + totalcols = [len(r) for r in t] + assert min(totalcols) == max(totalcols) + totalcols = totalcols[0] + + # string-ify + for i in range(totalrows): + for j in range(totalcols): + x, s = t[i][j] + sp = s[11] + if sp: + x = eval(f'f"{{{x}{sp}}}"') + Table._put((str(x), s), t, (i, j), empty) + + # expand delimiters + _repl = ( + lambda s: s[:2] + (1, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:] + if s[2] + else s[:2] + (0, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:] + ) + for i, row in enumerate(t): + for j, (x, s_own) in enumerate(row): + # expand delim_up(^) + if s_own[3]: + u, v = i, j + while 0 <= u: + _, s = t[u][v] + if (i, j) != (u, v) and (s[2] and not s[10]): + break + Table._put((x, _repl(s)), t, (u, v), empty) + u -= 1 + + # expand delim_down(v) + if s_own[4]: + u, v = i, j + while u < totalrows: + _, s = t[u][v] + if (i, j) != (u, v) and (s[2] and not s[10]): + break + Table._put((x, _repl(s)), t, (u, v), empty) + u += 1 + + # expand delim_right(>) + if s_own[5]: + u, v = i, j + while v < totalcols: + _, s = t[u][v] + if (i, j) != (u, v) and (s[2] and not s[10]): + break + Table._put((x, _repl(s)), t, (u, v), empty) + v += 1 + + # expand delim_left(<) + if s_own[6]: + u, v = i, j + while 0 <= v: + _, s = t[u][v] + if (i, j) != (u, v) and (s[2] and not s[10]): + break + Table._put((x, _repl(s)), t, (u, v), empty) + v -= 1 + + # justification calculation + widths = [ + 0, + ] * totalcols # j + heights = [ + 0, + ] * totalrows # i + for i, row in enumerate(t): + for j, (x, s) in enumerate(row): + # height caclulation + heights[i] = max(heights[i], x.count("\n")) + + # width calculation; non-delim fillers no contribution + if s[2] or not s[10]: + w = max(len(q) for q in x.split("\n")) + widths[j] = max(widths[j], w) + # no newline ==> height=1 + heights = [h + 1 for h in heights] + + # render table + rend = [] + roff = 0 + for i, row in enumerate(t): + for j, (x, s) in enumerate(row): + w, h = widths[j], heights[i] + + # expand fillers and delimiters + if s[2] or s[10]: + xs = x.split("\n") + xw0 = min(len(l) for l in xs) + xw1 = max(len(l) for l in xs) + xh = len(xs) + if (xw0 == xw1 == w) and (xh == h): + pass + elif xw0 == xw1 == w: + x = "\n".join( + [ + xs[0], + ] + * h + ) + elif xh == h: + x = "\n".join([(l[0] if l else "") * w for l in xs]) + else: + x = x[0] if x else " " + x = "\n".join( + [ + x * w, + ] + * h + ) + + # justify horizontally + x = [l.rjust(w) if s[0] else l.ljust(w) for l in x.split("\n")] + + # justify vertically + plus = [ + " " * w, + ] * (h - len(x)) + x = plus + x if not s[1] else x + plus + + # input to table + for r, xline in enumerate(x): + Table._put(xline, rend, (roff + r, j), None) + roff += h + + # return rendered string + return "\n".join(["".join(r) for r in rend]) + + # parsing + def _spec(s, transpose=False): + if ":" in s: + i = s.index(":") + sp = s[i:] + s = s[:i] + else: + sp = "" + s = s.lower() + return ( + int("r" in s), # 0:: 0:left(l) 1:right(r) + int("t" in s), # 1:: 0:bottom(b) 1:top(t) + int(any([i in s for i in [".", "<", ">", "^", "v"]])), # 2:: delim_here(.) + int("^" in s if not transpose else "<" in s), # 3:: delim_up(^) + int("v" in s if not transpose else ">" in s), # 4:: delim_down(v) + int(">" in s if not transpose else "v" in s), # 5:: delim_right(>) + int("<" in s if not transpose else "^" in s), # 6:: delim_left(<) + int("+" in s), # 7:: subtable(+) + int("-" in s if not transpose else "|" in s), # 8:: subtable_horiz(-) + int("|" in s if not transpose else "-" in s), # 9:: subtable_vert(|) + int("_" in s), # 10:: fill(_); if delim, overwrite; else fit + sp, # 11:: special(:) f-string for numbers + ) + + def _put(obj, t, ij, empty): + i, j = ij + while i >= len(t): + t.append([]) + while j >= len(t[i]): + t[i].append(empty) + t[i][j] = obj + return + + def parse( + table, + delimiter=" ", + orientation="br", + double_colon=True, + ): + # disabling transpose + transpose = False + + # set up empty entry + empty = ("", Table._spec(orientation, transpose)) + + # transpose + t = [] + for i, row in enumerate(table): + for j, item in enumerate(row): + ij = (i, j) if not transpose else (j, i) + if type(item) == tuple and len(item) == 2 and type(item[1]) == str: + item = (item[0], Table._spec(item[1], transpose)) + elif double_colon and type(item) == str and "::" in item: + x, s = item.split("::") + item = (x, Table._spec(s, transpose)) + else: + item = (item, Table._spec(orientation, transpose)) + Table._put(item, t, ij, empty) + + # normalization + maxcol = 0 + maxrow = len(t) + for i, row in enumerate(t): + # take element number into account + maxcol = max(maxcol, len([i for i in row if not i[1][2]])) + + # take subtables into account + for j, (x, s) in enumerate(row): + if s[7]: + r = len(x) + maxrow = max(maxrow, i + r) + c = max(len(q) for q in x) + maxcol = max(maxcol, j + c) + elif s[8]: + c = len(x) + maxcol = max(maxcol, j + c) + elif s[9]: + r = len(x) + maxrow = max(maxrow, i + r) + totalcols = 2 * maxcol + 1 + totalrows = maxrow + t += [[]] * (totalrows - len(t)) + newt = [] + delim = (delimiter, Table._spec("._" + orientation, transpose)) + for i, row in enumerate(t): + wasd = False + tcount = 0 + for j in range(totalcols): + item = t[i][tcount] if tcount < len(t[i]) else empty + isd = item[1][2] + if wasd and isd: + Table._put(empty, newt, (i, j), empty) + wasd = False + elif wasd and not isd: + Table._put(item, newt, (i, j), empty) + tcount += 1 + wasd = False + elif not wasd and isd: + Table._put(item, newt, (i, j), empty) + tcount += 1 + wasd = True + elif not wasd and not isd: + Table._put(delim, newt, (i, j), empty) + wasd = True + t = newt + + # normalization: add dummy last column for delimiter + for row in t: + row.append(empty) + + # expand subtables + delim_cols = [i for i in range(totalcols) if i % 2 == 0] + while True: + # find a table + ij = None + for i, row in enumerate(t): + for j, item in enumerate(row): + st, s = item + if s[7]: + ij = i, j, 7, st, s + break + elif s[8]: + ij = i, j, 8, st, s + break + elif s[9]: + ij = i, j, 9, st, s + break + if ij is not None: + break + if ij is None: + break + + # replace its specs + i, j, k, st, s = ij + s = list(s) + s[7] = s[8] = s[9] = 0 + s = tuple(s) + + # expand it + if k == 7: # 2d table + for x, row in enumerate(st): + for y, obj in enumerate(row): + a = i + x if not transpose else i + y + b = j + 2 * y if not transpose else j + 2 * x + Table._put((obj, s), t, (a, b), None) + if k == 8: # subtable_horiz + for y, obj in enumerate(st): + Table._put((obj, s), t, (i, j + 2 * y), None) + if k == 9: # subtable_vert + for x, obj in enumerate(st): + Table._put((obj, s), t, (i + x, j), None) + + # return, finally + return t + + +class Resnet(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = ch = channels + self.net = nn.Sequential( + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + ) + return + + def forward(self, x): + return x + self.net(x) + + +class Synthesizer(nn.Module): + def __init__( + self, size, channels_image, channels_flow, channels_mask, channels_feature + ): + super().__init__() + self.size = size + self.diam = diam(self.size) + self.channels_image = cimg = channels_image + self.channels_flow = cflow = channels_flow + self.channels_mask = cmask = channels_mask + self.channels_feature = cfeat = channels_feature + self.channels = ch = cimg + cflow // 2 + cmask + cfeat + self.interpolator = Interpolator(self.size, mode="bilinear") + self.net = nn.Sequential( + nn.Conv2d(ch + 3, 64, kernel_size=1, padding=0), + Resnet(64), + nn.Sequential( + nn.PReLU(64), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + ), + Resnet(32), + nn.Sequential( + nn.PReLU(32), + nn.Conv2d(32, 16, kernel_size=3, padding=1), + nn.BatchNorm2d(16), + ), + Resnet(16), + nn.Sequential( + nn.PReLU(16), + nn.Conv2d(16, 3, kernel_size=3, padding=1), + ), + ) + return + + def forward(self, images, flows, masks, features, return_more=False): + itp = self.interpolator + images = [ + (images[0] + images[1]) / 2, + ] + images + logimgs = [itp(pixel_logit(i[:, :3])) for i in images] + cat = torch.cat( + [ + *logimgs, + *[itp(f).norm(dim=1, keepdim=True) / self.diam for f in flows], + *[itp(m) for m in masks], + *[itp(f) for f in features], + ], + dim=1, + ) + residual = self.net(cat) + return torch.sigmoid(logimgs[0] + 0.5 * residual), ( + locals() if return_more else None + ) + + +class FlowZMetric(nn.Module): + def __init__(self): + super().__init__() + return + + def forward(self, img0, img1, flow0, flow1, return_more=False): + # B(i0,f0) = i1 + # B(i1,f1) = i0 + # F(x,f0,z0) + # F(x,f1,z1) + img0 = kornia.color.rgb_to_lab(img0[:, :3]) + img1 = kornia.color.rgb_to_lab(img1[:, :3]) + return [ + -0.1 * (img1 - flow_backwarp(img0, flow0)).norm(dim=1, keepdim=True), # z0 + -0.1 * (img0 - flow_backwarp(img1, flow1)).norm(dim=1, keepdim=True), # z1 + ], (locals() if return_more else None) + + +class NEDT(nn.Module): + def __init__(self): + super().__init__() + return + + def forward( + self, + img, + t=2.0, + sigma_factor=1 / 540, + k=1.6, + epsilon=0.01, + kernel_factor=4, + exp_factor=540 / 15, + return_more=False, + ): + with torch.no_grad(): + dog = batch_dog( + img, + t=t, + sigma=img.shape[-2] * sigma_factor, + k=k, + epsilon=epsilon, + kernel_factor=kernel_factor, + clip=False, + ) + edt = batch_edt((dog > 0.5).float()) + ans = 1 - (-edt * exp_factor / max(edt.shape[-2:])).exp() + return ans, (locals() if return_more else None) + + +class HalfWarper(nn.Module): + def __init__(self): + super().__init__() + self.channels_image = 4 * 3 + self.channels_flow = 2 * 2 + self.channels_mask = 2 * 1 + self.channels = self.channels_image + self.channels_flow + self.channels_mask + + def morph_open(self, x, k): + if k == 0: + return x + else: + with torch.no_grad(): + return kornia.morphology.opening(x, torch.ones(k, k, device=x.device)) + + def forward(self, img0, img1, flow0, flow1, z0, z1, k, t=0.5, return_more=False): + # forewarps + flow0_ = (1 - t) * flow0 + flow1_ = t * flow1 + f01 = forewarp(img0, flow1_, mode="sm", metric=z1, mask=True) + f10 = forewarp(img1, flow0_, mode="sm", metric=z0, mask=True) + f01i, f01m = f01[:, :-1], self.morph_open(f01[:, -1:], k=k) + f10i, f10m = f10[:, :-1], self.morph_open(f10[:, -1:], k=k) + + # base guess + base0 = f01m * f01i + (1 - f01m) * f10i + base1 = f10m * f10i + (1 - f10m) * f01i + ans = [ + [ # images + base0, + base1, + f01i, + f10i, + ], + [ # flows + flow0_, + flow1_, + ], + [ # masks + f01m, + f10m, + ], + ] + return ans, (locals() if return_more else None) + + +class ResnetFeatureExtractor(nn.Module): + def __init__(self, inferserve_query, size_in=None): + super().__init__() + self.inferserve_query = iq = inferserve_query + self.size_in = si = size_in + if iq[0] == "torchvision": + # use pytorch pretrained resnet50 + self.base_hparams = None + resnet = tv.models.resnet50(pretrained=True) + + self.resize = T.Resize(256) + self.resnet_preprocess = T.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 64ch, 128p (assuming 256p input) + self.maxpool = resnet.maxpool + self.layer1 = resnet.layer1 # 256ch, 64p + self.layer2 = resnet.layer2 # 512ch, 32p + else: + base = userving.infer_model_load(*iq).eval() + self.base_hparams = base.hparams + + self.resize = T.Resize(base.hparams.largs.size) + self.resnet_preprocess = base.resnet_preprocess + self.conv1 = base.resnet.conv1 + self.bn1 = base.resnet.bn1 + self.relu = base.resnet.relu # 64ch, 128p (assuming 256p input) + self.maxpool = base.resnet.maxpool + self.layer1 = base.resnet.layer1 # 256ch, 64p + self.layer2 = base.resnet.layer2 # 512ch, 32p + if self.size_in is None: + self.sizes_out = None + else: + s = self.resize.size + self.sizes_out = [ + pixel_ij( + rescale_dry(si, (s // 2) / si[0]), rounding="ceil" + ), # conv1, 128p + pixel_ij( + rescale_dry(si, (s // 4) / si[0]), rounding="ceil" + ), # layer1, 64p + pixel_ij( + rescale_dry(si, (s // 8) / si[0]), rounding="ceil" + ), # layer2, 32p + ] + self.channels = [ + 64, + 256, + 512, + ] + return + + def forward(self, x, force_sizes_out=False, return_more=False): + ans = [] + x = x[:, :3] + x = self.resize(x) + x = self.resnet_preprocess(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + ans.append(x) # conv1 + x = self.maxpool(x) + x = self.layer1(x) + ans.append(x) # layer1 + x = self.layer2(x) + ans.append(x) # layer2 + if force_sizes_out or (self.sizes_out is None): + self.sizes_out = [tuple(q.shape[-2:]) for q in ans] + return ans, (locals() if return_more else None) + + +class NetNedt(nn.Module): + def __init__(self): + super().__init__() + chin = 3 + 1 + 4 + 4 + 1 + 1 + ch = 16 + chout = 1 + self.net = nn.Sequential( + nn.PReLU(chin), + nn.Conv2d(chin, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, chout, kernel_size=3, padding=1), + ) + return + + def forward(self, out_base, out_base_nedt, hw_imgs, hw_masks, return_more=False): + cat = torch.cat( + [ + out_base, # 3 + out_base_nedt, # 1 + hw_imgs[0], # 4 + hw_imgs[1], # 4 + hw_masks[0], # 1 + hw_masks[1], # 1 + ], + dim=1, + ) + log = pixel_logit(cat.clip(0, 1)) + ans = torch.sigmoid(self.net(log)) + return ans, (locals() if return_more else None) + + +class NetTail(nn.Module): + def __init__(self): + super().__init__() + chin = 3 + 1 + 1 + ch = 16 + chout = 3 + self.net = nn.Sequential( + nn.PReLU(chin), + nn.Conv2d(chin, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, ch, kernel_size=3, padding=1), + nn.BatchNorm2d(ch), + nn.PReLU(ch), + nn.Conv2d(ch, chout, kernel_size=3, padding=1), + ) + return + + def forward(self, out_base, out_base_nedt, pred_nedt, return_more=False): + cat = torch.cat( + [ + out_base, # 3 + out_base_nedt, # 1 + pred_nedt, # 1 + ], + dim=1, + ) + log = pixel_logit(cat.clip(0, 1)) + ans = torch.sigmoid(log[:, :3] + self.net(log)) + return ans, (locals() if return_more else None) + + +class SoftsplatLite(nn.Module): + def __init__(self): + super().__init__() + self.feature_extractor = ResnetFeatureExtractor( + ("torchvision", "resnet50"), + (540, 960), + ) + self.z_metric = FlowZMetric() + self.flow_downsamplers = [ + Interpolator(s, mode="bilinear") for s in self.feature_extractor.sizes_out + ] + self.gridnet_converter = GridnetConverter( + self.feature_extractor.channels, + [32, 64, 128], + ) + self.gridnet = Gridnet( + *[32, 64, 128], + total_dropout_p=0.0, + depth=1, # equivalent to u-net + ) + self.nedt = NEDT() + self.half_warper = HalfWarper() + self.synthesizer = Synthesizer( + (540, 960), + self.half_warper.channels_image, + self.half_warper.channels_flow, + self.half_warper.channels_mask, + self.gridnet.channels_0, + ) + return + + def forward(self, x, t=0.5, k=5, return_more=False): + rm = return_more + flow0, flow1 = x["flows"].swapaxes(0, 1) + img0, img1 = x["images"][:, 0], x["images"][:, -1] + (z0, z1), locs_z = self.z_metric(img0, img1, flow0, flow1, return_more=rm) + img0 = torch.cat([img0, self.nedt(img0)[0]], dim=1) + img1 = torch.cat([img1, self.nedt(img1)[0]], dim=1) + + # images and flows + (hw_imgs, hw_flows, hw_masks), locs_hw = self.half_warper( + img0, + img1, + flow0, + flow1, + z0, + z1, + k, + t=t, + return_more=rm, + ) + + # features + feats0, locs_fe0 = self.feature_extractor(img0, return_more=rm) + feats1, locs_fe1 = self.feature_extractor(img1, return_more=rm) + warps = [] + for ft0, ft1, ds in zip(feats0, feats1, self.flow_downsamplers): + (w, _, _), _ = self.half_warper( + ft0, + ft1, + ds(flow0, 1), + ds(flow1, 1), + ds(z0), + ds(z1), + k, + t=t, + ) + warps.append((w[0] + w[1]) / 2) + feats = self.gridnet(self.gridnet_converter(warps)) + + # synthesis + pred, locs_synth = self.synthesizer( + hw_imgs, + hw_flows, + hw_masks, + [ + feats[0], + ], + return_more=rm, + ) + return pred, (locals() if rm else None) + + +class DTM(nn.Module): + def __init__(self): + super().__init__() + self.net_nedt = NetNedt() + self.net_tail = NetTail() + self.nedt = NEDT() + return + + def forward(self, x, out_base, locs_base, return_more=False): + rm = return_more + with torch.no_grad(): + out_base_nedt, locs_base_nedt = self.nedt(out_base, return_more=rm) + hw_imgs, hw_masks = locs_base["hw_imgs"], locs_base["hw_masks"] + pred_nedt, locs_nedt = self.net_nedt( + out_base, out_base_nedt, hw_imgs, hw_masks, return_more=rm + ) + pred, locs_tail = self.net_tail( + out_base, out_base_nedt, pred_nedt.clone().detach(), return_more=rm + ) + return torch.cat([pred, pred_nedt], dim=1), (locals() if rm else None) + + +class RAFT(nn.Module): + def __init__(self, path="/workspace/tensorrt/models/anime_interp_full.ckpt"): + super().__init__() + self.raft = RFR( + Namespace( + small=False, + mixed_precision=False, + ) + ) + if path is not None: + sd = torch.load(path)["model_state_dict"] + self.raft.load_state_dict( + { + k[len("module.flownet.") :]: v + for k, v in sd.items() + if k.startswith("module.flownet.") + }, + strict=False, + ) + return + + def forward(self, img0, img1, flow0=None, iters=12, return_more=False): + if flow0 is not None: + flow0 = flow0.flip(dims=(1,)) + out = self.raft(img1, img0, iters=iters, flow_init=flow0) + return out[0].flip(dims=(1,)), (locals() if return_more else None) diff --git a/vfi_models/film/__init__.py b/vfi_models/film/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f89f1bdc63d149a36c72614b3f1c7ed7ffa7962 --- /dev/null +++ b/vfi_models/film/__init__.py @@ -0,0 +1,113 @@ +import torch +from comfy.model_management import get_torch_device, soft_empty_cache +import bisect +import numpy as np +import typing +from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames +import pathlib +import gc + +MODEL_TYPE = pathlib.Path(__file__).parent.name +DEVICE = get_torch_device() +def inference(model, img_batch_1, img_batch_2, inter_frames): + results = [ + img_batch_1, + img_batch_2 + ] + + idxes = [0, inter_frames + 1] + remains = list(range(1, inter_frames + 1)) + + splits = torch.linspace(0, 1, inter_frames + 2) + + for _ in range(len(remains)): + starts = splits[idxes[:-1]] + ends = splits[idxes[1:]] + distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs() + matrix = torch.argmin(distances).item() + start_i, step = np.unravel_index(matrix, distances.shape) + end_i = start_i + 1 + + x0 = results[start_i].to(DEVICE) + x1 = results[end_i].to(DEVICE) + dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]]) + + with torch.no_grad(): + prediction = model(x0, x1, dt) + insert_position = bisect.bisect_left(idxes, remains[step]) + idxes.insert(insert_position, remains[step]) + results.insert(insert_position, prediction.clamp(0, 1).float()) + del remains[step] + + return [tensor.flip(0) for tensor in results] + +class FILM_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (["film_net_fp32.pt"], ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + interpolation_states = optional_interpolation_states + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + model = torch.jit.load(model_path, map_location='cpu') + model.eval() + model = model.to(DEVICE) + dtype = torch.float32 + + frames = preprocess_frames(frames) + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + output_frames = [] + + if type(multiplier) == int: + multipliers = [multiplier] * len(frames) + else: + multipliers = list(map(int, multiplier)) + multipliers += [2] * (len(frames) - len(multipliers) - 1) + for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it + if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr): + continue + #Ensure that input frames are in fp32 - the same dtype as model + frame_0 = frames[frame_itr:frame_itr+1].to(DEVICE).float() + frame_1 = frames[frame_itr+1:frame_itr+2].to(DEVICE).float() + relust = inference(model, frame_0, frame_1, multipliers[frame_itr] - 1) + output_frames.extend([frame.detach().cpu().to(dtype=dtype) for frame in relust[:-1]]) + + number_of_frames_processed_since_last_cleared_cuda_cache += 1 + # Try to avoid a memory overflow by clearing cuda cache regularly + if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: + print("Comfy-VFI: Clearing cache...", end = ' ') + soft_empty_cache() + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + print("Done cache clearing") + gc.collect() + + output_frames.append(frames[-1:].to(dtype=dtype)) # Append final frame + output_frames = [frame.cpu() for frame in output_frames] #Ensure all frames are in cpu + out = torch.cat(output_frames, dim=0) + # clear cache for courtesy + print("Comfy-VFI: Final clearing cache...", end = ' ') + soft_empty_cache() + print("Done cache clearing") + return (postprocess_frames(out), ) diff --git a/vfi_models/film/film_arch.py b/vfi_models/film/film_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..a527e842f8e97d2c0f2c6997d8fa4d3134e59167 --- /dev/null +++ b/vfi_models/film/film_arch.py @@ -0,0 +1,798 @@ +""" +https://github.com/dajes/frame-interpolation-pytorch/blob/main/feature_extractor.py +https://github.com/dajes/frame-interpolation-pytorch/blob/main/fusion.py +https://github.com/dajes/frame-interpolation-pytorch/blob/main/interpolator.py +https://github.com/dajes/frame-interpolation-pytorch/blob/main/pyramid_flow_estimator.py +https://github.com/dajes/frame-interpolation-pytorch/blob/main/util.py +""" + +"""PyTorch layer for extracting image features for the film_net interpolator. + +The feature extractor implemented here converts an image pyramid into a pyramid +of deep features. The feature pyramid serves a similar purpose as U-Net +architecture's encoder, but we use a special cascaded architecture described in +Multi-view Image Fusion [1]. + +For comprehensiveness, below is a short description of the idea. While the +description is a bit involved, the cascaded feature pyramid can be used just +like any image feature pyramid. + +Why cascaded architeture? +========================= +To understand the concept it is worth reviewing a traditional feature pyramid +first: *A traditional feature pyramid* as in U-net or in many optical flow +networks is built by alternating between convolutions and pooling, starting +from the input image. + +It is well known that early features of such architecture correspond to low +level concepts such as edges in the image whereas later layers extract +semantically higher level concepts such as object classes etc. In other words, +the meaning of the filters in each resolution level is different. For problems +such as semantic segmentation and many others this is a desirable property. + +However, the asymmetric features preclude sharing weights across resolution +levels in the feature extractor itself and in any subsequent neural networks +that follow. This can be a downside, since optical flow prediction, for +instance is symmetric across resolution levels. The cascaded feature +architecture addresses this shortcoming. + +How is it built? +================ +The *cascaded* feature pyramid contains feature vectors that have constant +length and meaning on each resolution level, except few of the finest ones. The +advantage of this is that the subsequent optical flow layer can learn +synergically from many resolutions. This means that coarse level prediction can +benefit from finer resolution training examples, which can be useful with +moderately sized datasets to avoid overfitting. + +The cascaded feature pyramid is built by extracting shallower subtree pyramids, +each one of them similar to the traditional architecture. Each subtree +pyramid S_i is extracted starting from each resolution level: + +image resolution 0 -> S_0 +image resolution 1 -> S_1 +image resolution 2 -> S_2 +... + +If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid +is constructed by concatenating features as follows (assuming subtree depth=3): + +lvl +feat_0 = concat( S_0_0 ) +feat_1 = concat( S_1_0 S_0_1 ) +feat_2 = concat( S_2_0 S_1_1 S_0_2 ) +feat_3 = concat( S_3_0 S_2_1 S_1_2 ) +feat_4 = concat( S_4_0 S_3_1 S_2_2 ) +feat_5 = concat( S_5_0 S_4_1 S_3_2 ) + .... + +In above, all levels except feat_0 and feat_1 have the same number of features +with similar semantic meaning. This enables training a single optical flow +predictor module shared by levels 2,3,4,5... . For more details and evaluation +see [1]. + +[1] Multi-view Image Fusion, Trinidad et al. 2019 +""" +from typing import List + +import torch +from torch import nn +from torch.nn import functional as F + + +class SubTreeExtractor(nn.Module): + """Extracts a hierarchical set of features from an image. + + This is a conventional, hierarchical image feature extractor, that extracts + [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels. + Each level is followed by average pooling. + """ + + def __init__(self, in_channels=3, channels=64, n_layers=4): + super().__init__() + convs = [] + for i in range(n_layers): + convs.append(nn.Sequential( + conv(in_channels, (channels << i), 3), + conv((channels << i), (channels << i), 3) + )) + in_channels = channels << i + self.convs = nn.ModuleList(convs) + + def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]: + """Extracts a pyramid of features from the image. + + Args: + image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS. + n: number of pyramid levels to extract. This can be less or equal to + options.sub_levels given in the __init__. + Returns: + The pyramid of features, starting from the finest level. Each element + contains the output after the last convolution on the corresponding + pyramid level. + """ + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, kernel_size=2, stride=2) + return pyramid + + +class FeatureExtractor(nn.Module): + """Extracts features from an image pyramid using a cascaded architecture. + """ + + def __init__(self, in_channels=3, channels=64, sub_levels=4): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels) + self.sub_levels = sub_levels + + def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: + """Extracts a cascaded feature pyramid. + + Args: + image_pyramid: Image pyramid as a list, starting from the finest level. + Returns: + A pyramid of cascaded features. + """ + sub_pyramids: List[List[torch.Tensor]] = [] + for i in range(len(image_pyramid)): + # At each level of the image pyramid, creates a sub_pyramid of features + # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor. + # We use the same instance since we want to share the weights. + # + # However, we cap the depth of the sub_pyramid so we don't create features + # that are beyond the coarsest level of the cascaded feature pyramid we + # want to generate. + capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels) + sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels)) + # Below we generate the cascades of features on each level of the feature + # pyramid. Assuming sub_levels=3, The layout of the features will be + # as shown in the example on file documentation above. + feature_pyramid: List[torch.Tensor] = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + return feature_pyramid + + + + + + + + + + + +"""The final fusion stage for the film_net frame interpolator. + +The inputs to this module are the warped input images, image features and +flow fields, all aligned to the target frame (often midway point between the +two original inputs). The output is the final image. FILM has no explicit +occlusion handling -- instead using the abovementioned information this module +automatically decides how to best blend the inputs together to produce content +in areas where the pixels can only be borrowed from one of the inputs. + +Similarly, this module also decides on how much to blend in each input in case +of fractional timestep that is not at the halfway point. For example, if the two +inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1, +it often makes most sense to favor the first input. However, this is not +always the case -- in particular in occluded pixels. + +The architecture of the Fusion module follows U-net [1] architecture's decoder +side, e.g. each pyramid level consists of concatenation with upsampled coarser +level output, and two 3x3 convolutions. + +The upsampling is implemented as 'resize convolution', e.g. nearest neighbor +upsampling followed by 2x2 convolution as explained in [2]. The classic U-net +uses max-pooling which has a tendency to create checkerboard artifacts. + +[1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image + Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf +[2] https://distill.pub/2016/deconv-checkerboard/ +""" +from typing import List + +import torch +from torch import nn +from torch.nn import functional as F + + +_NUMBER_OF_COLOR_CHANNELS = 3 + + +def get_channels_at_level(level, filters): + n_images = 2 + channels = _NUMBER_OF_COLOR_CHANNELS + flows = 2 + + return (sum(filters << i for i in range(level)) + channels + flows) * n_images + + +class Fusion(nn.Module): + """The decoder.""" + + def __init__(self, n_layers=4, specialized_layers=3, filters=64): + """ + Args: + m: specialized levels + """ + super().__init__() + + # The final convolution that outputs RGB: + self.output_conv = nn.Conv2d(filters, 3, kernel_size=1) + + # Each item 'convs[i]' will contain the list of convolutions to be applied + # for pyramid level 'i'. + self.convs = nn.ModuleList() + + # Create the convolutions. Roughly following the feature extractor, we + # double the number of filters when the resolution halves, but only up to + # the specialized_levels, after which we use the same number of filters on + # all levels. + # + # We create the convs in fine-to-coarse order, so that the array index + # for the convs will correspond to our normal indexing (0=finest level). + # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442) + + in_channels = get_channels_at_level(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + convs = nn.ModuleList([ + conv(in_channels, num_filters, size=2, activation=None), + conv(in_channels + (increase or num_filters), num_filters, size=3), + conv(num_filters, num_filters, size=3)] + ) + self.convs.append(convs) + in_channels = num_filters + increase = get_channels_at_level(i, filters) - num_filters // 2 + + def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor: + """Runs the fusion module. + + Args: + pyramid: The input feature pyramid as list of tensors. Each tensor being + in (B x H x W x C) format, with finest level tensor first. + + Returns: + A batch of RGB images. + Raises: + ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in + the constructor. + """ + + # As a slight difference to a conventional decoder (e.g. U-net), we don't + # apply any extra convolutions to the coarsest level, but just pass it + # to finer levels for concatenation. This choice has not been thoroughly + # evaluated, but is motivated by the educated guess that the fusion part + # probably does not need large spatial context, because at this point the + # features are spatially aligned by the preceding warp. + net = pyramid[-1] + + # Loop starting from the 2nd coarsest level: + # for i in reversed(range(0, len(pyramid) - 1)): + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + # Resize the tensor from coarser level to match for concatenation. + level_size = pyramid[i].shape[2:4] + net = F.interpolate(net, size=level_size, mode='nearest') + net = layers[0](net) + net = torch.cat([pyramid[i], net], dim=1) + net = layers[1](net) + net = layers[2](net) + net = self.output_conv(net) + return net + + + + + + + + + + + +"""The film_net frame interpolator main model code. + +Basics +====== +The film_net is an end-to-end learned neural frame interpolator implemented as +a PyTorch model. It has the following inputs and outputs: + +Inputs: + x0: image A. + x1: image B. + time: desired sub-frame time. + +Outputs: + image: the predicted in-between image at the chosen time in range [0, 1]. + +Additional outputs include forward and backward warped image pyramids, flow +pyramids, etc., that can be visualized for debugging and analysis. + +Note that many training sets only contain triplets with ground truth at +time=0.5. If a model has been trained with such training set, it will only work +well for synthesizing frames at time=0.5. Such models can only generate more +in-between frames using recursion. + +Architecture +============ +The inference consists of three main stages: 1) feature extraction 2) warping +3) fusion. On high-level, the architecture has similarities to Context-aware +Synthesis for Video Frame Interpolation [1], but the exact architecture is +closer to Multi-view Image Fusion [2] with some modifications for the frame +interpolation use-case. + +Feature extraction stage employs the cascaded multi-scale architecture described +in [2]. The advantage of this architecture is that coarse level flow prediction +can be learned from finer resolution image samples. This is especially useful +to avoid overfitting with moderately sized datasets. + +The warping stage uses a residual flow prediction idea that is similar to +PWC-Net [3], Multi-view Image Fusion [2] and many others. + +The fusion stage is similar to U-Net's decoder where the skip connections are +connected to warped image and feature pyramids. This is described in [2]. + +Implementation Conventions +==================== +Pyramids +-------- +Throughtout the model, all image and feature pyramids are stored as python lists +with finest level first followed by downscaled versions obtained by successively +halving the resolution. The depths of all pyramids are determined by +options.pyramid_levels. The only exception to this is internal to the feature +extractor, where smaller feature pyramids are temporarily constructed with depth +options.sub_levels. + +Color ranges & gamma +-------------------- +The model code makes no assumptions on whether the images are in gamma or +linearized space or what is the range of RGB color values. So a model can be +trained with different choices. This does not mean that all the choices lead to +similar results. In practice the model has been proven to work well with RGB +scale = [0,1] with gamma-space images (i.e. not linearized). + +[1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018 +[2] Multi-view Image Fusion, Trinidad et al, 2019 +[3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume +""" +from typing import Dict, List + +import torch +from torch import nn + + + +class Interpolator(nn.Module): + def __init__( + self, + pyramid_levels=7, + fusion_pyramid_levels=5, + specialized_levels=3, + sub_levels=4, + filters=64, + flow_convs=(3, 3, 3, 3), + flow_filters=(32, 64, 128, 256), + ): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + + self.extract = FeatureExtractor(3, filters, sub_levels) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters) + self.fuse = Fusion(sub_levels, specialized_levels, filters) + + def shuffle_images(self, x0, x1): + return [ + build_image_pyramid(x0, self.pyramid_levels), + build_image_pyramid(x1, self.pyramid_levels) + ] + + def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]: + image_pyramids = self.shuffle_images(x0, x1) + + # Siamese feature pyramids: + feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])] + + # Predict forward flow. + forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1]) + + # Predict backward flow. + backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0]) + + # Concatenate features and images: + + # Note that we keep up to 'fusion_pyramid_levels' levels as only those + # are used by the fusion module. + + forward_flow_pyramid = flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels] + + backward_flow_pyramid = flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels] + + # We multiply the flows with t and 1-t to warp to the desired fractional time. + # + # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo- + # lator for multi-frame interpolation. Below, we create a constant tensor of + # shape [B]. We use the `time` tensor to infer the batch size. + mid_time = torch.full_like(batch_dt, .5) + backward_flow = multiply_pyramid(backward_flow_pyramid, mid_time[:, 0]) + forward_flow = multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0]) + + pyramids_to_warp = [ + concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels], + feature_pyramids[0][:self.fusion_pyramid_levels]), + concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels], + feature_pyramids[1][:self.fusion_pyramid_levels]) + ] + + # Warp features and images using the flow. Note that we use backward warping + # and backward flow is used to read from image 0 and forward flow from + # image 1. + forward_warped_pyramid = pyramid_warp(pyramids_to_warp[0], backward_flow) + backward_warped_pyramid = pyramid_warp(pyramids_to_warp[1], forward_flow) + + aligned_pyramid = concatenate_pyramids(forward_warped_pyramid, + backward_warped_pyramid) + aligned_pyramid = concatenate_pyramids(aligned_pyramid, backward_flow) + aligned_pyramid = concatenate_pyramids(aligned_pyramid, forward_flow) + + return { + 'image': [self.fuse(aligned_pyramid)], + 'forward_residual_flow_pyramid': forward_residual_flow_pyramid, + 'backward_residual_flow_pyramid': backward_residual_flow_pyramid, + 'forward_flow_pyramid': forward_flow_pyramid, + 'backward_flow_pyramid': backward_flow_pyramid, + } + + + def forward(self, x0, x1, batch_dt) -> torch.Tensor: + return self.debug_forward(x0, x1, batch_dt)['image'][0] + + + + + + + + + + +"""PyTorch layer for estimating optical flow by a residual flow pyramid. + +This approach of estimating optical flow between two images can be traced back +to [1], but is also used by later neural optical flow computation methods such +as SpyNet [2] and PWC-Net [3]. + +The basic idea is that the optical flow is first estimated in a coarse +resolution, then the flow is upsampled to warp the higher resolution image and +then a residual correction is computed and added to the estimated flow. This +process is repeated in a pyramid on coarse to fine order to successively +increase the resolution of both optical flow and the warped image. + +In here, the optical flow predictor is used as an internal component for the +film_net frame interpolator, to warp the two input images into the inbetween, +target frame. + +[1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987. +[2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid + Network. 2016 +[3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using + Pyramid, Warping, and Cost Volume, 2017 +""" +from typing import List + +import torch +from torch import nn +from torch.nn import functional as F + + + +class FlowEstimator(nn.Module): + """Small-receptive field predictor for computing the flow between two images. + + This is used to compute the residual flow fields in PyramidFlowEstimator. + + Note that while the number of 3x3 convolutions & filters to apply is + configurable, two extra 1x1 convolutions are appended to extract the flow in + the end. + + Attributes: + name: The name of the layer + num_convs: Number of 3x3 convolutions to apply + num_filters: Number of filters in each 3x3 convolution + """ + + def __init__(self, in_channels: int, num_convs: int, num_filters: int): + super(FlowEstimator, self).__init__() + + self._convs = nn.ModuleList() + for i in range(num_convs): + self._convs.append(conv(in_channels=in_channels, out_channels=num_filters, size=3)) + in_channels = num_filters + self._convs.append(conv(in_channels, num_filters // 2, size=1)) + in_channels = num_filters // 2 + # For the final convolution, we want no activation at all to predict the + # optical flow vector values. We have done extensive testing on explicitly + # bounding these values using sigmoid, but it turned out that having no + # activation gives better results. + self._convs.append(conv(in_channels, 2, size=1, activation=None)) + + def forward(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor: + """Estimates optical flow between two images. + + Args: + features_a: per pixel feature vectors for image A (B x H x W x C) + features_b: per pixel feature vectors for image B (B x H x W x C) + + Returns: + A tensor with optical flow from A to B + """ + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + """Predicts optical flow by coarse-to-fine refinement. + """ + + def __init__(self, filters: int = 64, + flow_convs: tuple = (3, 3, 3, 3), + flow_filters: tuple = (32, 64, 128, 256)): + super(PyramidFlowEstimator, self).__init__() + + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append( + FlowEstimator( + in_channels=in_channels, + num_convs=flow_convs[i], + num_filters=flow_filters[i])) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a: List[torch.Tensor], + feature_pyramid_b: List[torch.Tensor]) -> List[torch.Tensor]: + """Estimates residual flow pyramids between two image pyramids. + + Each image pyramid is represented as a list of tensors in fine-to-coarse + order. Each individual image is represented as a tensor where each pixel is + a vector of image features. + + flow_pyramid_synthesis can be used to convert the residual flow + pyramid returned by this method into a flow pyramid, where each level + encodes the flow instead of a residual correction. + + Args: + feature_pyramid_a: image pyramid as a list in fine-to-coarse order + feature_pyramid_b: image pyramid as a list in fine-to-coarse order + + Returns: + List of flow tensors, in fine-to-coarse order, each level encoding the + difference against the bilinearly upsampled version from the coarser + level. The coarsest flow tensor, e.g. the last element in the array is the + 'DC-term', e.g. not a residual (alternatively you can think of it being a + residual against zero). + """ + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + for i in range(levels - 2, len(self._predictors) - 1, -1): + # Upsamples the flow to match the current pyramid level. Also, scales the + # magnitude by two to reflect the new size. + level_size = feature_pyramid_a[i].shape[2:4] + v = F.interpolate(2 * v, size=level_size, mode='bilinear') + # Warp feature_pyramid_b[i] image based on the current flow estimate. + warped = warp(feature_pyramid_b[i], v) + # Estimate the residual flow between pyramid_a[i] and warped image: + v_residual = self._predictor(feature_pyramid_a[i], warped) + residuals.insert(0, v_residual) + v = v_residual + v + + for k, predictor in enumerate(self._predictors): + i = len(self._predictors) - 1 - k + # Upsamples the flow to match the current pyramid level. Also, scales the + # magnitude by two to reflect the new size. + level_size = feature_pyramid_a[i].shape[2:4] + v = F.interpolate(2 * v, size=level_size, mode='bilinear') + # Warp feature_pyramid_b[i] image based on the current flow estimate. + warped = warp(feature_pyramid_b[i], v) + # Estimate the residual flow between pyramid_a[i] and warped image: + v_residual = predictor(feature_pyramid_a[i], warped) + residuals.insert(0, v_residual) + v = v_residual + v + return residuals + + + + + + + + + + +"""Various utilities used in the film_net frame interpolator model.""" +from typing import List, Optional + +import cv2 +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def pad_batch(batch, align): + height, width = batch.shape[1:3] + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)] + batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant') + return batch, crop_region + + +def load_image(path, align=64): + image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255) + image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align) + return image_batch, crop_region + + +def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]: + """Builds an image pyramid from a given image. + + The original image is included in the pyramid and the rest are generated by + successively halving the resolution. + + Args: + image: the input image. + options: film_net options object + + Returns: + A list of images starting from the finest with options.pyramid_levels items + """ + + pyramid = [] + for i in range(pyramid_levels): + pyramid.append(image) + if i < pyramid_levels - 1: + image = F.avg_pool2d(image, 2, 2) + return pyramid + + +def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: + """Backward warps the image using the given flow. + + Specifically, the output pixel in batch b, at position x, y will be computed + as follows: + (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0]) + output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x) + + Note that the flow vectors are expected as [x, y], e.g. x in position 0 and + y in position 1. + + Args: + image: An image with shape BxHxWxC. + flow: A flow with shape BxHxWx2, with the two channels denoting the relative + offset in order: (dx, dy). + Returns: + A warped image. + """ + flow = -flow.flip(1) + + dtype = flow.dtype + device = flow.device + + # warped = tfa_image.dense_image_warp(image, flow) + # Same as above but with pytorch + ls1 = 1 - 1 / flow.shape[3] + ls2 = 1 - 1 / flow.shape[2] + + normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor( + [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None] + normalized_flow2 = torch.stack([ + torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1], + torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0], + ], dim=3) + + padding_mode = "border" + if device.type == "mps": + # https://github.com/pytorch/pytorch/issues/125098 + padding_mode = "zeros" + normalized_flow2 = normalized_flow2.clamp(-1, 1) + warped = F.grid_sample( + input=image, + grid=normalized_flow2, + mode='bilinear', + padding_mode=padding_mode, + align_corners=False, + ) + return warped.reshape(image.shape) + + +def multiply_pyramid(pyramid: List[torch.Tensor], + scalar: torch.Tensor) -> List[torch.Tensor]: + """Multiplies all image batches in the pyramid by a batch of scalars. + + Args: + pyramid: Pyramid of image batches. + scalar: Batch of scalars. + + Returns: + An image pyramid with all images multiplied by the scalar. + """ + # To multiply each image with its corresponding scalar, we first transpose + # the batch of images from BxHxWxC-format to CxHxWxB. This can then be + # multiplied with a batch of scalars, then we transpose back to the standard + # BxHxWxC form. + return [image * scalar for image in pyramid] + + +def flow_pyramid_synthesis( + residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: + """Converts a residual flow pyramid into a flow pyramid.""" + flow = residual_pyramid[-1] + flow_pyramid: List[torch.Tensor] = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + level_size = residual_flow.shape[2:4] + flow = F.interpolate(2 * flow, size=level_size, mode='bilinear') + flow = residual_flow + flow + flow_pyramid.insert(0, flow) + return flow_pyramid + + +def pyramid_warp(feature_pyramid: List[torch.Tensor], + flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]: + """Warps the feature pyramid using the flow pyramid. + + Args: + feature_pyramid: feature pyramid starting from the finest level. + flow_pyramid: flow fields, starting from the finest level. + + Returns: + Reverse warped feature pyramid. + """ + warped_feature_pyramid = [] + for features, flow in zip(feature_pyramid, flow_pyramid): + warped_feature_pyramid.append(warp(features, flow)) + return warped_feature_pyramid + + +def concatenate_pyramids(pyramid1: List[torch.Tensor], + pyramid2: List[torch.Tensor]) -> List[torch.Tensor]: + """Concatenates each pyramid level together in the channel dimension.""" + result = [] + for features1, features2 in zip(pyramid1, pyramid2): + result.append(torch.cat([features1, features2], dim=1)) + return result + + +def conv(in_channels, out_channels, size, activation: Optional[str] = 'relu'): + # Since PyTorch doesn't have an in-built activation in Conv2d, we use a + # Sequential layer to combine Conv2d and Leaky ReLU in one module. + _conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=size, + padding='same') + if activation is None: + return _conv + assert activation == 'relu' + return nn.Sequential( + _conv, + nn.LeakyReLU(.2) + ) \ No newline at end of file diff --git a/vfi_models/flavr/__init__.py b/vfi_models/flavr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64a4b6bed3ba40712317708b175214bd773633fd --- /dev/null +++ b/vfi_models/flavr/__init__.py @@ -0,0 +1,115 @@ +import torch +from comfy.model_management import get_torch_device, soft_empty_cache +import numpy as np +import typing +from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames, assert_batch_size +import pathlib +import warnings +from .flavr_arch import UNet_3D_3D, InputPadder +import gc + +device = get_torch_device() +NBR_FRAME = 4 + +def build_flavr(model_path): + sd = torch.load(model_path)['state_dict'] + sd = {k.partition("module.")[-1]:v for k,v in sd.items()} + + #Ref: Class UNet_3D_3D + model = UNet_3D_3D("unet_18", n_inputs=NBR_FRAME, n_outputs=sd["outconv.1.weight"].shape[0] // 3, joinType="concat" , upmode="transpose") + model.load_state_dict(sd) + model.to(device).eval() + del sd + return model + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAMES = ["FLAVR_2x.pth", "FLAVR_4x.pth", "FLAVR_8x.pth"] + +class FLAVR_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (CKPT_NAMES, ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 2}), #TODO: Implement recursively invoking interpolator for multi-frame interpolation + "duplicate_first_last_frames": ("BOOLEAN", {"default": False}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + #Reference: https://github.com/danier97/ST-MFNet/blob/main/interpolate_yuv.py#L93 + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + duplicate_first_last_frames: bool = False, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + if multiplier != 2: + warnings.warn("Currently, FLAVR only supports 2x interpolation. The process will continue but please set multiplier=2 afterward") + + assert_batch_size(frames, batch_size=4, vfi_name="ST-MFNet") + interpolation_states = optional_interpolation_states + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + model = build_flavr(model_path) + frames = preprocess_frames(frames) + padder = InputPadder(frames.shape, 16) + frames = padder.pad(frames) + + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + output_frames = [] + for frame_itr in range(len(frames) - 3): + #Does skipping frame i+1 make sanse in this case? + if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr) and interpolation_states.is_frame_skipped(frame_itr + 1): + continue + + #Ensure that input frames are in fp32 - the same dtype as model + frame0, frame1, frame2, frame3 = ( + frames[frame_itr:frame_itr+1].float(), + frames[frame_itr+1:frame_itr+2].float(), + frames[frame_itr+2:frame_itr+3].float(), + frames[frame_itr+3:frame_itr+4].float() + ) + new_frame = model([frame0.to(device), frame1.to(device), frame2.to(device), frame3.to(device)])[0].detach().cpu() + number_of_frames_processed_since_last_cleared_cuda_cache += 2 + + if frame_itr == 0: + output_frames.append(frame0) + if duplicate_first_last_frames: + output_frames.append(frame0) # repeat the first frame + output_frames.append(frame1) + output_frames.append(new_frame) + output_frames.append(frame2) + if frame_itr == len(frames) - 4: + output_frames.append(frame3) + if duplicate_first_last_frames: + output_frames.append(frame3) # repeat the last frame + + # Try to avoid a memory overflow by clearing cuda cache regularly + if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: + print("Comfy-VFI: Clearing cache...", end = ' ') + soft_empty_cache() + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + print("Done cache clearing") + gc.collect() + + dtype = torch.float32 + output_frames = [frame.cpu().to(dtype=dtype) for frame in output_frames] #Ensure all frames are in cpu + out = torch.cat(output_frames, dim=0) + out = padder.unpad(out) + # clear cache for courtesy + print("Comfy-VFI: Final clearing cache...", end=' ') + soft_empty_cache() + print("Done cache clearing") + return (postprocess_frames(out), ) diff --git a/vfi_models/flavr/flavr_arch.py b/vfi_models/flavr/flavr_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3f26be0968f1192a95005900a081ac84b229323e --- /dev/null +++ b/vfi_models/flavr/flavr_arch.py @@ -0,0 +1,217 @@ +""" +https://github.com/tarun005/FLAVR/blob/main/model/FLAVR_arch.py +https://github.com/tarun005/FLAVR/blob/main/model/resnet_3D.py (only SEGating) +""" +import math +import numpy as np +import importlib + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SEGating(nn.Module): + + def __init__(self , inplanes , reduction=16): + + super().__init__() + + self.pool = nn.AdaptiveAvgPool3d(1) + self.attn_layer = nn.Sequential( + nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True), + nn.Sigmoid() + ) + + def forward(self , x): + + out = self.pool(x) + y = self.attn_layer(out) + return x * y + +def joinTensors(X1 , X2 , type="concat"): + + if type == "concat": + return torch.cat([X1 , X2] , dim=1) + elif type == "add": + return X1 + X2 + else: + return X1 + + +class Conv_2d(nn.Module): + + def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False): + + super().__init__() + self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] + + if batchnorm: + self.conv += [nn.BatchNorm2d(out_ch)] + + self.conv = nn.Sequential(*self.conv) + + def forward(self, x): + + return self.conv(x) + +class upConv3D(nn.Module): + + def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False): + + super().__init__() + + self.upmode = upmode + + if self.upmode=="transpose": + self.upconv = nn.ModuleList( + [nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding), + SEGating(out_ch) + ] + ) + + else: + self.upconv = nn.ModuleList( + [nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False), + nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1), + SEGating(out_ch) + ] + ) + + if batchnorm: + self.upconv += [nn.BatchNorm3d(out_ch)] + + self.upconv = nn.Sequential(*self.upconv) + + def forward(self, x): + + return self.upconv(x) + +class Conv_3d(nn.Module): + + def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False): + + super().__init__() + self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + SEGating(out_ch) + ] + + if batchnorm: + self.conv += [nn.BatchNorm3d(out_ch)] + + self.conv = nn.Sequential(*self.conv) + + def forward(self, x): + + return self.conv(x) + +class upConv2D(nn.Module): + + def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False): + + super().__init__() + + self.upmode = upmode + + if self.upmode=="transpose": + self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)] + + else: + self.upconv = [ + nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False), + nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1) + ] + + if batchnorm: + self.upconv += [nn.BatchNorm2d(out_ch)] + + self.upconv = nn.Sequential(*self.upconv) + + def forward(self, x): + + return self.upconv(x) + + +class UNet_3D_3D(nn.Module): + def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"): + super().__init__() + + nf = [512 , 256 , 128 , 64] + out_channels = 3*n_outputs + self.joinType = joinType + self.n_outputs = n_outputs + + growth = 2 if joinType == "concat" else 1 + self.lrelu = nn.LeakyReLU(0.2, True) + + unet_3D = importlib.import_module(".resnet_3D", "vfi_models.flavr") + if n_outputs > 1: + unet_3D.useBias = True + self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm) + + self.decoder = nn.Sequential( + Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm), + upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm), + upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm), + Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm), + upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm) + ) + + self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm) + + self.outconv = nn.Sequential( + nn.ReflectionPad2d(3), + nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0) + ) + + def forward(self, images): + + images = torch.stack(images , dim=2) + + ## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN + mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True) + images = images-mean_ + + x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images) + + dx_3 = self.lrelu(self.decoder[0](x_4)) + dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType) + + dx_2 = self.lrelu(self.decoder[1](dx_3)) + dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType) + + dx_1 = self.lrelu(self.decoder[2](dx_2)) + dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType) + + dx_0 = self.lrelu(self.decoder[3](dx_1)) + dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType) + + dx_out = self.lrelu(self.decoder[4](dx_0)) + dx_out = torch.cat(torch.unbind(dx_out , 2) , 1) + + out = self.lrelu(self.feature_fuse(dx_out)) + out = self.outconv(out) + + out = torch.split(out, dim=1, split_size_or_sections=3) + mean_ = mean_.squeeze(2) + out = [o+mean_ for o in out] + + return out + +class InputPadder: + """ Pads images such that dimensions are divisible by divisor """ + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + + def pad(self, input_tensor): + return F.pad(input_tensor, self._pad, mode='replicate') + + def unpad(self, input_tensor): + return self._unpad(input_tensor) + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] \ No newline at end of file diff --git a/vfi_models/flavr/resnet_3D.py b/vfi_models/flavr/resnet_3D.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8b9a85a9edcf24ed8ff685c667dbcdbff55e69 --- /dev/null +++ b/vfi_models/flavr/resnet_3D.py @@ -0,0 +1,288 @@ +# Modified from https://github.com/pytorch/vision/tree/master/torchvision/models/video + +import torch +import torch.nn as nn + +__all__ = ['unet_18', 'unet_34'] + +useBias = False + +class identity(nn.Module): + + def __init__(self , *args , **kwargs): + + super().__init__() + + def forward(self , x): + return x + +class Conv3DSimple(nn.Conv3d): + def __init__(self, + in_planes, + out_planes, + midplanes=None, + stride=1, + padding=1): + + super(Conv3DSimple, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=useBias) + + @staticmethod + def get_downsample_stride(stride , temporal_stride): + if temporal_stride: + return (temporal_stride, stride, stride) + else: + return (stride , stride , stride) + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem + """ + def __init__(self): + super().__init__( + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), + padding=(1, 3, 3), bias=useBias), + batchnorm(64), + nn.ReLU(inplace=False)) + + +class Conv2Plus1D(nn.Sequential): + + def __init__(self, + in_planes, + out_planes, + midplanes, + stride=1, + padding=1): + if not isinstance(stride , int): + temporal_stride , stride , stride = stride + else: + temporal_stride = stride + + super(Conv2Plus1D, self).__init__( + nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), + stride=(1, stride, stride), padding=(0, padding, padding), + bias=False), + # batchnorm(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), + stride=(temporal_stride, 1, 1), padding=(padding, 0, 0), + bias=False)) + + @staticmethod + def get_downsample_stride(stride , temporal_stride): + if temporal_stride: + return (temporal_stride, stride, stride) + else: + return (stride , stride , stride) + +class R2Plus1dStem(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution + """ + def __init__(self): + super().__init__( + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), + stride=(1, 2, 2), padding=(0, 3, 3), + bias=False), + batchnorm(45), + nn.ReLU(inplace=True), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), + stride=(1, 1, 1), padding=(1, 0, 0), + bias=False), + batchnorm(64), + nn.ReLU(inplace=True)) + + +class SEGating(nn.Module): + + def __init__(self , inplanes , reduction=16): + + super().__init__() + + self.pool = nn.AdaptiveAvgPool3d(1) + self.attn_layer = nn.Sequential( + nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True), + nn.Sigmoid() + ) + + def forward(self , x): + + out = self.pool(x) + y = self.attn_layer(out) + return x * y + +class BasicBlock(nn.Module): + + expansion = 1 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super(BasicBlock, self).__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), + batchnorm(planes), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes), + batchnorm(planes) + ) + self.fg = SEGating(planes) ## Feature Gating + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.conv2(out) + out = self.fg(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +class VideoResNet(nn.Module): + + def __init__(self, block, conv_makers, layers, + stem, zero_init_residual=False): + """Generic resnet video generator. + + Args: + block (nn.Module): resnet building block + conv_makers (list(functions)): generator function for each layer + layers (List[int]): number of blocks per layer + stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + """ + super(VideoResNet, self).__init__() + self.inplanes = 64 + + self.stem = stem() + + self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1 ) + self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2 , temporal_stride=1) + self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2 , temporal_stride=1) + self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=1, temporal_stride=1) + + # init weights + self._initialize_weights() + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def forward(self, x): + x_0 = self.stem(x) + x_1 = self.layer1(x_0) + x_2 = self.layer2(x_1) + x_3 = self.layer3(x_2) + x_4 = self.layer4(x_3) + return x_0 , x_1 , x_2 , x_3 , x_4 + + def _make_layer(self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride , temporal_stride) + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=ds_stride, bias=False), + batchnorm(planes * block.expansion) + ) + stride = ds_stride + + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample )) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder )) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _video_resnet(arch, pretrained=False, progress=True, **kwargs): + model = VideoResNet(**kwargs) + ## TODO: Other 3D resnet models, like S3D, r(2+1)D. + + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def unet_18(pretrained=False, bn=False, progress=True, **kwargs): + """ + Construct 18 layer Unet3D model as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: R3D-18 encoder + """ + global batchnorm + if bn: + batchnorm = nn.BatchNorm3d + else: + batchnorm = identity + + return _video_resnet('r3d_18', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, **kwargs) + +def unet_34(pretrained=False, bn=False, progress=True, **kwargs): + """ + Construct 34 layer Unet3D model as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: R3D-18 encoder + """ + global batchnorm + # bn = False + if bn: + batchnorm = nn.BatchNorm3d + else: + batchnorm = identity + + + return _video_resnet('r3d_34', + pretrained, progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[3, 4, 6, 3], + stem=BasicStem, **kwargs) \ No newline at end of file diff --git a/vfi_models/gmfss_fortuna/GMFSS_Fortuna.py b/vfi_models/gmfss_fortuna/GMFSS_Fortuna.py new file mode 100644 index 0000000000000000000000000000000000000000..949e5130c59b3cf090666dff229bfa9a2be3a072 --- /dev/null +++ b/vfi_models/gmfss_fortuna/GMFSS_Fortuna.py @@ -0,0 +1,24 @@ +import itertools +import numpy as np +import vapoursynth as vs +from .GMFSS_Fortuna_arch import Model_inference +import torch +import traceback + + +class GMFSS_Fortuna: + def __init__(self): + self.cache = False + self.amount_input_img = 2 + + torch.set_grad_enabled(False) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + self.model = Model_inference() + self.model.eval() + + def execute(self, I0, I1, timestep): + with torch.inference_mode(): + middle = self.model(I0, I1, timestep).cpu() + return middle diff --git a/vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py b/vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..062489675ef58d541c0311a8a88f3510a737075f --- /dev/null +++ b/vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py @@ -0,0 +1,1850 @@ +""" +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/GMFSS_infer_b.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/softsplat.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FusionNet_b.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FeatureNet.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/MetricNet.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/IFNet_HDv3.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/gmflow.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/utils.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/position.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/geometry.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/matching.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/transformer.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/backbone.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/trident_conv.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/warplayer.py +""" + +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import math +from vfi_models.rife.rife_arch import IFNet +from vfi_models.ops import softsplat +from comfy.model_management import get_torch_device + +device = get_torch_device() +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = ( + self.num_branch if self.training or self.test_branch_idx == -1 else 1 + ) + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d( + input, + self.weight, + self.bias, + stride, + padding, + self.dilation, + self.groups, + ) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] + if self.test_branch_idx == -1 + else self.strides[-1], + self.paddings[self.test_branch_idx] + if self.test_branch_idx == -1 + else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs + + +class ResidualBlock_class(nn.Module): + def __init__( + self, + in_planes, + planes, + norm_layer=nn.InstanceNorm2d, + stride=1, + dilation=1, + ): + super(ResidualBlock_class, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + dilation=dilation, + padding=dilation, + stride=stride, + bias=False, + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=False, + ) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__( + self, + output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d( + 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False + ) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer( + feature_dims[0], stride=1, norm_layer=norm_layer + ) # 1/2 + self.layer2 = self._make_layer( + feature_dims[1], stride=2, norm_layer=norm_layer + ) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer( + feature_dims[2], + stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv( + output_dim, + output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock_class( + self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation + ) + layer2 = ResidualBlock_class( + dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation + ) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask( + input_resolution, + window_size_h, + window_size_w, + shift_size_h, + shift_size_w, + device=get_torch_device(), +): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = ( + slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None), + ) + w_slices = ( + slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature( + img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True + ) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + return attn_mask + + +def single_head_split_window_attention( + q, + k, + v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, +): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c**0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature( + q, num_splits=num_splits, channel_last=True + ) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = ( + torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) + / scale_factor + ) # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits( + out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, + channel_last=True, + ) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__( + self, + d_model=256, + nhead=1, + attention_type="swin", + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == "swin" and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__( + self, + d_model=256, + nhead=1, + attention_type="swin", + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn( + source, + source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn( + source, + target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__( + self, + num_layers=6, + d_model=128, + nhead=1, + attention_type="swin", + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList( + [ + TransformerBlock( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True + if attention_type == "swin" and i % 2 == 1 + else False, + ) + for i in range(num_layers) + ] + ) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + feature0, + feature1, + attn_num_splits=None, + **kwargs, + ): + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == "swin" and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer( + concat0, + concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = ( + feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() + ) # [B, C, H, W] + feature1 = ( + feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() + ) # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__( + self, + in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + feature0, + flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn( + feature0, flow, local_window_radius=local_window_radius + ) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn( + self, + feature0, + flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj( + feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape( + b * h * w, 1, c + ) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = ( + self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)) + .permute(0, 2, 1) + .reshape(b, c, h, w) + ) + + feature0_window = F.unfold( + feature0_proj, kernel_size=kernel_size, padding=local_window_radius + ) # [B, C*(2R+1)^2), H*W] + + feature0_window = ( + feature0_window.view(b, c, kernel_size**2, h, w) + .permute(0, 3, 4, 1, 2) + .reshape(b * h * w, c, kernel_size**2) + ) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold( + flow, kernel_size=kernel_size, padding=local_window_radius + ) # [B, 2*(2R+1)^2), H*W] + + flow_window = ( + flow_window.view(b, 2, kernel_size**2, h, w) + .permute(0, 3, 4, 2, 1) + .reshape(b * h * w, kernel_size**2, 2) + ) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / ( + c**0.5 + ) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = ( + torch.matmul(prob, flow_window) + .view(b, h, w, 2) + .permute(0, 3, 1, 2) + .contiguous() + ) # [B, 2, H, W] + + return out + + +def global_correlation_softmax( + feature0, + feature1, + pred_bidir_flow=False, +): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / ( + c**0.5 + ) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat( + (correlation, correlation.permute(0, 2, 1)), dim=0 + ) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = ( + torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) + ) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax( + feature0, + feature1, + local_radius, + padding_mode="zeros", +): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid( + -local_radius, + local_radius, + -local_radius, + local_radius, + local_h, + local_w, + device=feature0.device, + ) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & ( + sample_coords[:, :, :, 0] < w + ) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & ( + sample_coords[:, :, :, 1] < h + ) # [B, H*W, (2R+1)^2] + + valid = ( + valid_x & valid_y + ) # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample( + feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True + ).permute( + 0, 2, 1, 3 + ) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / ( + c**0.5 + ) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = ( + torch.matmul(prob.unsqueeze(-2), sample_coords_softmax) + .squeeze(-2) + .view(b, h, w, 2) + .permute(0, 3, 1, 2) + ) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid( + [ + torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device), + ], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample( + img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False +): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample( + img, grid, mode=mode, padding_mode=padding_mode, align_corners=True + ) + + if return_mask: + mask = ( + (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) + ) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode="zeros"): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +def split_feature( + feature, + num_splits=2, + channel_last=False, +): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) + .permute(0, 1, 3, 2, 4, 5) + .reshape(b_new, h_new, w_new, c) + ) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) + .permute(0, 2, 4, 1, 3, 5) + .reshape(b_new, c, h_new, w_new) + ) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits( + splits, + num_splits=2, + channel_last=False, +): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = ( + splits.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(new_b, num_splits * h, num_splits * w, c) + ) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = ( + splits.permute(0, 3, 1, 4, 2, 5) + .contiguous() + .view(new_b, c, num_splits * h, num_splits * w) + ) # [B, C, H, W] + + return merge + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 - mean) / std + img1 = (img1 - mean) / std + + return img0, img1 + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +class GMFlow(nn.Module): + def __init__( + self, + num_scales=2, + upsample_factor=4, + feature_channels=128, + attention_type="swin", + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1, + **kwargs, + ): + super(GMFlow, self).__init__() + + self.num_scales = num_scales + self.feature_channels = feature_channels + self.upsample_factor = upsample_factor + self.attention_type = attention_type + self.num_transformer_layers = num_transformer_layers + + # CNN backbone + self.backbone = CNNEncoder( + output_dim=feature_channels, num_output_scales=num_scales + ) + + # Transformer + self.transformer = FeatureTransformer( + num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # flow propagation with self-attn + self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) + + # convex upsampling: concat feature0 and flow as input + self.upsampler = nn.Sequential( + nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0), + ) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone( + concat + ) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow( + self, + flow, + feature, + bilinear=False, + upsample_factor=8, + ): + if bilinear: + up_flow = ( + F.interpolate( + flow, + scale_factor=upsample_factor, + mode="bilinear", + align_corners=True, + ) + * upsample_factor + ) + + else: + # convex upsampling + concat = torch.cat((flow, feature), dim=1) + + mask = self.upsampler(concat) + b, flow_channel, h, w = flow.shape + mask = mask.view( + b, 1, 9, self.upsample_factor, self.upsample_factor, h, w + ) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) + up_flow = up_flow.view( + b, flow_channel, 9, 1, 1, h, w + ) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape( + b, flow_channel, self.upsample_factor * h, self.upsample_factor * w + ) # [B, 2, K*H, K*W] + + return up_flow + + def forward( + self, + img0, + img1, + attn_splits_list=[2, 8], + corr_radius_list=[-1, 4], + prop_radius_list=[-1, 1], + pred_bidir_flow=False, + **kwargs, + ): + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # resolution low to high + feature0_list, feature1_list = self.extract_feature( + img0, img1 + ) # list of features + + flow = None + + assert ( + len(attn_splits_list) + == len(corr_radius_list) + == len(prop_radius_list) + == self.num_scales + ) + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat( + (feature1, feature0), dim=0 + ) + + upsample_factor = self.upsample_factor * ( + 2 ** (self.num_scales - 1 - scale_idx) + ) + + if scale_idx > 0: + flow = ( + F.interpolate( + flow, scale_factor=2, mode="bilinear", align_corners=True + ) + * 2 + ) + + if flow is not None: + flow = flow.detach() + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + + attn_splits = attn_splits_list[scale_idx] + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position( + feature0, feature1, attn_splits, self.feature_channels + ) + + # Transformer + feature0, feature1 = self.transformer( + feature0, feature1, attn_num_splits=attn_splits + ) + + # correlation and softmax + if corr_radius == -1: # global matching + flow_pred = global_correlation_softmax( + feature0, feature1, pred_bidir_flow + )[0] + else: # local matching + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[ + 0 + ] + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + # upsample to the original resolution for supervison + if ( + self.training + ): # only need to upsample intermediate flow predictions at training time + flow_bilinear = self.upsample_flow( + flow, None, bilinear=True, upsample_factor=upsample_factor + ) + + # flow propagation with self-attn + if pred_bidir_flow and scale_idx == 0: + feature0 = torch.cat( + (feature0, feature1), dim=0 + ) # [2*B, C, H, W] for propagation + flow = self.feature_flow_attn( + feature0, + flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius, + ) + + # bilinear upsampling at training time except the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow( + flow, feature0, bilinear=True, upsample_factor=upsample_factor + ) + + if scale_idx == self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0) + + return flow_up + + +backwarp_tenGrid = {} + + +def backwarp(tenIn, tenflow): + if str(tenflow.shape) not in backwarp_tenGrid: + tenHor = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenflow.shape[3], + dtype=tenflow.dtype, + device=tenflow.device, + ) + .view(1, 1, 1, -1) + .repeat(1, 1, tenflow.shape[2], 1) + ) + tenVer = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenflow.shape[2], + dtype=tenflow.dtype, + device=tenflow.device, + ) + .view(1, 1, -1, 1) + .repeat(1, 1, 1, tenflow.shape[3]) + ) + + backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).to(get_torch_device()) + # end + + tenflow = torch.cat( + [ + tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), + tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + return torch.nn.functional.grid_sample( + input=tenIn, + grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + + +class MetricNet(nn.Module): + def __init__(self): + super(MetricNet, self).__init__() + self.metric_in = nn.Conv2d(14, 64, 3, 1, 1) + self.metric_net1 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1)) + self.metric_net2 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1)) + self.metric_net3 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1)) + self.metric_out = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 2, 3, 1, 1)) + + def forward(self, img0, img1, flow01, flow10): + metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction="none").mean( + [1], True + ) + metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction="none").mean( + [1], True + ) + + fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10) + + flow01 = torch.cat( + [ + flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0), + flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0), + ], + 1, + ) + flow10 = torch.cat( + [ + flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0), + flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + img = torch.cat((img0, img1), 1) + metric = torch.cat((-metric0, -metric1), 1) + flow = torch.cat((flow01, flow10), 1) + occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1) + + feat = self.metric_in(torch.cat((img, metric, flow, occ), 1)) + feat = self.metric_net1(feat) + feat + feat = self.metric_net2(feat) + feat + feat = self.metric_net3(feat) + feat + metric = self.metric_out(feat) + + metric = torch.tanh(metric) * 10 + + return metric[:, :1], metric[:, 1:2] + + +class FeatureNet(nn.Module): + """The quadratic model""" + + def __init__(self): + super(FeatureNet, self).__init__() + self.block1 = nn.Sequential( + nn.PReLU(), + nn.Conv2d(3, 64, 3, 2, 1), + nn.PReLU(), + nn.Conv2d(64, 64, 3, 1, 1), + ) + self.block2 = nn.Sequential( + nn.PReLU(), + nn.Conv2d(64, 128, 3, 2, 1), + nn.PReLU(), + nn.Conv2d(128, 128, 3, 1, 1), + ) + self.block3 = nn.Sequential( + nn.PReLU(), + nn.Conv2d(128, 192, 3, 2, 1), + nn.PReLU(), + nn.Conv2d(192, 192, 3, 1, 1), + ) + + def forward(self, x): + x1 = self.block1(x) + x2 = self.block2(x1) + x3 = self.block3(x2) + + return x1, x2, x3 + + +# Residual Block +def ResidualBlock(in_channels, out_channels, stride=1): + return torch.nn.Sequential( + nn.PReLU(), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + ), + nn.PReLU(), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + ), + ) + + +# downsample block +def DownsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + nn.PReLU(), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + ), + nn.PReLU(), + nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True + ), + ) + + +# upsample block +def UpsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + nn.PReLU(), + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=4, + stride=stride, + padding=1, + bias=True, + ), + nn.PReLU(), + nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True + ), + ) + + +class PixelShuffleBlcok(nn.Module): + def __init__(self, in_feat, num_feat, num_out_ch): + super(PixelShuffleBlcok, self).__init__() + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(in_feat, num_feat, 3, 1, 1), nn.PReLU() + ) + self.upsample = nn.Sequential( + nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), nn.PixelShuffle(2) + ) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + return x + + +# grid network +class GridNet(nn.Module): + def __init__( + self, + in_channels=12, + in_channels1=128, + in_channels2=256, + in_channels3=384, + out_channels=3, + ): + super(GridNet, self).__init__() + + self.residual_model_head = ResidualBlock(in_channels, 64) + self.residual_model_head1 = ResidualBlock(in_channels1, 64) + self.residual_model_head2 = ResidualBlock(in_channels2, 128) + self.residual_model_head3 = ResidualBlock(in_channels3, 192) + + self.residual_model_01 = ResidualBlock(64, 64) + # self.residual_model_02=ResidualBlock(64, 64) + # self.residual_model_03=ResidualBlock(64, 64) + self.residual_model_04 = ResidualBlock(64, 64) + self.residual_model_05 = ResidualBlock(64, 64) + self.residual_model_tail = PixelShuffleBlcok(64, 64, out_channels) + + self.residual_model_11 = ResidualBlock(128, 128) + # self.residual_model_12=ResidualBlock(128, 128) + # self.residual_model_13=ResidualBlock(128, 128) + self.residual_model_14 = ResidualBlock(128, 128) + self.residual_model_15 = ResidualBlock(128, 128) + + self.residual_model_21 = ResidualBlock(192, 192) + # self.residual_model_22=ResidualBlock(192, 192) + # self.residual_model_23=ResidualBlock(192, 192) + self.residual_model_24 = ResidualBlock(192, 192) + self.residual_model_25 = ResidualBlock(192, 192) + + # + + self.downsample_model_10 = DownsampleBlock(64, 128) + self.downsample_model_20 = DownsampleBlock(128, 192) + + self.downsample_model_11 = DownsampleBlock(64, 128) + self.downsample_model_21 = DownsampleBlock(128, 192) + + # self.downsample_model_12=DownsampleBlock(64, 128) + # self.downsample_model_22=DownsampleBlock(128, 192) + + # + + # self.upsample_model_03=UpsampleBlock(128, 64) + # self.upsample_model_13=UpsampleBlock(192, 128) + + self.upsample_model_04 = UpsampleBlock(128, 64) + self.upsample_model_14 = UpsampleBlock(192, 128) + + self.upsample_model_05 = UpsampleBlock(128, 64) + self.upsample_model_15 = UpsampleBlock(192, 128) + + def forward(self, x, x1, x2, x3): + X00 = self.residual_model_head(x) + self.residual_model_head1( + x1 + ) # --- 182 ~ 185 + # X10 = self.residual_model_head1(x1) + + X01 = self.residual_model_01(X00) + X00 # --- 208 ~ 211 ,AddBackward1213 + + X10 = self.downsample_model_10(X00) + self.residual_model_head2( + x2 + ) # --- 186 ~ 189 + X20 = self.downsample_model_20(X10) + self.residual_model_head3( + x3 + ) # --- 190 ~ 193 + + residual_11 = ( + self.residual_model_11(X10) + X10 + ) # 201 ~ 204 , sum AddBackward1206 + downsample_11 = self.downsample_model_11(X01) # 214 ~ 217 + X11 = residual_11 + downsample_11 # --- AddBackward1218 + + residual_21 = ( + self.residual_model_21(X20) + X20 + ) # 194 ~ 197 , sum AddBackward1199 + downsample_21 = self.downsample_model_21(X11) # 219 ~ 222 + X21 = residual_21 + downsample_21 # AddBackward1223 + + X24 = self.residual_model_24(X21) + X21 # --- 224 ~ 227 , AddBackward1229 + X25 = self.residual_model_25(X24) + X24 # --- 230 ~ 233 , AddBackward1235 + + upsample_14 = self.upsample_model_14(X24) # 242 ~ 246 + residual_14 = self.residual_model_14(X11) + X11 # 248 ~ 251, AddBackward1253 + X14 = upsample_14 + residual_14 # --- AddBackward1254 + + upsample_04 = self.upsample_model_04(X14) # 268 ~ 272 + residual_04 = self.residual_model_04(X01) + X01 # 274 ~ 277, AddBackward1279 + X04 = upsample_04 + residual_04 # --- AddBackward1280 + + upsample_15 = self.upsample_model_15(X25) # 236 ~ 240 + residual_15 = self.residual_model_15(X14) + X14 # 255 ~ 258, AddBackward1260 + X15 = upsample_15 + residual_15 # AddBackward1261 + + upsample_05 = self.upsample_model_05(X15) # 262 ~ 266 + residual_05 = self.residual_model_05(X04) + X04 # 281 ~ 284,AddBackward1286 + X05 = upsample_05 + residual_05 # AddBackward1287 + + X_tail = self.residual_model_tail(X05) # 288 ~ 291 + + return X_tail +# end + +class Model: + def __init__(self): + self.flownet = GMFlow() + self.metricnet = MetricNet() + self.feat_ext = FeatureNet() + self.fusionnet = GridNet() + self.version = 3.9 + + def eval(self): + self.flownet.eval() + self.metricnet.eval() + self.feat_ext.eval() + self.fusionnet.eval() + + def device(self): + self.flownet.to(device) + self.metricnet.to(device) + self.feat_ext.to(device) + self.fusionnet.to(device) + + def load_model(self, path_dict): + #models/GMFSS_fortuna_flownet.pkl + self.flownet.load_state_dict(torch.load(path_dict["flownet"])) + #models/GMFSS_fortuna_metric.pkl + self.metricnet.load_state_dict(torch.load(path_dict["metricnet"])) + #models/GMFSS_fortuna_feat.pkl + self.feat_ext.load_state_dict(torch.load(path_dict["feat_ext"])) + #models/GMFSS_fortuna_fusionnet.pkl + self.fusionnet.load_state_dict(torch.load(path_dict["fusionnet"])) + + def reuse(self, img0, img1, scale): + feat11, feat12, feat13 = self.feat_ext(img0) + feat21, feat22, feat23 = self.feat_ext(img1) + + img0 = F.interpolate( + img0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + img1 = F.interpolate( + img1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + + if scale != 1.0: + imgf0 = F.interpolate( + img0, scale_factor=scale, mode="bilinear", align_corners=False + ) + imgf1 = F.interpolate( + img1, scale_factor=scale, mode="bilinear", align_corners=False + ) + else: + imgf0 = img0 + imgf1 = img1 + flow01 = self.flownet(imgf0, imgf1, return_flow=True) + flow10 = self.flownet(imgf1, imgf0, return_flow=True) + if scale != 1.0: + flow01 = ( + F.interpolate( + flow01, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + ) + / scale + ) + flow10 = ( + F.interpolate( + flow10, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + ) + / scale + ) + + metric0, metric1 = self.metricnet(img0, img1, flow01, flow10) + + return ( + flow01, + flow10, + metric0, + metric1, + feat11, + feat12, + feat13, + feat21, + feat22, + feat23, + ) + + def inference( + self, + img0, + img1, + flow01, + flow10, + metric0, + metric1, + feat11, + feat12, + feat13, + feat21, + feat22, + feat23, + timestep, + ): + F1t = timestep * flow01 + F2t = (1 - timestep) * flow10 + + Z1t = timestep * metric0 + Z2t = (1 - timestep) * metric1 + + img0 = F.interpolate( + img0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + I1t = softsplat(img0, F1t, Z1t, strMode="soft") + img1 = F.interpolate( + img1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + I2t = softsplat(img1, F2t, Z2t, strMode="soft") + + feat1t1 = softsplat(feat11, F1t, Z1t, strMode="soft") + feat2t1 = softsplat(feat21, F2t, Z2t, strMode="soft") + + F1td = ( + F.interpolate(F1t, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + Z1d = F.interpolate(Z1t, scale_factor=0.5, mode="bilinear", align_corners=False) + feat1t2 = softsplat(feat12, F1td, Z1d, strMode="soft") + F2td = ( + F.interpolate(F2t, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + Z2d = F.interpolate(Z2t, scale_factor=0.5, mode="bilinear", align_corners=False) + feat2t2 = softsplat(feat22, F2td, Z2d, strMode="soft") + + F1tdd = ( + F.interpolate(F1t, scale_factor=0.25, mode="bilinear", align_corners=False) + * 0.25 + ) + Z1dd = F.interpolate( + Z1t, scale_factor=0.25, mode="bilinear", align_corners=False + ) + feat1t3 = softsplat(feat13, F1tdd, Z1dd, strMode="soft") + F2tdd = ( + F.interpolate(F2t, scale_factor=0.25, mode="bilinear", align_corners=False) + * 0.25 + ) + Z2dd = F.interpolate( + Z2t, scale_factor=0.25, mode="bilinear", align_corners=False + ) + feat2t3 = softsplat(feat23, F2tdd, Z2dd, strMode="soft") + + out = self.fusionnet( + torch.cat([img0, I1t, I2t, img1], dim=1), + torch.cat([feat1t1, feat2t1], dim=1), + torch.cat([feat1t2, feat2t2], dim=1), + torch.cat([feat1t3, feat2t3], dim=1), + ) + + return torch.clamp(out, 0, 1) diff --git a/vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py b/vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py new file mode 100644 index 0000000000000000000000000000000000000000..41e92ddfa62fae5877eb12b43677d28ee9d3e29e --- /dev/null +++ b/vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py @@ -0,0 +1,23 @@ +import itertools +import numpy as np +import vapoursynth as vs +from .GMFSS_Fortuna_union_arch import Model_inference +import torch + + +class GMFSS_Fortuna_union: + def __init__(self): + self.cache = False + self.amount_input_img = 2 + + torch.set_grad_enabled(False) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + self.model = Model_inference() + self.model.eval() + + def execute(self, I0, I1, timestep): + with torch.inference_mode(): + middle = self.model(I0, I1, timestep).cpu() + return middle diff --git a/vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py b/vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..665e56534b647fe884d236693833f12060daa148 --- /dev/null +++ b/vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py @@ -0,0 +1,1857 @@ +""" +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/GMFSS_infer_u.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/softsplat.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FusionNet_u.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FeatureNet.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/MetricNet.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/IFNet_HDv3.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/gmflow.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/utils.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/position.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/geometry.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/matching.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/transformer.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/backbone.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/trident_conv.py +https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/warplayer.py +""" + +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import math +from vfi_models.rife.rife_arch import IFNet +from vfi_models.ops import softsplat +from comfy.model_management import get_torch_device + +device = get_torch_device() +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = ( + self.num_branch if self.training or self.test_branch_idx == -1 else 1 + ) + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d( + input, + self.weight, + self.bias, + stride, + padding, + self.dilation, + self.groups, + ) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] + if self.test_branch_idx == -1 + else self.strides[-1], + self.paddings[self.test_branch_idx] + if self.test_branch_idx == -1 + else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs + + +class ResidualBlock_class(nn.Module): + def __init__( + self, + in_planes, + planes, + norm_layer=nn.InstanceNorm2d, + stride=1, + dilation=1, + ): + super(ResidualBlock_class, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + dilation=dilation, + padding=dilation, + stride=stride, + bias=False, + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=False, + ) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__( + self, + output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d( + 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False + ) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer( + feature_dims[0], stride=1, norm_layer=norm_layer + ) # 1/2 + self.layer2 = self._make_layer( + feature_dims[1], stride=2, norm_layer=norm_layer + ) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer( + feature_dims[2], + stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv( + output_dim, + output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock_class( + self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation + ) + layer2 = ResidualBlock_class( + dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation + ) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask( + input_resolution, + window_size_h, + window_size_w, + shift_size_h, + shift_size_w, + device=get_torch_device(), +): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = ( + slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None), + ) + w_slices = ( + slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature( + img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True + ) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + return attn_mask + + +def single_head_split_window_attention( + q, + k, + v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, +): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c**0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature( + q, num_splits=num_splits, channel_last=True + ) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = ( + torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) + / scale_factor + ) # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits( + out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, + channel_last=True, + ) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__( + self, + d_model=256, + nhead=1, + attention_type="swin", + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == "swin" and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__( + self, + d_model=256, + nhead=1, + attention_type="swin", + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn( + source, + source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn( + source, + target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__( + self, + num_layers=6, + d_model=128, + nhead=1, + attention_type="swin", + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList( + [ + TransformerBlock( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True + if attention_type == "swin" and i % 2 == 1 + else False, + ) + for i in range(num_layers) + ] + ) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + feature0, + feature1, + attn_num_splits=None, + **kwargs, + ): + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == "swin" and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer( + concat0, + concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = ( + feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() + ) # [B, C, H, W] + feature1 = ( + feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() + ) # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__( + self, + in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward( + self, + feature0, + flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn( + feature0, flow, local_window_radius=local_window_radius + ) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn( + self, + feature0, + flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj( + feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape( + b * h * w, 1, c + ) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = ( + self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)) + .permute(0, 2, 1) + .reshape(b, c, h, w) + ) + + feature0_window = F.unfold( + feature0_proj, kernel_size=kernel_size, padding=local_window_radius + ) # [B, C*(2R+1)^2), H*W] + + feature0_window = ( + feature0_window.view(b, c, kernel_size**2, h, w) + .permute(0, 3, 4, 1, 2) + .reshape(b * h * w, c, kernel_size**2) + ) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold( + flow, kernel_size=kernel_size, padding=local_window_radius + ) # [B, 2*(2R+1)^2), H*W] + + flow_window = ( + flow_window.view(b, 2, kernel_size**2, h, w) + .permute(0, 3, 4, 2, 1) + .reshape(b * h * w, kernel_size**2, 2) + ) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / ( + c**0.5 + ) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = ( + torch.matmul(prob, flow_window) + .view(b, h, w, 2) + .permute(0, 3, 1, 2) + .contiguous() + ) # [B, 2, H, W] + + return out + + +def global_correlation_softmax( + feature0, + feature1, + pred_bidir_flow=False, +): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / ( + c**0.5 + ) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat( + (correlation, correlation.permute(0, 2, 1)), dim=0 + ) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = ( + torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) + ) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax( + feature0, + feature1, + local_radius, + padding_mode="zeros", +): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid( + -local_radius, + local_radius, + -local_radius, + local_radius, + local_h, + local_w, + device=feature0.device, + ) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & ( + sample_coords[:, :, :, 0] < w + ) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & ( + sample_coords[:, :, :, 1] < h + ) # [B, H*W, (2R+1)^2] + + valid = ( + valid_x & valid_y + ) # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample( + feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True + ).permute( + 0, 2, 1, 3 + ) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / ( + c**0.5 + ) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = ( + torch.matmul(prob.unsqueeze(-2), sample_coords_softmax) + .squeeze(-2) + .view(b, h, w, 2) + .permute(0, 3, 1, 2) + ) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid( + [ + torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device), + ], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample( + img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False +): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample( + img, grid, mode=mode, padding_mode=padding_mode, align_corners=True + ) + + if return_mask: + mask = ( + (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) + ) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode="zeros"): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +def split_feature( + feature, + num_splits=2, + channel_last=False, +): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) + .permute(0, 1, 3, 2, 4, 5) + .reshape(b_new, h_new, w_new, c) + ) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) + .permute(0, 2, 4, 1, 3, 5) + .reshape(b_new, c, h_new, w_new) + ) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits( + splits, + num_splits=2, + channel_last=False, +): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = ( + splits.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(new_b, num_splits * h, num_splits * w, c) + ) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = ( + splits.permute(0, 3, 1, 4, 2, 5) + .contiguous() + .view(new_b, c, num_splits * h, num_splits * w) + ) # [B, C, H, W] + + return merge + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 - mean) / std + img1 = (img1 - mean) / std + + return img0, img1 + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +class GMFlow(nn.Module): + def __init__( + self, + num_scales=2, + upsample_factor=4, + feature_channels=128, + attention_type="swin", + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1, + **kwargs, + ): + super(GMFlow, self).__init__() + + self.num_scales = num_scales + self.feature_channels = feature_channels + self.upsample_factor = upsample_factor + self.attention_type = attention_type + self.num_transformer_layers = num_transformer_layers + + # CNN backbone + self.backbone = CNNEncoder( + output_dim=feature_channels, num_output_scales=num_scales + ) + + # Transformer + self.transformer = FeatureTransformer( + num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # flow propagation with self-attn + self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) + + # convex upsampling: concat feature0 and flow as input + self.upsampler = nn.Sequential( + nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0), + ) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone( + concat + ) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow( + self, + flow, + feature, + bilinear=False, + upsample_factor=8, + ): + if bilinear: + up_flow = ( + F.interpolate( + flow, + scale_factor=upsample_factor, + mode="bilinear", + align_corners=True, + ) + * upsample_factor + ) + + else: + # convex upsampling + concat = torch.cat((flow, feature), dim=1) + + mask = self.upsampler(concat) + b, flow_channel, h, w = flow.shape + mask = mask.view( + b, 1, 9, self.upsample_factor, self.upsample_factor, h, w + ) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) + up_flow = up_flow.view( + b, flow_channel, 9, 1, 1, h, w + ) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape( + b, flow_channel, self.upsample_factor * h, self.upsample_factor * w + ) # [B, 2, K*H, K*W] + + return up_flow + + def forward( + self, + img0, + img1, + attn_splits_list=[2, 8], + corr_radius_list=[-1, 4], + prop_radius_list=[-1, 1], + pred_bidir_flow=False, + **kwargs, + ): + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # resolution low to high + feature0_list, feature1_list = self.extract_feature( + img0, img1 + ) # list of features + + flow = None + + assert ( + len(attn_splits_list) + == len(corr_radius_list) + == len(prop_radius_list) + == self.num_scales + ) + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat( + (feature1, feature0), dim=0 + ) + + upsample_factor = self.upsample_factor * ( + 2 ** (self.num_scales - 1 - scale_idx) + ) + + if scale_idx > 0: + flow = ( + F.interpolate( + flow, scale_factor=2, mode="bilinear", align_corners=True + ) + * 2 + ) + + if flow is not None: + flow = flow.detach() + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + + attn_splits = attn_splits_list[scale_idx] + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position( + feature0, feature1, attn_splits, self.feature_channels + ) + + # Transformer + feature0, feature1 = self.transformer( + feature0, feature1, attn_num_splits=attn_splits + ) + + # correlation and softmax + if corr_radius == -1: # global matching + flow_pred = global_correlation_softmax( + feature0, feature1, pred_bidir_flow + )[0] + else: # local matching + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[ + 0 + ] + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + # upsample to the original resolution for supervison + if ( + self.training + ): # only need to upsample intermediate flow predictions at training time + flow_bilinear = self.upsample_flow( + flow, None, bilinear=True, upsample_factor=upsample_factor + ) + + # flow propagation with self-attn + if pred_bidir_flow and scale_idx == 0: + feature0 = torch.cat( + (feature0, feature1), dim=0 + ) # [2*B, C, H, W] for propagation + flow = self.feature_flow_attn( + feature0, + flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius, + ) + + # bilinear upsampling at training time except the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow( + flow, feature0, bilinear=True, upsample_factor=upsample_factor + ) + + if scale_idx == self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0) + + return flow_up + + +backwarp_tenGrid = {} + + +def backwarp(tenIn, tenflow): + if str(tenflow.shape) not in backwarp_tenGrid: + tenHor = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenflow.shape[3], + dtype=tenflow.dtype, + device=tenflow.device, + ) + .view(1, 1, 1, -1) + .repeat(1, 1, tenflow.shape[2], 1) + ) + tenVer = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenflow.shape[2], + dtype=tenflow.dtype, + device=tenflow.device, + ) + .view(1, 1, -1, 1) + .repeat(1, 1, 1, tenflow.shape[3]) + ) + + backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).to(get_torch_device()) + # end + + tenflow = torch.cat( + [ + tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), + tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + return torch.nn.functional.grid_sample( + input=tenIn, + grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + + +class MetricNet(nn.Module): + def __init__(self): + super(MetricNet, self).__init__() + self.metric_in = nn.Conv2d(14, 64, 3, 1, 1) + self.metric_net1 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1)) + self.metric_net2 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1)) + self.metric_net3 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1)) + self.metric_out = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 2, 3, 1, 1)) + + def forward(self, img0, img1, flow01, flow10): + metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction="none").mean( + [1], True + ) + metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction="none").mean( + [1], True + ) + + fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10) + + flow01 = torch.cat( + [ + flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0), + flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0), + ], + 1, + ) + flow10 = torch.cat( + [ + flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0), + flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + img = torch.cat((img0, img1), 1) + metric = torch.cat((-metric0, -metric1), 1) + flow = torch.cat((flow01, flow10), 1) + occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1) + + feat = self.metric_in(torch.cat((img, metric, flow, occ), 1)) + feat = self.metric_net1(feat) + feat + feat = self.metric_net2(feat) + feat + feat = self.metric_net3(feat) + feat + metric = self.metric_out(feat) + + metric = torch.tanh(metric) * 10 + + return metric[:, :1], metric[:, 1:2] + + +class FeatureNet(nn.Module): + """The quadratic model""" + + def __init__(self): + super(FeatureNet, self).__init__() + self.block1 = nn.Sequential( + nn.PReLU(), + nn.Conv2d(3, 64, 3, 2, 1), + nn.PReLU(), + nn.Conv2d(64, 64, 3, 1, 1), + ) + self.block2 = nn.Sequential( + nn.PReLU(), + nn.Conv2d(64, 128, 3, 2, 1), + nn.PReLU(), + nn.Conv2d(128, 128, 3, 1, 1), + ) + self.block3 = nn.Sequential( + nn.PReLU(), + nn.Conv2d(128, 192, 3, 2, 1), + nn.PReLU(), + nn.Conv2d(192, 192, 3, 1, 1), + ) + + def forward(self, x): + x1 = self.block1(x) + x2 = self.block2(x1) + x3 = self.block3(x2) + + return x1, x2, x3 + + +# Residual Block +def ResidualBlock(in_channels, out_channels, stride=1): + return torch.nn.Sequential( + nn.PReLU(), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + ), + nn.PReLU(), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + ), + ) + + +# downsample block +def DownsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + nn.PReLU(), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=True, + ), + nn.PReLU(), + nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True + ), + ) + + +# upsample block +def UpsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + nn.PReLU(), + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=4, + stride=stride, + padding=1, + bias=True, + ), + nn.PReLU(), + nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True + ), + ) + + +class PixelShuffleBlcok(nn.Module): + def __init__(self, in_feat, num_feat, num_out_ch): + super(PixelShuffleBlcok, self).__init__() + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(in_feat, num_feat, 3, 1, 1), nn.PReLU() + ) + self.upsample = nn.Sequential( + nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), nn.PixelShuffle(2) + ) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + return x + + +# grid network +class GridNet(nn.Module): + def __init__( + self, + in_channels=9, + in_channels1=128, + in_channels2=256, + in_channels3=384, + out_channels=3, + ): + super(GridNet, self).__init__() + + self.residual_model_head0 = ResidualBlock(in_channels, 64) + self.residual_model_head1 = ResidualBlock(in_channels1, 64) + self.residual_model_head2 = ResidualBlock(in_channels2, 128) + self.residual_model_head3 = ResidualBlock(in_channels3, 192) + + self.residual_model_01 = ResidualBlock(64, 64) + # self.residual_model_02=ResidualBlock(64, 64) + # self.residual_model_03=ResidualBlock(64, 64) + self.residual_model_04 = ResidualBlock(64, 64) + self.residual_model_05 = ResidualBlock(64, 64) + self.residual_model_tail = PixelShuffleBlcok(64, 64, out_channels) + + self.residual_model_11 = ResidualBlock(128, 128) + # self.residual_model_12=ResidualBlock(128, 128) + # self.residual_model_13=ResidualBlock(128, 128) + self.residual_model_14 = ResidualBlock(128, 128) + self.residual_model_15 = ResidualBlock(128, 128) + + self.residual_model_21 = ResidualBlock(192, 192) + # self.residual_model_22=ResidualBlock(192, 192) + # self.residual_model_23=ResidualBlock(192, 192) + self.residual_model_24 = ResidualBlock(192, 192) + self.residual_model_25 = ResidualBlock(192, 192) + + # + + self.downsample_model_10 = DownsampleBlock(64, 128) + self.downsample_model_20 = DownsampleBlock(128, 192) + + self.downsample_model_11 = DownsampleBlock(64, 128) + self.downsample_model_21 = DownsampleBlock(128, 192) + + # self.downsample_model_12=DownsampleBlock(64, 128) + # self.downsample_model_22=DownsampleBlock(128, 192) + + # + + # self.upsample_model_03=UpsampleBlock(128, 64) + # self.upsample_model_13=UpsampleBlock(192, 128) + + self.upsample_model_04 = UpsampleBlock(128, 64) + self.upsample_model_14 = UpsampleBlock(192, 128) + + self.upsample_model_05 = UpsampleBlock(128, 64) + self.upsample_model_15 = UpsampleBlock(192, 128) + + def forward(self, x, x1, x2, x3): + X00 = self.residual_model_head0(x) + self.residual_model_head1( + x1 + ) # --- 182 ~ 185 + # X10 = self.residual_model_head1(x1) + + X01 = self.residual_model_01(X00) + X00 # --- 208 ~ 211 ,AddBackward1213 + + X10 = self.downsample_model_10(X00) + self.residual_model_head2( + x2 + ) # --- 186 ~ 189 + X20 = self.downsample_model_20(X10) + self.residual_model_head3( + x3 + ) # --- 190 ~ 193 + + residual_11 = ( + self.residual_model_11(X10) + X10 + ) # 201 ~ 204 , sum AddBackward1206 + downsample_11 = self.downsample_model_11(X01) # 214 ~ 217 + X11 = residual_11 + downsample_11 # --- AddBackward1218 + + residual_21 = ( + self.residual_model_21(X20) + X20 + ) # 194 ~ 197 , sum AddBackward1199 + downsample_21 = self.downsample_model_21(X11) # 219 ~ 222 + X21 = residual_21 + downsample_21 # AddBackward1223 + + X24 = self.residual_model_24(X21) + X21 # --- 224 ~ 227 , AddBackward1229 + X25 = self.residual_model_25(X24) + X24 # --- 230 ~ 233 , AddBackward1235 + + upsample_14 = self.upsample_model_14(X24) # 242 ~ 246 + residual_14 = self.residual_model_14(X11) + X11 # 248 ~ 251, AddBackward1253 + X14 = upsample_14 + residual_14 # --- AddBackward1254 + + upsample_04 = self.upsample_model_04(X14) # 268 ~ 272 + residual_04 = self.residual_model_04(X01) + X01 # 274 ~ 277, AddBackward1279 + X04 = upsample_04 + residual_04 # --- AddBackward1280 + + upsample_15 = self.upsample_model_15(X25) # 236 ~ 240 + residual_15 = self.residual_model_15(X14) + X14 # 255 ~ 258, AddBackward1260 + X15 = upsample_15 + residual_15 # AddBackward1261 + + upsample_05 = self.upsample_model_05(X15) # 262 ~ 266 + residual_05 = self.residual_model_05(X04) + X04 # 281 ~ 284,AddBackward1286 + X05 = upsample_05 + residual_05 # AddBackward1287 + + X_tail = self.residual_model_tail(X05) # 288 ~ 291 + + return X_tail +# end + + +class Model: + def __init__(self): + self.flownet = GMFlow() + self.ifnet = IFNet(arch_ver="4.6") + self.metricnet = MetricNet() + self.feat_ext = FeatureNet() + self.fusionnet = GridNet() + self.version = 3.9 + + def eval(self): + self.flownet.eval() + self.ifnet.eval() + self.metricnet.eval() + self.feat_ext.eval() + self.fusionnet.eval() + + def device(self): + self.flownet.to(device) + self.ifnet.to(device) + self.metricnet.to(device) + self.feat_ext.to(device) + self.fusionnet.to(device) + + def load_model(self, path_dict): + #models/rife46.pth + self.ifnet.load_state_dict(torch.load(path_dict["ifnet"])) + #models/GMFSS_fortuna_flownet.pkl + self.flownet.load_state_dict(torch.load(path_dict["flownet"])) + #models/GMFSS_fortuna_union_metric.pkl + self.metricnet.load_state_dict(torch.load(path_dict["metricnet"])) + #models/GMFSS_fortuna_union_feat.pkl + self.feat_ext.load_state_dict(torch.load(path_dict["feat_ext"])) + #models/GMFSS_fortuna_union_fusionnet.pkl + self.fusionnet.load_state_dict(torch.load(path_dict["fusionnet"])) + + def reuse(self, img0, img1, scale): + feat11, feat12, feat13 = self.feat_ext(img0) + feat21, feat22, feat23 = self.feat_ext(img1) + + img0 = F.interpolate( + img0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + img1 = F.interpolate( + img1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + + if scale != 1.0: + imgf0 = F.interpolate( + img0, scale_factor=scale, mode="bilinear", align_corners=False + ) + imgf1 = F.interpolate( + img1, scale_factor=scale, mode="bilinear", align_corners=False + ) + else: + imgf0 = img0 + imgf1 = img1 + flow01 = self.flownet(imgf0, imgf1, return_flow=True) + flow10 = self.flownet(imgf1, imgf0, return_flow=True) + if scale != 1.0: + flow01 = ( + F.interpolate( + flow01, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + ) + / scale + ) + flow10 = ( + F.interpolate( + flow10, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + ) + / scale + ) + + metric0, metric1 = self.metricnet(img0, img1, flow01, flow10) + + return ( + flow01, + flow10, + metric0, + metric1, + feat11, + feat12, + feat13, + feat21, + feat22, + feat23, + ) + + def inference( + self, + img0, + img1, + flow01, + flow10, + metric0, + metric1, + feat11, + feat12, + feat13, + feat21, + feat22, + feat23, + timestep, + ): + F1t = timestep * flow01 + F2t = (1 - timestep) * flow10 + + Z1t = timestep * metric0 + Z2t = (1 - timestep) * metric1 + + img0 = F.interpolate( + img0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + I1t = softsplat(img0, F1t, Z1t, strMode="soft") + img1 = F.interpolate( + img1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + I2t = softsplat(img1, F2t, Z2t, strMode="soft") + + rife = self.ifnet(img0, img1, timestep, scale_list=[8, 4, 2, 1]) + + feat1t1 = softsplat(feat11, F1t, Z1t, strMode="soft") + feat2t1 = softsplat(feat21, F2t, Z2t, strMode="soft") + + F1td = ( + F.interpolate(F1t, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + Z1d = F.interpolate(Z1t, scale_factor=0.5, mode="bilinear", align_corners=False) + feat1t2 = softsplat(feat12, F1td, Z1d, strMode="soft") + F2td = ( + F.interpolate(F2t, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + Z2d = F.interpolate(Z2t, scale_factor=0.5, mode="bilinear", align_corners=False) + feat2t2 = softsplat(feat22, F2td, Z2d, strMode="soft") + + F1tdd = ( + F.interpolate(F1t, scale_factor=0.25, mode="bilinear", align_corners=False) + * 0.25 + ) + Z1dd = F.interpolate( + Z1t, scale_factor=0.25, mode="bilinear", align_corners=False + ) + feat1t3 = softsplat(feat13, F1tdd, Z1dd, strMode="soft") + F2tdd = ( + F.interpolate(F2t, scale_factor=0.25, mode="bilinear", align_corners=False) + * 0.25 + ) + Z2dd = F.interpolate( + Z2t, scale_factor=0.25, mode="bilinear", align_corners=False + ) + feat2t3 = softsplat(feat23, F2tdd, Z2dd, strMode="soft") + + out = self.fusionnet( + torch.cat([I1t, rife, I2t], dim=1), + torch.cat([feat1t1, feat2t1], dim=1), + torch.cat([feat1t2, feat2t2], dim=1), + torch.cat([feat1t3, feat2t3], dim=1), + ) + + return torch.clamp(out, 0, 1) diff --git a/vfi_models/gmfss_fortuna/__init__.py b/vfi_models/gmfss_fortuna/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e55bb606f30d7d42eeac118fe8832bd7493957f --- /dev/null +++ b/vfi_models/gmfss_fortuna/__init__.py @@ -0,0 +1,143 @@ +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList +import typing +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.model_management import get_torch_device + + +GLOBAL_MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPTS_PATH_CONFIG = { + "GMFSS_fortuna_union": { + "ifnet": ("rife", "rife46.pth"), + "flownet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_flownet.pkl"), + "metricnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_metric.pkl"), + "feat_ext": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_feat.pkl"), + "fusionnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_fusionnet.pkl") + }, + "GMFSS_fortuna": { + "flownet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_flownet.pkl"), + "metricnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_metric.pkl"), + "feat_ext": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_feat.pkl"), + "fusionnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_fusionnet.pkl") + } +} + +class CommonModelInference(nn.Module): + def __init__(self, model_type): + super(CommonModelInference, self).__init__() + from .GMFSS_Fortuna_arch import Model as GMFSS + from .GMFSS_Fortuna_union_arch import Model as GMFSS_Union + self.model = GMFSS_Union() if "union" in model_type else GMFSS() + self.model.eval() + self.model.device() + _model_path_config = CKPTS_PATH_CONFIG[model_type] + self.model.load_model({ + key: load_file_from_github_release(*_model_path_config[key]) + for key in _model_path_config + }) + + def forward(self, I0, I1, timestep, scale=1.0): + n, c, h, w = I0.shape + tmp = max(64, int(64 / scale)) + ph = ((h - 1) // tmp + 1) * tmp + pw = ((w - 1) // tmp + 1) * tmp + padding = (0, pw - w, 0, ph - h) + I0 = F.pad(I0, padding) + I1 = F.pad(I1, padding) + ( + flow01, + flow10, + metric0, + metric1, + feat11, + feat12, + feat13, + feat21, + feat22, + feat23, + ) = self.model.reuse(I0, I1, scale) + + output = self.model.inference( + I0, + I1, + flow01, + flow10, + metric0, + metric1, + feat11, + feat12, + feat13, + feat21, + feat22, + feat23, + timestep, + ) + return output[:, :, :h, :w] + +class GMFSS_Fortuna_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (list(CKPTS_PATH_CONFIG.keys()), ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + """ + Perform video frame interpolation using a given checkpoint model. + + Args: + ckpt_name (str): The name of the checkpoint model to use. + frames (torch.Tensor): A tensor containing input video frames. + clear_cache_after_n_frames (int, optional): The number of frames to process before clearing CUDA cache + to prevent memory overflow. Defaults to 10. Lower numbers are safer but mean more processing time. + How high you should set it depends on how many input frames there are, input resolution (after upscaling), + how many times you want to multiply them, and how long you're willing to wait for the process to complete. + multiplier (int, optional): The multiplier for each input frame. 60 input frames * 2 = 120 output frames. Defaults to 2. + + Returns: + tuple: A tuple containing the output interpolated frames. + + Note: + This method interpolates frames in a video sequence using a specified checkpoint model. + It processes each frame sequentially, generating interpolated frames between them. + + To prevent memory overflow, it clears the CUDA cache after processing a specified number of frames. + """ + + interpolation_model = CommonModelInference(model_type=ckpt_name) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model, scale): + return model(frame_0, frame_1, timestep, scale) + + scale = 1 + + args = [interpolation_model, scale] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + ) + return (out,) diff --git a/vfi_models/ifrnet/IFRNet_L_arch.py b/vfi_models/ifrnet/IFRNet_L_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..f9554a0bbfbe274ccfa021c618512f5220bbb220 --- /dev/null +++ b/vfi_models/ifrnet/IFRNet_L_arch.py @@ -0,0 +1,293 @@ +# https://github.com/ltkong218/IFRNet/blob/main/models/IFRNet_L.py +# https://github.com/ltkong218/IFRNet/blob/main/utils.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.model_management import get_torch_device + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat( + [ + flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), + flow[:, 1:2, :, :] / ((H - 1.0) / 2.0), + ], + 1, + ) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample( + input=img, + grid=grid_, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + return output + + +def get_robust_weight(flow_pred, flow_gt, beta): + epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5 + robust_weight = torch.exp(-beta * epe) + return robust_weight + + +def resize(x, scale_factor): + return F.interpolate( + x, scale_factor=scale_factor, mode="bilinear", align_corners=False + ) + + +def convrelu( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, +): + return nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias, + ), + nn.PReLU(out_channels), + ) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv2 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv3 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv4 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv5 = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + out[:, -self.side_channels :, :, :] = self.conv2( + out[:, -self.side_channels :, :, :] + ) + out = self.conv3(out) + out[:, -self.side_channels :, :, :] = self.conv4( + out[:, -self.side_channels :, :, :] + ) + out = self.prelu(x + self.conv5(out)) + return out + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.pyramid1 = nn.Sequential( + convrelu(3, 64, 7, 2, 3), convrelu(64, 64, 3, 1, 1) + ) + self.pyramid2 = nn.Sequential( + convrelu(64, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1) + ) + self.pyramid3 = nn.Sequential( + convrelu(96, 144, 3, 2, 1), convrelu(144, 144, 3, 1, 1) + ) + self.pyramid4 = nn.Sequential( + convrelu(144, 192, 3, 2, 1), convrelu(192, 192, 3, 1, 1) + ) + + def forward(self, img): + f1 = self.pyramid1(img) + f2 = self.pyramid2(f1) + f3 = self.pyramid3(f2) + f4 = self.pyramid4(f3) + return f1, f2, f3, f4 + + +class Decoder4(nn.Module): + def __init__(self): + super(Decoder4, self).__init__() + self.convblock = nn.Sequential( + convrelu(384 + 1, 384), + ResBlock(384, 64), + nn.ConvTranspose2d(384, 148, 4, 2, 1, bias=True), + ) + + def forward(self, f0, f1, embt): + b, c, h, w = f0.shape + embt = embt.repeat(1, 1, h, w) + f_in = torch.cat([f0, f1, embt], 1) + f_out = self.convblock(f_in) + return f_out + + +class Decoder3(nn.Module): + def __init__(self): + super(Decoder3, self).__init__() + self.convblock = nn.Sequential( + convrelu(436, 432), + ResBlock(432, 64), + nn.ConvTranspose2d(432, 100, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, up_flow0, up_flow1): + f0_warp = warp(f0, up_flow0) + f1_warp = warp(f1, up_flow1) + f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) + f_out = self.convblock(f_in) + return f_out + + +class Decoder2(nn.Module): + def __init__(self): + super(Decoder2, self).__init__() + self.convblock = nn.Sequential( + convrelu(292, 288), + ResBlock(288, 64), + nn.ConvTranspose2d(288, 68, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, up_flow0, up_flow1): + f0_warp = warp(f0, up_flow0) + f1_warp = warp(f1, up_flow1) + f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) + f_out = self.convblock(f_in) + return f_out + + +class Decoder1(nn.Module): + def __init__(self): + super(Decoder1, self).__init__() + self.convblock = nn.Sequential( + convrelu(196, 192), + ResBlock(192, 64), + nn.ConvTranspose2d(192, 8, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, up_flow0, up_flow1): + f0_warp = warp(f0, up_flow0) + f1_warp = warp(f1, up_flow1) + f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) + f_out = self.convblock(f_in) + return f_out + + +class IRFNet_L(nn.Module): + def __init__(self): + super(IRFNet_L, self).__init__() + self.encoder = Encoder() + self.decoder4 = Decoder4() + self.decoder3 = Decoder3() + self.decoder2 = Decoder2() + self.decoder1 = Decoder1() + + def forward(self, img0, img1, scale_factor=1.0, timestep=0.5): + # emb1 = torch.tensor(1/2).view(1, 1, 1, 1).float() + # emb2 = torch.tensor(2/2).view(1, 1, 1, 1).float() + # embt = torch.cat([emb1, emb2], 0) + n, c, h, w = img0.shape + + ph = ((h - 1) // 64 + 1) * 64 + pw = ((w - 1) // 64 + 1) * 64 + padding = (0, pw - w, 0, ph - h) + img0 = F.pad(img0, padding) + img1 = F.pad(img1, padding) + + #Support multiple batches + embt = torch.tensor([timestep] * n).view(n, 1, 1, 1).float().to(get_torch_device()) + if "HalfTensor" in str(img0.type()): + embt = embt.half() + + mean_ = ( + torch.cat([img0, img1], 2) + .mean(1, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) + img0 = img0 - mean_ + img1 = img1 - mean_ + + img0_ = resize(img0, scale_factor=scale_factor) + img1_ = resize(img1, scale_factor=scale_factor) + + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + out4 = self.decoder4(f0_4, f1_4, embt) + up_flow0_4 = out4[:, 0:2] + up_flow1_4 = out4[:, 2:4] + ft_3_ = out4[:, 4:] + + out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) + up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) + ft_2_ = out3[:, 4:] + + out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) + up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) + ft_1_ = out2[:, 4:] + + out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) + up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) + up_mask_1 = torch.sigmoid(out1[:, 4:5]) + up_res_1 = out1[:, 5:] + + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor)) + up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor)) + + img0_warp = warp(img0, up_flow0_1) + img1_warp = warp(img1, up_flow1_1) + imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ + imgt_pred = imgt_merge + up_res_1 + imgt_pred = torch.clamp(imgt_pred, 0, 1) + return imgt_pred[:, :, :h, :w] diff --git a/vfi_models/ifrnet/IFRNet_S_arch.py b/vfi_models/ifrnet/IFRNet_S_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2282b93b74eddf5792190f77dff0de457ff8b4 --- /dev/null +++ b/vfi_models/ifrnet/IFRNet_S_arch.py @@ -0,0 +1,293 @@ +# https://github.com/ltkong218/IFRNet/blob/main/models/IFRNet_S.py +# https://github.com/ltkong218/IFRNet/blob/main/utils.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.model_management import get_torch_device + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat( + [ + flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), + flow[:, 1:2, :, :] / ((H - 1.0) / 2.0), + ], + 1, + ) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample( + input=img, + grid=grid_, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + return output + + +def get_robust_weight(flow_pred, flow_gt, beta): + epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5 + robust_weight = torch.exp(-beta * epe) + return robust_weight + + +def resize(x, scale_factor): + return F.interpolate( + x, scale_factor=scale_factor, mode="bilinear", align_corners=False + ) + + +def convrelu( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, +): + return nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias=bias, + ), + nn.PReLU(out_channels), + ) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv2 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv3 = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ), + nn.PReLU(in_channels), + ) + self.conv4 = nn.Sequential( + nn.Conv2d( + side_channels, + side_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ), + nn.PReLU(side_channels), + ) + self.conv5 = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias + ) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + out[:, -self.side_channels :, :, :] = self.conv2( + out[:, -self.side_channels :, :, :] + ) + out = self.conv3(out) + out[:, -self.side_channels :, :, :] = self.conv4( + out[:, -self.side_channels :, :, :] + ) + out = self.prelu(x + self.conv5(out)) + return out + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.pyramid1 = nn.Sequential( + convrelu(3, 24, 3, 2, 1), convrelu(24, 24, 3, 1, 1) + ) + self.pyramid2 = nn.Sequential( + convrelu(24, 36, 3, 2, 1), convrelu(36, 36, 3, 1, 1) + ) + self.pyramid3 = nn.Sequential( + convrelu(36, 54, 3, 2, 1), convrelu(54, 54, 3, 1, 1) + ) + self.pyramid4 = nn.Sequential( + convrelu(54, 72, 3, 2, 1), convrelu(72, 72, 3, 1, 1) + ) + + def forward(self, img): + f1 = self.pyramid1(img) + f2 = self.pyramid2(f1) + f3 = self.pyramid3(f2) + f4 = self.pyramid4(f3) + return f1, f2, f3, f4 + + +class Decoder4(nn.Module): + def __init__(self): + super(Decoder4, self).__init__() + self.convblock = nn.Sequential( + convrelu(144 + 1, 144), + ResBlock(144, 24), + nn.ConvTranspose2d(144, 58, 4, 2, 1, bias=True), + ) + + def forward(self, f0, f1, embt): + b, c, h, w = f0.shape + embt = embt.repeat(1, 1, h, w) + f_in = torch.cat([f0, f1, embt], 1) + f_out = self.convblock(f_in) + return f_out + + +class Decoder3(nn.Module): + def __init__(self): + super(Decoder3, self).__init__() + self.convblock = nn.Sequential( + convrelu(166, 162), + ResBlock(162, 24), + nn.ConvTranspose2d(162, 40, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, up_flow0, up_flow1): + f0_warp = warp(f0, up_flow0) + f1_warp = warp(f1, up_flow1) + f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) + f_out = self.convblock(f_in) + return f_out + + +class Decoder2(nn.Module): + def __init__(self): + super(Decoder2, self).__init__() + self.convblock = nn.Sequential( + convrelu(112, 108), + ResBlock(108, 24), + nn.ConvTranspose2d(108, 28, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, up_flow0, up_flow1): + f0_warp = warp(f0, up_flow0) + f1_warp = warp(f1, up_flow1) + f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) + f_out = self.convblock(f_in) + return f_out + + +class Decoder1(nn.Module): + def __init__(self): + super(Decoder1, self).__init__() + self.convblock = nn.Sequential( + convrelu(76, 72), + ResBlock(72, 24), + nn.ConvTranspose2d(72, 8, 4, 2, 1, bias=True), + ) + + def forward(self, ft_, f0, f1, up_flow0, up_flow1): + f0_warp = warp(f0, up_flow0) + f1_warp = warp(f1, up_flow1) + f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1) + f_out = self.convblock(f_in) + return f_out + + +class IRFNet_S(nn.Module): + def __init__(self): + super(IRFNet_S, self).__init__() + self.encoder = Encoder() + self.decoder4 = Decoder4() + self.decoder3 = Decoder3() + self.decoder2 = Decoder2() + self.decoder1 = Decoder1() + + def forward(self, img0, img1, scale_factor=1.0, timestep=0.5): + # emb1 = torch.tensor(1/2).view(1, 1, 1, 1).float() + # emb2 = torch.tensor(2/2).view(1, 1, 1, 1).float() + # embt = torch.cat([emb1, emb2], 0) + n, c, h, w = img0.shape + + ph = ((h - 1) // 64 + 1) * 64 + pw = ((w - 1) // 64 + 1) * 64 + padding = (0, pw - w, 0, ph - h) + img0 = F.pad(img0, padding) + img1 = F.pad(img1, padding) + + #Support multiple batches + embt = torch.tensor([timestep] * n).view(n, 1, 1, 1).float().to(get_torch_device()) + if "HalfTensor" in str(img0.type()): + embt = embt.half() + + mean_ = ( + torch.cat([img0, img1], 2) + .mean(1, keepdim=True) + .mean(2, keepdim=True) + .mean(3, keepdim=True) + ) + img0 = img0 - mean_ + img1 = img1 - mean_ + + img0_ = resize(img0, scale_factor=scale_factor) + img1_ = resize(img1, scale_factor=scale_factor) + + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + out4 = self.decoder4(f0_4, f1_4, embt) + up_flow0_4 = out4[:, 0:2] + up_flow1_4 = out4[:, 2:4] + ft_3_ = out4[:, 4:] + + out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0) + up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0) + ft_2_ = out3[:, 4:] + + out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0) + up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0) + ft_1_ = out2[:, 4:] + + out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0) + up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0) + up_mask_1 = torch.sigmoid(out1[:, 4:5]) + up_res_1 = out1[:, 5:] + + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * ( + 1.0 / scale_factor + ) + up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor)) + up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor)) + + img0_warp = warp(img0, up_flow0_1) + img1_warp = warp(img1, up_flow1_1) + imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_ + imgt_pred = imgt_merge + up_res_1 + imgt_pred = torch.clamp(imgt_pred, 0, 1) + return imgt_pred[:, :, :h, :w] diff --git a/vfi_models/ifrnet/__init__.py b/vfi_models/ifrnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38a9b771f535dbf081c3a2f32af4e8079cb6203b --- /dev/null +++ b/vfi_models/ifrnet/__init__.py @@ -0,0 +1,57 @@ +import torch +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames +import typing +from comfy.model_management import get_torch_device +from vfi_utils import generic_frame_loop, InterpolationStateList + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAMES = ["IFRNet_S_Vimeo90K.pth", "IFRNet_L_Vimeo90K.pth", "IFRNet_S_GoPro.pth", "IFRNet_L_GoPro.pth"] + +class IFRNet_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (CKPT_NAMES, ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), + "scale_factor": ([0.25, 0.5, 1.0, 2.0, 4.0], {"default": 1.0}), + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames: typing.SupportsInt = 1, + multiplier: typing.SupportsInt = 2, + scale_factor: typing.SupportsFloat = 1.0, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + from .IFRNet_S_arch import IRFNet_S + from .IFRNet_L_arch import IRFNet_L + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + interpolation_model = IRFNet_S() if 'S' in ckpt_name else IRFNet_L() + interpolation_model.load_state_dict(torch.load(model_path)) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model, scale_factor): + return model(frame_0, frame_1, timestep, scale_factor) + + args = [interpolation_model, scale_factor] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + ) + return (out,) diff --git a/vfi_models/ifunet/IFUNet_arch.py b/vfi_models/ifunet/IFUNet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..82066132e8fb370d8016afbedb47afed7779c9bb --- /dev/null +++ b/vfi_models/ifunet/IFUNet_arch.py @@ -0,0 +1,766 @@ +""" +https://github.com/98mxr/IFUNet/blob/main/model/IFUNet.py +https://github.com/98mxr/IFUNet/blob/main/model/cbam.py +https://github.com/98mxr/IFUNet/blob/main/model/warplayer.py +https://github.com/98mxr/IFUNet/blob/5be535c8cff66d6fa1967252685719df4c0620e4/model/RIFE.py +https://github.com/98mxr/IFUNet/blob/main/model/rrdb.py +https://github.com/98mxr/IFUNet/blob/main/model/ResynNet.py +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.model_management import get_torch_device + +backwarp_tenGrid = {} +device = get_torch_device() + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.PReLU(out_planes), + ) + + +class DegCNN(nn.Module): + def __init__(self): + super(DegCNN, self).__init__() + self.conv0 = conv(3, 32, 3, 2, 1) + self.conv1 = conv(32, 32, 3, 2, 1) + self.conv2 = conv(32, 32, 3, 2, 1) + self.conv3 = conv(32, 32, 3, 2, 1) + self.deconv = nn.Sequential( + nn.Dropout2d(0.95), + nn.ConvTranspose2d(4 * 32, 32, 4, 2, 1), + nn.PReLU(32), + nn.Conv2d(32, 3, 3, 1, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + f0 = self.conv0(x) + f1 = self.conv1(f0) + f2 = self.conv2(f1) + f3 = self.conv3(f2) + f1 = F.interpolate(f1, scale_factor=2.0, mode="bilinear", align_corners=False) + f2 = F.interpolate(f2, scale_factor=4.0, mode="bilinear", align_corners=False) + f3 = F.interpolate(f3, scale_factor=8.0, mode="bilinear", align_corners=False) + return self.deconv(torch.cat((f0, f1, f2, f3), 1)) + + +class FlowBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(FlowBlock, self).__init__() + self.conv0 = nn.Sequential( + conv_bn(in_planes, c // 2, 3, 2, 1), + conv_bn(c // 2, c, 3, 2, 1), + conv_bn(c, 2 * c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + conv_bn(2 * c, 2 * c), + conv_bn(2 * c, 2 * c), + conv_bn(2 * c, 2 * c), + conv_bn(2 * c, 2 * c), + conv_bn(2 * c, 2 * c), + conv_bn(2 * c, 2 * c), + ) + self.lastconv = nn.ConvTranspose2d(2 * c, 4, 4, 2, 1) + + def forward(self, x, flow, scale=1): + x = F.interpolate( + x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + if flow is not None: + flow = ( + F.interpolate( + flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + * 1.0 + / scale + ) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + feat + tmp = self.lastconv(feat) + tmp = F.interpolate( + tmp, scale_factor=scale * 4, mode="bilinear", align_corners=False + ) + flow = tmp[:, :2] * scale * 4 + mask = tmp[:, 2:3] + return flow, mask + + +class ResynNet(nn.Module): + def __init__(self): + super(ResynNet, self).__init__() + self.block0 = FlowBlock(6, c=128) + self.block1 = FlowBlock(12, c=128) + self.block2 = FlowBlock(12, c=128) + self.degrad = DegCNN() + # Contextual Refinement context + decode + self.context0 = nn.Sequential( + conv(3, 16, 3, 2, 1), + conv(16, 32, 3, 2, 1), + ) + self.context1 = nn.Sequential( + conv(3, 16, 3, 2, 1), + conv(16, 32, 3, 2, 1), + ) + self.decode = nn.Sequential( + nn.ConvTranspose2d(64, 32, 4, 2, 1), + nn.ConvTranspose2d(32, 3, 4, 2, 1), + nn.Tanh(), + ) + + def calflow(self, img0, lowres, scale): + flow = None + stu = [self.block0, self.block1, self.block2] + for i in range(3): + if flow is not None: + flow_d, mask_d = stu[i]( + torch.cat((img0, lowres, warped_img0, mask), 1), + flow, + scale=scale[i], + ) + flow = flow + flow_d + mask = mask + mask_d + else: + flow, mask = stu[i](torch.cat((img0, lowres), 1), None, scale=scale[i]) + warped_img0 = warp(img0, flow) + flow_down = ( + F.interpolate(flow, scale_factor=0.25, mode="bilinear", align_corners=False) + * 0.25 + ) + c0 = warp(self.context0(img0), flow_down) + c1 = self.context1(warped_img0) + warped_img0 = warped_img0 + self.decode(torch.cat((c0, c1), 1)) + return flow, mask, torch.clamp(warped_img0, 0, 1) + + def forward( + self, x, deg=None, gt=None, scale=[4, 2, 1], training=False, blend=True + ): + if training: + deg = self.degrad(gt) + loss_cons = (gt - deg).abs().mean() + else: + loss_cons = torch.tensor([0]) + img_list = [] + N = x.shape[1] // 3 + for i in range(N): + img_list.append(x[:, i * 3 : i * 3 + 3]) + warped_list = [] + merged = [] + mask_list = [] + flow_list = [] + for i in range(N): + f, m, img = self.calflow(img_list[i], deg.detach(), scale) + mask_list.append(m) + warped_list.append(img) + flow_list.append(f) + if blend: + N += 1 + mask_list.append(m * 0) + warped_list.append(deg) + mask = F.softmax(torch.clamp(torch.cat(mask_list, 1), -4, 4), dim=1) + merged = 0 + for i in range(N): + merged += warped_list[i] * mask[:, i : i + 1] + return merged, loss_cons + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + # default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + # 只能先取消,default_init_weights来自basicsr.arch_util + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + # 原作者这么说我就这么听着吧 + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + # 原作者这么说我就这么听着吧 + return out * 0.2 + x + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__( + self, num_in_ch=16, num_out_ch=1, num_feat=64, num_block=6, num_grow_ch=32 + ): + super(RRDBNet, self).__init__() + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer( + RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch + ) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, img0, img1, warped_img0, warped_img1, flow): + x = torch.cat((img0, img1, warped_img0, warped_img1), 1) + x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False) + flow = ( + F.interpolate(flow, scale_factor=0.25, mode="bilinear", align_corners=False) + * 0.25 + ) + feat = torch.cat((x, flow), 1) + + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample,充分利用四倍放大 + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2.0, mode="nearest")) + ) + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2.0, mode="nearest")) + ) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + + out = torch.sigmoid(out) + return out + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + + +class BasicConv(nn.Module): + def __init__( + self, + in_planes, + out_planes, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + relu=True, + bn=True, + bias=False, + ): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + self.bn = ( + nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) + if bn + else None + ) + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels), + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == "avg": + avg_pool = F.avg_pool2d( + x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) + ) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == "max": + max_pool = F.max_pool2d( + x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) + ) + channel_att_raw = self.mlp(max_pool) + elif pool_type == "lp": + lp_pool = F.lp_pool2d( + x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) + ) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == "lse": + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( + (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 + ) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv( + 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False + ) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale + + +class CBAM(nn.Module): + def __init__( + self, + gate_channels, + reduction_ratio=16, + pool_types=["avg", "max"], + no_spatial=False, + ): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +class UNetConv(nn.Module): + def __init__(self, in_planes, out_planes, att=True): + super(UNetConv, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, 2, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + if att: + self.cbam = CBAM(out_planes, 16) # 这一步导致了通道数最低为128 + else: + self.cbam = None + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + if self.cbam is not None: + x = self.cbam(x) + return x + + +class UpConv(nn.Module): + def __init__(self, in_planes, out_planes, att=True): + super(UpConv, self).__init__() + self.deconv = nn.Sequential( + nn.ConvTranspose2d(in_planes, in_planes // 2, 4, 2, 1), + nn.PReLU(in_planes // 2), + ) + + # 也许不需要这么卷积,我不确定 + self.conv1 = conv(in_planes, in_planes // 2, 3, 1, 1) + self.conv2 = conv(in_planes // 2, out_planes, 3, 1, 1) + + if att: + self.cbam = CBAM(out_planes, 16) + else: + self.cbam = None + + def forward(self, x1, x2): + x1 = self.deconv(x1) + y = self.conv1(torch.cat((x1, x2), 1)) + y = self.conv2(y) + if self.cbam is not None: + y = self.cbam(y) + return y + + +class FeatureNet(nn.Module): + def __init__(self, in_planes, out_planes): + super(FeatureNet, self).__init__() + # 处理IFBlock0时通道数问题 + self.conv0 = conv(7, in_planes, 1, 1, 0) + + self.conv1 = UNetConv(in_planes, out_planes // 8, att=False) + self.conv2 = UNetConv(out_planes // 8, out_planes // 4, att=True) + self.conv3 = UNetConv(out_planes // 4, out_planes // 2, att=True) + self.conv4 = UNetConv(out_planes // 2, out_planes, att=True) + self.conv5 = UNetConv(out_planes, 2 * out_planes, att=True) + + self.deconv5 = UpConv(2 * out_planes, out_planes, att=True) + self.deconv4 = UpConv(out_planes, out_planes // 2, att=False) + self.deconv3 = UpConv(out_planes // 2, out_planes // 4, att=False) + + def forward(self, x, level=0): + if x.shape[1] != 17: + x = self.conv0(x) + x2 = self.conv1(x) + x4 = self.conv2(x2) + x8 = self.conv3(x4) + x16 = self.conv4(x8) + x32 = self.conv5(x16) + y = self.deconv5(x32, x16) # 匹配IFBlock0通道和尺寸 + + # “早退机制”以期待用同一个UNet提取特征,不确定是否对训练产生影响 + if level != 0: + y = self.deconv4(y, x8) # 匹配IFBlock1通道和尺寸 + if level == 2: + y = self.deconv3(y, x4) # 匹配IFBlock2通道和尺寸 + return y + + +class IFBlock(nn.Module): + def __init__(self, c=64, level=0): + super(IFBlock, self).__init__() + self.convblock = nn.Sequential( + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + conv(c, c), + ) + self.flowconv = nn.Conv2d(c, 4, 3, 1, 1) + self.maskconvx16 = nn.Conv2d(c, 16 * 16 * 9, 1, 1, 0) + self.maskconvx8 = nn.Conv2d(c, 8 * 8 * 9, 1, 1, 0) + self.maskconvx4 = nn.Conv2d(c, 4 * 4 * 9, 1, 1, 0) + + self.level = level + assert self.level in [4, 8, 16], "Bitch" + + def mask_conv(self, x): + if self.level == 4: + return self.maskconvx4(x) + if self.level == 8: + return self.maskconvx8(x) + if self.level == 16: + return self.maskconvx16(x) + + def upsample_flow(self, flow, mask): + # 俺寻思俺懂了 + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, self.level, self.level, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.level * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, self.level * H, self.level * W) + + def forward(self, x, scale): + x = self.convblock(x) + x # 类似ResNet的f(x) + x + tmp = self.flowconv(x) + up_mask = self.mask_conv(x) + flow_up = self.upsample_flow(tmp, up_mask) + flow = ( + F.interpolate( + flow_up, scale_factor=scale, mode="bilinear", align_corners=False + ) + * scale + ) + return flow + + +class IFUNet(nn.Module): + def __init__(self): + super(IFUNet, self).__init__() + # block0通道数必须为128的整倍数 + self.fmap = FeatureNet(in_planes=17, out_planes=256) + self.block0 = IFBlock(c=256, level=16) + self.block1 = IFBlock(c=128, level=8) + self.block2 = IFBlock(c=64, level=4) + + def forward(self, x, scale=1.0, timestep=0.5, ensemble=True): + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + warped_img0 = img0 + warped_img1 = img1 + flow = None + block = [self.block0, self.block1, self.block2] + for i in range(3): + if flow != None: + x = torch.cat((img0, img1, timestep, warped_img0, warped_img1), 1) + flowtmp = flow + if scale != 1: + x = F.interpolate( + x, scale_factor=scale, mode="bilinear", align_corners=False + ) + flowtmp = ( + F.interpolate( + flow, + scale_factor=scale, + mode="bilinear", + align_corners=False, + ) + * scale + ) + x = torch.cat((x, flowtmp), 1) + # 期待UNet能提取到特征,不再需要ensemble + Fmap = self.fmap(x, level=i) + flow_d = block[i](Fmap, scale=1.0 / scale) + flow = flow + flow_d + + if ensemble: + x = torch.cat( + (img1, img0, 1 - timestep, warped_img0, warped_img1), 1 + ) + flowtmp = flow + if scale != 1: + x = F.interpolate( + x, scale_factor=scale, mode="bilinear", align_corners=False + ) + flowtmp = ( + F.interpolate( + flow, + scale_factor=scale, + mode="bilinear", + align_corners=False, + ) + * scale + ) + x = torch.cat((x, flowtmp), 1) + # 期待UNet能提取到特征,不再需要ensemble + Fmap = self.fmap(x, level=i) + flow_d = block[i](Fmap, scale=1.0 / scale) + flow2 = flow + flow_d + flow = (flow + flow2) / 2 + else: + x = torch.cat((img0, img1, timestep), 1) + if scale != 1: + x = F.interpolate( + x, scale_factor=scale, mode="bilinear", align_corners=False + ) + Fmap = self.fmap(x, level=i) + flow = block[i](Fmap, scale=1.0 / scale) + + if ensemble: + x = torch.cat((img1, img0, 1 - timestep), 1) + if scale != 1: + x = F.interpolate( + x, scale_factor=scale, mode="bilinear", align_corners=False + ) + Fmap = self.fmap(x, level=i) + flow2 = block[i](Fmap, scale=1.0 / scale) + flow = (flow + flow2) / 2 + + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + return flow, warped_img0, warped_img1 + + +class IFUNetModel(nn.Module): + def __init__(self, local_rank=-1): + super(IFUNetModel, self).__init__() + self.flownet = IFUNet() + self.fusionnet = RRDBNet() + self.refinenet = ResynNet() + + def forward(self, img0, img1, timestep=0.5, scale=1.0, ensemble=False): + n, c, h, w = img0.shape + ph = ((h - 1) // 64 + 1) * 64 + pw = ((w - 1) // 64 + 1) * 64 + padding = (0, pw - w, 0, ph - h) + img0 = F.pad(img0, padding) + img1 = F.pad(img1, padding) + + imgs = torch.cat((img0, img1), 1) + flow, warped_img0, warped_img1 = self.flownet(imgs, scale, timestep, ensemble) + mask = self.fusionnet(img0, img1, warped_img0, warped_img1, flow) + merged = warped_img0 * mask + warped_img1 * (1 - mask) + merged, _ = self.refinenet(imgs, deg=merged, scale=[4, 2, 1]) + return merged[:, :, :h, :w] diff --git a/vfi_models/ifunet/__init__.py b/vfi_models/ifunet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec98675f6cc459a026a0837c90915016c058d79 --- /dev/null +++ b/vfi_models/ifunet/__init__.py @@ -0,0 +1,59 @@ +import torch +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList +import typing +from comfy.model_management import get_torch_device + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAMES = ["IFUNet.pth"] + +class IFUnet_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (CKPT_NAMES, ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), + "scale_factor": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 100, "step": 0.1}), + "ensemble": ("BOOLEAN", {"default":True}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames: typing.SupportsInt = 1, + multiplier: typing.SupportsInt = 2, + scale_factor: typing.SupportsFloat = 1.0, + ensemble: bool = True, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + from .IFUNet_arch import IFUNetModel + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + interpolation_model = IFUNetModel() + interpolation_model.load_state_dict(torch.load(model_path)) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model, scale_factor, ensemble): + return model(frame_0, frame_1, timestep=timestep, scale=scale_factor, ensemble=ensemble) + + args = [interpolation_model, scale_factor, ensemble] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + ) + return (out,) + diff --git a/vfi_models/m2m/M2M_arch.py b/vfi_models/m2m/M2M_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..536f915efa4678bd55884852ad6e3ea386fd3ffe --- /dev/null +++ b/vfi_models/m2m/M2M_arch.py @@ -0,0 +1,1037 @@ +""" +https://github.com/feinanshan/M2M_VFI/blob/main/Test/model/py +https://raw.githubusercontent.com/feinanshan/M2M_VFI/main/Test/model/py +https://github.com/feinanshan/M2M_VFI/blob/main/Test/model/py +https://github.com/feinanshan/M2M_VFI/blob/main/Test/model/py +https://github.com/feinanshan/M2M_VFI/blob/main/Test/model/m2m.py +""" + +import collections +import math +import os +import re +import torch +import typing +from vfi_models.ops import softsplat_func +from vfi_models.ops import costvol_func + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn: torch.Tensor, tenFlow: torch.Tensor): + if ( + "grid" + + str(tenFlow.dtype) + + str(tenFlow.device) + + str(tenFlow.shape[2]) + + str(tenFlow.shape[3]) + not in objBackwarpcache + ): + tenHor = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenFlow.shape[3], + dtype=tenFlow.dtype, + device=tenFlow.device, + ) + .view(1, 1, 1, -1) + .repeat(1, 1, tenFlow.shape[2], 1) + ) + tenVer = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenFlow.shape[2], + dtype=tenFlow.dtype, + device=tenFlow.device, + ) + .view(1, 1, -1, 1) + .repeat(1, 1, 1, tenFlow.shape[3]) + ) + + objBackwarpcache[ + "grid" + + str(tenFlow.dtype) + + str(tenFlow.device) + + str(tenFlow.shape[2]) + + str(tenFlow.shape[3]) + ] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor( + data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], + dtype=tenFlow.dtype, + device=tenFlow.device, + ).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample( + input=tenIn, + grid=( + objBackwarpcache[ + "grid" + + str(tenFlow.dtype) + + str(tenFlow.device) + + str(tenFlow.shape[2]) + + str(tenFlow.shape[3]) + ] + + tenFlow + ).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + + +# end + +########################################################## + + +class Basic(torch.nn.Module): + def __init__( + self, + strType: str, + intChans: typing.List[int], + objScratch: typing.Optional[typing.Dict] = None, + ): + super().__init__() + + self.strType = strType + self.netEvenize = None + self.netMain = None + self.netShortcut = None + + intIn = intChans[0] + intOut = intChans[-1] + netMain = [] + intChans = intChans.copy() + fltStride = 1.0 + + for intPart, strPart in enumerate(self.strType.split("+")[0].split("-")): + if strPart.startswith("evenize") == True and intPart == 0: + + class Evenize(torch.nn.Module): + def __init__(self, strPad): + super().__init__() + + self.strPad = strPad + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + intPad = [0, 0, 0, 0] + + if tenIn.shape[3] % 2 != 0: + intPad[1] = 1 + if tenIn.shape[2] % 2 != 0: + intPad[3] = 1 + + if min(intPad) != 0 or max(intPad) != 0: + tenIn = torch.nn.functional.pad( + input=tenIn, + pad=intPad, + mode=self.strPad + if self.strPad != "zeros" + else "constant", + value=0.0, + ) + # end + + return tenIn + + # end + + # end + + strPad = "zeros" + + if "(" in strPart: + if "replpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "replicate" + if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "reflect" + # end + + self.netEvenize = Evenize(strPad) + + elif strPart.startswith("conv") == True: + intKsize = 3 + intPad = 1 + strPad = "zeros" + + if "(" in strPart: + intKsize = int(strPart.split("(")[1].split(")")[0].split(",")[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if "replpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "replicate" + if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "reflect" + # end + + if "nopad" in self.strType.split("+"): + intPad = 0 + # end + + netMain += [ + torch.nn.Conv2d( + in_channels=intChans[0], + out_channels=intChans[1], + kernel_size=intKsize, + stride=1, + padding=intPad, + padding_mode=strPad, + bias="nobias" not in self.strType.split("+"), + ) + ] + intChans = intChans[1:] + fltStride *= 1.0 + + elif strPart.startswith("sconv") == True: + intKsize = 3 + intPad = 1 + strPad = "zeros" + + if "(" in strPart: + intKsize = int(strPart.split("(")[1].split(")")[0].split(",")[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if "replpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "replicate" + if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "reflect" + # end + + if "nopad" in self.strType.split("+"): + intPad = 0 + # end + + netMain += [ + torch.nn.Conv2d( + in_channels=intChans[0], + out_channels=intChans[1], + kernel_size=intKsize, + stride=2, + padding=intPad, + padding_mode=strPad, + bias="nobias" not in self.strType.split("+"), + ) + ] + intChans = intChans[1:] + fltStride *= 2.0 + + elif strPart.startswith("up") == True: + + class Up(torch.nn.Module): + def __init__(self, strType): + super().__init__() + + self.strType = strType + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + if self.strType == "nearest": + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=2.0, + mode="nearest-exact", + align_corners=False, + ) + + elif self.strType == "bilinear": + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=2.0, + mode="bilinear", + align_corners=False, + ) + + elif self.strType == "pyramid": + return pyramid(tenIn, None, "up") + + elif self.strType == "shuffle": + return torch.nn.functional.pixel_shuffle( + tenIn, upscale_factor=2 + ) # https://github.com/pytorch/pytorch/issues/62854 + + # end + + assert False # to make torchscript happy + + # end + + # end + + strType = "bilinear" + + if "(" in strPart: + if "nearest" in strPart.split("(")[1].split(")")[0].split(","): + strType = "nearest" + if "pyramid" in strPart.split("(")[1].split(")")[0].split(","): + strType = "pyramid" + if "shuffle" in strPart.split("(")[1].split(")")[0].split(","): + strType = "shuffle" + # end + + netMain += [Up(strType)] + fltStride *= 0.5 + + elif strPart.startswith("prelu") == True: + netMain += [ + torch.nn.PReLU( + num_parameters=1, + init=float(strPart.split("(")[1].split(")")[0].split(",")[0]), + ) + ] + fltStride *= 1.0 + + elif True: + assert False + + # end + # end + + self.netMain = torch.nn.Sequential(*netMain) + + for strPart in self.strType.split("+")[1:]: + if strPart.startswith("skip") == True: + if intIn == intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Identity() + + elif intIn != intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Conv2d( + in_channels=intIn, + out_channels=intOut, + kernel_size=1, + stride=1, + padding=0, + bias="nobias" not in self.strType.split("+"), + ) + + elif intIn == intOut and fltStride != 1.0: + + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=self.fltScale, + mode="bilinear", + align_corners=False, + ) + + # end + + # end + + self.netShortcut = Down(1.0 / fltStride) + + elif intIn != intOut and fltStride != 1.0: + + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=self.fltScale, + mode="bilinear", + align_corners=False, + ) + + # end + + # end + + self.netShortcut = torch.nn.Sequential( + Down(1.0 / fltStride), + torch.nn.Conv2d( + in_channels=intIn, + out_channels=intOut, + kernel_size=1, + stride=1, + padding=0, + bias="nobias" not in self.strType.split("+"), + ), + ) + + # end + + elif strPart.startswith("...") == True: + pass + + # end + # end + + assert len(intChans) == 1 + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + if self.netEvenize is not None: + tenIn = self.netEvenize(tenIn) + # end + + tenOut = self.netMain(tenIn) + + if self.netShortcut is not None: + tenOut = tenOut + self.netShortcut(tenIn) + # end + + return tenOut + + # end + + +# end + + +########################################################## + + +class Network(torch.nn.Module): + def __init__(self): + super().__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netOne = Basic( + "evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)", + [3, 32, 32, 32], + None, + ) + self.netTwo = Basic( + "evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)", + [32, 32, 32, 32], + None, + ) + self.netThr = Basic( + "evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)", + [32, 32, 32, 32], + None, + ) + + # end + + def forward(self, tenIn): + tenOne = self.netOne(tenIn) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = torch.nn.functional.avg_pool2d( + input=tenThr, kernel_size=2, stride=2, count_include_pad=False + ) + tenFiv = torch.nn.functional.avg_pool2d( + input=tenFou, kernel_size=2, stride=2, count_include_pad=False + ) + + return [tenOne, tenTwo, tenThr, tenFou, tenFiv] + + # end + + # end + + class Decoder(torch.nn.Module): + def __init__(self, intChannels): + super().__init__() + + self.netCostacti = torch.nn.PReLU(num_parameters=1, init=0.25) + self.netMain = Basic( + "conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)", + [intChannels, 128, 128, 96, 64, 32, 2], + None, + ) + + # end + + def forward(self, tenOne, tenTwo, tenFlow): + if tenFlow is not None: + tenFlow = 2.0 * torch.nn.functional.interpolate( + input=tenFlow, + scale_factor=2.0, + mode="bilinear", + align_corners=False, + ) + # end + + tenMain = [] + + if tenFlow is None: + tenMain.append(tenOne) + tenMain.append(self.netCostacti(costvol_func.apply(tenOne, tenTwo))) + + elif tenFlow is not None: + tenMain.append(tenOne) + tenMain.append( + self.netCostacti( + costvol_func.apply( + tenOne, backwarp(tenTwo, tenFlow.detach()) + ) + ) + ) + tenMain.append(tenFlow) + + # end + + return (tenFlow if tenFlow is not None else 0.0) + self.netMain( + torch.cat(tenMain, 1) + ) + + # end + + # end + + self.netExtractor = Extractor() + + self.netFiv = Decoder(32 + 81 + 0) + self.netFou = Decoder(32 + 81 + 2) + self.netThr = Decoder(32 + 81 + 2) + self.netTwo = Decoder(32 + 81 + 2) + self.netOne = Decoder(32 + 81 + 2) + + # end + + def bidir(self, tenOne, tenTwo): + tenOne, tenTwo = list( + zip( + *[ + torch.split(tenFeat, [tenOne.shape[0], tenTwo.shape[0]], 0) + for tenFeat in self.netExtractor(torch.cat([tenOne, tenTwo], 0)) + ] + ) + ) + + tenFwd = None + tenFwd = self.netFiv(tenOne[-1], tenTwo[-1], tenFwd) + tenFwd = self.netFou(tenOne[-2], tenTwo[-2], tenFwd) + tenFwd = self.netThr(tenOne[-3], tenTwo[-3], tenFwd) + tenFwd = self.netTwo(tenOne[-4], tenTwo[-4], tenFwd) + tenFwd = self.netOne(tenOne[-5], tenTwo[-5], tenFwd) + + tenBwd = None + tenBwd = self.netFiv(tenTwo[-1], tenOne[-1], tenBwd) + tenBwd = self.netFou(tenTwo[-2], tenOne[-2], tenBwd) + tenBwd = self.netThr(tenTwo[-3], tenOne[-3], tenBwd) + tenBwd = self.netTwo(tenTwo[-4], tenOne[-4], tenBwd) + tenBwd = self.netOne(tenTwo[-5], tenOne[-5], tenBwd) + + return tenFwd, tenBwd + + # end + + +# end + +########################################################## + + +def forwarp_mframe_mask( + tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None +): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat( + [ + tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), + td * (tenMetric).clip(-20.0, 20.0).exp(), + ], + 1, + ) + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOut = 0 + tenNormalize = 0 + for idx in range(flow_num): + tenOutF, tenNormalizeF = one_fdir( + tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx] + ) + tenOutB, tenNormalizeB = one_fdir( + tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx] + ) + + tenOut += tenOutF + tenOutB + tenNormalize += tenNormalizeF + tenNormalizeB + + return tenOut / tenNormalize, tenNormalize < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + torch.nn.PReLU(out_planes), + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=4, + stride=2, + padding=1, + bias=True, + ), + torch.nn.PReLU(out_planes), + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, 1, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d( + 16 * c, + 16 * 16 * c, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ), + torch.nn.Sigmoid(), + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d( + 16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True + ), + torch.nn.Sigmoid(), + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d( + 16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True + ), + torch.nn.Sigmoid(), + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = ( + torch.nn.functional.interpolate( + flow0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + flow1 = ( + torch.nn.functional.interpolate( + flow1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = ( + torch.nn.functional.interpolate( + flow0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + flow1 = ( + torch.nn.functional.interpolate( + flow1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = ( + torch.nn.functional.interpolate( + flow0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + flow1 = ( + torch.nn.functional.interpolate( + flow1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = ( + torch.nn.functional.interpolate( + flow0, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + flow1 = ( + torch.nn.functional.interpolate( + flow1, scale_factor=0.5, mode="bilinear", align_corners=False + ) + * 0.5 + ) + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0.repeat(1, self.branch, 1, 1), m1.repeat(1, self.branch, 1, 1) + + +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.netFlow = Network() + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate( + input=flow0, + scale_factor=ratio, + mode="bilinear", + align_corners=False, + ) + flow1 = ratio * torch.nn.functional.interpolate( + input=flow1, + scale_factor=ratio, + mode="bilinear", + align_corners=False, + ) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, im0, im1, fltTimes=[0.5], ratio=None): + if ratio is None: + ratio = self.ratio + + intWidth = im0.shape[3] and im1.shape[3] + intHeight = im0.shape[2] and im1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + im0 = torch.nn.functional.pad( + input=im0, pad=[0, intPadr, 0, intPadb], mode="replicate" + ) + im1 = torch.nn.functional.pad( + input=im1, pad=[0, intPadr, 0, intPadb], mode="replicate" + ) + + N_, C_, H_, W_ = im0.shape + + outputs = [] + + with torch.set_grad_enabled(False): + tenStats = [im0, im1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len( + tenStats + ) + tenStd_ = ( + sum( + [ + tenIn.std([1, 2, 3], False, True).square() + + (tenMean_ - tenIn.mean([1, 2, 3], True)).square() + for tenIn in tenStats + ] + ) + / len(tenStats) + ).sqrt() + + im0_o = (im0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (im1 - tenMean_) / (tenStd_ + 0.0000001) + + im0 = (im0 - tenMean_) / (tenStd_ + 0.0000001) + im1 = (im1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate( + input=im0, scale_factor=2.0 / ratio, mode="bilinear", align_corners=False + ) + im1_ = torch.nn.functional.interpolate( + input=im1, scale_factor=2.0 / ratio, mode="bilinear", align_corners=False + ) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, im0, im1, ratio) + + for fltTime_ in fltTimes: + im0 = im0_o.repeat(1, self.branch, 1, 1) + im1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = fltTime_.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view( + N_ * self.branch, 2, H_, W_ + ) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view( + N_ * self.branch, 2, H_, W_ + ) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view( + N_ * self.branch, 1, H_, W_ + ) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view( + N_ * self.branch, 1, H_, W_ + ) + + im0 = im0.reshape(N_, self.branch, 3, H_, W_).view( + N_ * self.branch, 3, H_, W_ + ) + im1 = im1.reshape(N_, self.branch, 3, H_, W_).view( + N_ * self.branch, 3, H_, W_ + ) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view( + N_ * self.branch, 1, 1, 1 + ) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view( + N_ * self.branch, 1, 1, 1 + ) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view( + N_ * self.branch, 1, 1, 1 + ) + + tenPhotoone = ( + ( + 1.0 + - ( + WeiMF + * (im0 - backwarp(im1, tenFwd).detach()).abs().mean([1], True) + ) + ) + .clip(0.001, None) + .square() + ) + tenPhototwo = ( + ( + 1.0 + - ( + WeiMB + * (im1 - backwarp(im0, tenBwd).detach()).abs().mean([1], True) + ) + ) + .clip(0.001, None) + .square() + ) + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + im0 = im0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + im1 = im1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask( + im0, flow0, t1, im1, flow1, t0, metric0, metric1 + ) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + outputs.append((tenOutput * (tenStd_ + 0.0000001)) + tenMean_) + + return [output[:, :, :intHeight, :intWidth] for output in outputs] diff --git a/vfi_models/m2m/__init__.py b/vfi_models/m2m/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88e489e7825b005010e85009843b122aee8edbd6 --- /dev/null +++ b/vfi_models/m2m/__init__.py @@ -0,0 +1,60 @@ +import pathlib +import torch +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames +import typing +from comfy.model_management import get_torch_device +from vfi_utils import InterpolationStateList, generic_frame_loop + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAMES = ["M2M.pth"] + + +class M2M_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (CKPT_NAMES, ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}), + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames: typing.SupportsInt = 1, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + from .M2M_arch import M2M_PWC + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + interpolation_model = M2M_PWC() + interpolation_model.load_state_dict(torch.load(model_path)) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, int_timestep, model): + tenSteps = [ + torch.FloatTensor([int_timestep] * len(frame_0)).view(len(frame_0), 1, 1, 1).to(get_torch_device()) + ] + return model(frame_0, frame_1, tenSteps)[0] + + args = [interpolation_model] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + ) + return (out,) diff --git a/vfi_models/ops/__init__.py b/vfi_models/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25b13b405172ab46f9b45769520c58f5c098286d --- /dev/null +++ b/vfi_models/ops/__init__.py @@ -0,0 +1,22 @@ +import torch.multiprocessing as mp + +if mp.current_process().name == "MainProcess": + import yaml + import os + from pathlib import Path + + config_path = Path(Path(__file__).parent.parent.parent.resolve(), "config.yaml") + + if os.path.exists(config_path): + config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) + ops_backend = config["ops_backend"] + else: + ops_backend = "taichi" + + assert ops_backend in ["taichi", "cupy"] + + if ops_backend == "taichi": + from .taichi_ops import softsplat, ModuleSoftsplat, FunctionSoftsplat, softsplat_func, costvol_func, sepconv_func, init, batch_edt, FunctionAdaCoF, ModuleCorrelation, FunctionCorrelation, _FunctionCorrelation + else: + from .cupy_ops import softsplat, ModuleSoftsplat, FunctionSoftsplat, softsplat_func, costvol_func, sepconv_func, init, batch_edt, FunctionAdaCoF, ModuleCorrelation, FunctionCorrelation, _FunctionCorrelation + diff --git a/vfi_models/ops/cupy_ops/__init__.py b/vfi_models/ops/cupy_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10090797e7a12b14bafe26d0c3235d6e80f82158 --- /dev/null +++ b/vfi_models/ops/cupy_ops/__init__.py @@ -0,0 +1,11 @@ +from .costvol import * +from .sepconv import * +from .softsplat import * +from .adacof import * +from .correlation import * +from comfy.model_management import is_nvidia, get_torch_device_name, get_torch_device + +def init(): + if not is_nvidia(): + raise NotImplementedError(f"CuPy ops backend only support CUDA device but found {get_torch_device_name(get_torch_device())} instead. Try Taichi ops backend by editing config.yaml") + return \ No newline at end of file diff --git a/vfi_models/ops/cupy_ops/adacof.py b/vfi_models/ops/cupy_ops/adacof.py new file mode 100644 index 0000000000000000000000000000000000000000..378469045ab9dcb31df62080ba80b2cf5aa9b616 --- /dev/null +++ b/vfi_models/ops/cupy_ops/adacof.py @@ -0,0 +1,491 @@ +import torch +from .utils import cuda_kernel, cuda_launch, cuda_int32 +import math + +kernel_AdaCoF_updateOutput = """ + extern "C" __global__ void kernel_AdaCoF_updateOutput( + const int n, + const float* input, + const float* weight, + const float* offset_i, + const float* offset_j, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float dblOutput = 0.0; + + const int intSample = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int c = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int i = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int j = ( intIndex ) % SIZE_3(output); + + for (int k = 0; k < F_SIZE; k += 1) { + for (int l = 0; l < F_SIZE; l += 1) { + float w = VALUE_4(weight, intSample, k*F_SIZE+l, i, j); + float alpha = VALUE_4(offset_i, intSample, k*F_SIZE+l, i, j); + float beta = VALUE_4(offset_j, intSample, k*F_SIZE+l, i, j); + int A = (int) alpha; + int B = (int) beta; + + int i_k_A = i+k*DILATION+A; + if(i_k_A < 0) + i_k_A = 0; + if(i_k_A > SIZE_2(input) - 1) + i_k_A = SIZE_2(input) - 1; + + int j_l_B = j+l*DILATION+B; + if(j_l_B < 0) + j_l_B = 0; + if(j_l_B > SIZE_3(input) - 1) + j_l_B = SIZE_3(input) - 1; + + int i_k_A_1 = i+k*DILATION+A+1; + if(i_k_A_1 < 0) + i_k_A_1 = 0; + if(i_k_A_1 > SIZE_2(input) - 1) + i_k_A_1 = SIZE_2(input) - 1; + + int j_l_B_1 = j+l*DILATION+B+1; + if(j_l_B_1 < 0) + j_l_B_1 = 0; + if(j_l_B_1 > SIZE_3(input) - 1) + j_l_B_1 = SIZE_3(input) - 1; + + dblOutput += w * ( + VALUE_4(input, intSample, c, i_k_A, j_l_B)*(1-(alpha-(float)A))*(1-(beta-(float)B)) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B)*(alpha-(float)A)*(1-(beta-(float)B)) + + VALUE_4(input, intSample, c, i_k_A, j_l_B_1)*(1-(alpha-(float)A))*(beta-(float)B) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B_1)*(alpha-(float)A)*(beta-(float)B) + ); + } + } + + output[intIndex] = dblOutput; + } } +""" + +kernel_AdaCoF_updateGradWeight = """ + extern "C" __global__ void kernel_AdaCoF_updateGradWeight( + const int n, + const float* gradLoss, + const float* input, + const float* offset_i, + const float* offset_j, + float* gradWeight + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float floatOutput = 0.0; + + const int intSample = ( intIndex / SIZE_3(gradWeight) / SIZE_2(gradWeight) / SIZE_1(gradWeight) ) % SIZE_0(gradWeight); + const int intDepth = ( intIndex / SIZE_3(gradWeight) / SIZE_2(gradWeight) ) % SIZE_1(gradWeight); + const int i = ( intIndex / SIZE_3(gradWeight) ) % SIZE_2(gradWeight); + const int j = ( intIndex ) % SIZE_3(gradWeight); + + int k = intDepth / F_SIZE; + int l = intDepth % F_SIZE; + + for (int c = 0; c < 3; c++) + { + float delta = VALUE_4(gradLoss, intSample, c, i, j); + float alpha = VALUE_4(offset_i, intSample, k*F_SIZE+l, i, j); + float beta = VALUE_4(offset_j, intSample, k*F_SIZE+l, i, j); + int A = (int) alpha; + int B = (int) beta; + + int i_k_A = i+k*DILATION+A; + if(i_k_A < 0) + i_k_A = 0; + if(i_k_A > SIZE_2(input) - 1) + i_k_A = SIZE_2(input) - 1; + + int j_l_B = j+l*DILATION+B; + if(j_l_B < 0) + j_l_B = 0; + if(j_l_B > SIZE_3(input) - 1) + j_l_B = SIZE_3(input) - 1; + + int i_k_A_1 = i+k*DILATION+A+1; + if(i_k_A_1 < 0) + i_k_A_1 = 0; + if(i_k_A_1 > SIZE_2(input) - 1) + i_k_A_1 = SIZE_2(input) - 1; + + int j_l_B_1 = j+l*DILATION+B+1; + if(j_l_B_1 < 0) + j_l_B_1 = 0; + if(j_l_B_1 > SIZE_3(input) - 1) + j_l_B_1 = SIZE_3(input) - 1; + + floatOutput += delta * ( + VALUE_4(input, intSample, c, i_k_A, j_l_B)*(1-(alpha-(float)A))*(1-(beta-(float)B)) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B)*(alpha-(float)A)*(1-(beta-(float)B)) + + VALUE_4(input, intSample, c, i_k_A, j_l_B_1)*(1-(alpha-(float)A))*(beta-(float)B) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B_1)*(alpha-(float)A)*(beta-(float)B) + ); + } + + gradWeight[intIndex] = floatOutput; + } } +""" + +kernel_AdaCoF_updateGradAlpha = """ + extern "C" __global__ void kernel_AdaCoF_updateGradAlpha( + const int n, + const float* gradLoss, + const float* input, + const float* weight, + const float* offset_i, + const float* offset_j, + float* gradOffset_i + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float floatOutput = 0.0; + + const int intSample = ( intIndex / SIZE_3(gradOffset_i) / SIZE_2(gradOffset_i) / SIZE_1(gradOffset_i) ) % SIZE_0(gradOffset_i); + const int intDepth = ( intIndex / SIZE_3(gradOffset_i) / SIZE_2(gradOffset_i) ) % SIZE_1(gradOffset_i); + const int i = ( intIndex / SIZE_3(gradOffset_i) ) % SIZE_2(gradOffset_i); + const int j = ( intIndex ) % SIZE_3(gradOffset_i); + + int k = intDepth / F_SIZE; + int l = intDepth % F_SIZE; + + for (int c = 0; c < 3; c++) + { + float delta = VALUE_4(gradLoss, intSample, c, i, j); + float w = VALUE_4(weight, intSample, k*F_SIZE+l, i, j); + float alpha = VALUE_4(offset_i, intSample, k*F_SIZE+l, i, j); + float beta = VALUE_4(offset_j, intSample, k*F_SIZE+l, i, j); + int A = (int) alpha; + int B = (int) beta; + + int i_k_A = i+k*DILATION+A; + if(i_k_A < 0) + i_k_A = 0; + if(i_k_A > SIZE_2(input) - 1) + i_k_A = SIZE_2(input) - 1; + + int j_l_B = j+l*DILATION+B; + if(j_l_B < 0) + j_l_B = 0; + if(j_l_B > SIZE_3(input) - 1) + j_l_B = SIZE_3(input) - 1; + + int i_k_A_1 = i+k*DILATION+A+1; + if(i_k_A_1 < 0) + i_k_A_1 = 0; + if(i_k_A_1 > SIZE_2(input) - 1) + i_k_A_1 = SIZE_2(input) - 1; + + int j_l_B_1 = j+l*DILATION+B+1; + if(j_l_B_1 < 0) + j_l_B_1 = 0; + if(j_l_B_1 > SIZE_3(input) - 1) + j_l_B_1 = SIZE_3(input) - 1; + + floatOutput += delta * w * ( + - VALUE_4(input, intSample, c, i_k_A, j_l_B)*(1-(beta-(float)B)) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B)*(1-(beta-(float)B)) - + VALUE_4(input, intSample, c, i_k_A, j_l_B_1)*(beta-(float)B) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B_1)*(beta-(float)B) + ); + } + + gradOffset_i[intIndex] = floatOutput; + } } +""" + +kernel_AdaCoF_updateGradBeta = """ + extern "C" __global__ void kernel_AdaCoF_updateGradBeta( + const int n, + const float* gradLoss, + const float* input, + const float* weight, + const float* offset_i, + const float* offset_j, + float* gradOffset_j + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float floatOutput = 0.0; + + const int intSample = ( intIndex / SIZE_3(gradOffset_j) / SIZE_2(gradOffset_j) / SIZE_1(gradOffset_j) ) % SIZE_0(gradOffset_j); + const int intDepth = ( intIndex / SIZE_3(gradOffset_j) / SIZE_2(gradOffset_j) ) % SIZE_1(gradOffset_j); + const int i = ( intIndex / SIZE_3(gradOffset_j) ) % SIZE_2(gradOffset_j); + const int j = ( intIndex ) % SIZE_3(gradOffset_j); + + int k = intDepth / F_SIZE; + int l = intDepth % F_SIZE; + + for (int c = 0; c < 3; c++) + { + float delta = VALUE_4(gradLoss, intSample, c, i, j); + float w = VALUE_4(weight, intSample, k*F_SIZE+l, i, j); + float alpha = VALUE_4(offset_i, intSample, k*F_SIZE+l, i, j); + float beta = VALUE_4(offset_j, intSample, k*F_SIZE+l, i, j); + int A = (int) alpha; + int B = (int) beta; + + int i_k_A = i+k*DILATION+A; + if(i_k_A < 0) + i_k_A = 0; + if(i_k_A > SIZE_2(input) - 1) + i_k_A = SIZE_2(input) - 1; + + int j_l_B = j+l*DILATION+B; + if(j_l_B < 0) + j_l_B = 0; + if(j_l_B > SIZE_3(input) - 1) + j_l_B = SIZE_3(input) - 1; + + int i_k_A_1 = i+k*DILATION+A+1; + if(i_k_A_1 < 0) + i_k_A_1 = 0; + if(i_k_A_1 > SIZE_2(input) - 1) + i_k_A_1 = SIZE_2(input) - 1; + + int j_l_B_1 = j+l*DILATION+B+1; + if(j_l_B_1 < 0) + j_l_B_1 = 0; + if(j_l_B_1 > SIZE_3(input) - 1) + j_l_B_1 = SIZE_3(input) - 1; + + floatOutput += delta * w * ( + - VALUE_4(input, intSample, c, i_k_A, j_l_B)*(1-(alpha-(float)A)) - + VALUE_4(input, intSample, c, i_k_A_1, j_l_B)*(alpha-(float)A) + + VALUE_4(input, intSample, c, i_k_A, j_l_B_1)*(1-(alpha-(float)A)) + + VALUE_4(input, intSample, c, i_k_A_1, j_l_B_1)*(alpha-(float)A) + ); + } + + gradOffset_j[intIndex] = floatOutput; + } } +""" + +class FunctionAdaCoF(torch.autograd.Function): + # end + @staticmethod + def forward(ctx, input, weight, offset_i, offset_j, dilation): + ctx.save_for_backward(input, weight, offset_i, offset_j) + ctx.dilation = dilation + + intSample = input.size(0) + intInputDepth = input.size(1) + intInputHeight = input.size(2) + intInputWidth = input.size(3) + intFilterSize = int(math.sqrt(weight.size(1))) + intOutputHeight = weight.size(2) + intOutputWidth = weight.size(3) + + assert ( + intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 + ) + assert ( + intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 + ) + + assert input.is_contiguous() == True + assert weight.is_contiguous() == True + assert offset_i.is_contiguous() == True + assert offset_j.is_contiguous() == True + + output = input.new_zeros( + intSample, intInputDepth, intOutputHeight, intOutputWidth + ) + + if input.is_cuda == True: + + class Stream: + ptr = torch.cuda.current_stream().cuda_stream + + # end + + n = output.nelement() + cuda_launch( + cuda_kernel( + "kernel_AdaCoF_updateOutput", + kernel_AdaCoF_updateOutput, + { + "input": input, + "weight": weight, + "offset_i": offset_i, + "offset_j": offset_j, + "output": output, + }, + F_SIZE=str(intFilterSize), + DILATION=str(dilation) + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + input.data_ptr(), + weight.data_ptr(), + offset_i.data_ptr(), + offset_j.data_ptr(), + output.data_ptr(), + ], + stream=Stream, + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + @staticmethod + def backward(ctx, gradOutput): + input, weight, offset_i, offset_j = ctx.saved_tensors + dilation = ctx.dilation + + intSample = input.size(0) + intInputDepth = input.size(1) + intInputHeight = input.size(2) + intInputWidth = input.size(3) + intFilterSize = int(math.sqrt(weight.size(1))) + intOutputHeight = weight.size(2) + intOutputWidth = weight.size(3) + + assert ( + intInputHeight - ((intFilterSize - 1) * dilation + 1) == intOutputHeight - 1 + ) + assert ( + intInputWidth - ((intFilterSize - 1) * dilation + 1) == intOutputWidth - 1 + ) + + assert gradOutput.is_contiguous() == True + + gradInput = ( + input.new_zeros(intSample, intInputDepth, intInputHeight, intInputWidth) + if ctx.needs_input_grad[0] == True + else None + ) + gradWeight = ( + input.new_zeros( + intSample, intFilterSize**2, intOutputHeight, intOutputWidth + ) + if ctx.needs_input_grad[1] == True + else None + ) + gradOffset_i = ( + input.new_zeros( + intSample, intFilterSize**2, intOutputHeight, intOutputWidth + ) + if ctx.needs_input_grad[2] == True + else None + ) + gradOffset_j = ( + input.new_zeros( + intSample, intFilterSize**2, intOutputHeight, intOutputWidth + ) + if ctx.needs_input_grad[2] == True + else None + ) + + if input.is_cuda == True: + + class Stream: + ptr = torch.cuda.current_stream().cuda_stream + + # end + + # weight grad + n_w = gradWeight.nelement() + cuda_launch( + cuda_kernel( + "kernel_AdaCoF_updateGradWeight", + kernel_AdaCoF_updateGradWeight, + { + "gradLoss": gradOutput, + "input": input, + "offset_i": offset_i, + "offset_j": offset_j, + "gradWeight": gradWeight, + }, + F_SIZE=str(intFilterSize), + DILATION=str(dilation) + ), + )( + grid=tuple([int((n_w + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n_w, + gradOutput.data_ptr(), + input.data_ptr(), + offset_i.data_ptr(), + offset_j.data_ptr(), + gradWeight.data_ptr(), + ], + stream=Stream, + ) + + # alpha grad + n_i = gradOffset_i.nelement() + cuda_launch( + cuda_kernel( + "kernel_AdaCoF_updateGradAlpha", + kernel_AdaCoF_updateGradAlpha, + { + "gradLoss": gradOutput, + "input": input, + "weight": weight, + "offset_i": offset_i, + "offset_j": offset_j, + "gradOffset_i": gradOffset_i, + }, + F_SIZE=str(intFilterSize), + DILATION=str(dilation) + ), + )( + grid=tuple([int((n_i + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n_i, + gradOutput.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + offset_i.data_ptr(), + offset_j.data_ptr(), + gradOffset_i.data_ptr(), + ], + stream=Stream, + ) + + # beta grad + n_j = gradOffset_j.nelement() + cuda_launch( + cuda_kernel( + "kernel_AdaCoF_updateGradBeta", + kernel_AdaCoF_updateGradBeta, + { + "gradLoss": gradOutput, + "input": input, + "weight": weight, + "offset_i": offset_i, + "offset_j": offset_j, + "gradOffset_j": gradOffset_j, + }, + F_SIZE=str(intFilterSize), + DILATION=str(dilation) + ), + )( + grid=tuple([int((n_j + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n_j, + gradOutput.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + offset_i.data_ptr(), + offset_j.data_ptr(), + gradOffset_j.data_ptr(), + ], + stream=Stream, + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + # end + + return gradInput, gradWeight, gradOffset_i, gradOffset_j, None + +__all__ = ["FunctionAdaCoF"] diff --git a/vfi_models/ops/cupy_ops/batch_edt.py b/vfi_models/ops/cupy_ops/batch_edt.py new file mode 100644 index 0000000000000000000000000000000000000000..3c3542c7dc9332417ed898de2f7da5e3f750b298 --- /dev/null +++ b/vfi_models/ops/cupy_ops/batch_edt.py @@ -0,0 +1,119 @@ +############### DISTANCE TRANSFORM ############### +# img tensor: (bs,h,w) or (bs,1,h,w) +# returns same shape +# expects white lines, black whitespace +# defaults to diameter if empty image +from .utils import cuda_kernel, cuda_launch, cuda_int32, cuda_float32 +import torch + +_batch_edt_kernel = ( + "kernel_dt", + """ + extern "C" __global__ void kernel_dt( + const int bs, + const int h, + const int w, + const float diam2, + float* data, + float* output + ) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= bs*h*w) { + return; + } + int pb = idx / (h*w); + int pi = (idx - h*w*pb) / w; + int pj = (idx - h*w*pb - w*pi); + + float cost; + float mincost = diam2; + for (int j = 0; j < w; j++) { + cost = data[h*w*pb + w*pi + j] + (pj-j)*(pj-j); + if (cost < mincost) { + mincost = cost; + } + } + output[idx] = mincost; + return; + } +""", +) +_batch_edt = None + + +def batch_edt(img, block=1024): + # must initialize cuda/cupy after forking + global _batch_edt + if _batch_edt is None: + _batch_edt = cuda_launch(*_batch_edt_kernel) + + # bookkeeppingg + if len(img.shape) == 4: + assert img.shape[1] == 1 + img = img.squeeze(1) + expand = True + else: + expand = False + bs, h, w = img.shape + diam2 = h**2 + w**2 + odtype = img.dtype + grid = (img.nelement() + block - 1) // block + + # cupy implementation + if img.is_cuda: + # first pass, y-axis + data = ((1 - img.type(torch.float32)) * diam2).contiguous() + intermed = torch.zeros_like(data) + _batch_edt( + grid=(grid, 1, 1), + block=(block, 1, 1), # < 1024 + args=[ + cuda_int32(bs), + cuda_int32(h), + cuda_int32(w), + cuda_float32(diam2), + data.data_ptr(), + intermed.data_ptr(), + ], + ) + + # second pass, x-axis + intermed = intermed.permute(0, 2, 1).contiguous() + out = torch.zeros_like(intermed) + _batch_edt( + grid=(grid, 1, 1), + block=(block, 1, 1), + args=[ + cuda_int32(bs), + cuda_int32(w), + cuda_int32(h), + cuda_float32(diam2), + intermed.data_ptr(), + out.data_ptr(), + ], + ) + ans = out.permute(0, 2, 1).sqrt() + ans = ans.type(odtype) if odtype != ans.dtype else ans + + # default to scipy cpu implementation + else: + raise NotImplementedError() + """ sums = img.sum(dim=(1, 2)) + ans = torch.tensor( + np.stack( + [ + scipy.ndimage.morphology.distance_transform_edt(i) + if s != 0 + else np.ones_like(i) # change scipy behavior for empty image + * np.sqrt(diam2) + for i, s in zip(1 - img, sums) + ] + ), + dtype=odtype, + ) """ + + if expand: + ans = ans.unsqueeze(1) + return ans + +__all__ = ["batch_edt"] \ No newline at end of file diff --git a/vfi_models/ops/cupy_ops/correlation.py b/vfi_models/ops/cupy_ops/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e69e2dbbba453ab367dc76a1c3c566d2f5540c --- /dev/null +++ b/vfi_models/ops/cupy_ops/correlation.py @@ -0,0 +1,413 @@ +import torch +from .utils import cuda_kernel, cuda_launch, cuda_int32 + +kernel_Correlation_rearrange = """ + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +""" + +kernel_Correlation_updateOutput = """ + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +""" + +kernel_Correlation_updateGradFirst = """ + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +""" + +kernel_Correlation_updateGradSecond = """ + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +""" + + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros( + [first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]] + ) + rbot1 = first.new_zeros( + [first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]] + ) + + self.save_for_backward(first, second, rbot0, rbot1) + + first = first.contiguous() + assert first.is_cuda == True + second = second.contiguous() + assert second.is_cuda == True + + output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cuda_launch( + cuda_kernel( + "kernel_Correlation_rearrange", kernel_Correlation_rearrange, {"input": first, "output": rbot0} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()], + ) + + n = second.shape[2] * second.shape[3] + cuda_launch( + cuda_kernel( + "kernel_Correlation_rearrange", kernel_Correlation_rearrange, {"input": second, "output": rbot1} + ), + )( + grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()], + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cuda_launch( + cuda_kernel( + "kernel_Correlation_updateOutput", + kernel_Correlation_updateOutput, + {"rbot0": rbot0, "rbot1": rbot1, "top": output}, + ), + )( + grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=first.shape[1] * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()], + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous() + assert gradOutput.is_cuda == True + + gradFirst = ( + first.new_zeros( + [first.shape[0], first.shape[1], first.shape[2], first.shape[3]] + ) + if self.needs_input_grad[0] == True + else None + ) + gradSecond = ( + first.new_zeros( + [first.shape[0], first.shape[1], first.shape[2], first.shape[3]] + ) + if self.needs_input_grad[1] == True + else None + ) + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cuda_launch( + cuda_kernel( + "kernel_Correlation_updateGradFirst", + kernel_Correlation_updateGradFirst, + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": gradOutput, + "gradFirst": gradFirst, + "gradSecond": None, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + gradFirst.data_ptr(), + None, + ], + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cuda_launch( + cuda_kernel( + "kernel_Correlation_updateGradSecond", + kernel_Correlation_updateGradSecond, + { + "rbot0": rbot0, + "rbot1": rbot1, + "gradOutput": gradOutput, + "gradFirst": None, + "gradSecond": gradSecond, + }, + ), + )( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + n, + intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + None, + gradSecond.data_ptr(), + ], + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + + # end + + +# end + + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + +# end + + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + # end + +__all__ = ["_FunctionCorrelation", "FunctionCorrelation", "ModuleCorrelation"] diff --git a/vfi_models/ops/cupy_ops/costvol.py b/vfi_models/ops/cupy_ops/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..070dd483be1716de1a3986034eaaeb9cddda9ee9 --- /dev/null +++ b/vfi_models/ops/cupy_ops/costvol.py @@ -0,0 +1,317 @@ +from .utils import cuda_kernel, cuda_launch, cuda_int32 +import torch, collections + +costvol_out = """ + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } +""" + +costvol_onegrad = """ + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } +""" + +costvol_twograd = """ + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } +""" + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty( + [tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]] + ) + + cuda_launch( + cuda_kernel( + "costvol_out", + costvol_out, + { + "intChans": tenOne.shape[1], + "tenOne": tenOne, + "tenTwo": tenTwo, + "tenOut": tenOut, + }, + ) + )( + grid=tuple( + [ + int( + ( + (tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + + 512 + - 1 + ) + / 512 + ), + 1, + 1, + ] + ), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), + tenOne.data_ptr(), + tenTwo.data_ptr(), + tenOut.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert tenOutgrad.is_cuda == True + + tenOnegrad = ( + tenOne.new_zeros( + [tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]] + ) + if self.needs_input_grad[0] == True + else None + ) + tenTwograd = ( + tenTwo.new_zeros( + [tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]] + ) + if self.needs_input_grad[1] == True + else None + ) + + if tenOnegrad is not None: + cuda_launch( + cuda_kernel( + "costvol_onegrad", + costvol_onegrad, + { + "intChans": tenOne.shape[1], + "tenOne": tenOne, + "tenTwo": tenTwo, + "tenOutgrad": tenOutgrad, + "tenOnegrad": tenOnegrad, + "tenTwograd": tenTwograd, + }, + ) + )( + grid=tuple( + [ + int( + ( + ( + tenOnegrad.shape[0] + * tenOnegrad.shape[2] + * tenOnegrad.shape[3] + ) + + 512 + - 1 + ) + / 512 + ), + 1, + 1, + ] + ), + block=tuple([512, 1, 1]), + args=[ + cuda_int32( + tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3] + ), + tenOne.data_ptr(), + tenTwo.data_ptr(), + tenOutgrad.data_ptr(), + tenOnegrad.data_ptr(), + tenTwograd.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + if tenTwograd is not None: + cuda_launch( + cuda_kernel( + "costvol_twograd", + costvol_twograd, + { + "intChans": tenOne.shape[1], + "tenOne": tenOne, + "tenTwo": tenTwo, + "tenOutgrad": tenOutgrad, + "tenOnegrad": tenOnegrad, + "tenTwograd": tenTwograd, + }, + ) + )( + grid=tuple( + [ + int( + ( + ( + tenTwograd.shape[0] + * tenTwograd.shape[2] + * tenTwograd.shape[3] + ) + + 512 + - 1 + ) + / 512 + ), + 1, + 1, + ] + ), + block=tuple([512, 1, 1]), + args=[ + cuda_int32( + tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3] + ), + tenOne.data_ptr(), + tenTwo.data_ptr(), + tenOutgrad.data_ptr(), + tenOnegrad.data_ptr(), + tenTwograd.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + return tenOnegrad, tenTwograd, None, None + + # end + + +# end + +__all__ = ["costvol_func"] \ No newline at end of file diff --git a/vfi_models/ops/cupy_ops/sepconv.py b/vfi_models/ops/cupy_ops/sepconv.py new file mode 100644 index 0000000000000000000000000000000000000000..c334cdca77674f2566dd0075903e4d590e6d9eca --- /dev/null +++ b/vfi_models/ops/cupy_ops/sepconv.py @@ -0,0 +1,332 @@ +import torch +from .utils import cuda_launch, cuda_kernel, cuda_int32 + +sepconv_vergrad = """ + extern "C" __global__ void __launch_bounds__(512) sepconv_vergrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenVer, + const {{type}}* __restrict__ tenHor, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenVergrad, + {{type}}* __restrict__ tenHorgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenVergrad) / SIZE_2(tenVergrad) / SIZE_1(tenVergrad) ) % SIZE_0(tenVergrad); + const int intC = ( intIndex / SIZE_3(tenVergrad) / SIZE_2(tenVergrad) ) % SIZE_1(tenVergrad); + const int intY = ( intIndex / SIZE_3(tenVergrad) ) % SIZE_2(tenVergrad); + const int intX = ( intIndex ) % SIZE_3(tenVergrad); + + {{type}} fltVergrad = 0.0; + + {{type}} fltKahanc = 0.0; + {{type}} fltKahany = 0.0; + {{type}} fltKahant = 0.0; + + for (int intI = 0; intI < SIZE_1(tenIn); intI += 1) { + for (int intFx = 0; intFx < SIZE_1(tenHor); intFx += 1) { + fltKahany = VALUE_4(tenHor, intN, intFx, intY, intX) * VALUE_4(tenIn, intN, intI, intY + intC, intX + intFx) * VALUE_4(tenOutgrad, intN, intI, intY, intX); + fltKahany = fltKahany - fltKahanc; + fltKahant = fltVergrad + fltKahany; + fltKahanc = (fltKahant - fltVergrad) - fltKahany; + fltVergrad = fltKahant; + } + } + + tenVergrad[intIndex] = fltVergrad; + } } +""" + +sepconv_ingrad = """ + extern "C" __global__ void __launch_bounds__(512) sepconv_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenVer, + const {{type}}* __restrict__ tenHor, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenVergrad, + {{type}}* __restrict__ tenHorgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + {{type}} fltIngrad = 0.0; + + {{type}} fltKahanc = 0.0; + {{type}} fltKahany = 0.0; + {{type}} fltKahant = 0.0; + + for (int intFy = 0; intFy < SIZE_1(tenVer); intFy += 1) { + int intKy = intY + intFy - (SIZE_1(tenVer) - 1); + + if (intKy < 0) { continue; } + if (intKy >= SIZE_2(tenVer)) { continue; } + + for (int intFx = 0; intFx < SIZE_1(tenHor); intFx += 1) { + int intKx = intX + intFx - (SIZE_1(tenHor) - 1); + + if (intKx < 0) { continue; } + if (intKx >= SIZE_3(tenHor)) { continue; } + + fltKahany = VALUE_4(tenVer, intN, (SIZE_1(tenVer) - 1) - intFy, intKy, intKx) * VALUE_4(tenHor, intN, (SIZE_1(tenHor) - 1) - intFx, intKy, intKx) * VALUE_4(tenOutgrad, intN, intC, intKy, intKx); + fltKahany = fltKahany - fltKahanc; + fltKahant = fltIngrad + fltKahany; + fltKahanc = (fltKahant - fltIngrad) - fltKahany; + fltIngrad = fltKahant; + } + } + + tenIngrad[intIndex] = fltIngrad; + } } +""" + +sepconv_out = """ + extern "C" __global__ void __launch_bounds__(512) sepconv_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenVer, + const {{type}}* __restrict__ tenHor, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOut = 0.0; + + {{type}} fltKahanc = 0.0; + {{type}} fltKahany = 0.0; + {{type}} fltKahant = 0.0; + + for (int intFy = 0; intFy < SIZE_1(tenVer); intFy += 1) { + for (int intFx = 0; intFx < SIZE_1(tenHor); intFx += 1) { + fltKahany = VALUE_4(tenIn, intN, intC, intY + intFy, intX + intFx) * VALUE_4(tenVer, intN, intFy, intY, intX) * VALUE_4(tenHor, intN, intFx, intY, intX); + fltKahany = fltKahany - fltKahanc; + fltKahant = fltOut + fltKahany; + fltKahanc = (fltKahant - fltOut) - fltKahany; + fltOut = fltKahant; + } + } + + tenOut[intIndex] = fltOut; + } } +""" + +sepconv_horgrad = """ + extern "C" __global__ void __launch_bounds__(512) sepconv_horgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenVer, + const {{type}}* __restrict__ tenHor, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenVergrad, + {{type}}* __restrict__ tenHorgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenHorgrad) / SIZE_2(tenHorgrad) / SIZE_1(tenHorgrad) ) % SIZE_0(tenHorgrad); + const int intC = ( intIndex / SIZE_3(tenHorgrad) / SIZE_2(tenHorgrad) ) % SIZE_1(tenHorgrad); + const int intY = ( intIndex / SIZE_3(tenHorgrad) ) % SIZE_2(tenHorgrad); + const int intX = ( intIndex ) % SIZE_3(tenHorgrad); + + {{type}} fltHorgrad = 0.0; + + {{type}} fltKahanc = 0.0; + {{type}} fltKahany = 0.0; + {{type}} fltKahant = 0.0; + + for (int intI = 0; intI < SIZE_1(tenIn); intI += 1) { + for (int intFy = 0; intFy < SIZE_1(tenVer); intFy += 1) { + fltKahany = VALUE_4(tenVer, intN, intFy, intY, intX) * VALUE_4(tenIn, intN, intI, intY + intFy, intX + intC) * VALUE_4(tenOutgrad, intN, intI, intY, intX); + fltKahany = fltKahany - fltKahanc; + fltKahant = fltHorgrad + fltKahany; + fltKahanc = (fltKahant - fltHorgrad) - fltKahany; + fltHorgrad = fltKahant; + } + } + + tenHorgrad[intIndex] = fltHorgrad; + } } +""" + +class sepconv_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenVer, tenHor): + tenOut = tenIn.new_empty( + [ + tenIn.shape[0], + tenIn.shape[1], + tenVer.shape[2] and tenHor.shape[2], + tenVer.shape[3] and tenHor.shape[3], + ] + ) + + if tenIn.is_cuda == True: + cuda_launch( + cuda_kernel( + "sepconv_out", + sepconv_out, + { + "tenIn": tenIn, + "tenVer": tenVer, + "tenHor": tenHor, + "tenOut": tenOut, + }, + ) + )( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenOut.nelement()), + tenIn.data_ptr(), + tenVer.data_ptr(), + tenHor.data_ptr(), + tenOut.data_ptr(), + ], + ) + + elif tenIn.is_cuda != True: + assert False + + # end + + self.save_for_backward(tenIn, tenVer, tenHor) + + return tenOut + + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenVer, tenHor = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert tenOutgrad.is_cuda == True + + tenIngrad = ( + tenIn.new_empty( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + if self.needs_input_grad[0] == True + else None + ) + tenVergrad = ( + tenVer.new_empty( + [tenVer.shape[0], tenVer.shape[1], tenVer.shape[2], tenVer.shape[3]] + ) + if self.needs_input_grad[1] == True + else None + ) + tenHorgrad = ( + tenHor.new_empty( + [tenHor.shape[0], tenHor.shape[1], tenHor.shape[2], tenHor.shape[3]] + ) + if self.needs_input_grad[2] == True + else None + ) + + if tenIngrad is not None: + cuda_launch( + cuda_kernel( + "sepconv_ingrad", + sepconv_ingrad, + { + "tenIn": tenIn, + "tenVer": tenVer, + "tenHor": tenHor, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenVergrad": tenVergrad, + "tenHorgrad": tenHorgrad, + }, + ) + )( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenIngrad.nelement()), + tenIn.data_ptr(), + tenVer.data_ptr(), + tenHor.data_ptr(), + tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), + None, + None, + ], + ) + # end + + if tenVergrad is not None: + cuda_launch( + cuda_kernel( + "sepconv_vergrad", + sepconv_vergrad, + { + "tenIn": tenIn, + "tenVer": tenVer, + "tenHor": tenHor, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenVergrad": tenVergrad, + "tenHorgrad": tenHorgrad, + }, + ) + )( + grid=tuple([int((tenVergrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenVergrad.nelement()), + tenIn.data_ptr(), + tenVer.data_ptr(), + tenHor.data_ptr(), + tenOutgrad.data_ptr(), + None, + tenVergrad.data_ptr(), + None, + ], + ) + # end + + if tenHorgrad is not None: + cuda_launch( + cuda_kernel( + "sepconv_horgrad", + sepconv_horgrad, + { + "tenIn": tenIn, + "tenVer": tenVer, + "tenHor": tenHor, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenVergrad": tenVergrad, + "tenHorgrad": tenHorgrad, + }, + ) + )( + grid=tuple([int((tenHorgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenHorgrad.nelement()), + tenIn.data_ptr(), + tenVer.data_ptr(), + tenHor.data_ptr(), + tenOutgrad.data_ptr(), + None, + None, + tenHorgrad.data_ptr(), + ], + ) + # end + + return tenIngrad, tenVergrad, tenHorgrad + + # end + + +# end +__all__ = ["sepconv_func"] \ No newline at end of file diff --git a/vfi_models/ops/cupy_ops/softsplat.py b/vfi_models/ops/cupy_ops/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2ae47a638c5d025e85169759b587947640e012 --- /dev/null +++ b/vfi_models/ops/cupy_ops/softsplat.py @@ -0,0 +1,440 @@ +import torch +from .utils import cuda_launch, cuda_kernel, cuda_int32 +import cupy +import collections + +softsplat_flowgrad = """ + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } +""" + +softsplat_ingrad = """ + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } +""" + +softsplat_out = """ + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } +""" + + +# end + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + + if tenIn.is_cuda == True: + cuda_launch( + cuda_kernel( + "softsplat_out", + softsplat_out, + {"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut}, + ) + )( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenOut.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOut.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + + elif tenIn.is_cuda != True: + assert False + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert tenOutgrad.is_cuda == True + + tenIngrad = ( + tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] + ) + if self.needs_input_grad[0] == True + else None + ) + tenFlowgrad = ( + tenFlow.new_zeros( + [tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]] + ) + if self.needs_input_grad[1] == True + else None + ) + + if tenIngrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_ingrad", + softsplat_ingrad, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenIngrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), + None, + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + if tenFlowgrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_flowgrad", + softsplat_flowgrad, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenFlowgrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + None, + tenFlowgrad.data_ptr(), + ], + stream=collections.namedtuple("Stream", "ptr")( + torch.cuda.current_stream().cuda_stream + ), + ) + # end + + return tenIngrad, tenFlowgrad + + # end + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert tenMetric is None or tenMetric.shape[1] == 1 + assert strType in ["summation", "average", "linear", "softmax"] + + if strType == "average": + tenInput = torch.cat( + [ + tenInput, + tenInput.new_ones( + tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3] + ), + ], + 1, + ) + + elif strType == "linear": + tenInput = torch.cat([tenInput * tenMetric, tenMetric], 1) + + elif strType == "softmax": + tenInput = torch.cat([tenInput * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOutput = softsplat_func.apply(tenInput, tenFlow) + + if strType != "summation": + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + # end + + return tenOutput + + +# end + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super().__init__() + + self.strType = strType + + # end + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) + + # end + + +# end + + + +def softsplat( + tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str +): + assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] + + if strMode == "sum": + assert tenMetric is None + if strMode == "avg": + assert tenMetric is None + if strMode.split("-")[0] == "linear": + assert tenMetric is not None + if strMode.split("-")[0] == "soft": + assert tenMetric is not None + + if strMode == "avg": + tenIn = torch.cat( + [ + tenIn, + tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]), + ], + 1, + ) + + elif strMode.split("-")[0] == "linear": + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split("-")[0] == "soft": + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split("-")[0] in ["avg", "linear", "soft"]: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split("-")) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "addeps": + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "zeroeps": + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split("-")[1] == "clipeps": + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut + + +# end + +__all__ = ["FunctionSoftsplat", "ModuleSoftsplat", "softsplat", "softsplat_func"] diff --git a/vfi_models/ops/cupy_ops/utils.py b/vfi_models/ops/cupy_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec29a48e11955eb9f9f8aa36b1542721dd362345 --- /dev/null +++ b/vfi_models/ops/cupy_ops/utils.py @@ -0,0 +1,242 @@ +import cupy +import os +import re +import torch +import typing +from pathlib import Path +import platform + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn: int): + return cupy.int32(intIn) + + +# end + + +def cuda_float32(fltIn: float): + return cupy.float32(fltIn) + + +# end + + +def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict, **replace_kwargs): + if "device" not in objCudacache: + objCudacache["device"] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert False + + # end + # end + + strKey += objCudacache["device"] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace("{{" + strVariable + "}}", objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace("{{type}}", "unsigned char") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace("{{type}}", "half") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace("{{type}}", "float") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace("{{type}}", "double") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace("{{type}}", "int") + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace("{{type}}", "long") + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert False + + elif True: + print(strVariable, type(objValue)) + assert False + + # end + # end + + while True: + objMatch = re.search("(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search("(OFFSET_)([0-4])(\()([^\)]+)(\))", strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(",") + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str(intStrides[intArg]) + + ")" + for intArg in range(intArgs) + ] + + strKernel = strKernel.replace( + objMatch.group(0), "(" + str.join("+", strIndex) + ")" + ) + # end + + while True: + objMatch = re.search("(VALUE_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str(intStrides[intArg]) + + ")" + ) + # end + + strKernel = strKernel.replace( + "VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", + strTensor + "[" + str.join("+", strIndex) + "]", + ) + # end + + for replace_key, value in replace_kwargs.items(): + strKernel = strKernel.replace(replace_key, value) + + objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel} + # end + + return strKey + + +# end +def get_cuda_home_path(): + if "CUDA_HOME" in os.environ: + return os.environ["CUDA_HOME"] + import torch + torch_lib_path = Path(torch.__file__).parent / "lib" + torch_lib_path = str(torch_lib_path.resolve()) + if os.path.exists(torch_lib_path): + nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(torch_lib_path)) + nvrtc = list(nvrtc) + return torch_lib_path if len(nvrtc) > 0 else None + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey: str): + if True:#"CUDA_HOME" not in os.environ: + cuda_home = get_cuda_home_path() + if cuda_home is not None: + os.environ["CUDA_HOME"] = cuda_home + os.environ["CUDA_PATH"] = cuda_home + else: + os.environ["CUDA_HOME"] = "/usr/local/cuda/" + os.environ["CUDA_PATH"] = "/usr/local/cuda/" + # print(objCudacache[strKey]['strKernel']) + # return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) + return cupy.RawModule(code=objCudacache[strKey]["strKernel"]).get_function( + objCudacache[strKey]["strFunction"] + ) diff --git a/vfi_models/ops/taichi_ops/__init__.py b/vfi_models/ops/taichi_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e36bb6cf71d64b17637b96712784de777eef1b28 --- /dev/null +++ b/vfi_models/ops/taichi_ops/__init__.py @@ -0,0 +1,150 @@ +import comfy.model_management as model_management +import torch +import torch.multiprocessing as mp +from .worker_process import f +from .utils import to_shared_memory + +parent_conn, child_conn, process = None, None, None +device = model_management.get_torch_device() + +def req_to_taichi_process(op_name, *tensors): + global parent_conn, child_conn, process + if parent_conn is None: + mp.set_start_method('spawn', force=True) + parent_conn, child_conn = mp.Pipe() + process = mp.Process(target=f, args=(child_conn, device)) + process.start() + + tensors = to_shared_memory(tensors) + parent_conn.send((op_name, tensors)) + result = parent_conn.recv() + del tensors + + if type(result) not in [tuple, list]: + raise Exception(result) + + return [tensor.to(device) for tensor in result] + +def softsplat( + tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str +): + assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] + + if strMode == "sum": + assert tenMetric is None + if strMode == "avg": + assert tenMetric is None + if strMode.split("-")[0] == "linear": + assert tenMetric is not None + if strMode.split("-")[0] == "soft": + assert tenMetric is not None + + if strMode == "avg": + tenIn = torch.cat( + [ + tenIn, + tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]), + ], + 1, + ) + + elif strMode.split("-")[0] == "linear": + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split("-")[0] == "soft": + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = req_to_taichi_process("softsplat_out", tenIn, tenFlow)[0] + + if strMode.split("-")[0] in ["avg", "linear", "soft"]: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split("-")) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "addeps": + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "zeroeps": + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split("-")[1] == "clipeps": + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert tenMetric is None or tenMetric.shape[1] == 1 + assert strType in ["summation", "average", "linear", "softmax"] + + if strType == "average": + tenInput = torch.cat( + [ + tenInput, + tenInput.new_ones( + tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3] + ), + ], + 1, + ) + + elif strType == "linear": + tenInput = torch.cat([tenInput * tenMetric, tenMetric], 1) + + elif strType == "softmax": + tenInput = torch.cat([tenInput * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOutput = req_to_taichi_process("softsplat_out", tenInput, tenFlow)[0] + + if strType != "summation": + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + # end + + return tenOutput + + +# end + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(self).__init__() + + self.strType = strType + + # end + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) + +def softsplat_func(tenIn, tenFlow): + return req_to_taichi_process("softsplat_out", tenIn, tenFlow)[0] + +class costvol_func: + @staticmethod + def apply(tenOne, tenTwo): + return req_to_taichi_process("costvol_out", tenOne, tenTwo)[0] + +class sepconv_func: + @staticmethod + def apply(tenIn, tenVer, tenHor): + return req_to_taichi_process("sepconv_out", tenIn, tenVer, tenHor)[0] + +def init(): + one_sample = torch.ones(1, 3, 16, 16, dtype=torch.float32, device=device) + softsplat_func(one_sample, one_sample) + costvol_func.apply(one_sample, one_sample) + sepconv_func.apply(one_sample, one_sample, one_sample) diff --git a/vfi_models/ops/taichi_ops/adacof.py b/vfi_models/ops/taichi_ops/adacof.py new file mode 100644 index 0000000000000000000000000000000000000000..acf2672e84b32f843018ee617b7800b9b884d4b1 --- /dev/null +++ b/vfi_models/ops/taichi_ops/adacof.py @@ -0,0 +1,6 @@ +import torch +class FunctionAdaCoF(torch.autograd.Function): + # end + @staticmethod + def forward(ctx, input, weight, offset_i, offset_j, dilation): + raise NotImplementedError() diff --git a/vfi_models/ops/taichi_ops/batch_edt.py b/vfi_models/ops/taichi_ops/batch_edt.py new file mode 100644 index 0000000000000000000000000000000000000000..c1fe1fd4e2e6eb51cbad9d76d29cf2191c194d14 --- /dev/null +++ b/vfi_models/ops/taichi_ops/batch_edt.py @@ -0,0 +1,2 @@ +def batch_edt(img, block=1024): + raise NotImplementedError() \ No newline at end of file diff --git a/vfi_models/ops/taichi_ops/correlation.py b/vfi_models/ops/taichi_ops/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..3782d1a930219681996d3e48361acf82b8866edc --- /dev/null +++ b/vfi_models/ops/taichi_ops/correlation.py @@ -0,0 +1,15 @@ +import torch + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + raise NotImplementedError() + +def FunctionCorrelation(tenFirst, tenSecond): + raise NotImplementedError() + return _FunctionCorrelation.apply(tenFirst, tenSecond) + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + raise NotImplementedError() + super(ModuleCorrelation, self).__init__() \ No newline at end of file diff --git a/vfi_models/ops/taichi_ops/costvol.py b/vfi_models/ops/taichi_ops/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..a82b7273cd38b7ea19c99356ea762a6e4b58ff99 --- /dev/null +++ b/vfi_models/ops/taichi_ops/costvol.py @@ -0,0 +1,26 @@ +import taichi as ti +import taichi.math as tm + +""" @ti.kernel +def costvol_out(tenOne: ti.types.ndarray(), tltOne: ti.types.ndarray(), tenTwo: ti.types.ndarray(), tenOut: ti.types.ndarray()): + N, C, H, W = tenOut.shape + for i, ch, y, x in ti.ndrange(N, C, H, W): + for intValue in range(tenOne.shape[1]): + tltOne[intValue] = tenOne[i, intValue, y, x] + + tenOut_ch = 0 + for intOy in range(y - 4, y + 4 + 1): + for intOx in range(x - 4, x + 4 + 1): + point = tm.ivec2(intOx, intOy) + fltValue = 0.0 + for intValue in range(ch): + if (point.y >= 0) and (point.y < H) and (point.x >= 0) and (point.x < W): + fltValue += ti.abs(tltOne[intValue] - tenTwo[i, intValue, point.y, point.x]) + else: + fltValue += ti.abs(tltOne[intValue]) + + tenOut[i, tenOut_ch, y, x] = fltValue / tenOne.shape[1] + tenOut_ch += 1 """ + +def worker_interface(op_name, tensors): + raise NotImplementedError(op_name) diff --git a/vfi_models/ops/taichi_ops/raw_softsplat.py b/vfi_models/ops/taichi_ops/raw_softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfd3fd52484afa930939541c10c94c7b270d7bf --- /dev/null +++ b/vfi_models/ops/taichi_ops/raw_softsplat.py @@ -0,0 +1,126 @@ +#Seperate taichi kernels to another file so that comfy.model_management won't be called in the new process + +import taichi as ti +import taichi.math as tm + +@ti.func +def put_to_tenOut(tenOut: ti.types.ndarray(), fltIn: ti.i32, flt: ti.i32, pos:tm.uvec2, i:ti.i32, ch:ti.i32): + N, C, H, W = tenOut.shape + if (pos.x >= 0) and (pos.x < W) and (pos.y >= 0) and (pos.y < H): + tenOut[i, ch, pos.y, pos.x] += fltIn * flt +@ti.kernel +def softsplat_out(tenIn: ti.types.ndarray(), tenFlow: ti.types.ndarray(), tenOut: ti.types.ndarray()): + N, C, H, W = tenIn.shape + for i, ch, y, x in ti.ndrange(N, C, H, W): + fltX = x + tenFlow[i, 0, y, x] + fltY = y + tenFlow[i, 1, y, x] + fltIn = tenIn[i, ch, y, x] + + northWest = tm.ivec2(ti.floor(fltX), ti.floor(fltY)) + northEast = northWest + [1, 0] + southWest = northWest + [0, 1] + southEast = northWest + [1, 1] + + fltNorthwest = (southEast.x - fltX) * (southEast.y - fltY) + fltNortheast = (fltX - southWest.x) * (southWest.y - fltY) + fltSouthwest = (northEast.x - fltX) * (fltY - northEast.y) + fltSoutheast = (fltX - northWest.x) * (fltY - northWest.y) + + put_to_tenOut(tenOut, fltIn, fltNorthwest, northWest, i, ch) + put_to_tenOut(tenOut, fltIn, fltNortheast, northEast, i, ch) + put_to_tenOut(tenOut, fltIn, fltSouthwest, southWest, i, ch) + put_to_tenOut(tenOut, fltIn, fltSoutheast, southEast, i, ch) + +@ti.func +def add_to_fltFlowgrad(fltFlowgrad, tenOutgrad, fltIn, flt, pos, i, ch): + N, C, H, W = tenOutgrad.shape + if (pos.x >= 0) and (pos.x < W) and (pos.y >= 0) and (pos.y < H): + fltFlowgrad += tenOutgrad[i, ch, pos.y, pos.x] * fltIn * flt + +@ti.kernel +def softsplat_flowgrad( + tenIn: ti.types.ndarray(), + tenFlow: ti.types.ndarray(), + tenOutgrad: ti.types.ndarray(), + tenIngrad: ti.types.ndarray(), + tenFlowgrad: ti.types.ndarray() +): + N, C, H, W = tenFlowgrad.shape + for i, ch, y, x in ti.ndrange(N, C, H, W): + fltFlowgrad = 0.0 + fltX = x + tenFlow[i, 0, y, x] + fltY = y + tenFlow[i, 1, y, x] + + northWest = tm.vec2(ti.floor(fltX, dtype=ti.i32), ti.floor(fltY, dtype=ti.i32)) + northEast = tm.vec2(northWest.x + 1, northWest.y) + southWest = tm.vec2(northWest.x, northWest.y + 1) + southEast = tm.vec2(northWest.x + 1, northWest.y + 1) + + if ch == 0: + fltNorthwest = -1.0 * (southEast.y - fltY) + fltNortheast = +1.0 * (southWest.y - fltY) + fltSouthwest = -1.0 * (fltY - northEast.y) + fltSoutheast = +1.0 * (fltY - northWest.y) + + elif ch == 1: + fltNorthwest = -1.0 * (southEast.x - fltX) + fltNortheast = -1.0 * (fltX - southWest.x) + fltSouthwest = +1.0 * (northEast.x - fltX) + fltSoutheast = +1.0 * (fltX - northWest.x) + + for outgrad_ch in ti.ndrange(tenOutgrad.shape[1]): + fltIn = tenIn[i, outgrad_ch, y, x] + add_to_fltFlowgrad(fltFlowgrad, tenOutgrad, fltIn, fltNorthwest, northWest, i, outgrad_ch) + add_to_fltFlowgrad(fltFlowgrad, tenOutgrad, fltIn, fltNortheast, northEast, i, outgrad_ch) + add_to_fltFlowgrad(fltFlowgrad, tenOutgrad, fltIn, fltSouthwest, southWest, i, outgrad_ch) + add_to_fltFlowgrad(fltFlowgrad, tenOutgrad, fltIn, fltSoutheast, southEast, i, outgrad_ch) + + tenFlowgrad[i] = fltFlowgrad #Is 'i' the same as intIndex? + +@ti.func +def add_to_fltIngrad(fltIngrad, tenOutgrad, flt, pos, i, ch): + N, C, H, W = tenOutgrad.shape + if (pos.x >= 0) and (pos.x < W) and (pos.y >= 0) and (pos.y < H): + fltIngrad += tenOutgrad[i, ch, pos.y, pos.x] * flt +@ti.kernel +def softsplat_ingrad( + tenIn: ti.types.ndarray(), + tenFlow: ti.types.ndarray(), + tenOutgrad: ti.types.ndarray(), + tenIngrad: ti.types.ndarray(), + tenFlowgrad: ti.types.ndarray() +): + N, C, H, W = tenIngrad.shape + for i, ch, y, x in ti.ndrange(N, C, H, W): + fltIngrad = 0.0 + fltX = x + tenFlow[i, 0, y, x] + fltY = y + tenFlow[i, 1, y, x] + + northWest = tm.vec2(ti.floor(fltX, dtype=ti.i32), ti.floor(fltY, dtype=ti.i32)) + northEast = tm.vec2(northWest.x + 1, northWest.y) + southWest = tm.vec2(northWest.x, northWest.y + 1) + southEast = tm.vec2(northWest.x + 1, northWest.y + 1) + + fltNorthwest = (southEast.x - fltX) * (southEast.y - fltY) + fltNortheast = (fltX - southWest.x) * (southWest.y - fltY) + fltSouthwest = (northEast.x - fltX) * (fltY - northEast.y) + fltSoutheast = (fltX - northWest.x) * (fltY - northWest.y) + + add_to_fltIngrad(fltIngrad, tenOutgrad, fltNorthwest, northWest, i, ch) + add_to_fltIngrad(fltIngrad, tenOutgrad, fltNortheast, northEast, i, ch) + add_to_fltIngrad(fltIngrad, tenOutgrad, fltSouthwest, southWest, i, ch) + add_to_fltIngrad(fltIngrad, tenOutgrad, fltSoutheast, southEast, i, ch) + tenIngrad[i] = fltIngrad + +# end + +def worker_interface(op_name, tensors): + if op_name == "softsplat_out": + tenIn, tenFlow = tensors + tenOut = tenIn.new_zeros(tenIn.shape) + softsplat_out(tenIn, tenFlow, tenOut) + return (tenOut, ) + + raise NotImplementedError(op_name) + +__all__ = ["worker_interface"] \ No newline at end of file diff --git a/vfi_models/ops/taichi_ops/sepconv.py b/vfi_models/ops/taichi_ops/sepconv.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d8909befa072646f4077b10bd18f735c928218 --- /dev/null +++ b/vfi_models/ops/taichi_ops/sepconv.py @@ -0,0 +1,39 @@ +import taichi as ti +import taichi.math as tm +from functools import reduce + +@ti.kernel +def sepconv_out(tenIn: ti.types.ndarray(), tenVer: ti.types.ndarray(), tenHor: ti.types.ndarray(), tenOut: ti.types.ndarray()): + N, C, H, W = tenIn.shape + intIndex = 0 + for i, ch, y, x in ti.ndrange(N, C, H, W): + fltOut, fltKahanc, fltKahany, fltKahant = 0.0, 0.0, 0.0, 0.0 + for intFy, intFx in ti.ndrange(tenVer.shape[1], tenHor.shape[1]): + fltKahany = tenIn[i, ch, y + intFy, x + intFx] * tenVer[i, intFy, y, x] * tenHor[i, intFx, y, x] + fltKahany = fltKahany - fltKahanc + fltKahant = fltOut + fltKahany + fltKahanc = (fltKahant - fltOut) - fltKahany + fltOut = fltKahant + tenOut[intIndex] = fltOut + intIndex += 1 + + +def worker_interface(op_name, tensors): + if op_name == "sepconv_out": + tenIn, tenVer, tenHor = tensors + real_tenOut_shape = [ + tenIn.shape[0], + tenIn.shape[1], + tenVer.shape[2] and tenHor.shape[2], + tenVer.shape[3] and tenHor.shape[3], + ] + tenOut = tenIn.new_zeros([ + int(reduce(lambda a, b: a * b, real_tenOut_shape)) + ]) + sepconv_out(tenIn, tenVer, tenHor, tenOut) + tenOut = tenOut.view(*real_tenOut_shape) + return (tenOut, ) + + raise NotImplementedError(op_name) + +__all__ = ["worker_interface"] \ No newline at end of file diff --git a/vfi_models/ops/taichi_ops/utils.py b/vfi_models/ops/taichi_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee42b2fb05c2230930d0fe4557db361a47b54708 --- /dev/null +++ b/vfi_models/ops/taichi_ops/utils.py @@ -0,0 +1,11 @@ +import platform +import torch +def to_shared_memory(tensors: tuple[torch.Tensor]): + return [tensor.cpu() for tensor in tensors if tensor is not None] + """ if platform.system() == "Windows": + return [tensor.cpu() for tensor in tensors if tensor is not None] + + return [tensor.share_memory_() for tensor in tensors if tensor is not None] """ + +def to_device(tensors: tuple[torch.Tensor], device: torch.device): + return [tensor.to(device) for tensor in tensors if tensor is not None] \ No newline at end of file diff --git a/vfi_models/ops/taichi_ops/worker_process.py b/vfi_models/ops/taichi_ops/worker_process.py new file mode 100644 index 0000000000000000000000000000000000000000..586d06a9e310247fcdd2368f8e362eeaaddac1fd --- /dev/null +++ b/vfi_models/ops/taichi_ops/worker_process.py @@ -0,0 +1,26 @@ +import torch.multiprocessing as mp +import torch +from .raw_softsplat import worker_interface as raw_softsplat +from .costvol import worker_interface as costvol +from .sepconv import worker_interface as sepconv +from .utils import to_shared_memory, to_device +import taichi as ti +import traceback + +def f(child_conn, device: torch.DeviceObjType): + ti.init(arch=ti.gpu) + while True: + op_name, tensors = child_conn.recv() + tensors = to_device(tensors, device) + try: + if "softsplat" in op_name: + result = raw_softsplat(op_name, tensors) + elif "costvol" in op_name: + result = costvol(op_name, tensors) + elif "sepconv" in op_name: + result = sepconv(op_name, tensors) + else: + raise NotImplementedError(op_name) + child_conn.send(to_shared_memory(result)) + except: + child_conn.send(traceback.format_exc()) diff --git a/vfi_models/rife/__init__.py b/vfi_models/rife/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c9e82dbeebcd7dcd49acc16703fbcea110d0657 --- /dev/null +++ b/vfi_models/rife/__init__.py @@ -0,0 +1,107 @@ +import torch +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList +import typing +from comfy.model_management import get_torch_device +import re +from functools import cmp_to_key +from packaging import version + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAME_VER_DICT = { + "rife40.pth": "4.0", + "rife41.pth": "4.0", + "rife42.pth": "4.2", + "rife43.pth": "4.3", + "rife44.pth": "4.3", + "rife45.pth": "4.5", + "rife46.pth": "4.6", + "rife47.pth": "4.7", + "rife48.pth": "4.7", + "rife49.pth": "4.7", + "sudo_rife4_269.662_testV1_scale1.pth": "4.0" + #Arch 4.10 doesn't work due to state dict mismatch + #TODO: Investigating and fix it + #"rife410.pth": "4.10", + #"rife411.pth": "4.10", + #"rife412.pth": "4.10" +} + +class RIFE_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": ( + sorted(list(CKPT_NAME_VER_DICT.keys()), key=lambda ckpt_name: version.parse(CKPT_NAME_VER_DICT[ckpt_name])), + {"default": "rife47.pth"} + ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 1}), + "fast_mode": ("BOOLEAN", {"default":True}), + "ensemble": ("BOOLEAN", {"default":True}), + "scale_factor": ([0.25, 0.5, 1.0, 2.0, 4.0], {"default": 1.0}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + fast_mode = False, + ensemble = False, + scale_factor = 1.0, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + """ + Perform video frame interpolation using a given checkpoint model. + + Args: + ckpt_name (str): The name of the checkpoint model to use. + frames (torch.Tensor): A tensor containing input video frames. + clear_cache_after_n_frames (int, optional): The number of frames to process before clearing CUDA cache + to prevent memory overflow. Defaults to 10. Lower numbers are safer but mean more processing time. + How high you should set it depends on how many input frames there are, input resolution (after upscaling), + how many times you want to multiply them, and how long you're willing to wait for the process to complete. + multiplier (int, optional): The multiplier for each input frame. 60 input frames * 2 = 120 output frames. Defaults to 2. + + Returns: + tuple: A tuple containing the output interpolated frames. + + Note: + This method interpolates frames in a video sequence using a specified checkpoint model. + It processes each frame sequentially, generating interpolated frames between them. + + To prevent memory overflow, it clears the CUDA cache after processing a specified number of frames. + """ + from .rife_arch import IFNet + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + arch_ver = CKPT_NAME_VER_DICT[ckpt_name] + interpolation_model = IFNet(arch_ver=arch_ver) + interpolation_model.load_state_dict(torch.load(model_path)) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model, scale_list, in_fast_mode, in_ensemble): + return model(frame_0, frame_1, timestep, scale_list, in_fast_mode, in_ensemble) + + scale_list = [8 / scale_factor, 4 / scale_factor, 2 / scale_factor, 1 / scale_factor] + + args = [interpolation_model, scale_list, fast_mode, ensemble] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, dtype=torch.float32) + ) + return (out,) diff --git a/vfi_models/rife/rife_arch.py b/vfi_models/rife/rife_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2b4545298a1f34d8429e93f0aa4f43c06c090e --- /dev/null +++ b/vfi_models/rife/rife_arch.py @@ -0,0 +1,586 @@ +""" +26-Dez-21 +https://github.com/hzwer/Practical-RIFE +https://github.com/hzwer/Practical-RIFE/blob/main/model/warplayer.py +https://github.com/HolyWu/vs-rife/blob/master/vsrife/__init__.py +""" +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import AdamW +import torch +import torch.nn.functional as F +import torch.nn as nn +import torch.optim as optim +import warnings +from comfy.model_management import get_torch_device + +device = get_torch_device() +backwarp_tenGrid = {} + + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + + if tenInput.type() == "torch.cuda.HalfTensor": + g = g.half() + + padding_mode = "border" + if device.type == "mps": + # https://github.com/pytorch/pytorch/issues/125098 + padding_mode = "zeros" + g = g.clamp(-1, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=g, + mode="bilinear", + padding_mode=padding_mode, + align_corners=True, + ) + + +def conv( + in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + arch_ver="4.0", +): + if arch_ver == "4.0": + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + if arch_ver in ["4.2", "4.3", "4.5", "4.6", "4.7", "4.10"]: + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + ) + + +def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1, arch_ver="4.0"): + if arch_ver == "4.0": + return nn.Sequential( + torch.nn.ConvTranspose2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=4, + stride=2, + padding=1, + bias=True, + ), + nn.PReLU(out_planes), + ) + if arch_ver in ["4.2", "4.3", "4.5", "4.6", "4.7", "4.10"]: + return nn.Sequential( + torch.nn.ConvTranspose2d( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=4, + stride=2, + padding=1, + bias=True, + ), + nn.LeakyReLU(0.2, True), + ) + + +class Conv2(nn.Module): + def __init__(self, in_planes, out_planes, stride=2, arch_ver="4.0"): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1, arch_ver=arch_ver) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1, arch_ver=arch_ver) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, arch_ver="4.0"): + super(IFBlock, self).__init__() + self.arch_ver = arch_ver + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1, arch_ver=arch_ver), + conv(c // 2, c, 3, 2, 1, arch_ver=arch_ver), + ) + self.arch_ver = arch_ver + + if arch_ver in ["4.0", "4.2", "4.3"]: + self.convblock = nn.Sequential( + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + conv(c, c, arch_ver=arch_ver), + ) + self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) + + if arch_ver in ["4.5", "4.6", "4.7", "4.10"]: + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + if arch_ver == "4.5": + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4 * 5, 4, 2, 1), nn.PixelShuffle(2) + ) + if arch_ver in ["4.6", "4.7", "4.10"]: + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate( + x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + if flow is not None: + flow = ( + F.interpolate( + flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False + ) + * 1.0 + / scale + ) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + if self.arch_ver == "4.0": + feat = self.convblock(feat) + feat + if self.arch_ver in ["4.2", "4.3", "4.5", "4.6", "4.7", "4.10"]: + feat = self.convblock(feat) + + tmp = self.lastconv(feat) + if self.arch_ver in ["4.0", "4.2", "4.3"]: + tmp = F.interpolate( + tmp, scale_factor=scale * 2, mode="bilinear", align_corners=False + ) + flow = tmp[:, :4] * scale * 2 + if self.arch_ver in ["4.5", "4.6", "4.7", "4.10"]: + tmp = F.interpolate( + tmp, scale_factor=scale, mode="bilinear", align_corners=False + ) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + + +class Contextnet(nn.Module): + def __init__(self, arch_ver="4.0"): + super(Contextnet, self).__init__() + c = 16 + self.conv1 = Conv2(3, c, arch_ver=arch_ver) + self.conv2 = Conv2(c, 2 * c, arch_ver=arch_ver) + self.conv3 = Conv2(2 * c, 4 * c, arch_ver=arch_ver) + self.conv4 = Conv2(4 * c, 8 * c, arch_ver=arch_ver) + + def forward(self, x, flow): + x = self.conv1(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + f1 = warp(x, flow) + x = self.conv2(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + f2 = warp(x, flow) + x = self.conv3(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + f3 = warp(x, flow) + x = self.conv4(x) + flow = ( + F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) + * 0.5 + ) + f4 = warp(x, flow) + return [f1, f2, f3, f4] + + +class Unet(nn.Module): + def __init__(self, arch_ver="4.0"): + super(Unet, self).__init__() + c = 16 + self.down0 = Conv2(17, 2 * c, arch_ver=arch_ver) + self.down1 = Conv2(4 * c, 4 * c, arch_ver=arch_ver) + self.down2 = Conv2(8 * c, 8 * c, arch_ver=arch_ver) + self.down3 = Conv2(16 * c, 16 * c, arch_ver=arch_ver) + self.up0 = deconv(32 * c, 8 * c, arch_ver=arch_ver) + self.up1 = deconv(16 * c, 4 * c, arch_ver=arch_ver) + self.up2 = deconv(8 * c, 2 * c, arch_ver=arch_ver) + self.up3 = deconv(4 * c, c, arch_ver=arch_ver) + self.conv = nn.Conv2d(c, 3, 3, 1, 1) + + def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): + s0 = self.down0( + torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1) + ) + s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) + s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) + s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) + x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) + x = self.up1(torch.cat((x, s2), 1)) + x = self.up2(torch.cat((x, s1), 1)) + x = self.up3(torch.cat((x, s0), 1)) + x = self.conv(x) + return torch.sigmoid(x) + + +""" +currently supports 4.0-4.12 + +4.0: 4.0, 4.1 +4.2: 4.2 +4.3: 4.3, 4.4 +4.5: 4.5 +4.6: 4.6 +4.7: 4.7, 4.8, 4.9 +4.10: 4.10 4.11 4.12 +""" + + +class IFNet(nn.Module): + def __init__(self, arch_ver="4.0"): + super(IFNet, self).__init__() + self.arch_ver = arch_ver + if arch_ver in ["4.0", "4.2", "4.3", "4.5", "4.6"]: + self.block0 = IFBlock(7, c=192, arch_ver=arch_ver) + self.block1 = IFBlock(8 + 4, c=128, arch_ver=arch_ver) + self.block2 = IFBlock(8 + 4, c=96, arch_ver=arch_ver) + self.block3 = IFBlock(8 + 4, c=64, arch_ver=arch_ver) + if arch_ver in ["4.7"]: + self.block0 = IFBlock(7 + 8, c=192, arch_ver=arch_ver) + self.block1 = IFBlock(8 + 4 + 8, c=128, arch_ver=arch_ver) + self.block2 = IFBlock(8 + 4 + 8, c=96, arch_ver=arch_ver) + self.block3 = IFBlock(8 + 4 + 8, c=64, arch_ver=arch_ver) + self.encode = nn.Sequential( + nn.Conv2d(3, 16, 3, 2, 1), nn.ConvTranspose2d(16, 4, 4, 2, 1) + ) + if arch_ver in ["4.10"]: + self.block0 = IFBlock(7 + 16, c=192) + self.block1 = IFBlock(8 + 4 + 16, c=128) + self.block2 = IFBlock(8 + 4 + 16, c=96) + self.block3 = IFBlock(8 + 4 + 16, c=64) + self.encode = nn.Sequential( + nn.Conv2d(3, 32, 3, 2, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, True), + nn.ConvTranspose2d(32, 8, 4, 2, 1), + ) + + if arch_ver in ["4.0", "4.2", "4.3"]: + self.contextnet = Contextnet(arch_ver=arch_ver) + self.unet = Unet(arch_ver=arch_ver) + self.arch_ver = arch_ver + + def forward( + self, + img0, + img1, + timestep=0.5, + scale_list=[8, 4, 2, 1], + training=True, + fastmode=True, + ensemble=False, + return_flow=False, + ): + img0 = torch.clamp(img0, 0, 1) + img1 = torch.clamp(img1, 0, 1) + + n, c, h, w = img0.shape + ph = ((h - 1) // 64 + 1) * 64 + pw = ((w - 1) // 64 + 1) * 64 + padding = (0, pw - w, 0, ph - h) + img0 = F.pad(img0, padding) + img1 = F.pad(img1, padding) + x = torch.cat((img0, img1), 1) + + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + else: + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + + flow_list = [] + merged = [] + mask_list = [] + + if self.arch_ver in ["4.7", "4.10"]: + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3] + + for i in range(4): + if flow is None: + # 4.0-4.6 + if self.arch_ver in ["4.0", "4.2", "4.3", "4.5", "4.6"]: + flow, mask = block[i]( + torch.cat((img0[:, :3], img1[:, :3], timestep), 1), + None, + scale=scale_list[i], + ) + if ensemble: + f1, m1 = block[i]( + torch.cat((img1[:, :3], img0[:, :3], 1 - timestep), 1), + None, + scale=scale_list[i], + ) + flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = (mask + (-m1)) / 2 + + # 4.7+ + if self.arch_ver in ["4.7", "4.10"]: + flow, mask = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), + None, + scale=scale_list[i], + ) + + if ensemble: + f_, m_ = block[i]( + torch.cat( + (img1[:, :3], img0[:, :3], f1, f0, 1 - timestep), 1 + ), + None, + scale=scale_list[i], + ) + flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (mask + (-m_)) / 2 + + else: + # 4.0-4.6 + if self.arch_ver in ["4.0", "4.2", "4.3", "4.5", "4.6"]: + f0, m0 = block[i]( + torch.cat( + (warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1 + ), + flow, + scale=scale_list[i], + ) + + if self.arch_ver in ["4.0"]: + if ( + i == 1 + and f0[:, :2].abs().max() > 32 + and f0[:, 2:4].abs().max() > 32 + and not training + ): + for k in range(4): + scale_list[k] *= 2 + flow, mask = block[0]( + torch.cat((img0[:, :3], img1[:, :3], timestep), 1), + None, + scale=scale_list[0], + ) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + f0, m0 = block[i]( + torch.cat( + ( + warped_img0[:, :3], + warped_img1[:, :3], + timestep, + mask, + ), + 1, + ), + flow, + scale=scale_list[i], + ) + + # 4.7+ + if self.arch_ver in ["4.7", "4.10"]: + fd, m0 = block[i]( + torch.cat( + ( + warped_img0[:, :3], + warped_img1[:, :3], + warp(f0, flow[:, :2]), + warp(f1, flow[:, 2:4]), + timestep, + mask, + ), + 1, + ), + flow, + scale=scale_list[i], + ) + flow = flow + fd + + # 4.0-4.6 ensemble + if ensemble and self.arch_ver in [ + "4.0", + "4.2", + "4.3", + "4.5", + "4.6", + ]: + f1, m1 = block[i]( + torch.cat( + ( + warped_img1[:, :3], + warped_img0[:, :3], + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + m0 = (m0 + (-m1)) / 2 + + # 4.7+ ensemble + if ensemble and self.arch_ver in ["4.7", "4.10"]: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + + f_, m_ = block[i]( + torch.cat( + ( + warped_img1[:, :3], + warped_img0[:, :3], + wf1, + wf0, + 1 - timestep, + -mask, + ), + 1, + ), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 + mask = (m0 + (-m_)) / 2 + + if self.arch_ver in ["4.0", "4.2", "4.3", "4.5", "4.6"]: + flow = flow + f0 + mask = mask + m0 + + if not ensemble and self.arch_ver in ["4.7", "4.10"]: + mask = m0 + + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + + if self.arch_ver in ["4.0", "4.1", "4.2", "4.3", "4.4", "4.5", "4.6"]: + mask_list[3] = torch.sigmoid(mask_list[3]) + merged[3] = merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) + + if self.arch_ver in ["4.7", "4.10"]: + mask = torch.sigmoid(mask) + merged[3] = warped_img0 * mask + warped_img1 * (1 - mask) + + if not fastmode and self.arch_ver in ["4.0", "4.2", "4.3"]: + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[3] = torch.clamp(merged[3] + res, 0, 1) + return merged[3][:, :, :h, :w] \ No newline at end of file diff --git a/vfi_models/sepconv/__init__.py b/vfi_models/sepconv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb62e76289052ff8a0bc7c9260b664c62e2e033a --- /dev/null +++ b/vfi_models/sepconv/__init__.py @@ -0,0 +1,56 @@ +import torch +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames +import typing +from comfy.model_management import soft_empty_cache, get_torch_device +from vfi_utils import InterpolationStateList, generic_frame_loop + +MODEL_TYPE = pathlib.Path(__file__).parent.name +CKPT_NAMES = ["sepconv.pth"] + + +class SepconvVFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (CKPT_NAMES, ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + from .sepconv_enhanced import Network + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + interpolation_model = Network() + interpolation_model.load_state_dict(torch.load(model_path)) + interpolation_model.eval().to(get_torch_device()) + frames = preprocess_frames(frames) + + def return_middle_frame(frame_0, frame_1, timestep, model): + return model(frame_0, frame_1) + + args = [interpolation_model] + out = postprocess_frames( + generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args, + interpolation_states=optional_interpolation_states, use_timestep=False, dtype=torch.float32) + ) + return (out,) diff --git a/vfi_models/sepconv/sepconv_enhanced.py b/vfi_models/sepconv/sepconv_enhanced.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a747ee5053bd7b1053e1a432cf86a0164a2223 --- /dev/null +++ b/vfi_models/sepconv/sepconv_enhanced.py @@ -0,0 +1,748 @@ +""" +23-nov-21 +https://github.com/sniklaus/revisiting-sepconv/blob/fea509d98157170df1fb35bf615bd41d98858e1a/run.py +https://github.com/sniklaus/revisiting-sepconv/blob/fea509d98157170df1fb35bf615bd41d98858e1a/sepconv/sepconv.py +Deleted stuffs about arguments_strModel and getopt +""" +#!/usr/bin/env python +import torch +import typing +from comfy.model_management import get_torch_device + +########################################################## +from vfi_models.ops import sepconv_func +########################################################## + + + + + +import torch + +import math +import numpy +import os +import PIL +import PIL.Image +import sys +import typing + +########################################################## + +assert ( + int(str("").join(torch.__version__.split(".")[0:2])) >= 13 +) # requires at least pytorch version 1.3.0 + +torch.set_grad_enabled( + False +) # make sure to not compute gradients for computational performance + +torch.backends.cudnn.enabled = ( + True # make sure to use cudnn for computational performance +) + +########################################################## + +########################################################## + + +class Basic(torch.nn.Module): + def __init__( + self, + strType: str, + intChans: typing.List[int], + objScratch: typing.Optional[typing.Dict] = None, + ): + super().__init__() + + self.strType = strType + self.netEvenize = None + self.netMain = None + self.netShortcut = None + + intIn = intChans[0] + intOut = intChans[-1] + netMain = [] + intChans = intChans.copy() + fltStride = 1.0 + + for intPart, strPart in enumerate(self.strType.split("+")[0].split("-")): + if strPart.startswith("conv") == True: + intKsize = 3 + intPad = 1 + strPad = "zeros" + + if "(" in strPart: + intKsize = int(strPart.split("(")[1].split(")")[0].split(",")[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if "replpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "replicate" + if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "reflect" + # end + + if "nopad" in self.strType.split("+"): + intPad = 0 + # end + + netMain += [ + torch.nn.Conv2d( + in_channels=intChans[0], + out_channels=intChans[1], + kernel_size=intKsize, + stride=1, + padding=intPad, + padding_mode=strPad, + bias="nobias" not in self.strType.split("+"), + ) + ] + intChans = intChans[1:] + fltStride *= 1.0 + + elif strPart.startswith("sconv") == True: + intKsize = 3 + intPad = 1 + strPad = "zeros" + + if "(" in strPart: + intKsize = int(strPart.split("(")[1].split(")")[0].split(",")[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if "replpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "replicate" + if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): + strPad = "reflect" + # end + + if "nopad" in self.strType.split("+"): + intPad = 0 + # end + + netMain += [ + torch.nn.Conv2d( + in_channels=intChans[0], + out_channels=intChans[1], + kernel_size=intKsize, + stride=2, + padding=intPad, + padding_mode=strPad, + bias="nobias" not in self.strType.split("+"), + ) + ] + intChans = intChans[1:] + fltStride *= 2.0 + + elif strPart.startswith("up") == True: + + class Up(torch.nn.Module): + def __init__(self, strType): + super().__init__() + + self.strType = strType + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + if self.strType == "nearest": + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=2.0, + mode="nearest", + align_corners=False, + ) + + elif self.strType == "bilinear": + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=2.0, + mode="bilinear", + align_corners=False, + ) + + elif self.strType == "pyramid": + return pyramid(tenIn, None, "up") + + elif self.strType == "shuffle": + return torch.nn.functional.pixel_shuffle( + tenIn, upscale_factor=2 + ) # https://github.com/pytorch/pytorch/issues/62854 + + # end + + assert False # to make torchscript happy + + # end + + # end + + strType = "bilinear" + + if "(" in strPart: + if "nearest" in strPart.split("(")[1].split(")")[0].split(","): + strType = "nearest" + if "pyramid" in strPart.split("(")[1].split(")")[0].split(","): + strType = "pyramid" + if "shuffle" in strPart.split("(")[1].split(")")[0].split(","): + strType = "shuffle" + # end + + netMain += [Up(strType)] + fltStride *= 0.5 + + elif strPart.startswith("prelu") == True: + netMain += [ + torch.nn.PReLU( + num_parameters=1, + init=float(strPart.split("(")[1].split(")")[0].split(",")[0]), + ) + ] + fltStride *= 1.0 + + elif True: + assert False + + # end + # end + + self.netMain = torch.nn.Sequential(*netMain) + + for strPart in self.strType.split("+")[1:]: + if strPart.startswith("skip") == True: + if intIn == intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Identity() + + elif intIn != intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Conv2d( + in_channels=intIn, + out_channels=intOut, + kernel_size=1, + stride=1, + padding=0, + bias="nobias" not in self.strType.split("+"), + ) + + elif intIn == intOut and fltStride != 1.0: + + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=self.fltScale, + mode="bilinear", + align_corners=False, + ) + + # end + + # end + + self.netShortcut = Down(1.0 / fltStride) + + elif intIn != intOut and fltStride != 1.0: + + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate( + input=tenIn, + scale_factor=self.fltScale, + mode="bilinear", + align_corners=False, + ) + + # end + + # end + + self.netShortcut = torch.nn.Sequential( + Down(1.0 / fltStride), + torch.nn.Conv2d( + in_channels=intIn, + out_channels=intOut, + kernel_size=1, + stride=1, + padding=0, + bias="nobias" not in self.strType.split("+"), + ), + ) + + # end + + elif strPart.startswith("...") == True: + pass + + # end + # end + + assert len(intChans) == 1 + + # end + + def forward(self, tenIn: torch.Tensor) -> torch.Tensor: + if self.netEvenize is not None: + tenIn = self.netEvenize(tenIn) + # end + + tenOut = self.netMain(tenIn) + + if self.netShortcut is not None: + tenOut = tenOut + self.netShortcut(tenIn) + # end + + return tenOut + + # end + + +# end + + +class Encode(torch.nn.Module): + objScratch: typing.Dict[str, typing.List[int]] = None + + def __init__( + self, + intIns: typing.List[int], + intOuts: typing.List[int], + strHor: str, + strVer: str, + objScratch: typing.Dict[str, typing.List[int]], + ): + super().__init__() + + assert len(intIns) == len(intOuts) + assert len(intOuts) == len(intIns) + + self.intRows = len(intIns) and len(intOuts) + self.intIns = intIns.copy() + self.intOuts = intOuts.copy() + self.strHor = strHor + self.strVer = strVer + self.objScratch = objScratch + + self.netHor = torch.nn.ModuleList() + self.netVer = torch.nn.ModuleList() + + for intRow in range(self.intRows): + netHor = torch.nn.Identity() + netVer = torch.nn.Identity() + + if self.intOuts[intRow] != 0: + if self.intIns[intRow] != 0: + netHor = Basic( + self.strHor, + [ + self.intIns[intRow], + self.intOuts[intRow], + self.intOuts[intRow], + ], + objScratch, + ) + # end + + if intRow != 0: + netVer = Basic( + self.strVer, + [ + self.intOuts[intRow - 1], + self.intOuts[intRow], + self.intOuts[intRow], + ], + objScratch, + ) + # end + # end + + self.netHor.append(netHor) + self.netVer.append(netVer) + # end + + # end + + def forward(self, tenIns: typing.List[torch.Tensor]) -> typing.List[torch.Tensor]: + intRow = 0 + for netHor in self.netHor: + if self.intOuts[intRow] != 0: + if self.intIns[intRow] != 0: + tenIns[intRow] = netHor(tenIns[intRow]) + # end + # end + intRow += 1 + # end + + intRow = 0 + for netVer in self.netVer: + if self.intOuts[intRow] != 0: + if intRow != 0: + tenIns[intRow] = tenIns[intRow] + netVer(tenIns[intRow - 1]) + # end + # end + intRow += 1 + # end + + for intRow, tenIn in enumerate(tenIns): + self.objScratch["levelshape" + str(intRow)] = tenIn.shape + # end + + return tenIns + + # end + + +# end + + +class Decode(torch.nn.Module): + objScratch: typing.Dict[str, typing.List[int]] = None + + def __init__( + self, + intIns: typing.List[int], + intOuts: typing.List[int], + strHor: str, + strVer: str, + objScratch: typing.Dict[str, typing.List[int]], + ): + super().__init__() + + assert len(intIns) == len(intOuts) + assert len(intOuts) == len(intIns) + + self.intRows = len(intIns) and len(intOuts) + self.intIns = intIns.copy() + self.intOuts = intOuts.copy() + self.strHor = strHor + self.strVer = strVer + self.objScratch = objScratch + + self.netHor = torch.nn.ModuleList() + self.netVer = torch.nn.ModuleList() + + for intRow in range(self.intRows - 1, -1, -1): + netHor = torch.nn.Identity() + netVer = torch.nn.Identity() + + if self.intOuts[intRow] != 0: + if self.intIns[intRow] != 0: + netHor = Basic( + self.strHor, + [ + self.intIns[intRow], + self.intOuts[intRow], + self.intOuts[intRow], + ], + objScratch, + ) + # end + + if intRow != self.intRows - 1: + netVer = Basic( + self.strVer, + [ + self.intOuts[intRow + 1], + self.intOuts[intRow], + self.intOuts[intRow], + ], + objScratch, + ) + # end + # end + + self.netHor.append(netHor) + self.netVer.append(netVer) + # end + + # end + + def forward(self, tenIns: typing.List[torch.Tensor]) -> typing.List[torch.Tensor]: + intRow = self.intRows - 1 + for netHor in self.netHor: + if self.intOuts[intRow] != 0: + if self.intIns[intRow] != 0: + tenIns[intRow] = netHor(tenIns[intRow]) + # end + # end + intRow -= 1 + # end + + intRow = self.intRows - 1 + for netVer in self.netVer: + if self.intOuts[intRow] != 0: + if intRow != self.intRows - 1: + tenVer = netVer(tenIns[intRow + 1]) + + if "levelshape" + str(intRow) in self.objScratch: + if ( + tenVer.shape[2] + == self.objScratch["levelshape" + str(intRow)][2] + 1 + ): + tenVer = torch.nn.functional.pad( + input=tenVer, + pad=[0, 0, 0, -1], + mode="constant", + value=0.0, + ) + if ( + tenVer.shape[3] + == self.objScratch["levelshape" + str(intRow)][3] + 1 + ): + tenVer = torch.nn.functional.pad( + input=tenVer, + pad=[0, -1, 0, 0], + mode="constant", + value=0.0, + ) + # end + + tenIns[intRow] = tenIns[intRow] + tenVer + # end + # end + intRow -= 1 + # end + + return tenIns + + # end + + +# end + +########################################################## + + +class Network(torch.nn.Module): + def __init__(self): + super().__init__() + + self.intEncdec = [1, 1] + self.intChannels = [32, 64, 128, 256, 512] + + self.objScratch = {} + + self.netInput = torch.nn.Conv2d( + in_channels=3, + out_channels=int(round(0.5 * self.intChannels[0])), + kernel_size=3, + stride=1, + padding=1, + padding_mode="zeros", + ) + + self.netEncode = torch.nn.Sequential( + *( + [ + Encode( + [0] * len(self.intChannels), + self.intChannels, + "prelu(0.25)-conv(3)-prelu(0.25)-conv(3)+skip", + "prelu(0.25)-sconv(3)-prelu(0.25)-conv(3)", + self.objScratch, + ) + ] + + [ + Encode( + self.intChannels, + self.intChannels, + "prelu(0.25)-conv(3)-prelu(0.25)-conv(3)+skip", + "prelu(0.25)-sconv(3)-prelu(0.25)-conv(3)", + self.objScratch, + ) + for intEncdec in range(1, self.intEncdec[0]) + ] + ) + ) + + self.netDecode = torch.nn.Sequential( + *( + [ + Decode( + [0] + self.intChannels[1:], + [0] + self.intChannels[1:], + "prelu(0.25)-conv(3)-prelu(0.25)-conv(3)+skip", + "prelu(0.25)-up(bilinear)-conv(3)-prelu(0.25)-conv(3)", + self.objScratch, + ) + for intEncdec in range(0, self.intEncdec[1]) + ] + ) + ) + + self.netVerone = Basic( + "up(bilinear)-conv(3)-prelu(0.25)-conv(3)", + [self.intChannels[1], self.intChannels[1], 51], + ) + self.netVertwo = Basic( + "up(bilinear)-conv(3)-prelu(0.25)-conv(3)", + [self.intChannels[1], self.intChannels[1], 51], + ) + self.netHorone = Basic( + "up(bilinear)-conv(3)-prelu(0.25)-conv(3)", + [self.intChannels[1], self.intChannels[1], 51], + ) + self.netHortwo = Basic( + "up(bilinear)-conv(3)-prelu(0.25)-conv(3)", + [self.intChannels[1], self.intChannels[1], 51], + ) + + # self.load_state_dict(torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/resepconv/network-' + arguments_strModel + '.pytorch', file_name='resepconv-' + arguments_strModel)) + + # end + + def forward(self, x1, x2): + # padding if needed + intWidth = x1.shape[3] + intHeight = x1.shape[2] + + intPadr = (2 - (intWidth % 2)) % 2 + intPadb = (2 - (intHeight % 2)) % 2 + + tenOne = torch.nn.functional.pad( + input=x1, pad=[0, intPadr, 0, intPadb], mode="replicate" + ) + tenTwo = torch.nn.functional.pad( + input=x2, pad=[0, intPadr, 0, intPadb], mode="replicate" + ) + #### + + tenSeq = [tenOne, tenTwo] + + with torch.set_grad_enabled(False): + tenStack = torch.stack(tenSeq, 1) + tenMean = ( + tenStack.view(tenStack.shape[0], -1) + .mean(1, True) + .view(tenStack.shape[0], 1, 1, 1) + ) + tenStd = ( + tenStack.view(tenStack.shape[0], -1) + .std(1, True) + .view(tenStack.shape[0], 1, 1, 1) + ) + tenSeq = [ + (tenFrame - tenMean) / (tenStd + 0.0000001) for tenFrame in tenSeq + ] + tenSeq = [tenFrame.detach() for tenFrame in tenSeq] + # end + + tenOut = self.netDecode( + self.netEncode( + [torch.cat([self.netInput(tenSeq[0]), self.netInput(tenSeq[1])], 1)] + + ([0.0] * (len(self.intChannels) - 1)) + ) + )[1] + + tenOne = torch.nn.functional.pad( + input=tenOne, + pad=[ + int(math.floor(0.5 * 51)), + int(math.floor(0.5 * 51)), + int(math.floor(0.5 * 51)), + int(math.floor(0.5 * 51)), + ], + mode="replicate", + ) + tenTwo = torch.nn.functional.pad( + input=tenTwo, + pad=[ + int(math.floor(0.5 * 51)), + int(math.floor(0.5 * 51)), + int(math.floor(0.5 * 51)), + int(math.floor(0.5 * 51)), + ], + mode="replicate", + ) + + tenOne = torch.cat( + [ + tenOne, + tenOne.new_ones([tenOne.shape[0], 1, tenOne.shape[2], tenOne.shape[3]]), + ], + 1, + ).detach() + tenTwo = torch.cat( + [ + tenTwo, + tenTwo.new_ones([tenTwo.shape[0], 1, tenTwo.shape[2], tenTwo.shape[3]]), + ], + 1, + ).detach() + + tenVerone = self.netVerone(tenOut) + tenVertwo = self.netVertwo(tenOut) + tenHorone = self.netHorone(tenOut) + tenHortwo = self.netHortwo(tenOut) + + tenOut = sepconv_func.apply(tenOne, tenVerone, tenHorone) + sepconv_func.apply( + tenTwo, tenVertwo, tenHortwo + ) + + tenNormalize = tenOut[:, -1:, :, :] + tenNormalize[tenNormalize.abs() < 0.01] = 1.0 + tenOut = tenOut[:, :-1, :, :] / tenNormalize + + # crop if needed + return tenOut[:, :, :intHeight, :intWidth] + + # end + + +# end + +netNetwork = None + +########################################################## + + +def estimate(tenOne, tenTwo): + global netNetwork + + if netNetwork is None: + netNetwork = Network().to(get_torch_device()).eval() + # end + + assert tenOne.shape[1] == tenTwo.shape[1] + assert tenOne.shape[2] == tenTwo.shape[2] + + intWidth = tenOne.shape[2] + intHeight = tenOne.shape[1] + + assert ( + intWidth <= 1280 + ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications + assert ( + intHeight <= 720 + ) # while our approach works with larger images, we do not recommend it unless you are aware of the implications + + tenPreprocessedOne = tenOne.to(get_torch_device()).view(1, 3, intHeight, intWidth) + tenPreprocessedTwo = tenTwo.to(get_torch_device()).view(1, 3, intHeight, intWidth) + + intPadr = (2 - (intWidth % 2)) % 2 + intPadb = (2 - (intHeight % 2)) % 2 + + tenPreprocessedOne = torch.nn.functional.pad( + input=tenPreprocessedOne, pad=[0, intPadr, 0, intPadb], mode="replicate" + ) + tenPreprocessedTwo = torch.nn.functional.pad( + input=tenPreprocessedTwo, pad=[0, intPadr, 0, intPadb], mode="replicate" + ) + + return netNetwork([tenPreprocessedOne, tenPreprocessedTwo])[ + 0, :, :intHeight, :intWidth + ].cpu() + + +# end diff --git a/vfi_models/stmfnet/__init__.py b/vfi_models/stmfnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d10f1ec96f25df510245e87bb7a41938ed4768b --- /dev/null +++ b/vfi_models/stmfnet/__init__.py @@ -0,0 +1,100 @@ +import torch +from comfy.model_management import get_torch_device, soft_empty_cache +import numpy as np +import typing +from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames, assert_batch_size +import pathlib +import warnings +import gc + +MODEL_TYPE = pathlib.Path(__file__).parent.name +device = get_torch_device() + +class STMFNet_VFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (["stmfnet.pth"], ), + "frames": ("IMAGE", ), + "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}), + "multiplier": ("INT", {"default": 2, "min": 2, "max": 2}), #TODO: Implement recursively invoking interpolator for multi-frame interpolation + "duplicate_first_last_frames": ("BOOLEAN", {"default": False}) + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ) + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + #Reference: https://github.com/danier97/ST-MFNet/blob/main/interpolate_yuv.py#L93 + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + clear_cache_after_n_frames = 10, + multiplier: typing.SupportsInt = 2, + duplicate_first_last_frames: bool = False, + optional_interpolation_states: InterpolationStateList = None, + **kwargs + ): + from .stmfnet_arch import STMFNet_Model + if multiplier != 2: + warnings.warn("Currently, ST-MFNet only supports 2x interpolation. The process will continue but please set multiplier=2 afterward") + + assert_batch_size(frames, batch_size=4, vfi_name="ST-MFNet") + interpolation_states = optional_interpolation_states + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + model = STMFNet_Model() + model.load_state_dict(torch.load(model_path)['state_dict']) + model = model.eval().to(device) + + frames = preprocess_frames(frames) + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + output_frames = [] + for frame_itr in range(len(frames) - 3): + #Does skipping frame i+1 make sanse in this case? + if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr) and interpolation_states.is_frame_skipped(frame_itr + 1): + continue + + #Ensure that input frames are in fp32 - the same dtype as model + frame0, frame1, frame2, frame3 = ( + frames[frame_itr:frame_itr+1].float(), + frames[frame_itr+1:frame_itr+2].float(), + frames[frame_itr+2:frame_itr+3].float(), + frames[frame_itr+3:frame_itr+4].float() + ) + new_frame = model(frame0.to(device), frame1.to(device), frame2.to(device), frame3.to(device)).detach().cpu() + number_of_frames_processed_since_last_cleared_cuda_cache += 2 + + if frame_itr == 0: + output_frames.append(frame0) + if duplicate_first_last_frames: + output_frames.append(frame0) # repeat the first frame + output_frames.append(frame1) + output_frames.append(new_frame) + output_frames.append(frame2) + if frame_itr == len(frames) - 4: + output_frames.append(frame3) + if duplicate_first_last_frames: + output_frames.append(frame3) # repeat the last frame + + # Try to avoid a memory overflow by clearing cuda cache regularly + if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: + print("Comfy-VFI: Clearing cache...", end = ' ') + soft_empty_cache() + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + print("Done cache clearing") + gc.collect() + + dtype = torch.float32 + output_frames = [frame.cpu().to(dtype=dtype) for frame in output_frames] #Ensure all frames are in cpu + out = torch.cat(output_frames, dim=0) + # clear cache for courtesy + print("Comfy-VFI: Final clearing cache...", end = ' ') + soft_empty_cache() + print("Done cache clearing") + return (postprocess_frames(out), ) \ No newline at end of file diff --git a/vfi_models/stmfnet/stmfnet_arch.py b/vfi_models/stmfnet/stmfnet_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1e495e0eabf68a39af6bc7eceef554c3777d9d5f --- /dev/null +++ b/vfi_models/stmfnet/stmfnet_arch.py @@ -0,0 +1,2858 @@ +# https://github.com/danielism97/ST-MFNet/blob/main/models/stmfnet.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/pwcnet.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/correlation/correlation.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/gridnet.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/feature.py +# https://github.com/danielism97/ST-MFNet/blob/main/utility.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/resnet_3D.py +# https://github.com/danielism97/ST-MFNet/blob/main/cupy_module/adacof.py +# https://github.com/danielism97/ST-MFNet/blob/main/cupy_module/softsplat.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/__init__.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/pwcnet.py +# https://github.com/danielism97/ST-MFNet/blob/main/models/misc/correlation/correlation.py +from torch.nn import functional as F +from torch.utils.model_zoo import load_url as load_state_dict_from_url +import cv2 +import math +import numpy +import numpy as np +import PIL +import PIL.Image +import re +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.optim.lr_scheduler as lrs +from vfi_models.ops import FunctionCorrelation, FunctionAdaCoF, ModuleSoftsplat +from vfi_utils import get_ckpt_container_path +import pathlib +MODEL_TYPE = pathlib.Path(__file__).parent.name + +#Simple way to reduce oranges on VSCode bar +def identity(x): + return x + + +def backwarp(tenInput, tenFlow): + backwarp_tenGrid = {} + backwarp_tenPartial = {} + if str(tenFlow.shape) not in backwarp_tenGrid: + tenHor = ( + torch.linspace( + -1.0 + (1.0 / tenFlow.shape[3]), + 1.0 - (1.0 / tenFlow.shape[3]), + tenFlow.shape[3], + ) + .view(1, 1, 1, -1) + .expand(-1, -1, tenFlow.shape[2], -1) + ) + tenVer = ( + torch.linspace( + -1.0 + (1.0 / tenFlow.shape[2]), + 1.0 - (1.0 / tenFlow.shape[2]), + tenFlow.shape[2], + ) + .view(1, 1, -1, 1) + .expand(-1, -1, -1, tenFlow.shape[3]) + ) + + backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() + # end + + if str(tenFlow.shape) not in backwarp_tenPartial: + backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones( + [tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]] + ) + # end + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + tenInput = torch.cat([tenInput, backwarp_tenPartial[str(tenFlow.shape)]], 1) + + tenOutput = torch.nn.functional.grid_sample( + input=tenInput, + grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + align_corners=False, + ) + + tenMask = tenOutput[:, -1:, :, :] + tenMask[tenMask > 0.999] = 1.0 + tenMask[tenMask < 1.0] = 0.0 + + return tenOutput[:, :-1, :, :] * tenMask + + +# end + +########################################################## + + +class PWCNet(torch.nn.Module): + def __init__(self): + super(PWCNet, self).__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super(Extractor, self).__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, + out_channels=96, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=96, + out_channels=128, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=128, + out_channels=196, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=196, + out_channels=196, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=196, + out_channels=196, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + # end + + def forward(self, tenInput): + tenOne = self.netOne(tenInput) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = self.netFou(tenThr) + tenFiv = self.netFiv(tenFou) + tenSix = self.netSix(tenFiv) + + return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] + + # end + + # end + + class Decoder(torch.nn.Module): + def __init__(self, intLevel): + super(Decoder, self).__init__() + + intPrevious = [ + None, + None, + 81 + 32 + 2 + 2, + 81 + 64 + 2 + 2, + 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, + 81, + None, + ][intLevel + 1] + intCurrent = [ + None, + None, + 81 + 32 + 2 + 2, + 81 + 64 + 2 + 2, + 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, + 81, + None, + ][intLevel + 0] + + if intLevel < 6: + self.netUpflow = torch.nn.ConvTranspose2d( + in_channels=2, + out_channels=2, + kernel_size=4, + stride=2, + padding=1, + ) + if intLevel < 6: + self.netUpfeat = torch.nn.ConvTranspose2d( + in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, + out_channels=2, + kernel_size=4, + stride=2, + padding=1, + ) + if intLevel < 6: + self.fltBackwarp = [None, None, None, 5.0, 2.5, 1.25, 0.625, None][ + intLevel + 1 + ] + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128 + 96, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128 + 96 + 64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + ) + ) + + # end + + def forward(self, tenFirst, tenSecond, objPrevious): + tenFlow = None + tenFeat = None + + if objPrevious is None: + tenFlow = None + tenFeat = None + + tenVolume = torch.nn.functional.leaky_relu( + input=FunctionCorrelation( + tenFirst=tenFirst, tenSecond=tenSecond + ), + negative_slope=0.1, + inplace=False, + ) + + tenFeat = torch.cat([tenVolume], 1) + + elif objPrevious is not None: + tenFlow = self.netUpflow(objPrevious["tenFlow"]) + tenFeat = self.netUpfeat(objPrevious["tenFeat"]) + + tenVolume = torch.nn.functional.leaky_relu( + input=FunctionCorrelation( + tenFirst=tenFirst, + tenSecond=backwarp( + tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp + ), + ), + negative_slope=0.1, + inplace=False, + ) + + tenFeat = torch.cat([tenVolume, tenFirst, tenFlow, tenFeat], 1) + + # end + + tenFeat = torch.cat([self.netOne(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netTwo(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netThr(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netFou(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netFiv(tenFeat), tenFeat], 1) + + tenFlow = self.netSix(tenFeat) + + return {"tenFlow": tenFlow, "tenFeat": tenFeat} + + # end + + # end + + class Refiner(torch.nn.Module): + def __init__(self): + super(Refiner, self).__init__() + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=2, + dilation=2, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=4, + dilation=4, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=96, + kernel_size=3, + stride=1, + padding=8, + dilation=8, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=64, + kernel_size=3, + stride=1, + padding=16, + dilation=16, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + ), + ) + + # end + + def forward(self, tenInput): + return self.netMain(tenInput) + + # end + + # end + + self.netExtractor = Extractor() + + self.netTwo = Decoder(2) + self.netThr = Decoder(3) + self.netFou = Decoder(4) + self.netFiv = Decoder(5) + self.netSix = Decoder(6) + + self.netRefiner = Refiner() + + self.load_state_dict( + { + strKey.replace("module", "net"): tenWeight + for strKey, tenWeight in torch.hub.load_state_dict_from_url( + url="http://content.sniklaus.com/github/pytorch-pwc/network-" + + "default" + + ".pytorch", + model_dir=get_ckpt_container_path(MODEL_TYPE) + ).items() + } + ) + + # end + + def forward(self, tenFirst, tenSecond, *args): + # optionally pass pre-extracted feature pyramid in as args + if len(args) == 0: + tenFirst = self.netExtractor(tenFirst) + tenSecond = self.netExtractor(tenSecond) + else: + tenFirst, tenSecond = args + + objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) + objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) + objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) + objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) + objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) + + return objEstimate["tenFlow"] + self.netRefiner(objEstimate["tenFeat"]) + + # end + + def extract_pyramid(self, tenFirst, tenSecond): + return self.netExtractor(tenFirst), self.netExtractor(tenSecond) + + def extract_pyramid_single(self, tenFirst): + return self.netExtractor(tenFirst) + + +# end + +netNetwork = None + +########################################################## + + +def estimate(tenFirst, tenSecond): + global netNetwork + + if netNetwork is None: + netNetwork = Network().cuda().eval() + # end + + assert tenFirst.shape[1] == tenSecond.shape[1] + assert tenFirst.shape[2] == tenSecond.shape[2] + + intWidth = tenFirst.shape[2] + intHeight = tenFirst.shape[1] + + assert ( + intWidth == 1024 + ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + assert ( + intHeight == 436 + ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + + tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) + tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) + + tenPreprocessedFirst = torch.nn.functional.interpolate( + input=tenPreprocessedFirst, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode="bilinear", + align_corners=False, + ) + tenPreprocessedSecond = torch.nn.functional.interpolate( + input=tenPreprocessedSecond, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode="bilinear", + align_corners=False, + ) + + tenFlow = 20.0 * torch.nn.functional.interpolate( + input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), + size=(intHeight, intWidth), + mode="bilinear", + align_corners=False, + ) + + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow[0, :, :, :].cpu() + + +# end + + +class Upsampler_8tap(nn.Module): + def __init__(self): + super(Upsampler_8tap, self).__init__() + filt_8tap = torch.tensor([[-1, 4, -11, 40, 40, -11, 4, -1]]).div(64) + self.filter = nn.Parameter(filt_8tap.repeat(3, 1, 1, 1), requires_grad=False) + + def forward(self, im): + b, c, h, w = im.shape + im_up = torch.zeros(b, c, h * 2, w * 2).to(im.device) + im_up[:, :, ::2, ::2] = im + + p = (8 - 1) // 2 + im_up_row = F.conv2d( + F.pad(im, pad=(p, p + 1, 0, 0), mode="reflect"), self.filter, groups=3 + ) + im_up[:, :, 0::2, 1::2] = im_up_row + im_up_col = torch.transpose( + F.conv2d( + F.pad(torch.transpose(im, 2, 3), pad=(p, p + 1, 0, 0), mode="reflect"), + self.filter, + groups=3, + ), + 2, + 3, + ) + im_up[:, :, 1::2, 0::2] = im_up_col + im_up_cross = F.conv2d( + F.pad(im_up[:, :, 1::2, ::2], pad=(p, p + 1, 0, 0), mode="reflect"), + self.filter, + groups=3, + ) + im_up[:, :, 1::2, 1::2] = im_up_cross + return im_up + +# end + + +model_urls = { + "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", +} + + +class Conv3DSimple(nn.Conv3d): + def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): + super(Conv3DSimple, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(3, 3, 3), + stride=stride, + padding=padding, + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride, temporal_stride): + if temporal_stride: + return (temporal_stride, stride, stride) + else: + return (stride, stride, stride) + + +class Conv2Plus1D(nn.Sequential): + def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): + super(Conv2Plus1D, self).__init__( + nn.Conv3d( + in_planes, + midplanes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ), + batchnorm(midplanes), + nn.ReLU(inplace=True), + nn.Conv3d( + midplanes, + out_planes, + kernel_size=(3, 1, 1), + stride=(stride, 1, 1), + padding=(padding, 0, 0), + bias=False, + ), + ) + + @staticmethod + def get_downsample_stride(stride): + return stride, stride, stride + + +class Conv3DNoTemporal(nn.Conv3d): + def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): + super(Conv3DNoTemporal, self).__init__( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ) + + @staticmethod + def get_downsample_stride(stride): + return 1, stride, stride + + +class SEGating(nn.Module): + def __init__(self, inplanes, reduction=16): + super().__init__() + + self.pool = nn.AdaptiveAvgPool3d(1) + self.attn_layer = nn.Sequential( + nn.Conv3d(inplanes, inplanes, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, x): + out = self.pool(x) + y = self.attn_layer(out) + return x * y + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + super(BasicBlock, self).__init__() + self.conv1 = nn.Sequential( + conv_builder(inplanes, planes, midplanes, stride), + batchnorm(planes), + nn.ReLU(inplace=True), + ) + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes), batchnorm(planes) + ) + self.fg = SEGating(planes) ## Feature Gating, from FLAVR + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.conv2(out) + out = self.fg(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): + super(Bottleneck, self).__init__() + midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) + + # 1x1x1 + self.conv1 = nn.Sequential( + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), + batchnorm(planes), + nn.ReLU(inplace=True), + ) + # Second kernel + self.conv2 = nn.Sequential( + conv_builder(planes, planes, midplanes, stride), + batchnorm(planes), + nn.ReLU(inplace=True), + ) + + # 1x1x1 + self.conv3 = nn.Sequential( + nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), + batchnorm(planes * self.expansion), + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BasicStem(nn.Sequential): + """The default conv-batchnorm-relu stem""" + + def __init__(self, outplanes=32): + super(BasicStem, self).__init__( + nn.Conv3d( + 3, + outplanes, + kernel_size=(3, 7, 7), + stride=(1, 2, 2), + padding=(1, 3, 3), + bias=False, + ), + batchnorm(outplanes), + nn.ReLU(inplace=True), + ) + + +class R2Plus1dStem(nn.Sequential): + """R(2+1)D stem is different than the default one as it uses separated 3D convolution""" + + def __init__(self): + super(R2Plus1dStem, self).__init__( + nn.Conv3d( + 3, + 45, + kernel_size=(1, 7, 7), + stride=(1, 2, 2), + padding=(0, 3, 3), + bias=False, + ), + batchnorm(45), + nn.ReLU(inplace=True), + nn.Conv3d( + 45, + 64, + kernel_size=(3, 1, 1), + stride=(1, 1, 1), + padding=(1, 0, 0), + bias=False, + ), + batchnorm(64), + nn.ReLU(inplace=True), + ) + + +class VideoResNet(nn.Module): + def __init__( + self, + block, + conv_makers, + layers, + stem, + zero_init_residual=False, + channels=[32, 64, 96, 128], + ): + """Generic resnet video generator. + + Args: + block (nn.Module): resnet building block + conv_makers (list(functions)): generator function for each layer + layers (List[int]): number of blocks per layer + stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. + zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. + """ + super(VideoResNet, self).__init__() + self.inplanes = channels[0] # output channel of first stem + + self.stem = stem() + + self.layer1 = self._make_layer( + block, conv_makers[0], channels[0], layers[0], stride=1 + ) + self.layer2 = self._make_layer( + block, conv_makers[1], channels[1], layers[1], stride=2, temporal_stride=1 + ) + self.layer3 = self._make_layer( + block, conv_makers[2], channels[2], layers[2], stride=2, temporal_stride=1 + ) + self.layer4 = self._make_layer( + block, conv_makers[3], channels[3], layers[3], stride=1, temporal_stride=1 + ) + + # init weights + self._initialize_weights() + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + + def forward(self, x): + tensorConv0 = self.stem(x) + tensorConv1 = self.layer1(tensorConv0) + tensorConv2 = self.layer2(tensorConv1) + tensorConv3 = self.layer3(tensorConv2) + tensorConv4 = self.layer4(tensorConv3) + return tensorConv0, tensorConv1, tensorConv2, tensorConv3, tensorConv4 + + def _make_layer( + self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None + ): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + ds_stride = conv_builder.get_downsample_stride(stride, temporal_stride) + downsample = nn.Sequential( + nn.Conv3d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=ds_stride, + bias=False, + ), + batchnorm(planes * block.expansion), + ) + stride = ds_stride + + layers = [] + layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) + + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_builder)) + + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv3d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm3d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _video_resnet(arch, pretrained=False, progress=True, **kwargs): + model = VideoResNet(**kwargs) + + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress, model_dir=get_ckpt_container_path(MODEL_TYPE)) + model.load_state_dict(state_dict) + return model + + +def r3d_18(bn=False, pretrained=False, progress=True, **kwargs): + """Construct 18 layer Resnet3D model as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: R3D-18 network + """ + + global batchnorm + if bn: + batchnorm = nn.BatchNorm3d + else: + batchnorm = identity + + return _video_resnet( + "r3d_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, + **kwargs, + ) + + +def mc3_18(bn=False, pretrained=False, progress=True, **kwargs): + """Constructor for 18 layer Mixed Convolution network as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: MC3 Network definition + """ + global batchnorm + if bn: + batchnorm = nn.BatchNorm3d + else: + batchnorm = identity + + return _video_resnet( + "mc3_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, + layers=[2, 2, 2, 2], + stem=BasicStem, + **kwargs, + ) + + +def r2plus1d_18(bn=False, pretrained=False, progress=True, **kwargs): + """Constructor for the 18 layer deep R(2+1)D network as in + https://arxiv.org/abs/1711.11248 + + Args: + pretrained (bool): If True, returns a model pre-trained on Kinetics-400 + progress (bool): If True, displays a progress bar of the download to stderr + + Returns: + nn.Module: R(2+1)D-18 network + """ + + global batchnorm + if bn: + batchnorm = nn.BatchNorm3d + else: + batchnorm = identity + + return _video_resnet( + "r2plus1d_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[2, 2, 2, 2], + stem=R2Plus1dStem, + **kwargs, + ) + + +class upConv3D(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose"): + super().__init__() + self.upmode = upmode + if self.upmode == "transpose": + self.upconv = nn.ModuleList( + [ + nn.ConvTranspose3d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + SEGating(out_ch), + batchnorm(out_ch), + ] + ) + else: + self.upconv = nn.ModuleList( + [ + nn.Upsample( + mode="trilinear", scale_factor=(1, 2, 2), align_corners=False + ), + nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=1), + SEGating(out_ch), + batchnorm(out_ch), + ] + ) + self.upconv = nn.Sequential(*self.upconv) + + def forward(self, x): + return self.upconv(x) + + +class Conv_3d(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + self.conv = nn.Sequential( + nn.Conv3d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ), + SEGating(out_ch), + batchnorm(out_ch), + ) + + def forward(self, x): + return self.conv(x) + + +def make_optimizer(args, my_model): + trainable = filter(lambda x: x.requires_grad, my_model.parameters()) + + if args.optimizer == "SGD": + optimizer_function = optim.SGD + kwargs = {"momentum": 0.9} + elif args.optimizer == "ADAM": + optimizer_function = optim.Adam + kwargs = {"betas": (0.9, 0.999), "eps": 1e-08} + elif args.optimizer == "ADAMax": + optimizer_function = optim.Adamax + kwargs = {"betas": (0.9, 0.999), "eps": 1e-08} + elif args.optimizer == "RMSprop": + optimizer_function = optim.RMSprop + kwargs = {"eps": 1e-08} + + kwargs["lr"] = args.lr + kwargs["weight_decay"] = args.weight_decay + + return optimizer_function(trainable, **kwargs) + + +def make_scheduler(args, my_optimizer): + if args.decay_type == "step": + scheduler = lrs.StepLR(my_optimizer, step_size=args.lr_decay, gamma=args.gamma) + elif args.decay_type.find("step") >= 0: + milestones = args.decay_type.split("_") + milestones.pop(0) + milestones = list(map(lambda x: int(x), milestones)) + scheduler = lrs.MultiStepLR( + my_optimizer, milestones=milestones, gamma=args.gamma + ) + elif args.decay_type == "plateau": + scheduler = lrs.ReduceLROnPlateau( + my_optimizer, + mode="max", + factor=args.gamma, + patience=args.patience, + threshold=0.01, # metric to be used is psnr + threshold_mode="abs", + verbose=True, + ) + + return scheduler + + +def gaussian_kernel(sz, sigma): + k = torch.arange(-(sz - 1) / 2, (sz + 1) / 2) + k = torch.exp(-1.0 / (2 * sigma**2) * k**2) + k = k.reshape(-1, 1) * k.reshape(1, -1) + k = k / torch.sum(k) + return k + + +def moduleNormalize(frame): + return torch.cat( + [ + (frame[:, 0:1, :, :] - 0.4631), + (frame[:, 1:2, :, :] - 0.4352), + (frame[:, 2:3, :, :] - 0.3990), + ], + 1, + ) + + +class FoldUnfold: + """ + Class to handle folding tensor frame into batch of patches and back to frame again + Thanks to Charlie Tan (charlie.tan.2019@bristol.ac.uk) for the earier version. + """ + + def __init__(self, height, width, patch_size, overlap): + if height % 2 or width % 2 or patch_size % 2 or overlap % 2: + print( + "only defined for even values of height, width, patch_size size and overlap, odd values will reconstruct incorrectly" + ) + return + + self.height = height + self.width = width + + self.patch_size = patch_size + self.overlap = overlap + self.stride = patch_size - overlap + + def fold_to_patches(self, *frames): + """ + args: frames -- list of (1,3,H,W) tensors + returns: list of (B,3,h,w) image patches + """ + + # number of blocks in each direction + n_blocks_h = (self.height // (self.stride)) + 1 + n_blocks_w = (self.width // (self.stride)) + 1 + + # how much to pad each edge by + self.pad_h = (self.stride * n_blocks_h + self.overlap - self.height) // 2 + self.pad_w = (self.stride * n_blocks_w + self.overlap - self.width) // 2 + self.height_pad = self.height + 2 * self.pad_h + self.width_pad = self.width + 2 * self.pad_w + + # pad the frames and unfold into patches + patches_list = [] + for i in range(len(frames)): + padded = F.pad( + frames[i], + (self.pad_w, self.pad_w, self.pad_h, self.pad_h), + mode="reflect", + ) + unfolded = F.unfold(padded, self.patch_size, stride=self.stride) + patches = unfolded.permute(2, 1, 0).reshape( + -1, 3, self.patch_size, self.patch_size + ) + patches_list.append(patches) + + return patches_list + + def unfold_to_frame(self, patches): + """ + args: patches -- tensor of shape (B,3,h,w) + returns: frame -- tensor of shape (1,3,H,W) + """ + + # reshape and permute back into [frames, chans * patch_size ** 2, num_patches] as expected by fold + frame_unfold = patches.reshape(-1, 3 * self.patch_size**2, 1).permute(2, 1, 0) + + # fold into tensor of shape pad_shape + frame_fold = F.fold( + frame_unfold, + (self.height_pad, self.width_pad), + self.patch_size, + stride=self.stride, + ) + + # unfold sums overlaps instead of averaging so tensor of ones unfolded and + # folded to track overlaps and take mean of overlapping pixels + ones = torch.ones_like(frame_fold) + ones_unfold = F.unfold(ones, self.patch_size, stride=self.stride) + + # divisor is tensor of shape pad_shape where each element is the number of values that have overlapped + # 1 = no overlaps + divisor = F.fold( + ones_unfold, + (self.height_pad, self.width_pad), + self.patch_size, + stride=self.stride, + ) + + # divide reconstructed frame by divisor + frame_div = frame_fold / divisor + + # crop frame to remove the padded areas + frame_crop = frame_div[ + :, :, self.pad_h : -self.pad_h, self.pad_w : -self.pad_w + ].clone() + + return frame_crop + + +def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt="420"): + if pix_fmt == "420": + multiplier = 1 + uv_factor = 2 + elif pix_fmt == "444": + multiplier = 2 + uv_factor = 1 + else: + print("Pixel format {} is not supported".format(pix_fmt)) + return + + if bit_depth == 8: + datatype = np.uint8 + stream.seek(iFrame * 1.5 * width * height * multiplier) + Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape( + (height, width) + ) + + # read chroma samples and upsample since original is 4:2:0 sampling + U = np.fromfile( + stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) + ).reshape((height // uv_factor, width // uv_factor)) + V = np.fromfile( + stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) + ).reshape((height // uv_factor, width // uv_factor)) + + else: + datatype = np.uint16 + stream.seek(iFrame * 3 * width * height * multiplier) + Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape( + (height, width) + ) + + U = np.fromfile( + stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) + ).reshape((height // uv_factor, width // uv_factor)) + V = np.fromfile( + stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) + ).reshape((height // uv_factor, width // uv_factor)) + + if pix_fmt == "420": + yuv = np.empty((height * 3 // 2, width), dtype=datatype) + yuv[0:height, :] = Y + + yuv[height : height + height // 4, :] = U.reshape(-1, width) + yuv[height + height // 4 :, :] = V.reshape(-1, width) + + if bit_depth != 8: + yuv = (yuv / (2**bit_depth - 1) * 255).astype(np.uint8) + + # convert to rgb + rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) + + else: + yvu = np.stack([Y, V, U], axis=2) + if bit_depth != 8: + yvu = (yvu / (2**bit_depth - 1) * 255).astype(np.uint8) + rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) + + return rgb + + +def quantize(imTensor): + return imTensor.clamp(0.0, 1.0).mul(255).round() + + +def tensor2rgb(tensor): + """ + Convert GPU Tensor to RGB image (numpy array) + """ + out = [] + for b in range(tensor.shape[0]): + out.append( + np.moveaxis(quantize(tensor[b]).cpu().detach().numpy(), 0, 2).astype( + np.uint8 + ) + ) + return np.array(out) # (B,H,W,C) + + +class Identity(nn.Module): + def __init__(self, *args): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class SEBlock(nn.Module): + def __init__(self, input_dim, reduction=16): + super(SEBlock, self).__init__() + mid = int(input_dim / reduction) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(input_dim, mid), + nn.ReLU(inplace=True), + nn.Linear(mid, input_dim), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class ResNextBlock(nn.Module): + def __init__( + self, down, cin, cout, ks, stride=1, groups=32, base_width=4, norm_layer=None + ): + super(ResNextBlock, self).__init__() + if norm_layer is None or norm_layer == "batch": + norm_layer = nn.BatchNorm2d + elif norm_layer == "identity": + norm_layer = Identity + width = int(cout * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = nn.Conv2d(cin, width, kernel_size=1, stride=1, bias=False) + self.bn1 = norm_layer(width) + if down: + self.conv2 = nn.Conv2d( + width, + width, + kernel_size=ks, + stride=stride, + padding=(ks - 1) // 2, + groups=groups, + bias=False, + ) + else: + self.conv2 = nn.ConvTranspose2d( + width, + width, + kernel_size=ks, + stride=stride, + padding=(ks - stride) // 2, + groups=groups, + bias=False, + ) + self.bn2 = norm_layer(width) + self.conv3 = nn.Conv2d(width, cout, kernel_size=1, stride=1, bias=False) + self.bn3 = norm_layer(cout) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if stride != 1 or cin != cout: + if down: + self.downsample = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size=1, stride=stride, bias=False), + norm_layer(cout), + ) + else: + self.downsample = nn.Sequential( + # ks = stride here s.t. resolution can be kept + nn.ConvTranspose2d( + cin, cout, kernel_size=2, stride=stride, bias=False + ), + norm_layer(cout), + ) + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class MultiScaleResNextBlock(nn.Module): + def __init__(self, down, cin, cout, ks_s, ks_l, stride, norm_layer): + super(MultiScaleResNextBlock, self).__init__() + self.resnext_small = ResNextBlock( + down, cin, cout // 2, ks_s, stride, norm_layer=norm_layer + ) + self.resnext_large = ResNextBlock( + down, cin, cout // 2, ks_l, stride, norm_layer=norm_layer + ) + self.attention = SEBlock(cout) + + def forward(self, tensorCombine): + out_small = self.resnext_small(tensorCombine) + out_large = self.resnext_large(tensorCombine) + out = torch.cat([out_small, out_large], 1) + out = self.attention(out) + return out + + +class UMultiScaleResNext(nn.Module): + def __init__( + self, channels=[64, 128, 256, 512], norm_layer="batch", inplanes=6, **kwargs + ): + super(UMultiScaleResNext, self).__init__() + self.conv1 = MultiScaleResNextBlock( + True, inplanes, channels[0], ks_s=3, ks_l=7, stride=2, norm_layer=norm_layer + ) + self.conv2 = MultiScaleResNextBlock( + True, + channels[0], + channels[1], + ks_s=3, + ks_l=7, + stride=2, + norm_layer=norm_layer, + ) + self.conv3 = MultiScaleResNextBlock( + True, + channels[1], + channels[2], + ks_s=3, + ks_l=5, + stride=2, + norm_layer=norm_layer, + ) + self.conv4 = MultiScaleResNextBlock( + True, + channels[2], + channels[3], + ks_s=3, + ks_l=5, + stride=2, + norm_layer=norm_layer, + ) + + self.deconv4 = MultiScaleResNextBlock( + True, + channels[3], + channels[3], + ks_s=3, + ks_l=5, + stride=1, + norm_layer=norm_layer, + ) + self.deconv3 = MultiScaleResNextBlock( + False, + channels[3], + channels[2], + ks_s=4, + ks_l=6, + stride=2, + norm_layer=norm_layer, + ) + self.deconv2 = MultiScaleResNextBlock( + False, + channels[2], + channels[1], + ks_s=4, + ks_l=8, + stride=2, + norm_layer=norm_layer, + ) + self.deconv1 = MultiScaleResNextBlock( + False, + channels[1], + channels[0], + ks_s=4, + ks_l=8, + stride=2, + norm_layer=norm_layer, + ) + + def forward(self, im0, im2): + tensorJoin = torch.cat([im0, im2], 1) # (B,6,H,W) + + tensorConv1 = self.conv1(tensorJoin) + tensorConv2 = self.conv2(tensorConv1) + tensorConv3 = self.conv3(tensorConv2) + tensorConv4 = self.conv4(tensorConv3) + + tensorDeconv4 = self.deconv4(tensorConv4) + tensorDeconv3 = self.deconv3(tensorDeconv4 + tensorConv4) + tensorDeconv2 = self.deconv2(tensorDeconv3 + tensorConv3) + tensorDeconv1 = self.deconv1(tensorDeconv2 + tensorConv2) + + return tensorDeconv1 + + +class MultiInputGridNet(nn.Module): + def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6): + super(MultiInputGridNet, self).__init__() + + self.n_row = n_row + self.n_col = n_col + self.n_chs = grid_chs + assert ( + len(grid_chs) == self.n_row + ), "should give num channels for each row (scale stream)" + assert ( + len(in_chs) == self.n_row + ), "should give input channels for each row (scale stream)" + + for r, n_ch in enumerate(self.n_chs): + setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch)) + for c in range(1, self.n_col): + setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): + for c in range(int(self.n_col / 2)): + setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): + for c in range(int(self.n_col / 2)): + setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) + + self.lateral_final = LateralBlock(self.n_chs[0], out_chs) + + def forward(self, *args): + assert len(args) == self.n_row + + # extensible, memory-efficient + cur_col = list(args) + for c in range(int(self.n_col / 2)): + for r in range(self.n_row): + cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) + if r != 0: + cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) + + for c in range(int(self.n_col / 2), self.n_col): + for r in range(self.n_row - 1, -1, -1): + cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) + if r != self.n_row - 1: + cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( + cur_col[r + 1] + ) + + return self.lateral_final(cur_col[0]) + + +class MIMOGridNet(nn.Module): + def __init__( + self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6, outrow=(0, 1, 2) + ): + super(MIMOGridNet, self).__init__() + + self.n_row = n_row + self.n_col = n_col + self.n_chs = grid_chs + self.outrow = outrow + assert ( + len(grid_chs) == self.n_row + ), "should give num channels for each row (scale stream)" + assert ( + len(in_chs) == self.n_row + ), "should give input channels for each row (scale stream)" + assert len(out_chs) == len( + self.outrow + ), "should give out channels for each output row (scale stream)" + + for r, n_ch in enumerate(self.n_chs): + setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch)) + for c in range(1, self.n_col): + setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): + for c in range(int(self.n_col / 2)): + setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): + for c in range(int(self.n_col / 2)): + setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) + + for i, r in enumerate(outrow): + setattr(self, f"lateral_final_{r}", LateralBlock(self.n_chs[r], out_chs[i])) + + def forward(self, *args): + assert len(args) == self.n_row + + # extensible, memory-efficient + cur_col = list(args) + for c in range(int(self.n_col / 2)): + for r in range(self.n_row): + cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) + if r != 0: + cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) + + for c in range(int(self.n_col / 2), self.n_col): + for r in range(self.n_row - 1, -1, -1): + cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) + if r != self.n_row - 1: + cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( + cur_col[r + 1] + ) + + out = [] + for r in self.outrow: + out.append(getattr(self, f"lateral_final_{r}")(cur_col[r])) + + return out + + +class GeneralGridNet(nn.Module): + def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6): + super(GeneralGridNet, self).__init__() + + self.n_row = n_row + self.n_col = n_col + self.n_chs = grid_chs + assert ( + len(grid_chs) == self.n_row + ), "should give num channels for each row (scale stream)" + + for r, n_ch in enumerate(self.n_chs): + if r == 0: + setattr(self, f"lateral_{r}_0", LateralBlock(in_chs, n_ch)) + for c in range(1, self.n_col): + setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): + for c in range(int(self.n_col / 2)): + setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): + for c in range(int(self.n_col / 2)): + setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) + + self.lateral_final = LateralBlock(self.n_chs[0], out_chs) + + def forward(self, x): + cur_col = [x] + [None] * (self.n_row - 1) + for c in range(int(self.n_col / 2)): + for r in range(self.n_row): + if cur_col[r] != None: + cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) + else: + cur_col[r] = 0.0 + if r != 0: + cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) + + for c in range(int(self.n_col / 2), self.n_col): + for r in range(self.n_row - 1, -1, -1): + cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) + if r != self.n_row - 1: + cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( + cur_col[r + 1] + ) + + return self.lateral_final(cur_col[0]) + + +class GridNet(nn.Module): + def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96)): + super(GridNet, self).__init__() + + self.n_row = 3 + self.n_col = 6 + self.n_chs = grid_chs + assert ( + len(grid_chs) == self.n_row + ), "should give num channels for each row (scale stream)" + + self.lateral_init = LateralBlock(in_chs, self.n_chs[0]) + + for r, n_ch in enumerate(self.n_chs): + for c in range(self.n_col - 1): + setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): + for c in range(int(self.n_col / 2)): + setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) + + for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): + for c in range(int(self.n_col / 2)): + setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) + + self.lateral_final = LateralBlock(self.n_chs[0], out_chs) + + def forward(self, x): + state_00 = self.lateral_init(x) + state_10 = self.down_0_0(state_00) + state_20 = self.down_1_0(state_10) + + state_01 = self.lateral_0_0(state_00) + state_11 = self.down_0_1(state_01) + self.lateral_1_0(state_10) + state_21 = self.down_1_1(state_11) + self.lateral_2_0(state_20) + + state_02 = self.lateral_0_1(state_01) + state_12 = self.down_0_2(state_02) + self.lateral_1_1(state_11) + state_22 = self.down_1_2(state_12) + self.lateral_2_1(state_21) + + state_23 = self.lateral_2_2(state_22) + state_13 = self.up_1_0(state_23) + self.lateral_1_2(state_12) + state_03 = self.up_0_0(state_13) + self.lateral_0_2(state_02) + + state_24 = self.lateral_2_3(state_23) + state_14 = self.up_1_1(state_24) + self.lateral_1_3(state_13) + state_04 = self.up_0_1(state_14) + self.lateral_0_3(state_03) + + state_25 = self.lateral_2_4(state_24) + state_15 = self.up_1_2(state_25) + self.lateral_1_4(state_14) + state_05 = self.up_0_2(state_15) + self.lateral_0_4(state_04) + + return self.lateral_final(state_05) + + +class LateralBlock(nn.Module): + def __init__(self, ch_in, ch_out): + super(LateralBlock, self).__init__() + self.f = nn.Sequential( + nn.PReLU(), + nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), + ) + if ch_in != ch_out: + self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1) + + def forward(self, x): + fx = self.f(x) + if fx.shape[1] != x.shape[1]: + x = self.conv(x) + return fx + x + + +class DownSamplingBlock(nn.Module): + def __init__(self, ch_in, ch_out): + super(DownSamplingBlock, self).__init__() + self.f = nn.Sequential( + nn.PReLU(), + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=2, padding=1), + nn.PReLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), + ) + + def forward(self, x): + return self.f(x) + + +class UpSamplingBlock(nn.Module): + def __init__(self, ch_in, ch_out): + super(UpSamplingBlock, self).__init__() + self.f = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + nn.PReLU(), + nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), + nn.PReLU(), + nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), + ) + + def forward(self, x): + return self.f(x) + +# end + + +class Network(torch.nn.Module): + def __init__(self): + super(Network, self).__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super(Extractor, self).__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, + out_channels=96, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=96, + out_channels=128, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=128, + out_channels=196, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=196, + out_channels=196, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=196, + out_channels=196, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + # end + + def forward(self, tenInput): + tenOne = self.netOne(tenInput) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = self.netFou(tenThr) + tenFiv = self.netFiv(tenFou) + tenSix = self.netSix(tenFiv) + + return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] + + # end + + # end + + class Decoder(torch.nn.Module): + def __init__(self, intLevel): + super(Decoder, self).__init__() + + intPrevious = [ + None, + None, + 81 + 32 + 2 + 2, + 81 + 64 + 2 + 2, + 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, + 81, + None, + ][intLevel + 1] + intCurrent = [ + None, + None, + 81 + 32 + 2 + 2, + 81 + 64 + 2 + 2, + 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, + 81, + None, + ][intLevel + 0] + + if intLevel < 6: + self.netUpflow = torch.nn.ConvTranspose2d( + in_channels=2, + out_channels=2, + kernel_size=4, + stride=2, + padding=1, + ) + if intLevel < 6: + self.netUpfeat = torch.nn.ConvTranspose2d( + in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, + out_channels=2, + kernel_size=4, + stride=2, + padding=1, + ) + if intLevel < 6: + self.fltBackwarp = [None, None, None, 5.0, 2.5, 1.25, 0.625, None][ + intLevel + 1 + ] + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128 + 96, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128 + 96 + 64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + ) + ) + + # end + + def forward(self, tenFirst, tenSecond, objPrevious): + tenFlow = None + tenFeat = None + + if objPrevious is None: + tenFlow = None + tenFeat = None + + tenVolume = torch.nn.functional.leaky_relu( + input=FunctionCorrelation( + tenFirst=tenFirst, tenSecond=tenSecond + ), + negative_slope=0.1, + inplace=False, + ) + + tenFeat = torch.cat([tenVolume], 1) + + elif objPrevious is not None: + tenFlow = self.netUpflow(objPrevious["tenFlow"]) + tenFeat = self.netUpfeat(objPrevious["tenFeat"]) + + tenVolume = torch.nn.functional.leaky_relu( + input=FunctionCorrelation( + tenFirst=tenFirst, + tenSecond=backwarp( + tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp + ), + ), + negative_slope=0.1, + inplace=False, + ) + + tenFeat = torch.cat([tenVolume, tenFirst, tenFlow, tenFeat], 1) + + # end + + tenFeat = torch.cat([self.netOne(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netTwo(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netThr(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netFou(tenFeat), tenFeat], 1) + tenFeat = torch.cat([self.netFiv(tenFeat), tenFeat], 1) + + tenFlow = self.netSix(tenFeat) + + return {"tenFlow": tenFlow, "tenFeat": tenFeat} + + # end + + # end + + class Refiner(torch.nn.Module): + def __init__(self): + super(Refiner, self).__init__() + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=2, + dilation=2, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=4, + dilation=4, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=96, + kernel_size=3, + stride=1, + padding=8, + dilation=8, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=64, + kernel_size=3, + stride=1, + padding=16, + dilation=16, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + ), + ) + + # end + + def forward(self, tenInput): + return self.netMain(tenInput) + + # end + + # end + + self.netExtractor = Extractor() + + self.netTwo = Decoder(2) + self.netThr = Decoder(3) + self.netFou = Decoder(4) + self.netFiv = Decoder(5) + self.netSix = Decoder(6) + + self.netRefiner = Refiner() + + self.load_state_dict( + { + strKey.replace("module", "net"): tenWeight + for strKey, tenWeight in torch.hub.load_state_dict_from_url( + url="http://content.sniklaus.com/github/pytorch-pwc/network-" + + "default" + + ".pytorch", + model_dir=get_ckpt_container_path(MODEL_TYPE) + ).items() + } + ) + + # end + + def forward(self, tenFirst, tenSecond, *args): + # optionally pass pre-extracted feature pyramid in as args + if len(args) == 0: + tenFirst = self.netExtractor(tenFirst) + tenSecond = self.netExtractor(tenSecond) + else: + tenFirst, tenSecond = args + + objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) + objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) + objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) + objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) + objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) + + return objEstimate["tenFlow"] + self.netRefiner(objEstimate["tenFeat"]) + + # end + + def extract_pyramid(self, tenFirst, tenSecond): + return self.netExtractor(tenFirst), self.netExtractor(tenSecond) + + def extract_pyramid_single(self, tenFirst): + return self.netExtractor(tenFirst) + + +# end + +netNetwork = None + +########################################################## + + +def estimate(tenFirst, tenSecond): + global netNetwork + + if netNetwork is None: + netNetwork = Network().cuda().eval() + # end + + assert tenFirst.shape[1] == tenSecond.shape[1] + assert tenFirst.shape[2] == tenSecond.shape[2] + + intWidth = tenFirst.shape[2] + intHeight = tenFirst.shape[1] + + assert ( + intWidth == 1024 + ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + assert ( + intHeight == 436 + ) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + + tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) + tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) + + tenPreprocessedFirst = torch.nn.functional.interpolate( + input=tenPreprocessedFirst, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode="bilinear", + align_corners=False, + ) + tenPreprocessedSecond = torch.nn.functional.interpolate( + input=tenPreprocessedSecond, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode="bilinear", + align_corners=False, + ) + + tenFlow = 20.0 * torch.nn.functional.interpolate( + input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), + size=(intHeight, intWidth), + mode="bilinear", + align_corners=False, + ) + + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow[0, :, :, :].cpu() + + +# end + + +class UNet3d_18(nn.Module): + def __init__(self, channels=[32, 64, 96, 128], bn=True): + super(UNet3d_18, self).__init__() + growth = 2 # since concatenating previous outputs + upmode = "transpose" # use transposeConv to upsample + + self.channels = channels + + self.lrelu = nn.LeakyReLU(0.2, True) + + self.encoder = r3d_18(bn=bn, channels=channels) + + self.decoder = nn.Sequential( + Conv_3d( + channels[::-1][0], + channels[::-1][1], + kernel_size=3, + padding=1, + bias=True, + ), + upConv3D( + channels[::-1][1] * growth, + channels[::-1][2], + kernel_size=(3, 4, 4), + stride=(1, 2, 2), + padding=(1, 1, 1), + upmode=upmode, + ), + upConv3D( + channels[::-1][2] * growth, + channels[::-1][3], + kernel_size=(3, 4, 4), + stride=(1, 2, 2), + padding=(1, 1, 1), + upmode=upmode, + ), + Conv_3d( + channels[::-1][3] * growth, + channels[::-1][3], + kernel_size=3, + padding=1, + bias=True, + ), + upConv3D( + channels[::-1][3] * growth, + channels[::-1][3], + kernel_size=(3, 4, 4), + stride=(1, 2, 2), + padding=(1, 1, 1), + upmode=upmode, + ), + ) + + self.feature_fuse = nn.Sequential( + *( + [ + nn.Conv2d( + channels[::-1][3] * 5, + channels[::-1][3], + kernel_size=1, + stride=1, + bias=False, + ) + ] + + [nn.BatchNorm2d(channels[::-1][3]) if bn else Identity] + ) + ) + + self.outconv = nn.Sequential( + nn.ReflectionPad2d(3), + nn.Conv2d(channels[::-1][3], 3, kernel_size=7, stride=1, padding=0), + ) + + def forward(self, im1, im3, im5, im7, im4_tilde): + images = torch.stack((im1, im3, im4_tilde, im5, im7), dim=2) + + x_0, x_1, x_2, x_3, x_4 = self.encoder(images) + + dx_3 = self.lrelu(self.decoder[0](x_4)) + dx_3 = torch.cat([dx_3, x_3], dim=1) + + dx_2 = self.lrelu(self.decoder[1](dx_3)) + dx_2 = torch.cat([dx_2, x_2], dim=1) + + dx_1 = self.lrelu(self.decoder[2](dx_2)) + dx_1 = torch.cat([dx_1, x_1], dim=1) + + dx_0 = self.lrelu(self.decoder[3](dx_1)) + dx_0 = torch.cat([dx_0, x_0], dim=1) + + dx_out = self.lrelu(self.decoder[4](dx_0)) + dx_out = torch.cat(torch.unbind(dx_out, 2), 1) + + out = self.lrelu(self.feature_fuse(dx_out)) + out = self.outconv(out) + + return out + + +class KernelEstimation(torch.nn.Module): + def __init__(self, kernel_size): + super(KernelEstimation, self).__init__() + self.kernel_size = kernel_size + + def Subnet_offset(ks): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + torch.nn.Conv2d( + in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + ) + + def Subnet_weight(ks): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + torch.nn.Conv2d( + in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.Softmax(dim=1), + ) + + def Subnet_offset_ds(ks): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + ) + + def Subnet_weight_ds(ks): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.Softmax(dim=1), + ) + + def Subnet_offset_us(ks): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True), + torch.nn.Conv2d( + in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + ) + + def Subnet_weight_us(ks): + return torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d( + in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.ReLU(inplace=False), + torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True), + torch.nn.Conv2d( + in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 + ), + torch.nn.Softmax(dim=1), + ) + + self.moduleWeight1_ds = Subnet_weight_ds(self.kernel_size**2) + self.moduleAlpha1_ds = Subnet_offset_ds(self.kernel_size**2) + self.moduleBeta1_ds = Subnet_offset_ds(self.kernel_size**2) + self.moduleWeight2_ds = Subnet_weight_ds(self.kernel_size**2) + self.moduleAlpha2_ds = Subnet_offset_ds(self.kernel_size**2) + self.moduleBeta2_ds = Subnet_offset_ds(self.kernel_size**2) + + self.moduleWeight1 = Subnet_weight(self.kernel_size**2) + self.moduleAlpha1 = Subnet_offset(self.kernel_size**2) + self.moduleBeta1 = Subnet_offset(self.kernel_size**2) + self.moduleWeight2 = Subnet_weight(self.kernel_size**2) + self.moduleAlpha2 = Subnet_offset(self.kernel_size**2) + self.moduleBeta2 = Subnet_offset(self.kernel_size**2) + + self.moduleWeight1_us = Subnet_weight_us(self.kernel_size**2) + self.moduleAlpha1_us = Subnet_offset_us(self.kernel_size**2) + self.moduleBeta1_us = Subnet_offset_us(self.kernel_size**2) + self.moduleWeight2_us = Subnet_weight_us(self.kernel_size**2) + self.moduleAlpha2_us = Subnet_offset_us(self.kernel_size**2) + self.moduleBeta2_us = Subnet_offset_us(self.kernel_size**2) + + def forward(self, tensorCombine): + # Frame 0 + Weight1_ds = self.moduleWeight1_ds(tensorCombine) + Weight1 = self.moduleWeight1(tensorCombine) + Weight1_us = self.moduleWeight1_us(tensorCombine) + Alpha1_ds = self.moduleAlpha1_ds(tensorCombine) + Alpha1 = self.moduleAlpha1(tensorCombine) + Alpha1_us = self.moduleAlpha1_us(tensorCombine) + Beta1_ds = self.moduleBeta1_ds(tensorCombine) + Beta1 = self.moduleBeta1(tensorCombine) + Beta1_us = self.moduleBeta1_us(tensorCombine) + + # Frame 2 + Weight2_ds = self.moduleWeight2_ds(tensorCombine) + Weight2 = self.moduleWeight2(tensorCombine) + Weight2_us = self.moduleWeight2_us(tensorCombine) + Alpha2_ds = self.moduleAlpha2_ds(tensorCombine) + Alpha2 = self.moduleAlpha2(tensorCombine) + Alpha2_us = self.moduleAlpha2_us(tensorCombine) + Beta2_ds = self.moduleBeta2_ds(tensorCombine) + Beta2 = self.moduleBeta2(tensorCombine) + Beta2_us = self.moduleBeta2_us(tensorCombine) + + return ( + Weight1_ds, + Alpha1_ds, + Beta1_ds, + Weight2_ds, + Alpha2_ds, + Beta2_ds, + Weight1, + Alpha1, + Beta1, + Weight2, + Alpha2, + Beta2, + Weight1_us, + Alpha1_us, + Beta1_us, + Weight2_us, + Alpha2_us, + Beta2_us, + ) + + +class STMFNet_Model(torch.nn.Module): + def __init__(self): + super(STMFNet_Model, self).__init__() + + class Metric(torch.nn.Module): + def __init__(self): + super(Metric, self).__init__() + self.paramScale = torch.nn.Parameter(-torch.ones(1, 1, 1, 1)) + + def forward(self, tenFirst, tenSecond, tenFlow): + return self.paramScale * F.l1_loss( + input=tenFirst, + target=backwarp(tenSecond, tenFlow), + reduction="none", + ).mean(1, True) + + self.kernel_size = 5 + self.dilation = 1 + self.featc = [64, 128, 256, 512] + self.featnorm = "batch" + self.finetune_pwc = False + + self.kernel_pad = int(((self.kernel_size - 1) * self.dilation) / 2.0) + + self.feature_extractor = UMultiScaleResNext( + self.featc, norm_layer=self.featnorm + ) + + self.get_kernel = KernelEstimation(self.kernel_size) + + self.modulePad = torch.nn.ReplicationPad2d( + [self.kernel_pad, self.kernel_pad, self.kernel_pad, self.kernel_pad] + ) + + self.moduleAdaCoF = FunctionAdaCoF.apply + + self.gauss_kernel = torch.nn.Parameter( + gaussian_kernel(5, 0.5).repeat(3, 1, 1, 1), requires_grad=False + ) + + self.upsampler = Upsampler_8tap() + + self.scale_synthesis = MIMOGridNet( + (6, 6 + 6, 6), (3,), grid_chs=(32, 64, 96), n_row=3, n_col=4, outrow=(1,) + ) + + self.flow_estimator = PWCNet() + + self.softsplat = ModuleSoftsplat(strType="softmax") + + self.metric = Metric() + + self.dyntex_generator = UNet3d_18(bn=self.featnorm) + + # freeze weights of PWCNet if not finetuning it + if not self.finetune_pwc: + for param in self.flow_estimator.parameters(): + param.requires_grad = False + + def forward(self, I0, I1, I2, I3): + h0 = int(list(I1.size())[2]) + w0 = int(list(I1.size())[3]) + h2 = int(list(I2.size())[2]) + w2 = int(list(I2.size())[3]) + if h0 != h2 or w0 != w2: + sys.exit("Frame sizes do not match") + + h_padded = False + w_padded = False + if h0 % 128 != 0: + pad_h = 128 - (h0 % 128) + I0 = F.pad(I0, (0, 0, 0, pad_h), mode="reflect") + I1 = F.pad(I1, (0, 0, 0, pad_h), mode="reflect") + I2 = F.pad(I2, (0, 0, 0, pad_h), mode="reflect") + I3 = F.pad(I3, (0, 0, 0, pad_h), mode="reflect") + h_padded = True + + if w0 % 128 != 0: + pad_w = 128 - (w0 % 128) + I0 = F.pad(I0, (0, pad_w, 0, 0), mode="reflect") + I1 = F.pad(I1, (0, pad_w, 0, 0), mode="reflect") + I2 = F.pad(I2, (0, pad_w, 0, 0), mode="reflect") + I3 = F.pad(I3, (0, pad_w, 0, 0), mode="reflect") + w_padded = True + + feats = self.feature_extractor(moduleNormalize(I1), moduleNormalize(I2)) + kernelest = self.get_kernel(feats) + Weight1_ds, Alpha1_ds, Beta1_ds, Weight2_ds, Alpha2_ds, Beta2_ds = kernelest[:6] + Weight1, Alpha1, Beta1, Weight2, Alpha2, Beta2 = kernelest[6:12] + Weight1_us, Alpha1_us, Beta1_us, Weight2_us, Alpha2_us, Beta2_us = kernelest[ + 12: + ] + + # Original scale + tensorAdaCoF1 = ( + self.moduleAdaCoF(self.modulePad(I1), Weight1, Alpha1, Beta1, self.dilation) + * 1.0 + ) + tensorAdaCoF2 = ( + self.moduleAdaCoF(self.modulePad(I2), Weight2, Alpha2, Beta2, self.dilation) + * 1.0 + ) + + # 1/2 downsampled version + c, h, w = I1.shape[1:] + p = (self.gauss_kernel.shape[-1] - 1) // 2 + I1_blur = F.conv2d( + F.pad(I1, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c + ) + I2_blur = F.conv2d( + F.pad(I2, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c + ) + I1_ds = F.interpolate( + I1_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False + ) + I2_ds = F.interpolate( + I2_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False + ) + tensorAdaCoF1_ds = ( + self.moduleAdaCoF( + self.modulePad(I1_ds), Weight1_ds, Alpha1_ds, Beta1_ds, self.dilation + ) + * 1.0 + ) + tensorAdaCoF2_ds = ( + self.moduleAdaCoF( + self.modulePad(I2_ds), Weight2_ds, Alpha2_ds, Beta2_ds, self.dilation + ) + * 1.0 + ) + + # x2 upsampled version + I1_us = self.upsampler(I1) + I2_us = self.upsampler(I2) + tensorAdaCoF1_us = ( + self.moduleAdaCoF( + self.modulePad(I1_us), Weight1_us, Alpha1_us, Beta1_us, self.dilation + ) + * 1.0 + ) + tensorAdaCoF2_us = ( + self.moduleAdaCoF( + self.modulePad(I2_us), Weight2_us, Alpha2_us, Beta2_us, self.dilation + ) + * 1.0 + ) + + # use softsplat for refinement + pyramid0, pyramid2 = self.flow_estimator.extract_pyramid(I1, I2) + flow_0_2 = 20 * self.flow_estimator(I1, I2, pyramid0, pyramid2) + flow_0_2 = F.interpolate( + flow_0_2, size=(h, w), mode="bilinear", align_corners=False + ) + flow_2_0 = 20 * self.flow_estimator(I2, I1, pyramid2, pyramid0) + flow_2_0 = F.interpolate( + flow_2_0, size=(h, w), mode="bilinear", align_corners=False + ) + metric_0_2 = self.metric(I1, I2, flow_0_2) + metric_2_0 = self.metric(I2, I1, flow_2_0) + tensorSoftsplat0 = self.softsplat(I1, 0.5 * flow_0_2, metric_0_2) + tensorSoftsplat2 = self.softsplat(I2, 0.5 * flow_2_0, metric_2_0) + + # synthesize multiple scales + tensorCombine_us = torch.cat([tensorAdaCoF1_us, tensorAdaCoF2_us], dim=1) + tensorCombine = torch.cat( + [tensorAdaCoF1, tensorAdaCoF2, tensorSoftsplat0, tensorSoftsplat2], dim=1 + ) + tensorCombine_ds = torch.cat([tensorAdaCoF1_ds, tensorAdaCoF2_ds], dim=1) + output_tilde = self.scale_synthesis( + tensorCombine_us, tensorCombine, tensorCombine_ds + )[0] + + # generate dynamic texture + dyntex = self.dyntex_generator(I0, I1, I2, I3, output_tilde) + output = output_tilde + dyntex + + if h_padded: + output = output[:, :, 0:h0, :] + if w_padded: + output = output[:, :, :, 0:w0] + + if self.training: + return {"frame1": output} + else: + return output \ No newline at end of file diff --git a/vfi_models/xvfi/__init__.py b/vfi_models/xvfi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e52fc1a85dfadc84dd517875abd7d592e3c087e2 --- /dev/null +++ b/vfi_models/xvfi/__init__.py @@ -0,0 +1,115 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import einops +from torch.utils.data import DataLoader +import pathlib +from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, InterpolationStateList +import typing +from comfy.model_management import get_torch_device + +CKPT_CONFIGS = { + "XVFInet_X4K1000FPS_exp1_latest.pt": { + "module_scale_factor": 4, + "S_trn": 3, + "S_tst": 5 + }, + "XVFInet_Vimeo_exp1_latest.pt": { + "module_scale_factor": 2, + "S_trn": 1, + "S_tst": 1 + } +} + +class XVFI_Inference(nn.Module): + def __init__(self, model_path, model_config) -> None: + super(XVFI_Inference, self).__init__() + from .xvfi_arch import XVFInet, weights_init + model_config = model_config + args = argparse.Namespace( + gpu=get_torch_device(), + nf=64, + **model_config, + img_ch=3, + ) + self.model = XVFInet(args).apply(weights_init).to(get_torch_device()) + self.model.load_state_dict(torch.load(model_path, map_location=get_torch_device())["state_dict_Model"]) + + def forward(self, I0, I1, timestep): + #"Real" inference is called "test_custom" in the original repo + #https://github.com/JihyongOh/XVFI/blob/main/utils.py#L434 + #https://github.com/JihyongOh/XVFI/blob/main/main.py#L336 + + x = torch.stack([I0, I1], dim=0) + x = einops.rearrange(x, "t b c h w -> b c t h w") + return self.model(x, timestep, is_training=False) + +MODEL_TYPE = pathlib.Path(__file__).parent.name + +class XVFI: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (list(CKPT_CONFIGS.keys()), ), + "frames": ("IMAGE", ), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 100}), + "multipler": ("INT", {"default": 2, "min": 2, "max": 1000}), + }, + "optional": { + "optional_interpolation_states": ("INTERPOLATION_STATES", ), + } + } + + RETURN_TYPES = ("IMAGE", ) + FUNCTION = "vfi" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def vfi( + self, + ckpt_name: typing.AnyStr, + frames: torch.Tensor, + batch_size: typing.SupportsInt = 1, + multipler: typing.SupportsInt = 2, + optional_interpolation_states: InterpolationStateList = None + ): + model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name) + ckpt_config = CKPT_CONFIGS[ckpt_name] + global model + model = XVFI_Inference(model_path, ckpt_config) + + frames = preprocess_frames(frames) + #https://github.com/JihyongOh/XVFI/blob/main/main.py#L314 + divide = 2 ** (ckpt_config["S_tst"]) * ckpt_config["module_scale_factor"] * 4 + B, C, H, W = frames.size() + H_padding = (divide - H % divide) % divide + W_padding = (divide - W % divide) % divide + if H_padding != 0 or W_padding != 0: + frames = F.pad(frames, (0, W_padding, 0, H_padding), "constant") + + frame_dict = { + str(i): frames[i].unsqueeze(0) for i in range(frames.shape[0]) + } + + if optional_interpolation_states is None: + interpolation_states = [True] * (frames.shape[0] - 1) + else: + interpolation_states = optional_interpolation_states + + enabled_former_idxs = [i for i, state in enumerate(interpolation_states) if state] + former_idxs_loader = DataLoader(enabled_former_idxs, batch_size=batch_size) + + for former_idxs_batch in former_idxs_loader: + for middle_i in range(1, multipler): + _middle_frames = model( + frames[former_idxs_batch], + frames[former_idxs_batch + 1], + timestep=torch.tensor([middle_i/multipler]).repeat(len(former_idxs_batch)).unsqueeze(1).to(get_torch_device()) + ) + for i, former_idx in enumerate(former_idxs_batch): + frame_dict[f'{former_idx}.{middle_i}'] = _middle_frames[i].unsqueeze(0) + + out_frames = torch.cat([frame_dict[key] for key in sorted(frame_dict.keys())], dim=0)[:, :, :H, :W] + return (postprocess_frames(out_frames), ) + diff --git a/vfi_models/xvfi/xvfi_arch.py b/vfi_models/xvfi/xvfi_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..dae013e749616c9b6b2e2767db08b8c8dab6046f --- /dev/null +++ b/vfi_models/xvfi/xvfi_arch.py @@ -0,0 +1,506 @@ +import functools, random +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from torch.nn import init +from comfy.model_management import get_torch_device + + + +class XVFInet(nn.Module): + + def __init__(self, args): + super(XVFInet, self).__init__() + self.args = args + self.device = get_torch_device() + self.nf = args.nf + self.scale = args.module_scale_factor + self.vfinet = VFInet(args) + self.lrelu = nn.ReLU() + self.in_channels = 3 + self.channel_converter = nn.Sequential( + nn.Conv3d(self.in_channels, self.nf, [1, 3, 3], [1, 1, 1], [0, 1, 1]), + nn.ReLU()) + + self.rec_ext_ds_module = [self.channel_converter] + self.rec_ext_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1]) + for _ in range(int(np.log2(self.scale))): + self.rec_ext_ds_module.append(self.rec_ext_ds) + self.rec_ext_ds_module.append(nn.ReLU()) + self.rec_ext_ds_module.append(nn.Conv3d(self.nf, self.nf, [1, 3, 3], 1, [0, 1, 1])) + self.rec_ext_ds_module.append(RResBlock2D_3D(args, T_reduce_flag=False)) + self.rec_ext_ds_module = nn.Sequential(*self.rec_ext_ds_module) + + self.rec_ctx_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1]) + + print("The lowest scale depth for training (S_trn): ", self.args.S_trn) + print("The lowest scale depth for test (S_tst): ", self.args.S_tst) + + def forward(self, x, t_value, is_training=True): + ''' + x shape : [B,C,T,H,W] + t_value shape : [B,1] ############### + ''' + B, C, T, H, W = x.size() + B2, C2 = t_value.size() + assert C2 == 1, "t_value shape is [B,]" + assert T % 2 == 0, "T must be an even number" + t_value = t_value.view(B, 1, 1, 1) + + flow_l = None + feat_x = self.rec_ext_ds_module(x) + feat_x_list = [feat_x] + self.lowest_depth_level = self.args.S_trn if is_training else self.args.S_tst + for level in range(1, self.lowest_depth_level+1): + feat_x = self.rec_ctx_ds(feat_x) + feat_x_list.append(feat_x) + + if is_training: + out_l_list = [] + flow_refine_l_list = [] + out_l, flow_l, flow_refine_l = self.vfinet(x, feat_x_list[self.args.S_trn], flow_l, t_value, level=self.args.S_trn, is_training=True) + out_l_list.append(out_l) + flow_refine_l_list.append(flow_refine_l) + for level in range(self.args.S_trn-1, 0, -1): ## self.args.S_trn, self.args.S_trn-1, ..., 1. level 0 is not included + out_l, flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=True) + out_l_list.append(out_l) + out_l, flow_l, flow_refine_l, occ_0_l0 = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=True) + out_l_list.append(out_l) + flow_refine_l_list.append(flow_refine_l) + return out_l_list[::-1], flow_refine_l_list[::-1], occ_0_l0, torch.mean(x, dim=2) # out_l_list should be reversed. [out_l0, out_l1, ...] + + else: # Testing + for level in range(self.args.S_tst, 0, -1): ## self.args.S_tst, self.args.S_tst-1, ..., 1. level 0 is not included + flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=False) + out_l = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=False) + return out_l + + +class VFInet(nn.Module): + + def __init__(self, args): + super(VFInet, self).__init__() + self.args = args + self.device = get_torch_device() + self.nf = args.nf + self.scale = args.module_scale_factor + self.in_channels = 3 + + self.conv_flow_bottom = nn.Sequential( + nn.Conv2d(2*self.nf, 2*self.nf, [4,4], 2, [1,1]), + nn.ReLU(), + nn.Conv2d(2*self.nf, 4*self.nf, [4,4], 2, [1,1]), + nn.ReLU(), + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]), + nn.ReLU(), + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]), + nn.ReLU(), + nn.Conv2d(self.nf, 6, [3,3], 1, [1,1]), + ) + + self.conv_flow1 = nn.Conv2d(2*self.nf, self.nf, [3, 3], 1, [1, 1]) + + self.conv_flow2 = nn.Sequential( + nn.Conv2d(2*self.nf + 4, 2 * self.nf, [4, 4], 2, [1, 1]), + nn.ReLU(), + nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]), + nn.ReLU(), + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]), + nn.ReLU(), + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]), + nn.ReLU(), + nn.Conv2d(self.nf, 6, [3, 3], 1, [1, 1]), + ) + + self.conv_flow3 = nn.Sequential( + nn.Conv2d(4 + self.nf * 4, self.nf, [1, 1], 1, [0, 0]), + nn.ReLU(), + nn.Conv2d(self.nf, 2 * self.nf, [4, 4], 2, [1, 1]), + nn.ReLU(), + nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]), + nn.ReLU(), + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]), + nn.ReLU(), + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]), + nn.ReLU(), + nn.Conv2d(self.nf, 4, [3, 3], 1, [1, 1]), + ) + + self.refine_unet = RefineUNet(args) + self.lrelu = nn.ReLU() + + def forward(self, x, feat_x, flow_l_prev, t_value, level, is_training): + ''' + x shape : [B,C,T,H,W] + t_value shape : [B,1] ############### + ''' + B, C, T, H, W = x.size() + assert T % 2 == 0, "T must be an even number" + + ####################### For a single level + l = 2 ** level + x_l = x.permute(0,2,1,3,4) + x_l = x_l.contiguous().view(B * T, C, H, W) + + if level == 0: + pass + else: + x_l = F.interpolate(x_l, scale_factor=(1.0 / l, 1.0 / l), mode='bicubic', align_corners=False) + ''' + Down pixel-shuffle + ''' + x_l = x_l.view(B, T, C, H//l, W//l) + x_l = x_l.permute(0,2,1,3,4) + + B, C, T, H, W = x_l.size() + + ## Feature extraction + feat0_l = feat_x[:,:,0,:,:] + feat1_l = feat_x[:,:,1,:,:] + + ## Flow estimation + if flow_l_prev is None: + flow_l_tmp = self.conv_flow_bottom(torch.cat((feat0_l, feat1_l), dim=1)) + flow_l = flow_l_tmp[:,:4,:,:] + else: + up_flow_l_prev = 2.0*F.interpolate(flow_l_prev.detach(), scale_factor=(2,2), mode='bilinear', align_corners=False) + warped_feat1_l = self.bwarp(feat1_l, up_flow_l_prev[:,:2,:,:]) + warped_feat0_l = self.bwarp(feat0_l, up_flow_l_prev[:,2:,:,:]) + flow_l_tmp = self.conv_flow2(torch.cat([self.conv_flow1(torch.cat([feat0_l, warped_feat1_l],dim=1)), self.conv_flow1(torch.cat([feat1_l, warped_feat0_l],dim=1)), up_flow_l_prev],dim=1)) + flow_l = flow_l_tmp[:,:4,:,:] + up_flow_l_prev + + if not is_training and level!=0: + return flow_l + + flow_01_l = flow_l[:,:2,:,:] + flow_10_l = flow_l[:,2:,:,:] + z_01_l = torch.sigmoid(flow_l_tmp[:,4:5,:,:]) + z_10_l = torch.sigmoid(flow_l_tmp[:,5:6,:,:]) + + ## Complementary Flow Reversal (CFR) + flow_forward, norm0_l = self.z_fwarp(flow_01_l, t_value * flow_01_l, z_01_l) ## Actually, F (t) -> (t+1). Translation only. Not normalized yet + flow_backward, norm1_l = self.z_fwarp(flow_10_l, (1-t_value) * flow_10_l, z_10_l) ## Actually, F (1-t) -> (-t). Translation only. Not normalized yet + + flow_t0_l = -(1-t_value) * ((t_value)*flow_forward) + (t_value) * ((t_value)*flow_backward) # The numerator of Eq.(1) in the paper. + flow_t1_l = (1-t_value) * ((1-t_value)*flow_forward) - (t_value) * ((1-t_value)*flow_backward) # The numerator of Eq.(2) in the paper. + + norm_l = (1-t_value)*norm0_l + t_value*norm1_l + mask_ = (norm_l.detach() > 0).type(norm_l.type()) + flow_t0_l = (1-mask_) * flow_t0_l + mask_ * (flow_t0_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(1) + flow_t1_l = (1-mask_) * flow_t1_l + mask_ * (flow_t1_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(2) + + ## Feature warping + warped0_l = self.bwarp(feat0_l, flow_t0_l) + warped1_l = self.bwarp(feat1_l, flow_t1_l) + + ## Flow refinement + flow_refine_l = torch.cat([feat0_l, warped0_l, warped1_l, feat1_l, flow_t0_l, flow_t1_l], dim=1) + flow_refine_l = self.conv_flow3(flow_refine_l) + torch.cat([flow_t0_l, flow_t1_l], dim=1) + flow_t0_l = flow_refine_l[:, :2, :, :] + flow_t1_l = flow_refine_l[:, 2:4, :, :] + + warped0_l = self.bwarp(feat0_l, flow_t0_l) + warped1_l = self.bwarp(feat1_l, flow_t1_l) + + ## Flow upscale + flow_t0_l = self.scale * F.interpolate(flow_t0_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False) + flow_t1_l = self.scale * F.interpolate(flow_t1_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False) + + ## Image warping and blending + warped_img0_l = self.bwarp(x_l[:,:,0,:,:], flow_t0_l) + warped_img1_l = self.bwarp(x_l[:,:,1,:,:], flow_t1_l) + + refine_out = self.refine_unet(torch.cat([F.pixel_shuffle(torch.cat([feat0_l, feat1_l, warped0_l, warped1_l],dim=1), self.scale), x_l[:,:,0,:,:], x_l[:,:,1,:,:], warped_img0_l, warped_img1_l, flow_t0_l, flow_t1_l],dim=1)) + occ_0_l = torch.sigmoid(refine_out[:, 0:1, :, :]) + occ_1_l = 1-occ_0_l + + out_l = (1-t_value)*occ_0_l*warped_img0_l + t_value*occ_1_l*warped_img1_l + out_l = out_l / ( (1-t_value)*occ_0_l + t_value*occ_1_l ) + refine_out[:, 1:4, :, :] + + if not is_training and level==0: + return out_l + + if is_training: + if flow_l_prev is None: + # if level == self.args.S_trn: + return out_l, flow_l, flow_refine_l[:, 0:4, :, :] + elif level != 0: + return out_l, flow_l + else: # level==0 + return out_l, flow_l, flow_refine_l[:, 0:4, :, :], occ_0_l + + def bwarp(self, x, flo): + ''' + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + ''' + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) + yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) + grid = torch.cat((xx, yy), 1).float() + + grid = grid.to(self.device) + vgrid = torch.autograd.Variable(grid) + flo + + # scale grid to [-1,1] + vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 + vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 + + vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2] + output = nn.functional.grid_sample(x, vgrid, align_corners=True) + mask = torch.autograd.Variable(torch.ones(x.size())).to(self.device) + mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) + + # mask[mask<0.9999] = 0 + # mask[mask>0] = 1 + mask = mask.masked_fill_(mask < 0.999, 0) + mask = mask.masked_fill_(mask > 0, 1) + + return output * mask + + def fwarp(self, img, flo): + + """ + -img: image (N, C, H, W) + -flo: optical flow (N, 2, H, W) + elements of flo is in [0, H] and [0, W] for dx, dy + https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py + """ + + # (x1, y1) (x1, y2) + # +---------------+ + # | | + # | o(x, y) | + # | | + # | | + # | | + # | | + # +---------------+ + # (x2, y1) (x2, y2) + + N, C, _, _ = img.size() + + # translate start-point optical flow to end-point optical flow + y = flo[:, 0:1:, :] + x = flo[:, 1:2, :, :] + + x = x.repeat(1, C, 1, 1) + y = y.repeat(1, C, 1, 1) + + # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) + x1 = torch.floor(x) + x2 = x1 + 1 + y1 = torch.floor(y) + y2 = y1 + 1 + + # firstly, get gaussian weights + w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2) + + # secondly, sample each weighted corner + img11, o11 = self.sample_one(img, x1, y1, w11) + img12, o12 = self.sample_one(img, x1, y2, w12) + img21, o21 = self.sample_one(img, x2, y1, w21) + img22, o22 = self.sample_one(img, x2, y2, w22) + + imgw = img11 + img12 + img21 + img22 + o = o11 + o12 + o21 + o22 + + return imgw, o + + + def z_fwarp(self, img, flo, z): + """ + -img: image (N, C, H, W) + -flo: optical flow (N, 2, H, W) + elements of flo is in [0, H] and [0, W] for dx, dy + modified from https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py + """ + + # (x1, y1) (x1, y2) + # +---------------+ + # | | + # | o(x, y) | + # | | + # | | + # | | + # | | + # +---------------+ + # (x2, y1) (x2, y2) + + N, C, _, _ = img.size() + + # translate start-point optical flow to end-point optical flow + y = flo[:, 0:1:, :] + x = flo[:, 1:2, :, :] + + x = x.repeat(1, C, 1, 1) + y = y.repeat(1, C, 1, 1) + + # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) + x1 = torch.floor(x) + x2 = x1 + 1 + y1 = torch.floor(y) + y2 = y1 + 1 + + # firstly, get gaussian weights + w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2, z+1e-5) + + # secondly, sample each weighted corner + img11, o11 = self.sample_one(img, x1, y1, w11) + img12, o12 = self.sample_one(img, x1, y2, w12) + img21, o21 = self.sample_one(img, x2, y1, w21) + img22, o22 = self.sample_one(img, x2, y2, w22) + + imgw = img11 + img12 + img21 + img22 + o = o11 + o12 + o21 + o22 + + return imgw, o + + + def get_gaussian_weights(self, x, y, x1, x2, y1, y2, z=1.0): + # z 0.0 ~ 1.0 + w11 = z * torch.exp(-((x - x1) ** 2 + (y - y1) ** 2)) + w12 = z * torch.exp(-((x - x1) ** 2 + (y - y2) ** 2)) + w21 = z * torch.exp(-((x - x2) ** 2 + (y - y1) ** 2)) + w22 = z * torch.exp(-((x - x2) ** 2 + (y - y2) ** 2)) + + return w11, w12, w21, w22 + + def sample_one(self, img, shiftx, shifty, weight): + """ + Input: + -img (N, C, H, W) + -shiftx, shifty (N, c, H, W) + """ + + N, C, H, W = img.size() + + # flatten all (all restored as Tensors) + flat_shiftx = shiftx.view(-1) + flat_shifty = shifty.view(-1) + flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].to(self.device).long().repeat(N, C,1,W).view(-1) + flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].to(self.device).long().repeat(N, C,H,1).view(-1) + flat_weight = weight.view(-1) + flat_img = img.contiguous().view(-1) + + # The corresponding positions in I1 + idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).to(self.device).long().repeat(1, C, H, W).view(-1) + idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).to(self.device).long().repeat(N, 1, H, W).view(-1) + idxx = flat_shiftx.long() + flat_basex + idxy = flat_shifty.long() + flat_basey + + # recording the inside part the shifted + mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W) + + # Mask off points out of boundaries + ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy) + ids_mask = torch.masked_select(ids, mask).clone().to(self.device) + + # Note here! accmulate fla must be true for proper bp + img_warp = torch.zeros([N * C * H * W, ]).to(self.device) + img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True) + + one_warp = torch.zeros([N * C * H * W, ]).to(self.device) + one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True) + + return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W) + +class RefineUNet(nn.Module): + def __init__(self, args): + super(RefineUNet, self).__init__() + self.args = args + self.scale = args.module_scale_factor + self.nf = args.nf + self.conv1 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1]) + self.conv2 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1]) + self.lrelu = nn.ReLU() + self.NN = nn.UpsamplingNearest2d(scale_factor=2) + + self.enc1 = nn.Conv2d((4*self.nf)//self.scale//self.scale + 4*args.img_ch + 4, self.nf, [4, 4], 2, [1, 1]) + self.enc2 = nn.Conv2d(self.nf, 2*self.nf, [4, 4], 2, [1, 1]) + self.enc3 = nn.Conv2d(2*self.nf, 4*self.nf, [4, 4], 2, [1, 1]) + self.dec0 = nn.Conv2d(4*self.nf, 4*self.nf, [3, 3], 1, [1, 1]) + self.dec1 = nn.Conv2d(4*self.nf + 2*self.nf, 2*self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc2 + self.dec2 = nn.Conv2d(2*self.nf + self.nf, self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc1 + self.dec3 = nn.Conv2d(self.nf, 1+args.img_ch, [3, 3], 1, [1, 1]) ## input added with warped image + + def forward(self, concat): + enc1 = self.lrelu(self.enc1(concat)) + enc2 = self.lrelu(self.enc2(enc1)) + out = self.lrelu(self.enc3(enc2)) + + out = self.lrelu(self.dec0(out)) + out = self.NN(out) + + out = torch.cat((out,enc2),dim=1) + out = self.lrelu(self.dec1(out)) + + out = self.NN(out) + out = torch.cat((out,enc1),dim=1) + out = self.lrelu(self.dec2(out)) + + out = self.NN(out) + out = self.dec3(out) + return out + +class ResBlock2D_3D(nn.Module): + ## Shape of input [B,C,T,H,W] + ## Shape of output [B,C,T,H,W] + def __init__(self, args): + super(ResBlock2D_3D, self).__init__() + self.args = args + self.nf = args.nf + + self.conv3x3_1 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1]) + self.conv3x3_2 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1]) + self.lrelu = nn.ReLU() + + def forward(self, x): + ''' + x shape : [B,C,T,H,W] + ''' + B, C, T, H, W = x.size() + + out = self.conv3x3_2(self.lrelu(self.conv3x3_1(x))) + + return x + out + +class RResBlock2D_3D(nn.Module): + + def __init__(self, args, T_reduce_flag=False): + super(RResBlock2D_3D, self).__init__() + self.args = args + self.nf = args.nf + self.T_reduce_flag = T_reduce_flag + self.resblock1 = ResBlock2D_3D(self.args) + self.resblock2 = ResBlock2D_3D(self.args) + if T_reduce_flag: + self.reduceT_conv = nn.Conv3d(self.nf, self.nf, [3,1,1], 1, [0,0,0]) + + def forward(self, x): + ''' + x shape : [B,C,T,H,W] + ''' + out = self.resblock1(x) + out = self.resblock2(out) + if self.T_reduce_flag: + return self.reduceT_conv(out + x) + else: + return out + x + +def weights_init(m): + classname = m.__class__.__name__ + if (classname.find('Conv2d') != -1) or (classname.find('Conv3d') != -1): + init.xavier_normal_(m.weight) + # init.kaiming_normal_(m.weight, nonlinearity='relu') + if hasattr(m, 'bias') and m.bias is not None: + init.zeros_(m.bias) \ No newline at end of file diff --git a/vfi_utils.py b/vfi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7504637e63abfc8b826a36f11456e8a47d2dbf5d --- /dev/null +++ b/vfi_utils.py @@ -0,0 +1,295 @@ +import yaml +import os +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse +import torch +import typing +import traceback +import einops +import gc +import torchvision.transforms.functional as transform +from comfy.model_management import soft_empty_cache, get_torch_device +import numpy as np + +BASE_MODEL_DOWNLOAD_URLS = [ + "https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/", + "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/", + "https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/" +] + +config_path = os.path.join(os.path.dirname(__file__), "./config.yaml") +if os.path.exists(config_path): + config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) +else: + raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation") +DEVICE = get_torch_device() + +class InterpolationStateList(): + + def __init__(self, frame_indices: typing.List[int], is_skip_list: bool): + self.frame_indices = frame_indices + self.is_skip_list = is_skip_list + + def is_frame_skipped(self, frame_index): + is_frame_in_list = frame_index in self.frame_indices + return self.is_skip_list and is_frame_in_list or not self.is_skip_list and not is_frame_in_list + + +class MakeInterpolationStateList: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}), + "is_skip_list": ("BOOLEAN", {"default": True},), + }, + } + + RETURN_TYPES = ("INTERPOLATION_STATES",) + FUNCTION = "create_options" + CATEGORY = "ComfyUI-Frame-Interpolation/VFI" + + def create_options(self, frame_indices: str, is_skip_list: bool): + frame_indices_list = [int(item) for item in frame_indices.split(',')] + + interpolation_state_list = InterpolationStateList( + frame_indices=frame_indices_list, + is_skip_list=is_skip_list, + ) + return (interpolation_state_list,) + + +def get_ckpt_container_path(model_type): + return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type)) + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + file_name = os.path.basename(parts.path) + if file_name is not None: + file_name = file_name + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + +def load_file_from_github_release(model_type, ckpt_name): + error_strs = [] + for i, base_model_download_url in enumerate(BASE_MODEL_DOWNLOAD_URLS): + try: + return load_file_from_url(base_model_download_url + ckpt_name, get_ckpt_container_path(model_type)) + except Exception: + traceback_str = traceback.format_exc() + if i < len(BASE_MODEL_DOWNLOAD_URLS) - 1: + print("Failed! Trying another endpoint.") + error_strs.append(f"Error when downloading from: {base_model_download_url + ckpt_name}\n\n{traceback_str}") + + error_str = '\n\n'.join(error_strs) + raise Exception(f"Tried all GitHub base urls to download {ckpt_name} but no suceess. Below is the error log:\n\n{error_str}") + + +def load_file_from_direct_url(model_type, url): + return load_file_from_url(url, get_ckpt_container_path(model_type)) + +def preprocess_frames(frames): + return einops.rearrange(frames[..., :3], "n h w c -> n c h w") + +def postprocess_frames(frames): + return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu() + +def assert_batch_size(frames, batch_size=2, vfi_name=None): + subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires" + assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage." + +def _generic_frame_loop( + frames, + clear_cache_after_n_frames, + multiplier: typing.Union[typing.SupportsInt, typing.List], + return_middle_frame_function, + *return_middle_frame_function_args, + interpolation_states: InterpolationStateList = None, + use_timestep=True, + dtype=torch.float16, + final_logging=True): + + #https://github.com/hzwer/Practical-RIFE/blob/main/inference_video.py#L169 + def non_timestep_inference(frame0, frame1, n): + middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args) + if n == 1: + return [middle] + first_half = non_timestep_inference(frame0, middle, n=n//2) + second_half = non_timestep_inference(middle, frame1, n=n//2) + if n%2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] + + output_frames = torch.zeros(multiplier*frames.shape[0], *frames.shape[1:], dtype=dtype, device="cpu") + out_len = 0 + + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + + for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it + frame0 = frames[frame_itr:frame_itr+1] + output_frames[out_len] = frame0 # Start with first frame + out_len += 1 + # Ensure that input frames are in fp32 - the same dtype as model + frame0 = frame0.to(dtype=torch.float32) + frame1 = frames[frame_itr+1:frame_itr+2].to(dtype=torch.float32) + + if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr): + continue + + # Generate and append a batch of middle frames + middle_frame_batches = [] + + if use_timestep: + for middle_i in range(1, multiplier): + timestep = middle_i/multiplier + + middle_frame = return_middle_frame_function( + frame0.to(DEVICE), + frame1.to(DEVICE), + timestep, + *return_middle_frame_function_args + ).detach().cpu() + middle_frame_batches.append(middle_frame.to(dtype=dtype)) + else: + middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1) + middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype)) + + # Copy middle frames to output + for middle_frame in middle_frame_batches: + output_frames[out_len] = middle_frame + out_len += 1 + + number_of_frames_processed_since_last_cleared_cuda_cache += 1 + # Try to avoid a memory overflow by clearing cuda cache regularly + if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: + print("Comfy-VFI: Clearing cache...", end=' ') + soft_empty_cache() + number_of_frames_processed_since_last_cleared_cuda_cache = 0 + print("Done cache clearing") + + gc.collect() + + if final_logging: + print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") + # Append final frame + output_frames[out_len] = frames[-1:] + out_len += 1 + # clear cache for courtesy + if final_logging: + print("Comfy-VFI: Final clearing cache...", end = ' ') + soft_empty_cache() + if final_logging: + print("Done cache clearing") + return output_frames[:out_len] + +def generic_frame_loop( + model_name, + frames, + clear_cache_after_n_frames, + multiplier: typing.Union[typing.SupportsInt, typing.List], + return_middle_frame_function, + *return_middle_frame_function_args, + interpolation_states: InterpolationStateList = None, + use_timestep=True, + dtype=torch.float32): + + assert_batch_size(frames, vfi_name=model_name.replace('_', ' ').replace('VFI', '')) + if type(multiplier) == int: + return _generic_frame_loop( + frames, + clear_cache_after_n_frames, + multiplier, + return_middle_frame_function, + *return_middle_frame_function_args, + interpolation_states=interpolation_states, + use_timestep=use_timestep, + dtype=dtype + ) + if type(multiplier) == list: + multipliers = list(map(int, multiplier)) + multipliers += [2] * (len(frames) - len(multipliers) - 1) + frame_batches = [] + for frame_itr in range(len(frames) - 1): + multiplier = multipliers[frame_itr] + if multiplier == 0: continue + frame_batch = _generic_frame_loop( + frames[frame_itr:frame_itr+2], + clear_cache_after_n_frames, + multiplier, + return_middle_frame_function, + *return_middle_frame_function_args, + interpolation_states=interpolation_states, + use_timestep=use_timestep, + dtype=dtype, + final_logging=False + ) + if frame_itr != len(frames) - 2: # Not append last frame unless this batch is the last one + frame_batch = frame_batch[:-1] + frame_batches.append(frame_batch) + output_frames = torch.cat(frame_batches) + print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") + return output_frames + raise NotImplementedError(f"multipiler of {type(multiplier)}") + +class FloatToInt: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "float": ("FLOAT", {"default": 0, 'min': 0, 'step': 0.01}) + } + } + + RETURN_TYPES = ("INT",) + FUNCTION = "convert" + CATEGORY = "ComfyUI-Frame-Interpolation" + + def convert(self, float): + if hasattr(float, "__iter__"): + return (list(map(int, float)),) + return (int(float),) + +""" def generic_4frame_loop( + frames, + clear_cache_after_n_frames, + multiplier: typing.SupportsInt, + return_middle_frame_function, + *return_middle_frame_function_args, + interpolation_states: InterpolationStateList = None, + use_timestep=False): + + if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model") + def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n): + middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args) + if n == 1: + return [middle] + first_half = non_timestep_inference(frame_0, middle, n=n//2) + second_half = non_timestep_inference(middle, frame_1, n=n//2) + if n%2: + return [*first_half, middle, *second_half] + else: + return [*first_half, *second_half] """ \ No newline at end of file