aliensmn commited on
Commit
61029c7
·
verified ·
1 Parent(s): f387762

Mirror from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +15 -0
  2. .github/workflows/publish.yml +25 -0
  3. .gitignore +3 -0
  4. All_in_one_v1_3.png +3 -0
  5. LICENSE +21 -0
  6. README.md +194 -0
  7. __init__.py +42 -0
  8. config.yaml +3 -0
  9. demo_frames/anime0.png +3 -0
  10. demo_frames/anime1.png +3 -0
  11. demo_frames/bocchi0.jpg +3 -0
  12. demo_frames/bocchi1.jpg +3 -0
  13. demo_frames/real0.png +3 -0
  14. demo_frames/real1.png +3 -0
  15. demo_frames/rick/00003.png +3 -0
  16. demo_frames/rick/00004.png +3 -0
  17. demo_frames/rick/00005.png +3 -0
  18. demo_frames/violet0.png +3 -0
  19. demo_frames/violet1.png +3 -0
  20. example.png +3 -0
  21. install-taichi.bat +11 -0
  22. install.bat +16 -0
  23. install.py +59 -0
  24. interpolation_schedule.png +3 -0
  25. other_nodes.py +88 -0
  26. pyproject.toml +13 -0
  27. requirements-no-cupy.txt +9 -0
  28. requirements-with-cupy.txt +10 -0
  29. test.py +38 -0
  30. test_vfi_schedule.gif +3 -0
  31. vfi_models/amt/__init__.py +87 -0
  32. vfi_models/amt/amt_arch.py +1590 -0
  33. vfi_models/cain/__init__.py +64 -0
  34. vfi_models/cain/cain_arch.py +74 -0
  35. vfi_models/cain/cain_encdec_arch.py +95 -0
  36. vfi_models/cain/cain_noca_arch.py +73 -0
  37. vfi_models/cain/common.py +361 -0
  38. vfi_models/eisai/__init__.py +84 -0
  39. vfi_models/eisai/eisai_arch.py +2586 -0
  40. vfi_models/film/__init__.py +113 -0
  41. vfi_models/film/film_arch.py +798 -0
  42. vfi_models/flavr/__init__.py +115 -0
  43. vfi_models/flavr/flavr_arch.py +217 -0
  44. vfi_models/flavr/resnet_3D.py +288 -0
  45. vfi_models/gmfss_fortuna/GMFSS_Fortuna.py +24 -0
  46. vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py +1850 -0
  47. vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py +23 -0
  48. vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py +1857 -0
  49. vfi_models/gmfss_fortuna/__init__.py +143 -0
  50. vfi_models/ifrnet/IFRNet_L_arch.py +293 -0
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ All_in_one_v1_3.png filter=lfs diff=lfs merge=lfs -text
37
+ demo_frames/anime0.png filter=lfs diff=lfs merge=lfs -text
38
+ demo_frames/anime1.png filter=lfs diff=lfs merge=lfs -text
39
+ demo_frames/bocchi0.jpg filter=lfs diff=lfs merge=lfs -text
40
+ demo_frames/bocchi1.jpg filter=lfs diff=lfs merge=lfs -text
41
+ demo_frames/real0.png filter=lfs diff=lfs merge=lfs -text
42
+ demo_frames/real1.png filter=lfs diff=lfs merge=lfs -text
43
+ demo_frames/rick/00003.png filter=lfs diff=lfs merge=lfs -text
44
+ demo_frames/rick/00004.png filter=lfs diff=lfs merge=lfs -text
45
+ demo_frames/rick/00005.png filter=lfs diff=lfs merge=lfs -text
46
+ demo_frames/violet0.png filter=lfs diff=lfs merge=lfs -text
47
+ demo_frames/violet1.png filter=lfs diff=lfs merge=lfs -text
48
+ example.png filter=lfs diff=lfs merge=lfs -text
49
+ interpolation_schedule.png filter=lfs diff=lfs merge=lfs -text
50
+ test_vfi_schedule.gif filter=lfs diff=lfs merge=lfs -text
.github/workflows/publish.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish to Comfy registry
2
+ on:
3
+ workflow_dispatch:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "pyproject.toml"
9
+
10
+ permissions:
11
+ issues: write
12
+
13
+ jobs:
14
+ publish-node:
15
+ name: Publish Custom Node to registry
16
+ runs-on: ubuntu-latest
17
+ if: ${{ github.repository_owner == 'Fannovel16' }}
18
+ steps:
19
+ - name: Check out code
20
+ uses: actions/checkout@v4
21
+ - name: Publish Custom Node
22
+ uses: Comfy-Org/publish-node-action@v1
23
+ with:
24
+ ## Add your own personal access token to your Github Repository secrets and reference it here.
25
+ personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ckpts
2
+ __pycache__
3
+ test_result
All_in_one_v1_3.png ADDED

Git LFS Details

  • SHA256: 90735b644e0c35634642b65f2a8041a9a4da380d27b9bcc4d3bbef47869bd92a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Fannovel16
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI Frame Interpolation (ComfyUI VFI) (WIP)
2
+
3
+ A custom node set for Video Frame Interpolation in ComfyUI.
4
+ **UPDATE** Memory management is improved. Now this extension takes less RAM and VRAM than before.
5
+
6
+ **UPDATE 2** VFI nodes now accept scheduling multipiler values
7
+
8
+ ![](./interpolation_schedule.png)
9
+ ![](./test_vfi_schedule.gif)
10
+
11
+ ## Nodes
12
+ * KSampler Gradually Adding More Denoise (efficient)
13
+ * GMFSS Fortuna VFI
14
+ * IFRNet VFI
15
+ * IFUnet VFI
16
+ * M2M VFI
17
+ * RIFE VFI (4.0 - 4.9) (Note that option `fast_mode` won't do anything from v4.5+ as `contextnet` is removed)
18
+ * FILM VFI
19
+ * Sepconv VFI
20
+ * AMT VFI
21
+ * Make Interpolation State List
22
+ * STMFNet VFI (requires at least 4 frames, can only do 2x interpolation for now)
23
+ * FLAVR VFI (same conditions as STMFNet)
24
+
25
+ ## Install
26
+ ### ComfyUI Manager
27
+ Incompatibile issue with it is now fixed
28
+
29
+ Following this guide to install this extension
30
+
31
+ https://github.com/ltdrdata/ComfyUI-Manager#how-to-use
32
+ ### Command-line
33
+ #### Windows
34
+ Run install.bat
35
+
36
+ For Window users, if you are having trouble with cupy, please run `install.bat` instead of `install-cupy.py` or `python install.py`.
37
+ #### Linux
38
+ Open your shell app and start venv if it is used for ComfyUI. Run:
39
+ ```
40
+ python install.py
41
+ ```
42
+ ## Support for non-CUDA device (experimental)
43
+ If you don't have a NVidia card, you can try `taichi` ops backend powered by [Taichi Lang](https://www.taichi-lang.org/)
44
+
45
+ On Windows, you can install it by running `install.bat` or `pip install taichi` on Linux
46
+
47
+ Then change value of `ops_backend` from `cupy` to `taichi` in `config.yaml`
48
+
49
+ If `NotImplementedError` appears, a VFI node in the workflow isn't supported by taichi
50
+
51
+ ## Usage
52
+ All VFI nodes can be accessed in **category** `ComfyUI-Frame-Interpolation/VFI` if the installation is successful and require a `IMAGE` containing frames (at least 2, or at least 4 for STMF-Net/FLAVR).
53
+
54
+ Regarding STMFNet and FLAVR, if you only have two or three frames, you should use: Load Images -> Other VFI node (FILM is recommended in this case) with `multiplier=4` -> STMFNet VFI/FLAVR VFI
55
+
56
+ `clear_cache_after_n_frames` is used to avoid out-of-memory. Decreasing it makes the chance lower but also increases processing time.
57
+
58
+ It is recommended to use LoadImages (LoadImagesFromDirectory) from [ComfyUI-Advanced-ControlNet](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/) and [ComfyUI-VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) along side with this extension.
59
+
60
+ ## Example
61
+ ### Simple workflow
62
+ Workflow metadata isn't embeded
63
+ Download these two images [anime0.png](./demo_frames/anime0.png) and [anime1.png](./demo_frames/anime0.png) and put them into a folder like `E:\test` in this image.
64
+ ![](./example.png)
65
+
66
+ ### Complex workflow
67
+ It's used in AnimationDiff (can load workflow metadata)
68
+ ![](All_in_one_v1_3.png)
69
+
70
+ ## Credit
71
+ Big thanks for styler00dollar for making [VSGAN-tensorrt-docker](https://github.com/styler00dollar/VSGAN-tensorrt-docker). About 99% the code of this repo comes from it.
72
+
73
+ Citation for each VFI node:
74
+ ### GMFSS Fortuna
75
+ The All-In-One GMFSS: Dedicated for Anime Video Frame Interpolation
76
+
77
+ https://github.com/98mxr/GMFSS_Fortuna
78
+
79
+ ### IFRNet
80
+ ```bibtex
81
+ @InProceedings{Kong_2022_CVPR,
82
+ author = {Kong, Lingtong and Jiang, Boyuan and Luo, Donghao and Chu, Wenqing and Huang, Xiaoming and Tai, Ying and Wang, Chengjie and Yang, Jie},
83
+ title = {IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation},
84
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
85
+ year = {2022}
86
+ }
87
+ ```
88
+
89
+ ### IFUnet
90
+ RIFE with IFUNet, FusionNet and RefineNet
91
+
92
+ https://github.com/98mxr/IFUNet
93
+ ### M2M
94
+ ```bibtex
95
+ @InProceedings{hu2022m2m,
96
+ title={Many-to-many Splatting for Efficient Video Frame Interpolation},
97
+ author={Hu, Ping and Niklaus, Simon and Sclaroff, Stan and Saenko, Kate},
98
+ journal={CVPR},
99
+ year={2022}
100
+ }
101
+ ```
102
+
103
+ ### RIFE
104
+ ```bibtex
105
+ @inproceedings{huang2022rife,
106
+ title={Real-Time Intermediate Flow Estimation for Video Frame Interpolation},
107
+ author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang},
108
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
109
+ year={2022}
110
+ }
111
+ ```
112
+
113
+ ### FILM
114
+ [Frame interpolation in PyTorch](https://github.com/dajes/frame-interpolation-pytorch)
115
+
116
+ ```bibtex
117
+ @inproceedings{reda2022film,
118
+ title = {FILM: Frame Interpolation for Large Motion},
119
+ author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
120
+ booktitle = {European Conference on Computer Vision (ECCV)},
121
+ year = {2022}
122
+ }
123
+ ```
124
+
125
+ ```bibtex
126
+ @misc{film-tf,
127
+ title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
128
+ author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
129
+ year = {2022},
130
+ publisher = {GitHub},
131
+ journal = {GitHub repository},
132
+ howpublished = {\url{https://github.com/google-research/frame-interpolation}}
133
+ }
134
+ ```
135
+
136
+ ### Sepconv
137
+ ```bibtex
138
+ [1] @inproceedings{Niklaus_WACV_2021,
139
+ author = {Simon Niklaus and Long Mai and Oliver Wang},
140
+ title = {Revisiting Adaptive Convolutions for Video Frame Interpolation},
141
+ booktitle = {IEEE Winter Conference on Applications of Computer Vision},
142
+ year = {2021}
143
+ }
144
+ ```
145
+
146
+ ```bibtex
147
+ [2] @inproceedings{Niklaus_ICCV_2017,
148
+ author = {Simon Niklaus and Long Mai and Feng Liu},
149
+ title = {Video Frame Interpolation via Adaptive Separable Convolution},
150
+ booktitle = {IEEE International Conference on Computer Vision},
151
+ year = {2017}
152
+ }
153
+ ```
154
+
155
+ ```bibtex
156
+ [3] @inproceedings{Niklaus_CVPR_2017,
157
+ author = {Simon Niklaus and Long Mai and Feng Liu},
158
+ title = {Video Frame Interpolation via Adaptive Convolution},
159
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
160
+ year = {2017}
161
+ }
162
+ ```
163
+
164
+ ### AMT
165
+ ```bibtex
166
+ @inproceedings{licvpr23amt,
167
+ title={AMT: All-Pairs Multi-Field Transforms for Efficient Frame Interpolation},
168
+ author={Li, Zhen and Zhu, Zuo-Liang and Han, Ling-Hao and Hou, Qibin and Guo, Chun-Le and Cheng, Ming-Ming},
169
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
170
+ year={2023}
171
+ }
172
+ ```
173
+
174
+ ### ST-MFNet
175
+ ```bibtex
176
+ @InProceedings{Danier_2022_CVPR,
177
+ author = {Danier, Duolikun and Zhang, Fan and Bull, David},
178
+ title = {ST-MFNet: A Spatio-Temporal Multi-Flow Network for Frame Interpolation},
179
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
180
+ month = {June},
181
+ year = {2022},
182
+ pages = {3521-3531}
183
+ }
184
+ ```
185
+
186
+ ### FLAVR
187
+ ```bibtex
188
+ @article{kalluri2021flavr,
189
+ title={FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation},
190
+ author={Kalluri, Tarun and Pathak, Deepak and Chandraker, Manmohan and Tran, Du},
191
+ booktitle={arxiv},
192
+ year={2021}
193
+ }
194
+ ```
__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
4
+
5
+ from .other_nodes import Gradually_More_Denoise_KSampler
6
+
7
+ #Some models are commented out because the code is not completed
8
+ #from vfi_models.eisai import EISAI_VFI
9
+ from vfi_models.gmfss_fortuna import GMFSS_Fortuna_VFI
10
+ from vfi_models.ifrnet import IFRNet_VFI
11
+ from vfi_models.ifunet import IFUnet_VFI
12
+ from vfi_models.m2m import M2M_VFI
13
+ from vfi_models.rife import RIFE_VFI
14
+ from vfi_models.sepconv import SepconvVFI
15
+ from vfi_models.amt import AMT_VFI
16
+ from vfi_models.film import FILM_VFI
17
+ from vfi_models.stmfnet import STMFNet_VFI
18
+ from vfi_models.flavr import FLAVR_VFI
19
+ from vfi_models.cain import CAIN_VFI
20
+ from vfi_utils import MakeInterpolationStateList, FloatToInt
21
+
22
+ NODE_CLASS_MAPPINGS = {
23
+ "KSampler Gradually Adding More Denoise (efficient)": Gradually_More_Denoise_KSampler,
24
+ # "EISAI VFI": EISAI_VFI,
25
+ "GMFSS Fortuna VFI": GMFSS_Fortuna_VFI,
26
+ "IFRNet VFI": IFRNet_VFI,
27
+ "IFUnet VFI": IFUnet_VFI,
28
+ "M2M VFI": M2M_VFI,
29
+ "RIFE VFI": RIFE_VFI,
30
+ "Sepconv VFI": SepconvVFI,
31
+ "AMT VFI": AMT_VFI,
32
+ "FILM VFI": FILM_VFI,
33
+ "Make Interpolation State List": MakeInterpolationStateList,
34
+ "STMFNet VFI": STMFNet_VFI,
35
+ "FLAVR VFI": FLAVR_VFI,
36
+ "CAIN VFI": CAIN_VFI,
37
+ "VFI FloatToInt": FloatToInt
38
+ }
39
+
40
+ NODE_DISPLAY_NAME_MAPPINGS = {
41
+ "RIFE VFI": "RIFE VFI (recommend rife47 and rife49)"
42
+ }
config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #Plz don't delete this file, just edit it when neccessary.
2
+ ckpts_path: "./ckpts"
3
+ ops_backend: "cupy" #Either "taichi" or "cupy"
demo_frames/anime0.png ADDED

Git LFS Details

  • SHA256: 734039ac77a89cf8d52fed8989bd4335392a1d246b099979d1c58a145c629ace
  • Pointer size: 131 Bytes
  • Size of remote file: 341 kB
demo_frames/anime1.png ADDED

Git LFS Details

  • SHA256: dd24bdafe9a0cfc82eada33c40962e9977ed5b6711ae6d89bf28b07cbded712a
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
demo_frames/bocchi0.jpg ADDED

Git LFS Details

  • SHA256: c607fae213b83d4c15fa10d6939b612f7f2242afd0b8716b203ace51774f6718
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
demo_frames/bocchi1.jpg ADDED

Git LFS Details

  • SHA256: f03f067142490d4353d3f5af8bd51b0f9f4bdd3d2094dde6a28f4fec062fbe16
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
demo_frames/real0.png ADDED

Git LFS Details

  • SHA256: 4792023ccf17c8231c6eb5ee40de528d515e2f8c419b3949985411a122a4de4f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
demo_frames/real1.png ADDED

Git LFS Details

  • SHA256: 37c8e6ec527c81895e5a66ea49cdd18b85045f9fed6fdfb75b45f438649235bf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
demo_frames/rick/00003.png ADDED

Git LFS Details

  • SHA256: 98f5dba7557ba55d13f494425d340ca84af8b56e35f929fab5df39e54015e265
  • Pointer size: 131 Bytes
  • Size of remote file: 456 kB
demo_frames/rick/00004.png ADDED

Git LFS Details

  • SHA256: 61bcf7933b192d84870b80910f7f983371c642d5c7100b34e8cc6dbd01cba7e6
  • Pointer size: 131 Bytes
  • Size of remote file: 355 kB
demo_frames/rick/00005.png ADDED

Git LFS Details

  • SHA256: f795d06e93ad4f9c19db578e9378a48b6008cc3df81fb2cd9fbbd5ed91bd8cf7
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB
demo_frames/violet0.png ADDED

Git LFS Details

  • SHA256: c6844899b551801ee22d4f57993ab66fd4b6fbe00eab916d6b987bdf083eadfe
  • Pointer size: 131 Bytes
  • Size of remote file: 889 kB
demo_frames/violet1.png ADDED

Git LFS Details

  • SHA256: 66ee9a9a486f57eb80ba5d41140eaca4ca46f0d946a3cff93eabb0ee3b1e29d0
  • Pointer size: 131 Bytes
  • Size of remote file: 951 kB
example.png ADDED

Git LFS Details

  • SHA256: 9a5e9310bfba63b109990b326402d42477688682858bc64f146ef546e6662ead
  • Pointer size: 131 Bytes
  • Size of remote file: 182 kB
install-taichi.bat ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Installing Taichi lang backend...
3
+
4
+ if exist "%python_exec%" (
5
+ %python_exec% -s -m pip install taichi
6
+ ) else (
7
+ echo Installing with system Python
8
+ pip install taichi
9
+ )
10
+
11
+ pause
install.bat ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ set "requirements_txt=%~dp0\requirements-no-cupy.txt"
4
+ set "python_exec=..\..\..\python_embeded\python.exe"
5
+
6
+ echo Installing ComfyUI Frame Interpolation..
7
+
8
+ if exist "%python_exec%" (
9
+ echo Installing with ComfyUI Portable
10
+ %python_exec% -s install.py
11
+ ) else (
12
+ echo Installing with system Python
13
+ python install.py
14
+ )
15
+
16
+ pause
install.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import sys
4
+ import platform
5
+
6
+ def get_cuda_ver_from_dir(cuda_home):
7
+ nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(cuda_home))
8
+ nvrtc = list(nvrtc)
9
+ if len(nvrtc) == 0:
10
+ return
11
+ nvrtc = nvrtc[0]
12
+ if ('102' in nvrtc) or ('10.2' in nvrtc):
13
+ return '102'
14
+ if '110' in nvrtc or ('11.0' in nvrtc):
15
+ return '110'
16
+ if '111' in nvrtc or ('11.1' in nvrtc):
17
+ return '111'
18
+ if '11' in nvrtc:
19
+ return '11x'
20
+ if '12' in nvrtc:
21
+ return '12x'
22
+
23
+ s_param = '-s' if "python_embeded" in sys.executable else ''
24
+
25
+ def get_cuda_home_path():
26
+ if "CUDA_HOME" in os.environ:
27
+ return os.environ["CUDA_HOME"]
28
+ import torch
29
+ torch_lib_path = Path(torch.__file__).parent / "lib"
30
+ torch_lib_path = str(torch_lib_path.resolve())
31
+ if os.path.exists(torch_lib_path):
32
+ nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(torch_lib_path))
33
+ nvrtc = list(nvrtc)
34
+ return torch_lib_path if len(nvrtc) > 0 else None
35
+
36
+ def install_cupy():
37
+ cuda_home = get_cuda_home_path()
38
+ try:
39
+ if cuda_home is not None:
40
+ os.environ["CUDA_HOME"] = cuda_home
41
+ os.environ["CUDA_PATH"] = cuda_home
42
+ import cupy
43
+ print("CuPy is already installed.")
44
+ except:
45
+ print("Uninstall cupy if existed...")
46
+ os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x')
47
+ print("Installing cupy...")
48
+ cuda_ver = get_cuda_ver_from_dir(cuda_home)
49
+ cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel"
50
+ os.system(f'"{sys.executable}" {s_param} -m pip install {cupy_package}')
51
+
52
+ with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f:
53
+ for package in f.readlines():
54
+ package = package.strip()
55
+ print(f"Installing {package}...")
56
+ os.system(f'"{sys.executable}" {s_param} -m pip install {package}')
57
+
58
+ print("Checking cupy...")
59
+ install_cupy()
interpolation_schedule.png ADDED

Git LFS Details

  • SHA256: c6999ee4a5fd6222b7b05adb8afa4994053bfe8e0f9c6b5cccf25992638b586c
  • Pointer size: 131 Bytes
  • Size of remote file: 378 kB
other_nodes.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import latent_preview
2
+ import comfy
3
+ import einops
4
+ import torch
5
+
6
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
7
+ device = comfy.model_management.get_torch_device()
8
+ latent_image = latent["samples"]
9
+
10
+ if disable_noise:
11
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
12
+ else:
13
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
14
+ noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
15
+
16
+ noise_mask = None
17
+ if "noise_mask" in latent:
18
+ noise_mask = latent["noise_mask"]
19
+
20
+ preview_format = "JPEG"
21
+ if preview_format not in ["JPEG", "PNG"]:
22
+ preview_format = "JPEG"
23
+
24
+ previewer = latent_preview.get_previewer(device, model.model.latent_format)
25
+
26
+ pbar = comfy.utils.ProgressBar(steps)
27
+ def callback(step, x0, x, total_steps):
28
+ preview_bytes = None
29
+ if previewer:
30
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
31
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
32
+
33
+ samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
34
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
35
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
36
+ out = latent.copy()
37
+ out["samples"] = samples
38
+ return (out, )
39
+
40
+ class Gradually_More_Denoise_KSampler:
41
+ @classmethod
42
+ def INPUT_TYPES(s):
43
+ return {"required":
44
+ {"model": ("MODEL",),
45
+ "positive": ("CONDITIONING", ),
46
+ "negative": ("CONDITIONING", ),
47
+ "latent_image": ("LATENT", ),
48
+
49
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
50
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
51
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
52
+ "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
53
+ "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
54
+
55
+ "start_denoise": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
56
+ "denoise_increment": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.1}),
57
+ "denoise_increment_steps": ("INT", {"default": 20, "min": 1, "max": 10000})
58
+ },
59
+ "optional": { "optional_vae": ("VAE",) }
60
+ }
61
+
62
+ RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "LATENT", "VAE", )
63
+ RETURN_NAMES = ("MODEL", "CONDITIONING+", "CONDITIONING-", "LATENT", "VAE", )
64
+ OUTPUT_NODE = True
65
+ FUNCTION = "sample"
66
+ CATEGORY = "ComfyUI-Frame-Interpolation/others"
67
+
68
+ def sample(self, model, positive, negative, latent_image, optional_vae,
69
+ seed, steps, cfg, sampler_name, scheduler,start_denoise, denoise_increment, denoise_increment_steps):
70
+ if start_denoise + denoise_increment * denoise_increment_steps > 1.0:
71
+ raise Exception(f"Max denoise strength can't over 1.0 (start_denoise={start_denoise}, denoise_increment={denoise_increment}, denoise_increment_steps={denoise_increment_steps}")
72
+
73
+ copied_latent = latent_image.copy()
74
+ out_samples = []
75
+
76
+ for latent_sample in copied_latent["samples"]:
77
+ latent = {"samples": einops.rearrange(latent_sample, "c h w -> 1 c h w")}
78
+ #Latent's shape is NCHW
79
+ gradually_denoising_samples = [
80
+ common_ksampler(
81
+ model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=start_denoise + denoise_increment * i
82
+ )[0]["samples"]
83
+ for i in range(denoise_increment_steps)
84
+ ]
85
+ out_samples.extend(gradually_denoising_samples)
86
+
87
+ copied_latent["samples"] = torch.cat(out_samples, dim=0)
88
+ return (model, positive, negative, copied_latent, optional_vae)
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "comfyui-frame-interpolation"
3
+ description = "A custom node suite for Video Frame Interpolation in ComfyUI"
4
+ version = "1.0.7"
5
+ license = { file = "LICENSE" }
6
+
7
+ [project.urls]
8
+ Repository = "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation"
9
+
10
+ [tool.comfy]
11
+ PublisherId = "fannovel16"
12
+ DisplayName = "ComfyUI-Frame-Interpolation"
13
+ Icon = ""
requirements-no-cupy.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ einops
4
+ opencv-contrib-python
5
+ kornia
6
+ scipy
7
+ Pillow
8
+ torchvision
9
+ tqdm
requirements-with-cupy.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ einops
4
+ opencv-contrib-python
5
+ kornia
6
+ scipy
7
+ Pillow
8
+ torchvision
9
+ tqdm
10
+ cupy-wheel
test.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
4
+
5
+ import shutil
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import PIL
9
+ import torchvision.transforms.functional as transform
10
+ from vfi_utils import load_file_from_github_release
11
+ from vfi_models import gmfss_fortuna, ifrnet, ifunet, m2m, rife, sepconv, amt, xvfi, cain, flavr
12
+ import numpy as np
13
+
14
+ frame_0 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime0.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0)
15
+ frame_1 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime1.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0)
16
+
17
+
18
+ if os.path.exists("test_result"):
19
+ shutil.rmtree("test_result")
20
+
21
+ vfi_node_class = gmfss_fortuna.GMFSS_Fortuna_VFI()
22
+ for i, ckpt_name in enumerate(vfi_node_class.INPUT_TYPES()["required"]["ckpt_name"][0][:2]):
23
+ result = vfi_node_class.vfi(ckpt_name, torch.cat([
24
+ frame_0,
25
+ frame_1,
26
+ frame_0,
27
+ frame_1
28
+ ], dim=0).cuda(), multipler=4, batch_size=2)[0]
29
+ print(result.shape)
30
+ print(f"Generated {result.size(0)} frames")
31
+ frames = [PIL.Image.fromarray(np.clip((frame * 255).numpy(), 0, 255).astype(np.uint8)) for frame in result]
32
+ print(result[0].shape)
33
+ os.makedirs(f"test_result/video{i}", exist_ok=True)
34
+ for j, frame in enumerate(frames):
35
+ frame.save(f"test_result/video{i}/{j}.jpg")
36
+ frames[0].save(f"test_result/video{i}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1/3, loop=0)
37
+ os.startfile(f"test_result{os.path.sep}video{i}.gif")
38
+ #torchvision.io.video.write_video("test.mp4", einops.rearrange(result, "n c h w -> n h w c").cpu(), fps=1)
test_vfi_schedule.gif ADDED

Git LFS Details

  • SHA256: 931fcd4c2cc84b457cbc1b1c3b8745a2bf292ff7dc43d4f733a2c510ad90353d
  • Pointer size: 132 Bytes
  • Size of remote file: 8.41 MB
vfi_models/amt/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ import pathlib
5
+ from vfi_utils import load_file_from_direct_url, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
6
+ import typing
7
+ from comfy.model_management import get_torch_device
8
+ from .amt_arch import AMT_S, AMT_L, AMT_G, InputPadder
9
+
10
+ #https://github.com/MCG-NKU/AMT/tree/main/cfgs
11
+ CKPT_CONFIGS = {
12
+ "amt-s.pth": {
13
+ "network": AMT_S,
14
+ "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 3 }
15
+ },
16
+ "amt-l.pth": {
17
+ "network": AMT_L,
18
+ "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 5 }
19
+ },
20
+ "amt-g.pth": {
21
+ "network": AMT_G,
22
+ "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 5 }
23
+ },
24
+ "gopro_amt-s.pth": {
25
+ "network": AMT_S,
26
+ "params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 3 }
27
+ }
28
+ }
29
+
30
+
31
+ MODEL_TYPE = pathlib.Path(__file__).parent.name
32
+
33
+ class AMT_VFI:
34
+ @classmethod
35
+ def INPUT_TYPES(s):
36
+ return {
37
+ "required": {
38
+ "ckpt_name": (list(CKPT_CONFIGS.keys()), ),
39
+ "frames": ("IMAGE", ),
40
+ "clear_cache_after_n_frames": ("INT", {"default": 1, "min": 1, "max": 100}),
41
+ "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000})
42
+ },
43
+ "optional": {
44
+ "optional_interpolation_states": ("INTERPOLATION_STATES", )
45
+ }
46
+ }
47
+
48
+ RETURN_TYPES = ("IMAGE", )
49
+ FUNCTION = "vfi"
50
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
51
+
52
+ def vfi(
53
+ self,
54
+ ckpt_name: typing.AnyStr,
55
+ frames: torch.Tensor,
56
+ clear_cache_after_n_frames: typing.SupportsInt = 1,
57
+ multiplier: typing.SupportsInt = 2,
58
+ optional_interpolation_states: InterpolationStateList = None,
59
+ **kwargs
60
+ ):
61
+ model_path = load_file_from_direct_url(MODEL_TYPE, f"https://huggingface.co/lalala125/AMT/resolve/main/{ckpt_name}")
62
+ ckpt_config = CKPT_CONFIGS[ckpt_name]
63
+
64
+ interpolation_model = ckpt_config["network"](**ckpt_config["params"])
65
+ interpolation_model.load_state_dict(torch.load(model_path)["state_dict"])
66
+ interpolation_model.eval().to(get_torch_device())
67
+
68
+ frames = preprocess_frames(frames)
69
+ padder = InputPadder(frames.shape, 16)
70
+ frames = padder.pad(frames)
71
+
72
+ def return_middle_frame(frame_0, frame_1, timestep, model):
73
+ return model(
74
+ frame_0,
75
+ frame_1,
76
+ embt=torch.FloatTensor([timestep] * frame_0.shape[0]).view(frame_0.shape[0], 1, 1, 1).to(get_torch_device()),
77
+ scale_factor=1.0,
78
+ eval=True
79
+ )["imgt_pred"]
80
+
81
+ args = [interpolation_model]
82
+ out = generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
83
+ interpolation_states=optional_interpolation_states, dtype=torch.float32)
84
+ out = padder.unpad(out)
85
+ out = postprocess_frames(out)
86
+ return (out,)
87
+
vfi_models/amt/amt_arch.py ADDED
@@ -0,0 +1,1590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/MCG-NKU/AMT/blob/main/utils/dist_utils.py
3
+ https://github.com/MCG-NKU/AMT/blob/main/utils/flow_utils.py
4
+ https://github.com/MCG-NKU/AMT/blob/main/utils/utils.py
5
+ https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/feat_enc.py
6
+ https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/ifrnet.py
7
+ https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/multi_flow.py
8
+ https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/raft.py
9
+ https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-S.py
10
+ https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-L.py
11
+ https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-G.py
12
+ """
13
+ #Removed imageio by removing readImage, writeImage
14
+ #The model will receive image tensors from other ComfyUI's nodes so they are unneccessary
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import numpy as np
19
+ from PIL import ImageFile
20
+ import torch.nn.functional as F
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+ import re
23
+ import sys
24
+ import random
25
+
26
+ def warp(img, flow):
27
+ B, _, H, W = flow.shape
28
+ xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
29
+ yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
30
+ grid = torch.cat([xx, yy], 1).to(img)
31
+ flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
32
+ grid_ = (grid + flow_).permute(0, 2, 3, 1)
33
+ output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
34
+ return output
35
+
36
+
37
+ def make_colorwheel():
38
+ """
39
+ Generates a color wheel for optical flow visualization as presented in:
40
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
41
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
42
+ Code follows the original C++ source code of Daniel Scharstein.
43
+ Code follows the the Matlab source code of Deqing Sun.
44
+ Returns:
45
+ np.ndarray: Color wheel
46
+ """
47
+
48
+ RY = 15
49
+ YG = 6
50
+ GC = 4
51
+ CB = 11
52
+ BM = 13
53
+ MR = 6
54
+
55
+ ncols = RY + YG + GC + CB + BM + MR
56
+ colorwheel = np.zeros((ncols, 3))
57
+ col = 0
58
+
59
+ # RY
60
+ colorwheel[0:RY, 0] = 255
61
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
62
+ col = col+RY
63
+ # YG
64
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
65
+ colorwheel[col:col+YG, 1] = 255
66
+ col = col+YG
67
+ # GC
68
+ colorwheel[col:col+GC, 1] = 255
69
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
70
+ col = col+GC
71
+ # CB
72
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
73
+ colorwheel[col:col+CB, 2] = 255
74
+ col = col+CB
75
+ # BM
76
+ colorwheel[col:col+BM, 2] = 255
77
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
78
+ col = col+BM
79
+ # MR
80
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
81
+ colorwheel[col:col+MR, 0] = 255
82
+ return colorwheel
83
+
84
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
85
+ """
86
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
87
+ According to the C++ source code of Daniel Scharstein
88
+ According to the Matlab source code of Deqing Sun
89
+ Args:
90
+ u (np.ndarray): Input horizontal flow of shape [H,W]
91
+ v (np.ndarray): Input vertical flow of shape [H,W]
92
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
93
+ Returns:
94
+ np.ndarray: Flow visualization image of shape [H,W,3]
95
+ """
96
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
97
+ colorwheel = make_colorwheel() # shape [55x3]
98
+ ncols = colorwheel.shape[0]
99
+ rad = np.sqrt(np.square(u) + np.square(v))
100
+ a = np.arctan2(-v, -u)/np.pi
101
+ fk = (a+1) / 2*(ncols-1)
102
+ k0 = np.floor(fk).astype(np.int32)
103
+ k1 = k0 + 1
104
+ k1[k1 == ncols] = 0
105
+ f = fk - k0
106
+ for i in range(colorwheel.shape[1]):
107
+ tmp = colorwheel[:,i]
108
+ col0 = tmp[k0] / 255.0
109
+ col1 = tmp[k1] / 255.0
110
+ col = (1-f)*col0 + f*col1
111
+ idx = (rad <= 1)
112
+ col[idx] = 1 - rad[idx] * (1-col[idx])
113
+ col[~idx] = col[~idx] * 0.75 # out of range
114
+ # Note the 2-i => BGR instead of RGB
115
+ ch_idx = 2-i if convert_to_bgr else i
116
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
117
+ return flow_image
118
+
119
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
120
+ """
121
+ Expects a two dimensional flow image of shape.
122
+ Args:
123
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
124
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
125
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
126
+ Returns:
127
+ np.ndarray: Flow visualization image of shape [H,W,3]
128
+ """
129
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
130
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
131
+ if clip_flow is not None:
132
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
133
+ u = flow_uv[:,:,0]
134
+ v = flow_uv[:,:,1]
135
+ rad = np.sqrt(np.square(u) + np.square(v))
136
+ rad_max = np.max(rad)
137
+ epsilon = 1e-5
138
+ u = u / (rad_max + epsilon)
139
+ v = v / (rad_max + epsilon)
140
+ return flow_uv_to_colors(u, v, convert_to_bgr)
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+ class AverageMeter():
153
+ def __init__(self):
154
+ self.reset()
155
+
156
+ def reset(self):
157
+ self.val = 0.
158
+ self.avg = 0.
159
+ self.sum = 0.
160
+ self.count = 0
161
+
162
+ def update(self, val, n=1):
163
+ self.val = val
164
+ self.sum += val * n
165
+ self.count += n
166
+ self.avg = self.sum / self.count
167
+
168
+
169
+ class AverageMeterGroups:
170
+ def __init__(self) -> None:
171
+ self.meter_dict = dict()
172
+
173
+ def update(self, dict, n=1):
174
+ for name, val in dict.items():
175
+ if self.meter_dict.get(name) is None:
176
+ self.meter_dict[name] = AverageMeter()
177
+ self.meter_dict[name].update(val, n)
178
+
179
+ def reset(self, name=None):
180
+ if name is None:
181
+ for v in self.meter_dict.values():
182
+ v.reset()
183
+ else:
184
+ meter = self.meter_dict.get(name)
185
+ if meter is not None:
186
+ meter.reset()
187
+
188
+ def avg(self, name):
189
+ meter = self.meter_dict.get(name)
190
+ if meter is not None:
191
+ return meter.avg
192
+
193
+
194
+ class InputPadder:
195
+ """ Pads images such that dimensions are divisible by divisor """
196
+ def __init__(self, dims, divisor=16):
197
+ self.ht, self.wd = dims[-2:]
198
+ pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
199
+ pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
200
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
201
+
202
+ def pad(self, input_tensor):
203
+ return F.pad(input_tensor, self._pad, mode='replicate')
204
+
205
+ def unpad(self, input_tensor):
206
+ return self._unpad(input_tensor)
207
+
208
+ def _unpad(self, x):
209
+ ht, wd = x.shape[-2:]
210
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
211
+ return x[..., c[0]:c[1], c[2]:c[3]]
212
+
213
+
214
+ def img2tensor(img):
215
+ if img.shape[-1] > 3:
216
+ img = img[:,:,:3]
217
+ return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
218
+
219
+
220
+ def tensor2img(img_t):
221
+ return (img_t * 255.).detach(
222
+ ).squeeze(0).permute(1, 2, 0).cpu().numpy(
223
+ ).clip(0, 255).astype(np.uint8)
224
+
225
+ def seed_all(seed):
226
+ random.seed(seed)
227
+ np.random.seed(seed)
228
+ torch.manual_seed(seed)
229
+ torch.cuda.manual_seed_all(seed)
230
+
231
+
232
+ def readPFM(file):
233
+ file = open(file, 'rb')
234
+
235
+ color = None
236
+ width = None
237
+ height = None
238
+ scale = None
239
+ endian = None
240
+
241
+ header = file.readline().rstrip()
242
+ if header.decode("ascii") == 'PF':
243
+ color = True
244
+ elif header.decode("ascii") == 'Pf':
245
+ color = False
246
+ else:
247
+ raise Exception('Not a PFM file.')
248
+
249
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
250
+ if dim_match:
251
+ width, height = list(map(int, dim_match.groups()))
252
+ else:
253
+ raise Exception('Malformed PFM header.')
254
+
255
+ scale = float(file.readline().decode("ascii").rstrip())
256
+ if scale < 0:
257
+ endian = '<'
258
+ scale = -scale
259
+ else:
260
+ endian = '>'
261
+
262
+ data = np.fromfile(file, endian + 'f')
263
+ shape = (height, width, 3) if color else (height, width)
264
+
265
+ data = np.reshape(data, shape)
266
+ data = np.flipud(data)
267
+ return data, scale
268
+
269
+
270
+ def writePFM(file, image, scale=1):
271
+ file = open(file, 'wb')
272
+
273
+ color = None
274
+
275
+ if image.dtype.name != 'float32':
276
+ raise Exception('Image dtype must be float32.')
277
+
278
+ image = np.flipud(image)
279
+
280
+ if len(image.shape) == 3 and image.shape[2] == 3:
281
+ color = True
282
+ elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
283
+ color = False
284
+ else:
285
+ raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
286
+
287
+ file.write('PF\n' if color else 'Pf\n'.encode())
288
+ file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
289
+
290
+ endian = image.dtype.byteorder
291
+
292
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
293
+ scale = -scale
294
+
295
+ file.write('%f\n'.encode() % scale)
296
+
297
+ image.tofile(file)
298
+
299
+
300
+ def readFlow(name):
301
+ if name.endswith('.pfm') or name.endswith('.PFM'):
302
+ return readPFM(name)[0][:,:,0:2]
303
+
304
+ f = open(name, 'rb')
305
+
306
+ header = f.read(4)
307
+ if header.decode("utf-8") != 'PIEH':
308
+ raise Exception('Flow file header does not contain PIEH')
309
+
310
+ width = np.fromfile(f, np.int32, 1).squeeze()
311
+ height = np.fromfile(f, np.int32, 1).squeeze()
312
+
313
+ flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
314
+
315
+ return flow.astype(np.float32)
316
+
317
+ def writeFlow(name, flow):
318
+ f = open(name, 'wb')
319
+ f.write('PIEH'.encode('utf-8'))
320
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
321
+ flow = flow.astype(np.float32)
322
+ flow.tofile(f)
323
+
324
+
325
+ def readFloat(name):
326
+ f = open(name, 'rb')
327
+
328
+ if(f.readline().decode("utf-8")) != 'float\n':
329
+ raise Exception('float file %s did not contain <float> keyword' % name)
330
+
331
+ dim = int(f.readline())
332
+
333
+ dims = []
334
+ count = 1
335
+ for i in range(0, dim):
336
+ d = int(f.readline())
337
+ dims.append(d)
338
+ count *= d
339
+
340
+ dims = list(reversed(dims))
341
+
342
+ data = np.fromfile(f, np.float32, count).reshape(dims)
343
+ if dim > 2:
344
+ data = np.transpose(data, (2, 1, 0))
345
+ data = np.transpose(data, (1, 0, 2))
346
+
347
+ return data
348
+
349
+
350
+ def writeFloat(name, data):
351
+ f = open(name, 'wb')
352
+
353
+ dim=len(data.shape)
354
+ if dim>3:
355
+ raise Exception('bad float file dimension: %d' % dim)
356
+
357
+ f.write(('float\n').encode('ascii'))
358
+ f.write(('%d\n' % dim).encode('ascii'))
359
+
360
+ if dim == 1:
361
+ f.write(('%d\n' % data.shape[0]).encode('ascii'))
362
+ else:
363
+ f.write(('%d\n' % data.shape[1]).encode('ascii'))
364
+ f.write(('%d\n' % data.shape[0]).encode('ascii'))
365
+ for i in range(2, dim):
366
+ f.write(('%d\n' % data.shape[i]).encode('ascii'))
367
+
368
+ data = data.astype(np.float32)
369
+ if dim==2:
370
+ data.tofile(f)
371
+
372
+ else:
373
+ np.transpose(data, (2, 0, 1)).tofile(f)
374
+
375
+
376
+ def check_dim_and_resize(tensor_list):
377
+ shape_list = []
378
+ for t in tensor_list:
379
+ shape_list.append(t.shape[2:])
380
+
381
+ if len(set(shape_list)) > 1:
382
+ desired_shape = shape_list[0]
383
+ print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
384
+
385
+ resize_tensor_list = []
386
+ for t in tensor_list:
387
+ resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
388
+
389
+ tensor_list = resize_tensor_list
390
+
391
+ return tensor_list
392
+
393
+
394
+
395
+
396
+
397
+
398
+
399
+
400
+
401
+
402
+
403
+ class BottleneckBlock(nn.Module):
404
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
405
+ super(BottleneckBlock, self).__init__()
406
+
407
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
408
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
409
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
410
+ self.relu = nn.ReLU(inplace=True)
411
+
412
+ num_groups = planes // 8
413
+
414
+ if norm_fn == 'group':
415
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
416
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
417
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
418
+ if not stride == 1:
419
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
420
+
421
+ elif norm_fn == 'batch':
422
+ self.norm1 = nn.BatchNorm2d(planes//4)
423
+ self.norm2 = nn.BatchNorm2d(planes//4)
424
+ self.norm3 = nn.BatchNorm2d(planes)
425
+ if not stride == 1:
426
+ self.norm4 = nn.BatchNorm2d(planes)
427
+
428
+ elif norm_fn == 'instance':
429
+ self.norm1 = nn.InstanceNorm2d(planes//4)
430
+ self.norm2 = nn.InstanceNorm2d(planes//4)
431
+ self.norm3 = nn.InstanceNorm2d(planes)
432
+ if not stride == 1:
433
+ self.norm4 = nn.InstanceNorm2d(planes)
434
+
435
+ elif norm_fn == 'none':
436
+ self.norm1 = nn.Sequential()
437
+ self.norm2 = nn.Sequential()
438
+ self.norm3 = nn.Sequential()
439
+ if not stride == 1:
440
+ self.norm4 = nn.Sequential()
441
+
442
+ if stride == 1:
443
+ self.downsample = None
444
+
445
+ else:
446
+ self.downsample = nn.Sequential(
447
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
448
+
449
+
450
+ def forward(self, x):
451
+ y = x
452
+ y = self.relu(self.norm1(self.conv1(y)))
453
+ y = self.relu(self.norm2(self.conv2(y)))
454
+ y = self.relu(self.norm3(self.conv3(y)))
455
+
456
+ if self.downsample is not None:
457
+ x = self.downsample(x)
458
+
459
+ return self.relu(x+y)
460
+
461
+
462
+ class ResidualBlock(nn.Module):
463
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
464
+ super(ResidualBlock, self).__init__()
465
+
466
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
467
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
468
+ self.relu = nn.ReLU(inplace=True)
469
+
470
+ num_groups = planes // 8
471
+
472
+ if norm_fn == 'group':
473
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
474
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
475
+ if not stride == 1:
476
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
477
+
478
+ elif norm_fn == 'batch':
479
+ self.norm1 = nn.BatchNorm2d(planes)
480
+ self.norm2 = nn.BatchNorm2d(planes)
481
+ if not stride == 1:
482
+ self.norm3 = nn.BatchNorm2d(planes)
483
+
484
+ elif norm_fn == 'instance':
485
+ self.norm1 = nn.InstanceNorm2d(planes)
486
+ self.norm2 = nn.InstanceNorm2d(planes)
487
+ if not stride == 1:
488
+ self.norm3 = nn.InstanceNorm2d(planes)
489
+
490
+ elif norm_fn == 'none':
491
+ self.norm1 = nn.Sequential()
492
+ self.norm2 = nn.Sequential()
493
+ if not stride == 1:
494
+ self.norm3 = nn.Sequential()
495
+
496
+ if stride == 1:
497
+ self.downsample = None
498
+
499
+ else:
500
+ self.downsample = nn.Sequential(
501
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
502
+
503
+
504
+ def forward(self, x):
505
+ y = x
506
+ y = self.relu(self.norm1(self.conv1(y)))
507
+ y = self.relu(self.norm2(self.conv2(y)))
508
+
509
+ if self.downsample is not None:
510
+ x = self.downsample(x)
511
+
512
+ return self.relu(x+y)
513
+
514
+
515
+ class SmallEncoder(nn.Module):
516
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
517
+ super(SmallEncoder, self).__init__()
518
+ self.norm_fn = norm_fn
519
+
520
+ if self.norm_fn == 'group':
521
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
522
+
523
+ elif self.norm_fn == 'batch':
524
+ self.norm1 = nn.BatchNorm2d(32)
525
+
526
+ elif self.norm_fn == 'instance':
527
+ self.norm1 = nn.InstanceNorm2d(32)
528
+
529
+ elif self.norm_fn == 'none':
530
+ self.norm1 = nn.Sequential()
531
+
532
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
533
+ self.relu1 = nn.ReLU(inplace=True)
534
+
535
+ self.in_planes = 32
536
+ self.layer1 = self._make_layer(32, stride=1)
537
+ self.layer2 = self._make_layer(64, stride=2)
538
+ self.layer3 = self._make_layer(96, stride=2)
539
+
540
+ self.dropout = None
541
+ if dropout > 0:
542
+ self.dropout = nn.Dropout2d(p=dropout)
543
+
544
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
545
+
546
+ for m in self.modules():
547
+ if isinstance(m, nn.Conv2d):
548
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
549
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
550
+ if m.weight is not None:
551
+ nn.init.constant_(m.weight, 1)
552
+ if m.bias is not None:
553
+ nn.init.constant_(m.bias, 0)
554
+
555
+ def _make_layer(self, dim, stride=1):
556
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
557
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
558
+ layers = (layer1, layer2)
559
+
560
+ self.in_planes = dim
561
+ return nn.Sequential(*layers)
562
+
563
+
564
+ def forward(self, x):
565
+
566
+ # if input is list, combine batch dimension
567
+ is_list = isinstance(x, tuple) or isinstance(x, list)
568
+ if is_list:
569
+ batch_dim = x[0].shape[0]
570
+ x = torch.cat(x, dim=0)
571
+
572
+ x = self.conv1(x)
573
+ x = self.norm1(x)
574
+ x = self.relu1(x)
575
+
576
+ x = self.layer1(x)
577
+ x = self.layer2(x)
578
+ x = self.layer3(x)
579
+ x = self.conv2(x)
580
+
581
+ if self.training and self.dropout is not None:
582
+ x = self.dropout(x)
583
+
584
+ if is_list:
585
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
586
+
587
+ return x
588
+
589
+ class BasicEncoder(nn.Module):
590
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
591
+ super(BasicEncoder, self).__init__()
592
+ self.norm_fn = norm_fn
593
+
594
+ if self.norm_fn == 'group':
595
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
596
+
597
+ elif self.norm_fn == 'batch':
598
+ self.norm1 = nn.BatchNorm2d(64)
599
+
600
+ elif self.norm_fn == 'instance':
601
+ self.norm1 = nn.InstanceNorm2d(64)
602
+
603
+ elif self.norm_fn == 'none':
604
+ self.norm1 = nn.Sequential()
605
+
606
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
607
+ self.relu1 = nn.ReLU(inplace=True)
608
+
609
+ self.in_planes = 64
610
+ self.layer1 = self._make_layer(64, stride=1)
611
+ self.layer2 = self._make_layer(72, stride=2)
612
+ self.layer3 = self._make_layer(128, stride=2)
613
+
614
+ # output convolution
615
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
616
+
617
+ self.dropout = None
618
+ if dropout > 0:
619
+ self.dropout = nn.Dropout2d(p=dropout)
620
+
621
+ for m in self.modules():
622
+ if isinstance(m, nn.Conv2d):
623
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
624
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
625
+ if m.weight is not None:
626
+ nn.init.constant_(m.weight, 1)
627
+ if m.bias is not None:
628
+ nn.init.constant_(m.bias, 0)
629
+
630
+ def _make_layer(self, dim, stride=1):
631
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
632
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
633
+ layers = (layer1, layer2)
634
+
635
+ self.in_planes = dim
636
+ return nn.Sequential(*layers)
637
+
638
+
639
+ def forward(self, x):
640
+
641
+ # if input is list, combine batch dimension
642
+ is_list = isinstance(x, tuple) or isinstance(x, list)
643
+ if is_list:
644
+ batch_dim = x[0].shape[0]
645
+ x = torch.cat(x, dim=0)
646
+
647
+ x = self.conv1(x)
648
+ x = self.norm1(x)
649
+ x = self.relu1(x)
650
+
651
+ x = self.layer1(x)
652
+ x = self.layer2(x)
653
+ x = self.layer3(x)
654
+
655
+ x = self.conv2(x)
656
+
657
+ if self.training and self.dropout is not None:
658
+ x = self.dropout(x)
659
+
660
+ if is_list:
661
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
662
+
663
+ return x
664
+
665
+ class LargeEncoder(nn.Module):
666
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
667
+ super(LargeEncoder, self).__init__()
668
+ self.norm_fn = norm_fn
669
+
670
+ if self.norm_fn == 'group':
671
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
672
+
673
+ elif self.norm_fn == 'batch':
674
+ self.norm1 = nn.BatchNorm2d(64)
675
+
676
+ elif self.norm_fn == 'instance':
677
+ self.norm1 = nn.InstanceNorm2d(64)
678
+
679
+ elif self.norm_fn == 'none':
680
+ self.norm1 = nn.Sequential()
681
+
682
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
683
+ self.relu1 = nn.ReLU(inplace=True)
684
+
685
+ self.in_planes = 64
686
+ self.layer1 = self._make_layer(64, stride=1)
687
+ self.layer2 = self._make_layer(112, stride=2)
688
+ self.layer3 = self._make_layer(160, stride=2)
689
+ self.layer3_2 = self._make_layer(160, stride=1)
690
+
691
+ # output convolution
692
+ self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1)
693
+
694
+ self.dropout = None
695
+ if dropout > 0:
696
+ self.dropout = nn.Dropout2d(p=dropout)
697
+
698
+ for m in self.modules():
699
+ if isinstance(m, nn.Conv2d):
700
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
701
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
702
+ if m.weight is not None:
703
+ nn.init.constant_(m.weight, 1)
704
+ if m.bias is not None:
705
+ nn.init.constant_(m.bias, 0)
706
+
707
+ def _make_layer(self, dim, stride=1):
708
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
709
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
710
+ layers = (layer1, layer2)
711
+
712
+ self.in_planes = dim
713
+ return nn.Sequential(*layers)
714
+
715
+
716
+ def forward(self, x):
717
+
718
+ # if input is list, combine batch dimension
719
+ is_list = isinstance(x, tuple) or isinstance(x, list)
720
+ if is_list:
721
+ batch_dim = x[0].shape[0]
722
+ x = torch.cat(x, dim=0)
723
+
724
+ x = self.conv1(x)
725
+ x = self.norm1(x)
726
+ x = self.relu1(x)
727
+
728
+ x = self.layer1(x)
729
+ x = self.layer2(x)
730
+ x = self.layer3(x)
731
+ x = self.layer3_2(x)
732
+
733
+ x = self.conv2(x)
734
+
735
+ if self.training and self.dropout is not None:
736
+ x = self.dropout(x)
737
+
738
+ if is_list:
739
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
740
+
741
+ return x
742
+
743
+
744
+
745
+
746
+
747
+
748
+
749
+
750
+
751
+
752
+
753
+ def resize(x, scale_factor):
754
+ return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
755
+
756
+ def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
757
+ return nn.Sequential(
758
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
759
+ nn.PReLU(out_channels)
760
+ )
761
+
762
+ class ResBlock(nn.Module):
763
+ def __init__(self, in_channels, side_channels, bias=True):
764
+ super(ResBlock, self).__init__()
765
+ self.side_channels = side_channels
766
+ self.conv1 = nn.Sequential(
767
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
768
+ nn.PReLU(in_channels)
769
+ )
770
+ self.conv2 = nn.Sequential(
771
+ nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
772
+ nn.PReLU(side_channels)
773
+ )
774
+ self.conv3 = nn.Sequential(
775
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
776
+ nn.PReLU(in_channels)
777
+ )
778
+ self.conv4 = nn.Sequential(
779
+ nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
780
+ nn.PReLU(side_channels)
781
+ )
782
+ self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
783
+ self.prelu = nn.PReLU(in_channels)
784
+
785
+ def forward(self, x):
786
+ out = self.conv1(x)
787
+
788
+ res_feat = out[:, :-self.side_channels, ...]
789
+ side_feat = out[:, -self.side_channels:, :, :]
790
+ side_feat = self.conv2(side_feat)
791
+ out = self.conv3(torch.cat([res_feat, side_feat], 1))
792
+
793
+ res_feat = out[:, :-self.side_channels, ...]
794
+ side_feat = out[:, -self.side_channels:, :, :]
795
+ side_feat = self.conv4(side_feat)
796
+ out = self.conv5(torch.cat([res_feat, side_feat], 1))
797
+
798
+ out = self.prelu(x + out)
799
+ return out
800
+
801
+ class Encoder(nn.Module):
802
+ def __init__(self, channels, large=False):
803
+ super(Encoder, self).__init__()
804
+ self.channels = channels
805
+ prev_ch = 3
806
+ for idx, ch in enumerate(channels, 1):
807
+ k = 7 if large and idx == 1 else 3
808
+ p = 3 if k ==7 else 1
809
+ self.register_module(f'pyramid{idx}',
810
+ nn.Sequential(
811
+ convrelu(prev_ch, ch, k, 2, p),
812
+ convrelu(ch, ch, 3, 1, 1)
813
+ ))
814
+ prev_ch = ch
815
+
816
+ def forward(self, in_x):
817
+ fs = []
818
+ for idx in range(len(self.channels)):
819
+ out_x = getattr(self, f'pyramid{idx+1}')(in_x)
820
+ fs.append(out_x)
821
+ in_x = out_x
822
+ return fs
823
+
824
+ class InitDecoder(nn.Module):
825
+ def __init__(self, in_ch, out_ch, skip_ch) -> None:
826
+ super().__init__()
827
+ self.convblock = nn.Sequential(
828
+ convrelu(in_ch*2+1, in_ch*2),
829
+ ResBlock(in_ch*2, skip_ch),
830
+ nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
831
+ )
832
+ def forward(self, f0, f1, embt):
833
+ h, w = f0.shape[2:]
834
+ embt = embt.repeat(1, 1, h, w)
835
+ out = self.convblock(torch.cat([f0, f1, embt], 1))
836
+ flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
837
+ ft_ = out[:, 4:, ...]
838
+ return flow0, flow1, ft_
839
+
840
+ class IntermediateDecoder(nn.Module):
841
+ def __init__(self, in_ch, out_ch, skip_ch) -> None:
842
+ super().__init__()
843
+ self.convblock = nn.Sequential(
844
+ convrelu(in_ch*3+4, in_ch*3),
845
+ ResBlock(in_ch*3, skip_ch),
846
+ nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
847
+ )
848
+ def forward(self, ft_, f0, f1, flow0_in, flow1_in):
849
+ f0_warp = warp(f0, flow0_in)
850
+ f1_warp = warp(f1, flow1_in)
851
+ f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)
852
+ out = self.convblock(f_in)
853
+ flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
854
+ ft_ = out[:, 4:, ...]
855
+ flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
856
+ flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
857
+ return flow0, flow1, ft_
858
+
859
+
860
+
861
+
862
+
863
+
864
+
865
+
866
+
867
+
868
+
869
+ def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
870
+ mask=None, img_res=None, mean=None):
871
+ '''
872
+ A parallel implementation of multiple flow field warping
873
+ comb_block: An nn.Seqential object.
874
+ img shape: [b, c, h, w]
875
+ flow shape: [b, 2*num_flows, h, w]
876
+ mask (opt):
877
+ If 'mask' is None, the function conduct a simple average.
878
+ img_res (opt):
879
+ If 'img_res' is None, the function adds zero instead.
880
+ mean (opt):
881
+ If 'mean' is None, the function adds zero instead.
882
+ '''
883
+ b, c, h, w = flow0.shape
884
+ num_flows = c // 2
885
+ flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
886
+ flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
887
+
888
+ mask = mask.reshape(b, num_flows, 1, h, w
889
+ ).reshape(-1, 1, h, w) if mask is not None else None
890
+ img_res = img_res.reshape(b, num_flows, 3, h, w
891
+ ).reshape(-1, 3, h, w) if img_res is not None else 0
892
+ img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
893
+ img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
894
+ mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
895
+ ) if mean is not None else 0
896
+
897
+ img0_warp = warp(img0, flow0)
898
+ img1_warp = warp(img1, flow1)
899
+ img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
900
+ img_warps = img_warps.reshape(b, num_flows, 3, h, w)
901
+ imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
902
+ return imgt_pred
903
+
904
+
905
+ class MultiFlowDecoder(nn.Module):
906
+ def __init__(self, in_ch, skip_ch, num_flows=3):
907
+ super(MultiFlowDecoder, self).__init__()
908
+ self.num_flows = num_flows
909
+ self.convblock = nn.Sequential(
910
+ convrelu(in_ch*3+4, in_ch*3),
911
+ ResBlock(in_ch*3, skip_ch),
912
+ nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
913
+ )
914
+
915
+ def forward(self, ft_, f0, f1, flow0, flow1):
916
+ n = self.num_flows
917
+ f0_warp = warp(f0, flow0)
918
+ f1_warp = warp(f1, flow1)
919
+ out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
920
+ delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
921
+ mask = torch.sigmoid(mask)
922
+
923
+ flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
924
+ ).repeat(1, self.num_flows, 1, 1)
925
+ flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
926
+ ).repeat(1, self.num_flows, 1, 1)
927
+
928
+ return flow0, flow1, mask, img_res
929
+
930
+
931
+
932
+
933
+
934
+
935
+
936
+
937
+
938
+
939
+
940
+ def resize(x, scale_factor):
941
+ return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
942
+
943
+
944
+ def bilinear_sampler(img, coords, mask=False):
945
+ """ Wrapper for grid_sample, uses pixel coordinates """
946
+ H, W = img.shape[-2:]
947
+ xgrid, ygrid = coords.split([1,1], dim=-1)
948
+ xgrid = 2*xgrid/(W-1) - 1
949
+ ygrid = 2*ygrid/(H-1) - 1
950
+
951
+ grid = torch.cat([xgrid, ygrid], dim=-1)
952
+ img = F.grid_sample(img, grid, align_corners=True)
953
+
954
+ if mask:
955
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
956
+ return img, mask.float()
957
+
958
+ return img
959
+
960
+
961
+ def coords_grid(batch, ht, wd, device):
962
+ coords = torch.meshgrid(torch.arange(ht, device=device),
963
+ torch.arange(wd, device=device),
964
+ indexing='ij')
965
+ coords = torch.stack(coords[::-1], dim=0).float()
966
+ return coords[None].repeat(batch, 1, 1, 1)
967
+
968
+
969
+ class SmallUpdateBlock(nn.Module):
970
+ def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
971
+ corr_levels=4, radius=3, scale_factor=None):
972
+ super(SmallUpdateBlock, self).__init__()
973
+ cor_planes = corr_levels * (2 * radius + 1) **2
974
+ self.scale_factor = scale_factor
975
+
976
+ self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
977
+ self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
978
+ self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
979
+ self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
980
+
981
+ self.gru = nn.Sequential(
982
+ nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
983
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
984
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
985
+ )
986
+
987
+ self.feat_head = nn.Sequential(
988
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
989
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
990
+ nn.Conv2d(hidden_dim, cdim, 3, padding=1),
991
+ )
992
+
993
+ self.flow_head = nn.Sequential(
994
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
995
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
996
+ nn.Conv2d(hidden_dim, 4, 3, padding=1),
997
+ )
998
+
999
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
1000
+
1001
+ def forward(self, net, flow, corr):
1002
+ net = resize(net, 1 / self.scale_factor
1003
+ ) if self.scale_factor is not None else net
1004
+ cor = self.lrelu(self.convc1(corr))
1005
+ flo = self.lrelu(self.convf1(flow))
1006
+ flo = self.lrelu(self.convf2(flo))
1007
+ cor_flo = torch.cat([cor, flo], dim=1)
1008
+ inp = self.lrelu(self.conv(cor_flo))
1009
+ inp = torch.cat([inp, flow, net], dim=1)
1010
+
1011
+ out = self.gru(inp)
1012
+ delta_net = self.feat_head(out)
1013
+ delta_flow = self.flow_head(out)
1014
+
1015
+ if self.scale_factor is not None:
1016
+ delta_net = resize(delta_net, scale_factor=self.scale_factor)
1017
+ delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
1018
+
1019
+ return delta_net, delta_flow
1020
+
1021
+
1022
+ class BasicUpdateBlock(nn.Module):
1023
+ def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
1024
+ fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
1025
+ super(BasicUpdateBlock, self).__init__()
1026
+ cor_planes = corr_levels * (2 * radius + 1) **2
1027
+
1028
+ self.scale_factor = scale_factor
1029
+ self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
1030
+ self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
1031
+ self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
1032
+ self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
1033
+ self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
1034
+
1035
+ self.gru = nn.Sequential(
1036
+ nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
1037
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
1038
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
1039
+ )
1040
+
1041
+ self.feat_head = nn.Sequential(
1042
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
1043
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
1044
+ nn.Conv2d(hidden_dim, cdim, 3, padding=1),
1045
+ )
1046
+
1047
+ self.flow_head = nn.Sequential(
1048
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
1049
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
1050
+ nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
1051
+ )
1052
+
1053
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
1054
+
1055
+ def forward(self, net, flow, corr):
1056
+ net = resize(net, 1 / self.scale_factor
1057
+ ) if self.scale_factor is not None else net
1058
+ cor = self.lrelu(self.convc1(corr))
1059
+ cor = self.lrelu(self.convc2(cor))
1060
+ flo = self.lrelu(self.convf1(flow))
1061
+ flo = self.lrelu(self.convf2(flo))
1062
+ cor_flo = torch.cat([cor, flo], dim=1)
1063
+ inp = self.lrelu(self.conv(cor_flo))
1064
+ inp = torch.cat([inp, flow, net], dim=1)
1065
+
1066
+ out = self.gru(inp)
1067
+ delta_net = self.feat_head(out)
1068
+ delta_flow = self.flow_head(out)
1069
+
1070
+ if self.scale_factor is not None:
1071
+ delta_net = resize(delta_net, scale_factor=self.scale_factor)
1072
+ delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
1073
+ return delta_net, delta_flow
1074
+
1075
+
1076
+ class BidirCorrBlock:
1077
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
1078
+ self.num_levels = num_levels
1079
+ self.radius = radius
1080
+ self.corr_pyramid = []
1081
+ self.corr_pyramid_T = []
1082
+
1083
+ corr = BidirCorrBlock.corr(fmap1, fmap2)
1084
+ batch, h1, w1, dim, h2, w2 = corr.shape
1085
+ corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
1086
+
1087
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
1088
+ corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
1089
+
1090
+ self.corr_pyramid.append(corr)
1091
+ self.corr_pyramid_T.append(corr_T)
1092
+
1093
+ for _ in range(self.num_levels-1):
1094
+ corr = F.avg_pool2d(corr, 2, stride=2)
1095
+ corr_T = F.avg_pool2d(corr_T, 2, stride=2)
1096
+ self.corr_pyramid.append(corr)
1097
+ self.corr_pyramid_T.append(corr_T)
1098
+
1099
+ def __call__(self, coords0, coords1):
1100
+ r = self.radius
1101
+ coords0 = coords0.permute(0, 2, 3, 1)
1102
+ coords1 = coords1.permute(0, 2, 3, 1)
1103
+ assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
1104
+ batch, h1, w1, _ = coords0.shape
1105
+
1106
+ out_pyramid = []
1107
+ out_pyramid_T = []
1108
+ for i in range(self.num_levels):
1109
+ corr = self.corr_pyramid[i]
1110
+ corr_T = self.corr_pyramid_T[i]
1111
+
1112
+ dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
1113
+ dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
1114
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
1115
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
1116
+
1117
+ centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
1118
+ centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
1119
+ coords_lvl_0 = centroid_lvl_0 + delta_lvl
1120
+ coords_lvl_1 = centroid_lvl_1 + delta_lvl
1121
+
1122
+ corr = bilinear_sampler(corr, coords_lvl_0)
1123
+ corr_T = bilinear_sampler(corr_T, coords_lvl_1)
1124
+ corr = corr.view(batch, h1, w1, -1)
1125
+ corr_T = corr_T.view(batch, h1, w1, -1)
1126
+ out_pyramid.append(corr)
1127
+ out_pyramid_T.append(corr_T)
1128
+
1129
+ out = torch.cat(out_pyramid, dim=-1)
1130
+ out_T = torch.cat(out_pyramid_T, dim=-1)
1131
+ return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
1132
+
1133
+ @staticmethod
1134
+ def corr(fmap1, fmap2):
1135
+ batch, dim, ht, wd = fmap1.shape
1136
+ fmap1 = fmap1.view(batch, dim, ht*wd)
1137
+ fmap2 = fmap2.view(batch, dim, ht*wd)
1138
+
1139
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
1140
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
1141
+ return corr / torch.sqrt(torch.tensor(dim).float())
1142
+
1143
+
1144
+
1145
+
1146
+
1147
+
1148
+
1149
+
1150
+
1151
+
1152
+
1153
+ class AMT_S(nn.Module):
1154
+ def __init__(self,
1155
+ corr_radius=3,
1156
+ corr_lvls=4,
1157
+ num_flows=3,
1158
+ channels=[20, 32, 44, 56],
1159
+ skip_channels=20):
1160
+ super(AMT_S, self).__init__()
1161
+ self.radius = corr_radius
1162
+ self.corr_levels = corr_lvls
1163
+ self.num_flows = num_flows
1164
+ self.channels = channels
1165
+ self.skip_channels = skip_channels
1166
+
1167
+ self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.)
1168
+ self.encoder = Encoder(channels)
1169
+
1170
+ self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
1171
+ self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
1172
+ self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
1173
+ self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
1174
+
1175
+ self.update4 = self._get_updateblock(44)
1176
+ self.update3 = self._get_updateblock(32, 2)
1177
+ self.update2 = self._get_updateblock(20, 4)
1178
+
1179
+ self.comb_block = nn.Sequential(
1180
+ nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1),
1181
+ nn.PReLU(6*num_flows),
1182
+ nn.Conv2d(6*num_flows, 3, 3, 1, 1),
1183
+ )
1184
+
1185
+ def _get_updateblock(self, cdim, scale_factor=None):
1186
+ return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64,
1187
+ fc_dim=68, scale_factor=scale_factor,
1188
+ corr_levels=self.corr_levels, radius=self.radius)
1189
+
1190
+ def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
1191
+ # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
1192
+ # based on linear assumption
1193
+ t1_scale = 1. / embt
1194
+ t0_scale = 1. / (1. - embt)
1195
+ if downsample != 1:
1196
+ inv = 1 / downsample
1197
+ flow0 = inv * resize(flow0, scale_factor=inv)
1198
+ flow1 = inv * resize(flow1, scale_factor=inv)
1199
+
1200
+ corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
1201
+ corr = torch.cat([corr0, corr1], dim=1)
1202
+ flow = torch.cat([flow0, flow1], dim=1)
1203
+ return corr, flow
1204
+
1205
+ def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
1206
+ mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
1207
+ img0 = img0 - mean_
1208
+ img1 = img1 - mean_
1209
+ img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
1210
+ img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
1211
+ b, _, h, w = img0_.shape
1212
+ coord = coords_grid(b, h // 8, w // 8, img0.device)
1213
+
1214
+ fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
1215
+ corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
1216
+
1217
+ # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
1218
+ # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
1219
+ f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
1220
+ f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
1221
+
1222
+ ######################################### the 4th decoder #########################################
1223
+ up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
1224
+ corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
1225
+ up_flow0_4, up_flow1_4,
1226
+ embt, downsample=1)
1227
+
1228
+ # residue update with lookup corr
1229
+ delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
1230
+ delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
1231
+ up_flow0_4 = up_flow0_4 + delta_flow0_4
1232
+ up_flow1_4 = up_flow1_4 + delta_flow1_4
1233
+ ft_3_ = ft_3_ + delta_ft_3_
1234
+
1235
+ ######################################### the 3rd decoder #########################################
1236
+ up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
1237
+ corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
1238
+ coord, up_flow0_3, up_flow1_3,
1239
+ embt, downsample=2)
1240
+
1241
+ # residue update with lookup corr
1242
+ delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
1243
+ delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
1244
+ up_flow0_3 = up_flow0_3 + delta_flow0_3
1245
+ up_flow1_3 = up_flow1_3 + delta_flow1_3
1246
+ ft_2_ = ft_2_ + delta_ft_2_
1247
+
1248
+ ######################################### the 2nd decoder #########################################
1249
+ up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
1250
+ corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
1251
+ coord, up_flow0_2, up_flow1_2,
1252
+ embt, downsample=4)
1253
+
1254
+ # residue update with lookup corr
1255
+ delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
1256
+ delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
1257
+ up_flow0_2 = up_flow0_2 + delta_flow0_2
1258
+ up_flow1_2 = up_flow1_2 + delta_flow1_2
1259
+ ft_1_ = ft_1_ + delta_ft_1_
1260
+
1261
+ ######################################### the 1st decoder #########################################
1262
+ up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
1263
+
1264
+ if scale_factor != 1.0:
1265
+ up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
1266
+ up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
1267
+ mask = resize(mask, scale_factor=(1.0/scale_factor))
1268
+ img_res = resize(img_res, scale_factor=(1.0/scale_factor))
1269
+
1270
+ # Merge multiple predictions
1271
+ imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
1272
+ mask, img_res, mean_)
1273
+ imgt_pred = torch.clamp(imgt_pred, 0, 1)
1274
+
1275
+ if eval:
1276
+ return { 'imgt_pred': imgt_pred, }
1277
+ else:
1278
+ up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
1279
+ up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
1280
+ return {
1281
+ 'imgt_pred': imgt_pred,
1282
+ 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
1283
+ 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
1284
+ 'ft_pred': [ft_1_, ft_2_, ft_3_],
1285
+ }
1286
+
1287
+
1288
+
1289
+
1290
+
1291
+
1292
+
1293
+
1294
+
1295
+
1296
+
1297
+ class AMT_L(nn.Module):
1298
+ def __init__(self,
1299
+ corr_radius=3,
1300
+ corr_lvls=4,
1301
+ num_flows=5,
1302
+ channels=[48, 64, 72, 128],
1303
+ skip_channels=48
1304
+ ):
1305
+ super(AMT_L, self).__init__()
1306
+ self.radius = corr_radius
1307
+ self.corr_levels = corr_lvls
1308
+ self.num_flows = num_flows
1309
+
1310
+ self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.)
1311
+ self.encoder = Encoder([48, 64, 72, 128], large=True)
1312
+
1313
+ self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
1314
+ self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
1315
+ self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
1316
+ self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
1317
+
1318
+ self.update4 = self._get_updateblock(72, None)
1319
+ self.update3 = self._get_updateblock(64, 2.0)
1320
+ self.update2 = self._get_updateblock(48, 4.0)
1321
+
1322
+ self.comb_block = nn.Sequential(
1323
+ nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
1324
+ nn.PReLU(6*self.num_flows),
1325
+ nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
1326
+ )
1327
+
1328
+ def _get_updateblock(self, cdim, scale_factor=None):
1329
+ return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48,
1330
+ corr_dim=256, corr_dim2=160, fc_dim=124,
1331
+ scale_factor=scale_factor, corr_levels=self.corr_levels,
1332
+ radius=self.radius)
1333
+
1334
+ def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
1335
+ # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
1336
+ # based on linear assumption
1337
+ t1_scale = 1. / embt
1338
+ t0_scale = 1. / (1. - embt)
1339
+ if downsample != 1:
1340
+ inv = 1 / downsample
1341
+ flow0 = inv * resize(flow0, scale_factor=inv)
1342
+ flow1 = inv * resize(flow1, scale_factor=inv)
1343
+
1344
+ corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
1345
+ corr = torch.cat([corr0, corr1], dim=1)
1346
+ flow = torch.cat([flow0, flow1], dim=1)
1347
+ return corr, flow
1348
+
1349
+ def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
1350
+ mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
1351
+ img0 = img0 - mean_
1352
+ img1 = img1 - mean_
1353
+ img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
1354
+ img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
1355
+ b, _, h, w = img0_.shape
1356
+ coord = coords_grid(b, h // 8, w // 8, img0.device)
1357
+
1358
+ fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
1359
+ corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
1360
+
1361
+ # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
1362
+ # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
1363
+ f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
1364
+ f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
1365
+
1366
+ ######################################### the 4th decoder #########################################
1367
+ up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
1368
+ corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
1369
+ up_flow0_4, up_flow1_4,
1370
+ embt, downsample=1)
1371
+
1372
+ # residue update with lookup corr
1373
+ delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
1374
+ delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
1375
+ up_flow0_4 = up_flow0_4 + delta_flow0_4
1376
+ up_flow1_4 = up_flow1_4 + delta_flow1_4
1377
+ ft_3_ = ft_3_ + delta_ft_3_
1378
+
1379
+ ######################################### the 3rd decoder #########################################
1380
+ up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
1381
+ corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
1382
+ coord, up_flow0_3, up_flow1_3,
1383
+ embt, downsample=2)
1384
+
1385
+ # residue update with lookup corr
1386
+ delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
1387
+ delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
1388
+ up_flow0_3 = up_flow0_3 + delta_flow0_3
1389
+ up_flow1_3 = up_flow1_3 + delta_flow1_3
1390
+ ft_2_ = ft_2_ + delta_ft_2_
1391
+
1392
+ ######################################### the 2nd decoder #########################################
1393
+ up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
1394
+ corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
1395
+ coord, up_flow0_2, up_flow1_2,
1396
+ embt, downsample=4)
1397
+
1398
+ # residue update with lookup corr
1399
+ delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
1400
+ delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
1401
+ up_flow0_2 = up_flow0_2 + delta_flow0_2
1402
+ up_flow1_2 = up_flow1_2 + delta_flow1_2
1403
+ ft_1_ = ft_1_ + delta_ft_1_
1404
+
1405
+ ######################################### the 1st decoder #########################################
1406
+ up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
1407
+
1408
+ if scale_factor != 1.0:
1409
+ up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
1410
+ up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
1411
+ mask = resize(mask, scale_factor=(1.0/scale_factor))
1412
+ img_res = resize(img_res, scale_factor=(1.0/scale_factor))
1413
+
1414
+ # Merge multiple predictions
1415
+ imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
1416
+ mask, img_res, mean_)
1417
+ imgt_pred = torch.clamp(imgt_pred, 0, 1)
1418
+
1419
+ if eval:
1420
+ return { 'imgt_pred': imgt_pred, }
1421
+ else:
1422
+ up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
1423
+ up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
1424
+ return {
1425
+ 'imgt_pred': imgt_pred,
1426
+ 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
1427
+ 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
1428
+ 'ft_pred': [ft_1_, ft_2_, ft_3_],
1429
+ }
1430
+
1431
+
1432
+
1433
+
1434
+
1435
+
1436
+
1437
+
1438
+
1439
+
1440
+
1441
+ class AMT_G(nn.Module):
1442
+ def __init__(self,
1443
+ corr_radius=3,
1444
+ corr_lvls=4,
1445
+ num_flows=5,
1446
+ channels=[84, 96, 112, 128],
1447
+ skip_channels=84):
1448
+ super(AMT_G, self).__init__()
1449
+ self.radius = corr_radius
1450
+ self.corr_levels = corr_lvls
1451
+ self.num_flows = num_flows
1452
+
1453
+ self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.)
1454
+ self.encoder = Encoder(channels, large=True)
1455
+ self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
1456
+ self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
1457
+ self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
1458
+ self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
1459
+
1460
+ self.update4 = self._get_updateblock(112, None)
1461
+ self.update3_low = self._get_updateblock(96, 2.0)
1462
+ self.update2_low = self._get_updateblock(84, 4.0)
1463
+
1464
+ self.update3_high = self._get_updateblock(96, None)
1465
+ self.update2_high = self._get_updateblock(84, None)
1466
+
1467
+ self.comb_block = nn.Sequential(
1468
+ nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
1469
+ nn.PReLU(6*self.num_flows),
1470
+ nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
1471
+ )
1472
+
1473
+ def _get_updateblock(self, cdim, scale_factor=None):
1474
+ return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64,
1475
+ corr_dim=256, corr_dim2=192, fc_dim=188,
1476
+ scale_factor=scale_factor, corr_levels=self.corr_levels,
1477
+ radius=self.radius)
1478
+
1479
+ def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
1480
+ # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
1481
+ # based on linear assumption
1482
+ t1_scale = 1. / embt
1483
+ t0_scale = 1. / (1. - embt)
1484
+ if downsample != 1:
1485
+ inv = 1 / downsample
1486
+ flow0 = inv * resize(flow0, scale_factor=inv)
1487
+ flow1 = inv * resize(flow1, scale_factor=inv)
1488
+
1489
+ corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
1490
+ corr = torch.cat([corr0, corr1], dim=1)
1491
+ flow = torch.cat([flow0, flow1], dim=1)
1492
+ return corr, flow
1493
+
1494
+ def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
1495
+ mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
1496
+ img0 = img0 - mean_
1497
+ img1 = img1 - mean_
1498
+ img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
1499
+ img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
1500
+ b, _, h, w = img0_.shape
1501
+ coord = coords_grid(b, h // 8, w // 8, img0.device)
1502
+
1503
+ fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
1504
+ corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
1505
+
1506
+ # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
1507
+ # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
1508
+ f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
1509
+ f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
1510
+
1511
+ ######################################### the 4th decoder #########################################
1512
+ up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
1513
+ corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
1514
+ up_flow0_4, up_flow1_4,
1515
+ embt, downsample=1)
1516
+
1517
+ # residue update with lookup corr
1518
+ delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
1519
+ delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
1520
+ up_flow0_4 = up_flow0_4 + delta_flow0_4
1521
+ up_flow1_4 = up_flow1_4 + delta_flow1_4
1522
+ ft_3_ = ft_3_ + delta_ft_3_
1523
+
1524
+ ######################################### the 3rd decoder #########################################
1525
+ up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
1526
+ corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
1527
+ coord, up_flow0_3, up_flow1_3,
1528
+ embt, downsample=2)
1529
+
1530
+ # residue update with lookup corr
1531
+ delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3)
1532
+ delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
1533
+ up_flow0_3 = up_flow0_3 + delta_flow0_3
1534
+ up_flow1_3 = up_flow1_3 + delta_flow1_3
1535
+ ft_2_ = ft_2_ + delta_ft_2_
1536
+
1537
+ # residue update with lookup corr (hr)
1538
+ corr_3 = resize(corr_3, scale_factor=2.0)
1539
+ up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1)
1540
+ delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3)
1541
+ ft_2_ += delta_ft_2_
1542
+ up_flow0_3 += delta_up_flow_3[:, 0:2]
1543
+ up_flow1_3 += delta_up_flow_3[:, 2:4]
1544
+
1545
+ ######################################### the 2nd decoder #########################################
1546
+ up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
1547
+ corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
1548
+ coord, up_flow0_2, up_flow1_2,
1549
+ embt, downsample=4)
1550
+
1551
+ # residue update with lookup corr
1552
+ delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2)
1553
+ delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
1554
+ up_flow0_2 = up_flow0_2 + delta_flow0_2
1555
+ up_flow1_2 = up_flow1_2 + delta_flow1_2
1556
+ ft_1_ = ft_1_ + delta_ft_1_
1557
+
1558
+ # residue update with lookup corr (hr)
1559
+ corr_2 = resize(corr_2, scale_factor=4.0)
1560
+ up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1)
1561
+ delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2)
1562
+ ft_1_ += delta_ft_1_
1563
+ up_flow0_2 += delta_up_flow_2[:, 0:2]
1564
+ up_flow1_2 += delta_up_flow_2[:, 2:4]
1565
+
1566
+ ######################################### the 1st decoder #########################################
1567
+ up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
1568
+
1569
+ if scale_factor != 1.0:
1570
+ up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
1571
+ up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
1572
+ mask = resize(mask, scale_factor=(1.0/scale_factor))
1573
+ img_res = resize(img_res, scale_factor=(1.0/scale_factor))
1574
+
1575
+ # Merge multiple predictions
1576
+ imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
1577
+ mask, img_res, mean_)
1578
+ imgt_pred = torch.clamp(imgt_pred, 0, 1)
1579
+
1580
+ if eval:
1581
+ return { 'imgt_pred': imgt_pred, }
1582
+ else:
1583
+ up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
1584
+ up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
1585
+ return {
1586
+ 'imgt_pred': imgt_pred,
1587
+ 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
1588
+ 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
1589
+ 'ft_pred': [ft_1_, ft_2_, ft_3_],
1590
+ }
vfi_models/cain/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ import pathlib
4
+ from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
5
+ import typing
6
+ from comfy.model_management import get_torch_device
7
+
8
+ MODEL_TYPE = pathlib.Path(__file__).parent.name
9
+ CKPT_NAMES = ["pretrained_cain.pth"]
10
+
11
+
12
+ class CAIN_VFI:
13
+ @classmethod
14
+ def INPUT_TYPES(s):
15
+ return {
16
+ "required": {
17
+ "ckpt_name": (CKPT_NAMES, ),
18
+ "frames": ("IMAGE", ),
19
+ "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
20
+ "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000})
21
+ },
22
+ "optional": {
23
+ "optional_interpolation_states": ("INTERPOLATION_STATES", )
24
+ }
25
+ }
26
+
27
+ RETURN_TYPES = ("IMAGE", )
28
+ FUNCTION = "vfi"
29
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
30
+
31
+ def vfi(
32
+ self,
33
+ ckpt_name: typing.AnyStr,
34
+ frames: torch.Tensor,
35
+ clear_cache_after_n_frames: typing.SupportsInt = 1,
36
+ multiplier: typing.SupportsInt = 2,
37
+ optional_interpolation_states: InterpolationStateList = None,
38
+ **kwargs
39
+ ):
40
+ from .cain_arch import CAIN
41
+ model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
42
+ sd = torch.load(model_path)["state_dict"]
43
+ sd = {key.replace('module.', ''): value for key, value in sd.items()}
44
+
45
+
46
+ global interpolation_model
47
+ interpolation_model = CAIN(depth=3)
48
+ interpolation_model.load_state_dict(sd)
49
+ interpolation_model.eval().to(get_torch_device())
50
+ del sd
51
+
52
+ frames = preprocess_frames(frames)
53
+
54
+
55
+ def return_middle_frame(frame_0, frame_1, timestep, model):
56
+ #CAIN does some direct modifications to input frame tensors so we need to clone them
57
+ return model(frame_0.detach().clone(), frame_1.detach().clone())[0]
58
+
59
+ args = [interpolation_model]
60
+ out = postprocess_frames(
61
+ generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
62
+ interpolation_states=optional_interpolation_states, use_timestep=False, dtype=torch.float32)
63
+ )
64
+ return (out,)
vfi_models/cain/cain_arch.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .common import *
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, in_channels=3, depth=3):
12
+ super(Encoder, self).__init__()
13
+
14
+ # Shuffle pixels to expand in channel dimension
15
+ # shuffler_list = [PixelShuffle(0.5) for i in range(depth)]
16
+ # self.shuffler = nn.Sequential(*shuffler_list)
17
+ self.shuffler = PixelShuffle(1 / 2**depth)
18
+
19
+ relu = nn.LeakyReLU(0.2, True)
20
+
21
+ # FF_RCAN or FF_Resblocks
22
+ self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu)
23
+
24
+ def forward(self, x1, x2):
25
+ """
26
+ Encoder: Shuffle-spread --> Feature Fusion --> Return fused features
27
+ """
28
+ feats1 = self.shuffler(x1)
29
+ feats2 = self.shuffler(x2)
30
+
31
+ feats = self.interpolate(feats1, feats2)
32
+
33
+ return feats
34
+
35
+
36
+ class Decoder(nn.Module):
37
+ def __init__(self, depth=3):
38
+ super(Decoder, self).__init__()
39
+
40
+ # shuffler_list = [PixelShuffle(2) for i in range(depth)]
41
+ # self.shuffler = nn.Sequential(*shuffler_list)
42
+ self.shuffler = PixelShuffle(2**depth)
43
+
44
+ def forward(self, feats):
45
+ out = self.shuffler(feats)
46
+ return out
47
+
48
+
49
+ class CAIN(nn.Module):
50
+ def __init__(self, depth=3):
51
+ super(CAIN, self).__init__()
52
+
53
+ self.encoder = Encoder(in_channels=3, depth=depth)
54
+ self.decoder = Decoder(depth=depth)
55
+
56
+ def forward(self, x1, x2):
57
+ x1, m1 = sub_mean(x1)
58
+ x2, m2 = sub_mean(x2)
59
+
60
+ if not self.training:
61
+ paddingInput, paddingOutput = InOutPaddings(x1)
62
+ x1 = paddingInput(x1)
63
+ x2 = paddingInput(x2)
64
+
65
+ feats = self.encoder(x1, x2)
66
+ out = self.decoder(feats)
67
+
68
+ if not self.training:
69
+ out = paddingOutput(out)
70
+
71
+ mi = (m1 + m2) / 2
72
+ out += mi
73
+
74
+ return out, feats
vfi_models/cain/cain_encdec_arch.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .common import *
8
+ from comfy.model_management import get_torch_device
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, in_channels=3, depth=3, nf_start=32, norm=False):
12
+ super(Encoder, self).__init__()
13
+ self.device = get_torch_device()
14
+
15
+ nf = nf_start
16
+ relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ self.body = nn.Sequential(
19
+ ConvNorm(in_channels, nf * 1, 7, stride=1, norm=norm),
20
+ relu,
21
+ ConvNorm(nf * 1, nf * 2, 5, stride=2, norm=norm),
22
+ relu,
23
+ ConvNorm(nf * 2, nf * 4, 5, stride=2, norm=norm),
24
+ relu,
25
+ ConvNorm(nf * 4, nf * 6, 5, stride=2, norm=norm)
26
+ )
27
+
28
+ self.interpolate = Interpolation(5, 12, nf * 6, reduction=16, act=relu)
29
+
30
+ def forward(self, x1, x2):
31
+ """
32
+ Encoder: Feature Extraction --> Feature Fusion --> Return
33
+ """
34
+ feats1 = self.body(x1)
35
+ feats2 = self.body(x2)
36
+
37
+ feats = self.interpolate(feats1, feats2)
38
+
39
+ return feats
40
+
41
+
42
+ class Decoder(nn.Module):
43
+ def __init__(self, in_channels=192, out_channels=3, depth=3, norm=False, up_mode='shuffle'):
44
+ super(Decoder, self).__init__()
45
+ self.device = get_torch_device()
46
+
47
+ relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
48
+
49
+ nf = [in_channels, (in_channels*2)//3, in_channels//3, in_channels//6]
50
+ #nf = [192, 128, 64, 32]
51
+ #nf = [186, 124, 62, 31]
52
+ self.body = nn.Sequential(
53
+ UpConvNorm(nf[0], nf[1], mode=up_mode, norm=norm),
54
+ ResBlock(nf[1], nf[1], norm=norm, act=relu),
55
+ UpConvNorm(nf[1], nf[2], mode=up_mode, norm=norm),
56
+ ResBlock(nf[2], nf[2], norm=norm, act=relu),
57
+ UpConvNorm(nf[2], nf[3], mode=up_mode, norm=norm),
58
+ ResBlock(nf[3], nf[3], norm=norm, act=relu),
59
+ conv7x7(nf[3], out_channels)
60
+ )
61
+
62
+ def forward(self, feats):
63
+ out = self.body(feats)
64
+ #out = self.conv_final(out)
65
+
66
+ return out
67
+
68
+
69
+ class CAIN_EncDec(nn.Module):
70
+ def __init__(self, depth=3, n_resblocks=3, start_filts=32, up_mode='shuffle'):
71
+ super(CAIN_EncDec, self).__init__()
72
+ self.depth = depth
73
+
74
+ self.encoder = Encoder(in_channels=3, depth=depth, norm=False)
75
+ self.decoder = Decoder(in_channels=start_filts*6, depth=depth, norm=False, up_mode=up_mode)
76
+
77
+ def forward(self, x1, x2):
78
+ x1, m1 = sub_mean(x1)
79
+ x2, m2 = sub_mean(x2)
80
+
81
+ if not self.training:
82
+ paddingInput, paddingOutput = InOutPaddings(x1)
83
+ x1 = paddingInput(x1)
84
+ x2 = paddingInput(x2)
85
+
86
+ feats = self.encoder(x1, x2)
87
+ out = self.decoder(feats)
88
+
89
+ if not self.training:
90
+ out = paddingOutput(out)
91
+
92
+ mi = (m1 + m2)/2
93
+ out += mi
94
+
95
+ return out, feats
vfi_models/cain/cain_noca_arch.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .common import *
8
+ from comfy.model_management import get_torch_device
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, in_channels=3, depth=3):
12
+ super(Encoder, self).__init__()
13
+ self.device = get_torch_device()
14
+
15
+ self.shuffler = PixelShuffle(1/2**depth)
16
+ # self.shuffler = nn.Sequential(
17
+ # PixelShuffle(1/2),
18
+ # PixelShuffle(1/2),
19
+ # PixelShuffle(1/2))
20
+ self.interpolate = Interpolation_res(5, 12, in_channels * (4**depth))
21
+
22
+ def forward(self, x1, x2):
23
+ feats1 = self.shuffler(x1)
24
+ feats2 = self.shuffler(x2)
25
+
26
+ feats = self.interpolate(feats1, feats2)
27
+
28
+ return feats
29
+
30
+
31
+ class Decoder(nn.Module):
32
+ def __init__(self, depth=3):
33
+ super(Decoder, self).__init__()
34
+ self.device = get_torch_device()
35
+
36
+ self.shuffler = PixelShuffle(2**depth)
37
+ # self.shuffler = nn.Sequential(
38
+ # PixelShuffle(2),
39
+ # PixelShuffle(2),
40
+ # PixelShuffle(2))
41
+
42
+ def forward(self, feats):
43
+ out = self.shuffler(feats)
44
+ return out
45
+
46
+
47
+ class CAIN_NoCA(nn.Module):
48
+ def __init__(self, depth=3):
49
+ super(CAIN_NoCA, self).__init__()
50
+ self.depth = depth
51
+
52
+ self.encoder = Encoder(in_channels=3, depth=depth)
53
+ self.decoder = Decoder(depth=depth)
54
+
55
+ def forward(self, x1, x2):
56
+ x1, m1 = sub_mean(x1)
57
+ x2, m2 = sub_mean(x2)
58
+
59
+ if not self.training:
60
+ paddingInput, paddingOutput = InOutPaddings(x1)
61
+ x1 = paddingInput(x1)
62
+ x2 = paddingInput(x2)
63
+
64
+ feats = self.encoder(x1, x2)
65
+ out = self.decoder(feats)
66
+
67
+ if not self.training:
68
+ out = paddingOutput(out)
69
+
70
+ mi = (m1 + m2) / 2
71
+ out += mi
72
+
73
+ return out, feats
vfi_models/cain/common.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def sub_mean(x):
8
+ mean = x.mean(2, keepdim=True).mean(3, keepdim=True)
9
+ x -= mean
10
+ return x, mean
11
+
12
+ def InOutPaddings(x):
13
+ w, h = x.size(3), x.size(2)
14
+ padding_width, padding_height = 0, 0
15
+ if w != ((w >> 7) << 7):
16
+ padding_width = (((w >> 7) + 1) << 7) - w
17
+ if h != ((h >> 7) << 7):
18
+ padding_height = (((h >> 7) + 1) << 7) - h
19
+ paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2,
20
+ padding_height // 2, padding_height - padding_height // 2])
21
+ paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width,
22
+ 0 - padding_height // 2, padding_height // 2 - padding_height])
23
+ return paddingInput, paddingOutput
24
+
25
+
26
+ class ConvNorm(nn.Module):
27
+ def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False):
28
+ super(ConvNorm, self).__init__()
29
+
30
+ reflection_padding = kernel_size // 2
31
+ self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
32
+ self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True)
33
+
34
+ self.norm = norm
35
+ if norm == 'IN':
36
+ self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True)
37
+ elif norm == 'BN':
38
+ self.norm = nn.BatchNorm2d(out_feat)
39
+
40
+ def forward(self, x):
41
+ out = self.reflection_pad(x)
42
+ out = self.conv(out)
43
+ if self.norm:
44
+ out = self.norm(out)
45
+ return out
46
+
47
+
48
+ class UpConvNorm(nn.Module):
49
+ def __init__(self, in_channels, out_channels, mode='transpose', norm=False):
50
+ super(UpConvNorm, self).__init__()
51
+
52
+ if mode == 'transpose':
53
+ self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
54
+ elif mode == 'shuffle':
55
+ self.upconv = nn.Sequential(
56
+ ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm),
57
+ PixelShuffle(2))
58
+ else:
59
+ # out_channels is always going to be the same as in_channels
60
+ self.upconv = nn.Sequential(
61
+ nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
62
+ ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm))
63
+
64
+ def forward(self, x):
65
+ out = self.upconv(x)
66
+ return out
67
+
68
+
69
+
70
+ class meanShift(nn.Module):
71
+ def __init__(self, rgbRange, rgbMean, sign, nChannel=3):
72
+ super(meanShift, self).__init__()
73
+ if nChannel == 1:
74
+ l = rgbMean[0] * rgbRange * float(sign)
75
+
76
+ self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0)
77
+ self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1)
78
+ self.shifter.bias.data = torch.Tensor([l])
79
+ elif nChannel == 3:
80
+ r = rgbMean[0] * rgbRange * float(sign)
81
+ g = rgbMean[1] * rgbRange * float(sign)
82
+ b = rgbMean[2] * rgbRange * float(sign)
83
+
84
+ self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
85
+ self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
86
+ self.shifter.bias.data = torch.Tensor([r, g, b])
87
+ else:
88
+ r = rgbMean[0] * rgbRange * float(sign)
89
+ g = rgbMean[1] * rgbRange * float(sign)
90
+ b = rgbMean[2] * rgbRange * float(sign)
91
+ self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0)
92
+ self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1)
93
+ self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b])
94
+
95
+ # Freeze the meanShift layer
96
+ for params in self.shifter.parameters():
97
+ params.requires_grad = False
98
+
99
+ def forward(self, x):
100
+ x = self.shifter(x)
101
+
102
+ return x
103
+
104
+
105
+ """ CONV - (BN) - RELU - CONV - (BN) """
106
+ class ResBlock(nn.Module):
107
+ def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder
108
+ norm=False, act=nn.ReLU(True), downscale=False):
109
+ super(ResBlock, self).__init__()
110
+
111
+ self.body = nn.Sequential(
112
+ ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1),
113
+ act,
114
+ ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1)
115
+ )
116
+
117
+ self.downscale = None
118
+ if downscale:
119
+ self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2)
120
+
121
+ def forward(self, x):
122
+ res = x
123
+ out = self.body(x)
124
+ if self.downscale is not None:
125
+ res = self.downscale(res)
126
+ out += res
127
+
128
+ return out
129
+
130
+
131
+ ## Channel Attention (CA) Layer
132
+ class CALayer(nn.Module):
133
+ def __init__(self, channel, reduction=16):
134
+ super(CALayer, self).__init__()
135
+ # global average pooling: feature --> point
136
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
137
+ # feature channel downscale and upscale --> channel weight
138
+ self.conv_du = nn.Sequential(
139
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
140
+ nn.ReLU(inplace=True),
141
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
142
+ nn.Sigmoid()
143
+ )
144
+
145
+ def forward(self, x):
146
+ y = self.avg_pool(x)
147
+ y = self.conv_du(y)
148
+ return x * y, y
149
+
150
+
151
+ ## Residual Channel Attention Block (RCAB)
152
+ class RCAB(nn.Module):
153
+ def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True,
154
+ norm=False, act=nn.ReLU(True), downscale=False, return_ca=False):
155
+ super(RCAB, self).__init__()
156
+
157
+ self.body = nn.Sequential(
158
+ ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm),
159
+ act,
160
+ ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm),
161
+ CALayer(out_feat, reduction)
162
+ )
163
+ self.downscale = downscale
164
+ if downscale:
165
+ self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1)
166
+ self.return_ca = return_ca
167
+
168
+ def forward(self, x):
169
+ res = x
170
+ out, ca = self.body(x)
171
+ if self.downscale:
172
+ res = self.downConv(res)
173
+ out += res
174
+
175
+ if self.return_ca:
176
+ return out, ca
177
+ else:
178
+ return out
179
+
180
+
181
+ ## Residual Group (RG)
182
+ class ResidualGroup(nn.Module):
183
+ def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False):
184
+ super(ResidualGroup, self).__init__()
185
+
186
+ modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act)
187
+ for _ in range(n_resblocks)]
188
+ modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm))
189
+ self.body = nn.Sequential(*modules_body)
190
+
191
+ def forward(self, x):
192
+ res = self.body(x)
193
+ res += x
194
+ return res
195
+
196
+
197
+ def pixel_shuffle(input, scale_factor):
198
+ batch_size, channels, in_height, in_width = input.size()
199
+
200
+ out_channels = int(int(channels / scale_factor) / scale_factor)
201
+ out_height = int(in_height * scale_factor)
202
+ out_width = int(in_width * scale_factor)
203
+
204
+ if scale_factor >= 1:
205
+ input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width)
206
+ shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
207
+ else:
208
+ block_size = int(1 / scale_factor)
209
+ input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size)
210
+ shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
211
+
212
+ return shuffle_out.view(batch_size, out_channels, out_height, out_width)
213
+
214
+
215
+ class PixelShuffle(nn.Module):
216
+ def __init__(self, scale_factor):
217
+ super(PixelShuffle, self).__init__()
218
+ self.scale_factor = scale_factor
219
+
220
+ def forward(self, x):
221
+ return pixel_shuffle(x, self.scale_factor)
222
+ def extra_repr(self):
223
+ return 'scale_factor={}'.format(self.scale_factor)
224
+
225
+
226
+ def conv(in_channels, out_channels, kernel_size,
227
+ stride=1, bias=True, groups=1):
228
+ return nn.Conv2d(
229
+ in_channels,
230
+ out_channels,
231
+ kernel_size=kernel_size,
232
+ padding=kernel_size//2,
233
+ stride=1,
234
+ bias=bias,
235
+ groups=groups)
236
+
237
+
238
+ def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1):
239
+ return nn.Conv2d(
240
+ in_channels,
241
+ out_channels,
242
+ kernel_size=1,
243
+ stride=stride,
244
+ bias=bias,
245
+ groups=groups)
246
+
247
+ def conv3x3(in_channels, out_channels, stride=1,
248
+ padding=1, bias=True, groups=1):
249
+ return nn.Conv2d(
250
+ in_channels,
251
+ out_channels,
252
+ kernel_size=3,
253
+ stride=stride,
254
+ padding=padding,
255
+ bias=bias,
256
+ groups=groups)
257
+
258
+ def conv5x5(in_channels, out_channels, stride=1,
259
+ padding=2, bias=True, groups=1):
260
+ return nn.Conv2d(
261
+ in_channels,
262
+ out_channels,
263
+ kernel_size=5,
264
+ stride=stride,
265
+ padding=padding,
266
+ bias=bias,
267
+ groups=groups)
268
+
269
+ def conv7x7(in_channels, out_channels, stride=1,
270
+ padding=3, bias=True, groups=1):
271
+ return nn.Conv2d(
272
+ in_channels,
273
+ out_channels,
274
+ kernel_size=7,
275
+ stride=stride,
276
+ padding=padding,
277
+ bias=bias,
278
+ groups=groups)
279
+
280
+ def upconv2x2(in_channels, out_channels, mode='shuffle'):
281
+ if mode == 'transpose':
282
+ return nn.ConvTranspose2d(
283
+ in_channels,
284
+ out_channels,
285
+ kernel_size=4,
286
+ stride=2,
287
+ padding=1)
288
+ elif mode == 'shuffle':
289
+ return nn.Sequential(
290
+ conv3x3(in_channels, 4*out_channels),
291
+ PixelShuffle(2))
292
+ else:
293
+ # out_channels is always going to be the same as in_channels
294
+ return nn.Sequential(
295
+ nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
296
+ conv1x1(in_channels, out_channels))
297
+
298
+
299
+
300
+ class Interpolation(nn.Module):
301
+ def __init__(self, n_resgroups, n_resblocks, n_feats,
302
+ reduction=16, act=nn.LeakyReLU(0.2, True), norm=False):
303
+ super(Interpolation, self).__init__()
304
+
305
+ # define modules: head, body, tail
306
+ self.headConv = conv3x3(n_feats * 2, n_feats)
307
+
308
+ modules_body = [
309
+ ResidualGroup(
310
+ RCAB,
311
+ n_resblocks=n_resblocks,
312
+ n_feat=n_feats,
313
+ kernel_size=3,
314
+ reduction=reduction,
315
+ act=act,
316
+ norm=norm)
317
+ for _ in range(n_resgroups)]
318
+ self.body = nn.Sequential(*modules_body)
319
+
320
+ self.tailConv = conv3x3(n_feats, n_feats)
321
+
322
+ def forward(self, x0, x1):
323
+ # Build input tensor
324
+ x = torch.cat([x0, x1], dim=1)
325
+ x = self.headConv(x)
326
+
327
+ res = self.body(x)
328
+ res += x
329
+
330
+ out = self.tailConv(res)
331
+ return out
332
+
333
+
334
+ class Interpolation_res(nn.Module):
335
+ def __init__(self, n_resgroups, n_resblocks, n_feats,
336
+ act=nn.LeakyReLU(0.2, True), norm=False):
337
+ super(Interpolation_res, self).__init__()
338
+
339
+ # define modules: head, body, tail (reduces concatenated inputs to n_feat)
340
+ self.headConv = conv3x3(n_feats * 2, n_feats)
341
+
342
+ modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3,
343
+ reduction=0, act=act, norm=norm)
344
+ for _ in range(n_resgroups)]
345
+ self.body = nn.Sequential(*modules_body)
346
+
347
+ self.tailConv = conv3x3(n_feats, n_feats)
348
+
349
+ def forward(self, x0, x1):
350
+ # Build input tensor
351
+ x = torch.cat([x0, x1], dim=1)
352
+ x = self.headConv(x)
353
+
354
+ res = x
355
+ for m in self.body:
356
+ res = m(res)
357
+ res += x
358
+
359
+ x = self.tailConv(res)
360
+
361
+ return x
vfi_models/eisai/__init__.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
3
+ import typing
4
+ import torch
5
+ import torch.nn as nn
6
+ from comfy.model_management import soft_empty_cache, get_torch_device
7
+
8
+ MODEL_TYPE = pathlib.Path(__file__).parent.name
9
+ MODEL_FILE_NAMES = {
10
+ "ssl": "eisai_ssl.pt",
11
+ "dtm": "eisai_dtm.pt",
12
+ "raft": "eisai_anime_interp_full.ckpt"
13
+ }
14
+
15
+ class EISAI(nn.Module):
16
+ def __init__(self, model_file_names) -> None:
17
+ from .eisai_arch import SoftsplatLite, DTM, RAFT
18
+ super(EISAI, self).__init__()
19
+ self.raft = RAFT(load_file_from_github_release(MODEL_TYPE, model_file_names["raft"]))
20
+ self.raft.to(get_torch_device()).eval()
21
+
22
+ self.ssl = SoftsplatLite()
23
+ self.ssl.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["ssl"])))
24
+ self.ssl.to(get_torch_device()).eval()
25
+
26
+ self.dtm = DTM()
27
+ self.dtm.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["dtm"])))
28
+ self.dtm.to(get_torch_device()).eval()
29
+
30
+ def forward(self, img0, img1, t):
31
+ with torch.no_grad():
32
+ flow0, _ = self.raft(img0, img1)
33
+ flow1, _ = self.raft(img1, img0)
34
+ x = {
35
+ "images": torch.stack([img0, img1], dim=1),
36
+ "flows": torch.stack([flow0, flow1], dim=1),
37
+ }
38
+ out_ssl, _ = self.ssl(x, t=t, return_more=True)
39
+ out_dtm, _ = self.dtm(x, out_ssl, _, return_more=False)
40
+ return out_dtm[:, :3]
41
+
42
+ class EISAI_VFI:
43
+ @classmethod
44
+ def INPUT_TYPES(s):
45
+ return {
46
+ "required": {
47
+ "ckpt_name": (["eisai"], ),
48
+ "frames": ("IMAGE", ),
49
+ "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
50
+ "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
51
+ },
52
+ "optional": {
53
+ "optional_interpolation_states": ("INTERPOLATION_STATES", )
54
+ }
55
+ }
56
+
57
+ RETURN_TYPES = ("IMAGE", )
58
+ FUNCTION = "vfi"
59
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
60
+
61
+ def vfi(
62
+ self,
63
+ ckpt_name: typing.AnyStr,
64
+ frames: torch.Tensor,
65
+ clear_cache_after_n_frames = 10,
66
+ multiplier: typing.SupportsInt = 2,
67
+ optional_interpolation_states: InterpolationStateList = None,
68
+ **kwargs
69
+ ):
70
+ interpolation_model = EISAI(MODEL_FILE_NAMES)
71
+ interpolation_model.eval().to(get_torch_device())
72
+ frames = preprocess_frames(frames)
73
+
74
+ def return_middle_frame(frame_0, frame_1, timestep, model):
75
+ return model(frame_0, frame_1, t=timestep)
76
+
77
+ scale = 1
78
+
79
+ args = [interpolation_model, scale]
80
+ out = postprocess_frames(
81
+ generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
82
+ interpolation_states=optional_interpolation_states, dtype=torch.float32)
83
+ )
84
+ return (out,)
vfi_models/eisai/eisai_arch.py ADDED
@@ -0,0 +1,2586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_scripts/interpolate.py
3
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/models/ssldtm.py
4
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/util_v0.py
5
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/twodee_v0.py
6
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/pytorch_v0.py
7
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/distance_transform_v0.py
8
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/sketchers_v1.py
9
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/interpolator_v0.py
10
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/gridnet_v1.py
11
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/flow_v0.py
12
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/softsplat_v0.py
13
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/rfr_new.py
14
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/extractor.py
15
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/update.py
16
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/corr.py
17
+ https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/utils.py
18
+ """
19
+
20
+ import copy
21
+ import cv2
22
+ import torch.nn.functional as F
23
+ import torchvision.transforms.functional as F
24
+ import gc
25
+ from PIL import Image, ImageFile, ImageFont, ImageDraw
26
+ import inspect
27
+ from scipy import interpolate
28
+ import kornia
29
+ import math
30
+ from argparse import Namespace
31
+ import torch.nn as nn
32
+ import numpy as np
33
+ import os
34
+ from functools import partial
35
+ import pathlib
36
+ import PIL
37
+ import re
38
+ import requests
39
+ from scipy.spatial.transform import Rotation
40
+ import scipy
41
+ import shutil
42
+ import torchvision.transforms as T
43
+ import time
44
+ import torch
45
+ import torchvision as tv
46
+ import zlib
47
+ import numpy as np
48
+ import torch
49
+ import torch.nn as nn
50
+ import torch.nn.functional as F
51
+ from tqdm.auto import tqdm as std_tqdm
52
+ from tqdm.auto import trange as std_trange
53
+ from vfi_models.ops import FunctionSoftsplat, batch_edt
54
+ from comfy.model_management import get_torch_device
55
+
56
+ device = get_torch_device()
57
+ autocast = torch.autocast
58
+ tqdm = partial(std_tqdm, dynamic_ncols=True)
59
+ trange = partial(std_trange, dynamic_ncols=True)
60
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
61
+
62
+
63
+ def pixel_ij(x, rounding=True):
64
+ if isinstance(x, np.ndarray):
65
+ x = x.tolist()
66
+ return tuple(
67
+ pixel_rounder(i, rounding)
68
+ for i in (x if isinstance(x, tuple) or isinstance(x, list) else (x, x))
69
+ )
70
+
71
+
72
+ def rescale_dry(x, factor):
73
+ h, w = x[-2:] if isinstance(x, tuple) or isinstance(x, list) else I(x).size
74
+ return (h * factor, w * factor)
75
+
76
+
77
+ def pixel_rounder(n, mode):
78
+ if mode == True or mode == "round":
79
+ return round(n)
80
+ elif mode == "ceil":
81
+ return math.ceil(n)
82
+ elif mode == "floor":
83
+ return math.floor(n)
84
+ else:
85
+ return n
86
+
87
+
88
+ def diam(x):
89
+ if isinstance(x, tuple) or isinstance(x, list):
90
+ h, w = x[-2:]
91
+ elif isinstance(x, I):
92
+ h, w = x.size
93
+ else:
94
+ h, w = x.shape[-2:]
95
+ return np.sqrt(h**2 + w**2)
96
+
97
+
98
+ def pixel_logit(x, pixel_margin=1):
99
+ x = (x * (255 - 2 * pixel_margin) + pixel_margin) / 255
100
+ return torch.log(x / (1 - x))
101
+
102
+
103
+ class InputPadder:
104
+ """Pads images such that dimensions are divisible by 8"""
105
+
106
+ def __init__(self, dims):
107
+ self.ht, self.wd = dims[-2:]
108
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
109
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
110
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
111
+
112
+ def pad(self, *inputs):
113
+ return [F.pad(x, self._pad, mode="replicate") for x in inputs]
114
+
115
+ def unpad(self, x):
116
+ ht, wd = x.shape[-2:]
117
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
118
+ return x[..., c[0] : c[1], c[2] : c[3]]
119
+
120
+
121
+ def forward_interpolate(flow):
122
+ flow = flow.detach().cpu().numpy()
123
+ dx, dy = flow[0], flow[1]
124
+
125
+ ht, wd = dx.shape
126
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
127
+
128
+ x1 = x0 + dx
129
+ y1 = y0 + dy
130
+
131
+ x1 = x1.reshape(-1)
132
+ y1 = y1.reshape(-1)
133
+ dx = dx.reshape(-1)
134
+ dy = dy.reshape(-1)
135
+
136
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
137
+ x1 = x1[valid]
138
+ y1 = y1[valid]
139
+ dx = dx[valid]
140
+ dy = dy[valid]
141
+
142
+ flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method="cubic", fill_value=0)
143
+
144
+ flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method="cubic", fill_value=0)
145
+
146
+ flow = np.stack([flow_x, flow_y], axis=0)
147
+ return torch.from_numpy(flow).float()
148
+
149
+
150
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False):
151
+ """Wrapper for grid_sample, uses pixel coordinates"""
152
+ H, W = img.shape[-2:]
153
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
154
+ xgrid = 2 * xgrid / (W - 1) - 1
155
+ ygrid = 2 * ygrid / (H - 1) - 1
156
+
157
+ grid = torch.cat([xgrid, ygrid], dim=-1)
158
+ # print(img.size())
159
+ img = F.grid_sample(img, grid, align_corners=True)
160
+
161
+ if mask:
162
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
163
+ return img, mask.float()
164
+
165
+ return img
166
+
167
+
168
+ def coords_grid(batch, ht, wd):
169
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
170
+ coords = torch.stack(coords[::-1], dim=0).float()
171
+ return coords[None].repeat(batch, 1, 1, 1)
172
+
173
+
174
+ def upflow8(flow, mode="bilinear"):
175
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
176
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
177
+
178
+
179
+ class CorrBlock:
180
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
181
+ self.num_levels = num_levels
182
+ self.radius = radius
183
+ self.corr_pyramid = []
184
+
185
+ # all pairs correlation
186
+ corr = CorrBlock.corr(fmap1, fmap2)
187
+
188
+ batch, h1, w1, dim, h2, w2 = corr.shape
189
+ corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
190
+
191
+ self.corr_pyramid.append(corr)
192
+ for i in range(self.num_levels - 1):
193
+ corr = F.avg_pool2d(corr, 2, stride=2)
194
+ self.corr_pyramid.append(corr)
195
+
196
+ def __call__(self, coords):
197
+ r = self.radius
198
+ coords = coords.permute(0, 2, 3, 1)
199
+ batch, h1, w1, _ = coords.shape
200
+
201
+ out_pyramid = []
202
+ for i in range(self.num_levels):
203
+ corr = self.corr_pyramid[i]
204
+ dx = torch.linspace(-r, r, 2 * r + 1)
205
+ dy = torch.linspace(-r, r, 2 * r + 1)
206
+ delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
207
+
208
+ centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
209
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
210
+ coords_lvl = centroid_lvl + delta_lvl
211
+
212
+ corr = bilinear_sampler(corr, coords_lvl)
213
+ corr = corr.view(batch, h1, w1, -1)
214
+ out_pyramid.append(corr)
215
+
216
+ out = torch.cat(out_pyramid, dim=-1)
217
+ return out.permute(0, 3, 1, 2).contiguous().float()
218
+
219
+ @staticmethod
220
+ def corr(fmap1, fmap2):
221
+ batch, dim, ht, wd = fmap1.shape
222
+ fmap1 = fmap1.view(batch, dim, ht * wd)
223
+ fmap2 = fmap2.view(batch, dim, ht * wd)
224
+
225
+ corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
226
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
227
+ return corr / torch.sqrt(torch.tensor(dim).float())
228
+
229
+
230
+ class FlowHead(nn.Module):
231
+ def __init__(self, input_dim=128, hidden_dim=256):
232
+ super(FlowHead, self).__init__()
233
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
234
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
235
+ self.relu = nn.ReLU(inplace=True)
236
+
237
+ def forward(self, x):
238
+ return self.conv2(self.relu(self.conv1(x)))
239
+
240
+
241
+ class ConvGRU(nn.Module):
242
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
243
+ super(ConvGRU, self).__init__()
244
+ self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
245
+ self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
246
+ self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
247
+
248
+ def forward(self, h, x):
249
+ hx = torch.cat([h, x], dim=1)
250
+
251
+ z = torch.sigmoid(self.convz(hx))
252
+ r = torch.sigmoid(self.convr(hx))
253
+ q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
254
+
255
+ h = (1 - z) * h + z * q
256
+ return h
257
+
258
+
259
+ class SepConvGRU(nn.Module):
260
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
261
+ super(SepConvGRU, self).__init__()
262
+ self.convz1 = nn.Conv2d(
263
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
264
+ )
265
+ self.convr1 = nn.Conv2d(
266
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
267
+ )
268
+ self.convq1 = nn.Conv2d(
269
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
270
+ )
271
+
272
+ self.convz2 = nn.Conv2d(
273
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
274
+ )
275
+ self.convr2 = nn.Conv2d(
276
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
277
+ )
278
+ self.convq2 = nn.Conv2d(
279
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
280
+ )
281
+
282
+ def forward(self, h, x):
283
+ # horizontal
284
+ hx = torch.cat([h, x], dim=1)
285
+ z = torch.sigmoid(self.convz1(hx))
286
+ r = torch.sigmoid(self.convr1(hx))
287
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
288
+ h = (1 - z) * h + z * q
289
+
290
+ # vertical
291
+ hx = torch.cat([h, x], dim=1)
292
+ z = torch.sigmoid(self.convz2(hx))
293
+ r = torch.sigmoid(self.convr2(hx))
294
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
295
+ h = (1 - z) * h + z * q
296
+
297
+ return h
298
+
299
+
300
+ class SmallMotionEncoder(nn.Module):
301
+ def __init__(self, args):
302
+ super(SmallMotionEncoder, self).__init__()
303
+ cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
304
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
305
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
306
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
307
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
308
+
309
+ def forward(self, flow, corr):
310
+ cor = F.relu(self.convc1(corr))
311
+ flo = F.relu(self.convf1(flow))
312
+ flo = F.relu(self.convf2(flo))
313
+ cor_flo = torch.cat([cor, flo], dim=1)
314
+ out = F.relu(self.conv(cor_flo))
315
+ return torch.cat([out, flow], dim=1)
316
+
317
+
318
+ class BasicMotionEncoder(nn.Module):
319
+ def __init__(self, args):
320
+ super(BasicMotionEncoder, self).__init__()
321
+ cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
322
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
323
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
324
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
325
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
326
+ self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
327
+
328
+ def forward(self, flow, corr):
329
+ cor = F.relu(self.convc1(corr))
330
+ cor = F.relu(self.convc2(cor))
331
+ flo = F.relu(self.convf1(flow))
332
+ flo = F.relu(self.convf2(flo))
333
+
334
+ cor_flo = torch.cat([cor, flo], dim=1)
335
+ out = F.relu(self.conv(cor_flo))
336
+ return torch.cat([out, flow], dim=1)
337
+
338
+
339
+ class SmallUpdateBlock(nn.Module):
340
+ def __init__(self, args, hidden_dim=96):
341
+ super(SmallUpdateBlock, self).__init__()
342
+ self.encoder = SmallMotionEncoder(args)
343
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
344
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
345
+
346
+ def forward(self, net, inp, corr, flow):
347
+ motion_features = self.encoder(flow, corr)
348
+ inp = torch.cat([inp, motion_features], dim=1)
349
+ net = self.gru(net, inp)
350
+ delta_flow = self.flow_head(net)
351
+
352
+ return net, None, delta_flow
353
+
354
+
355
+ class BasicUpdateBlock(nn.Module):
356
+ def __init__(self, args, hidden_dim=128, input_dim=128):
357
+ super(BasicUpdateBlock, self).__init__()
358
+ self.args = args
359
+ self.encoder = BasicMotionEncoder(args)
360
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
361
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
362
+
363
+ self.mask = nn.Sequential(
364
+ nn.Conv2d(128, 256, 3, padding=1),
365
+ nn.ReLU(inplace=True),
366
+ nn.Conv2d(256, 64 * 9, 1, padding=0),
367
+ )
368
+
369
+ def forward(self, net, inp, corr, flow, upsample=True):
370
+ motion_features = self.encoder(flow, corr)
371
+ inp = torch.cat([inp, motion_features], dim=1)
372
+
373
+ net = self.gru(net, inp)
374
+ delta_flow = self.flow_head(net)
375
+
376
+ # scale mask to balence gradients
377
+ mask = 0.25 * self.mask(net)
378
+ return net, mask, delta_flow
379
+
380
+
381
+ class ResidualBlock(nn.Module):
382
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
383
+ super(ResidualBlock, self).__init__()
384
+
385
+ self.conv1 = nn.Conv2d(
386
+ in_planes, planes, kernel_size=3, padding=1, stride=stride
387
+ )
388
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
389
+ self.relu = nn.ReLU(inplace=True)
390
+
391
+ num_groups = planes // 8
392
+
393
+ if norm_fn == "group":
394
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
395
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
396
+ if not stride == 1:
397
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
398
+
399
+ elif norm_fn == "batch":
400
+ self.norm1 = nn.BatchNorm2d(planes)
401
+ self.norm2 = nn.BatchNorm2d(planes)
402
+ if not stride == 1:
403
+ self.norm3 = nn.BatchNorm2d(planes)
404
+
405
+ elif norm_fn == "instance":
406
+ self.norm1 = nn.InstanceNorm2d(planes)
407
+ self.norm2 = nn.InstanceNorm2d(planes)
408
+ if not stride == 1:
409
+ self.norm3 = nn.InstanceNorm2d(planes)
410
+
411
+ elif norm_fn == "none":
412
+ self.norm1 = nn.Sequential()
413
+ self.norm2 = nn.Sequential()
414
+ if not stride == 1:
415
+ self.norm3 = nn.Sequential()
416
+
417
+ if stride == 1:
418
+ self.downsample = None
419
+
420
+ else:
421
+ self.downsample = nn.Sequential(
422
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
423
+ )
424
+
425
+ def forward(self, x):
426
+ y = x
427
+ y = self.relu(self.norm1(self.conv1(y)))
428
+ y = self.relu(self.norm2(self.conv2(y)))
429
+
430
+ if self.downsample is not None:
431
+ x = self.downsample(x)
432
+
433
+ return self.relu(x + y)
434
+
435
+
436
+ class BottleneckBlock(nn.Module):
437
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
438
+ super(BottleneckBlock, self).__init__()
439
+
440
+ self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
441
+ self.conv2 = nn.Conv2d(
442
+ planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
443
+ )
444
+ self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
445
+ self.relu = nn.ReLU(inplace=True)
446
+
447
+ num_groups = planes // 8
448
+
449
+ if norm_fn == "group":
450
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
451
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
452
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
453
+ if not stride == 1:
454
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
455
+
456
+ elif norm_fn == "batch":
457
+ self.norm1 = nn.BatchNorm2d(planes // 4)
458
+ self.norm2 = nn.BatchNorm2d(planes // 4)
459
+ self.norm3 = nn.BatchNorm2d(planes)
460
+ if not stride == 1:
461
+ self.norm4 = nn.BatchNorm2d(planes)
462
+
463
+ elif norm_fn == "instance":
464
+ self.norm1 = nn.InstanceNorm2d(planes // 4)
465
+ self.norm2 = nn.InstanceNorm2d(planes // 4)
466
+ self.norm3 = nn.InstanceNorm2d(planes)
467
+ if not stride == 1:
468
+ self.norm4 = nn.InstanceNorm2d(planes)
469
+
470
+ elif norm_fn == "none":
471
+ self.norm1 = nn.Sequential()
472
+ self.norm2 = nn.Sequential()
473
+ self.norm3 = nn.Sequential()
474
+ if not stride == 1:
475
+ self.norm4 = nn.Sequential()
476
+
477
+ if stride == 1:
478
+ self.downsample = None
479
+
480
+ else:
481
+ self.downsample = nn.Sequential(
482
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
483
+ )
484
+
485
+ def forward(self, x):
486
+ y = x
487
+ y = self.relu(self.norm1(self.conv1(y)))
488
+ y = self.relu(self.norm2(self.conv2(y)))
489
+ y = self.relu(self.norm3(self.conv3(y)))
490
+
491
+ if self.downsample is not None:
492
+ x = self.downsample(x)
493
+
494
+ return self.relu(x + y)
495
+
496
+
497
+ class BasicEncoder(nn.Module):
498
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
499
+ super(BasicEncoder, self).__init__()
500
+ self.norm_fn = norm_fn
501
+
502
+ if self.norm_fn == "group":
503
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
504
+
505
+ elif self.norm_fn == "batch":
506
+ self.norm1 = nn.BatchNorm2d(64)
507
+
508
+ elif self.norm_fn == "instance":
509
+ self.norm1 = nn.InstanceNorm2d(64)
510
+
511
+ elif self.norm_fn == "none":
512
+ self.norm1 = nn.Sequential()
513
+
514
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
515
+ self.relu1 = nn.ReLU(inplace=True)
516
+
517
+ self.in_planes = 64
518
+ self.layer1 = self._make_layer(64, stride=1)
519
+ self.layer2 = self._make_layer(96, stride=2)
520
+ self.layer3 = self._make_layer(128, stride=2)
521
+
522
+ # output convolution
523
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
524
+
525
+ self.dropout = None
526
+ if dropout > 0:
527
+ self.dropout = nn.Dropout2d(p=dropout)
528
+
529
+ for m in self.modules():
530
+ if isinstance(m, nn.Conv2d):
531
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
532
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
533
+ if m.weight is not None:
534
+ nn.init.constant_(m.weight, 1)
535
+ if m.bias is not None:
536
+ nn.init.constant_(m.bias, 0)
537
+
538
+ def _make_layer(self, dim, stride=1):
539
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
540
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
541
+ layers = (layer1, layer2)
542
+
543
+ self.in_planes = dim
544
+ return nn.Sequential(*layers)
545
+
546
+ def forward(self, x):
547
+ # if input is list, combine batch dimension
548
+ is_list = isinstance(x, tuple) or isinstance(x, list)
549
+ if is_list:
550
+ batch_dim = x[0].shape[0]
551
+ x = torch.cat(x, dim=0)
552
+
553
+ x = self.conv1(x)
554
+ x = self.norm1(x)
555
+ x = self.relu1(x)
556
+
557
+ x = self.layer1(x)
558
+ x = self.layer2(x)
559
+ x = self.layer3(x)
560
+
561
+ x = self.conv2(x)
562
+
563
+ if self.training and self.dropout is not None:
564
+ x = self.dropout(x)
565
+
566
+ if is_list:
567
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
568
+
569
+ return x
570
+
571
+
572
+ class BasicEncoder1(nn.Module):
573
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
574
+ super(BasicEncoder1, self).__init__()
575
+ self.norm_fn = norm_fn
576
+
577
+ if self.norm_fn == "group":
578
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
579
+
580
+ elif self.norm_fn == "batch":
581
+ self.norm1 = nn.BatchNorm2d(64)
582
+
583
+ elif self.norm_fn == "instance":
584
+ self.norm1 = nn.InstanceNorm2d(64)
585
+
586
+ elif self.norm_fn == "none":
587
+ self.norm1 = nn.Sequential()
588
+
589
+ self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3)
590
+ self.relu1 = nn.ReLU(inplace=True)
591
+
592
+ self.in_planes = 64
593
+ self.layer1 = self._make_layer(64, stride=1)
594
+ self.layer2 = self._make_layer(96, stride=2)
595
+ self.layer3 = self._make_layer(128, stride=2)
596
+
597
+ # output convolution
598
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
599
+
600
+ self.dropout = None
601
+ if dropout > 0:
602
+ self.dropout = nn.Dropout2d(p=dropout)
603
+
604
+ for m in self.modules():
605
+ if isinstance(m, nn.Conv2d):
606
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
607
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
608
+ if m.weight is not None:
609
+ nn.init.constant_(m.weight, 1)
610
+ if m.bias is not None:
611
+ nn.init.constant_(m.bias, 0)
612
+
613
+ def _make_layer(self, dim, stride=1):
614
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
615
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
616
+ layers = (layer1, layer2)
617
+
618
+ self.in_planes = dim
619
+ return nn.Sequential(*layers)
620
+
621
+ def forward(self, x):
622
+ # if input is list, combine batch dimension
623
+ is_list = isinstance(x, tuple) or isinstance(x, list)
624
+ if is_list:
625
+ batch_dim = x[0].shape[0]
626
+ x = torch.cat(x, dim=0)
627
+
628
+ x = self.conv1(x)
629
+ x = self.norm1(x)
630
+ x = self.relu1(x)
631
+
632
+ x = self.layer1(x)
633
+ x = self.layer2(x)
634
+ x = self.layer3(x)
635
+
636
+ x = self.conv2(x)
637
+
638
+ if self.training and self.dropout is not None:
639
+ x = self.dropout(x)
640
+
641
+ if is_list:
642
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
643
+
644
+ return x
645
+
646
+
647
+ class SmallEncoder(nn.Module):
648
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
649
+ super(SmallEncoder, self).__init__()
650
+ self.norm_fn = norm_fn
651
+
652
+ if self.norm_fn == "group":
653
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
654
+
655
+ elif self.norm_fn == "batch":
656
+ self.norm1 = nn.BatchNorm2d(32)
657
+
658
+ elif self.norm_fn == "instance":
659
+ self.norm1 = nn.InstanceNorm2d(32)
660
+
661
+ elif self.norm_fn == "none":
662
+ self.norm1 = nn.Sequential()
663
+
664
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
665
+ self.relu1 = nn.ReLU(inplace=True)
666
+
667
+ self.in_planes = 32
668
+ self.layer1 = self._make_layer(32, stride=1)
669
+ self.layer2 = self._make_layer(64, stride=2)
670
+ self.layer3 = self._make_layer(96, stride=2)
671
+
672
+ self.dropout = None
673
+ if dropout > 0:
674
+ self.dropout = nn.Dropout2d(p=dropout)
675
+
676
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
677
+
678
+ for m in self.modules():
679
+ if isinstance(m, nn.Conv2d):
680
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
681
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
682
+ if m.weight is not None:
683
+ nn.init.constant_(m.weight, 1)
684
+ if m.bias is not None:
685
+ nn.init.constant_(m.bias, 0)
686
+
687
+ def _make_layer(self, dim, stride=1):
688
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
689
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
690
+ layers = (layer1, layer2)
691
+
692
+ self.in_planes = dim
693
+ return nn.Sequential(*layers)
694
+
695
+ def forward(self, x):
696
+ # if input is list, combine batch dimension
697
+ is_list = isinstance(x, tuple) or isinstance(x, list)
698
+ if is_list:
699
+ batch_dim = x[0].shape[0]
700
+ x = torch.cat(x, dim=0)
701
+
702
+ x = self.conv1(x)
703
+ x = self.norm1(x)
704
+ x = self.relu1(x)
705
+
706
+ x = self.layer1(x)
707
+ x = self.layer2(x)
708
+ x = self.layer3(x)
709
+ x = self.conv2(x)
710
+
711
+ if self.training and self.dropout is not None:
712
+ x = self.dropout(x)
713
+
714
+ if is_list:
715
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
716
+
717
+ return x
718
+
719
+
720
+ ##################################################
721
+ # RFR is implemented based on RAFT optical flow #
722
+ ##################################################
723
+
724
+
725
+ def backwarp(img, flow):
726
+ _, _, H, W = img.size()
727
+
728
+ u = flow[:, 0, :, :]
729
+ v = flow[:, 1, :, :]
730
+
731
+ gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
732
+
733
+ gridX = torch.tensor(
734
+ gridX,
735
+ requires_grad=False,
736
+ ).cuda()
737
+ gridY = torch.tensor(
738
+ gridY,
739
+ requires_grad=False,
740
+ ).cuda()
741
+ x = gridX.unsqueeze(0).expand_as(u).float() + u
742
+ y = gridY.unsqueeze(0).expand_as(v).float() + v
743
+ # range -1 to 1
744
+ x = 2 * (x / (W - 1) - 0.5)
745
+ y = 2 * (y / (H - 1) - 0.5)
746
+ # stacking X and Y
747
+ grid = torch.stack((x, y), dim=3)
748
+ # Sample pixels using bilinear interpolation.
749
+ imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True)
750
+
751
+ return imgOut
752
+
753
+
754
+ class ErrorAttention(nn.Module):
755
+ """A three-layer network for predicting mask"""
756
+
757
+ def __init__(self, input, output):
758
+ super(ErrorAttention, self).__init__()
759
+ self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
760
+ self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
761
+ self.conv3 = nn.Conv2d(38, output, 3, padding=1)
762
+ self.prelu1 = nn.PReLU()
763
+ self.prelu2 = nn.PReLU()
764
+
765
+ def forward(self, x1):
766
+ x = self.prelu1(self.conv1(x1))
767
+ x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1))
768
+ x = self.conv3(x)
769
+ return x
770
+
771
+
772
+ class RFR(nn.Module):
773
+ def __init__(self, args):
774
+ super(RFR, self).__init__()
775
+ self.attention2 = ErrorAttention(6, 1)
776
+ self.hidden_dim = hdim = 128
777
+ self.context_dim = cdim = 128
778
+ args.corr_levels = 4
779
+ args.corr_radius = 4
780
+ args.dropout = 0
781
+ self.args = args
782
+
783
+ # feature network, context network, and update block
784
+ self.fnet = BasicEncoder(output_dim=256, norm_fn="none", dropout=args.dropout)
785
+ # self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
786
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
787
+
788
+ def freeze_bn(self):
789
+ for m in self.modules():
790
+ if isinstance(m, nn.BatchNorm2d):
791
+ m.eval()
792
+
793
+ def initialize_flow(self, img):
794
+ """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
795
+ N, C, H, W = img.shape
796
+ coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
797
+ coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
798
+
799
+ # optical flow computed as difference: flow = coords1 - coords0
800
+ return coords0, coords1
801
+
802
+ def upsample_flow(self, flow, mask):
803
+ """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
804
+ N, _, H, W = flow.shape
805
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
806
+ mask = torch.softmax(mask, dim=2)
807
+
808
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
809
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
810
+
811
+ up_flow = torch.sum(mask * up_flow, dim=2)
812
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
813
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
814
+
815
+ def forward(
816
+ self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
817
+ ):
818
+ H, W = image1.size()[2:4]
819
+ H8 = H // 8 * 8
820
+ W8 = W // 8 * 8
821
+
822
+ if flow_init is not None:
823
+ flow_init_resize = F.interpolate(
824
+ flow_init, size=(H8 // 8, W8 // 8), mode="nearest"
825
+ )
826
+
827
+ flow_init_resize[:, :1] = (
828
+ flow_init_resize[:, :1].clone() * (W8 // 8 * 1.0) / flow_init.size()[3]
829
+ )
830
+ flow_init_resize[:, 1:] = (
831
+ flow_init_resize[:, 1:].clone() * (H8 // 8 * 1.0) / flow_init.size()[2]
832
+ )
833
+
834
+ if not hasattr(self.args, "not_use_rfr_mask") or (
835
+ hasattr(self.args, "not_use_rfr_mask")
836
+ and (not self.args.not_use_rfr_mask)
837
+ ):
838
+ im18 = F.interpolate(image1, size=(H8 // 8, W8 // 8), mode="bilinear")
839
+ im28 = F.interpolate(image2, size=(H8 // 8, W8 // 8), mode="bilinear")
840
+
841
+ warp21 = backwarp(im28, flow_init_resize)
842
+ error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True)
843
+ # print('errormin', error21.min(), error21.max())
844
+ f12init = (
845
+ torch.exp(
846
+ -self.attention2(
847
+ torch.cat([im18, error21, flow_init_resize], dim=1)
848
+ )
849
+ ** 2
850
+ )
851
+ * flow_init_resize
852
+ )
853
+ else:
854
+ flow_init_resize = None
855
+ flow_init = torch.zeros(
856
+ image1.size()[0], 2, image1.size()[2] // 8, image1.size()[3] // 8
857
+ ).cuda()
858
+ error21 = torch.zeros(
859
+ image1.size()[0], 1, image1.size()[2] // 8, image1.size()[3] // 8
860
+ ).cuda()
861
+
862
+ f12_init = flow_init
863
+ # print('None inital flow!')
864
+
865
+ image1 = F.interpolate(image1, size=(H8, W8), mode="bilinear")
866
+ image2 = F.interpolate(image2, size=(H8, W8), mode="bilinear")
867
+
868
+ f12s, f12, f12_init = self.forward_pred(
869
+ image1, image2, iters, flow_init_resize, upsample, test_mode
870
+ )
871
+
872
+ if hasattr(self.args, "requires_sq_flow") and self.args.requires_sq_flow:
873
+ for ii in range(len(f12s)):
874
+ f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode="bilinear")
875
+ f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0 * W8) * W
876
+ f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0 * H8) * H
877
+ if self.training:
878
+ return f12s
879
+ else:
880
+ return [f12s[-1]], f12_init
881
+ else:
882
+ f12[:, :1] = f12[:, :1].clone() / (1.0 * W8) * W
883
+ f12[:, 1:] = f12[:, 1:].clone() / (1.0 * H8) * H
884
+
885
+ f12 = F.interpolate(f12, size=(H, W), mode="bilinear")
886
+ # print('wo!!')
887
+ return (
888
+ f12,
889
+ f12_init,
890
+ error21,
891
+ )
892
+
893
+ def forward_pred(
894
+ self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
895
+ ):
896
+ """Estimate optical flow between pair of frames"""
897
+
898
+ image1 = image1.contiguous()
899
+ image2 = image2.contiguous()
900
+
901
+ hdim = self.hidden_dim
902
+ cdim = self.context_dim
903
+
904
+ # run the feature network
905
+ with autocast(device.type, enabled=self.args.mixed_precision):
906
+ fmap1, fmap2 = self.fnet([image1, image2])
907
+ fmap1 = fmap1.float()
908
+ fmap2 = fmap2.float()
909
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
910
+
911
+ # run the context network
912
+ with autocast(device.type, enabled=self.args.mixed_precision):
913
+ cnet = self.fnet(image1)
914
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
915
+ net = torch.tanh(net)
916
+ inp = torch.relu(inp)
917
+
918
+ coords0, coords1 = self.initialize_flow(image1)
919
+
920
+ if flow_init is not None:
921
+ coords1 = coords1 + flow_init
922
+
923
+ flow_predictions = []
924
+ for itr in range(iters):
925
+ coords1 = coords1.detach()
926
+ if itr == 0:
927
+ if flow_init is not None:
928
+ coords1 = coords1 + flow_init
929
+ corr = corr_fn(coords1) # index correlation volume
930
+
931
+ flow = coords1 - coords0
932
+ with autocast(device.type, enabled=self.args.mixed_precision):
933
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
934
+
935
+ # F(t+1) = F(t) + \Delta(t)
936
+ coords1 = coords1 + delta_flow
937
+
938
+ # upsample predictions
939
+ if up_mask is None:
940
+ flow_up = upflow8(coords1 - coords0)
941
+ else:
942
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
943
+
944
+ flow_predictions.append(flow_up)
945
+
946
+ return flow_predictions, flow_up, flow_init
947
+
948
+ ####################### WARPING #######################
949
+
950
+
951
+ # expects batched tensors, considered low-level operation
952
+ # img: bs, ch, h, w
953
+ # flow: bs, xy (pix displace), h, w
954
+ def flow_backwarp(
955
+ img, flow, resample="bilinear", padding_mode="border", align_corners=False
956
+ ):
957
+ if len(img.shape) != 4:
958
+ img = img[None,]
959
+ if len(flow.shape) != 4:
960
+ flow = flow[None,]
961
+ q = (
962
+ 2
963
+ * flow
964
+ / torch.tensor(
965
+ [
966
+ flow.shape[-2],
967
+ flow.shape[-1],
968
+ ],
969
+ device=flow.device,
970
+ dtype=torch.float,
971
+ )[None, :, None, None]
972
+ )
973
+ q = q + torch.stack(
974
+ torch.meshgrid(
975
+ torch.linspace(-1, 1, flow.shape[-2]),
976
+ torch.linspace(-1, 1, flow.shape[-1]),
977
+ )
978
+ )[
979
+ None,
980
+ ].to(
981
+ flow.device
982
+ )
983
+ if img.dtype != q.dtype:
984
+ img = img.type(q.dtype)
985
+
986
+ return nn.functional.grid_sample(
987
+ img,
988
+ q.flip(dims=(1,)).permute(0, 2, 3, 1),
989
+ mode=resample, # nearest, bicubic, bilinear
990
+ padding_mode=padding_mode, # border, zeros, reflection
991
+ align_corners=align_corners,
992
+ )
993
+
994
+
995
+ backwarp = flow_warp = flow_backwarp
996
+
997
+
998
+ # mode: sum, avg, lin, softmax
999
+ # lin/softmax w/out metric defaults to avg
1000
+ # must use gpu, move back to cpu if retain_device
1001
+ # typical metric: -20 * | img0 - backwarp(img1,flow) |
1002
+ # From Fannovel16: Changed mode params for common ops.
1003
+ def flow_forewarp(
1004
+ img, flow, mode="average", metric=None, mask=False, retain_device=True
1005
+ ):
1006
+ # setup
1007
+ #if mode == "sum":
1008
+ # mode = "summation"
1009
+ #elif mode == "avg":
1010
+ # mode = "average"
1011
+ if mode in ["lin", "linear"]:
1012
+ #mode = "linear" if metric is not None else "average"
1013
+ mode = "linear" if metric is not None else "avg"
1014
+ elif mode in ["sm", "softmax"]:
1015
+ #mode = "softmax" if metric is not None else "average"
1016
+ mode = "soft" if metric is not None else "avg"
1017
+ if len(img.shape) != 4:
1018
+ img = img[None,]
1019
+ if len(flow.shape) != 4:
1020
+ flow = flow[None,]
1021
+ if metric is not None and len(metric.shape) != 4:
1022
+ metric = metric[None,]
1023
+ flow = flow.flip(dims=(1,))
1024
+ if img.dtype != torch.float32:
1025
+ img = img.type(torch.float32)
1026
+ if flow.dtype != torch.float32:
1027
+ flow = flow.type(torch.float32)
1028
+ if metric is not None and metric.dtype != torch.float32:
1029
+ metric = metric.type(torch.float32)
1030
+
1031
+ # move to gpu if necessary
1032
+ assert img.device == flow.device
1033
+ if metric is not None:
1034
+ assert img.device == metric.device
1035
+ was_cpu = img.device.type == "cpu"
1036
+ if was_cpu:
1037
+ img = img.to("cuda")
1038
+ flow = flow.to("cuda")
1039
+ if metric is not None:
1040
+ metric = metric.to("cuda")
1041
+
1042
+ # add mask
1043
+ if mask:
1044
+ bs, ch, h, w = img.shape
1045
+ img = torch.cat(
1046
+ [img, torch.ones(bs, 1, h, w, dtype=img.dtype, device=img.device)], dim=1
1047
+ )
1048
+
1049
+ # forward, move back to cpu if desired
1050
+ ans = FunctionSoftsplat(img, flow, metric, mode)
1051
+ if was_cpu and retain_device:
1052
+ ans = ans.cpu()
1053
+ return ans
1054
+
1055
+
1056
+ forewarp = flow_forewarp
1057
+
1058
+
1059
+ # resizing utility
1060
+ def flow_resize(flow, size, mode="nearest", align_corners=False):
1061
+ # flow: bs,xy,h,w
1062
+ size = pixel_ij(size, rounding=True)
1063
+ if flow.dtype != torch.float:
1064
+ flow = flow.float()
1065
+ if len(flow.shape) == 3:
1066
+ flow = flow[None,]
1067
+ if flow.shape[-2:] == size:
1068
+ return flow
1069
+ return (
1070
+ nn.functional.interpolate(
1071
+ flow,
1072
+ size=size,
1073
+ mode=mode,
1074
+ align_corners=align_corners if mode != "nearest" else None,
1075
+ )
1076
+ * torch.tensor(
1077
+ [b / a for a, b in zip(flow.shape[-2:], size)],
1078
+ device=flow.device,
1079
+ )[None, :, None, None]
1080
+ )
1081
+
1082
+
1083
+ ####################### TRADITIONAL #######################
1084
+
1085
+ # dense
1086
+ _lucaskanade = lambda a, b: np.moveaxis(
1087
+ cv2.optflow.calcOpticalFlowSparseToDense(
1088
+ a,
1089
+ b, # grid_step=5, sigma=0.5,
1090
+ ),
1091
+ 2,
1092
+ 0,
1093
+ )[
1094
+ None,
1095
+ ]
1096
+ _farneback = lambda a, b: np.moveaxis(
1097
+ cv2.calcOpticalFlowFarneback(
1098
+ a,
1099
+ b,
1100
+ None,
1101
+ 0.6,
1102
+ 3,
1103
+ 25,
1104
+ 7,
1105
+ 5,
1106
+ 1.2,
1107
+ cv2.OPTFLOW_FARNEBACK_GAUSSIAN,
1108
+ ),
1109
+ 2,
1110
+ 0,
1111
+ )[
1112
+ None,
1113
+ ]
1114
+ _dtvl1_ = cv2.optflow.createOptFlow_DualTVL1()
1115
+ _dtvl1 = lambda a, b: np.moveaxis(
1116
+ _dtvl1_.calc(
1117
+ a,
1118
+ b,
1119
+ None,
1120
+ ),
1121
+ 2,
1122
+ 0,
1123
+ )[
1124
+ None,
1125
+ ]
1126
+ _simple = lambda a, b: np.moveaxis(
1127
+ cv2.optflow.calcOpticalFlowSF(
1128
+ a,
1129
+ b,
1130
+ 3,
1131
+ 5,
1132
+ 5,
1133
+ ),
1134
+ 2,
1135
+ 0,
1136
+ )[
1137
+ None,
1138
+ ]
1139
+ _pca_ = cv2.optflow.createOptFlow_PCAFlow()
1140
+ _pca = lambda a, b: np.moveaxis(
1141
+ _pca_.calc(
1142
+ a,
1143
+ b,
1144
+ None,
1145
+ ),
1146
+ 2,
1147
+ 0,
1148
+ )[
1149
+ None,
1150
+ ]
1151
+ _drlof = lambda a, b: np.moveaxis(
1152
+ cv2.optflow.calcOpticalFlowDenseRLOF(
1153
+ a,
1154
+ b,
1155
+ None,
1156
+ ),
1157
+ 2,
1158
+ 0,
1159
+ )[
1160
+ None,
1161
+ ]
1162
+ _deepflow_ = cv2.optflow.createOptFlow_DeepFlow()
1163
+ _deepflow = lambda a, b: np.moveaxis(
1164
+ _deepflow_.calc(
1165
+ a,
1166
+ b,
1167
+ None,
1168
+ ),
1169
+ 2,
1170
+ 0,
1171
+ )[
1172
+ None,
1173
+ ]
1174
+
1175
+
1176
+ def cv2flow(a, b, method="lucaskanade", back=False):
1177
+ if method == "lucaskanade":
1178
+ f = _lucaskanade
1179
+ a = a.convert("L").cv2()
1180
+ b = b.convert("L").cv2()
1181
+ elif method == "farneback":
1182
+ f = _farneback
1183
+ a = a.convert("L").cv2()
1184
+ b = b.convert("L").cv2()
1185
+ elif method == "dtvl1":
1186
+ f = _dtvl1
1187
+ a = a.convert("L").cv2()
1188
+ b = b.convert("L").cv2()
1189
+ elif method == "simple":
1190
+ f = _simple
1191
+ a = a.convert("RGB").cv2()
1192
+ b = b.convert("RGB").cv2()
1193
+ elif method == "pca":
1194
+ f = _pca
1195
+ a = a.convert("L").cv2()
1196
+ b = b.convert("L").cv2()
1197
+ elif method == "drlof":
1198
+ f = _drlof
1199
+ a = a.convert("RGB").cv2()
1200
+ b = b.convert("RGB").cv2()
1201
+ elif method == "deepflow":
1202
+ f = _deepflow
1203
+ a = a.convert("L").cv2()
1204
+ b = b.convert("L").cv2()
1205
+ else:
1206
+ assert 0
1207
+ ans = f(b, a)
1208
+ if back:
1209
+ ans = np.concatenate(
1210
+ [
1211
+ ans,
1212
+ f(a, b),
1213
+ ]
1214
+ )
1215
+ return torch.tensor(ans).flip(dims=(1,))
1216
+
1217
+
1218
+ ####################### FLOWNET2 #######################
1219
+
1220
+
1221
+ def flownet2(img_a, img_b, mode="shm", back=False):
1222
+ # package
1223
+ url = f"http://localhost:8109/get-flow"
1224
+ if mode == "shm":
1225
+ t = time.time()
1226
+ fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png"))
1227
+ fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png"))
1228
+ elif mode == "net":
1229
+ assert False, "not impl"
1230
+ q = u2d.img2uri(img.pil("RGB"))
1231
+ q.decode()
1232
+ resp = requests.get(
1233
+ url,
1234
+ params={
1235
+ "img_a": fn_a,
1236
+ "img_b": fn_b,
1237
+ "mode": mode,
1238
+ "back": back,
1239
+ # 'vis': vis,
1240
+ },
1241
+ )
1242
+
1243
+ # return
1244
+ ans = {"response": resp}
1245
+ if resp.status_code == 200:
1246
+ j = resp.json()
1247
+ ans["time"] = j["time"]
1248
+ ans["output"] = {
1249
+ "flow": torch.tensor(load(j["fn_flow"])),
1250
+ }
1251
+ # if vis:
1252
+ # ans['output']['vis'] = I(j['fn_vis'])
1253
+ if mode == "shm":
1254
+ shutil.rmtree(f"/dev/shm/_flownet2/{t}")
1255
+ return ans
1256
+
1257
+
1258
+ ####################### VISUALIZATION #######################
1259
+
1260
+
1261
+ class Gridnet(nn.Module):
1262
+ def __init__(self, channels_0, channels_1, channels_2, total_dropout_p, depth):
1263
+ super().__init__()
1264
+ self.channels_0 = ch0 = channels_0
1265
+ self.channels_1 = ch1 = channels_1
1266
+ self.channels_2 = ch2 = channels_2
1267
+ self.total_dropout_p = p = total_dropout_p
1268
+ self.depth = depth
1269
+ self.encoders = nn.ModuleList(
1270
+ [GridnetEncoder(ch0, ch1, ch2) for i in range(self.depth)]
1271
+ )
1272
+ self.decoders = nn.ModuleList(
1273
+ [GridnetDecoder(ch0, ch1, ch2) for i in range(self.depth)]
1274
+ )
1275
+ self.total_dropout = GridnetTotalDropout(p)
1276
+ return
1277
+
1278
+ def forward(self, x):
1279
+ for e, enc in enumerate(self.encoders):
1280
+ t = [self.total_dropout(i) for i in t] if e != 0 else x
1281
+ t = enc(t)
1282
+ for d, dec in enumerate(self.decoders):
1283
+ t = [self.total_dropout(i) for i in t]
1284
+ t = dec(t)
1285
+ return t
1286
+
1287
+
1288
+ class GridnetEncoder(nn.Module):
1289
+ def __init__(self, channels_0, channels_1, channels_2):
1290
+ super().__init__()
1291
+ self.channels_0 = ch0 = channels_0
1292
+ self.channels_1 = ch1 = channels_1
1293
+ self.channels_2 = ch2 = channels_2
1294
+ self.resnet_0 = GridnetResnet(ch0)
1295
+ self.resnet_1 = GridnetResnet(ch1)
1296
+ self.resnet_2 = GridnetResnet(ch2)
1297
+ self.downsample_01 = GridnetDownsample(ch0, ch1)
1298
+ self.downsample_12 = GridnetDownsample(ch1, ch2)
1299
+ return
1300
+
1301
+ def forward(self, x):
1302
+ out = [
1303
+ None,
1304
+ ] * 3
1305
+ out[0] = self.resnet_0(x[0])
1306
+ out[1] = self.resnet_1(x[1]) + self.downsample_01(out[0])
1307
+ out[2] = self.resnet_2(x[2]) + self.downsample_12(out[1])
1308
+ return out
1309
+
1310
+
1311
+ class GridnetDecoder(nn.Module):
1312
+ def __init__(self, channels_0, channels_1, channels_2):
1313
+ super().__init__()
1314
+ self.channels_0 = ch0 = channels_0
1315
+ self.channels_1 = ch1 = channels_1
1316
+ self.channels_2 = ch2 = channels_2
1317
+ self.resnet_0 = GridnetResnet(ch0)
1318
+ self.resnet_1 = GridnetResnet(ch1)
1319
+ self.resnet_2 = GridnetResnet(ch2)
1320
+ self.upsample_10 = GridnetUpsample(ch1, ch0)
1321
+ self.upsample_21 = GridnetUpsample(ch2, ch1)
1322
+ return
1323
+
1324
+ def forward(self, x):
1325
+ out = [
1326
+ None,
1327
+ ] * 3
1328
+ out[2] = self.resnet_2(x[2])
1329
+ out[1] = self.resnet_1(x[1]) + self.upsample_21(out[2])
1330
+ out[0] = self.resnet_0(x[0]) + self.upsample_10(out[1])
1331
+ return out
1332
+
1333
+
1334
+ class GridnetConverter(nn.Module):
1335
+ def __init__(self, channels_in, channels_out):
1336
+ super().__init__()
1337
+ self.channels_in = cin = channels_in
1338
+ self.channels_out = cout = channels_out
1339
+ self.nets = nn.ModuleList(
1340
+ [
1341
+ nn.Sequential(
1342
+ nn.PReLU(a),
1343
+ nn.Conv2d(a, b, kernel_size=1, padding=0),
1344
+ nn.BatchNorm2d(b),
1345
+ )
1346
+ for a, b in zip(cin, cout)
1347
+ ]
1348
+ )
1349
+ return
1350
+
1351
+ def forward(self, x):
1352
+ return [m(q) for m, q in zip(self.nets, x)]
1353
+
1354
+
1355
+ class GridnetResnet(nn.Module):
1356
+ def __init__(self, channels):
1357
+ super().__init__()
1358
+ self.channels = ch = channels
1359
+ self.net = nn.Sequential(
1360
+ nn.PReLU(ch),
1361
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
1362
+ nn.BatchNorm2d(ch),
1363
+ nn.PReLU(ch),
1364
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
1365
+ nn.BatchNorm2d(ch),
1366
+ )
1367
+ return
1368
+
1369
+ def forward(self, x):
1370
+ return x + self.net(x)
1371
+
1372
+
1373
+ class GridnetDownsample(nn.Module):
1374
+ def __init__(self, channels_in, channels_out):
1375
+ super().__init__()
1376
+ self.channels_in = chin = channels_in
1377
+ self.channels_out = chout = channels_out
1378
+ self.net = nn.Sequential(
1379
+ nn.PReLU(chin),
1380
+ nn.Conv2d(chin, chin, kernel_size=3, padding=1, stride=2),
1381
+ nn.BatchNorm2d(chin),
1382
+ nn.PReLU(chin),
1383
+ nn.Conv2d(chin, chout, kernel_size=3, padding=1),
1384
+ nn.BatchNorm2d(chout),
1385
+ )
1386
+ return
1387
+
1388
+ def forward(self, x):
1389
+ return self.net(x)
1390
+
1391
+
1392
+ class GridnetUpsample(nn.Module):
1393
+ def __init__(self, channels_in, channels_out):
1394
+ super().__init__()
1395
+ self.channels_in = chin = channels_in
1396
+ self.channels_out = chout = channels_out
1397
+ self.net = nn.Sequential(
1398
+ nn.Upsample(scale_factor=2, mode="nearest"),
1399
+ nn.PReLU(chin),
1400
+ nn.Conv2d(chin, chout, kernel_size=3, padding=1),
1401
+ nn.BatchNorm2d(chout),
1402
+ nn.PReLU(chout),
1403
+ nn.Conv2d(chout, chout, kernel_size=3, padding=1),
1404
+ nn.BatchNorm2d(chout),
1405
+ )
1406
+ return
1407
+
1408
+ def forward(self, x):
1409
+ return self.net(x)
1410
+
1411
+
1412
+ class GridnetTotalDropout(nn.Module):
1413
+ def __init__(self, p):
1414
+ super().__init__()
1415
+ assert 0 <= p < 1
1416
+ self.p = p
1417
+ self.weight = 1 / (1 - p)
1418
+ return
1419
+
1420
+ def get_drop(self, x):
1421
+ d = torch.rand(len(x))[:, None, None, None] < self.p
1422
+ d = (1 - d.float()).to(x.device) * self.weight
1423
+ return d
1424
+
1425
+ def forward(self, x, force_drop=None):
1426
+ if force_drop is True:
1427
+ ans = x * self.get_drop(x)
1428
+ elif force_drop is False:
1429
+ ans = x
1430
+ else:
1431
+ if self.training:
1432
+ ans = x * self.get_drop(x)
1433
+ else:
1434
+ ans = x
1435
+ return ans
1436
+
1437
+
1438
+ class Interpolator(nn.Module):
1439
+ def __init__(self, size, mode="bilinear"):
1440
+ super().__init__()
1441
+ self.size = size
1442
+ self.mode = mode
1443
+ return
1444
+
1445
+ def forward(self, x, is_flow=False):
1446
+ if x.shape[-2] == self.size:
1447
+ return x
1448
+ if len(x.shape) == 4:
1449
+ # bs,ch,h,w
1450
+ bs, ch, h, w = x.shape
1451
+ ans = nn.functional.interpolate(
1452
+ x,
1453
+ size=self.size,
1454
+ mode=self.mode,
1455
+ align_corners=(False, None)[self.mode == "nearest"],
1456
+ )
1457
+ if is_flow:
1458
+ ans = (
1459
+ ans
1460
+ * torch.tensor(
1461
+ [b / a for a, b in zip((h, w), self.size)],
1462
+ device=ans.device,
1463
+ )[None, :, None, None]
1464
+ )
1465
+ return ans
1466
+ elif len(x.shape) == 5:
1467
+ # bs,k,ch,h,w (merge bs and k)
1468
+ bs, k, ch, h, w = x.shape
1469
+ return self.forward(
1470
+ x.view(bs * k, ch, h, w),
1471
+ is_flow=is_flow,
1472
+ ).view(bs, k, ch, *self.size)
1473
+ else:
1474
+ assert 0
1475
+
1476
+
1477
+ ###################### CANNY ######################
1478
+
1479
+
1480
+ def canny(img, a=100, b=200):
1481
+ img = I(img).convert("L")
1482
+ return I(cv2.Canny(img.cv2(), a, b))
1483
+
1484
+
1485
+ # https://www.pyimagesearch.com/2015/04/06/zero-parameter-automatic-canny-edge-detection-with-python-and-opencv/
1486
+ def canny_pis(img, sigma=0.33):
1487
+ # compute the median of the single channel pixel intensities
1488
+ img = I(img).convert("L").uint8(ch_last=False)
1489
+ v = np.median(img)
1490
+ # apply automatic Canny edge detection using the computed median
1491
+ lower = int(max(0, (1.0 - sigma) * v))
1492
+ upper = int(min(255, (1.0 + sigma) * v))
1493
+ edged = cv2.Canny(img[0], lower, upper)
1494
+ # return the edged image
1495
+ return I(edged)
1496
+
1497
+
1498
+ # https://en.wikipedia.org/wiki/Otsu%27s_method
1499
+ def canny_otsu(img):
1500
+ img = I(img).convert("L").uint8(ch_last=False)
1501
+ high, _ = cv2.threshold(img[0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
1502
+ low = 0.5 * high
1503
+ return I(cv2.Canny(img[0], low, high))
1504
+
1505
+
1506
+ def xdog(img, t=1.0, epsilon=0.04, phi=100, sigma=3, k=1.6):
1507
+ img = I(img).convert("L").uint8(ch_last=False)
1508
+ grey = np.asarray(img, dtype=np.float32)
1509
+ g0 = scipy.ndimage.gaussian_filter(grey, sigma)
1510
+ g1 = scipy.ndimage.gaussian_filter(grey, sigma * k)
1511
+
1512
+ # ans = ((1+p) * g0 - p * g1) / 255
1513
+ ans = (g0 - t * g1) / 255
1514
+ ans = 1 + np.tanh(phi * (ans - epsilon)) * (ans < epsilon)
1515
+ return ans
1516
+
1517
+
1518
+ def dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
1519
+ img = I(img).convert("L").tensor()[None]
1520
+ kern0 = max(2 * int(sigma * kernel_factor) + 1, 3)
1521
+ kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3)
1522
+ g0 = kornia.filters.gaussian_blur2d(
1523
+ img,
1524
+ (kern0, kern0),
1525
+ (sigma, sigma),
1526
+ border_type="replicate",
1527
+ )
1528
+ g1 = kornia.filters.gaussian_blur2d(
1529
+ img,
1530
+ (kern1, kern1),
1531
+ (sigma * k, sigma * k),
1532
+ border_type="replicate",
1533
+ )
1534
+ ans = 0.5 + t * (g1 - g0) - epsilon
1535
+ ans = ans.clip(0, 1) if clip else ans
1536
+ return ans[0].numpy()
1537
+
1538
+
1539
+ # input: (bs,rgb(a),h,w) or (bs,1,h,w)
1540
+ # returns: (bs,1,h,w)
1541
+ def batch_dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
1542
+ # to grayscale if needed
1543
+ bs, ch, h, w = img.shape
1544
+ if ch in [3, 4]:
1545
+ img = kornia.color.rgb_to_grayscale(img[:, :3])
1546
+ else:
1547
+ assert ch == 1
1548
+
1549
+ # calculate dog
1550
+ kern0 = max(2 * int(sigma * kernel_factor) + 1, 3)
1551
+ kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3)
1552
+ g0 = kornia.filters.gaussian_blur2d(
1553
+ img,
1554
+ (kern0, kern0),
1555
+ (sigma, sigma),
1556
+ border_type="replicate",
1557
+ )
1558
+ g1 = kornia.filters.gaussian_blur2d(
1559
+ img,
1560
+ (kern1, kern1),
1561
+ (sigma * k, sigma * k),
1562
+ border_type="replicate",
1563
+ )
1564
+ ans = 0.5 + t * (g1 - g0) - epsilon
1565
+ ans = ans.clip(0, 1) if clip else ans
1566
+ return ans
1567
+
1568
+
1569
+ ############### DERIVED DISTANCES ###############
1570
+
1571
+ # input: (bs,h,w) or (bs,1,h,w)
1572
+ # returns: (bs,)
1573
+ # normalized s.t. metric is same across proportional image scales
1574
+
1575
+
1576
+ # average of two asymmetric distances
1577
+ # normalized by diameter and area
1578
+ def batch_chamfer_distance(gt, pred, block=1024, return_more=False):
1579
+ t = batch_chamfer_distance_t(gt, pred, block=block)
1580
+ p = batch_chamfer_distance_p(gt, pred, block=block)
1581
+ cd = (t + p) / 2
1582
+ return cd
1583
+
1584
+
1585
+ def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):
1586
+ assert gt.device == pred.device and gt.shape == pred.shape
1587
+ bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1]
1588
+ dpred = batch_edt(pred, block=block)
1589
+ cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
1590
+ if len(cd.shape) == 2:
1591
+ assert cd.shape[1] == 1
1592
+ cd = cd.squeeze(1)
1593
+ return cd
1594
+
1595
+
1596
+ def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):
1597
+ assert gt.device == pred.device and gt.shape == pred.shape
1598
+ bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1]
1599
+ dgt = batch_edt(gt, block=block)
1600
+ cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
1601
+ if len(cd.shape) == 2:
1602
+ assert cd.shape[1] == 1
1603
+ cd = cd.squeeze(1)
1604
+ return cd
1605
+
1606
+
1607
+ # normalized by diameter
1608
+ # always between [0,1]
1609
+ def batch_hausdorff_distance(gt, pred, block=1024, return_more=False):
1610
+ assert gt.device == pred.device and gt.shape == pred.shape
1611
+ bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1]
1612
+ dgt = batch_edt(gt, block=block)
1613
+ dpred = batch_edt(pred, block=block)
1614
+ hd = torch.stack(
1615
+ [
1616
+ (dgt * pred).amax(dim=(-2, -1)),
1617
+ (dpred * gt).amax(dim=(-2, -1)),
1618
+ ]
1619
+ ).amax(dim=0).float() / np.sqrt(h**2 + w**2)
1620
+ if len(hd.shape) == 2:
1621
+ assert hd.shape[1] == 1
1622
+ hd = hd.squeeze(1)
1623
+ return hd
1624
+
1625
+
1626
+ #################### UTILITIES ####################
1627
+
1628
+
1629
+ def reset_parameters(model):
1630
+ for layer in model.children():
1631
+ if hasattr(layer, "reset_parameters"):
1632
+ layer.reset_parameters()
1633
+ return model
1634
+
1635
+
1636
+ def channel_squeeze(x, dim=1):
1637
+ a = x.shape[:dim]
1638
+ b = x.shape[dim + 2 :]
1639
+ return x.reshape(*a, -1, *b)
1640
+
1641
+
1642
+ def channel_unsqueeze(x, shape, dim=1):
1643
+ a = x.shape[:dim]
1644
+ b = x.shape[dim + 1 :]
1645
+ return x.reshape(*a, *shape, *b)
1646
+
1647
+
1648
+ def default_collate(items, device=None):
1649
+ return to(dict(torch.utils.data.dataloader.default_collate(items)), device)
1650
+
1651
+
1652
+ def to(x, device):
1653
+ if device is None:
1654
+ return x
1655
+ if issubclass(x.__class__, dict):
1656
+ return dict(
1657
+ {
1658
+ k: v.to(device) if isinstance(v, torch.Tensor) else v
1659
+ for k, v in x.items()
1660
+ }
1661
+ )
1662
+ if isinstance(x, torch.Tensor):
1663
+ return x.to(device)
1664
+ if isinstance(x, np.ndarray):
1665
+ return torch.tensor(x).to(device)
1666
+ assert 0, "data not understood"
1667
+
1668
+
1669
+ ################ PARSING ################
1670
+
1671
+ from argparse import Namespace
1672
+
1673
+ # args: all args
1674
+ # bargs: base args
1675
+ # pargs: data processing args
1676
+ # largs: data loading args
1677
+ # margs: model args
1678
+ # targs: training args
1679
+
1680
+
1681
+ # typically used to read dataset filters
1682
+ def read_filter(fn, cast=None, sort=True, sort_key=None):
1683
+ if cast is None:
1684
+ cast = lambda x: x
1685
+ ans = [cast(line) for line in read(fn).split("\n") if line != ""]
1686
+ if sort:
1687
+ return sorted(ans, key=sort_key)
1688
+ else:
1689
+ return ans
1690
+
1691
+
1692
+ ################ FILE MANAGEMENT ################
1693
+
1694
+
1695
+ def mkfile(fn, parents=True, exist_ok=True):
1696
+ dn = "/".join(fn.split("/")[:-1])
1697
+ mkdir(dn, parents=parents, exist_ok=exist_ok)
1698
+ return fn
1699
+
1700
+
1701
+ def mkdir(dn, parents=True, exist_ok=True):
1702
+ pathlib.Path(dn).mkdir(parents=parents, exist_ok=exist_ok)
1703
+ return dn if (not dn[-1] == "/" or dn == "/") else dn[:-1]
1704
+
1705
+
1706
+ def fstrip(fn, return_more=False):
1707
+ dspl = fn.split("/")
1708
+ dn = "/".join(dspl[:-1]) if len(dspl) > 1 else "."
1709
+ fn = dspl[-1]
1710
+ fspl = fn.split(".")
1711
+ if len(fspl) == 1:
1712
+ bn = fspl[0]
1713
+ ext = ""
1714
+ else:
1715
+ bn = ".".join(fspl[:-1])
1716
+ ext = fspl[-1]
1717
+ if return_more:
1718
+ return Namespace(
1719
+ dn=dn,
1720
+ fn=fn,
1721
+ path=f"{dn}/{fn}",
1722
+ bn_path=f"{dn}/{bn}",
1723
+ bn=bn,
1724
+ ext=ext,
1725
+ )
1726
+ else:
1727
+ return bn
1728
+
1729
+
1730
+ def read(fn, mode="r"):
1731
+ with open(fn, mode) as handle:
1732
+ return handle.read()
1733
+
1734
+
1735
+ def write(text, fn, mode="w"):
1736
+ mkfile(fn, parents=True, exist_ok=True)
1737
+ with open(fn, mode) as handle:
1738
+ return handle.write(text)
1739
+
1740
+
1741
+ import pickle
1742
+
1743
+
1744
+ def dump(obj, fn, mode="wb"):
1745
+ mkfile(fn, parents=True, exist_ok=True)
1746
+ with open(fn, mode) as handle:
1747
+ return pickle.dump(obj, handle)
1748
+
1749
+
1750
+ def load(fn, mode="rb"):
1751
+ with open(fn, mode) as handle:
1752
+ return pickle.load(handle)
1753
+
1754
+
1755
+ import json
1756
+
1757
+
1758
+ def jwrite(x, fn, mode="w", indent="\t", sort_keys=False):
1759
+ mkfile(fn, parents=True, exist_ok=True)
1760
+ with open(fn, mode) as handle:
1761
+ return json.dump(x, handle, indent=indent, sort_keys=sort_keys)
1762
+
1763
+
1764
+ def jread(fn, mode="r"):
1765
+ with open(fn, mode) as handle:
1766
+ return json.load(handle)
1767
+
1768
+
1769
+ try:
1770
+ import yaml
1771
+
1772
+ def ywrite(x, fn, mode="w", default_flow_style=False):
1773
+ mkfile(fn, parents=True, exist_ok=True)
1774
+ with open(fn, mode) as handle:
1775
+ return yaml.dump(x, handle, default_flow_style=default_flow_style)
1776
+
1777
+ def yread(fn, mode="r"):
1778
+ with open(fn, mode) as handle:
1779
+ return yaml.safe_load(handle)
1780
+
1781
+ except:
1782
+ pass
1783
+
1784
+ try:
1785
+ import pyunpack
1786
+ except:
1787
+ pass
1788
+
1789
+ try:
1790
+ import mysql
1791
+ import mysql.connector
1792
+ except:
1793
+ pass
1794
+
1795
+
1796
+ ################ MISC ################
1797
+
1798
+ hakase = "./env/__hakase__.jpg"
1799
+ if not os.path.isfile(hakase):
1800
+ hakase = "./__env__/__hakase__.jpg"
1801
+
1802
+
1803
+ def mem(units="m"):
1804
+ return (
1805
+ psProcess(os.getpid()).memory_info().rss
1806
+ / {
1807
+ "b": 1,
1808
+ "k": 1e3,
1809
+ "m": 1e6,
1810
+ "g": 1e9,
1811
+ "t": 1e12,
1812
+ }[units[0].lower()]
1813
+ )
1814
+
1815
+
1816
+ def chunk(array, length, colwise=True):
1817
+ if colwise:
1818
+ return [array[i : i + length] for i in range(0, len(array), length)]
1819
+ else:
1820
+ return chunk(array, int(math.ceil(len(array) / length)), colwise=True)
1821
+
1822
+
1823
+ def classtree(x):
1824
+ return inspect.getclasstree(inspect.getmro(x))
1825
+
1826
+
1827
+ ################ AESTHETIC ################
1828
+
1829
+
1830
+ class Table:
1831
+ def __init__(
1832
+ self,
1833
+ table,
1834
+ delimiter=" ",
1835
+ orientation="br",
1836
+ double_colon=True,
1837
+ ):
1838
+ self.delimiter = delimiter
1839
+ self.orientation = orientation
1840
+ self.t = Table.parse(table, delimiter, orientation, double_colon)
1841
+ return
1842
+
1843
+ # rendering
1844
+ def __str__(self):
1845
+ return self.render()
1846
+
1847
+ def __repr__(self):
1848
+ return self.render()
1849
+
1850
+ def render(self):
1851
+ # set up empty entry
1852
+ empty = ("", Table._spec(self.orientation, transpose=False))
1853
+
1854
+ # calculate table size
1855
+ t = copy.deepcopy(self.t)
1856
+ totalrows = len(t)
1857
+ totalcols = [len(r) for r in t]
1858
+ assert min(totalcols) == max(totalcols)
1859
+ totalcols = totalcols[0]
1860
+
1861
+ # string-ify
1862
+ for i in range(totalrows):
1863
+ for j in range(totalcols):
1864
+ x, s = t[i][j]
1865
+ sp = s[11]
1866
+ if sp:
1867
+ x = eval(f'f"{{{x}{sp}}}"')
1868
+ Table._put((str(x), s), t, (i, j), empty)
1869
+
1870
+ # expand delimiters
1871
+ _repl = (
1872
+ lambda s: s[:2] + (1, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:]
1873
+ if s[2]
1874
+ else s[:2] + (0, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:]
1875
+ )
1876
+ for i, row in enumerate(t):
1877
+ for j, (x, s_own) in enumerate(row):
1878
+ # expand delim_up(^)
1879
+ if s_own[3]:
1880
+ u, v = i, j
1881
+ while 0 <= u:
1882
+ _, s = t[u][v]
1883
+ if (i, j) != (u, v) and (s[2] and not s[10]):
1884
+ break
1885
+ Table._put((x, _repl(s)), t, (u, v), empty)
1886
+ u -= 1
1887
+
1888
+ # expand delim_down(v)
1889
+ if s_own[4]:
1890
+ u, v = i, j
1891
+ while u < totalrows:
1892
+ _, s = t[u][v]
1893
+ if (i, j) != (u, v) and (s[2] and not s[10]):
1894
+ break
1895
+ Table._put((x, _repl(s)), t, (u, v), empty)
1896
+ u += 1
1897
+
1898
+ # expand delim_right(>)
1899
+ if s_own[5]:
1900
+ u, v = i, j
1901
+ while v < totalcols:
1902
+ _, s = t[u][v]
1903
+ if (i, j) != (u, v) and (s[2] and not s[10]):
1904
+ break
1905
+ Table._put((x, _repl(s)), t, (u, v), empty)
1906
+ v += 1
1907
+
1908
+ # expand delim_left(<)
1909
+ if s_own[6]:
1910
+ u, v = i, j
1911
+ while 0 <= v:
1912
+ _, s = t[u][v]
1913
+ if (i, j) != (u, v) and (s[2] and not s[10]):
1914
+ break
1915
+ Table._put((x, _repl(s)), t, (u, v), empty)
1916
+ v -= 1
1917
+
1918
+ # justification calculation
1919
+ widths = [
1920
+ 0,
1921
+ ] * totalcols # j
1922
+ heights = [
1923
+ 0,
1924
+ ] * totalrows # i
1925
+ for i, row in enumerate(t):
1926
+ for j, (x, s) in enumerate(row):
1927
+ # height caclulation
1928
+ heights[i] = max(heights[i], x.count("\n"))
1929
+
1930
+ # width calculation; non-delim fillers no contribution
1931
+ if s[2] or not s[10]:
1932
+ w = max(len(q) for q in x.split("\n"))
1933
+ widths[j] = max(widths[j], w)
1934
+ # no newline ==> height=1
1935
+ heights = [h + 1 for h in heights]
1936
+
1937
+ # render table
1938
+ rend = []
1939
+ roff = 0
1940
+ for i, row in enumerate(t):
1941
+ for j, (x, s) in enumerate(row):
1942
+ w, h = widths[j], heights[i]
1943
+
1944
+ # expand fillers and delimiters
1945
+ if s[2] or s[10]:
1946
+ xs = x.split("\n")
1947
+ xw0 = min(len(l) for l in xs)
1948
+ xw1 = max(len(l) for l in xs)
1949
+ xh = len(xs)
1950
+ if (xw0 == xw1 == w) and (xh == h):
1951
+ pass
1952
+ elif xw0 == xw1 == w:
1953
+ x = "\n".join(
1954
+ [
1955
+ xs[0],
1956
+ ]
1957
+ * h
1958
+ )
1959
+ elif xh == h:
1960
+ x = "\n".join([(l[0] if l else "") * w for l in xs])
1961
+ else:
1962
+ x = x[0] if x else " "
1963
+ x = "\n".join(
1964
+ [
1965
+ x * w,
1966
+ ]
1967
+ * h
1968
+ )
1969
+
1970
+ # justify horizontally
1971
+ x = [l.rjust(w) if s[0] else l.ljust(w) for l in x.split("\n")]
1972
+
1973
+ # justify vertically
1974
+ plus = [
1975
+ " " * w,
1976
+ ] * (h - len(x))
1977
+ x = plus + x if not s[1] else x + plus
1978
+
1979
+ # input to table
1980
+ for r, xline in enumerate(x):
1981
+ Table._put(xline, rend, (roff + r, j), None)
1982
+ roff += h
1983
+
1984
+ # return rendered string
1985
+ return "\n".join(["".join(r) for r in rend])
1986
+
1987
+ # parsing
1988
+ def _spec(s, transpose=False):
1989
+ if ":" in s:
1990
+ i = s.index(":")
1991
+ sp = s[i:]
1992
+ s = s[:i]
1993
+ else:
1994
+ sp = ""
1995
+ s = s.lower()
1996
+ return (
1997
+ int("r" in s), # 0:: 0:left(l) 1:right(r)
1998
+ int("t" in s), # 1:: 0:bottom(b) 1:top(t)
1999
+ int(any([i in s for i in [".", "<", ">", "^", "v"]])), # 2:: delim_here(.)
2000
+ int("^" in s if not transpose else "<" in s), # 3:: delim_up(^)
2001
+ int("v" in s if not transpose else ">" in s), # 4:: delim_down(v)
2002
+ int(">" in s if not transpose else "v" in s), # 5:: delim_right(>)
2003
+ int("<" in s if not transpose else "^" in s), # 6:: delim_left(<)
2004
+ int("+" in s), # 7:: subtable(+)
2005
+ int("-" in s if not transpose else "|" in s), # 8:: subtable_horiz(-)
2006
+ int("|" in s if not transpose else "-" in s), # 9:: subtable_vert(|)
2007
+ int("_" in s), # 10:: fill(_); if delim, overwrite; else fit
2008
+ sp, # 11:: special(:) f-string for numbers
2009
+ )
2010
+
2011
+ def _put(obj, t, ij, empty):
2012
+ i, j = ij
2013
+ while i >= len(t):
2014
+ t.append([])
2015
+ while j >= len(t[i]):
2016
+ t[i].append(empty)
2017
+ t[i][j] = obj
2018
+ return
2019
+
2020
+ def parse(
2021
+ table,
2022
+ delimiter=" ",
2023
+ orientation="br",
2024
+ double_colon=True,
2025
+ ):
2026
+ # disabling transpose
2027
+ transpose = False
2028
+
2029
+ # set up empty entry
2030
+ empty = ("", Table._spec(orientation, transpose))
2031
+
2032
+ # transpose
2033
+ t = []
2034
+ for i, row in enumerate(table):
2035
+ for j, item in enumerate(row):
2036
+ ij = (i, j) if not transpose else (j, i)
2037
+ if type(item) == tuple and len(item) == 2 and type(item[1]) == str:
2038
+ item = (item[0], Table._spec(item[1], transpose))
2039
+ elif double_colon and type(item) == str and "::" in item:
2040
+ x, s = item.split("::")
2041
+ item = (x, Table._spec(s, transpose))
2042
+ else:
2043
+ item = (item, Table._spec(orientation, transpose))
2044
+ Table._put(item, t, ij, empty)
2045
+
2046
+ # normalization
2047
+ maxcol = 0
2048
+ maxrow = len(t)
2049
+ for i, row in enumerate(t):
2050
+ # take element number into account
2051
+ maxcol = max(maxcol, len([i for i in row if not i[1][2]]))
2052
+
2053
+ # take subtables into account
2054
+ for j, (x, s) in enumerate(row):
2055
+ if s[7]:
2056
+ r = len(x)
2057
+ maxrow = max(maxrow, i + r)
2058
+ c = max(len(q) for q in x)
2059
+ maxcol = max(maxcol, j + c)
2060
+ elif s[8]:
2061
+ c = len(x)
2062
+ maxcol = max(maxcol, j + c)
2063
+ elif s[9]:
2064
+ r = len(x)
2065
+ maxrow = max(maxrow, i + r)
2066
+ totalcols = 2 * maxcol + 1
2067
+ totalrows = maxrow
2068
+ t += [[]] * (totalrows - len(t))
2069
+ newt = []
2070
+ delim = (delimiter, Table._spec("._" + orientation, transpose))
2071
+ for i, row in enumerate(t):
2072
+ wasd = False
2073
+ tcount = 0
2074
+ for j in range(totalcols):
2075
+ item = t[i][tcount] if tcount < len(t[i]) else empty
2076
+ isd = item[1][2]
2077
+ if wasd and isd:
2078
+ Table._put(empty, newt, (i, j), empty)
2079
+ wasd = False
2080
+ elif wasd and not isd:
2081
+ Table._put(item, newt, (i, j), empty)
2082
+ tcount += 1
2083
+ wasd = False
2084
+ elif not wasd and isd:
2085
+ Table._put(item, newt, (i, j), empty)
2086
+ tcount += 1
2087
+ wasd = True
2088
+ elif not wasd and not isd:
2089
+ Table._put(delim, newt, (i, j), empty)
2090
+ wasd = True
2091
+ t = newt
2092
+
2093
+ # normalization: add dummy last column for delimiter
2094
+ for row in t:
2095
+ row.append(empty)
2096
+
2097
+ # expand subtables
2098
+ delim_cols = [i for i in range(totalcols) if i % 2 == 0]
2099
+ while True:
2100
+ # find a table
2101
+ ij = None
2102
+ for i, row in enumerate(t):
2103
+ for j, item in enumerate(row):
2104
+ st, s = item
2105
+ if s[7]:
2106
+ ij = i, j, 7, st, s
2107
+ break
2108
+ elif s[8]:
2109
+ ij = i, j, 8, st, s
2110
+ break
2111
+ elif s[9]:
2112
+ ij = i, j, 9, st, s
2113
+ break
2114
+ if ij is not None:
2115
+ break
2116
+ if ij is None:
2117
+ break
2118
+
2119
+ # replace its specs
2120
+ i, j, k, st, s = ij
2121
+ s = list(s)
2122
+ s[7] = s[8] = s[9] = 0
2123
+ s = tuple(s)
2124
+
2125
+ # expand it
2126
+ if k == 7: # 2d table
2127
+ for x, row in enumerate(st):
2128
+ for y, obj in enumerate(row):
2129
+ a = i + x if not transpose else i + y
2130
+ b = j + 2 * y if not transpose else j + 2 * x
2131
+ Table._put((obj, s), t, (a, b), None)
2132
+ if k == 8: # subtable_horiz
2133
+ for y, obj in enumerate(st):
2134
+ Table._put((obj, s), t, (i, j + 2 * y), None)
2135
+ if k == 9: # subtable_vert
2136
+ for x, obj in enumerate(st):
2137
+ Table._put((obj, s), t, (i + x, j), None)
2138
+
2139
+ # return, finally
2140
+ return t
2141
+
2142
+
2143
+ class Resnet(nn.Module):
2144
+ def __init__(self, channels):
2145
+ super().__init__()
2146
+ self.channels = ch = channels
2147
+ self.net = nn.Sequential(
2148
+ nn.PReLU(ch),
2149
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
2150
+ nn.BatchNorm2d(ch),
2151
+ nn.PReLU(ch),
2152
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
2153
+ nn.BatchNorm2d(ch),
2154
+ )
2155
+ return
2156
+
2157
+ def forward(self, x):
2158
+ return x + self.net(x)
2159
+
2160
+
2161
+ class Synthesizer(nn.Module):
2162
+ def __init__(
2163
+ self, size, channels_image, channels_flow, channels_mask, channels_feature
2164
+ ):
2165
+ super().__init__()
2166
+ self.size = size
2167
+ self.diam = diam(self.size)
2168
+ self.channels_image = cimg = channels_image
2169
+ self.channels_flow = cflow = channels_flow
2170
+ self.channels_mask = cmask = channels_mask
2171
+ self.channels_feature = cfeat = channels_feature
2172
+ self.channels = ch = cimg + cflow // 2 + cmask + cfeat
2173
+ self.interpolator = Interpolator(self.size, mode="bilinear")
2174
+ self.net = nn.Sequential(
2175
+ nn.Conv2d(ch + 3, 64, kernel_size=1, padding=0),
2176
+ Resnet(64),
2177
+ nn.Sequential(
2178
+ nn.PReLU(64),
2179
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
2180
+ nn.BatchNorm2d(32),
2181
+ ),
2182
+ Resnet(32),
2183
+ nn.Sequential(
2184
+ nn.PReLU(32),
2185
+ nn.Conv2d(32, 16, kernel_size=3, padding=1),
2186
+ nn.BatchNorm2d(16),
2187
+ ),
2188
+ Resnet(16),
2189
+ nn.Sequential(
2190
+ nn.PReLU(16),
2191
+ nn.Conv2d(16, 3, kernel_size=3, padding=1),
2192
+ ),
2193
+ )
2194
+ return
2195
+
2196
+ def forward(self, images, flows, masks, features, return_more=False):
2197
+ itp = self.interpolator
2198
+ images = [
2199
+ (images[0] + images[1]) / 2,
2200
+ ] + images
2201
+ logimgs = [itp(pixel_logit(i[:, :3])) for i in images]
2202
+ cat = torch.cat(
2203
+ [
2204
+ *logimgs,
2205
+ *[itp(f).norm(dim=1, keepdim=True) / self.diam for f in flows],
2206
+ *[itp(m) for m in masks],
2207
+ *[itp(f) for f in features],
2208
+ ],
2209
+ dim=1,
2210
+ )
2211
+ residual = self.net(cat)
2212
+ return torch.sigmoid(logimgs[0] + 0.5 * residual), (
2213
+ locals() if return_more else None
2214
+ )
2215
+
2216
+
2217
+ class FlowZMetric(nn.Module):
2218
+ def __init__(self):
2219
+ super().__init__()
2220
+ return
2221
+
2222
+ def forward(self, img0, img1, flow0, flow1, return_more=False):
2223
+ # B(i0,f0) = i1
2224
+ # B(i1,f1) = i0
2225
+ # F(x,f0,z0)
2226
+ # F(x,f1,z1)
2227
+ img0 = kornia.color.rgb_to_lab(img0[:, :3])
2228
+ img1 = kornia.color.rgb_to_lab(img1[:, :3])
2229
+ return [
2230
+ -0.1 * (img1 - flow_backwarp(img0, flow0)).norm(dim=1, keepdim=True), # z0
2231
+ -0.1 * (img0 - flow_backwarp(img1, flow1)).norm(dim=1, keepdim=True), # z1
2232
+ ], (locals() if return_more else None)
2233
+
2234
+
2235
+ class NEDT(nn.Module):
2236
+ def __init__(self):
2237
+ super().__init__()
2238
+ return
2239
+
2240
+ def forward(
2241
+ self,
2242
+ img,
2243
+ t=2.0,
2244
+ sigma_factor=1 / 540,
2245
+ k=1.6,
2246
+ epsilon=0.01,
2247
+ kernel_factor=4,
2248
+ exp_factor=540 / 15,
2249
+ return_more=False,
2250
+ ):
2251
+ with torch.no_grad():
2252
+ dog = batch_dog(
2253
+ img,
2254
+ t=t,
2255
+ sigma=img.shape[-2] * sigma_factor,
2256
+ k=k,
2257
+ epsilon=epsilon,
2258
+ kernel_factor=kernel_factor,
2259
+ clip=False,
2260
+ )
2261
+ edt = batch_edt((dog > 0.5).float())
2262
+ ans = 1 - (-edt * exp_factor / max(edt.shape[-2:])).exp()
2263
+ return ans, (locals() if return_more else None)
2264
+
2265
+
2266
+ class HalfWarper(nn.Module):
2267
+ def __init__(self):
2268
+ super().__init__()
2269
+ self.channels_image = 4 * 3
2270
+ self.channels_flow = 2 * 2
2271
+ self.channels_mask = 2 * 1
2272
+ self.channels = self.channels_image + self.channels_flow + self.channels_mask
2273
+
2274
+ def morph_open(self, x, k):
2275
+ if k == 0:
2276
+ return x
2277
+ else:
2278
+ with torch.no_grad():
2279
+ return kornia.morphology.opening(x, torch.ones(k, k, device=x.device))
2280
+
2281
+ def forward(self, img0, img1, flow0, flow1, z0, z1, k, t=0.5, return_more=False):
2282
+ # forewarps
2283
+ flow0_ = (1 - t) * flow0
2284
+ flow1_ = t * flow1
2285
+ f01 = forewarp(img0, flow1_, mode="sm", metric=z1, mask=True)
2286
+ f10 = forewarp(img1, flow0_, mode="sm", metric=z0, mask=True)
2287
+ f01i, f01m = f01[:, :-1], self.morph_open(f01[:, -1:], k=k)
2288
+ f10i, f10m = f10[:, :-1], self.morph_open(f10[:, -1:], k=k)
2289
+
2290
+ # base guess
2291
+ base0 = f01m * f01i + (1 - f01m) * f10i
2292
+ base1 = f10m * f10i + (1 - f10m) * f01i
2293
+ ans = [
2294
+ [ # images
2295
+ base0,
2296
+ base1,
2297
+ f01i,
2298
+ f10i,
2299
+ ],
2300
+ [ # flows
2301
+ flow0_,
2302
+ flow1_,
2303
+ ],
2304
+ [ # masks
2305
+ f01m,
2306
+ f10m,
2307
+ ],
2308
+ ]
2309
+ return ans, (locals() if return_more else None)
2310
+
2311
+
2312
+ class ResnetFeatureExtractor(nn.Module):
2313
+ def __init__(self, inferserve_query, size_in=None):
2314
+ super().__init__()
2315
+ self.inferserve_query = iq = inferserve_query
2316
+ self.size_in = si = size_in
2317
+ if iq[0] == "torchvision":
2318
+ # use pytorch pretrained resnet50
2319
+ self.base_hparams = None
2320
+ resnet = tv.models.resnet50(pretrained=True)
2321
+
2322
+ self.resize = T.Resize(256)
2323
+ self.resnet_preprocess = T.Normalize(
2324
+ mean=[0.485, 0.456, 0.406],
2325
+ std=[0.229, 0.224, 0.225],
2326
+ )
2327
+ self.conv1 = resnet.conv1
2328
+ self.bn1 = resnet.bn1
2329
+ self.relu = resnet.relu # 64ch, 128p (assuming 256p input)
2330
+ self.maxpool = resnet.maxpool
2331
+ self.layer1 = resnet.layer1 # 256ch, 64p
2332
+ self.layer2 = resnet.layer2 # 512ch, 32p
2333
+ else:
2334
+ base = userving.infer_model_load(*iq).eval()
2335
+ self.base_hparams = base.hparams
2336
+
2337
+ self.resize = T.Resize(base.hparams.largs.size)
2338
+ self.resnet_preprocess = base.resnet_preprocess
2339
+ self.conv1 = base.resnet.conv1
2340
+ self.bn1 = base.resnet.bn1
2341
+ self.relu = base.resnet.relu # 64ch, 128p (assuming 256p input)
2342
+ self.maxpool = base.resnet.maxpool
2343
+ self.layer1 = base.resnet.layer1 # 256ch, 64p
2344
+ self.layer2 = base.resnet.layer2 # 512ch, 32p
2345
+ if self.size_in is None:
2346
+ self.sizes_out = None
2347
+ else:
2348
+ s = self.resize.size
2349
+ self.sizes_out = [
2350
+ pixel_ij(
2351
+ rescale_dry(si, (s // 2) / si[0]), rounding="ceil"
2352
+ ), # conv1, 128p
2353
+ pixel_ij(
2354
+ rescale_dry(si, (s // 4) / si[0]), rounding="ceil"
2355
+ ), # layer1, 64p
2356
+ pixel_ij(
2357
+ rescale_dry(si, (s // 8) / si[0]), rounding="ceil"
2358
+ ), # layer2, 32p
2359
+ ]
2360
+ self.channels = [
2361
+ 64,
2362
+ 256,
2363
+ 512,
2364
+ ]
2365
+ return
2366
+
2367
+ def forward(self, x, force_sizes_out=False, return_more=False):
2368
+ ans = []
2369
+ x = x[:, :3]
2370
+ x = self.resize(x)
2371
+ x = self.resnet_preprocess(x)
2372
+ x = self.conv1(x)
2373
+ x = self.bn1(x)
2374
+ x = self.relu(x)
2375
+ ans.append(x) # conv1
2376
+ x = self.maxpool(x)
2377
+ x = self.layer1(x)
2378
+ ans.append(x) # layer1
2379
+ x = self.layer2(x)
2380
+ ans.append(x) # layer2
2381
+ if force_sizes_out or (self.sizes_out is None):
2382
+ self.sizes_out = [tuple(q.shape[-2:]) for q in ans]
2383
+ return ans, (locals() if return_more else None)
2384
+
2385
+
2386
+ class NetNedt(nn.Module):
2387
+ def __init__(self):
2388
+ super().__init__()
2389
+ chin = 3 + 1 + 4 + 4 + 1 + 1
2390
+ ch = 16
2391
+ chout = 1
2392
+ self.net = nn.Sequential(
2393
+ nn.PReLU(chin),
2394
+ nn.Conv2d(chin, ch, kernel_size=3, padding=1),
2395
+ nn.BatchNorm2d(ch),
2396
+ nn.PReLU(ch),
2397
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
2398
+ nn.BatchNorm2d(ch),
2399
+ nn.PReLU(ch),
2400
+ nn.Conv2d(ch, chout, kernel_size=3, padding=1),
2401
+ )
2402
+ return
2403
+
2404
+ def forward(self, out_base, out_base_nedt, hw_imgs, hw_masks, return_more=False):
2405
+ cat = torch.cat(
2406
+ [
2407
+ out_base, # 3
2408
+ out_base_nedt, # 1
2409
+ hw_imgs[0], # 4
2410
+ hw_imgs[1], # 4
2411
+ hw_masks[0], # 1
2412
+ hw_masks[1], # 1
2413
+ ],
2414
+ dim=1,
2415
+ )
2416
+ log = pixel_logit(cat.clip(0, 1))
2417
+ ans = torch.sigmoid(self.net(log))
2418
+ return ans, (locals() if return_more else None)
2419
+
2420
+
2421
+ class NetTail(nn.Module):
2422
+ def __init__(self):
2423
+ super().__init__()
2424
+ chin = 3 + 1 + 1
2425
+ ch = 16
2426
+ chout = 3
2427
+ self.net = nn.Sequential(
2428
+ nn.PReLU(chin),
2429
+ nn.Conv2d(chin, ch, kernel_size=3, padding=1),
2430
+ nn.BatchNorm2d(ch),
2431
+ nn.PReLU(ch),
2432
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
2433
+ nn.BatchNorm2d(ch),
2434
+ nn.PReLU(ch),
2435
+ nn.Conv2d(ch, ch, kernel_size=3, padding=1),
2436
+ nn.BatchNorm2d(ch),
2437
+ nn.PReLU(ch),
2438
+ nn.Conv2d(ch, chout, kernel_size=3, padding=1),
2439
+ )
2440
+ return
2441
+
2442
+ def forward(self, out_base, out_base_nedt, pred_nedt, return_more=False):
2443
+ cat = torch.cat(
2444
+ [
2445
+ out_base, # 3
2446
+ out_base_nedt, # 1
2447
+ pred_nedt, # 1
2448
+ ],
2449
+ dim=1,
2450
+ )
2451
+ log = pixel_logit(cat.clip(0, 1))
2452
+ ans = torch.sigmoid(log[:, :3] + self.net(log))
2453
+ return ans, (locals() if return_more else None)
2454
+
2455
+
2456
+ class SoftsplatLite(nn.Module):
2457
+ def __init__(self):
2458
+ super().__init__()
2459
+ self.feature_extractor = ResnetFeatureExtractor(
2460
+ ("torchvision", "resnet50"),
2461
+ (540, 960),
2462
+ )
2463
+ self.z_metric = FlowZMetric()
2464
+ self.flow_downsamplers = [
2465
+ Interpolator(s, mode="bilinear") for s in self.feature_extractor.sizes_out
2466
+ ]
2467
+ self.gridnet_converter = GridnetConverter(
2468
+ self.feature_extractor.channels,
2469
+ [32, 64, 128],
2470
+ )
2471
+ self.gridnet = Gridnet(
2472
+ *[32, 64, 128],
2473
+ total_dropout_p=0.0,
2474
+ depth=1, # equivalent to u-net
2475
+ )
2476
+ self.nedt = NEDT()
2477
+ self.half_warper = HalfWarper()
2478
+ self.synthesizer = Synthesizer(
2479
+ (540, 960),
2480
+ self.half_warper.channels_image,
2481
+ self.half_warper.channels_flow,
2482
+ self.half_warper.channels_mask,
2483
+ self.gridnet.channels_0,
2484
+ )
2485
+ return
2486
+
2487
+ def forward(self, x, t=0.5, k=5, return_more=False):
2488
+ rm = return_more
2489
+ flow0, flow1 = x["flows"].swapaxes(0, 1)
2490
+ img0, img1 = x["images"][:, 0], x["images"][:, -1]
2491
+ (z0, z1), locs_z = self.z_metric(img0, img1, flow0, flow1, return_more=rm)
2492
+ img0 = torch.cat([img0, self.nedt(img0)[0]], dim=1)
2493
+ img1 = torch.cat([img1, self.nedt(img1)[0]], dim=1)
2494
+
2495
+ # images and flows
2496
+ (hw_imgs, hw_flows, hw_masks), locs_hw = self.half_warper(
2497
+ img0,
2498
+ img1,
2499
+ flow0,
2500
+ flow1,
2501
+ z0,
2502
+ z1,
2503
+ k,
2504
+ t=t,
2505
+ return_more=rm,
2506
+ )
2507
+
2508
+ # features
2509
+ feats0, locs_fe0 = self.feature_extractor(img0, return_more=rm)
2510
+ feats1, locs_fe1 = self.feature_extractor(img1, return_more=rm)
2511
+ warps = []
2512
+ for ft0, ft1, ds in zip(feats0, feats1, self.flow_downsamplers):
2513
+ (w, _, _), _ = self.half_warper(
2514
+ ft0,
2515
+ ft1,
2516
+ ds(flow0, 1),
2517
+ ds(flow1, 1),
2518
+ ds(z0),
2519
+ ds(z1),
2520
+ k,
2521
+ t=t,
2522
+ )
2523
+ warps.append((w[0] + w[1]) / 2)
2524
+ feats = self.gridnet(self.gridnet_converter(warps))
2525
+
2526
+ # synthesis
2527
+ pred, locs_synth = self.synthesizer(
2528
+ hw_imgs,
2529
+ hw_flows,
2530
+ hw_masks,
2531
+ [
2532
+ feats[0],
2533
+ ],
2534
+ return_more=rm,
2535
+ )
2536
+ return pred, (locals() if rm else None)
2537
+
2538
+
2539
+ class DTM(nn.Module):
2540
+ def __init__(self):
2541
+ super().__init__()
2542
+ self.net_nedt = NetNedt()
2543
+ self.net_tail = NetTail()
2544
+ self.nedt = NEDT()
2545
+ return
2546
+
2547
+ def forward(self, x, out_base, locs_base, return_more=False):
2548
+ rm = return_more
2549
+ with torch.no_grad():
2550
+ out_base_nedt, locs_base_nedt = self.nedt(out_base, return_more=rm)
2551
+ hw_imgs, hw_masks = locs_base["hw_imgs"], locs_base["hw_masks"]
2552
+ pred_nedt, locs_nedt = self.net_nedt(
2553
+ out_base, out_base_nedt, hw_imgs, hw_masks, return_more=rm
2554
+ )
2555
+ pred, locs_tail = self.net_tail(
2556
+ out_base, out_base_nedt, pred_nedt.clone().detach(), return_more=rm
2557
+ )
2558
+ return torch.cat([pred, pred_nedt], dim=1), (locals() if rm else None)
2559
+
2560
+
2561
+ class RAFT(nn.Module):
2562
+ def __init__(self, path="/workspace/tensorrt/models/anime_interp_full.ckpt"):
2563
+ super().__init__()
2564
+ self.raft = RFR(
2565
+ Namespace(
2566
+ small=False,
2567
+ mixed_precision=False,
2568
+ )
2569
+ )
2570
+ if path is not None:
2571
+ sd = torch.load(path)["model_state_dict"]
2572
+ self.raft.load_state_dict(
2573
+ {
2574
+ k[len("module.flownet.") :]: v
2575
+ for k, v in sd.items()
2576
+ if k.startswith("module.flownet.")
2577
+ },
2578
+ strict=False,
2579
+ )
2580
+ return
2581
+
2582
+ def forward(self, img0, img1, flow0=None, iters=12, return_more=False):
2583
+ if flow0 is not None:
2584
+ flow0 = flow0.flip(dims=(1,))
2585
+ out = self.raft(img1, img0, iters=iters, flow_init=flow0)
2586
+ return out[0].flip(dims=(1,)), (locals() if return_more else None)
vfi_models/film/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.model_management import get_torch_device, soft_empty_cache
3
+ import bisect
4
+ import numpy as np
5
+ import typing
6
+ from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames
7
+ import pathlib
8
+ import gc
9
+
10
+ MODEL_TYPE = pathlib.Path(__file__).parent.name
11
+ DEVICE = get_torch_device()
12
+ def inference(model, img_batch_1, img_batch_2, inter_frames):
13
+ results = [
14
+ img_batch_1,
15
+ img_batch_2
16
+ ]
17
+
18
+ idxes = [0, inter_frames + 1]
19
+ remains = list(range(1, inter_frames + 1))
20
+
21
+ splits = torch.linspace(0, 1, inter_frames + 2)
22
+
23
+ for _ in range(len(remains)):
24
+ starts = splits[idxes[:-1]]
25
+ ends = splits[idxes[1:]]
26
+ distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
27
+ matrix = torch.argmin(distances).item()
28
+ start_i, step = np.unravel_index(matrix, distances.shape)
29
+ end_i = start_i + 1
30
+
31
+ x0 = results[start_i].to(DEVICE)
32
+ x1 = results[end_i].to(DEVICE)
33
+ dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
34
+
35
+ with torch.no_grad():
36
+ prediction = model(x0, x1, dt)
37
+ insert_position = bisect.bisect_left(idxes, remains[step])
38
+ idxes.insert(insert_position, remains[step])
39
+ results.insert(insert_position, prediction.clamp(0, 1).float())
40
+ del remains[step]
41
+
42
+ return [tensor.flip(0) for tensor in results]
43
+
44
+ class FILM_VFI:
45
+ @classmethod
46
+ def INPUT_TYPES(s):
47
+ return {
48
+ "required": {
49
+ "ckpt_name": (["film_net_fp32.pt"], ),
50
+ "frames": ("IMAGE", ),
51
+ "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
52
+ "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
53
+ },
54
+ "optional": {
55
+ "optional_interpolation_states": ("INTERPOLATION_STATES", )
56
+ }
57
+ }
58
+
59
+ RETURN_TYPES = ("IMAGE", )
60
+ FUNCTION = "vfi"
61
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
62
+
63
+ def vfi(
64
+ self,
65
+ ckpt_name: typing.AnyStr,
66
+ frames: torch.Tensor,
67
+ clear_cache_after_n_frames = 10,
68
+ multiplier: typing.SupportsInt = 2,
69
+ optional_interpolation_states: InterpolationStateList = None,
70
+ **kwargs
71
+ ):
72
+ interpolation_states = optional_interpolation_states
73
+ model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
74
+ model = torch.jit.load(model_path, map_location='cpu')
75
+ model.eval()
76
+ model = model.to(DEVICE)
77
+ dtype = torch.float32
78
+
79
+ frames = preprocess_frames(frames)
80
+ number_of_frames_processed_since_last_cleared_cuda_cache = 0
81
+ output_frames = []
82
+
83
+ if type(multiplier) == int:
84
+ multipliers = [multiplier] * len(frames)
85
+ else:
86
+ multipliers = list(map(int, multiplier))
87
+ multipliers += [2] * (len(frames) - len(multipliers) - 1)
88
+ for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it
89
+ if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr):
90
+ continue
91
+ #Ensure that input frames are in fp32 - the same dtype as model
92
+ frame_0 = frames[frame_itr:frame_itr+1].to(DEVICE).float()
93
+ frame_1 = frames[frame_itr+1:frame_itr+2].to(DEVICE).float()
94
+ relust = inference(model, frame_0, frame_1, multipliers[frame_itr] - 1)
95
+ output_frames.extend([frame.detach().cpu().to(dtype=dtype) for frame in relust[:-1]])
96
+
97
+ number_of_frames_processed_since_last_cleared_cuda_cache += 1
98
+ # Try to avoid a memory overflow by clearing cuda cache regularly
99
+ if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames:
100
+ print("Comfy-VFI: Clearing cache...", end = ' ')
101
+ soft_empty_cache()
102
+ number_of_frames_processed_since_last_cleared_cuda_cache = 0
103
+ print("Done cache clearing")
104
+ gc.collect()
105
+
106
+ output_frames.append(frames[-1:].to(dtype=dtype)) # Append final frame
107
+ output_frames = [frame.cpu() for frame in output_frames] #Ensure all frames are in cpu
108
+ out = torch.cat(output_frames, dim=0)
109
+ # clear cache for courtesy
110
+ print("Comfy-VFI: Final clearing cache...", end = ' ')
111
+ soft_empty_cache()
112
+ print("Done cache clearing")
113
+ return (postprocess_frames(out), )
vfi_models/film/film_arch.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/dajes/frame-interpolation-pytorch/blob/main/feature_extractor.py
3
+ https://github.com/dajes/frame-interpolation-pytorch/blob/main/fusion.py
4
+ https://github.com/dajes/frame-interpolation-pytorch/blob/main/interpolator.py
5
+ https://github.com/dajes/frame-interpolation-pytorch/blob/main/pyramid_flow_estimator.py
6
+ https://github.com/dajes/frame-interpolation-pytorch/blob/main/util.py
7
+ """
8
+
9
+ """PyTorch layer for extracting image features for the film_net interpolator.
10
+
11
+ The feature extractor implemented here converts an image pyramid into a pyramid
12
+ of deep features. The feature pyramid serves a similar purpose as U-Net
13
+ architecture's encoder, but we use a special cascaded architecture described in
14
+ Multi-view Image Fusion [1].
15
+
16
+ For comprehensiveness, below is a short description of the idea. While the
17
+ description is a bit involved, the cascaded feature pyramid can be used just
18
+ like any image feature pyramid.
19
+
20
+ Why cascaded architeture?
21
+ =========================
22
+ To understand the concept it is worth reviewing a traditional feature pyramid
23
+ first: *A traditional feature pyramid* as in U-net or in many optical flow
24
+ networks is built by alternating between convolutions and pooling, starting
25
+ from the input image.
26
+
27
+ It is well known that early features of such architecture correspond to low
28
+ level concepts such as edges in the image whereas later layers extract
29
+ semantically higher level concepts such as object classes etc. In other words,
30
+ the meaning of the filters in each resolution level is different. For problems
31
+ such as semantic segmentation and many others this is a desirable property.
32
+
33
+ However, the asymmetric features preclude sharing weights across resolution
34
+ levels in the feature extractor itself and in any subsequent neural networks
35
+ that follow. This can be a downside, since optical flow prediction, for
36
+ instance is symmetric across resolution levels. The cascaded feature
37
+ architecture addresses this shortcoming.
38
+
39
+ How is it built?
40
+ ================
41
+ The *cascaded* feature pyramid contains feature vectors that have constant
42
+ length and meaning on each resolution level, except few of the finest ones. The
43
+ advantage of this is that the subsequent optical flow layer can learn
44
+ synergically from many resolutions. This means that coarse level prediction can
45
+ benefit from finer resolution training examples, which can be useful with
46
+ moderately sized datasets to avoid overfitting.
47
+
48
+ The cascaded feature pyramid is built by extracting shallower subtree pyramids,
49
+ each one of them similar to the traditional architecture. Each subtree
50
+ pyramid S_i is extracted starting from each resolution level:
51
+
52
+ image resolution 0 -> S_0
53
+ image resolution 1 -> S_1
54
+ image resolution 2 -> S_2
55
+ ...
56
+
57
+ If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
58
+ is constructed by concatenating features as follows (assuming subtree depth=3):
59
+
60
+ lvl
61
+ feat_0 = concat( S_0_0 )
62
+ feat_1 = concat( S_1_0 S_0_1 )
63
+ feat_2 = concat( S_2_0 S_1_1 S_0_2 )
64
+ feat_3 = concat( S_3_0 S_2_1 S_1_2 )
65
+ feat_4 = concat( S_4_0 S_3_1 S_2_2 )
66
+ feat_5 = concat( S_5_0 S_4_1 S_3_2 )
67
+ ....
68
+
69
+ In above, all levels except feat_0 and feat_1 have the same number of features
70
+ with similar semantic meaning. This enables training a single optical flow
71
+ predictor module shared by levels 2,3,4,5... . For more details and evaluation
72
+ see [1].
73
+
74
+ [1] Multi-view Image Fusion, Trinidad et al. 2019
75
+ """
76
+ from typing import List
77
+
78
+ import torch
79
+ from torch import nn
80
+ from torch.nn import functional as F
81
+
82
+
83
+ class SubTreeExtractor(nn.Module):
84
+ """Extracts a hierarchical set of features from an image.
85
+
86
+ This is a conventional, hierarchical image feature extractor, that extracts
87
+ [k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
88
+ Each level is followed by average pooling.
89
+ """
90
+
91
+ def __init__(self, in_channels=3, channels=64, n_layers=4):
92
+ super().__init__()
93
+ convs = []
94
+ for i in range(n_layers):
95
+ convs.append(nn.Sequential(
96
+ conv(in_channels, (channels << i), 3),
97
+ conv((channels << i), (channels << i), 3)
98
+ ))
99
+ in_channels = channels << i
100
+ self.convs = nn.ModuleList(convs)
101
+
102
+ def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
103
+ """Extracts a pyramid of features from the image.
104
+
105
+ Args:
106
+ image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
107
+ n: number of pyramid levels to extract. This can be less or equal to
108
+ options.sub_levels given in the __init__.
109
+ Returns:
110
+ The pyramid of features, starting from the finest level. Each element
111
+ contains the output after the last convolution on the corresponding
112
+ pyramid level.
113
+ """
114
+ head = image
115
+ pyramid = []
116
+ for i, layer in enumerate(self.convs):
117
+ head = layer(head)
118
+ pyramid.append(head)
119
+ if i < n - 1:
120
+ head = F.avg_pool2d(head, kernel_size=2, stride=2)
121
+ return pyramid
122
+
123
+
124
+ class FeatureExtractor(nn.Module):
125
+ """Extracts features from an image pyramid using a cascaded architecture.
126
+ """
127
+
128
+ def __init__(self, in_channels=3, channels=64, sub_levels=4):
129
+ super().__init__()
130
+ self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
131
+ self.sub_levels = sub_levels
132
+
133
+ def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
134
+ """Extracts a cascaded feature pyramid.
135
+
136
+ Args:
137
+ image_pyramid: Image pyramid as a list, starting from the finest level.
138
+ Returns:
139
+ A pyramid of cascaded features.
140
+ """
141
+ sub_pyramids: List[List[torch.Tensor]] = []
142
+ for i in range(len(image_pyramid)):
143
+ # At each level of the image pyramid, creates a sub_pyramid of features
144
+ # with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
145
+ # We use the same instance since we want to share the weights.
146
+ #
147
+ # However, we cap the depth of the sub_pyramid so we don't create features
148
+ # that are beyond the coarsest level of the cascaded feature pyramid we
149
+ # want to generate.
150
+ capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
151
+ sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
152
+ # Below we generate the cascades of features on each level of the feature
153
+ # pyramid. Assuming sub_levels=3, The layout of the features will be
154
+ # as shown in the example on file documentation above.
155
+ feature_pyramid: List[torch.Tensor] = []
156
+ for i in range(len(image_pyramid)):
157
+ features = sub_pyramids[i][0]
158
+ for j in range(1, self.sub_levels):
159
+ if j <= i:
160
+ features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
161
+ feature_pyramid.append(features)
162
+ return feature_pyramid
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+ """The final fusion stage for the film_net frame interpolator.
175
+
176
+ The inputs to this module are the warped input images, image features and
177
+ flow fields, all aligned to the target frame (often midway point between the
178
+ two original inputs). The output is the final image. FILM has no explicit
179
+ occlusion handling -- instead using the abovementioned information this module
180
+ automatically decides how to best blend the inputs together to produce content
181
+ in areas where the pixels can only be borrowed from one of the inputs.
182
+
183
+ Similarly, this module also decides on how much to blend in each input in case
184
+ of fractional timestep that is not at the halfway point. For example, if the two
185
+ inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
186
+ it often makes most sense to favor the first input. However, this is not
187
+ always the case -- in particular in occluded pixels.
188
+
189
+ The architecture of the Fusion module follows U-net [1] architecture's decoder
190
+ side, e.g. each pyramid level consists of concatenation with upsampled coarser
191
+ level output, and two 3x3 convolutions.
192
+
193
+ The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
194
+ upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
195
+ uses max-pooling which has a tendency to create checkerboard artifacts.
196
+
197
+ [1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
198
+ Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
199
+ [2] https://distill.pub/2016/deconv-checkerboard/
200
+ """
201
+ from typing import List
202
+
203
+ import torch
204
+ from torch import nn
205
+ from torch.nn import functional as F
206
+
207
+
208
+ _NUMBER_OF_COLOR_CHANNELS = 3
209
+
210
+
211
+ def get_channels_at_level(level, filters):
212
+ n_images = 2
213
+ channels = _NUMBER_OF_COLOR_CHANNELS
214
+ flows = 2
215
+
216
+ return (sum(filters << i for i in range(level)) + channels + flows) * n_images
217
+
218
+
219
+ class Fusion(nn.Module):
220
+ """The decoder."""
221
+
222
+ def __init__(self, n_layers=4, specialized_layers=3, filters=64):
223
+ """
224
+ Args:
225
+ m: specialized levels
226
+ """
227
+ super().__init__()
228
+
229
+ # The final convolution that outputs RGB:
230
+ self.output_conv = nn.Conv2d(filters, 3, kernel_size=1)
231
+
232
+ # Each item 'convs[i]' will contain the list of convolutions to be applied
233
+ # for pyramid level 'i'.
234
+ self.convs = nn.ModuleList()
235
+
236
+ # Create the convolutions. Roughly following the feature extractor, we
237
+ # double the number of filters when the resolution halves, but only up to
238
+ # the specialized_levels, after which we use the same number of filters on
239
+ # all levels.
240
+ #
241
+ # We create the convs in fine-to-coarse order, so that the array index
242
+ # for the convs will correspond to our normal indexing (0=finest level).
243
+ # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442)
244
+
245
+ in_channels = get_channels_at_level(n_layers, filters)
246
+ increase = 0
247
+ for i in range(n_layers)[::-1]:
248
+ num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
249
+ convs = nn.ModuleList([
250
+ conv(in_channels, num_filters, size=2, activation=None),
251
+ conv(in_channels + (increase or num_filters), num_filters, size=3),
252
+ conv(num_filters, num_filters, size=3)]
253
+ )
254
+ self.convs.append(convs)
255
+ in_channels = num_filters
256
+ increase = get_channels_at_level(i, filters) - num_filters // 2
257
+
258
+ def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor:
259
+ """Runs the fusion module.
260
+
261
+ Args:
262
+ pyramid: The input feature pyramid as list of tensors. Each tensor being
263
+ in (B x H x W x C) format, with finest level tensor first.
264
+
265
+ Returns:
266
+ A batch of RGB images.
267
+ Raises:
268
+ ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
269
+ the constructor.
270
+ """
271
+
272
+ # As a slight difference to a conventional decoder (e.g. U-net), we don't
273
+ # apply any extra convolutions to the coarsest level, but just pass it
274
+ # to finer levels for concatenation. This choice has not been thoroughly
275
+ # evaluated, but is motivated by the educated guess that the fusion part
276
+ # probably does not need large spatial context, because at this point the
277
+ # features are spatially aligned by the preceding warp.
278
+ net = pyramid[-1]
279
+
280
+ # Loop starting from the 2nd coarsest level:
281
+ # for i in reversed(range(0, len(pyramid) - 1)):
282
+ for k, layers in enumerate(self.convs):
283
+ i = len(self.convs) - 1 - k
284
+ # Resize the tensor from coarser level to match for concatenation.
285
+ level_size = pyramid[i].shape[2:4]
286
+ net = F.interpolate(net, size=level_size, mode='nearest')
287
+ net = layers[0](net)
288
+ net = torch.cat([pyramid[i], net], dim=1)
289
+ net = layers[1](net)
290
+ net = layers[2](net)
291
+ net = self.output_conv(net)
292
+ return net
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+
303
+
304
+ """The film_net frame interpolator main model code.
305
+
306
+ Basics
307
+ ======
308
+ The film_net is an end-to-end learned neural frame interpolator implemented as
309
+ a PyTorch model. It has the following inputs and outputs:
310
+
311
+ Inputs:
312
+ x0: image A.
313
+ x1: image B.
314
+ time: desired sub-frame time.
315
+
316
+ Outputs:
317
+ image: the predicted in-between image at the chosen time in range [0, 1].
318
+
319
+ Additional outputs include forward and backward warped image pyramids, flow
320
+ pyramids, etc., that can be visualized for debugging and analysis.
321
+
322
+ Note that many training sets only contain triplets with ground truth at
323
+ time=0.5. If a model has been trained with such training set, it will only work
324
+ well for synthesizing frames at time=0.5. Such models can only generate more
325
+ in-between frames using recursion.
326
+
327
+ Architecture
328
+ ============
329
+ The inference consists of three main stages: 1) feature extraction 2) warping
330
+ 3) fusion. On high-level, the architecture has similarities to Context-aware
331
+ Synthesis for Video Frame Interpolation [1], but the exact architecture is
332
+ closer to Multi-view Image Fusion [2] with some modifications for the frame
333
+ interpolation use-case.
334
+
335
+ Feature extraction stage employs the cascaded multi-scale architecture described
336
+ in [2]. The advantage of this architecture is that coarse level flow prediction
337
+ can be learned from finer resolution image samples. This is especially useful
338
+ to avoid overfitting with moderately sized datasets.
339
+
340
+ The warping stage uses a residual flow prediction idea that is similar to
341
+ PWC-Net [3], Multi-view Image Fusion [2] and many others.
342
+
343
+ The fusion stage is similar to U-Net's decoder where the skip connections are
344
+ connected to warped image and feature pyramids. This is described in [2].
345
+
346
+ Implementation Conventions
347
+ ====================
348
+ Pyramids
349
+ --------
350
+ Throughtout the model, all image and feature pyramids are stored as python lists
351
+ with finest level first followed by downscaled versions obtained by successively
352
+ halving the resolution. The depths of all pyramids are determined by
353
+ options.pyramid_levels. The only exception to this is internal to the feature
354
+ extractor, where smaller feature pyramids are temporarily constructed with depth
355
+ options.sub_levels.
356
+
357
+ Color ranges & gamma
358
+ --------------------
359
+ The model code makes no assumptions on whether the images are in gamma or
360
+ linearized space or what is the range of RGB color values. So a model can be
361
+ trained with different choices. This does not mean that all the choices lead to
362
+ similar results. In practice the model has been proven to work well with RGB
363
+ scale = [0,1] with gamma-space images (i.e. not linearized).
364
+
365
+ [1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
366
+ [2] Multi-view Image Fusion, Trinidad et al, 2019
367
+ [3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
368
+ """
369
+ from typing import Dict, List
370
+
371
+ import torch
372
+ from torch import nn
373
+
374
+
375
+
376
+ class Interpolator(nn.Module):
377
+ def __init__(
378
+ self,
379
+ pyramid_levels=7,
380
+ fusion_pyramid_levels=5,
381
+ specialized_levels=3,
382
+ sub_levels=4,
383
+ filters=64,
384
+ flow_convs=(3, 3, 3, 3),
385
+ flow_filters=(32, 64, 128, 256),
386
+ ):
387
+ super().__init__()
388
+ self.pyramid_levels = pyramid_levels
389
+ self.fusion_pyramid_levels = fusion_pyramid_levels
390
+
391
+ self.extract = FeatureExtractor(3, filters, sub_levels)
392
+ self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters)
393
+ self.fuse = Fusion(sub_levels, specialized_levels, filters)
394
+
395
+ def shuffle_images(self, x0, x1):
396
+ return [
397
+ build_image_pyramid(x0, self.pyramid_levels),
398
+ build_image_pyramid(x1, self.pyramid_levels)
399
+ ]
400
+
401
+ def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
402
+ image_pyramids = self.shuffle_images(x0, x1)
403
+
404
+ # Siamese feature pyramids:
405
+ feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])]
406
+
407
+ # Predict forward flow.
408
+ forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1])
409
+
410
+ # Predict backward flow.
411
+ backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0])
412
+
413
+ # Concatenate features and images:
414
+
415
+ # Note that we keep up to 'fusion_pyramid_levels' levels as only those
416
+ # are used by the fusion module.
417
+
418
+ forward_flow_pyramid = flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
419
+
420
+ backward_flow_pyramid = flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
421
+
422
+ # We multiply the flows with t and 1-t to warp to the desired fractional time.
423
+ #
424
+ # Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
425
+ # lator for multi-frame interpolation. Below, we create a constant tensor of
426
+ # shape [B]. We use the `time` tensor to infer the batch size.
427
+ mid_time = torch.full_like(batch_dt, .5)
428
+ backward_flow = multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
429
+ forward_flow = multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
430
+
431
+ pyramids_to_warp = [
432
+ concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels],
433
+ feature_pyramids[0][:self.fusion_pyramid_levels]),
434
+ concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels],
435
+ feature_pyramids[1][:self.fusion_pyramid_levels])
436
+ ]
437
+
438
+ # Warp features and images using the flow. Note that we use backward warping
439
+ # and backward flow is used to read from image 0 and forward flow from
440
+ # image 1.
441
+ forward_warped_pyramid = pyramid_warp(pyramids_to_warp[0], backward_flow)
442
+ backward_warped_pyramid = pyramid_warp(pyramids_to_warp[1], forward_flow)
443
+
444
+ aligned_pyramid = concatenate_pyramids(forward_warped_pyramid,
445
+ backward_warped_pyramid)
446
+ aligned_pyramid = concatenate_pyramids(aligned_pyramid, backward_flow)
447
+ aligned_pyramid = concatenate_pyramids(aligned_pyramid, forward_flow)
448
+
449
+ return {
450
+ 'image': [self.fuse(aligned_pyramid)],
451
+ 'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
452
+ 'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
453
+ 'forward_flow_pyramid': forward_flow_pyramid,
454
+ 'backward_flow_pyramid': backward_flow_pyramid,
455
+ }
456
+
457
+
458
+ def forward(self, x0, x1, batch_dt) -> torch.Tensor:
459
+ return self.debug_forward(x0, x1, batch_dt)['image'][0]
460
+
461
+
462
+
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+ """PyTorch layer for estimating optical flow by a residual flow pyramid.
471
+
472
+ This approach of estimating optical flow between two images can be traced back
473
+ to [1], but is also used by later neural optical flow computation methods such
474
+ as SpyNet [2] and PWC-Net [3].
475
+
476
+ The basic idea is that the optical flow is first estimated in a coarse
477
+ resolution, then the flow is upsampled to warp the higher resolution image and
478
+ then a residual correction is computed and added to the estimated flow. This
479
+ process is repeated in a pyramid on coarse to fine order to successively
480
+ increase the resolution of both optical flow and the warped image.
481
+
482
+ In here, the optical flow predictor is used as an internal component for the
483
+ film_net frame interpolator, to warp the two input images into the inbetween,
484
+ target frame.
485
+
486
+ [1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
487
+ [2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
488
+ Network. 2016
489
+ [3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
490
+ Pyramid, Warping, and Cost Volume, 2017
491
+ """
492
+ from typing import List
493
+
494
+ import torch
495
+ from torch import nn
496
+ from torch.nn import functional as F
497
+
498
+
499
+
500
+ class FlowEstimator(nn.Module):
501
+ """Small-receptive field predictor for computing the flow between two images.
502
+
503
+ This is used to compute the residual flow fields in PyramidFlowEstimator.
504
+
505
+ Note that while the number of 3x3 convolutions & filters to apply is
506
+ configurable, two extra 1x1 convolutions are appended to extract the flow in
507
+ the end.
508
+
509
+ Attributes:
510
+ name: The name of the layer
511
+ num_convs: Number of 3x3 convolutions to apply
512
+ num_filters: Number of filters in each 3x3 convolution
513
+ """
514
+
515
+ def __init__(self, in_channels: int, num_convs: int, num_filters: int):
516
+ super(FlowEstimator, self).__init__()
517
+
518
+ self._convs = nn.ModuleList()
519
+ for i in range(num_convs):
520
+ self._convs.append(conv(in_channels=in_channels, out_channels=num_filters, size=3))
521
+ in_channels = num_filters
522
+ self._convs.append(conv(in_channels, num_filters // 2, size=1))
523
+ in_channels = num_filters // 2
524
+ # For the final convolution, we want no activation at all to predict the
525
+ # optical flow vector values. We have done extensive testing on explicitly
526
+ # bounding these values using sigmoid, but it turned out that having no
527
+ # activation gives better results.
528
+ self._convs.append(conv(in_channels, 2, size=1, activation=None))
529
+
530
+ def forward(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor:
531
+ """Estimates optical flow between two images.
532
+
533
+ Args:
534
+ features_a: per pixel feature vectors for image A (B x H x W x C)
535
+ features_b: per pixel feature vectors for image B (B x H x W x C)
536
+
537
+ Returns:
538
+ A tensor with optical flow from A to B
539
+ """
540
+ net = torch.cat([features_a, features_b], dim=1)
541
+ for conv in self._convs:
542
+ net = conv(net)
543
+ return net
544
+
545
+
546
+ class PyramidFlowEstimator(nn.Module):
547
+ """Predicts optical flow by coarse-to-fine refinement.
548
+ """
549
+
550
+ def __init__(self, filters: int = 64,
551
+ flow_convs: tuple = (3, 3, 3, 3),
552
+ flow_filters: tuple = (32, 64, 128, 256)):
553
+ super(PyramidFlowEstimator, self).__init__()
554
+
555
+ in_channels = filters << 1
556
+ predictors = []
557
+ for i in range(len(flow_convs)):
558
+ predictors.append(
559
+ FlowEstimator(
560
+ in_channels=in_channels,
561
+ num_convs=flow_convs[i],
562
+ num_filters=flow_filters[i]))
563
+ in_channels += filters << (i + 2)
564
+ self._predictor = predictors[-1]
565
+ self._predictors = nn.ModuleList(predictors[:-1][::-1])
566
+
567
+ def forward(self, feature_pyramid_a: List[torch.Tensor],
568
+ feature_pyramid_b: List[torch.Tensor]) -> List[torch.Tensor]:
569
+ """Estimates residual flow pyramids between two image pyramids.
570
+
571
+ Each image pyramid is represented as a list of tensors in fine-to-coarse
572
+ order. Each individual image is represented as a tensor where each pixel is
573
+ a vector of image features.
574
+
575
+ flow_pyramid_synthesis can be used to convert the residual flow
576
+ pyramid returned by this method into a flow pyramid, where each level
577
+ encodes the flow instead of a residual correction.
578
+
579
+ Args:
580
+ feature_pyramid_a: image pyramid as a list in fine-to-coarse order
581
+ feature_pyramid_b: image pyramid as a list in fine-to-coarse order
582
+
583
+ Returns:
584
+ List of flow tensors, in fine-to-coarse order, each level encoding the
585
+ difference against the bilinearly upsampled version from the coarser
586
+ level. The coarsest flow tensor, e.g. the last element in the array is the
587
+ 'DC-term', e.g. not a residual (alternatively you can think of it being a
588
+ residual against zero).
589
+ """
590
+ levels = len(feature_pyramid_a)
591
+ v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
592
+ residuals = [v]
593
+ for i in range(levels - 2, len(self._predictors) - 1, -1):
594
+ # Upsamples the flow to match the current pyramid level. Also, scales the
595
+ # magnitude by two to reflect the new size.
596
+ level_size = feature_pyramid_a[i].shape[2:4]
597
+ v = F.interpolate(2 * v, size=level_size, mode='bilinear')
598
+ # Warp feature_pyramid_b[i] image based on the current flow estimate.
599
+ warped = warp(feature_pyramid_b[i], v)
600
+ # Estimate the residual flow between pyramid_a[i] and warped image:
601
+ v_residual = self._predictor(feature_pyramid_a[i], warped)
602
+ residuals.insert(0, v_residual)
603
+ v = v_residual + v
604
+
605
+ for k, predictor in enumerate(self._predictors):
606
+ i = len(self._predictors) - 1 - k
607
+ # Upsamples the flow to match the current pyramid level. Also, scales the
608
+ # magnitude by two to reflect the new size.
609
+ level_size = feature_pyramid_a[i].shape[2:4]
610
+ v = F.interpolate(2 * v, size=level_size, mode='bilinear')
611
+ # Warp feature_pyramid_b[i] image based on the current flow estimate.
612
+ warped = warp(feature_pyramid_b[i], v)
613
+ # Estimate the residual flow between pyramid_a[i] and warped image:
614
+ v_residual = predictor(feature_pyramid_a[i], warped)
615
+ residuals.insert(0, v_residual)
616
+ v = v_residual + v
617
+ return residuals
618
+
619
+
620
+
621
+
622
+
623
+
624
+
625
+
626
+
627
+
628
+ """Various utilities used in the film_net frame interpolator model."""
629
+ from typing import List, Optional
630
+
631
+ import cv2
632
+ import numpy as np
633
+ import torch
634
+ from torch import nn
635
+ from torch.nn import functional as F
636
+
637
+
638
+ def pad_batch(batch, align):
639
+ height, width = batch.shape[1:3]
640
+ height_to_pad = (align - height % align) if height % align != 0 else 0
641
+ width_to_pad = (align - width % align) if width % align != 0 else 0
642
+
643
+ crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)]
644
+ batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
645
+ (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant')
646
+ return batch, crop_region
647
+
648
+
649
+ def load_image(path, align=64):
650
+ image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
651
+ image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align)
652
+ return image_batch, crop_region
653
+
654
+
655
+ def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]:
656
+ """Builds an image pyramid from a given image.
657
+
658
+ The original image is included in the pyramid and the rest are generated by
659
+ successively halving the resolution.
660
+
661
+ Args:
662
+ image: the input image.
663
+ options: film_net options object
664
+
665
+ Returns:
666
+ A list of images starting from the finest with options.pyramid_levels items
667
+ """
668
+
669
+ pyramid = []
670
+ for i in range(pyramid_levels):
671
+ pyramid.append(image)
672
+ if i < pyramid_levels - 1:
673
+ image = F.avg_pool2d(image, 2, 2)
674
+ return pyramid
675
+
676
+
677
+ def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
678
+ """Backward warps the image using the given flow.
679
+
680
+ Specifically, the output pixel in batch b, at position x, y will be computed
681
+ as follows:
682
+ (flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
683
+ output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
684
+
685
+ Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
686
+ y in position 1.
687
+
688
+ Args:
689
+ image: An image with shape BxHxWxC.
690
+ flow: A flow with shape BxHxWx2, with the two channels denoting the relative
691
+ offset in order: (dx, dy).
692
+ Returns:
693
+ A warped image.
694
+ """
695
+ flow = -flow.flip(1)
696
+
697
+ dtype = flow.dtype
698
+ device = flow.device
699
+
700
+ # warped = tfa_image.dense_image_warp(image, flow)
701
+ # Same as above but with pytorch
702
+ ls1 = 1 - 1 / flow.shape[3]
703
+ ls2 = 1 - 1 / flow.shape[2]
704
+
705
+ normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor(
706
+ [flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None]
707
+ normalized_flow2 = torch.stack([
708
+ torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1],
709
+ torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0],
710
+ ], dim=3)
711
+
712
+ padding_mode = "border"
713
+ if device.type == "mps":
714
+ # https://github.com/pytorch/pytorch/issues/125098
715
+ padding_mode = "zeros"
716
+ normalized_flow2 = normalized_flow2.clamp(-1, 1)
717
+ warped = F.grid_sample(
718
+ input=image,
719
+ grid=normalized_flow2,
720
+ mode='bilinear',
721
+ padding_mode=padding_mode,
722
+ align_corners=False,
723
+ )
724
+ return warped.reshape(image.shape)
725
+
726
+
727
+ def multiply_pyramid(pyramid: List[torch.Tensor],
728
+ scalar: torch.Tensor) -> List[torch.Tensor]:
729
+ """Multiplies all image batches in the pyramid by a batch of scalars.
730
+
731
+ Args:
732
+ pyramid: Pyramid of image batches.
733
+ scalar: Batch of scalars.
734
+
735
+ Returns:
736
+ An image pyramid with all images multiplied by the scalar.
737
+ """
738
+ # To multiply each image with its corresponding scalar, we first transpose
739
+ # the batch of images from BxHxWxC-format to CxHxWxB. This can then be
740
+ # multiplied with a batch of scalars, then we transpose back to the standard
741
+ # BxHxWxC form.
742
+ return [image * scalar for image in pyramid]
743
+
744
+
745
+ def flow_pyramid_synthesis(
746
+ residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
747
+ """Converts a residual flow pyramid into a flow pyramid."""
748
+ flow = residual_pyramid[-1]
749
+ flow_pyramid: List[torch.Tensor] = [flow]
750
+ for residual_flow in residual_pyramid[:-1][::-1]:
751
+ level_size = residual_flow.shape[2:4]
752
+ flow = F.interpolate(2 * flow, size=level_size, mode='bilinear')
753
+ flow = residual_flow + flow
754
+ flow_pyramid.insert(0, flow)
755
+ return flow_pyramid
756
+
757
+
758
+ def pyramid_warp(feature_pyramid: List[torch.Tensor],
759
+ flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
760
+ """Warps the feature pyramid using the flow pyramid.
761
+
762
+ Args:
763
+ feature_pyramid: feature pyramid starting from the finest level.
764
+ flow_pyramid: flow fields, starting from the finest level.
765
+
766
+ Returns:
767
+ Reverse warped feature pyramid.
768
+ """
769
+ warped_feature_pyramid = []
770
+ for features, flow in zip(feature_pyramid, flow_pyramid):
771
+ warped_feature_pyramid.append(warp(features, flow))
772
+ return warped_feature_pyramid
773
+
774
+
775
+ def concatenate_pyramids(pyramid1: List[torch.Tensor],
776
+ pyramid2: List[torch.Tensor]) -> List[torch.Tensor]:
777
+ """Concatenates each pyramid level together in the channel dimension."""
778
+ result = []
779
+ for features1, features2 in zip(pyramid1, pyramid2):
780
+ result.append(torch.cat([features1, features2], dim=1))
781
+ return result
782
+
783
+
784
+ def conv(in_channels, out_channels, size, activation: Optional[str] = 'relu'):
785
+ # Since PyTorch doesn't have an in-built activation in Conv2d, we use a
786
+ # Sequential layer to combine Conv2d and Leaky ReLU in one module.
787
+ _conv = nn.Conv2d(
788
+ in_channels=in_channels,
789
+ out_channels=out_channels,
790
+ kernel_size=size,
791
+ padding='same')
792
+ if activation is None:
793
+ return _conv
794
+ assert activation == 'relu'
795
+ return nn.Sequential(
796
+ _conv,
797
+ nn.LeakyReLU(.2)
798
+ )
vfi_models/flavr/__init__.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.model_management import get_torch_device, soft_empty_cache
3
+ import numpy as np
4
+ import typing
5
+ from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames, assert_batch_size
6
+ import pathlib
7
+ import warnings
8
+ from .flavr_arch import UNet_3D_3D, InputPadder
9
+ import gc
10
+
11
+ device = get_torch_device()
12
+ NBR_FRAME = 4
13
+
14
+ def build_flavr(model_path):
15
+ sd = torch.load(model_path)['state_dict']
16
+ sd = {k.partition("module.")[-1]:v for k,v in sd.items()}
17
+
18
+ #Ref: Class UNet_3D_3D
19
+ model = UNet_3D_3D("unet_18", n_inputs=NBR_FRAME, n_outputs=sd["outconv.1.weight"].shape[0] // 3, joinType="concat" , upmode="transpose")
20
+ model.load_state_dict(sd)
21
+ model.to(device).eval()
22
+ del sd
23
+ return model
24
+
25
+ MODEL_TYPE = pathlib.Path(__file__).parent.name
26
+ CKPT_NAMES = ["FLAVR_2x.pth", "FLAVR_4x.pth", "FLAVR_8x.pth"]
27
+
28
+ class FLAVR_VFI:
29
+ @classmethod
30
+ def INPUT_TYPES(s):
31
+ return {
32
+ "required": {
33
+ "ckpt_name": (CKPT_NAMES, ),
34
+ "frames": ("IMAGE", ),
35
+ "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
36
+ "multiplier": ("INT", {"default": 2, "min": 2, "max": 2}), #TODO: Implement recursively invoking interpolator for multi-frame interpolation
37
+ "duplicate_first_last_frames": ("BOOLEAN", {"default": False})
38
+ },
39
+ "optional": {
40
+ "optional_interpolation_states": ("INTERPOLATION_STATES", )
41
+ }
42
+ }
43
+
44
+ RETURN_TYPES = ("IMAGE", )
45
+ FUNCTION = "vfi"
46
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
47
+
48
+ #Reference: https://github.com/danier97/ST-MFNet/blob/main/interpolate_yuv.py#L93
49
+ def vfi(
50
+ self,
51
+ ckpt_name: typing.AnyStr,
52
+ frames: torch.Tensor,
53
+ clear_cache_after_n_frames = 10,
54
+ multiplier: typing.SupportsInt = 2,
55
+ duplicate_first_last_frames: bool = False,
56
+ optional_interpolation_states: InterpolationStateList = None,
57
+ **kwargs
58
+ ):
59
+ if multiplier != 2:
60
+ warnings.warn("Currently, FLAVR only supports 2x interpolation. The process will continue but please set multiplier=2 afterward")
61
+
62
+ assert_batch_size(frames, batch_size=4, vfi_name="ST-MFNet")
63
+ interpolation_states = optional_interpolation_states
64
+ model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
65
+ model = build_flavr(model_path)
66
+ frames = preprocess_frames(frames)
67
+ padder = InputPadder(frames.shape, 16)
68
+ frames = padder.pad(frames)
69
+
70
+ number_of_frames_processed_since_last_cleared_cuda_cache = 0
71
+ output_frames = []
72
+ for frame_itr in range(len(frames) - 3):
73
+ #Does skipping frame i+1 make sanse in this case?
74
+ if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr) and interpolation_states.is_frame_skipped(frame_itr + 1):
75
+ continue
76
+
77
+ #Ensure that input frames are in fp32 - the same dtype as model
78
+ frame0, frame1, frame2, frame3 = (
79
+ frames[frame_itr:frame_itr+1].float(),
80
+ frames[frame_itr+1:frame_itr+2].float(),
81
+ frames[frame_itr+2:frame_itr+3].float(),
82
+ frames[frame_itr+3:frame_itr+4].float()
83
+ )
84
+ new_frame = model([frame0.to(device), frame1.to(device), frame2.to(device), frame3.to(device)])[0].detach().cpu()
85
+ number_of_frames_processed_since_last_cleared_cuda_cache += 2
86
+
87
+ if frame_itr == 0:
88
+ output_frames.append(frame0)
89
+ if duplicate_first_last_frames:
90
+ output_frames.append(frame0) # repeat the first frame
91
+ output_frames.append(frame1)
92
+ output_frames.append(new_frame)
93
+ output_frames.append(frame2)
94
+ if frame_itr == len(frames) - 4:
95
+ output_frames.append(frame3)
96
+ if duplicate_first_last_frames:
97
+ output_frames.append(frame3) # repeat the last frame
98
+
99
+ # Try to avoid a memory overflow by clearing cuda cache regularly
100
+ if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames:
101
+ print("Comfy-VFI: Clearing cache...", end = ' ')
102
+ soft_empty_cache()
103
+ number_of_frames_processed_since_last_cleared_cuda_cache = 0
104
+ print("Done cache clearing")
105
+ gc.collect()
106
+
107
+ dtype = torch.float32
108
+ output_frames = [frame.cpu().to(dtype=dtype) for frame in output_frames] #Ensure all frames are in cpu
109
+ out = torch.cat(output_frames, dim=0)
110
+ out = padder.unpad(out)
111
+ # clear cache for courtesy
112
+ print("Comfy-VFI: Final clearing cache...", end=' ')
113
+ soft_empty_cache()
114
+ print("Done cache clearing")
115
+ return (postprocess_frames(out), )
vfi_models/flavr/flavr_arch.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/tarun005/FLAVR/blob/main/model/FLAVR_arch.py
3
+ https://github.com/tarun005/FLAVR/blob/main/model/resnet_3D.py (only SEGating)
4
+ """
5
+ import math
6
+ import numpy as np
7
+ import importlib
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ class SEGating(nn.Module):
14
+
15
+ def __init__(self , inplanes , reduction=16):
16
+
17
+ super().__init__()
18
+
19
+ self.pool = nn.AdaptiveAvgPool3d(1)
20
+ self.attn_layer = nn.Sequential(
21
+ nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
22
+ nn.Sigmoid()
23
+ )
24
+
25
+ def forward(self , x):
26
+
27
+ out = self.pool(x)
28
+ y = self.attn_layer(out)
29
+ return x * y
30
+
31
+ def joinTensors(X1 , X2 , type="concat"):
32
+
33
+ if type == "concat":
34
+ return torch.cat([X1 , X2] , dim=1)
35
+ elif type == "add":
36
+ return X1 + X2
37
+ else:
38
+ return X1
39
+
40
+
41
+ class Conv_2d(nn.Module):
42
+
43
+ def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
44
+
45
+ super().__init__()
46
+ self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
47
+
48
+ if batchnorm:
49
+ self.conv += [nn.BatchNorm2d(out_ch)]
50
+
51
+ self.conv = nn.Sequential(*self.conv)
52
+
53
+ def forward(self, x):
54
+
55
+ return self.conv(x)
56
+
57
+ class upConv3D(nn.Module):
58
+
59
+ def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
60
+
61
+ super().__init__()
62
+
63
+ self.upmode = upmode
64
+
65
+ if self.upmode=="transpose":
66
+ self.upconv = nn.ModuleList(
67
+ [nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
68
+ SEGating(out_ch)
69
+ ]
70
+ )
71
+
72
+ else:
73
+ self.upconv = nn.ModuleList(
74
+ [nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
75
+ nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
76
+ SEGating(out_ch)
77
+ ]
78
+ )
79
+
80
+ if batchnorm:
81
+ self.upconv += [nn.BatchNorm3d(out_ch)]
82
+
83
+ self.upconv = nn.Sequential(*self.upconv)
84
+
85
+ def forward(self, x):
86
+
87
+ return self.upconv(x)
88
+
89
+ class Conv_3d(nn.Module):
90
+
91
+ def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
92
+
93
+ super().__init__()
94
+ self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
95
+ SEGating(out_ch)
96
+ ]
97
+
98
+ if batchnorm:
99
+ self.conv += [nn.BatchNorm3d(out_ch)]
100
+
101
+ self.conv = nn.Sequential(*self.conv)
102
+
103
+ def forward(self, x):
104
+
105
+ return self.conv(x)
106
+
107
+ class upConv2D(nn.Module):
108
+
109
+ def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
110
+
111
+ super().__init__()
112
+
113
+ self.upmode = upmode
114
+
115
+ if self.upmode=="transpose":
116
+ self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
117
+
118
+ else:
119
+ self.upconv = [
120
+ nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
121
+ nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
122
+ ]
123
+
124
+ if batchnorm:
125
+ self.upconv += [nn.BatchNorm2d(out_ch)]
126
+
127
+ self.upconv = nn.Sequential(*self.upconv)
128
+
129
+ def forward(self, x):
130
+
131
+ return self.upconv(x)
132
+
133
+
134
+ class UNet_3D_3D(nn.Module):
135
+ def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
136
+ super().__init__()
137
+
138
+ nf = [512 , 256 , 128 , 64]
139
+ out_channels = 3*n_outputs
140
+ self.joinType = joinType
141
+ self.n_outputs = n_outputs
142
+
143
+ growth = 2 if joinType == "concat" else 1
144
+ self.lrelu = nn.LeakyReLU(0.2, True)
145
+
146
+ unet_3D = importlib.import_module(".resnet_3D", "vfi_models.flavr")
147
+ if n_outputs > 1:
148
+ unet_3D.useBias = True
149
+ self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
150
+
151
+ self.decoder = nn.Sequential(
152
+ Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
153
+ upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
154
+ upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
155
+ Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
156
+ upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
157
+ )
158
+
159
+ self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
160
+
161
+ self.outconv = nn.Sequential(
162
+ nn.ReflectionPad2d(3),
163
+ nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
164
+ )
165
+
166
+ def forward(self, images):
167
+
168
+ images = torch.stack(images , dim=2)
169
+
170
+ ## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
171
+ mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
172
+ images = images-mean_
173
+
174
+ x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
175
+
176
+ dx_3 = self.lrelu(self.decoder[0](x_4))
177
+ dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
178
+
179
+ dx_2 = self.lrelu(self.decoder[1](dx_3))
180
+ dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
181
+
182
+ dx_1 = self.lrelu(self.decoder[2](dx_2))
183
+ dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
184
+
185
+ dx_0 = self.lrelu(self.decoder[3](dx_1))
186
+ dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
187
+
188
+ dx_out = self.lrelu(self.decoder[4](dx_0))
189
+ dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
190
+
191
+ out = self.lrelu(self.feature_fuse(dx_out))
192
+ out = self.outconv(out)
193
+
194
+ out = torch.split(out, dim=1, split_size_or_sections=3)
195
+ mean_ = mean_.squeeze(2)
196
+ out = [o+mean_ for o in out]
197
+
198
+ return out
199
+
200
+ class InputPadder:
201
+ """ Pads images such that dimensions are divisible by divisor """
202
+ def __init__(self, dims, divisor=16):
203
+ self.ht, self.wd = dims[-2:]
204
+ pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
205
+ pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
206
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
207
+
208
+ def pad(self, input_tensor):
209
+ return F.pad(input_tensor, self._pad, mode='replicate')
210
+
211
+ def unpad(self, input_tensor):
212
+ return self._unpad(input_tensor)
213
+
214
+ def _unpad(self, x):
215
+ ht, wd = x.shape[-2:]
216
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
217
+ return x[..., c[0]:c[1], c[2]:c[3]]
vfi_models/flavr/resnet_3D.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/pytorch/vision/tree/master/torchvision/models/video
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ __all__ = ['unet_18', 'unet_34']
7
+
8
+ useBias = False
9
+
10
+ class identity(nn.Module):
11
+
12
+ def __init__(self , *args , **kwargs):
13
+
14
+ super().__init__()
15
+
16
+ def forward(self , x):
17
+ return x
18
+
19
+ class Conv3DSimple(nn.Conv3d):
20
+ def __init__(self,
21
+ in_planes,
22
+ out_planes,
23
+ midplanes=None,
24
+ stride=1,
25
+ padding=1):
26
+
27
+ super(Conv3DSimple, self).__init__(
28
+ in_channels=in_planes,
29
+ out_channels=out_planes,
30
+ kernel_size=(3, 3, 3),
31
+ stride=stride,
32
+ padding=padding,
33
+ bias=useBias)
34
+
35
+ @staticmethod
36
+ def get_downsample_stride(stride , temporal_stride):
37
+ if temporal_stride:
38
+ return (temporal_stride, stride, stride)
39
+ else:
40
+ return (stride , stride , stride)
41
+
42
+ class BasicStem(nn.Sequential):
43
+ """The default conv-batchnorm-relu stem
44
+ """
45
+ def __init__(self):
46
+ super().__init__(
47
+ nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
48
+ padding=(1, 3, 3), bias=useBias),
49
+ batchnorm(64),
50
+ nn.ReLU(inplace=False))
51
+
52
+
53
+ class Conv2Plus1D(nn.Sequential):
54
+
55
+ def __init__(self,
56
+ in_planes,
57
+ out_planes,
58
+ midplanes,
59
+ stride=1,
60
+ padding=1):
61
+ if not isinstance(stride , int):
62
+ temporal_stride , stride , stride = stride
63
+ else:
64
+ temporal_stride = stride
65
+
66
+ super(Conv2Plus1D, self).__init__(
67
+ nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
68
+ stride=(1, stride, stride), padding=(0, padding, padding),
69
+ bias=False),
70
+ # batchnorm(midplanes),
71
+ nn.ReLU(inplace=True),
72
+ nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
73
+ stride=(temporal_stride, 1, 1), padding=(padding, 0, 0),
74
+ bias=False))
75
+
76
+ @staticmethod
77
+ def get_downsample_stride(stride , temporal_stride):
78
+ if temporal_stride:
79
+ return (temporal_stride, stride, stride)
80
+ else:
81
+ return (stride , stride , stride)
82
+
83
+ class R2Plus1dStem(nn.Sequential):
84
+ """R(2+1)D stem is different than the default one as it uses separated 3D convolution
85
+ """
86
+ def __init__(self):
87
+ super().__init__(
88
+ nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
89
+ stride=(1, 2, 2), padding=(0, 3, 3),
90
+ bias=False),
91
+ batchnorm(45),
92
+ nn.ReLU(inplace=True),
93
+ nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
94
+ stride=(1, 1, 1), padding=(1, 0, 0),
95
+ bias=False),
96
+ batchnorm(64),
97
+ nn.ReLU(inplace=True))
98
+
99
+
100
+ class SEGating(nn.Module):
101
+
102
+ def __init__(self , inplanes , reduction=16):
103
+
104
+ super().__init__()
105
+
106
+ self.pool = nn.AdaptiveAvgPool3d(1)
107
+ self.attn_layer = nn.Sequential(
108
+ nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
109
+ nn.Sigmoid()
110
+ )
111
+
112
+ def forward(self , x):
113
+
114
+ out = self.pool(x)
115
+ y = self.attn_layer(out)
116
+ return x * y
117
+
118
+ class BasicBlock(nn.Module):
119
+
120
+ expansion = 1
121
+
122
+ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
123
+ midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
124
+
125
+ super(BasicBlock, self).__init__()
126
+ self.conv1 = nn.Sequential(
127
+ conv_builder(inplanes, planes, midplanes, stride),
128
+ batchnorm(planes),
129
+ nn.ReLU(inplace=True)
130
+ )
131
+ self.conv2 = nn.Sequential(
132
+ conv_builder(planes, planes, midplanes),
133
+ batchnorm(planes)
134
+ )
135
+ self.fg = SEGating(planes) ## Feature Gating
136
+ self.relu = nn.ReLU(inplace=True)
137
+ self.downsample = downsample
138
+ self.stride = stride
139
+
140
+ def forward(self, x):
141
+ residual = x
142
+ out = self.conv1(x)
143
+ out = self.conv2(out)
144
+ out = self.fg(out)
145
+ if self.downsample is not None:
146
+ residual = self.downsample(x)
147
+
148
+ out += residual
149
+ out = self.relu(out)
150
+
151
+ return out
152
+
153
+ class VideoResNet(nn.Module):
154
+
155
+ def __init__(self, block, conv_makers, layers,
156
+ stem, zero_init_residual=False):
157
+ """Generic resnet video generator.
158
+
159
+ Args:
160
+ block (nn.Module): resnet building block
161
+ conv_makers (list(functions)): generator function for each layer
162
+ layers (List[int]): number of blocks per layer
163
+ stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
164
+ """
165
+ super(VideoResNet, self).__init__()
166
+ self.inplanes = 64
167
+
168
+ self.stem = stem()
169
+
170
+ self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1 )
171
+ self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2 , temporal_stride=1)
172
+ self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2 , temporal_stride=1)
173
+ self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=1, temporal_stride=1)
174
+
175
+ # init weights
176
+ self._initialize_weights()
177
+
178
+ if zero_init_residual:
179
+ for m in self.modules():
180
+ if isinstance(m, Bottleneck):
181
+ nn.init.constant_(m.bn3.weight, 0)
182
+
183
+ def forward(self, x):
184
+ x_0 = self.stem(x)
185
+ x_1 = self.layer1(x_0)
186
+ x_2 = self.layer2(x_1)
187
+ x_3 = self.layer3(x_2)
188
+ x_4 = self.layer4(x_3)
189
+ return x_0 , x_1 , x_2 , x_3 , x_4
190
+
191
+ def _make_layer(self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None):
192
+ downsample = None
193
+
194
+ if stride != 1 or self.inplanes != planes * block.expansion:
195
+ ds_stride = conv_builder.get_downsample_stride(stride , temporal_stride)
196
+ downsample = nn.Sequential(
197
+ nn.Conv3d(self.inplanes, planes * block.expansion,
198
+ kernel_size=1, stride=ds_stride, bias=False),
199
+ batchnorm(planes * block.expansion)
200
+ )
201
+ stride = ds_stride
202
+
203
+ layers = []
204
+ layers.append(block(self.inplanes, planes, conv_builder, stride, downsample ))
205
+
206
+ self.inplanes = planes * block.expansion
207
+ for i in range(1, blocks):
208
+ layers.append(block(self.inplanes, planes, conv_builder ))
209
+
210
+ return nn.Sequential(*layers)
211
+
212
+ def _initialize_weights(self):
213
+ for m in self.modules():
214
+ if isinstance(m, nn.Conv3d):
215
+ nn.init.kaiming_normal_(m.weight, mode='fan_out',
216
+ nonlinearity='relu')
217
+ if m.bias is not None:
218
+ nn.init.constant_(m.bias, 0)
219
+ elif isinstance(m, nn.BatchNorm3d):
220
+ nn.init.constant_(m.weight, 1)
221
+ nn.init.constant_(m.bias, 0)
222
+ elif isinstance(m, nn.Linear):
223
+ nn.init.normal_(m.weight, 0, 0.01)
224
+ nn.init.constant_(m.bias, 0)
225
+
226
+
227
+ def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
228
+ model = VideoResNet(**kwargs)
229
+ ## TODO: Other 3D resnet models, like S3D, r(2+1)D.
230
+
231
+ if pretrained:
232
+ state_dict = load_state_dict_from_url(model_urls[arch],
233
+ progress=progress)
234
+ model.load_state_dict(state_dict)
235
+ return model
236
+
237
+
238
+ def unet_18(pretrained=False, bn=False, progress=True, **kwargs):
239
+ """
240
+ Construct 18 layer Unet3D model as in
241
+ https://arxiv.org/abs/1711.11248
242
+
243
+ Args:
244
+ pretrained (bool): If True, returns a model pre-trained on Kinetics-400
245
+ progress (bool): If True, displays a progress bar of the download to stderr
246
+
247
+ Returns:
248
+ nn.Module: R3D-18 encoder
249
+ """
250
+ global batchnorm
251
+ if bn:
252
+ batchnorm = nn.BatchNorm3d
253
+ else:
254
+ batchnorm = identity
255
+
256
+ return _video_resnet('r3d_18',
257
+ pretrained, progress,
258
+ block=BasicBlock,
259
+ conv_makers=[Conv3DSimple] * 4,
260
+ layers=[2, 2, 2, 2],
261
+ stem=BasicStem, **kwargs)
262
+
263
+ def unet_34(pretrained=False, bn=False, progress=True, **kwargs):
264
+ """
265
+ Construct 34 layer Unet3D model as in
266
+ https://arxiv.org/abs/1711.11248
267
+
268
+ Args:
269
+ pretrained (bool): If True, returns a model pre-trained on Kinetics-400
270
+ progress (bool): If True, displays a progress bar of the download to stderr
271
+
272
+ Returns:
273
+ nn.Module: R3D-18 encoder
274
+ """
275
+ global batchnorm
276
+ # bn = False
277
+ if bn:
278
+ batchnorm = nn.BatchNorm3d
279
+ else:
280
+ batchnorm = identity
281
+
282
+
283
+ return _video_resnet('r3d_34',
284
+ pretrained, progress,
285
+ block=BasicBlock,
286
+ conv_makers=[Conv3DSimple] * 4,
287
+ layers=[3, 4, 6, 3],
288
+ stem=BasicStem, **kwargs)
vfi_models/gmfss_fortuna/GMFSS_Fortuna.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import vapoursynth as vs
4
+ from .GMFSS_Fortuna_arch import Model_inference
5
+ import torch
6
+ import traceback
7
+
8
+
9
+ class GMFSS_Fortuna:
10
+ def __init__(self):
11
+ self.cache = False
12
+ self.amount_input_img = 2
13
+
14
+ torch.set_grad_enabled(False)
15
+ torch.backends.cudnn.enabled = True
16
+ torch.backends.cudnn.benchmark = True
17
+
18
+ self.model = Model_inference()
19
+ self.model.eval()
20
+
21
+ def execute(self, I0, I1, timestep):
22
+ with torch.inference_mode():
23
+ middle = self.model(I0, I1, timestep).cpu()
24
+ return middle
vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py ADDED
@@ -0,0 +1,1850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/GMFSS_infer_b.py
3
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/softsplat.py
4
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FusionNet_b.py
5
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FeatureNet.py
6
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/MetricNet.py
7
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/IFNet_HDv3.py
8
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/gmflow.py
9
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/utils.py
10
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/position.py
11
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/geometry.py
12
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/matching.py
13
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/transformer.py
14
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/backbone.py
15
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/trident_conv.py
16
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/warplayer.py
17
+ """
18
+
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from torch.nn.modules.utils import _pair
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ import torch
27
+ import math
28
+ from vfi_models.rife.rife_arch import IFNet
29
+ from vfi_models.ops import softsplat
30
+ from comfy.model_management import get_torch_device
31
+
32
+ device = get_torch_device()
33
+ backwarp_tenGrid = {}
34
+
35
+
36
+ def warp(tenInput, tenFlow):
37
+ k = (str(tenFlow.device), str(tenFlow.size()))
38
+ if k not in backwarp_tenGrid:
39
+ tenHorizontal = (
40
+ torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
41
+ .view(1, 1, 1, tenFlow.shape[3])
42
+ .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
43
+ )
44
+ tenVertical = (
45
+ torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
46
+ .view(1, 1, tenFlow.shape[2], 1)
47
+ .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
48
+ )
49
+ backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
50
+
51
+ tenFlow = torch.cat(
52
+ [
53
+ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
54
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
55
+ ],
56
+ 1,
57
+ )
58
+
59
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
60
+ return torch.nn.functional.grid_sample(
61
+ input=tenInput,
62
+ grid=g,
63
+ mode="bilinear",
64
+ padding_mode="border",
65
+ align_corners=True,
66
+ )
67
+
68
+
69
+ class MultiScaleTridentConv(nn.Module):
70
+ def __init__(
71
+ self,
72
+ in_channels,
73
+ out_channels,
74
+ kernel_size,
75
+ stride=1,
76
+ strides=1,
77
+ paddings=0,
78
+ dilations=1,
79
+ dilation=1,
80
+ groups=1,
81
+ num_branch=1,
82
+ test_branch_idx=-1,
83
+ bias=False,
84
+ norm=None,
85
+ activation=None,
86
+ ):
87
+ super(MultiScaleTridentConv, self).__init__()
88
+ self.in_channels = in_channels
89
+ self.out_channels = out_channels
90
+ self.kernel_size = _pair(kernel_size)
91
+ self.num_branch = num_branch
92
+ self.stride = _pair(stride)
93
+ self.groups = groups
94
+ self.with_bias = bias
95
+ self.dilation = dilation
96
+ if isinstance(paddings, int):
97
+ paddings = [paddings] * self.num_branch
98
+ if isinstance(dilations, int):
99
+ dilations = [dilations] * self.num_branch
100
+ if isinstance(strides, int):
101
+ strides = [strides] * self.num_branch
102
+ self.paddings = [_pair(padding) for padding in paddings]
103
+ self.dilations = [_pair(dilation) for dilation in dilations]
104
+ self.strides = [_pair(stride) for stride in strides]
105
+ self.test_branch_idx = test_branch_idx
106
+ self.norm = norm
107
+ self.activation = activation
108
+
109
+ assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
110
+
111
+ self.weight = nn.Parameter(
112
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
113
+ )
114
+ if bias:
115
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
116
+ else:
117
+ self.bias = None
118
+
119
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
120
+ if self.bias is not None:
121
+ nn.init.constant_(self.bias, 0)
122
+
123
+ def forward(self, inputs):
124
+ num_branch = (
125
+ self.num_branch if self.training or self.test_branch_idx == -1 else 1
126
+ )
127
+ assert len(inputs) == num_branch
128
+
129
+ if self.training or self.test_branch_idx == -1:
130
+ outputs = [
131
+ F.conv2d(
132
+ input,
133
+ self.weight,
134
+ self.bias,
135
+ stride,
136
+ padding,
137
+ self.dilation,
138
+ self.groups,
139
+ )
140
+ for input, stride, padding in zip(inputs, self.strides, self.paddings)
141
+ ]
142
+ else:
143
+ outputs = [
144
+ F.conv2d(
145
+ inputs[0],
146
+ self.weight,
147
+ self.bias,
148
+ self.strides[self.test_branch_idx]
149
+ if self.test_branch_idx == -1
150
+ else self.strides[-1],
151
+ self.paddings[self.test_branch_idx]
152
+ if self.test_branch_idx == -1
153
+ else self.paddings[-1],
154
+ self.dilation,
155
+ self.groups,
156
+ )
157
+ ]
158
+
159
+ if self.norm is not None:
160
+ outputs = [self.norm(x) for x in outputs]
161
+ if self.activation is not None:
162
+ outputs = [self.activation(x) for x in outputs]
163
+ return outputs
164
+
165
+
166
+ class ResidualBlock_class(nn.Module):
167
+ def __init__(
168
+ self,
169
+ in_planes,
170
+ planes,
171
+ norm_layer=nn.InstanceNorm2d,
172
+ stride=1,
173
+ dilation=1,
174
+ ):
175
+ super(ResidualBlock_class, self).__init__()
176
+
177
+ self.conv1 = nn.Conv2d(
178
+ in_planes,
179
+ planes,
180
+ kernel_size=3,
181
+ dilation=dilation,
182
+ padding=dilation,
183
+ stride=stride,
184
+ bias=False,
185
+ )
186
+ self.conv2 = nn.Conv2d(
187
+ planes,
188
+ planes,
189
+ kernel_size=3,
190
+ dilation=dilation,
191
+ padding=dilation,
192
+ bias=False,
193
+ )
194
+ self.relu = nn.ReLU(inplace=True)
195
+
196
+ self.norm1 = norm_layer(planes)
197
+ self.norm2 = norm_layer(planes)
198
+ if not stride == 1 or in_planes != planes:
199
+ self.norm3 = norm_layer(planes)
200
+
201
+ if stride == 1 and in_planes == planes:
202
+ self.downsample = None
203
+ else:
204
+ self.downsample = nn.Sequential(
205
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
206
+ )
207
+
208
+ def forward(self, x):
209
+ y = x
210
+ y = self.relu(self.norm1(self.conv1(y)))
211
+ y = self.relu(self.norm2(self.conv2(y)))
212
+
213
+ if self.downsample is not None:
214
+ x = self.downsample(x)
215
+
216
+ return self.relu(x + y)
217
+
218
+
219
+ class CNNEncoder(nn.Module):
220
+ def __init__(
221
+ self,
222
+ output_dim=128,
223
+ norm_layer=nn.InstanceNorm2d,
224
+ num_output_scales=1,
225
+ **kwargs,
226
+ ):
227
+ super(CNNEncoder, self).__init__()
228
+ self.num_branch = num_output_scales
229
+
230
+ feature_dims = [64, 96, 128]
231
+
232
+ self.conv1 = nn.Conv2d(
233
+ 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False
234
+ ) # 1/2
235
+ self.norm1 = norm_layer(feature_dims[0])
236
+ self.relu1 = nn.ReLU(inplace=True)
237
+
238
+ self.in_planes = feature_dims[0]
239
+ self.layer1 = self._make_layer(
240
+ feature_dims[0], stride=1, norm_layer=norm_layer
241
+ ) # 1/2
242
+ self.layer2 = self._make_layer(
243
+ feature_dims[1], stride=2, norm_layer=norm_layer
244
+ ) # 1/4
245
+
246
+ # highest resolution 1/4 or 1/8
247
+ stride = 2 if num_output_scales == 1 else 1
248
+ self.layer3 = self._make_layer(
249
+ feature_dims[2],
250
+ stride=stride,
251
+ norm_layer=norm_layer,
252
+ ) # 1/4 or 1/8
253
+
254
+ self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
255
+
256
+ if self.num_branch > 1:
257
+ if self.num_branch == 4:
258
+ strides = (1, 2, 4, 8)
259
+ elif self.num_branch == 3:
260
+ strides = (1, 2, 4)
261
+ elif self.num_branch == 2:
262
+ strides = (1, 2)
263
+ else:
264
+ raise ValueError
265
+
266
+ self.trident_conv = MultiScaleTridentConv(
267
+ output_dim,
268
+ output_dim,
269
+ kernel_size=3,
270
+ strides=strides,
271
+ paddings=1,
272
+ num_branch=self.num_branch,
273
+ )
274
+
275
+ for m in self.modules():
276
+ if isinstance(m, nn.Conv2d):
277
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
278
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
279
+ if m.weight is not None:
280
+ nn.init.constant_(m.weight, 1)
281
+ if m.bias is not None:
282
+ nn.init.constant_(m.bias, 0)
283
+
284
+ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
285
+ layer1 = ResidualBlock_class(
286
+ self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation
287
+ )
288
+ layer2 = ResidualBlock_class(
289
+ dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation
290
+ )
291
+
292
+ layers = (layer1, layer2)
293
+
294
+ self.in_planes = dim
295
+ return nn.Sequential(*layers)
296
+
297
+ def forward(self, x):
298
+ x = self.conv1(x)
299
+ x = self.norm1(x)
300
+ x = self.relu1(x)
301
+
302
+ x = self.layer1(x) # 1/2
303
+ x = self.layer2(x) # 1/4
304
+ x = self.layer3(x) # 1/8 or 1/4
305
+
306
+ x = self.conv2(x)
307
+
308
+ if self.num_branch > 1:
309
+ out = self.trident_conv([x] * self.num_branch) # high to low res
310
+ else:
311
+ out = [x]
312
+
313
+ return out
314
+
315
+
316
+ def single_head_full_attention(q, k, v):
317
+ # q, k, v: [B, L, C]
318
+ assert q.dim() == k.dim() == v.dim() == 3
319
+
320
+ scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L]
321
+ attn = torch.softmax(scores, dim=2) # [B, L, L]
322
+ out = torch.matmul(attn, v) # [B, L, C]
323
+
324
+ return out
325
+
326
+
327
+ def generate_shift_window_attn_mask(
328
+ input_resolution,
329
+ window_size_h,
330
+ window_size_w,
331
+ shift_size_h,
332
+ shift_size_w,
333
+ device=get_torch_device(),
334
+ ):
335
+ # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
336
+ # calculate attention mask for SW-MSA
337
+ h, w = input_resolution
338
+ img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
339
+ h_slices = (
340
+ slice(0, -window_size_h),
341
+ slice(-window_size_h, -shift_size_h),
342
+ slice(-shift_size_h, None),
343
+ )
344
+ w_slices = (
345
+ slice(0, -window_size_w),
346
+ slice(-window_size_w, -shift_size_w),
347
+ slice(-shift_size_w, None),
348
+ )
349
+ cnt = 0
350
+ for h in h_slices:
351
+ for w in w_slices:
352
+ img_mask[:, h, w, :] = cnt
353
+ cnt += 1
354
+
355
+ mask_windows = split_feature(
356
+ img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
357
+ )
358
+
359
+ mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
360
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
361
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
362
+ attn_mask == 0, float(0.0)
363
+ )
364
+
365
+ return attn_mask
366
+
367
+
368
+ def single_head_split_window_attention(
369
+ q,
370
+ k,
371
+ v,
372
+ num_splits=1,
373
+ with_shift=False,
374
+ h=None,
375
+ w=None,
376
+ attn_mask=None,
377
+ ):
378
+ # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
379
+ # q, k, v: [B, L, C]
380
+ assert q.dim() == k.dim() == v.dim() == 3
381
+
382
+ assert h is not None and w is not None
383
+ assert q.size(1) == h * w
384
+
385
+ b, _, c = q.size()
386
+
387
+ b_new = b * num_splits * num_splits
388
+
389
+ window_size_h = h // num_splits
390
+ window_size_w = w // num_splits
391
+
392
+ q = q.view(b, h, w, c) # [B, H, W, C]
393
+ k = k.view(b, h, w, c)
394
+ v = v.view(b, h, w, c)
395
+
396
+ scale_factor = c**0.5
397
+
398
+ if with_shift:
399
+ assert attn_mask is not None # compute once
400
+ shift_size_h = window_size_h // 2
401
+ shift_size_w = window_size_w // 2
402
+
403
+ q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
404
+ k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
405
+ v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
406
+
407
+ q = split_feature(
408
+ q, num_splits=num_splits, channel_last=True
409
+ ) # [B*K*K, H/K, W/K, C]
410
+ k = split_feature(k, num_splits=num_splits, channel_last=True)
411
+ v = split_feature(v, num_splits=num_splits, channel_last=True)
412
+
413
+ scores = (
414
+ torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1))
415
+ / scale_factor
416
+ ) # [B*K*K, H/K*W/K, H/K*W/K]
417
+
418
+ if with_shift:
419
+ scores += attn_mask.repeat(b, 1, 1)
420
+
421
+ attn = torch.softmax(scores, dim=-1)
422
+
423
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
424
+
425
+ out = merge_splits(
426
+ out.view(b_new, h // num_splits, w // num_splits, c),
427
+ num_splits=num_splits,
428
+ channel_last=True,
429
+ ) # [B, H, W, C]
430
+
431
+ # shift back
432
+ if with_shift:
433
+ out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
434
+
435
+ out = out.view(b, -1, c)
436
+
437
+ return out
438
+
439
+
440
+ class TransformerLayer(nn.Module):
441
+ def __init__(
442
+ self,
443
+ d_model=256,
444
+ nhead=1,
445
+ attention_type="swin",
446
+ no_ffn=False,
447
+ ffn_dim_expansion=4,
448
+ with_shift=False,
449
+ **kwargs,
450
+ ):
451
+ super(TransformerLayer, self).__init__()
452
+
453
+ self.dim = d_model
454
+ self.nhead = nhead
455
+ self.attention_type = attention_type
456
+ self.no_ffn = no_ffn
457
+
458
+ self.with_shift = with_shift
459
+
460
+ # multi-head attention
461
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
462
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
463
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
464
+
465
+ self.merge = nn.Linear(d_model, d_model, bias=False)
466
+
467
+ self.norm1 = nn.LayerNorm(d_model)
468
+
469
+ # no ffn after self-attn, with ffn after cross-attn
470
+ if not self.no_ffn:
471
+ in_channels = d_model * 2
472
+ self.mlp = nn.Sequential(
473
+ nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
474
+ nn.GELU(),
475
+ nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
476
+ )
477
+
478
+ self.norm2 = nn.LayerNorm(d_model)
479
+
480
+ def forward(
481
+ self,
482
+ source,
483
+ target,
484
+ height=None,
485
+ width=None,
486
+ shifted_window_attn_mask=None,
487
+ attn_num_splits=None,
488
+ **kwargs,
489
+ ):
490
+ # source, target: [B, L, C]
491
+ query, key, value = source, target, target
492
+
493
+ # single-head attention
494
+ query = self.q_proj(query) # [B, L, C]
495
+ key = self.k_proj(key) # [B, L, C]
496
+ value = self.v_proj(value) # [B, L, C]
497
+
498
+ if self.attention_type == "swin" and attn_num_splits > 1:
499
+ if self.nhead > 1:
500
+ # we observe that multihead attention slows down the speed and increases the memory consumption
501
+ # without bringing obvious performance gains and thus the implementation is removed
502
+ raise NotImplementedError
503
+ else:
504
+ message = single_head_split_window_attention(
505
+ query,
506
+ key,
507
+ value,
508
+ num_splits=attn_num_splits,
509
+ with_shift=self.with_shift,
510
+ h=height,
511
+ w=width,
512
+ attn_mask=shifted_window_attn_mask,
513
+ )
514
+ else:
515
+ message = single_head_full_attention(query, key, value) # [B, L, C]
516
+
517
+ message = self.merge(message) # [B, L, C]
518
+ message = self.norm1(message)
519
+
520
+ if not self.no_ffn:
521
+ message = self.mlp(torch.cat([source, message], dim=-1))
522
+ message = self.norm2(message)
523
+
524
+ return source + message
525
+
526
+
527
+ class TransformerBlock(nn.Module):
528
+ """self attention + cross attention + FFN"""
529
+
530
+ def __init__(
531
+ self,
532
+ d_model=256,
533
+ nhead=1,
534
+ attention_type="swin",
535
+ ffn_dim_expansion=4,
536
+ with_shift=False,
537
+ **kwargs,
538
+ ):
539
+ super(TransformerBlock, self).__init__()
540
+
541
+ self.self_attn = TransformerLayer(
542
+ d_model=d_model,
543
+ nhead=nhead,
544
+ attention_type=attention_type,
545
+ no_ffn=True,
546
+ ffn_dim_expansion=ffn_dim_expansion,
547
+ with_shift=with_shift,
548
+ )
549
+
550
+ self.cross_attn_ffn = TransformerLayer(
551
+ d_model=d_model,
552
+ nhead=nhead,
553
+ attention_type=attention_type,
554
+ ffn_dim_expansion=ffn_dim_expansion,
555
+ with_shift=with_shift,
556
+ )
557
+
558
+ def forward(
559
+ self,
560
+ source,
561
+ target,
562
+ height=None,
563
+ width=None,
564
+ shifted_window_attn_mask=None,
565
+ attn_num_splits=None,
566
+ **kwargs,
567
+ ):
568
+ # source, target: [B, L, C]
569
+
570
+ # self attention
571
+ source = self.self_attn(
572
+ source,
573
+ source,
574
+ height=height,
575
+ width=width,
576
+ shifted_window_attn_mask=shifted_window_attn_mask,
577
+ attn_num_splits=attn_num_splits,
578
+ )
579
+
580
+ # cross attention and ffn
581
+ source = self.cross_attn_ffn(
582
+ source,
583
+ target,
584
+ height=height,
585
+ width=width,
586
+ shifted_window_attn_mask=shifted_window_attn_mask,
587
+ attn_num_splits=attn_num_splits,
588
+ )
589
+
590
+ return source
591
+
592
+
593
+ class FeatureTransformer(nn.Module):
594
+ def __init__(
595
+ self,
596
+ num_layers=6,
597
+ d_model=128,
598
+ nhead=1,
599
+ attention_type="swin",
600
+ ffn_dim_expansion=4,
601
+ **kwargs,
602
+ ):
603
+ super(FeatureTransformer, self).__init__()
604
+
605
+ self.attention_type = attention_type
606
+
607
+ self.d_model = d_model
608
+ self.nhead = nhead
609
+
610
+ self.layers = nn.ModuleList(
611
+ [
612
+ TransformerBlock(
613
+ d_model=d_model,
614
+ nhead=nhead,
615
+ attention_type=attention_type,
616
+ ffn_dim_expansion=ffn_dim_expansion,
617
+ with_shift=True
618
+ if attention_type == "swin" and i % 2 == 1
619
+ else False,
620
+ )
621
+ for i in range(num_layers)
622
+ ]
623
+ )
624
+
625
+ for p in self.parameters():
626
+ if p.dim() > 1:
627
+ nn.init.xavier_uniform_(p)
628
+
629
+ def forward(
630
+ self,
631
+ feature0,
632
+ feature1,
633
+ attn_num_splits=None,
634
+ **kwargs,
635
+ ):
636
+ b, c, h, w = feature0.shape
637
+ assert self.d_model == c
638
+
639
+ feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
640
+ feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
641
+
642
+ if self.attention_type == "swin" and attn_num_splits > 1:
643
+ # global and refine use different number of splits
644
+ window_size_h = h // attn_num_splits
645
+ window_size_w = w // attn_num_splits
646
+
647
+ # compute attn mask once
648
+ shifted_window_attn_mask = generate_shift_window_attn_mask(
649
+ input_resolution=(h, w),
650
+ window_size_h=window_size_h,
651
+ window_size_w=window_size_w,
652
+ shift_size_h=window_size_h // 2,
653
+ shift_size_w=window_size_w // 2,
654
+ device=feature0.device,
655
+ ) # [K*K, H/K*W/K, H/K*W/K]
656
+ else:
657
+ shifted_window_attn_mask = None
658
+
659
+ # concat feature0 and feature1 in batch dimension to compute in parallel
660
+ concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
661
+ concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
662
+
663
+ for layer in self.layers:
664
+ concat0 = layer(
665
+ concat0,
666
+ concat1,
667
+ height=h,
668
+ width=w,
669
+ shifted_window_attn_mask=shifted_window_attn_mask,
670
+ attn_num_splits=attn_num_splits,
671
+ )
672
+
673
+ # update feature1
674
+ concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
675
+
676
+ feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
677
+
678
+ # reshape back
679
+ feature0 = (
680
+ feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
681
+ ) # [B, C, H, W]
682
+ feature1 = (
683
+ feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
684
+ ) # [B, C, H, W]
685
+
686
+ return feature0, feature1
687
+
688
+
689
+ class FeatureFlowAttention(nn.Module):
690
+ """
691
+ flow propagation with self-attention on feature
692
+ query: feature0, key: feature0, value: flow
693
+ """
694
+
695
+ def __init__(
696
+ self,
697
+ in_channels,
698
+ **kwargs,
699
+ ):
700
+ super(FeatureFlowAttention, self).__init__()
701
+
702
+ self.q_proj = nn.Linear(in_channels, in_channels)
703
+ self.k_proj = nn.Linear(in_channels, in_channels)
704
+
705
+ for p in self.parameters():
706
+ if p.dim() > 1:
707
+ nn.init.xavier_uniform_(p)
708
+
709
+ def forward(
710
+ self,
711
+ feature0,
712
+ flow,
713
+ local_window_attn=False,
714
+ local_window_radius=1,
715
+ **kwargs,
716
+ ):
717
+ # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
718
+ if local_window_attn:
719
+ return self.forward_local_window_attn(
720
+ feature0, flow, local_window_radius=local_window_radius
721
+ )
722
+
723
+ b, c, h, w = feature0.size()
724
+
725
+ query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
726
+
727
+ # a note: the ``correct'' implementation should be:
728
+ # ``query = self.q_proj(query), key = self.k_proj(query)''
729
+ # this problem is observed while cleaning up the code
730
+ # however, this doesn't affect the performance since the projection is a linear operation,
731
+ # thus the two projection matrices for key can be merged
732
+ # so I just leave it as is in order to not re-train all models :)
733
+ query = self.q_proj(query) # [B, H*W, C]
734
+ key = self.k_proj(query) # [B, H*W, C]
735
+
736
+ value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
737
+
738
+ scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W]
739
+ prob = torch.softmax(scores, dim=-1)
740
+
741
+ out = torch.matmul(prob, value) # [B, H*W, 2]
742
+ out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
743
+
744
+ return out
745
+
746
+ def forward_local_window_attn(
747
+ self,
748
+ feature0,
749
+ flow,
750
+ local_window_radius=1,
751
+ ):
752
+ assert flow.size(1) == 2
753
+ assert local_window_radius > 0
754
+
755
+ b, c, h, w = feature0.size()
756
+
757
+ feature0_reshape = self.q_proj(
758
+ feature0.view(b, c, -1).permute(0, 2, 1)
759
+ ).reshape(
760
+ b * h * w, 1, c
761
+ ) # [B*H*W, 1, C]
762
+
763
+ kernel_size = 2 * local_window_radius + 1
764
+
765
+ feature0_proj = (
766
+ self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1))
767
+ .permute(0, 2, 1)
768
+ .reshape(b, c, h, w)
769
+ )
770
+
771
+ feature0_window = F.unfold(
772
+ feature0_proj, kernel_size=kernel_size, padding=local_window_radius
773
+ ) # [B, C*(2R+1)^2), H*W]
774
+
775
+ feature0_window = (
776
+ feature0_window.view(b, c, kernel_size**2, h, w)
777
+ .permute(0, 3, 4, 1, 2)
778
+ .reshape(b * h * w, c, kernel_size**2)
779
+ ) # [B*H*W, C, (2R+1)^2]
780
+
781
+ flow_window = F.unfold(
782
+ flow, kernel_size=kernel_size, padding=local_window_radius
783
+ ) # [B, 2*(2R+1)^2), H*W]
784
+
785
+ flow_window = (
786
+ flow_window.view(b, 2, kernel_size**2, h, w)
787
+ .permute(0, 3, 4, 2, 1)
788
+ .reshape(b * h * w, kernel_size**2, 2)
789
+ ) # [B*H*W, (2R+1)^2, 2]
790
+
791
+ scores = torch.matmul(feature0_reshape, feature0_window) / (
792
+ c**0.5
793
+ ) # [B*H*W, 1, (2R+1)^2]
794
+
795
+ prob = torch.softmax(scores, dim=-1)
796
+
797
+ out = (
798
+ torch.matmul(prob, flow_window)
799
+ .view(b, h, w, 2)
800
+ .permute(0, 3, 1, 2)
801
+ .contiguous()
802
+ ) # [B, 2, H, W]
803
+
804
+ return out
805
+
806
+
807
+ def global_correlation_softmax(
808
+ feature0,
809
+ feature1,
810
+ pred_bidir_flow=False,
811
+ ):
812
+ # global correlation
813
+ b, c, h, w = feature0.shape
814
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
815
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
816
+
817
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
818
+ c**0.5
819
+ ) # [B, H, W, H, W]
820
+
821
+ # flow from softmax
822
+ init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
823
+ grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
824
+
825
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
826
+
827
+ if pred_bidir_flow:
828
+ correlation = torch.cat(
829
+ (correlation, correlation.permute(0, 2, 1)), dim=0
830
+ ) # [2*B, H*W, H*W]
831
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
832
+ grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
833
+ b = b * 2
834
+
835
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
836
+
837
+ correspondence = (
838
+ torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2)
839
+ ) # [B, 2, H, W]
840
+
841
+ # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
842
+ flow = correspondence - init_grid
843
+
844
+ return flow, prob
845
+
846
+
847
+ def local_correlation_softmax(
848
+ feature0,
849
+ feature1,
850
+ local_radius,
851
+ padding_mode="zeros",
852
+ ):
853
+ b, c, h, w = feature0.size()
854
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
855
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
856
+
857
+ local_h = 2 * local_radius + 1
858
+ local_w = 2 * local_radius + 1
859
+
860
+ window_grid = generate_window_grid(
861
+ -local_radius,
862
+ local_radius,
863
+ -local_radius,
864
+ local_radius,
865
+ local_h,
866
+ local_w,
867
+ device=feature0.device,
868
+ ) # [2R+1, 2R+1, 2]
869
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
870
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
871
+
872
+ sample_coords_softmax = sample_coords
873
+
874
+ # exclude coords that are out of image space
875
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (
876
+ sample_coords[:, :, :, 0] < w
877
+ ) # [B, H*W, (2R+1)^2]
878
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (
879
+ sample_coords[:, :, :, 1] < h
880
+ ) # [B, H*W, (2R+1)^2]
881
+
882
+ valid = (
883
+ valid_x & valid_y
884
+ ) # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
885
+
886
+ # normalize coordinates to [-1, 1]
887
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
888
+ window_feature = F.grid_sample(
889
+ feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True
890
+ ).permute(
891
+ 0, 2, 1, 3
892
+ ) # [B, H*W, C, (2R+1)^2]
893
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
894
+
895
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
896
+ c**0.5
897
+ ) # [B, H*W, (2R+1)^2]
898
+
899
+ # mask invalid locations
900
+ corr[~valid] = -1e9
901
+
902
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
903
+
904
+ correspondence = (
905
+ torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
906
+ .squeeze(-2)
907
+ .view(b, h, w, 2)
908
+ .permute(0, 3, 1, 2)
909
+ ) # [B, 2, H, W]
910
+
911
+ flow = correspondence - coords_init
912
+ match_prob = prob
913
+
914
+ return flow, match_prob
915
+
916
+
917
+ def coords_grid(b, h, w, homogeneous=False, device=None):
918
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
919
+
920
+ stacks = [x, y]
921
+
922
+ if homogeneous:
923
+ ones = torch.ones_like(x) # [H, W]
924
+ stacks.append(ones)
925
+
926
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
927
+
928
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
929
+
930
+ if device is not None:
931
+ grid = grid.to(device)
932
+
933
+ return grid
934
+
935
+
936
+ def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
937
+ assert device is not None
938
+
939
+ x, y = torch.meshgrid(
940
+ [
941
+ torch.linspace(w_min, w_max, len_w, device=device),
942
+ torch.linspace(h_min, h_max, len_h, device=device),
943
+ ],
944
+ )
945
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
946
+
947
+ return grid
948
+
949
+
950
+ def normalize_coords(coords, h, w):
951
+ # coords: [B, H, W, 2]
952
+ c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
953
+ return (coords - c) / c # [-1, 1]
954
+
955
+
956
+ def bilinear_sample(
957
+ img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False
958
+ ):
959
+ # img: [B, C, H, W]
960
+ # sample_coords: [B, 2, H, W] in image scale
961
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
962
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
963
+
964
+ b, _, h, w = sample_coords.shape
965
+
966
+ # Normalize to [-1, 1]
967
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
968
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
969
+
970
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
971
+
972
+ img = F.grid_sample(
973
+ img, grid, mode=mode, padding_mode=padding_mode, align_corners=True
974
+ )
975
+
976
+ if return_mask:
977
+ mask = (
978
+ (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1)
979
+ ) # [B, H, W]
980
+
981
+ return img, mask
982
+
983
+ return img
984
+
985
+
986
+ def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
987
+ b, c, h, w = feature.size()
988
+ assert flow.size(1) == 2
989
+
990
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
991
+
992
+ return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
993
+
994
+
995
+ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
996
+ # fwd_flow, bwd_flow: [B, 2, H, W]
997
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
998
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
999
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
1000
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
1001
+
1002
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
1003
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
1004
+
1005
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
1006
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
1007
+
1008
+ threshold = alpha * flow_mag + beta
1009
+
1010
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
1011
+ bwd_occ = (diff_bwd > threshold).float()
1012
+
1013
+ return fwd_occ, bwd_occ
1014
+
1015
+
1016
+ class PositionEmbeddingSine(nn.Module):
1017
+ """
1018
+ This is a more standard version of the position embedding, very similar to the one
1019
+ used by the Attention is all you need paper, generalized to work on images.
1020
+ """
1021
+
1022
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
1023
+ super().__init__()
1024
+ self.num_pos_feats = num_pos_feats
1025
+ self.temperature = temperature
1026
+ self.normalize = normalize
1027
+ if scale is not None and normalize is False:
1028
+ raise ValueError("normalize should be True if scale is passed")
1029
+ if scale is None:
1030
+ scale = 2 * math.pi
1031
+ self.scale = scale
1032
+
1033
+ def forward(self, x):
1034
+ # x = tensor_list.tensors # [B, C, H, W]
1035
+ # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
1036
+ b, c, h, w = x.size()
1037
+ mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
1038
+ y_embed = mask.cumsum(1, dtype=torch.float32)
1039
+ x_embed = mask.cumsum(2, dtype=torch.float32)
1040
+ if self.normalize:
1041
+ eps = 1e-6
1042
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
1043
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
1044
+
1045
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
1046
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
1047
+
1048
+ pos_x = x_embed[:, :, :, None] / dim_t
1049
+ pos_y = y_embed[:, :, :, None] / dim_t
1050
+ pos_x = torch.stack(
1051
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
1052
+ ).flatten(3)
1053
+ pos_y = torch.stack(
1054
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
1055
+ ).flatten(3)
1056
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
1057
+ return pos
1058
+
1059
+
1060
+ def split_feature(
1061
+ feature,
1062
+ num_splits=2,
1063
+ channel_last=False,
1064
+ ):
1065
+ if channel_last: # [B, H, W, C]
1066
+ b, h, w, c = feature.size()
1067
+ assert h % num_splits == 0 and w % num_splits == 0
1068
+
1069
+ b_new = b * num_splits * num_splits
1070
+ h_new = h // num_splits
1071
+ w_new = w // num_splits
1072
+
1073
+ feature = (
1074
+ feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
1075
+ .permute(0, 1, 3, 2, 4, 5)
1076
+ .reshape(b_new, h_new, w_new, c)
1077
+ ) # [B*K*K, H/K, W/K, C]
1078
+ else: # [B, C, H, W]
1079
+ b, c, h, w = feature.size()
1080
+ assert h % num_splits == 0 and w % num_splits == 0
1081
+
1082
+ b_new = b * num_splits * num_splits
1083
+ h_new = h // num_splits
1084
+ w_new = w // num_splits
1085
+
1086
+ feature = (
1087
+ feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
1088
+ .permute(0, 2, 4, 1, 3, 5)
1089
+ .reshape(b_new, c, h_new, w_new)
1090
+ ) # [B*K*K, C, H/K, W/K]
1091
+
1092
+ return feature
1093
+
1094
+
1095
+ def merge_splits(
1096
+ splits,
1097
+ num_splits=2,
1098
+ channel_last=False,
1099
+ ):
1100
+ if channel_last: # [B*K*K, H/K, W/K, C]
1101
+ b, h, w, c = splits.size()
1102
+ new_b = b // num_splits // num_splits
1103
+
1104
+ splits = splits.view(new_b, num_splits, num_splits, h, w, c)
1105
+ merge = (
1106
+ splits.permute(0, 1, 3, 2, 4, 5)
1107
+ .contiguous()
1108
+ .view(new_b, num_splits * h, num_splits * w, c)
1109
+ ) # [B, H, W, C]
1110
+ else: # [B*K*K, C, H/K, W/K]
1111
+ b, c, h, w = splits.size()
1112
+ new_b = b // num_splits // num_splits
1113
+
1114
+ splits = splits.view(new_b, num_splits, num_splits, c, h, w)
1115
+ merge = (
1116
+ splits.permute(0, 3, 1, 4, 2, 5)
1117
+ .contiguous()
1118
+ .view(new_b, c, num_splits * h, num_splits * w)
1119
+ ) # [B, C, H, W]
1120
+
1121
+ return merge
1122
+
1123
+
1124
+ def normalize_img(img0, img1):
1125
+ # loaded images are in [0, 255]
1126
+ # normalize by ImageNet mean and std
1127
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
1128
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
1129
+ img0 = (img0 - mean) / std
1130
+ img1 = (img1 - mean) / std
1131
+
1132
+ return img0, img1
1133
+
1134
+
1135
+ def feature_add_position(feature0, feature1, attn_splits, feature_channels):
1136
+ pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
1137
+
1138
+ if attn_splits > 1: # add position in splited window
1139
+ feature0_splits = split_feature(feature0, num_splits=attn_splits)
1140
+ feature1_splits = split_feature(feature1, num_splits=attn_splits)
1141
+
1142
+ position = pos_enc(feature0_splits)
1143
+
1144
+ feature0_splits = feature0_splits + position
1145
+ feature1_splits = feature1_splits + position
1146
+
1147
+ feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
1148
+ feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
1149
+ else:
1150
+ position = pos_enc(feature0)
1151
+
1152
+ feature0 = feature0 + position
1153
+ feature1 = feature1 + position
1154
+
1155
+ return feature0, feature1
1156
+
1157
+
1158
+ class GMFlow(nn.Module):
1159
+ def __init__(
1160
+ self,
1161
+ num_scales=2,
1162
+ upsample_factor=4,
1163
+ feature_channels=128,
1164
+ attention_type="swin",
1165
+ num_transformer_layers=6,
1166
+ ffn_dim_expansion=4,
1167
+ num_head=1,
1168
+ **kwargs,
1169
+ ):
1170
+ super(GMFlow, self).__init__()
1171
+
1172
+ self.num_scales = num_scales
1173
+ self.feature_channels = feature_channels
1174
+ self.upsample_factor = upsample_factor
1175
+ self.attention_type = attention_type
1176
+ self.num_transformer_layers = num_transformer_layers
1177
+
1178
+ # CNN backbone
1179
+ self.backbone = CNNEncoder(
1180
+ output_dim=feature_channels, num_output_scales=num_scales
1181
+ )
1182
+
1183
+ # Transformer
1184
+ self.transformer = FeatureTransformer(
1185
+ num_layers=num_transformer_layers,
1186
+ d_model=feature_channels,
1187
+ nhead=num_head,
1188
+ attention_type=attention_type,
1189
+ ffn_dim_expansion=ffn_dim_expansion,
1190
+ )
1191
+
1192
+ # flow propagation with self-attn
1193
+ self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
1194
+
1195
+ # convex upsampling: concat feature0 and flow as input
1196
+ self.upsampler = nn.Sequential(
1197
+ nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
1198
+ nn.ReLU(inplace=True),
1199
+ nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
1200
+ )
1201
+
1202
+ def extract_feature(self, img0, img1):
1203
+ concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
1204
+ features = self.backbone(
1205
+ concat
1206
+ ) # list of [2B, C, H, W], resolution from high to low
1207
+
1208
+ # reverse: resolution from low to high
1209
+ features = features[::-1]
1210
+
1211
+ feature0, feature1 = [], []
1212
+
1213
+ for i in range(len(features)):
1214
+ feature = features[i]
1215
+ chunks = torch.chunk(feature, 2, 0) # tuple
1216
+ feature0.append(chunks[0])
1217
+ feature1.append(chunks[1])
1218
+
1219
+ return feature0, feature1
1220
+
1221
+ def upsample_flow(
1222
+ self,
1223
+ flow,
1224
+ feature,
1225
+ bilinear=False,
1226
+ upsample_factor=8,
1227
+ ):
1228
+ if bilinear:
1229
+ up_flow = (
1230
+ F.interpolate(
1231
+ flow,
1232
+ scale_factor=upsample_factor,
1233
+ mode="bilinear",
1234
+ align_corners=True,
1235
+ )
1236
+ * upsample_factor
1237
+ )
1238
+
1239
+ else:
1240
+ # convex upsampling
1241
+ concat = torch.cat((flow, feature), dim=1)
1242
+
1243
+ mask = self.upsampler(concat)
1244
+ b, flow_channel, h, w = flow.shape
1245
+ mask = mask.view(
1246
+ b, 1, 9, self.upsample_factor, self.upsample_factor, h, w
1247
+ ) # [B, 1, 9, K, K, H, W]
1248
+ mask = torch.softmax(mask, dim=2)
1249
+
1250
+ up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
1251
+ up_flow = up_flow.view(
1252
+ b, flow_channel, 9, 1, 1, h, w
1253
+ ) # [B, 2, 9, 1, 1, H, W]
1254
+
1255
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
1256
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
1257
+ up_flow = up_flow.reshape(
1258
+ b, flow_channel, self.upsample_factor * h, self.upsample_factor * w
1259
+ ) # [B, 2, K*H, K*W]
1260
+
1261
+ return up_flow
1262
+
1263
+ def forward(
1264
+ self,
1265
+ img0,
1266
+ img1,
1267
+ attn_splits_list=[2, 8],
1268
+ corr_radius_list=[-1, 4],
1269
+ prop_radius_list=[-1, 1],
1270
+ pred_bidir_flow=False,
1271
+ **kwargs,
1272
+ ):
1273
+ img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
1274
+
1275
+ # resolution low to high
1276
+ feature0_list, feature1_list = self.extract_feature(
1277
+ img0, img1
1278
+ ) # list of features
1279
+
1280
+ flow = None
1281
+
1282
+ assert (
1283
+ len(attn_splits_list)
1284
+ == len(corr_radius_list)
1285
+ == len(prop_radius_list)
1286
+ == self.num_scales
1287
+ )
1288
+
1289
+ for scale_idx in range(self.num_scales):
1290
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
1291
+
1292
+ if pred_bidir_flow and scale_idx > 0:
1293
+ # predicting bidirectional flow with refinement
1294
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
1295
+ (feature1, feature0), dim=0
1296
+ )
1297
+
1298
+ upsample_factor = self.upsample_factor * (
1299
+ 2 ** (self.num_scales - 1 - scale_idx)
1300
+ )
1301
+
1302
+ if scale_idx > 0:
1303
+ flow = (
1304
+ F.interpolate(
1305
+ flow, scale_factor=2, mode="bilinear", align_corners=True
1306
+ )
1307
+ * 2
1308
+ )
1309
+
1310
+ if flow is not None:
1311
+ flow = flow.detach()
1312
+ feature1 = flow_warp(feature1, flow) # [B, C, H, W]
1313
+
1314
+ attn_splits = attn_splits_list[scale_idx]
1315
+ corr_radius = corr_radius_list[scale_idx]
1316
+ prop_radius = prop_radius_list[scale_idx]
1317
+
1318
+ # add position to features
1319
+ feature0, feature1 = feature_add_position(
1320
+ feature0, feature1, attn_splits, self.feature_channels
1321
+ )
1322
+
1323
+ # Transformer
1324
+ feature0, feature1 = self.transformer(
1325
+ feature0, feature1, attn_num_splits=attn_splits
1326
+ )
1327
+
1328
+ # correlation and softmax
1329
+ if corr_radius == -1: # global matching
1330
+ flow_pred = global_correlation_softmax(
1331
+ feature0, feature1, pred_bidir_flow
1332
+ )[0]
1333
+ else: # local matching
1334
+ flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[
1335
+ 0
1336
+ ]
1337
+
1338
+ # flow or residual flow
1339
+ flow = flow + flow_pred if flow is not None else flow_pred
1340
+
1341
+ # upsample to the original resolution for supervison
1342
+ if (
1343
+ self.training
1344
+ ): # only need to upsample intermediate flow predictions at training time
1345
+ flow_bilinear = self.upsample_flow(
1346
+ flow, None, bilinear=True, upsample_factor=upsample_factor
1347
+ )
1348
+
1349
+ # flow propagation with self-attn
1350
+ if pred_bidir_flow and scale_idx == 0:
1351
+ feature0 = torch.cat(
1352
+ (feature0, feature1), dim=0
1353
+ ) # [2*B, C, H, W] for propagation
1354
+ flow = self.feature_flow_attn(
1355
+ feature0,
1356
+ flow.detach(),
1357
+ local_window_attn=prop_radius > 0,
1358
+ local_window_radius=prop_radius,
1359
+ )
1360
+
1361
+ # bilinear upsampling at training time except the last one
1362
+ if self.training and scale_idx < self.num_scales - 1:
1363
+ flow_up = self.upsample_flow(
1364
+ flow, feature0, bilinear=True, upsample_factor=upsample_factor
1365
+ )
1366
+
1367
+ if scale_idx == self.num_scales - 1:
1368
+ flow_up = self.upsample_flow(flow, feature0)
1369
+
1370
+ return flow_up
1371
+
1372
+
1373
+ backwarp_tenGrid = {}
1374
+
1375
+
1376
+ def backwarp(tenIn, tenflow):
1377
+ if str(tenflow.shape) not in backwarp_tenGrid:
1378
+ tenHor = (
1379
+ torch.linspace(
1380
+ start=-1.0,
1381
+ end=1.0,
1382
+ steps=tenflow.shape[3],
1383
+ dtype=tenflow.dtype,
1384
+ device=tenflow.device,
1385
+ )
1386
+ .view(1, 1, 1, -1)
1387
+ .repeat(1, 1, tenflow.shape[2], 1)
1388
+ )
1389
+ tenVer = (
1390
+ torch.linspace(
1391
+ start=-1.0,
1392
+ end=1.0,
1393
+ steps=tenflow.shape[2],
1394
+ dtype=tenflow.dtype,
1395
+ device=tenflow.device,
1396
+ )
1397
+ .view(1, 1, -1, 1)
1398
+ .repeat(1, 1, 1, tenflow.shape[3])
1399
+ )
1400
+
1401
+ backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).to(get_torch_device())
1402
+ # end
1403
+
1404
+ tenflow = torch.cat(
1405
+ [
1406
+ tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0),
1407
+ tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0),
1408
+ ],
1409
+ 1,
1410
+ )
1411
+
1412
+ return torch.nn.functional.grid_sample(
1413
+ input=tenIn,
1414
+ grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1),
1415
+ mode="bilinear",
1416
+ padding_mode="zeros",
1417
+ align_corners=True,
1418
+ )
1419
+
1420
+
1421
+ class MetricNet(nn.Module):
1422
+ def __init__(self):
1423
+ super(MetricNet, self).__init__()
1424
+ self.metric_in = nn.Conv2d(14, 64, 3, 1, 1)
1425
+ self.metric_net1 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
1426
+ self.metric_net2 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
1427
+ self.metric_net3 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
1428
+ self.metric_out = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 2, 3, 1, 1))
1429
+
1430
+ def forward(self, img0, img1, flow01, flow10):
1431
+ metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction="none").mean(
1432
+ [1], True
1433
+ )
1434
+ metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction="none").mean(
1435
+ [1], True
1436
+ )
1437
+
1438
+ fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10)
1439
+
1440
+ flow01 = torch.cat(
1441
+ [
1442
+ flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0),
1443
+ flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0),
1444
+ ],
1445
+ 1,
1446
+ )
1447
+ flow10 = torch.cat(
1448
+ [
1449
+ flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0),
1450
+ flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0),
1451
+ ],
1452
+ 1,
1453
+ )
1454
+
1455
+ img = torch.cat((img0, img1), 1)
1456
+ metric = torch.cat((-metric0, -metric1), 1)
1457
+ flow = torch.cat((flow01, flow10), 1)
1458
+ occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1)
1459
+
1460
+ feat = self.metric_in(torch.cat((img, metric, flow, occ), 1))
1461
+ feat = self.metric_net1(feat) + feat
1462
+ feat = self.metric_net2(feat) + feat
1463
+ feat = self.metric_net3(feat) + feat
1464
+ metric = self.metric_out(feat)
1465
+
1466
+ metric = torch.tanh(metric) * 10
1467
+
1468
+ return metric[:, :1], metric[:, 1:2]
1469
+
1470
+
1471
+ class FeatureNet(nn.Module):
1472
+ """The quadratic model"""
1473
+
1474
+ def __init__(self):
1475
+ super(FeatureNet, self).__init__()
1476
+ self.block1 = nn.Sequential(
1477
+ nn.PReLU(),
1478
+ nn.Conv2d(3, 64, 3, 2, 1),
1479
+ nn.PReLU(),
1480
+ nn.Conv2d(64, 64, 3, 1, 1),
1481
+ )
1482
+ self.block2 = nn.Sequential(
1483
+ nn.PReLU(),
1484
+ nn.Conv2d(64, 128, 3, 2, 1),
1485
+ nn.PReLU(),
1486
+ nn.Conv2d(128, 128, 3, 1, 1),
1487
+ )
1488
+ self.block3 = nn.Sequential(
1489
+ nn.PReLU(),
1490
+ nn.Conv2d(128, 192, 3, 2, 1),
1491
+ nn.PReLU(),
1492
+ nn.Conv2d(192, 192, 3, 1, 1),
1493
+ )
1494
+
1495
+ def forward(self, x):
1496
+ x1 = self.block1(x)
1497
+ x2 = self.block2(x1)
1498
+ x3 = self.block3(x2)
1499
+
1500
+ return x1, x2, x3
1501
+
1502
+
1503
+ # Residual Block
1504
+ def ResidualBlock(in_channels, out_channels, stride=1):
1505
+ return torch.nn.Sequential(
1506
+ nn.PReLU(),
1507
+ nn.Conv2d(
1508
+ in_channels,
1509
+ out_channels,
1510
+ kernel_size=3,
1511
+ stride=stride,
1512
+ padding=1,
1513
+ bias=True,
1514
+ ),
1515
+ nn.PReLU(),
1516
+ nn.Conv2d(
1517
+ out_channels,
1518
+ out_channels,
1519
+ kernel_size=3,
1520
+ stride=stride,
1521
+ padding=1,
1522
+ bias=True,
1523
+ ),
1524
+ )
1525
+
1526
+
1527
+ # downsample block
1528
+ def DownsampleBlock(in_channels, out_channels, stride=2):
1529
+ return torch.nn.Sequential(
1530
+ nn.PReLU(),
1531
+ nn.Conv2d(
1532
+ in_channels,
1533
+ out_channels,
1534
+ kernel_size=3,
1535
+ stride=stride,
1536
+ padding=1,
1537
+ bias=True,
1538
+ ),
1539
+ nn.PReLU(),
1540
+ nn.Conv2d(
1541
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
1542
+ ),
1543
+ )
1544
+
1545
+
1546
+ # upsample block
1547
+ def UpsampleBlock(in_channels, out_channels, stride=2):
1548
+ return torch.nn.Sequential(
1549
+ nn.PReLU(),
1550
+ nn.ConvTranspose2d(
1551
+ in_channels,
1552
+ out_channels,
1553
+ kernel_size=4,
1554
+ stride=stride,
1555
+ padding=1,
1556
+ bias=True,
1557
+ ),
1558
+ nn.PReLU(),
1559
+ nn.Conv2d(
1560
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
1561
+ ),
1562
+ )
1563
+
1564
+
1565
+ class PixelShuffleBlcok(nn.Module):
1566
+ def __init__(self, in_feat, num_feat, num_out_ch):
1567
+ super(PixelShuffleBlcok, self).__init__()
1568
+ self.conv_before_upsample = nn.Sequential(
1569
+ nn.Conv2d(in_feat, num_feat, 3, 1, 1), nn.PReLU()
1570
+ )
1571
+ self.upsample = nn.Sequential(
1572
+ nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), nn.PixelShuffle(2)
1573
+ )
1574
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1575
+
1576
+ def forward(self, x):
1577
+ x = self.conv_before_upsample(x)
1578
+ x = self.conv_last(self.upsample(x))
1579
+ return x
1580
+
1581
+
1582
+ # grid network
1583
+ class GridNet(nn.Module):
1584
+ def __init__(
1585
+ self,
1586
+ in_channels=12,
1587
+ in_channels1=128,
1588
+ in_channels2=256,
1589
+ in_channels3=384,
1590
+ out_channels=3,
1591
+ ):
1592
+ super(GridNet, self).__init__()
1593
+
1594
+ self.residual_model_head = ResidualBlock(in_channels, 64)
1595
+ self.residual_model_head1 = ResidualBlock(in_channels1, 64)
1596
+ self.residual_model_head2 = ResidualBlock(in_channels2, 128)
1597
+ self.residual_model_head3 = ResidualBlock(in_channels3, 192)
1598
+
1599
+ self.residual_model_01 = ResidualBlock(64, 64)
1600
+ # self.residual_model_02=ResidualBlock(64, 64)
1601
+ # self.residual_model_03=ResidualBlock(64, 64)
1602
+ self.residual_model_04 = ResidualBlock(64, 64)
1603
+ self.residual_model_05 = ResidualBlock(64, 64)
1604
+ self.residual_model_tail = PixelShuffleBlcok(64, 64, out_channels)
1605
+
1606
+ self.residual_model_11 = ResidualBlock(128, 128)
1607
+ # self.residual_model_12=ResidualBlock(128, 128)
1608
+ # self.residual_model_13=ResidualBlock(128, 128)
1609
+ self.residual_model_14 = ResidualBlock(128, 128)
1610
+ self.residual_model_15 = ResidualBlock(128, 128)
1611
+
1612
+ self.residual_model_21 = ResidualBlock(192, 192)
1613
+ # self.residual_model_22=ResidualBlock(192, 192)
1614
+ # self.residual_model_23=ResidualBlock(192, 192)
1615
+ self.residual_model_24 = ResidualBlock(192, 192)
1616
+ self.residual_model_25 = ResidualBlock(192, 192)
1617
+
1618
+ #
1619
+
1620
+ self.downsample_model_10 = DownsampleBlock(64, 128)
1621
+ self.downsample_model_20 = DownsampleBlock(128, 192)
1622
+
1623
+ self.downsample_model_11 = DownsampleBlock(64, 128)
1624
+ self.downsample_model_21 = DownsampleBlock(128, 192)
1625
+
1626
+ # self.downsample_model_12=DownsampleBlock(64, 128)
1627
+ # self.downsample_model_22=DownsampleBlock(128, 192)
1628
+
1629
+ #
1630
+
1631
+ # self.upsample_model_03=UpsampleBlock(128, 64)
1632
+ # self.upsample_model_13=UpsampleBlock(192, 128)
1633
+
1634
+ self.upsample_model_04 = UpsampleBlock(128, 64)
1635
+ self.upsample_model_14 = UpsampleBlock(192, 128)
1636
+
1637
+ self.upsample_model_05 = UpsampleBlock(128, 64)
1638
+ self.upsample_model_15 = UpsampleBlock(192, 128)
1639
+
1640
+ def forward(self, x, x1, x2, x3):
1641
+ X00 = self.residual_model_head(x) + self.residual_model_head1(
1642
+ x1
1643
+ ) # --- 182 ~ 185
1644
+ # X10 = self.residual_model_head1(x1)
1645
+
1646
+ X01 = self.residual_model_01(X00) + X00 # --- 208 ~ 211 ,AddBackward1213
1647
+
1648
+ X10 = self.downsample_model_10(X00) + self.residual_model_head2(
1649
+ x2
1650
+ ) # --- 186 ~ 189
1651
+ X20 = self.downsample_model_20(X10) + self.residual_model_head3(
1652
+ x3
1653
+ ) # --- 190 ~ 193
1654
+
1655
+ residual_11 = (
1656
+ self.residual_model_11(X10) + X10
1657
+ ) # 201 ~ 204 , sum AddBackward1206
1658
+ downsample_11 = self.downsample_model_11(X01) # 214 ~ 217
1659
+ X11 = residual_11 + downsample_11 # --- AddBackward1218
1660
+
1661
+ residual_21 = (
1662
+ self.residual_model_21(X20) + X20
1663
+ ) # 194 ~ 197 , sum AddBackward1199
1664
+ downsample_21 = self.downsample_model_21(X11) # 219 ~ 222
1665
+ X21 = residual_21 + downsample_21 # AddBackward1223
1666
+
1667
+ X24 = self.residual_model_24(X21) + X21 # --- 224 ~ 227 , AddBackward1229
1668
+ X25 = self.residual_model_25(X24) + X24 # --- 230 ~ 233 , AddBackward1235
1669
+
1670
+ upsample_14 = self.upsample_model_14(X24) # 242 ~ 246
1671
+ residual_14 = self.residual_model_14(X11) + X11 # 248 ~ 251, AddBackward1253
1672
+ X14 = upsample_14 + residual_14 # --- AddBackward1254
1673
+
1674
+ upsample_04 = self.upsample_model_04(X14) # 268 ~ 272
1675
+ residual_04 = self.residual_model_04(X01) + X01 # 274 ~ 277, AddBackward1279
1676
+ X04 = upsample_04 + residual_04 # --- AddBackward1280
1677
+
1678
+ upsample_15 = self.upsample_model_15(X25) # 236 ~ 240
1679
+ residual_15 = self.residual_model_15(X14) + X14 # 255 ~ 258, AddBackward1260
1680
+ X15 = upsample_15 + residual_15 # AddBackward1261
1681
+
1682
+ upsample_05 = self.upsample_model_05(X15) # 262 ~ 266
1683
+ residual_05 = self.residual_model_05(X04) + X04 # 281 ~ 284,AddBackward1286
1684
+ X05 = upsample_05 + residual_05 # AddBackward1287
1685
+
1686
+ X_tail = self.residual_model_tail(X05) # 288 ~ 291
1687
+
1688
+ return X_tail
1689
+ # end
1690
+
1691
+ class Model:
1692
+ def __init__(self):
1693
+ self.flownet = GMFlow()
1694
+ self.metricnet = MetricNet()
1695
+ self.feat_ext = FeatureNet()
1696
+ self.fusionnet = GridNet()
1697
+ self.version = 3.9
1698
+
1699
+ def eval(self):
1700
+ self.flownet.eval()
1701
+ self.metricnet.eval()
1702
+ self.feat_ext.eval()
1703
+ self.fusionnet.eval()
1704
+
1705
+ def device(self):
1706
+ self.flownet.to(device)
1707
+ self.metricnet.to(device)
1708
+ self.feat_ext.to(device)
1709
+ self.fusionnet.to(device)
1710
+
1711
+ def load_model(self, path_dict):
1712
+ #models/GMFSS_fortuna_flownet.pkl
1713
+ self.flownet.load_state_dict(torch.load(path_dict["flownet"]))
1714
+ #models/GMFSS_fortuna_metric.pkl
1715
+ self.metricnet.load_state_dict(torch.load(path_dict["metricnet"]))
1716
+ #models/GMFSS_fortuna_feat.pkl
1717
+ self.feat_ext.load_state_dict(torch.load(path_dict["feat_ext"]))
1718
+ #models/GMFSS_fortuna_fusionnet.pkl
1719
+ self.fusionnet.load_state_dict(torch.load(path_dict["fusionnet"]))
1720
+
1721
+ def reuse(self, img0, img1, scale):
1722
+ feat11, feat12, feat13 = self.feat_ext(img0)
1723
+ feat21, feat22, feat23 = self.feat_ext(img1)
1724
+
1725
+ img0 = F.interpolate(
1726
+ img0, scale_factor=0.5, mode="bilinear", align_corners=False
1727
+ )
1728
+ img1 = F.interpolate(
1729
+ img1, scale_factor=0.5, mode="bilinear", align_corners=False
1730
+ )
1731
+
1732
+ if scale != 1.0:
1733
+ imgf0 = F.interpolate(
1734
+ img0, scale_factor=scale, mode="bilinear", align_corners=False
1735
+ )
1736
+ imgf1 = F.interpolate(
1737
+ img1, scale_factor=scale, mode="bilinear", align_corners=False
1738
+ )
1739
+ else:
1740
+ imgf0 = img0
1741
+ imgf1 = img1
1742
+ flow01 = self.flownet(imgf0, imgf1, return_flow=True)
1743
+ flow10 = self.flownet(imgf1, imgf0, return_flow=True)
1744
+ if scale != 1.0:
1745
+ flow01 = (
1746
+ F.interpolate(
1747
+ flow01,
1748
+ scale_factor=1.0 / scale,
1749
+ mode="bilinear",
1750
+ align_corners=False,
1751
+ )
1752
+ / scale
1753
+ )
1754
+ flow10 = (
1755
+ F.interpolate(
1756
+ flow10,
1757
+ scale_factor=1.0 / scale,
1758
+ mode="bilinear",
1759
+ align_corners=False,
1760
+ )
1761
+ / scale
1762
+ )
1763
+
1764
+ metric0, metric1 = self.metricnet(img0, img1, flow01, flow10)
1765
+
1766
+ return (
1767
+ flow01,
1768
+ flow10,
1769
+ metric0,
1770
+ metric1,
1771
+ feat11,
1772
+ feat12,
1773
+ feat13,
1774
+ feat21,
1775
+ feat22,
1776
+ feat23,
1777
+ )
1778
+
1779
+ def inference(
1780
+ self,
1781
+ img0,
1782
+ img1,
1783
+ flow01,
1784
+ flow10,
1785
+ metric0,
1786
+ metric1,
1787
+ feat11,
1788
+ feat12,
1789
+ feat13,
1790
+ feat21,
1791
+ feat22,
1792
+ feat23,
1793
+ timestep,
1794
+ ):
1795
+ F1t = timestep * flow01
1796
+ F2t = (1 - timestep) * flow10
1797
+
1798
+ Z1t = timestep * metric0
1799
+ Z2t = (1 - timestep) * metric1
1800
+
1801
+ img0 = F.interpolate(
1802
+ img0, scale_factor=0.5, mode="bilinear", align_corners=False
1803
+ )
1804
+ I1t = softsplat(img0, F1t, Z1t, strMode="soft")
1805
+ img1 = F.interpolate(
1806
+ img1, scale_factor=0.5, mode="bilinear", align_corners=False
1807
+ )
1808
+ I2t = softsplat(img1, F2t, Z2t, strMode="soft")
1809
+
1810
+ feat1t1 = softsplat(feat11, F1t, Z1t, strMode="soft")
1811
+ feat2t1 = softsplat(feat21, F2t, Z2t, strMode="soft")
1812
+
1813
+ F1td = (
1814
+ F.interpolate(F1t, scale_factor=0.5, mode="bilinear", align_corners=False)
1815
+ * 0.5
1816
+ )
1817
+ Z1d = F.interpolate(Z1t, scale_factor=0.5, mode="bilinear", align_corners=False)
1818
+ feat1t2 = softsplat(feat12, F1td, Z1d, strMode="soft")
1819
+ F2td = (
1820
+ F.interpolate(F2t, scale_factor=0.5, mode="bilinear", align_corners=False)
1821
+ * 0.5
1822
+ )
1823
+ Z2d = F.interpolate(Z2t, scale_factor=0.5, mode="bilinear", align_corners=False)
1824
+ feat2t2 = softsplat(feat22, F2td, Z2d, strMode="soft")
1825
+
1826
+ F1tdd = (
1827
+ F.interpolate(F1t, scale_factor=0.25, mode="bilinear", align_corners=False)
1828
+ * 0.25
1829
+ )
1830
+ Z1dd = F.interpolate(
1831
+ Z1t, scale_factor=0.25, mode="bilinear", align_corners=False
1832
+ )
1833
+ feat1t3 = softsplat(feat13, F1tdd, Z1dd, strMode="soft")
1834
+ F2tdd = (
1835
+ F.interpolate(F2t, scale_factor=0.25, mode="bilinear", align_corners=False)
1836
+ * 0.25
1837
+ )
1838
+ Z2dd = F.interpolate(
1839
+ Z2t, scale_factor=0.25, mode="bilinear", align_corners=False
1840
+ )
1841
+ feat2t3 = softsplat(feat23, F2tdd, Z2dd, strMode="soft")
1842
+
1843
+ out = self.fusionnet(
1844
+ torch.cat([img0, I1t, I2t, img1], dim=1),
1845
+ torch.cat([feat1t1, feat2t1], dim=1),
1846
+ torch.cat([feat1t2, feat2t2], dim=1),
1847
+ torch.cat([feat1t3, feat2t3], dim=1),
1848
+ )
1849
+
1850
+ return torch.clamp(out, 0, 1)
vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import vapoursynth as vs
4
+ from .GMFSS_Fortuna_union_arch import Model_inference
5
+ import torch
6
+
7
+
8
+ class GMFSS_Fortuna_union:
9
+ def __init__(self):
10
+ self.cache = False
11
+ self.amount_input_img = 2
12
+
13
+ torch.set_grad_enabled(False)
14
+ torch.backends.cudnn.enabled = True
15
+ torch.backends.cudnn.benchmark = True
16
+
17
+ self.model = Model_inference()
18
+ self.model.eval()
19
+
20
+ def execute(self, I0, I1, timestep):
21
+ with torch.inference_mode():
22
+ middle = self.model(I0, I1, timestep).cpu()
23
+ return middle
vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py ADDED
@@ -0,0 +1,1857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/GMFSS_infer_u.py
3
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/softsplat.py
4
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FusionNet_u.py
5
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FeatureNet.py
6
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/MetricNet.py
7
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/IFNet_HDv3.py
8
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/gmflow.py
9
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/utils.py
10
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/position.py
11
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/geometry.py
12
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/matching.py
13
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/transformer.py
14
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/backbone.py
15
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/trident_conv.py
16
+ https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/warplayer.py
17
+ """
18
+
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from torch.nn.modules.utils import _pair
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torch
26
+ import math
27
+ from vfi_models.rife.rife_arch import IFNet
28
+ from vfi_models.ops import softsplat
29
+ from comfy.model_management import get_torch_device
30
+
31
+ device = get_torch_device()
32
+ backwarp_tenGrid = {}
33
+
34
+
35
+ def warp(tenInput, tenFlow):
36
+ k = (str(tenFlow.device), str(tenFlow.size()))
37
+ if k not in backwarp_tenGrid:
38
+ tenHorizontal = (
39
+ torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
40
+ .view(1, 1, 1, tenFlow.shape[3])
41
+ .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
42
+ )
43
+ tenVertical = (
44
+ torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
45
+ .view(1, 1, tenFlow.shape[2], 1)
46
+ .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
47
+ )
48
+ backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
49
+
50
+ tenFlow = torch.cat(
51
+ [
52
+ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
53
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
54
+ ],
55
+ 1,
56
+ )
57
+
58
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
59
+ return torch.nn.functional.grid_sample(
60
+ input=tenInput,
61
+ grid=g,
62
+ mode="bilinear",
63
+ padding_mode="border",
64
+ align_corners=True,
65
+ )
66
+
67
+
68
+ class MultiScaleTridentConv(nn.Module):
69
+ def __init__(
70
+ self,
71
+ in_channels,
72
+ out_channels,
73
+ kernel_size,
74
+ stride=1,
75
+ strides=1,
76
+ paddings=0,
77
+ dilations=1,
78
+ dilation=1,
79
+ groups=1,
80
+ num_branch=1,
81
+ test_branch_idx=-1,
82
+ bias=False,
83
+ norm=None,
84
+ activation=None,
85
+ ):
86
+ super(MultiScaleTridentConv, self).__init__()
87
+ self.in_channels = in_channels
88
+ self.out_channels = out_channels
89
+ self.kernel_size = _pair(kernel_size)
90
+ self.num_branch = num_branch
91
+ self.stride = _pair(stride)
92
+ self.groups = groups
93
+ self.with_bias = bias
94
+ self.dilation = dilation
95
+ if isinstance(paddings, int):
96
+ paddings = [paddings] * self.num_branch
97
+ if isinstance(dilations, int):
98
+ dilations = [dilations] * self.num_branch
99
+ if isinstance(strides, int):
100
+ strides = [strides] * self.num_branch
101
+ self.paddings = [_pair(padding) for padding in paddings]
102
+ self.dilations = [_pair(dilation) for dilation in dilations]
103
+ self.strides = [_pair(stride) for stride in strides]
104
+ self.test_branch_idx = test_branch_idx
105
+ self.norm = norm
106
+ self.activation = activation
107
+
108
+ assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
109
+
110
+ self.weight = nn.Parameter(
111
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
112
+ )
113
+ if bias:
114
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
115
+ else:
116
+ self.bias = None
117
+
118
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
119
+ if self.bias is not None:
120
+ nn.init.constant_(self.bias, 0)
121
+
122
+ def forward(self, inputs):
123
+ num_branch = (
124
+ self.num_branch if self.training or self.test_branch_idx == -1 else 1
125
+ )
126
+ assert len(inputs) == num_branch
127
+
128
+ if self.training or self.test_branch_idx == -1:
129
+ outputs = [
130
+ F.conv2d(
131
+ input,
132
+ self.weight,
133
+ self.bias,
134
+ stride,
135
+ padding,
136
+ self.dilation,
137
+ self.groups,
138
+ )
139
+ for input, stride, padding in zip(inputs, self.strides, self.paddings)
140
+ ]
141
+ else:
142
+ outputs = [
143
+ F.conv2d(
144
+ inputs[0],
145
+ self.weight,
146
+ self.bias,
147
+ self.strides[self.test_branch_idx]
148
+ if self.test_branch_idx == -1
149
+ else self.strides[-1],
150
+ self.paddings[self.test_branch_idx]
151
+ if self.test_branch_idx == -1
152
+ else self.paddings[-1],
153
+ self.dilation,
154
+ self.groups,
155
+ )
156
+ ]
157
+
158
+ if self.norm is not None:
159
+ outputs = [self.norm(x) for x in outputs]
160
+ if self.activation is not None:
161
+ outputs = [self.activation(x) for x in outputs]
162
+ return outputs
163
+
164
+
165
+ class ResidualBlock_class(nn.Module):
166
+ def __init__(
167
+ self,
168
+ in_planes,
169
+ planes,
170
+ norm_layer=nn.InstanceNorm2d,
171
+ stride=1,
172
+ dilation=1,
173
+ ):
174
+ super(ResidualBlock_class, self).__init__()
175
+
176
+ self.conv1 = nn.Conv2d(
177
+ in_planes,
178
+ planes,
179
+ kernel_size=3,
180
+ dilation=dilation,
181
+ padding=dilation,
182
+ stride=stride,
183
+ bias=False,
184
+ )
185
+ self.conv2 = nn.Conv2d(
186
+ planes,
187
+ planes,
188
+ kernel_size=3,
189
+ dilation=dilation,
190
+ padding=dilation,
191
+ bias=False,
192
+ )
193
+ self.relu = nn.ReLU(inplace=True)
194
+
195
+ self.norm1 = norm_layer(planes)
196
+ self.norm2 = norm_layer(planes)
197
+ if not stride == 1 or in_planes != planes:
198
+ self.norm3 = norm_layer(planes)
199
+
200
+ if stride == 1 and in_planes == planes:
201
+ self.downsample = None
202
+ else:
203
+ self.downsample = nn.Sequential(
204
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
205
+ )
206
+
207
+ def forward(self, x):
208
+ y = x
209
+ y = self.relu(self.norm1(self.conv1(y)))
210
+ y = self.relu(self.norm2(self.conv2(y)))
211
+
212
+ if self.downsample is not None:
213
+ x = self.downsample(x)
214
+
215
+ return self.relu(x + y)
216
+
217
+
218
+ class CNNEncoder(nn.Module):
219
+ def __init__(
220
+ self,
221
+ output_dim=128,
222
+ norm_layer=nn.InstanceNorm2d,
223
+ num_output_scales=1,
224
+ **kwargs,
225
+ ):
226
+ super(CNNEncoder, self).__init__()
227
+ self.num_branch = num_output_scales
228
+
229
+ feature_dims = [64, 96, 128]
230
+
231
+ self.conv1 = nn.Conv2d(
232
+ 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False
233
+ ) # 1/2
234
+ self.norm1 = norm_layer(feature_dims[0])
235
+ self.relu1 = nn.ReLU(inplace=True)
236
+
237
+ self.in_planes = feature_dims[0]
238
+ self.layer1 = self._make_layer(
239
+ feature_dims[0], stride=1, norm_layer=norm_layer
240
+ ) # 1/2
241
+ self.layer2 = self._make_layer(
242
+ feature_dims[1], stride=2, norm_layer=norm_layer
243
+ ) # 1/4
244
+
245
+ # highest resolution 1/4 or 1/8
246
+ stride = 2 if num_output_scales == 1 else 1
247
+ self.layer3 = self._make_layer(
248
+ feature_dims[2],
249
+ stride=stride,
250
+ norm_layer=norm_layer,
251
+ ) # 1/4 or 1/8
252
+
253
+ self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
254
+
255
+ if self.num_branch > 1:
256
+ if self.num_branch == 4:
257
+ strides = (1, 2, 4, 8)
258
+ elif self.num_branch == 3:
259
+ strides = (1, 2, 4)
260
+ elif self.num_branch == 2:
261
+ strides = (1, 2)
262
+ else:
263
+ raise ValueError
264
+
265
+ self.trident_conv = MultiScaleTridentConv(
266
+ output_dim,
267
+ output_dim,
268
+ kernel_size=3,
269
+ strides=strides,
270
+ paddings=1,
271
+ num_branch=self.num_branch,
272
+ )
273
+
274
+ for m in self.modules():
275
+ if isinstance(m, nn.Conv2d):
276
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
277
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
278
+ if m.weight is not None:
279
+ nn.init.constant_(m.weight, 1)
280
+ if m.bias is not None:
281
+ nn.init.constant_(m.bias, 0)
282
+
283
+ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
284
+ layer1 = ResidualBlock_class(
285
+ self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation
286
+ )
287
+ layer2 = ResidualBlock_class(
288
+ dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation
289
+ )
290
+
291
+ layers = (layer1, layer2)
292
+
293
+ self.in_planes = dim
294
+ return nn.Sequential(*layers)
295
+
296
+ def forward(self, x):
297
+ x = self.conv1(x)
298
+ x = self.norm1(x)
299
+ x = self.relu1(x)
300
+
301
+ x = self.layer1(x) # 1/2
302
+ x = self.layer2(x) # 1/4
303
+ x = self.layer3(x) # 1/8 or 1/4
304
+
305
+ x = self.conv2(x)
306
+
307
+ if self.num_branch > 1:
308
+ out = self.trident_conv([x] * self.num_branch) # high to low res
309
+ else:
310
+ out = [x]
311
+
312
+ return out
313
+
314
+
315
+ def single_head_full_attention(q, k, v):
316
+ # q, k, v: [B, L, C]
317
+ assert q.dim() == k.dim() == v.dim() == 3
318
+
319
+ scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L]
320
+ attn = torch.softmax(scores, dim=2) # [B, L, L]
321
+ out = torch.matmul(attn, v) # [B, L, C]
322
+
323
+ return out
324
+
325
+
326
+ def generate_shift_window_attn_mask(
327
+ input_resolution,
328
+ window_size_h,
329
+ window_size_w,
330
+ shift_size_h,
331
+ shift_size_w,
332
+ device=get_torch_device(),
333
+ ):
334
+ # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
335
+ # calculate attention mask for SW-MSA
336
+ h, w = input_resolution
337
+ img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
338
+ h_slices = (
339
+ slice(0, -window_size_h),
340
+ slice(-window_size_h, -shift_size_h),
341
+ slice(-shift_size_h, None),
342
+ )
343
+ w_slices = (
344
+ slice(0, -window_size_w),
345
+ slice(-window_size_w, -shift_size_w),
346
+ slice(-shift_size_w, None),
347
+ )
348
+ cnt = 0
349
+ for h in h_slices:
350
+ for w in w_slices:
351
+ img_mask[:, h, w, :] = cnt
352
+ cnt += 1
353
+
354
+ mask_windows = split_feature(
355
+ img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
356
+ )
357
+
358
+ mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
359
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
360
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
361
+ attn_mask == 0, float(0.0)
362
+ )
363
+
364
+ return attn_mask
365
+
366
+
367
+ def single_head_split_window_attention(
368
+ q,
369
+ k,
370
+ v,
371
+ num_splits=1,
372
+ with_shift=False,
373
+ h=None,
374
+ w=None,
375
+ attn_mask=None,
376
+ ):
377
+ # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
378
+ # q, k, v: [B, L, C]
379
+ assert q.dim() == k.dim() == v.dim() == 3
380
+
381
+ assert h is not None and w is not None
382
+ assert q.size(1) == h * w
383
+
384
+ b, _, c = q.size()
385
+
386
+ b_new = b * num_splits * num_splits
387
+
388
+ window_size_h = h // num_splits
389
+ window_size_w = w // num_splits
390
+
391
+ q = q.view(b, h, w, c) # [B, H, W, C]
392
+ k = k.view(b, h, w, c)
393
+ v = v.view(b, h, w, c)
394
+
395
+ scale_factor = c**0.5
396
+
397
+ if with_shift:
398
+ assert attn_mask is not None # compute once
399
+ shift_size_h = window_size_h // 2
400
+ shift_size_w = window_size_w // 2
401
+
402
+ q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
403
+ k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
404
+ v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
405
+
406
+ q = split_feature(
407
+ q, num_splits=num_splits, channel_last=True
408
+ ) # [B*K*K, H/K, W/K, C]
409
+ k = split_feature(k, num_splits=num_splits, channel_last=True)
410
+ v = split_feature(v, num_splits=num_splits, channel_last=True)
411
+
412
+ scores = (
413
+ torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1))
414
+ / scale_factor
415
+ ) # [B*K*K, H/K*W/K, H/K*W/K]
416
+
417
+ if with_shift:
418
+ scores += attn_mask.repeat(b, 1, 1)
419
+
420
+ attn = torch.softmax(scores, dim=-1)
421
+
422
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
423
+
424
+ out = merge_splits(
425
+ out.view(b_new, h // num_splits, w // num_splits, c),
426
+ num_splits=num_splits,
427
+ channel_last=True,
428
+ ) # [B, H, W, C]
429
+
430
+ # shift back
431
+ if with_shift:
432
+ out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
433
+
434
+ out = out.view(b, -1, c)
435
+
436
+ return out
437
+
438
+
439
+ class TransformerLayer(nn.Module):
440
+ def __init__(
441
+ self,
442
+ d_model=256,
443
+ nhead=1,
444
+ attention_type="swin",
445
+ no_ffn=False,
446
+ ffn_dim_expansion=4,
447
+ with_shift=False,
448
+ **kwargs,
449
+ ):
450
+ super(TransformerLayer, self).__init__()
451
+
452
+ self.dim = d_model
453
+ self.nhead = nhead
454
+ self.attention_type = attention_type
455
+ self.no_ffn = no_ffn
456
+
457
+ self.with_shift = with_shift
458
+
459
+ # multi-head attention
460
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
461
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
462
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
463
+
464
+ self.merge = nn.Linear(d_model, d_model, bias=False)
465
+
466
+ self.norm1 = nn.LayerNorm(d_model)
467
+
468
+ # no ffn after self-attn, with ffn after cross-attn
469
+ if not self.no_ffn:
470
+ in_channels = d_model * 2
471
+ self.mlp = nn.Sequential(
472
+ nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
473
+ nn.GELU(),
474
+ nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
475
+ )
476
+
477
+ self.norm2 = nn.LayerNorm(d_model)
478
+
479
+ def forward(
480
+ self,
481
+ source,
482
+ target,
483
+ height=None,
484
+ width=None,
485
+ shifted_window_attn_mask=None,
486
+ attn_num_splits=None,
487
+ **kwargs,
488
+ ):
489
+ # source, target: [B, L, C]
490
+ query, key, value = source, target, target
491
+
492
+ # single-head attention
493
+ query = self.q_proj(query) # [B, L, C]
494
+ key = self.k_proj(key) # [B, L, C]
495
+ value = self.v_proj(value) # [B, L, C]
496
+
497
+ if self.attention_type == "swin" and attn_num_splits > 1:
498
+ if self.nhead > 1:
499
+ # we observe that multihead attention slows down the speed and increases the memory consumption
500
+ # without bringing obvious performance gains and thus the implementation is removed
501
+ raise NotImplementedError
502
+ else:
503
+ message = single_head_split_window_attention(
504
+ query,
505
+ key,
506
+ value,
507
+ num_splits=attn_num_splits,
508
+ with_shift=self.with_shift,
509
+ h=height,
510
+ w=width,
511
+ attn_mask=shifted_window_attn_mask,
512
+ )
513
+ else:
514
+ message = single_head_full_attention(query, key, value) # [B, L, C]
515
+
516
+ message = self.merge(message) # [B, L, C]
517
+ message = self.norm1(message)
518
+
519
+ if not self.no_ffn:
520
+ message = self.mlp(torch.cat([source, message], dim=-1))
521
+ message = self.norm2(message)
522
+
523
+ return source + message
524
+
525
+
526
+ class TransformerBlock(nn.Module):
527
+ """self attention + cross attention + FFN"""
528
+
529
+ def __init__(
530
+ self,
531
+ d_model=256,
532
+ nhead=1,
533
+ attention_type="swin",
534
+ ffn_dim_expansion=4,
535
+ with_shift=False,
536
+ **kwargs,
537
+ ):
538
+ super(TransformerBlock, self).__init__()
539
+
540
+ self.self_attn = TransformerLayer(
541
+ d_model=d_model,
542
+ nhead=nhead,
543
+ attention_type=attention_type,
544
+ no_ffn=True,
545
+ ffn_dim_expansion=ffn_dim_expansion,
546
+ with_shift=with_shift,
547
+ )
548
+
549
+ self.cross_attn_ffn = TransformerLayer(
550
+ d_model=d_model,
551
+ nhead=nhead,
552
+ attention_type=attention_type,
553
+ ffn_dim_expansion=ffn_dim_expansion,
554
+ with_shift=with_shift,
555
+ )
556
+
557
+ def forward(
558
+ self,
559
+ source,
560
+ target,
561
+ height=None,
562
+ width=None,
563
+ shifted_window_attn_mask=None,
564
+ attn_num_splits=None,
565
+ **kwargs,
566
+ ):
567
+ # source, target: [B, L, C]
568
+
569
+ # self attention
570
+ source = self.self_attn(
571
+ source,
572
+ source,
573
+ height=height,
574
+ width=width,
575
+ shifted_window_attn_mask=shifted_window_attn_mask,
576
+ attn_num_splits=attn_num_splits,
577
+ )
578
+
579
+ # cross attention and ffn
580
+ source = self.cross_attn_ffn(
581
+ source,
582
+ target,
583
+ height=height,
584
+ width=width,
585
+ shifted_window_attn_mask=shifted_window_attn_mask,
586
+ attn_num_splits=attn_num_splits,
587
+ )
588
+
589
+ return source
590
+
591
+
592
+ class FeatureTransformer(nn.Module):
593
+ def __init__(
594
+ self,
595
+ num_layers=6,
596
+ d_model=128,
597
+ nhead=1,
598
+ attention_type="swin",
599
+ ffn_dim_expansion=4,
600
+ **kwargs,
601
+ ):
602
+ super(FeatureTransformer, self).__init__()
603
+
604
+ self.attention_type = attention_type
605
+
606
+ self.d_model = d_model
607
+ self.nhead = nhead
608
+
609
+ self.layers = nn.ModuleList(
610
+ [
611
+ TransformerBlock(
612
+ d_model=d_model,
613
+ nhead=nhead,
614
+ attention_type=attention_type,
615
+ ffn_dim_expansion=ffn_dim_expansion,
616
+ with_shift=True
617
+ if attention_type == "swin" and i % 2 == 1
618
+ else False,
619
+ )
620
+ for i in range(num_layers)
621
+ ]
622
+ )
623
+
624
+ for p in self.parameters():
625
+ if p.dim() > 1:
626
+ nn.init.xavier_uniform_(p)
627
+
628
+ def forward(
629
+ self,
630
+ feature0,
631
+ feature1,
632
+ attn_num_splits=None,
633
+ **kwargs,
634
+ ):
635
+ b, c, h, w = feature0.shape
636
+ assert self.d_model == c
637
+
638
+ feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
639
+ feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
640
+
641
+ if self.attention_type == "swin" and attn_num_splits > 1:
642
+ # global and refine use different number of splits
643
+ window_size_h = h // attn_num_splits
644
+ window_size_w = w // attn_num_splits
645
+
646
+ # compute attn mask once
647
+ shifted_window_attn_mask = generate_shift_window_attn_mask(
648
+ input_resolution=(h, w),
649
+ window_size_h=window_size_h,
650
+ window_size_w=window_size_w,
651
+ shift_size_h=window_size_h // 2,
652
+ shift_size_w=window_size_w // 2,
653
+ device=feature0.device,
654
+ ) # [K*K, H/K*W/K, H/K*W/K]
655
+ else:
656
+ shifted_window_attn_mask = None
657
+
658
+ # concat feature0 and feature1 in batch dimension to compute in parallel
659
+ concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
660
+ concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
661
+
662
+ for layer in self.layers:
663
+ concat0 = layer(
664
+ concat0,
665
+ concat1,
666
+ height=h,
667
+ width=w,
668
+ shifted_window_attn_mask=shifted_window_attn_mask,
669
+ attn_num_splits=attn_num_splits,
670
+ )
671
+
672
+ # update feature1
673
+ concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
674
+
675
+ feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
676
+
677
+ # reshape back
678
+ feature0 = (
679
+ feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
680
+ ) # [B, C, H, W]
681
+ feature1 = (
682
+ feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
683
+ ) # [B, C, H, W]
684
+
685
+ return feature0, feature1
686
+
687
+
688
+ class FeatureFlowAttention(nn.Module):
689
+ """
690
+ flow propagation with self-attention on feature
691
+ query: feature0, key: feature0, value: flow
692
+ """
693
+
694
+ def __init__(
695
+ self,
696
+ in_channels,
697
+ **kwargs,
698
+ ):
699
+ super(FeatureFlowAttention, self).__init__()
700
+
701
+ self.q_proj = nn.Linear(in_channels, in_channels)
702
+ self.k_proj = nn.Linear(in_channels, in_channels)
703
+
704
+ for p in self.parameters():
705
+ if p.dim() > 1:
706
+ nn.init.xavier_uniform_(p)
707
+
708
+ def forward(
709
+ self,
710
+ feature0,
711
+ flow,
712
+ local_window_attn=False,
713
+ local_window_radius=1,
714
+ **kwargs,
715
+ ):
716
+ # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
717
+ if local_window_attn:
718
+ return self.forward_local_window_attn(
719
+ feature0, flow, local_window_radius=local_window_radius
720
+ )
721
+
722
+ b, c, h, w = feature0.size()
723
+
724
+ query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
725
+
726
+ # a note: the ``correct'' implementation should be:
727
+ # ``query = self.q_proj(query), key = self.k_proj(query)''
728
+ # this problem is observed while cleaning up the code
729
+ # however, this doesn't affect the performance since the projection is a linear operation,
730
+ # thus the two projection matrices for key can be merged
731
+ # so I just leave it as is in order to not re-train all models :)
732
+ query = self.q_proj(query) # [B, H*W, C]
733
+ key = self.k_proj(query) # [B, H*W, C]
734
+
735
+ value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
736
+
737
+ scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W]
738
+ prob = torch.softmax(scores, dim=-1)
739
+
740
+ out = torch.matmul(prob, value) # [B, H*W, 2]
741
+ out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
742
+
743
+ return out
744
+
745
+ def forward_local_window_attn(
746
+ self,
747
+ feature0,
748
+ flow,
749
+ local_window_radius=1,
750
+ ):
751
+ assert flow.size(1) == 2
752
+ assert local_window_radius > 0
753
+
754
+ b, c, h, w = feature0.size()
755
+
756
+ feature0_reshape = self.q_proj(
757
+ feature0.view(b, c, -1).permute(0, 2, 1)
758
+ ).reshape(
759
+ b * h * w, 1, c
760
+ ) # [B*H*W, 1, C]
761
+
762
+ kernel_size = 2 * local_window_radius + 1
763
+
764
+ feature0_proj = (
765
+ self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1))
766
+ .permute(0, 2, 1)
767
+ .reshape(b, c, h, w)
768
+ )
769
+
770
+ feature0_window = F.unfold(
771
+ feature0_proj, kernel_size=kernel_size, padding=local_window_radius
772
+ ) # [B, C*(2R+1)^2), H*W]
773
+
774
+ feature0_window = (
775
+ feature0_window.view(b, c, kernel_size**2, h, w)
776
+ .permute(0, 3, 4, 1, 2)
777
+ .reshape(b * h * w, c, kernel_size**2)
778
+ ) # [B*H*W, C, (2R+1)^2]
779
+
780
+ flow_window = F.unfold(
781
+ flow, kernel_size=kernel_size, padding=local_window_radius
782
+ ) # [B, 2*(2R+1)^2), H*W]
783
+
784
+ flow_window = (
785
+ flow_window.view(b, 2, kernel_size**2, h, w)
786
+ .permute(0, 3, 4, 2, 1)
787
+ .reshape(b * h * w, kernel_size**2, 2)
788
+ ) # [B*H*W, (2R+1)^2, 2]
789
+
790
+ scores = torch.matmul(feature0_reshape, feature0_window) / (
791
+ c**0.5
792
+ ) # [B*H*W, 1, (2R+1)^2]
793
+
794
+ prob = torch.softmax(scores, dim=-1)
795
+
796
+ out = (
797
+ torch.matmul(prob, flow_window)
798
+ .view(b, h, w, 2)
799
+ .permute(0, 3, 1, 2)
800
+ .contiguous()
801
+ ) # [B, 2, H, W]
802
+
803
+ return out
804
+
805
+
806
+ def global_correlation_softmax(
807
+ feature0,
808
+ feature1,
809
+ pred_bidir_flow=False,
810
+ ):
811
+ # global correlation
812
+ b, c, h, w = feature0.shape
813
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
814
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
815
+
816
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
817
+ c**0.5
818
+ ) # [B, H, W, H, W]
819
+
820
+ # flow from softmax
821
+ init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
822
+ grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
823
+
824
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
825
+
826
+ if pred_bidir_flow:
827
+ correlation = torch.cat(
828
+ (correlation, correlation.permute(0, 2, 1)), dim=0
829
+ ) # [2*B, H*W, H*W]
830
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
831
+ grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
832
+ b = b * 2
833
+
834
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
835
+
836
+ correspondence = (
837
+ torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2)
838
+ ) # [B, 2, H, W]
839
+
840
+ # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
841
+ flow = correspondence - init_grid
842
+
843
+ return flow, prob
844
+
845
+
846
+ def local_correlation_softmax(
847
+ feature0,
848
+ feature1,
849
+ local_radius,
850
+ padding_mode="zeros",
851
+ ):
852
+ b, c, h, w = feature0.size()
853
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
854
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
855
+
856
+ local_h = 2 * local_radius + 1
857
+ local_w = 2 * local_radius + 1
858
+
859
+ window_grid = generate_window_grid(
860
+ -local_radius,
861
+ local_radius,
862
+ -local_radius,
863
+ local_radius,
864
+ local_h,
865
+ local_w,
866
+ device=feature0.device,
867
+ ) # [2R+1, 2R+1, 2]
868
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
869
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
870
+
871
+ sample_coords_softmax = sample_coords
872
+
873
+ # exclude coords that are out of image space
874
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (
875
+ sample_coords[:, :, :, 0] < w
876
+ ) # [B, H*W, (2R+1)^2]
877
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (
878
+ sample_coords[:, :, :, 1] < h
879
+ ) # [B, H*W, (2R+1)^2]
880
+
881
+ valid = (
882
+ valid_x & valid_y
883
+ ) # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
884
+
885
+ # normalize coordinates to [-1, 1]
886
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
887
+ window_feature = F.grid_sample(
888
+ feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True
889
+ ).permute(
890
+ 0, 2, 1, 3
891
+ ) # [B, H*W, C, (2R+1)^2]
892
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
893
+
894
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
895
+ c**0.5
896
+ ) # [B, H*W, (2R+1)^2]
897
+
898
+ # mask invalid locations
899
+ corr[~valid] = -1e9
900
+
901
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
902
+
903
+ correspondence = (
904
+ torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
905
+ .squeeze(-2)
906
+ .view(b, h, w, 2)
907
+ .permute(0, 3, 1, 2)
908
+ ) # [B, 2, H, W]
909
+
910
+ flow = correspondence - coords_init
911
+ match_prob = prob
912
+
913
+ return flow, match_prob
914
+
915
+
916
+ def coords_grid(b, h, w, homogeneous=False, device=None):
917
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
918
+
919
+ stacks = [x, y]
920
+
921
+ if homogeneous:
922
+ ones = torch.ones_like(x) # [H, W]
923
+ stacks.append(ones)
924
+
925
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
926
+
927
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
928
+
929
+ if device is not None:
930
+ grid = grid.to(device)
931
+
932
+ return grid
933
+
934
+
935
+ def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
936
+ assert device is not None
937
+
938
+ x, y = torch.meshgrid(
939
+ [
940
+ torch.linspace(w_min, w_max, len_w, device=device),
941
+ torch.linspace(h_min, h_max, len_h, device=device),
942
+ ],
943
+ )
944
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
945
+
946
+ return grid
947
+
948
+
949
+ def normalize_coords(coords, h, w):
950
+ # coords: [B, H, W, 2]
951
+ c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
952
+ return (coords - c) / c # [-1, 1]
953
+
954
+
955
+ def bilinear_sample(
956
+ img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False
957
+ ):
958
+ # img: [B, C, H, W]
959
+ # sample_coords: [B, 2, H, W] in image scale
960
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
961
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
962
+
963
+ b, _, h, w = sample_coords.shape
964
+
965
+ # Normalize to [-1, 1]
966
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
967
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
968
+
969
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
970
+
971
+ img = F.grid_sample(
972
+ img, grid, mode=mode, padding_mode=padding_mode, align_corners=True
973
+ )
974
+
975
+ if return_mask:
976
+ mask = (
977
+ (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1)
978
+ ) # [B, H, W]
979
+
980
+ return img, mask
981
+
982
+ return img
983
+
984
+
985
+ def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
986
+ b, c, h, w = feature.size()
987
+ assert flow.size(1) == 2
988
+
989
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
990
+
991
+ return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
992
+
993
+
994
+ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
995
+ # fwd_flow, bwd_flow: [B, 2, H, W]
996
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
997
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
998
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
999
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
1000
+
1001
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
1002
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
1003
+
1004
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
1005
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
1006
+
1007
+ threshold = alpha * flow_mag + beta
1008
+
1009
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
1010
+ bwd_occ = (diff_bwd > threshold).float()
1011
+
1012
+ return fwd_occ, bwd_occ
1013
+
1014
+
1015
+ class PositionEmbeddingSine(nn.Module):
1016
+ """
1017
+ This is a more standard version of the position embedding, very similar to the one
1018
+ used by the Attention is all you need paper, generalized to work on images.
1019
+ """
1020
+
1021
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
1022
+ super().__init__()
1023
+ self.num_pos_feats = num_pos_feats
1024
+ self.temperature = temperature
1025
+ self.normalize = normalize
1026
+ if scale is not None and normalize is False:
1027
+ raise ValueError("normalize should be True if scale is passed")
1028
+ if scale is None:
1029
+ scale = 2 * math.pi
1030
+ self.scale = scale
1031
+
1032
+ def forward(self, x):
1033
+ # x = tensor_list.tensors # [B, C, H, W]
1034
+ # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
1035
+ b, c, h, w = x.size()
1036
+ mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
1037
+ y_embed = mask.cumsum(1, dtype=torch.float32)
1038
+ x_embed = mask.cumsum(2, dtype=torch.float32)
1039
+ if self.normalize:
1040
+ eps = 1e-6
1041
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
1042
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
1043
+
1044
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
1045
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
1046
+
1047
+ pos_x = x_embed[:, :, :, None] / dim_t
1048
+ pos_y = y_embed[:, :, :, None] / dim_t
1049
+ pos_x = torch.stack(
1050
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
1051
+ ).flatten(3)
1052
+ pos_y = torch.stack(
1053
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
1054
+ ).flatten(3)
1055
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
1056
+ return pos
1057
+
1058
+
1059
+ def split_feature(
1060
+ feature,
1061
+ num_splits=2,
1062
+ channel_last=False,
1063
+ ):
1064
+ if channel_last: # [B, H, W, C]
1065
+ b, h, w, c = feature.size()
1066
+ assert h % num_splits == 0 and w % num_splits == 0
1067
+
1068
+ b_new = b * num_splits * num_splits
1069
+ h_new = h // num_splits
1070
+ w_new = w // num_splits
1071
+
1072
+ feature = (
1073
+ feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
1074
+ .permute(0, 1, 3, 2, 4, 5)
1075
+ .reshape(b_new, h_new, w_new, c)
1076
+ ) # [B*K*K, H/K, W/K, C]
1077
+ else: # [B, C, H, W]
1078
+ b, c, h, w = feature.size()
1079
+ assert h % num_splits == 0 and w % num_splits == 0
1080
+
1081
+ b_new = b * num_splits * num_splits
1082
+ h_new = h // num_splits
1083
+ w_new = w // num_splits
1084
+
1085
+ feature = (
1086
+ feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
1087
+ .permute(0, 2, 4, 1, 3, 5)
1088
+ .reshape(b_new, c, h_new, w_new)
1089
+ ) # [B*K*K, C, H/K, W/K]
1090
+
1091
+ return feature
1092
+
1093
+
1094
+ def merge_splits(
1095
+ splits,
1096
+ num_splits=2,
1097
+ channel_last=False,
1098
+ ):
1099
+ if channel_last: # [B*K*K, H/K, W/K, C]
1100
+ b, h, w, c = splits.size()
1101
+ new_b = b // num_splits // num_splits
1102
+
1103
+ splits = splits.view(new_b, num_splits, num_splits, h, w, c)
1104
+ merge = (
1105
+ splits.permute(0, 1, 3, 2, 4, 5)
1106
+ .contiguous()
1107
+ .view(new_b, num_splits * h, num_splits * w, c)
1108
+ ) # [B, H, W, C]
1109
+ else: # [B*K*K, C, H/K, W/K]
1110
+ b, c, h, w = splits.size()
1111
+ new_b = b // num_splits // num_splits
1112
+
1113
+ splits = splits.view(new_b, num_splits, num_splits, c, h, w)
1114
+ merge = (
1115
+ splits.permute(0, 3, 1, 4, 2, 5)
1116
+ .contiguous()
1117
+ .view(new_b, c, num_splits * h, num_splits * w)
1118
+ ) # [B, C, H, W]
1119
+
1120
+ return merge
1121
+
1122
+
1123
+ def normalize_img(img0, img1):
1124
+ # loaded images are in [0, 255]
1125
+ # normalize by ImageNet mean and std
1126
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
1127
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
1128
+ img0 = (img0 - mean) / std
1129
+ img1 = (img1 - mean) / std
1130
+
1131
+ return img0, img1
1132
+
1133
+
1134
+ def feature_add_position(feature0, feature1, attn_splits, feature_channels):
1135
+ pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
1136
+
1137
+ if attn_splits > 1: # add position in splited window
1138
+ feature0_splits = split_feature(feature0, num_splits=attn_splits)
1139
+ feature1_splits = split_feature(feature1, num_splits=attn_splits)
1140
+
1141
+ position = pos_enc(feature0_splits)
1142
+
1143
+ feature0_splits = feature0_splits + position
1144
+ feature1_splits = feature1_splits + position
1145
+
1146
+ feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
1147
+ feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
1148
+ else:
1149
+ position = pos_enc(feature0)
1150
+
1151
+ feature0 = feature0 + position
1152
+ feature1 = feature1 + position
1153
+
1154
+ return feature0, feature1
1155
+
1156
+
1157
+ class GMFlow(nn.Module):
1158
+ def __init__(
1159
+ self,
1160
+ num_scales=2,
1161
+ upsample_factor=4,
1162
+ feature_channels=128,
1163
+ attention_type="swin",
1164
+ num_transformer_layers=6,
1165
+ ffn_dim_expansion=4,
1166
+ num_head=1,
1167
+ **kwargs,
1168
+ ):
1169
+ super(GMFlow, self).__init__()
1170
+
1171
+ self.num_scales = num_scales
1172
+ self.feature_channels = feature_channels
1173
+ self.upsample_factor = upsample_factor
1174
+ self.attention_type = attention_type
1175
+ self.num_transformer_layers = num_transformer_layers
1176
+
1177
+ # CNN backbone
1178
+ self.backbone = CNNEncoder(
1179
+ output_dim=feature_channels, num_output_scales=num_scales
1180
+ )
1181
+
1182
+ # Transformer
1183
+ self.transformer = FeatureTransformer(
1184
+ num_layers=num_transformer_layers,
1185
+ d_model=feature_channels,
1186
+ nhead=num_head,
1187
+ attention_type=attention_type,
1188
+ ffn_dim_expansion=ffn_dim_expansion,
1189
+ )
1190
+
1191
+ # flow propagation with self-attn
1192
+ self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
1193
+
1194
+ # convex upsampling: concat feature0 and flow as input
1195
+ self.upsampler = nn.Sequential(
1196
+ nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
1197
+ nn.ReLU(inplace=True),
1198
+ nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
1199
+ )
1200
+
1201
+ def extract_feature(self, img0, img1):
1202
+ concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
1203
+ features = self.backbone(
1204
+ concat
1205
+ ) # list of [2B, C, H, W], resolution from high to low
1206
+
1207
+ # reverse: resolution from low to high
1208
+ features = features[::-1]
1209
+
1210
+ feature0, feature1 = [], []
1211
+
1212
+ for i in range(len(features)):
1213
+ feature = features[i]
1214
+ chunks = torch.chunk(feature, 2, 0) # tuple
1215
+ feature0.append(chunks[0])
1216
+ feature1.append(chunks[1])
1217
+
1218
+ return feature0, feature1
1219
+
1220
+ def upsample_flow(
1221
+ self,
1222
+ flow,
1223
+ feature,
1224
+ bilinear=False,
1225
+ upsample_factor=8,
1226
+ ):
1227
+ if bilinear:
1228
+ up_flow = (
1229
+ F.interpolate(
1230
+ flow,
1231
+ scale_factor=upsample_factor,
1232
+ mode="bilinear",
1233
+ align_corners=True,
1234
+ )
1235
+ * upsample_factor
1236
+ )
1237
+
1238
+ else:
1239
+ # convex upsampling
1240
+ concat = torch.cat((flow, feature), dim=1)
1241
+
1242
+ mask = self.upsampler(concat)
1243
+ b, flow_channel, h, w = flow.shape
1244
+ mask = mask.view(
1245
+ b, 1, 9, self.upsample_factor, self.upsample_factor, h, w
1246
+ ) # [B, 1, 9, K, K, H, W]
1247
+ mask = torch.softmax(mask, dim=2)
1248
+
1249
+ up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
1250
+ up_flow = up_flow.view(
1251
+ b, flow_channel, 9, 1, 1, h, w
1252
+ ) # [B, 2, 9, 1, 1, H, W]
1253
+
1254
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
1255
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
1256
+ up_flow = up_flow.reshape(
1257
+ b, flow_channel, self.upsample_factor * h, self.upsample_factor * w
1258
+ ) # [B, 2, K*H, K*W]
1259
+
1260
+ return up_flow
1261
+
1262
+ def forward(
1263
+ self,
1264
+ img0,
1265
+ img1,
1266
+ attn_splits_list=[2, 8],
1267
+ corr_radius_list=[-1, 4],
1268
+ prop_radius_list=[-1, 1],
1269
+ pred_bidir_flow=False,
1270
+ **kwargs,
1271
+ ):
1272
+ img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
1273
+
1274
+ # resolution low to high
1275
+ feature0_list, feature1_list = self.extract_feature(
1276
+ img0, img1
1277
+ ) # list of features
1278
+
1279
+ flow = None
1280
+
1281
+ assert (
1282
+ len(attn_splits_list)
1283
+ == len(corr_radius_list)
1284
+ == len(prop_radius_list)
1285
+ == self.num_scales
1286
+ )
1287
+
1288
+ for scale_idx in range(self.num_scales):
1289
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
1290
+
1291
+ if pred_bidir_flow and scale_idx > 0:
1292
+ # predicting bidirectional flow with refinement
1293
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
1294
+ (feature1, feature0), dim=0
1295
+ )
1296
+
1297
+ upsample_factor = self.upsample_factor * (
1298
+ 2 ** (self.num_scales - 1 - scale_idx)
1299
+ )
1300
+
1301
+ if scale_idx > 0:
1302
+ flow = (
1303
+ F.interpolate(
1304
+ flow, scale_factor=2, mode="bilinear", align_corners=True
1305
+ )
1306
+ * 2
1307
+ )
1308
+
1309
+ if flow is not None:
1310
+ flow = flow.detach()
1311
+ feature1 = flow_warp(feature1, flow) # [B, C, H, W]
1312
+
1313
+ attn_splits = attn_splits_list[scale_idx]
1314
+ corr_radius = corr_radius_list[scale_idx]
1315
+ prop_radius = prop_radius_list[scale_idx]
1316
+
1317
+ # add position to features
1318
+ feature0, feature1 = feature_add_position(
1319
+ feature0, feature1, attn_splits, self.feature_channels
1320
+ )
1321
+
1322
+ # Transformer
1323
+ feature0, feature1 = self.transformer(
1324
+ feature0, feature1, attn_num_splits=attn_splits
1325
+ )
1326
+
1327
+ # correlation and softmax
1328
+ if corr_radius == -1: # global matching
1329
+ flow_pred = global_correlation_softmax(
1330
+ feature0, feature1, pred_bidir_flow
1331
+ )[0]
1332
+ else: # local matching
1333
+ flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[
1334
+ 0
1335
+ ]
1336
+
1337
+ # flow or residual flow
1338
+ flow = flow + flow_pred if flow is not None else flow_pred
1339
+
1340
+ # upsample to the original resolution for supervison
1341
+ if (
1342
+ self.training
1343
+ ): # only need to upsample intermediate flow predictions at training time
1344
+ flow_bilinear = self.upsample_flow(
1345
+ flow, None, bilinear=True, upsample_factor=upsample_factor
1346
+ )
1347
+
1348
+ # flow propagation with self-attn
1349
+ if pred_bidir_flow and scale_idx == 0:
1350
+ feature0 = torch.cat(
1351
+ (feature0, feature1), dim=0
1352
+ ) # [2*B, C, H, W] for propagation
1353
+ flow = self.feature_flow_attn(
1354
+ feature0,
1355
+ flow.detach(),
1356
+ local_window_attn=prop_radius > 0,
1357
+ local_window_radius=prop_radius,
1358
+ )
1359
+
1360
+ # bilinear upsampling at training time except the last one
1361
+ if self.training and scale_idx < self.num_scales - 1:
1362
+ flow_up = self.upsample_flow(
1363
+ flow, feature0, bilinear=True, upsample_factor=upsample_factor
1364
+ )
1365
+
1366
+ if scale_idx == self.num_scales - 1:
1367
+ flow_up = self.upsample_flow(flow, feature0)
1368
+
1369
+ return flow_up
1370
+
1371
+
1372
+ backwarp_tenGrid = {}
1373
+
1374
+
1375
+ def backwarp(tenIn, tenflow):
1376
+ if str(tenflow.shape) not in backwarp_tenGrid:
1377
+ tenHor = (
1378
+ torch.linspace(
1379
+ start=-1.0,
1380
+ end=1.0,
1381
+ steps=tenflow.shape[3],
1382
+ dtype=tenflow.dtype,
1383
+ device=tenflow.device,
1384
+ )
1385
+ .view(1, 1, 1, -1)
1386
+ .repeat(1, 1, tenflow.shape[2], 1)
1387
+ )
1388
+ tenVer = (
1389
+ torch.linspace(
1390
+ start=-1.0,
1391
+ end=1.0,
1392
+ steps=tenflow.shape[2],
1393
+ dtype=tenflow.dtype,
1394
+ device=tenflow.device,
1395
+ )
1396
+ .view(1, 1, -1, 1)
1397
+ .repeat(1, 1, 1, tenflow.shape[3])
1398
+ )
1399
+
1400
+ backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).to(get_torch_device())
1401
+ # end
1402
+
1403
+ tenflow = torch.cat(
1404
+ [
1405
+ tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0),
1406
+ tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0),
1407
+ ],
1408
+ 1,
1409
+ )
1410
+
1411
+ return torch.nn.functional.grid_sample(
1412
+ input=tenIn,
1413
+ grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1),
1414
+ mode="bilinear",
1415
+ padding_mode="zeros",
1416
+ align_corners=True,
1417
+ )
1418
+
1419
+
1420
+ class MetricNet(nn.Module):
1421
+ def __init__(self):
1422
+ super(MetricNet, self).__init__()
1423
+ self.metric_in = nn.Conv2d(14, 64, 3, 1, 1)
1424
+ self.metric_net1 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
1425
+ self.metric_net2 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
1426
+ self.metric_net3 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
1427
+ self.metric_out = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 2, 3, 1, 1))
1428
+
1429
+ def forward(self, img0, img1, flow01, flow10):
1430
+ metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction="none").mean(
1431
+ [1], True
1432
+ )
1433
+ metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction="none").mean(
1434
+ [1], True
1435
+ )
1436
+
1437
+ fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10)
1438
+
1439
+ flow01 = torch.cat(
1440
+ [
1441
+ flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0),
1442
+ flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0),
1443
+ ],
1444
+ 1,
1445
+ )
1446
+ flow10 = torch.cat(
1447
+ [
1448
+ flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0),
1449
+ flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0),
1450
+ ],
1451
+ 1,
1452
+ )
1453
+
1454
+ img = torch.cat((img0, img1), 1)
1455
+ metric = torch.cat((-metric0, -metric1), 1)
1456
+ flow = torch.cat((flow01, flow10), 1)
1457
+ occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1)
1458
+
1459
+ feat = self.metric_in(torch.cat((img, metric, flow, occ), 1))
1460
+ feat = self.metric_net1(feat) + feat
1461
+ feat = self.metric_net2(feat) + feat
1462
+ feat = self.metric_net3(feat) + feat
1463
+ metric = self.metric_out(feat)
1464
+
1465
+ metric = torch.tanh(metric) * 10
1466
+
1467
+ return metric[:, :1], metric[:, 1:2]
1468
+
1469
+
1470
+ class FeatureNet(nn.Module):
1471
+ """The quadratic model"""
1472
+
1473
+ def __init__(self):
1474
+ super(FeatureNet, self).__init__()
1475
+ self.block1 = nn.Sequential(
1476
+ nn.PReLU(),
1477
+ nn.Conv2d(3, 64, 3, 2, 1),
1478
+ nn.PReLU(),
1479
+ nn.Conv2d(64, 64, 3, 1, 1),
1480
+ )
1481
+ self.block2 = nn.Sequential(
1482
+ nn.PReLU(),
1483
+ nn.Conv2d(64, 128, 3, 2, 1),
1484
+ nn.PReLU(),
1485
+ nn.Conv2d(128, 128, 3, 1, 1),
1486
+ )
1487
+ self.block3 = nn.Sequential(
1488
+ nn.PReLU(),
1489
+ nn.Conv2d(128, 192, 3, 2, 1),
1490
+ nn.PReLU(),
1491
+ nn.Conv2d(192, 192, 3, 1, 1),
1492
+ )
1493
+
1494
+ def forward(self, x):
1495
+ x1 = self.block1(x)
1496
+ x2 = self.block2(x1)
1497
+ x3 = self.block3(x2)
1498
+
1499
+ return x1, x2, x3
1500
+
1501
+
1502
+ # Residual Block
1503
+ def ResidualBlock(in_channels, out_channels, stride=1):
1504
+ return torch.nn.Sequential(
1505
+ nn.PReLU(),
1506
+ nn.Conv2d(
1507
+ in_channels,
1508
+ out_channels,
1509
+ kernel_size=3,
1510
+ stride=stride,
1511
+ padding=1,
1512
+ bias=True,
1513
+ ),
1514
+ nn.PReLU(),
1515
+ nn.Conv2d(
1516
+ out_channels,
1517
+ out_channels,
1518
+ kernel_size=3,
1519
+ stride=stride,
1520
+ padding=1,
1521
+ bias=True,
1522
+ ),
1523
+ )
1524
+
1525
+
1526
+ # downsample block
1527
+ def DownsampleBlock(in_channels, out_channels, stride=2):
1528
+ return torch.nn.Sequential(
1529
+ nn.PReLU(),
1530
+ nn.Conv2d(
1531
+ in_channels,
1532
+ out_channels,
1533
+ kernel_size=3,
1534
+ stride=stride,
1535
+ padding=1,
1536
+ bias=True,
1537
+ ),
1538
+ nn.PReLU(),
1539
+ nn.Conv2d(
1540
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
1541
+ ),
1542
+ )
1543
+
1544
+
1545
+ # upsample block
1546
+ def UpsampleBlock(in_channels, out_channels, stride=2):
1547
+ return torch.nn.Sequential(
1548
+ nn.PReLU(),
1549
+ nn.ConvTranspose2d(
1550
+ in_channels,
1551
+ out_channels,
1552
+ kernel_size=4,
1553
+ stride=stride,
1554
+ padding=1,
1555
+ bias=True,
1556
+ ),
1557
+ nn.PReLU(),
1558
+ nn.Conv2d(
1559
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
1560
+ ),
1561
+ )
1562
+
1563
+
1564
+ class PixelShuffleBlcok(nn.Module):
1565
+ def __init__(self, in_feat, num_feat, num_out_ch):
1566
+ super(PixelShuffleBlcok, self).__init__()
1567
+ self.conv_before_upsample = nn.Sequential(
1568
+ nn.Conv2d(in_feat, num_feat, 3, 1, 1), nn.PReLU()
1569
+ )
1570
+ self.upsample = nn.Sequential(
1571
+ nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), nn.PixelShuffle(2)
1572
+ )
1573
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1574
+
1575
+ def forward(self, x):
1576
+ x = self.conv_before_upsample(x)
1577
+ x = self.conv_last(self.upsample(x))
1578
+ return x
1579
+
1580
+
1581
+ # grid network
1582
+ class GridNet(nn.Module):
1583
+ def __init__(
1584
+ self,
1585
+ in_channels=9,
1586
+ in_channels1=128,
1587
+ in_channels2=256,
1588
+ in_channels3=384,
1589
+ out_channels=3,
1590
+ ):
1591
+ super(GridNet, self).__init__()
1592
+
1593
+ self.residual_model_head0 = ResidualBlock(in_channels, 64)
1594
+ self.residual_model_head1 = ResidualBlock(in_channels1, 64)
1595
+ self.residual_model_head2 = ResidualBlock(in_channels2, 128)
1596
+ self.residual_model_head3 = ResidualBlock(in_channels3, 192)
1597
+
1598
+ self.residual_model_01 = ResidualBlock(64, 64)
1599
+ # self.residual_model_02=ResidualBlock(64, 64)
1600
+ # self.residual_model_03=ResidualBlock(64, 64)
1601
+ self.residual_model_04 = ResidualBlock(64, 64)
1602
+ self.residual_model_05 = ResidualBlock(64, 64)
1603
+ self.residual_model_tail = PixelShuffleBlcok(64, 64, out_channels)
1604
+
1605
+ self.residual_model_11 = ResidualBlock(128, 128)
1606
+ # self.residual_model_12=ResidualBlock(128, 128)
1607
+ # self.residual_model_13=ResidualBlock(128, 128)
1608
+ self.residual_model_14 = ResidualBlock(128, 128)
1609
+ self.residual_model_15 = ResidualBlock(128, 128)
1610
+
1611
+ self.residual_model_21 = ResidualBlock(192, 192)
1612
+ # self.residual_model_22=ResidualBlock(192, 192)
1613
+ # self.residual_model_23=ResidualBlock(192, 192)
1614
+ self.residual_model_24 = ResidualBlock(192, 192)
1615
+ self.residual_model_25 = ResidualBlock(192, 192)
1616
+
1617
+ #
1618
+
1619
+ self.downsample_model_10 = DownsampleBlock(64, 128)
1620
+ self.downsample_model_20 = DownsampleBlock(128, 192)
1621
+
1622
+ self.downsample_model_11 = DownsampleBlock(64, 128)
1623
+ self.downsample_model_21 = DownsampleBlock(128, 192)
1624
+
1625
+ # self.downsample_model_12=DownsampleBlock(64, 128)
1626
+ # self.downsample_model_22=DownsampleBlock(128, 192)
1627
+
1628
+ #
1629
+
1630
+ # self.upsample_model_03=UpsampleBlock(128, 64)
1631
+ # self.upsample_model_13=UpsampleBlock(192, 128)
1632
+
1633
+ self.upsample_model_04 = UpsampleBlock(128, 64)
1634
+ self.upsample_model_14 = UpsampleBlock(192, 128)
1635
+
1636
+ self.upsample_model_05 = UpsampleBlock(128, 64)
1637
+ self.upsample_model_15 = UpsampleBlock(192, 128)
1638
+
1639
+ def forward(self, x, x1, x2, x3):
1640
+ X00 = self.residual_model_head0(x) + self.residual_model_head1(
1641
+ x1
1642
+ ) # --- 182 ~ 185
1643
+ # X10 = self.residual_model_head1(x1)
1644
+
1645
+ X01 = self.residual_model_01(X00) + X00 # --- 208 ~ 211 ,AddBackward1213
1646
+
1647
+ X10 = self.downsample_model_10(X00) + self.residual_model_head2(
1648
+ x2
1649
+ ) # --- 186 ~ 189
1650
+ X20 = self.downsample_model_20(X10) + self.residual_model_head3(
1651
+ x3
1652
+ ) # --- 190 ~ 193
1653
+
1654
+ residual_11 = (
1655
+ self.residual_model_11(X10) + X10
1656
+ ) # 201 ~ 204 , sum AddBackward1206
1657
+ downsample_11 = self.downsample_model_11(X01) # 214 ~ 217
1658
+ X11 = residual_11 + downsample_11 # --- AddBackward1218
1659
+
1660
+ residual_21 = (
1661
+ self.residual_model_21(X20) + X20
1662
+ ) # 194 ~ 197 , sum AddBackward1199
1663
+ downsample_21 = self.downsample_model_21(X11) # 219 ~ 222
1664
+ X21 = residual_21 + downsample_21 # AddBackward1223
1665
+
1666
+ X24 = self.residual_model_24(X21) + X21 # --- 224 ~ 227 , AddBackward1229
1667
+ X25 = self.residual_model_25(X24) + X24 # --- 230 ~ 233 , AddBackward1235
1668
+
1669
+ upsample_14 = self.upsample_model_14(X24) # 242 ~ 246
1670
+ residual_14 = self.residual_model_14(X11) + X11 # 248 ~ 251, AddBackward1253
1671
+ X14 = upsample_14 + residual_14 # --- AddBackward1254
1672
+
1673
+ upsample_04 = self.upsample_model_04(X14) # 268 ~ 272
1674
+ residual_04 = self.residual_model_04(X01) + X01 # 274 ~ 277, AddBackward1279
1675
+ X04 = upsample_04 + residual_04 # --- AddBackward1280
1676
+
1677
+ upsample_15 = self.upsample_model_15(X25) # 236 ~ 240
1678
+ residual_15 = self.residual_model_15(X14) + X14 # 255 ~ 258, AddBackward1260
1679
+ X15 = upsample_15 + residual_15 # AddBackward1261
1680
+
1681
+ upsample_05 = self.upsample_model_05(X15) # 262 ~ 266
1682
+ residual_05 = self.residual_model_05(X04) + X04 # 281 ~ 284,AddBackward1286
1683
+ X05 = upsample_05 + residual_05 # AddBackward1287
1684
+
1685
+ X_tail = self.residual_model_tail(X05) # 288 ~ 291
1686
+
1687
+ return X_tail
1688
+ # end
1689
+
1690
+
1691
+ class Model:
1692
+ def __init__(self):
1693
+ self.flownet = GMFlow()
1694
+ self.ifnet = IFNet(arch_ver="4.6")
1695
+ self.metricnet = MetricNet()
1696
+ self.feat_ext = FeatureNet()
1697
+ self.fusionnet = GridNet()
1698
+ self.version = 3.9
1699
+
1700
+ def eval(self):
1701
+ self.flownet.eval()
1702
+ self.ifnet.eval()
1703
+ self.metricnet.eval()
1704
+ self.feat_ext.eval()
1705
+ self.fusionnet.eval()
1706
+
1707
+ def device(self):
1708
+ self.flownet.to(device)
1709
+ self.ifnet.to(device)
1710
+ self.metricnet.to(device)
1711
+ self.feat_ext.to(device)
1712
+ self.fusionnet.to(device)
1713
+
1714
+ def load_model(self, path_dict):
1715
+ #models/rife46.pth
1716
+ self.ifnet.load_state_dict(torch.load(path_dict["ifnet"]))
1717
+ #models/GMFSS_fortuna_flownet.pkl
1718
+ self.flownet.load_state_dict(torch.load(path_dict["flownet"]))
1719
+ #models/GMFSS_fortuna_union_metric.pkl
1720
+ self.metricnet.load_state_dict(torch.load(path_dict["metricnet"]))
1721
+ #models/GMFSS_fortuna_union_feat.pkl
1722
+ self.feat_ext.load_state_dict(torch.load(path_dict["feat_ext"]))
1723
+ #models/GMFSS_fortuna_union_fusionnet.pkl
1724
+ self.fusionnet.load_state_dict(torch.load(path_dict["fusionnet"]))
1725
+
1726
+ def reuse(self, img0, img1, scale):
1727
+ feat11, feat12, feat13 = self.feat_ext(img0)
1728
+ feat21, feat22, feat23 = self.feat_ext(img1)
1729
+
1730
+ img0 = F.interpolate(
1731
+ img0, scale_factor=0.5, mode="bilinear", align_corners=False
1732
+ )
1733
+ img1 = F.interpolate(
1734
+ img1, scale_factor=0.5, mode="bilinear", align_corners=False
1735
+ )
1736
+
1737
+ if scale != 1.0:
1738
+ imgf0 = F.interpolate(
1739
+ img0, scale_factor=scale, mode="bilinear", align_corners=False
1740
+ )
1741
+ imgf1 = F.interpolate(
1742
+ img1, scale_factor=scale, mode="bilinear", align_corners=False
1743
+ )
1744
+ else:
1745
+ imgf0 = img0
1746
+ imgf1 = img1
1747
+ flow01 = self.flownet(imgf0, imgf1, return_flow=True)
1748
+ flow10 = self.flownet(imgf1, imgf0, return_flow=True)
1749
+ if scale != 1.0:
1750
+ flow01 = (
1751
+ F.interpolate(
1752
+ flow01,
1753
+ scale_factor=1.0 / scale,
1754
+ mode="bilinear",
1755
+ align_corners=False,
1756
+ )
1757
+ / scale
1758
+ )
1759
+ flow10 = (
1760
+ F.interpolate(
1761
+ flow10,
1762
+ scale_factor=1.0 / scale,
1763
+ mode="bilinear",
1764
+ align_corners=False,
1765
+ )
1766
+ / scale
1767
+ )
1768
+
1769
+ metric0, metric1 = self.metricnet(img0, img1, flow01, flow10)
1770
+
1771
+ return (
1772
+ flow01,
1773
+ flow10,
1774
+ metric0,
1775
+ metric1,
1776
+ feat11,
1777
+ feat12,
1778
+ feat13,
1779
+ feat21,
1780
+ feat22,
1781
+ feat23,
1782
+ )
1783
+
1784
+ def inference(
1785
+ self,
1786
+ img0,
1787
+ img1,
1788
+ flow01,
1789
+ flow10,
1790
+ metric0,
1791
+ metric1,
1792
+ feat11,
1793
+ feat12,
1794
+ feat13,
1795
+ feat21,
1796
+ feat22,
1797
+ feat23,
1798
+ timestep,
1799
+ ):
1800
+ F1t = timestep * flow01
1801
+ F2t = (1 - timestep) * flow10
1802
+
1803
+ Z1t = timestep * metric0
1804
+ Z2t = (1 - timestep) * metric1
1805
+
1806
+ img0 = F.interpolate(
1807
+ img0, scale_factor=0.5, mode="bilinear", align_corners=False
1808
+ )
1809
+ I1t = softsplat(img0, F1t, Z1t, strMode="soft")
1810
+ img1 = F.interpolate(
1811
+ img1, scale_factor=0.5, mode="bilinear", align_corners=False
1812
+ )
1813
+ I2t = softsplat(img1, F2t, Z2t, strMode="soft")
1814
+
1815
+ rife = self.ifnet(img0, img1, timestep, scale_list=[8, 4, 2, 1])
1816
+
1817
+ feat1t1 = softsplat(feat11, F1t, Z1t, strMode="soft")
1818
+ feat2t1 = softsplat(feat21, F2t, Z2t, strMode="soft")
1819
+
1820
+ F1td = (
1821
+ F.interpolate(F1t, scale_factor=0.5, mode="bilinear", align_corners=False)
1822
+ * 0.5
1823
+ )
1824
+ Z1d = F.interpolate(Z1t, scale_factor=0.5, mode="bilinear", align_corners=False)
1825
+ feat1t2 = softsplat(feat12, F1td, Z1d, strMode="soft")
1826
+ F2td = (
1827
+ F.interpolate(F2t, scale_factor=0.5, mode="bilinear", align_corners=False)
1828
+ * 0.5
1829
+ )
1830
+ Z2d = F.interpolate(Z2t, scale_factor=0.5, mode="bilinear", align_corners=False)
1831
+ feat2t2 = softsplat(feat22, F2td, Z2d, strMode="soft")
1832
+
1833
+ F1tdd = (
1834
+ F.interpolate(F1t, scale_factor=0.25, mode="bilinear", align_corners=False)
1835
+ * 0.25
1836
+ )
1837
+ Z1dd = F.interpolate(
1838
+ Z1t, scale_factor=0.25, mode="bilinear", align_corners=False
1839
+ )
1840
+ feat1t3 = softsplat(feat13, F1tdd, Z1dd, strMode="soft")
1841
+ F2tdd = (
1842
+ F.interpolate(F2t, scale_factor=0.25, mode="bilinear", align_corners=False)
1843
+ * 0.25
1844
+ )
1845
+ Z2dd = F.interpolate(
1846
+ Z2t, scale_factor=0.25, mode="bilinear", align_corners=False
1847
+ )
1848
+ feat2t3 = softsplat(feat23, F2tdd, Z2dd, strMode="soft")
1849
+
1850
+ out = self.fusionnet(
1851
+ torch.cat([I1t, rife, I2t], dim=1),
1852
+ torch.cat([feat1t1, feat2t1], dim=1),
1853
+ torch.cat([feat1t2, feat2t2], dim=1),
1854
+ torch.cat([feat1t3, feat2t3], dim=1),
1855
+ )
1856
+
1857
+ return torch.clamp(out, 0, 1)
vfi_models/gmfss_fortuna/__init__.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
3
+ import typing
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from comfy.model_management import get_torch_device
8
+
9
+
10
+ GLOBAL_MODEL_TYPE = pathlib.Path(__file__).parent.name
11
+ CKPTS_PATH_CONFIG = {
12
+ "GMFSS_fortuna_union": {
13
+ "ifnet": ("rife", "rife46.pth"),
14
+ "flownet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_flownet.pkl"),
15
+ "metricnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_metric.pkl"),
16
+ "feat_ext": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_feat.pkl"),
17
+ "fusionnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_fusionnet.pkl")
18
+ },
19
+ "GMFSS_fortuna": {
20
+ "flownet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_flownet.pkl"),
21
+ "metricnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_metric.pkl"),
22
+ "feat_ext": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_feat.pkl"),
23
+ "fusionnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_fusionnet.pkl")
24
+ }
25
+ }
26
+
27
+ class CommonModelInference(nn.Module):
28
+ def __init__(self, model_type):
29
+ super(CommonModelInference, self).__init__()
30
+ from .GMFSS_Fortuna_arch import Model as GMFSS
31
+ from .GMFSS_Fortuna_union_arch import Model as GMFSS_Union
32
+ self.model = GMFSS_Union() if "union" in model_type else GMFSS()
33
+ self.model.eval()
34
+ self.model.device()
35
+ _model_path_config = CKPTS_PATH_CONFIG[model_type]
36
+ self.model.load_model({
37
+ key: load_file_from_github_release(*_model_path_config[key])
38
+ for key in _model_path_config
39
+ })
40
+
41
+ def forward(self, I0, I1, timestep, scale=1.0):
42
+ n, c, h, w = I0.shape
43
+ tmp = max(64, int(64 / scale))
44
+ ph = ((h - 1) // tmp + 1) * tmp
45
+ pw = ((w - 1) // tmp + 1) * tmp
46
+ padding = (0, pw - w, 0, ph - h)
47
+ I0 = F.pad(I0, padding)
48
+ I1 = F.pad(I1, padding)
49
+ (
50
+ flow01,
51
+ flow10,
52
+ metric0,
53
+ metric1,
54
+ feat11,
55
+ feat12,
56
+ feat13,
57
+ feat21,
58
+ feat22,
59
+ feat23,
60
+ ) = self.model.reuse(I0, I1, scale)
61
+
62
+ output = self.model.inference(
63
+ I0,
64
+ I1,
65
+ flow01,
66
+ flow10,
67
+ metric0,
68
+ metric1,
69
+ feat11,
70
+ feat12,
71
+ feat13,
72
+ feat21,
73
+ feat22,
74
+ feat23,
75
+ timestep,
76
+ )
77
+ return output[:, :, :h, :w]
78
+
79
+ class GMFSS_Fortuna_VFI:
80
+ @classmethod
81
+ def INPUT_TYPES(s):
82
+ return {
83
+ "required": {
84
+ "ckpt_name": (list(CKPTS_PATH_CONFIG.keys()), ),
85
+ "frames": ("IMAGE", ),
86
+ "clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
87
+ "multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
88
+ },
89
+ "optional": {
90
+ "optional_interpolation_states": ("INTERPOLATION_STATES", )
91
+ }
92
+ }
93
+
94
+ RETURN_TYPES = ("IMAGE", )
95
+ FUNCTION = "vfi"
96
+ CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
97
+
98
+ def vfi(
99
+ self,
100
+ ckpt_name: typing.AnyStr,
101
+ frames: torch.Tensor,
102
+ clear_cache_after_n_frames = 10,
103
+ multiplier: typing.SupportsInt = 2,
104
+ optional_interpolation_states: InterpolationStateList = None,
105
+ **kwargs
106
+ ):
107
+ """
108
+ Perform video frame interpolation using a given checkpoint model.
109
+
110
+ Args:
111
+ ckpt_name (str): The name of the checkpoint model to use.
112
+ frames (torch.Tensor): A tensor containing input video frames.
113
+ clear_cache_after_n_frames (int, optional): The number of frames to process before clearing CUDA cache
114
+ to prevent memory overflow. Defaults to 10. Lower numbers are safer but mean more processing time.
115
+ How high you should set it depends on how many input frames there are, input resolution (after upscaling),
116
+ how many times you want to multiply them, and how long you're willing to wait for the process to complete.
117
+ multiplier (int, optional): The multiplier for each input frame. 60 input frames * 2 = 120 output frames. Defaults to 2.
118
+
119
+ Returns:
120
+ tuple: A tuple containing the output interpolated frames.
121
+
122
+ Note:
123
+ This method interpolates frames in a video sequence using a specified checkpoint model.
124
+ It processes each frame sequentially, generating interpolated frames between them.
125
+
126
+ To prevent memory overflow, it clears the CUDA cache after processing a specified number of frames.
127
+ """
128
+
129
+ interpolation_model = CommonModelInference(model_type=ckpt_name)
130
+ interpolation_model.eval().to(get_torch_device())
131
+ frames = preprocess_frames(frames)
132
+
133
+ def return_middle_frame(frame_0, frame_1, timestep, model, scale):
134
+ return model(frame_0, frame_1, timestep, scale)
135
+
136
+ scale = 1
137
+
138
+ args = [interpolation_model, scale]
139
+ out = postprocess_frames(
140
+ generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
141
+ interpolation_states=optional_interpolation_states, dtype=torch.float32)
142
+ )
143
+ return (out,)
vfi_models/ifrnet/IFRNet_L_arch.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/ltkong218/IFRNet/blob/main/models/IFRNet_L.py
2
+ # https://github.com/ltkong218/IFRNet/blob/main/utils.py
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from comfy.model_management import get_torch_device
7
+
8
+
9
+ def warp(img, flow):
10
+ B, _, H, W = flow.shape
11
+ xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
12
+ yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
13
+ grid = torch.cat([xx, yy], 1).to(img)
14
+ flow_ = torch.cat(
15
+ [
16
+ flow[:, 0:1, :, :] / ((W - 1.0) / 2.0),
17
+ flow[:, 1:2, :, :] / ((H - 1.0) / 2.0),
18
+ ],
19
+ 1,
20
+ )
21
+ grid_ = (grid + flow_).permute(0, 2, 3, 1)
22
+ output = F.grid_sample(
23
+ input=img,
24
+ grid=grid_,
25
+ mode="bilinear",
26
+ padding_mode="border",
27
+ align_corners=True,
28
+ )
29
+ return output
30
+
31
+
32
+ def get_robust_weight(flow_pred, flow_gt, beta):
33
+ epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5
34
+ robust_weight = torch.exp(-beta * epe)
35
+ return robust_weight
36
+
37
+
38
+ def resize(x, scale_factor):
39
+ return F.interpolate(
40
+ x, scale_factor=scale_factor, mode="bilinear", align_corners=False
41
+ )
42
+
43
+
44
+ def convrelu(
45
+ in_channels,
46
+ out_channels,
47
+ kernel_size=3,
48
+ stride=1,
49
+ padding=1,
50
+ dilation=1,
51
+ groups=1,
52
+ bias=True,
53
+ ):
54
+ return nn.Sequential(
55
+ nn.Conv2d(
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size,
59
+ stride,
60
+ padding,
61
+ dilation,
62
+ groups,
63
+ bias=bias,
64
+ ),
65
+ nn.PReLU(out_channels),
66
+ )
67
+
68
+
69
+ class ResBlock(nn.Module):
70
+ def __init__(self, in_channels, side_channels, bias=True):
71
+ super(ResBlock, self).__init__()
72
+ self.side_channels = side_channels
73
+ self.conv1 = nn.Sequential(
74
+ nn.Conv2d(
75
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
76
+ ),
77
+ nn.PReLU(in_channels),
78
+ )
79
+ self.conv2 = nn.Sequential(
80
+ nn.Conv2d(
81
+ side_channels,
82
+ side_channels,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1,
86
+ bias=bias,
87
+ ),
88
+ nn.PReLU(side_channels),
89
+ )
90
+ self.conv3 = nn.Sequential(
91
+ nn.Conv2d(
92
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
93
+ ),
94
+ nn.PReLU(in_channels),
95
+ )
96
+ self.conv4 = nn.Sequential(
97
+ nn.Conv2d(
98
+ side_channels,
99
+ side_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1,
103
+ bias=bias,
104
+ ),
105
+ nn.PReLU(side_channels),
106
+ )
107
+ self.conv5 = nn.Conv2d(
108
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
109
+ )
110
+ self.prelu = nn.PReLU(in_channels)
111
+
112
+ def forward(self, x):
113
+ out = self.conv1(x)
114
+ out[:, -self.side_channels :, :, :] = self.conv2(
115
+ out[:, -self.side_channels :, :, :]
116
+ )
117
+ out = self.conv3(out)
118
+ out[:, -self.side_channels :, :, :] = self.conv4(
119
+ out[:, -self.side_channels :, :, :]
120
+ )
121
+ out = self.prelu(x + self.conv5(out))
122
+ return out
123
+
124
+
125
+ class Encoder(nn.Module):
126
+ def __init__(self):
127
+ super(Encoder, self).__init__()
128
+ self.pyramid1 = nn.Sequential(
129
+ convrelu(3, 64, 7, 2, 3), convrelu(64, 64, 3, 1, 1)
130
+ )
131
+ self.pyramid2 = nn.Sequential(
132
+ convrelu(64, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1)
133
+ )
134
+ self.pyramid3 = nn.Sequential(
135
+ convrelu(96, 144, 3, 2, 1), convrelu(144, 144, 3, 1, 1)
136
+ )
137
+ self.pyramid4 = nn.Sequential(
138
+ convrelu(144, 192, 3, 2, 1), convrelu(192, 192, 3, 1, 1)
139
+ )
140
+
141
+ def forward(self, img):
142
+ f1 = self.pyramid1(img)
143
+ f2 = self.pyramid2(f1)
144
+ f3 = self.pyramid3(f2)
145
+ f4 = self.pyramid4(f3)
146
+ return f1, f2, f3, f4
147
+
148
+
149
+ class Decoder4(nn.Module):
150
+ def __init__(self):
151
+ super(Decoder4, self).__init__()
152
+ self.convblock = nn.Sequential(
153
+ convrelu(384 + 1, 384),
154
+ ResBlock(384, 64),
155
+ nn.ConvTranspose2d(384, 148, 4, 2, 1, bias=True),
156
+ )
157
+
158
+ def forward(self, f0, f1, embt):
159
+ b, c, h, w = f0.shape
160
+ embt = embt.repeat(1, 1, h, w)
161
+ f_in = torch.cat([f0, f1, embt], 1)
162
+ f_out = self.convblock(f_in)
163
+ return f_out
164
+
165
+
166
+ class Decoder3(nn.Module):
167
+ def __init__(self):
168
+ super(Decoder3, self).__init__()
169
+ self.convblock = nn.Sequential(
170
+ convrelu(436, 432),
171
+ ResBlock(432, 64),
172
+ nn.ConvTranspose2d(432, 100, 4, 2, 1, bias=True),
173
+ )
174
+
175
+ def forward(self, ft_, f0, f1, up_flow0, up_flow1):
176
+ f0_warp = warp(f0, up_flow0)
177
+ f1_warp = warp(f1, up_flow1)
178
+ f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
179
+ f_out = self.convblock(f_in)
180
+ return f_out
181
+
182
+
183
+ class Decoder2(nn.Module):
184
+ def __init__(self):
185
+ super(Decoder2, self).__init__()
186
+ self.convblock = nn.Sequential(
187
+ convrelu(292, 288),
188
+ ResBlock(288, 64),
189
+ nn.ConvTranspose2d(288, 68, 4, 2, 1, bias=True),
190
+ )
191
+
192
+ def forward(self, ft_, f0, f1, up_flow0, up_flow1):
193
+ f0_warp = warp(f0, up_flow0)
194
+ f1_warp = warp(f1, up_flow1)
195
+ f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
196
+ f_out = self.convblock(f_in)
197
+ return f_out
198
+
199
+
200
+ class Decoder1(nn.Module):
201
+ def __init__(self):
202
+ super(Decoder1, self).__init__()
203
+ self.convblock = nn.Sequential(
204
+ convrelu(196, 192),
205
+ ResBlock(192, 64),
206
+ nn.ConvTranspose2d(192, 8, 4, 2, 1, bias=True),
207
+ )
208
+
209
+ def forward(self, ft_, f0, f1, up_flow0, up_flow1):
210
+ f0_warp = warp(f0, up_flow0)
211
+ f1_warp = warp(f1, up_flow1)
212
+ f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
213
+ f_out = self.convblock(f_in)
214
+ return f_out
215
+
216
+
217
+ class IRFNet_L(nn.Module):
218
+ def __init__(self):
219
+ super(IRFNet_L, self).__init__()
220
+ self.encoder = Encoder()
221
+ self.decoder4 = Decoder4()
222
+ self.decoder3 = Decoder3()
223
+ self.decoder2 = Decoder2()
224
+ self.decoder1 = Decoder1()
225
+
226
+ def forward(self, img0, img1, scale_factor=1.0, timestep=0.5):
227
+ # emb1 = torch.tensor(1/2).view(1, 1, 1, 1).float()
228
+ # emb2 = torch.tensor(2/2).view(1, 1, 1, 1).float()
229
+ # embt = torch.cat([emb1, emb2], 0)
230
+ n, c, h, w = img0.shape
231
+
232
+ ph = ((h - 1) // 64 + 1) * 64
233
+ pw = ((w - 1) // 64 + 1) * 64
234
+ padding = (0, pw - w, 0, ph - h)
235
+ img0 = F.pad(img0, padding)
236
+ img1 = F.pad(img1, padding)
237
+
238
+ #Support multiple batches
239
+ embt = torch.tensor([timestep] * n).view(n, 1, 1, 1).float().to(get_torch_device())
240
+ if "HalfTensor" in str(img0.type()):
241
+ embt = embt.half()
242
+
243
+ mean_ = (
244
+ torch.cat([img0, img1], 2)
245
+ .mean(1, keepdim=True)
246
+ .mean(2, keepdim=True)
247
+ .mean(3, keepdim=True)
248
+ )
249
+ img0 = img0 - mean_
250
+ img1 = img1 - mean_
251
+
252
+ img0_ = resize(img0, scale_factor=scale_factor)
253
+ img1_ = resize(img1, scale_factor=scale_factor)
254
+
255
+ f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
256
+ f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
257
+
258
+ out4 = self.decoder4(f0_4, f1_4, embt)
259
+ up_flow0_4 = out4[:, 0:2]
260
+ up_flow1_4 = out4[:, 2:4]
261
+ ft_3_ = out4[:, 4:]
262
+
263
+ out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
264
+ up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
265
+ up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
266
+ ft_2_ = out3[:, 4:]
267
+
268
+ out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
269
+ up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
270
+ up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
271
+ ft_1_ = out2[:, 4:]
272
+
273
+ out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
274
+ up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
275
+ up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
276
+ up_mask_1 = torch.sigmoid(out1[:, 4:5])
277
+ up_res_1 = out1[:, 5:]
278
+
279
+ up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
280
+ 1.0 / scale_factor
281
+ )
282
+ up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
283
+ 1.0 / scale_factor
284
+ )
285
+ up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor))
286
+ up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor))
287
+
288
+ img0_warp = warp(img0, up_flow0_1)
289
+ img1_warp = warp(img1, up_flow1_1)
290
+ imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
291
+ imgt_pred = imgt_merge + up_res_1
292
+ imgt_pred = torch.clamp(imgt_pred, 0, 1)
293
+ return imgt_pred[:, :, :h, :w]