Mirror from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +15 -0
- .github/workflows/publish.yml +25 -0
- .gitignore +3 -0
- All_in_one_v1_3.png +3 -0
- LICENSE +21 -0
- README.md +194 -0
- __init__.py +42 -0
- config.yaml +3 -0
- demo_frames/anime0.png +3 -0
- demo_frames/anime1.png +3 -0
- demo_frames/bocchi0.jpg +3 -0
- demo_frames/bocchi1.jpg +3 -0
- demo_frames/real0.png +3 -0
- demo_frames/real1.png +3 -0
- demo_frames/rick/00003.png +3 -0
- demo_frames/rick/00004.png +3 -0
- demo_frames/rick/00005.png +3 -0
- demo_frames/violet0.png +3 -0
- demo_frames/violet1.png +3 -0
- example.png +3 -0
- install-taichi.bat +11 -0
- install.bat +16 -0
- install.py +59 -0
- interpolation_schedule.png +3 -0
- other_nodes.py +88 -0
- pyproject.toml +13 -0
- requirements-no-cupy.txt +9 -0
- requirements-with-cupy.txt +10 -0
- test.py +38 -0
- test_vfi_schedule.gif +3 -0
- vfi_models/amt/__init__.py +87 -0
- vfi_models/amt/amt_arch.py +1590 -0
- vfi_models/cain/__init__.py +64 -0
- vfi_models/cain/cain_arch.py +74 -0
- vfi_models/cain/cain_encdec_arch.py +95 -0
- vfi_models/cain/cain_noca_arch.py +73 -0
- vfi_models/cain/common.py +361 -0
- vfi_models/eisai/__init__.py +84 -0
- vfi_models/eisai/eisai_arch.py +2586 -0
- vfi_models/film/__init__.py +113 -0
- vfi_models/film/film_arch.py +798 -0
- vfi_models/flavr/__init__.py +115 -0
- vfi_models/flavr/flavr_arch.py +217 -0
- vfi_models/flavr/resnet_3D.py +288 -0
- vfi_models/gmfss_fortuna/GMFSS_Fortuna.py +24 -0
- vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py +1850 -0
- vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py +23 -0
- vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py +1857 -0
- vfi_models/gmfss_fortuna/__init__.py +143 -0
- vfi_models/ifrnet/IFRNet_L_arch.py +293 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
All_in_one_v1_3.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
demo_frames/anime0.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo_frames/anime1.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
demo_frames/bocchi0.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
demo_frames/bocchi1.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
demo_frames/real0.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
demo_frames/real1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
demo_frames/rick/00003.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
demo_frames/rick/00004.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
demo_frames/rick/00005.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
demo_frames/violet0.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
demo_frames/violet1.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
example.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
interpolation_schedule.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
test_vfi_schedule.gif filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/publish.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Publish to Comfy registry
|
| 2 |
+
on:
|
| 3 |
+
workflow_dispatch:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
paths:
|
| 8 |
+
- "pyproject.toml"
|
| 9 |
+
|
| 10 |
+
permissions:
|
| 11 |
+
issues: write
|
| 12 |
+
|
| 13 |
+
jobs:
|
| 14 |
+
publish-node:
|
| 15 |
+
name: Publish Custom Node to registry
|
| 16 |
+
runs-on: ubuntu-latest
|
| 17 |
+
if: ${{ github.repository_owner == 'Fannovel16' }}
|
| 18 |
+
steps:
|
| 19 |
+
- name: Check out code
|
| 20 |
+
uses: actions/checkout@v4
|
| 21 |
+
- name: Publish Custom Node
|
| 22 |
+
uses: Comfy-Org/publish-node-action@v1
|
| 23 |
+
with:
|
| 24 |
+
## Add your own personal access token to your Github Repository secrets and reference it here.
|
| 25 |
+
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ckpts
|
| 2 |
+
__pycache__
|
| 3 |
+
test_result
|
All_in_one_v1_3.png
ADDED
|
Git LFS Details
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Fannovel16
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI Frame Interpolation (ComfyUI VFI) (WIP)
|
| 2 |
+
|
| 3 |
+
A custom node set for Video Frame Interpolation in ComfyUI.
|
| 4 |
+
**UPDATE** Memory management is improved. Now this extension takes less RAM and VRAM than before.
|
| 5 |
+
|
| 6 |
+
**UPDATE 2** VFI nodes now accept scheduling multipiler values
|
| 7 |
+
|
| 8 |
+

|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
## Nodes
|
| 12 |
+
* KSampler Gradually Adding More Denoise (efficient)
|
| 13 |
+
* GMFSS Fortuna VFI
|
| 14 |
+
* IFRNet VFI
|
| 15 |
+
* IFUnet VFI
|
| 16 |
+
* M2M VFI
|
| 17 |
+
* RIFE VFI (4.0 - 4.9) (Note that option `fast_mode` won't do anything from v4.5+ as `contextnet` is removed)
|
| 18 |
+
* FILM VFI
|
| 19 |
+
* Sepconv VFI
|
| 20 |
+
* AMT VFI
|
| 21 |
+
* Make Interpolation State List
|
| 22 |
+
* STMFNet VFI (requires at least 4 frames, can only do 2x interpolation for now)
|
| 23 |
+
* FLAVR VFI (same conditions as STMFNet)
|
| 24 |
+
|
| 25 |
+
## Install
|
| 26 |
+
### ComfyUI Manager
|
| 27 |
+
Incompatibile issue with it is now fixed
|
| 28 |
+
|
| 29 |
+
Following this guide to install this extension
|
| 30 |
+
|
| 31 |
+
https://github.com/ltdrdata/ComfyUI-Manager#how-to-use
|
| 32 |
+
### Command-line
|
| 33 |
+
#### Windows
|
| 34 |
+
Run install.bat
|
| 35 |
+
|
| 36 |
+
For Window users, if you are having trouble with cupy, please run `install.bat` instead of `install-cupy.py` or `python install.py`.
|
| 37 |
+
#### Linux
|
| 38 |
+
Open your shell app and start venv if it is used for ComfyUI. Run:
|
| 39 |
+
```
|
| 40 |
+
python install.py
|
| 41 |
+
```
|
| 42 |
+
## Support for non-CUDA device (experimental)
|
| 43 |
+
If you don't have a NVidia card, you can try `taichi` ops backend powered by [Taichi Lang](https://www.taichi-lang.org/)
|
| 44 |
+
|
| 45 |
+
On Windows, you can install it by running `install.bat` or `pip install taichi` on Linux
|
| 46 |
+
|
| 47 |
+
Then change value of `ops_backend` from `cupy` to `taichi` in `config.yaml`
|
| 48 |
+
|
| 49 |
+
If `NotImplementedError` appears, a VFI node in the workflow isn't supported by taichi
|
| 50 |
+
|
| 51 |
+
## Usage
|
| 52 |
+
All VFI nodes can be accessed in **category** `ComfyUI-Frame-Interpolation/VFI` if the installation is successful and require a `IMAGE` containing frames (at least 2, or at least 4 for STMF-Net/FLAVR).
|
| 53 |
+
|
| 54 |
+
Regarding STMFNet and FLAVR, if you only have two or three frames, you should use: Load Images -> Other VFI node (FILM is recommended in this case) with `multiplier=4` -> STMFNet VFI/FLAVR VFI
|
| 55 |
+
|
| 56 |
+
`clear_cache_after_n_frames` is used to avoid out-of-memory. Decreasing it makes the chance lower but also increases processing time.
|
| 57 |
+
|
| 58 |
+
It is recommended to use LoadImages (LoadImagesFromDirectory) from [ComfyUI-Advanced-ControlNet](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/) and [ComfyUI-VideoHelperSuite](https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite) along side with this extension.
|
| 59 |
+
|
| 60 |
+
## Example
|
| 61 |
+
### Simple workflow
|
| 62 |
+
Workflow metadata isn't embeded
|
| 63 |
+
Download these two images [anime0.png](./demo_frames/anime0.png) and [anime1.png](./demo_frames/anime0.png) and put them into a folder like `E:\test` in this image.
|
| 64 |
+

|
| 65 |
+
|
| 66 |
+
### Complex workflow
|
| 67 |
+
It's used in AnimationDiff (can load workflow metadata)
|
| 68 |
+

|
| 69 |
+
|
| 70 |
+
## Credit
|
| 71 |
+
Big thanks for styler00dollar for making [VSGAN-tensorrt-docker](https://github.com/styler00dollar/VSGAN-tensorrt-docker). About 99% the code of this repo comes from it.
|
| 72 |
+
|
| 73 |
+
Citation for each VFI node:
|
| 74 |
+
### GMFSS Fortuna
|
| 75 |
+
The All-In-One GMFSS: Dedicated for Anime Video Frame Interpolation
|
| 76 |
+
|
| 77 |
+
https://github.com/98mxr/GMFSS_Fortuna
|
| 78 |
+
|
| 79 |
+
### IFRNet
|
| 80 |
+
```bibtex
|
| 81 |
+
@InProceedings{Kong_2022_CVPR,
|
| 82 |
+
author = {Kong, Lingtong and Jiang, Boyuan and Luo, Donghao and Chu, Wenqing and Huang, Xiaoming and Tai, Ying and Wang, Chengjie and Yang, Jie},
|
| 83 |
+
title = {IFRNet: Intermediate Feature Refine Network for Efficient Frame Interpolation},
|
| 84 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 85 |
+
year = {2022}
|
| 86 |
+
}
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### IFUnet
|
| 90 |
+
RIFE with IFUNet, FusionNet and RefineNet
|
| 91 |
+
|
| 92 |
+
https://github.com/98mxr/IFUNet
|
| 93 |
+
### M2M
|
| 94 |
+
```bibtex
|
| 95 |
+
@InProceedings{hu2022m2m,
|
| 96 |
+
title={Many-to-many Splatting for Efficient Video Frame Interpolation},
|
| 97 |
+
author={Hu, Ping and Niklaus, Simon and Sclaroff, Stan and Saenko, Kate},
|
| 98 |
+
journal={CVPR},
|
| 99 |
+
year={2022}
|
| 100 |
+
}
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### RIFE
|
| 104 |
+
```bibtex
|
| 105 |
+
@inproceedings{huang2022rife,
|
| 106 |
+
title={Real-Time Intermediate Flow Estimation for Video Frame Interpolation},
|
| 107 |
+
author={Huang, Zhewei and Zhang, Tianyuan and Heng, Wen and Shi, Boxin and Zhou, Shuchang},
|
| 108 |
+
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
|
| 109 |
+
year={2022}
|
| 110 |
+
}
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### FILM
|
| 114 |
+
[Frame interpolation in PyTorch](https://github.com/dajes/frame-interpolation-pytorch)
|
| 115 |
+
|
| 116 |
+
```bibtex
|
| 117 |
+
@inproceedings{reda2022film,
|
| 118 |
+
title = {FILM: Frame Interpolation for Large Motion},
|
| 119 |
+
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
|
| 120 |
+
booktitle = {European Conference on Computer Vision (ECCV)},
|
| 121 |
+
year = {2022}
|
| 122 |
+
}
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
```bibtex
|
| 126 |
+
@misc{film-tf,
|
| 127 |
+
title = {Tensorflow 2 Implementation of "FILM: Frame Interpolation for Large Motion"},
|
| 128 |
+
author = {Fitsum Reda and Janne Kontkanen and Eric Tabellion and Deqing Sun and Caroline Pantofaru and Brian Curless},
|
| 129 |
+
year = {2022},
|
| 130 |
+
publisher = {GitHub},
|
| 131 |
+
journal = {GitHub repository},
|
| 132 |
+
howpublished = {\url{https://github.com/google-research/frame-interpolation}}
|
| 133 |
+
}
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Sepconv
|
| 137 |
+
```bibtex
|
| 138 |
+
[1] @inproceedings{Niklaus_WACV_2021,
|
| 139 |
+
author = {Simon Niklaus and Long Mai and Oliver Wang},
|
| 140 |
+
title = {Revisiting Adaptive Convolutions for Video Frame Interpolation},
|
| 141 |
+
booktitle = {IEEE Winter Conference on Applications of Computer Vision},
|
| 142 |
+
year = {2021}
|
| 143 |
+
}
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
```bibtex
|
| 147 |
+
[2] @inproceedings{Niklaus_ICCV_2017,
|
| 148 |
+
author = {Simon Niklaus and Long Mai and Feng Liu},
|
| 149 |
+
title = {Video Frame Interpolation via Adaptive Separable Convolution},
|
| 150 |
+
booktitle = {IEEE International Conference on Computer Vision},
|
| 151 |
+
year = {2017}
|
| 152 |
+
}
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
```bibtex
|
| 156 |
+
[3] @inproceedings{Niklaus_CVPR_2017,
|
| 157 |
+
author = {Simon Niklaus and Long Mai and Feng Liu},
|
| 158 |
+
title = {Video Frame Interpolation via Adaptive Convolution},
|
| 159 |
+
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
|
| 160 |
+
year = {2017}
|
| 161 |
+
}
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
### AMT
|
| 165 |
+
```bibtex
|
| 166 |
+
@inproceedings{licvpr23amt,
|
| 167 |
+
title={AMT: All-Pairs Multi-Field Transforms for Efficient Frame Interpolation},
|
| 168 |
+
author={Li, Zhen and Zhu, Zuo-Liang and Han, Ling-Hao and Hou, Qibin and Guo, Chun-Le and Cheng, Ming-Ming},
|
| 169 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 170 |
+
year={2023}
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### ST-MFNet
|
| 175 |
+
```bibtex
|
| 176 |
+
@InProceedings{Danier_2022_CVPR,
|
| 177 |
+
author = {Danier, Duolikun and Zhang, Fan and Bull, David},
|
| 178 |
+
title = {ST-MFNet: A Spatio-Temporal Multi-Flow Network for Frame Interpolation},
|
| 179 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 180 |
+
month = {June},
|
| 181 |
+
year = {2022},
|
| 182 |
+
pages = {3521-3531}
|
| 183 |
+
}
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
### FLAVR
|
| 187 |
+
```bibtex
|
| 188 |
+
@article{kalluri2021flavr,
|
| 189 |
+
title={FLAVR: Flow-Agnostic Video Representations for Fast Frame Interpolation},
|
| 190 |
+
author={Kalluri, Tarun and Pathak, Deepak and Chandraker, Manmohan and Tran, Du},
|
| 191 |
+
booktitle={arxiv},
|
| 192 |
+
year={2021}
|
| 193 |
+
}
|
| 194 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
| 4 |
+
|
| 5 |
+
from .other_nodes import Gradually_More_Denoise_KSampler
|
| 6 |
+
|
| 7 |
+
#Some models are commented out because the code is not completed
|
| 8 |
+
#from vfi_models.eisai import EISAI_VFI
|
| 9 |
+
from vfi_models.gmfss_fortuna import GMFSS_Fortuna_VFI
|
| 10 |
+
from vfi_models.ifrnet import IFRNet_VFI
|
| 11 |
+
from vfi_models.ifunet import IFUnet_VFI
|
| 12 |
+
from vfi_models.m2m import M2M_VFI
|
| 13 |
+
from vfi_models.rife import RIFE_VFI
|
| 14 |
+
from vfi_models.sepconv import SepconvVFI
|
| 15 |
+
from vfi_models.amt import AMT_VFI
|
| 16 |
+
from vfi_models.film import FILM_VFI
|
| 17 |
+
from vfi_models.stmfnet import STMFNet_VFI
|
| 18 |
+
from vfi_models.flavr import FLAVR_VFI
|
| 19 |
+
from vfi_models.cain import CAIN_VFI
|
| 20 |
+
from vfi_utils import MakeInterpolationStateList, FloatToInt
|
| 21 |
+
|
| 22 |
+
NODE_CLASS_MAPPINGS = {
|
| 23 |
+
"KSampler Gradually Adding More Denoise (efficient)": Gradually_More_Denoise_KSampler,
|
| 24 |
+
# "EISAI VFI": EISAI_VFI,
|
| 25 |
+
"GMFSS Fortuna VFI": GMFSS_Fortuna_VFI,
|
| 26 |
+
"IFRNet VFI": IFRNet_VFI,
|
| 27 |
+
"IFUnet VFI": IFUnet_VFI,
|
| 28 |
+
"M2M VFI": M2M_VFI,
|
| 29 |
+
"RIFE VFI": RIFE_VFI,
|
| 30 |
+
"Sepconv VFI": SepconvVFI,
|
| 31 |
+
"AMT VFI": AMT_VFI,
|
| 32 |
+
"FILM VFI": FILM_VFI,
|
| 33 |
+
"Make Interpolation State List": MakeInterpolationStateList,
|
| 34 |
+
"STMFNet VFI": STMFNet_VFI,
|
| 35 |
+
"FLAVR VFI": FLAVR_VFI,
|
| 36 |
+
"CAIN VFI": CAIN_VFI,
|
| 37 |
+
"VFI FloatToInt": FloatToInt
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 41 |
+
"RIFE VFI": "RIFE VFI (recommend rife47 and rife49)"
|
| 42 |
+
}
|
config.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Plz don't delete this file, just edit it when neccessary.
|
| 2 |
+
ckpts_path: "./ckpts"
|
| 3 |
+
ops_backend: "cupy" #Either "taichi" or "cupy"
|
demo_frames/anime0.png
ADDED
|
Git LFS Details
|
demo_frames/anime1.png
ADDED
|
Git LFS Details
|
demo_frames/bocchi0.jpg
ADDED
|
Git LFS Details
|
demo_frames/bocchi1.jpg
ADDED
|
Git LFS Details
|
demo_frames/real0.png
ADDED
|
Git LFS Details
|
demo_frames/real1.png
ADDED
|
Git LFS Details
|
demo_frames/rick/00003.png
ADDED
|
Git LFS Details
|
demo_frames/rick/00004.png
ADDED
|
Git LFS Details
|
demo_frames/rick/00005.png
ADDED
|
Git LFS Details
|
demo_frames/violet0.png
ADDED
|
Git LFS Details
|
demo_frames/violet1.png
ADDED
|
Git LFS Details
|
example.png
ADDED
|
Git LFS Details
|
install-taichi.bat
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo Installing Taichi lang backend...
|
| 3 |
+
|
| 4 |
+
if exist "%python_exec%" (
|
| 5 |
+
%python_exec% -s -m pip install taichi
|
| 6 |
+
) else (
|
| 7 |
+
echo Installing with system Python
|
| 8 |
+
pip install taichi
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
pause
|
install.bat
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
|
| 3 |
+
set "requirements_txt=%~dp0\requirements-no-cupy.txt"
|
| 4 |
+
set "python_exec=..\..\..\python_embeded\python.exe"
|
| 5 |
+
|
| 6 |
+
echo Installing ComfyUI Frame Interpolation..
|
| 7 |
+
|
| 8 |
+
if exist "%python_exec%" (
|
| 9 |
+
echo Installing with ComfyUI Portable
|
| 10 |
+
%python_exec% -s install.py
|
| 11 |
+
) else (
|
| 12 |
+
echo Installing with system Python
|
| 13 |
+
python install.py
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
pause
|
install.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
import platform
|
| 5 |
+
|
| 6 |
+
def get_cuda_ver_from_dir(cuda_home):
|
| 7 |
+
nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(cuda_home))
|
| 8 |
+
nvrtc = list(nvrtc)
|
| 9 |
+
if len(nvrtc) == 0:
|
| 10 |
+
return
|
| 11 |
+
nvrtc = nvrtc[0]
|
| 12 |
+
if ('102' in nvrtc) or ('10.2' in nvrtc):
|
| 13 |
+
return '102'
|
| 14 |
+
if '110' in nvrtc or ('11.0' in nvrtc):
|
| 15 |
+
return '110'
|
| 16 |
+
if '111' in nvrtc or ('11.1' in nvrtc):
|
| 17 |
+
return '111'
|
| 18 |
+
if '11' in nvrtc:
|
| 19 |
+
return '11x'
|
| 20 |
+
if '12' in nvrtc:
|
| 21 |
+
return '12x'
|
| 22 |
+
|
| 23 |
+
s_param = '-s' if "python_embeded" in sys.executable else ''
|
| 24 |
+
|
| 25 |
+
def get_cuda_home_path():
|
| 26 |
+
if "CUDA_HOME" in os.environ:
|
| 27 |
+
return os.environ["CUDA_HOME"]
|
| 28 |
+
import torch
|
| 29 |
+
torch_lib_path = Path(torch.__file__).parent / "lib"
|
| 30 |
+
torch_lib_path = str(torch_lib_path.resolve())
|
| 31 |
+
if os.path.exists(torch_lib_path):
|
| 32 |
+
nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(torch_lib_path))
|
| 33 |
+
nvrtc = list(nvrtc)
|
| 34 |
+
return torch_lib_path if len(nvrtc) > 0 else None
|
| 35 |
+
|
| 36 |
+
def install_cupy():
|
| 37 |
+
cuda_home = get_cuda_home_path()
|
| 38 |
+
try:
|
| 39 |
+
if cuda_home is not None:
|
| 40 |
+
os.environ["CUDA_HOME"] = cuda_home
|
| 41 |
+
os.environ["CUDA_PATH"] = cuda_home
|
| 42 |
+
import cupy
|
| 43 |
+
print("CuPy is already installed.")
|
| 44 |
+
except:
|
| 45 |
+
print("Uninstall cupy if existed...")
|
| 46 |
+
os.system(f'"{sys.executable}" {s_param} -m pip uninstall -y cupy-wheel cupy-cuda102 cupy-cuda110 cupy-cuda111 cupy-cuda11x cupy-cuda12x')
|
| 47 |
+
print("Installing cupy...")
|
| 48 |
+
cuda_ver = get_cuda_ver_from_dir(cuda_home)
|
| 49 |
+
cupy_package = f"cupy-cuda{cuda_ver}" if cuda_ver is not None else "cupy-wheel"
|
| 50 |
+
os.system(f'"{sys.executable}" {s_param} -m pip install {cupy_package}')
|
| 51 |
+
|
| 52 |
+
with open(Path(__file__).parent / "requirements-no-cupy.txt", 'r') as f:
|
| 53 |
+
for package in f.readlines():
|
| 54 |
+
package = package.strip()
|
| 55 |
+
print(f"Installing {package}...")
|
| 56 |
+
os.system(f'"{sys.executable}" {s_param} -m pip install {package}')
|
| 57 |
+
|
| 58 |
+
print("Checking cupy...")
|
| 59 |
+
install_cupy()
|
interpolation_schedule.png
ADDED
|
Git LFS Details
|
other_nodes.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import latent_preview
|
| 2 |
+
import comfy
|
| 3 |
+
import einops
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
| 7 |
+
device = comfy.model_management.get_torch_device()
|
| 8 |
+
latent_image = latent["samples"]
|
| 9 |
+
|
| 10 |
+
if disable_noise:
|
| 11 |
+
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
| 12 |
+
else:
|
| 13 |
+
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
| 14 |
+
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
|
| 15 |
+
|
| 16 |
+
noise_mask = None
|
| 17 |
+
if "noise_mask" in latent:
|
| 18 |
+
noise_mask = latent["noise_mask"]
|
| 19 |
+
|
| 20 |
+
preview_format = "JPEG"
|
| 21 |
+
if preview_format not in ["JPEG", "PNG"]:
|
| 22 |
+
preview_format = "JPEG"
|
| 23 |
+
|
| 24 |
+
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
| 25 |
+
|
| 26 |
+
pbar = comfy.utils.ProgressBar(steps)
|
| 27 |
+
def callback(step, x0, x, total_steps):
|
| 28 |
+
preview_bytes = None
|
| 29 |
+
if previewer:
|
| 30 |
+
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
| 31 |
+
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
| 32 |
+
|
| 33 |
+
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
| 34 |
+
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
| 35 |
+
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
|
| 36 |
+
out = latent.copy()
|
| 37 |
+
out["samples"] = samples
|
| 38 |
+
return (out, )
|
| 39 |
+
|
| 40 |
+
class Gradually_More_Denoise_KSampler:
|
| 41 |
+
@classmethod
|
| 42 |
+
def INPUT_TYPES(s):
|
| 43 |
+
return {"required":
|
| 44 |
+
{"model": ("MODEL",),
|
| 45 |
+
"positive": ("CONDITIONING", ),
|
| 46 |
+
"negative": ("CONDITIONING", ),
|
| 47 |
+
"latent_image": ("LATENT", ),
|
| 48 |
+
|
| 49 |
+
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
| 50 |
+
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
| 51 |
+
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
| 52 |
+
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
| 53 |
+
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
| 54 |
+
|
| 55 |
+
"start_denoise": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 56 |
+
"denoise_increment": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.1}),
|
| 57 |
+
"denoise_increment_steps": ("INT", {"default": 20, "min": 1, "max": 10000})
|
| 58 |
+
},
|
| 59 |
+
"optional": { "optional_vae": ("VAE",) }
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "LATENT", "VAE", )
|
| 63 |
+
RETURN_NAMES = ("MODEL", "CONDITIONING+", "CONDITIONING-", "LATENT", "VAE", )
|
| 64 |
+
OUTPUT_NODE = True
|
| 65 |
+
FUNCTION = "sample"
|
| 66 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/others"
|
| 67 |
+
|
| 68 |
+
def sample(self, model, positive, negative, latent_image, optional_vae,
|
| 69 |
+
seed, steps, cfg, sampler_name, scheduler,start_denoise, denoise_increment, denoise_increment_steps):
|
| 70 |
+
if start_denoise + denoise_increment * denoise_increment_steps > 1.0:
|
| 71 |
+
raise Exception(f"Max denoise strength can't over 1.0 (start_denoise={start_denoise}, denoise_increment={denoise_increment}, denoise_increment_steps={denoise_increment_steps}")
|
| 72 |
+
|
| 73 |
+
copied_latent = latent_image.copy()
|
| 74 |
+
out_samples = []
|
| 75 |
+
|
| 76 |
+
for latent_sample in copied_latent["samples"]:
|
| 77 |
+
latent = {"samples": einops.rearrange(latent_sample, "c h w -> 1 c h w")}
|
| 78 |
+
#Latent's shape is NCHW
|
| 79 |
+
gradually_denoising_samples = [
|
| 80 |
+
common_ksampler(
|
| 81 |
+
model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=start_denoise + denoise_increment * i
|
| 82 |
+
)[0]["samples"]
|
| 83 |
+
for i in range(denoise_increment_steps)
|
| 84 |
+
]
|
| 85 |
+
out_samples.extend(gradually_denoising_samples)
|
| 86 |
+
|
| 87 |
+
copied_latent["samples"] = torch.cat(out_samples, dim=0)
|
| 88 |
+
return (model, positive, negative, copied_latent, optional_vae)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "comfyui-frame-interpolation"
|
| 3 |
+
description = "A custom node suite for Video Frame Interpolation in ComfyUI"
|
| 4 |
+
version = "1.0.7"
|
| 5 |
+
license = { file = "LICENSE" }
|
| 6 |
+
|
| 7 |
+
[project.urls]
|
| 8 |
+
Repository = "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation"
|
| 9 |
+
|
| 10 |
+
[tool.comfy]
|
| 11 |
+
PublisherId = "fannovel16"
|
| 12 |
+
DisplayName = "ComfyUI-Frame-Interpolation"
|
| 13 |
+
Icon = ""
|
requirements-no-cupy.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
einops
|
| 4 |
+
opencv-contrib-python
|
| 5 |
+
kornia
|
| 6 |
+
scipy
|
| 7 |
+
Pillow
|
| 8 |
+
torchvision
|
| 9 |
+
tqdm
|
requirements-with-cupy.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
numpy
|
| 3 |
+
einops
|
| 4 |
+
opencv-contrib-python
|
| 5 |
+
kornia
|
| 6 |
+
scipy
|
| 7 |
+
Pillow
|
| 8 |
+
torchvision
|
| 9 |
+
tqdm
|
| 10 |
+
cupy-wheel
|
test.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
|
| 4 |
+
|
| 5 |
+
import shutil
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import PIL
|
| 9 |
+
import torchvision.transforms.functional as transform
|
| 10 |
+
from vfi_utils import load_file_from_github_release
|
| 11 |
+
from vfi_models import gmfss_fortuna, ifrnet, ifunet, m2m, rife, sepconv, amt, xvfi, cain, flavr
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
frame_0 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime0.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0)
|
| 15 |
+
frame_1 = torch.from_numpy(np.array(PIL.Image.open("demo_frames/anime1.png").convert("RGB")).astype(np.float32) / 255.0).unsqueeze(0)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if os.path.exists("test_result"):
|
| 19 |
+
shutil.rmtree("test_result")
|
| 20 |
+
|
| 21 |
+
vfi_node_class = gmfss_fortuna.GMFSS_Fortuna_VFI()
|
| 22 |
+
for i, ckpt_name in enumerate(vfi_node_class.INPUT_TYPES()["required"]["ckpt_name"][0][:2]):
|
| 23 |
+
result = vfi_node_class.vfi(ckpt_name, torch.cat([
|
| 24 |
+
frame_0,
|
| 25 |
+
frame_1,
|
| 26 |
+
frame_0,
|
| 27 |
+
frame_1
|
| 28 |
+
], dim=0).cuda(), multipler=4, batch_size=2)[0]
|
| 29 |
+
print(result.shape)
|
| 30 |
+
print(f"Generated {result.size(0)} frames")
|
| 31 |
+
frames = [PIL.Image.fromarray(np.clip((frame * 255).numpy(), 0, 255).astype(np.uint8)) for frame in result]
|
| 32 |
+
print(result[0].shape)
|
| 33 |
+
os.makedirs(f"test_result/video{i}", exist_ok=True)
|
| 34 |
+
for j, frame in enumerate(frames):
|
| 35 |
+
frame.save(f"test_result/video{i}/{j}.jpg")
|
| 36 |
+
frames[0].save(f"test_result/video{i}.gif", save_all=True, append_images=frames[1:], optimize=True, duration=1/3, loop=0)
|
| 37 |
+
os.startfile(f"test_result{os.path.sep}video{i}.gif")
|
| 38 |
+
#torchvision.io.video.write_video("test.mp4", einops.rearrange(result, "n c h w -> n h w c").cpu(), fps=1)
|
test_vfi_schedule.gif
ADDED
|
Git LFS Details
|
vfi_models/amt/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
import pathlib
|
| 5 |
+
from vfi_utils import load_file_from_direct_url, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
|
| 6 |
+
import typing
|
| 7 |
+
from comfy.model_management import get_torch_device
|
| 8 |
+
from .amt_arch import AMT_S, AMT_L, AMT_G, InputPadder
|
| 9 |
+
|
| 10 |
+
#https://github.com/MCG-NKU/AMT/tree/main/cfgs
|
| 11 |
+
CKPT_CONFIGS = {
|
| 12 |
+
"amt-s.pth": {
|
| 13 |
+
"network": AMT_S,
|
| 14 |
+
"params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 3 }
|
| 15 |
+
},
|
| 16 |
+
"amt-l.pth": {
|
| 17 |
+
"network": AMT_L,
|
| 18 |
+
"params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 5 }
|
| 19 |
+
},
|
| 20 |
+
"amt-g.pth": {
|
| 21 |
+
"network": AMT_G,
|
| 22 |
+
"params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 5 }
|
| 23 |
+
},
|
| 24 |
+
"gopro_amt-s.pth": {
|
| 25 |
+
"network": AMT_S,
|
| 26 |
+
"params": { "corr_radius": 3, "corr_lvls": 4, "num_flows": 3 }
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
MODEL_TYPE = pathlib.Path(__file__).parent.name
|
| 32 |
+
|
| 33 |
+
class AMT_VFI:
|
| 34 |
+
@classmethod
|
| 35 |
+
def INPUT_TYPES(s):
|
| 36 |
+
return {
|
| 37 |
+
"required": {
|
| 38 |
+
"ckpt_name": (list(CKPT_CONFIGS.keys()), ),
|
| 39 |
+
"frames": ("IMAGE", ),
|
| 40 |
+
"clear_cache_after_n_frames": ("INT", {"default": 1, "min": 1, "max": 100}),
|
| 41 |
+
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000})
|
| 42 |
+
},
|
| 43 |
+
"optional": {
|
| 44 |
+
"optional_interpolation_states": ("INTERPOLATION_STATES", )
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
RETURN_TYPES = ("IMAGE", )
|
| 49 |
+
FUNCTION = "vfi"
|
| 50 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
| 51 |
+
|
| 52 |
+
def vfi(
|
| 53 |
+
self,
|
| 54 |
+
ckpt_name: typing.AnyStr,
|
| 55 |
+
frames: torch.Tensor,
|
| 56 |
+
clear_cache_after_n_frames: typing.SupportsInt = 1,
|
| 57 |
+
multiplier: typing.SupportsInt = 2,
|
| 58 |
+
optional_interpolation_states: InterpolationStateList = None,
|
| 59 |
+
**kwargs
|
| 60 |
+
):
|
| 61 |
+
model_path = load_file_from_direct_url(MODEL_TYPE, f"https://huggingface.co/lalala125/AMT/resolve/main/{ckpt_name}")
|
| 62 |
+
ckpt_config = CKPT_CONFIGS[ckpt_name]
|
| 63 |
+
|
| 64 |
+
interpolation_model = ckpt_config["network"](**ckpt_config["params"])
|
| 65 |
+
interpolation_model.load_state_dict(torch.load(model_path)["state_dict"])
|
| 66 |
+
interpolation_model.eval().to(get_torch_device())
|
| 67 |
+
|
| 68 |
+
frames = preprocess_frames(frames)
|
| 69 |
+
padder = InputPadder(frames.shape, 16)
|
| 70 |
+
frames = padder.pad(frames)
|
| 71 |
+
|
| 72 |
+
def return_middle_frame(frame_0, frame_1, timestep, model):
|
| 73 |
+
return model(
|
| 74 |
+
frame_0,
|
| 75 |
+
frame_1,
|
| 76 |
+
embt=torch.FloatTensor([timestep] * frame_0.shape[0]).view(frame_0.shape[0], 1, 1, 1).to(get_torch_device()),
|
| 77 |
+
scale_factor=1.0,
|
| 78 |
+
eval=True
|
| 79 |
+
)["imgt_pred"]
|
| 80 |
+
|
| 81 |
+
args = [interpolation_model]
|
| 82 |
+
out = generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
|
| 83 |
+
interpolation_states=optional_interpolation_states, dtype=torch.float32)
|
| 84 |
+
out = padder.unpad(out)
|
| 85 |
+
out = postprocess_frames(out)
|
| 86 |
+
return (out,)
|
| 87 |
+
|
vfi_models/amt/amt_arch.py
ADDED
|
@@ -0,0 +1,1590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/MCG-NKU/AMT/blob/main/utils/dist_utils.py
|
| 3 |
+
https://github.com/MCG-NKU/AMT/blob/main/utils/flow_utils.py
|
| 4 |
+
https://github.com/MCG-NKU/AMT/blob/main/utils/utils.py
|
| 5 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/feat_enc.py
|
| 6 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/ifrnet.py
|
| 7 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/multi_flow.py
|
| 8 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/blocks/raft.py
|
| 9 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-S.py
|
| 10 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-L.py
|
| 11 |
+
https://github.com/MCG-NKU/AMT/blob/main/networks/AMT-G.py
|
| 12 |
+
"""
|
| 13 |
+
#Removed imageio by removing readImage, writeImage
|
| 14 |
+
#The model will receive image tensors from other ComfyUI's nodes so they are unneccessary
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import numpy as np
|
| 19 |
+
from PIL import ImageFile
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 22 |
+
import re
|
| 23 |
+
import sys
|
| 24 |
+
import random
|
| 25 |
+
|
| 26 |
+
def warp(img, flow):
|
| 27 |
+
B, _, H, W = flow.shape
|
| 28 |
+
xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
|
| 29 |
+
yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
|
| 30 |
+
grid = torch.cat([xx, yy], 1).to(img)
|
| 31 |
+
flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
|
| 32 |
+
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
| 33 |
+
output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
|
| 34 |
+
return output
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def make_colorwheel():
|
| 38 |
+
"""
|
| 39 |
+
Generates a color wheel for optical flow visualization as presented in:
|
| 40 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
| 41 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
| 42 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
| 43 |
+
Code follows the the Matlab source code of Deqing Sun.
|
| 44 |
+
Returns:
|
| 45 |
+
np.ndarray: Color wheel
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
RY = 15
|
| 49 |
+
YG = 6
|
| 50 |
+
GC = 4
|
| 51 |
+
CB = 11
|
| 52 |
+
BM = 13
|
| 53 |
+
MR = 6
|
| 54 |
+
|
| 55 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 56 |
+
colorwheel = np.zeros((ncols, 3))
|
| 57 |
+
col = 0
|
| 58 |
+
|
| 59 |
+
# RY
|
| 60 |
+
colorwheel[0:RY, 0] = 255
|
| 61 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
| 62 |
+
col = col+RY
|
| 63 |
+
# YG
|
| 64 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
| 65 |
+
colorwheel[col:col+YG, 1] = 255
|
| 66 |
+
col = col+YG
|
| 67 |
+
# GC
|
| 68 |
+
colorwheel[col:col+GC, 1] = 255
|
| 69 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
| 70 |
+
col = col+GC
|
| 71 |
+
# CB
|
| 72 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
| 73 |
+
colorwheel[col:col+CB, 2] = 255
|
| 74 |
+
col = col+CB
|
| 75 |
+
# BM
|
| 76 |
+
colorwheel[col:col+BM, 2] = 255
|
| 77 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
| 78 |
+
col = col+BM
|
| 79 |
+
# MR
|
| 80 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
| 81 |
+
colorwheel[col:col+MR, 0] = 255
|
| 82 |
+
return colorwheel
|
| 83 |
+
|
| 84 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
| 85 |
+
"""
|
| 86 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
| 87 |
+
According to the C++ source code of Daniel Scharstein
|
| 88 |
+
According to the Matlab source code of Deqing Sun
|
| 89 |
+
Args:
|
| 90 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
| 91 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
| 92 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 93 |
+
Returns:
|
| 94 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 95 |
+
"""
|
| 96 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
| 97 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
| 98 |
+
ncols = colorwheel.shape[0]
|
| 99 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 100 |
+
a = np.arctan2(-v, -u)/np.pi
|
| 101 |
+
fk = (a+1) / 2*(ncols-1)
|
| 102 |
+
k0 = np.floor(fk).astype(np.int32)
|
| 103 |
+
k1 = k0 + 1
|
| 104 |
+
k1[k1 == ncols] = 0
|
| 105 |
+
f = fk - k0
|
| 106 |
+
for i in range(colorwheel.shape[1]):
|
| 107 |
+
tmp = colorwheel[:,i]
|
| 108 |
+
col0 = tmp[k0] / 255.0
|
| 109 |
+
col1 = tmp[k1] / 255.0
|
| 110 |
+
col = (1-f)*col0 + f*col1
|
| 111 |
+
idx = (rad <= 1)
|
| 112 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
| 113 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
| 114 |
+
# Note the 2-i => BGR instead of RGB
|
| 115 |
+
ch_idx = 2-i if convert_to_bgr else i
|
| 116 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
| 117 |
+
return flow_image
|
| 118 |
+
|
| 119 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
| 120 |
+
"""
|
| 121 |
+
Expects a two dimensional flow image of shape.
|
| 122 |
+
Args:
|
| 123 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
| 124 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
| 125 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 126 |
+
Returns:
|
| 127 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 128 |
+
"""
|
| 129 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
| 130 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
| 131 |
+
if clip_flow is not None:
|
| 132 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
| 133 |
+
u = flow_uv[:,:,0]
|
| 134 |
+
v = flow_uv[:,:,1]
|
| 135 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 136 |
+
rad_max = np.max(rad)
|
| 137 |
+
epsilon = 1e-5
|
| 138 |
+
u = u / (rad_max + epsilon)
|
| 139 |
+
v = v / (rad_max + epsilon)
|
| 140 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class AverageMeter():
|
| 153 |
+
def __init__(self):
|
| 154 |
+
self.reset()
|
| 155 |
+
|
| 156 |
+
def reset(self):
|
| 157 |
+
self.val = 0.
|
| 158 |
+
self.avg = 0.
|
| 159 |
+
self.sum = 0.
|
| 160 |
+
self.count = 0
|
| 161 |
+
|
| 162 |
+
def update(self, val, n=1):
|
| 163 |
+
self.val = val
|
| 164 |
+
self.sum += val * n
|
| 165 |
+
self.count += n
|
| 166 |
+
self.avg = self.sum / self.count
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AverageMeterGroups:
|
| 170 |
+
def __init__(self) -> None:
|
| 171 |
+
self.meter_dict = dict()
|
| 172 |
+
|
| 173 |
+
def update(self, dict, n=1):
|
| 174 |
+
for name, val in dict.items():
|
| 175 |
+
if self.meter_dict.get(name) is None:
|
| 176 |
+
self.meter_dict[name] = AverageMeter()
|
| 177 |
+
self.meter_dict[name].update(val, n)
|
| 178 |
+
|
| 179 |
+
def reset(self, name=None):
|
| 180 |
+
if name is None:
|
| 181 |
+
for v in self.meter_dict.values():
|
| 182 |
+
v.reset()
|
| 183 |
+
else:
|
| 184 |
+
meter = self.meter_dict.get(name)
|
| 185 |
+
if meter is not None:
|
| 186 |
+
meter.reset()
|
| 187 |
+
|
| 188 |
+
def avg(self, name):
|
| 189 |
+
meter = self.meter_dict.get(name)
|
| 190 |
+
if meter is not None:
|
| 191 |
+
return meter.avg
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class InputPadder:
|
| 195 |
+
""" Pads images such that dimensions are divisible by divisor """
|
| 196 |
+
def __init__(self, dims, divisor=16):
|
| 197 |
+
self.ht, self.wd = dims[-2:]
|
| 198 |
+
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
|
| 199 |
+
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
|
| 200 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 201 |
+
|
| 202 |
+
def pad(self, input_tensor):
|
| 203 |
+
return F.pad(input_tensor, self._pad, mode='replicate')
|
| 204 |
+
|
| 205 |
+
def unpad(self, input_tensor):
|
| 206 |
+
return self._unpad(input_tensor)
|
| 207 |
+
|
| 208 |
+
def _unpad(self, x):
|
| 209 |
+
ht, wd = x.shape[-2:]
|
| 210 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 211 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def img2tensor(img):
|
| 215 |
+
if img.shape[-1] > 3:
|
| 216 |
+
img = img[:,:,:3]
|
| 217 |
+
return torch.tensor(img).permute(2, 0, 1).unsqueeze(0) / 255.0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def tensor2img(img_t):
|
| 221 |
+
return (img_t * 255.).detach(
|
| 222 |
+
).squeeze(0).permute(1, 2, 0).cpu().numpy(
|
| 223 |
+
).clip(0, 255).astype(np.uint8)
|
| 224 |
+
|
| 225 |
+
def seed_all(seed):
|
| 226 |
+
random.seed(seed)
|
| 227 |
+
np.random.seed(seed)
|
| 228 |
+
torch.manual_seed(seed)
|
| 229 |
+
torch.cuda.manual_seed_all(seed)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def readPFM(file):
|
| 233 |
+
file = open(file, 'rb')
|
| 234 |
+
|
| 235 |
+
color = None
|
| 236 |
+
width = None
|
| 237 |
+
height = None
|
| 238 |
+
scale = None
|
| 239 |
+
endian = None
|
| 240 |
+
|
| 241 |
+
header = file.readline().rstrip()
|
| 242 |
+
if header.decode("ascii") == 'PF':
|
| 243 |
+
color = True
|
| 244 |
+
elif header.decode("ascii") == 'Pf':
|
| 245 |
+
color = False
|
| 246 |
+
else:
|
| 247 |
+
raise Exception('Not a PFM file.')
|
| 248 |
+
|
| 249 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
|
| 250 |
+
if dim_match:
|
| 251 |
+
width, height = list(map(int, dim_match.groups()))
|
| 252 |
+
else:
|
| 253 |
+
raise Exception('Malformed PFM header.')
|
| 254 |
+
|
| 255 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
| 256 |
+
if scale < 0:
|
| 257 |
+
endian = '<'
|
| 258 |
+
scale = -scale
|
| 259 |
+
else:
|
| 260 |
+
endian = '>'
|
| 261 |
+
|
| 262 |
+
data = np.fromfile(file, endian + 'f')
|
| 263 |
+
shape = (height, width, 3) if color else (height, width)
|
| 264 |
+
|
| 265 |
+
data = np.reshape(data, shape)
|
| 266 |
+
data = np.flipud(data)
|
| 267 |
+
return data, scale
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def writePFM(file, image, scale=1):
|
| 271 |
+
file = open(file, 'wb')
|
| 272 |
+
|
| 273 |
+
color = None
|
| 274 |
+
|
| 275 |
+
if image.dtype.name != 'float32':
|
| 276 |
+
raise Exception('Image dtype must be float32.')
|
| 277 |
+
|
| 278 |
+
image = np.flipud(image)
|
| 279 |
+
|
| 280 |
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 281 |
+
color = True
|
| 282 |
+
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:
|
| 283 |
+
color = False
|
| 284 |
+
else:
|
| 285 |
+
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
| 286 |
+
|
| 287 |
+
file.write('PF\n' if color else 'Pf\n'.encode())
|
| 288 |
+
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
|
| 289 |
+
|
| 290 |
+
endian = image.dtype.byteorder
|
| 291 |
+
|
| 292 |
+
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
| 293 |
+
scale = -scale
|
| 294 |
+
|
| 295 |
+
file.write('%f\n'.encode() % scale)
|
| 296 |
+
|
| 297 |
+
image.tofile(file)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def readFlow(name):
|
| 301 |
+
if name.endswith('.pfm') or name.endswith('.PFM'):
|
| 302 |
+
return readPFM(name)[0][:,:,0:2]
|
| 303 |
+
|
| 304 |
+
f = open(name, 'rb')
|
| 305 |
+
|
| 306 |
+
header = f.read(4)
|
| 307 |
+
if header.decode("utf-8") != 'PIEH':
|
| 308 |
+
raise Exception('Flow file header does not contain PIEH')
|
| 309 |
+
|
| 310 |
+
width = np.fromfile(f, np.int32, 1).squeeze()
|
| 311 |
+
height = np.fromfile(f, np.int32, 1).squeeze()
|
| 312 |
+
|
| 313 |
+
flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
|
| 314 |
+
|
| 315 |
+
return flow.astype(np.float32)
|
| 316 |
+
|
| 317 |
+
def writeFlow(name, flow):
|
| 318 |
+
f = open(name, 'wb')
|
| 319 |
+
f.write('PIEH'.encode('utf-8'))
|
| 320 |
+
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
| 321 |
+
flow = flow.astype(np.float32)
|
| 322 |
+
flow.tofile(f)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def readFloat(name):
|
| 326 |
+
f = open(name, 'rb')
|
| 327 |
+
|
| 328 |
+
if(f.readline().decode("utf-8")) != 'float\n':
|
| 329 |
+
raise Exception('float file %s did not contain <float> keyword' % name)
|
| 330 |
+
|
| 331 |
+
dim = int(f.readline())
|
| 332 |
+
|
| 333 |
+
dims = []
|
| 334 |
+
count = 1
|
| 335 |
+
for i in range(0, dim):
|
| 336 |
+
d = int(f.readline())
|
| 337 |
+
dims.append(d)
|
| 338 |
+
count *= d
|
| 339 |
+
|
| 340 |
+
dims = list(reversed(dims))
|
| 341 |
+
|
| 342 |
+
data = np.fromfile(f, np.float32, count).reshape(dims)
|
| 343 |
+
if dim > 2:
|
| 344 |
+
data = np.transpose(data, (2, 1, 0))
|
| 345 |
+
data = np.transpose(data, (1, 0, 2))
|
| 346 |
+
|
| 347 |
+
return data
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def writeFloat(name, data):
|
| 351 |
+
f = open(name, 'wb')
|
| 352 |
+
|
| 353 |
+
dim=len(data.shape)
|
| 354 |
+
if dim>3:
|
| 355 |
+
raise Exception('bad float file dimension: %d' % dim)
|
| 356 |
+
|
| 357 |
+
f.write(('float\n').encode('ascii'))
|
| 358 |
+
f.write(('%d\n' % dim).encode('ascii'))
|
| 359 |
+
|
| 360 |
+
if dim == 1:
|
| 361 |
+
f.write(('%d\n' % data.shape[0]).encode('ascii'))
|
| 362 |
+
else:
|
| 363 |
+
f.write(('%d\n' % data.shape[1]).encode('ascii'))
|
| 364 |
+
f.write(('%d\n' % data.shape[0]).encode('ascii'))
|
| 365 |
+
for i in range(2, dim):
|
| 366 |
+
f.write(('%d\n' % data.shape[i]).encode('ascii'))
|
| 367 |
+
|
| 368 |
+
data = data.astype(np.float32)
|
| 369 |
+
if dim==2:
|
| 370 |
+
data.tofile(f)
|
| 371 |
+
|
| 372 |
+
else:
|
| 373 |
+
np.transpose(data, (2, 0, 1)).tofile(f)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def check_dim_and_resize(tensor_list):
|
| 377 |
+
shape_list = []
|
| 378 |
+
for t in tensor_list:
|
| 379 |
+
shape_list.append(t.shape[2:])
|
| 380 |
+
|
| 381 |
+
if len(set(shape_list)) > 1:
|
| 382 |
+
desired_shape = shape_list[0]
|
| 383 |
+
print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}')
|
| 384 |
+
|
| 385 |
+
resize_tensor_list = []
|
| 386 |
+
for t in tensor_list:
|
| 387 |
+
resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear'))
|
| 388 |
+
|
| 389 |
+
tensor_list = resize_tensor_list
|
| 390 |
+
|
| 391 |
+
return tensor_list
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class BottleneckBlock(nn.Module):
|
| 404 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
| 405 |
+
super(BottleneckBlock, self).__init__()
|
| 406 |
+
|
| 407 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
| 408 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
| 409 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
| 410 |
+
self.relu = nn.ReLU(inplace=True)
|
| 411 |
+
|
| 412 |
+
num_groups = planes // 8
|
| 413 |
+
|
| 414 |
+
if norm_fn == 'group':
|
| 415 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
| 416 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
| 417 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 418 |
+
if not stride == 1:
|
| 419 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 420 |
+
|
| 421 |
+
elif norm_fn == 'batch':
|
| 422 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
| 423 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
| 424 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 425 |
+
if not stride == 1:
|
| 426 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
| 427 |
+
|
| 428 |
+
elif norm_fn == 'instance':
|
| 429 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
| 430 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
| 431 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 432 |
+
if not stride == 1:
|
| 433 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
| 434 |
+
|
| 435 |
+
elif norm_fn == 'none':
|
| 436 |
+
self.norm1 = nn.Sequential()
|
| 437 |
+
self.norm2 = nn.Sequential()
|
| 438 |
+
self.norm3 = nn.Sequential()
|
| 439 |
+
if not stride == 1:
|
| 440 |
+
self.norm4 = nn.Sequential()
|
| 441 |
+
|
| 442 |
+
if stride == 1:
|
| 443 |
+
self.downsample = None
|
| 444 |
+
|
| 445 |
+
else:
|
| 446 |
+
self.downsample = nn.Sequential(
|
| 447 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def forward(self, x):
|
| 451 |
+
y = x
|
| 452 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 453 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 454 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
| 455 |
+
|
| 456 |
+
if self.downsample is not None:
|
| 457 |
+
x = self.downsample(x)
|
| 458 |
+
|
| 459 |
+
return self.relu(x+y)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class ResidualBlock(nn.Module):
|
| 463 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
| 464 |
+
super(ResidualBlock, self).__init__()
|
| 465 |
+
|
| 466 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
| 467 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
| 468 |
+
self.relu = nn.ReLU(inplace=True)
|
| 469 |
+
|
| 470 |
+
num_groups = planes // 8
|
| 471 |
+
|
| 472 |
+
if norm_fn == 'group':
|
| 473 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 474 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 475 |
+
if not stride == 1:
|
| 476 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 477 |
+
|
| 478 |
+
elif norm_fn == 'batch':
|
| 479 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 480 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 481 |
+
if not stride == 1:
|
| 482 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 483 |
+
|
| 484 |
+
elif norm_fn == 'instance':
|
| 485 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 486 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 487 |
+
if not stride == 1:
|
| 488 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 489 |
+
|
| 490 |
+
elif norm_fn == 'none':
|
| 491 |
+
self.norm1 = nn.Sequential()
|
| 492 |
+
self.norm2 = nn.Sequential()
|
| 493 |
+
if not stride == 1:
|
| 494 |
+
self.norm3 = nn.Sequential()
|
| 495 |
+
|
| 496 |
+
if stride == 1:
|
| 497 |
+
self.downsample = None
|
| 498 |
+
|
| 499 |
+
else:
|
| 500 |
+
self.downsample = nn.Sequential(
|
| 501 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def forward(self, x):
|
| 505 |
+
y = x
|
| 506 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 507 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 508 |
+
|
| 509 |
+
if self.downsample is not None:
|
| 510 |
+
x = self.downsample(x)
|
| 511 |
+
|
| 512 |
+
return self.relu(x+y)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class SmallEncoder(nn.Module):
|
| 516 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
| 517 |
+
super(SmallEncoder, self).__init__()
|
| 518 |
+
self.norm_fn = norm_fn
|
| 519 |
+
|
| 520 |
+
if self.norm_fn == 'group':
|
| 521 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
| 522 |
+
|
| 523 |
+
elif self.norm_fn == 'batch':
|
| 524 |
+
self.norm1 = nn.BatchNorm2d(32)
|
| 525 |
+
|
| 526 |
+
elif self.norm_fn == 'instance':
|
| 527 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
| 528 |
+
|
| 529 |
+
elif self.norm_fn == 'none':
|
| 530 |
+
self.norm1 = nn.Sequential()
|
| 531 |
+
|
| 532 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
| 533 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 534 |
+
|
| 535 |
+
self.in_planes = 32
|
| 536 |
+
self.layer1 = self._make_layer(32, stride=1)
|
| 537 |
+
self.layer2 = self._make_layer(64, stride=2)
|
| 538 |
+
self.layer3 = self._make_layer(96, stride=2)
|
| 539 |
+
|
| 540 |
+
self.dropout = None
|
| 541 |
+
if dropout > 0:
|
| 542 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 543 |
+
|
| 544 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
| 545 |
+
|
| 546 |
+
for m in self.modules():
|
| 547 |
+
if isinstance(m, nn.Conv2d):
|
| 548 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 549 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 550 |
+
if m.weight is not None:
|
| 551 |
+
nn.init.constant_(m.weight, 1)
|
| 552 |
+
if m.bias is not None:
|
| 553 |
+
nn.init.constant_(m.bias, 0)
|
| 554 |
+
|
| 555 |
+
def _make_layer(self, dim, stride=1):
|
| 556 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 557 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
| 558 |
+
layers = (layer1, layer2)
|
| 559 |
+
|
| 560 |
+
self.in_planes = dim
|
| 561 |
+
return nn.Sequential(*layers)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def forward(self, x):
|
| 565 |
+
|
| 566 |
+
# if input is list, combine batch dimension
|
| 567 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 568 |
+
if is_list:
|
| 569 |
+
batch_dim = x[0].shape[0]
|
| 570 |
+
x = torch.cat(x, dim=0)
|
| 571 |
+
|
| 572 |
+
x = self.conv1(x)
|
| 573 |
+
x = self.norm1(x)
|
| 574 |
+
x = self.relu1(x)
|
| 575 |
+
|
| 576 |
+
x = self.layer1(x)
|
| 577 |
+
x = self.layer2(x)
|
| 578 |
+
x = self.layer3(x)
|
| 579 |
+
x = self.conv2(x)
|
| 580 |
+
|
| 581 |
+
if self.training and self.dropout is not None:
|
| 582 |
+
x = self.dropout(x)
|
| 583 |
+
|
| 584 |
+
if is_list:
|
| 585 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 586 |
+
|
| 587 |
+
return x
|
| 588 |
+
|
| 589 |
+
class BasicEncoder(nn.Module):
|
| 590 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
| 591 |
+
super(BasicEncoder, self).__init__()
|
| 592 |
+
self.norm_fn = norm_fn
|
| 593 |
+
|
| 594 |
+
if self.norm_fn == 'group':
|
| 595 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 596 |
+
|
| 597 |
+
elif self.norm_fn == 'batch':
|
| 598 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 599 |
+
|
| 600 |
+
elif self.norm_fn == 'instance':
|
| 601 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 602 |
+
|
| 603 |
+
elif self.norm_fn == 'none':
|
| 604 |
+
self.norm1 = nn.Sequential()
|
| 605 |
+
|
| 606 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 607 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 608 |
+
|
| 609 |
+
self.in_planes = 64
|
| 610 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 611 |
+
self.layer2 = self._make_layer(72, stride=2)
|
| 612 |
+
self.layer3 = self._make_layer(128, stride=2)
|
| 613 |
+
|
| 614 |
+
# output convolution
|
| 615 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
| 616 |
+
|
| 617 |
+
self.dropout = None
|
| 618 |
+
if dropout > 0:
|
| 619 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 620 |
+
|
| 621 |
+
for m in self.modules():
|
| 622 |
+
if isinstance(m, nn.Conv2d):
|
| 623 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 624 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 625 |
+
if m.weight is not None:
|
| 626 |
+
nn.init.constant_(m.weight, 1)
|
| 627 |
+
if m.bias is not None:
|
| 628 |
+
nn.init.constant_(m.bias, 0)
|
| 629 |
+
|
| 630 |
+
def _make_layer(self, dim, stride=1):
|
| 631 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 632 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 633 |
+
layers = (layer1, layer2)
|
| 634 |
+
|
| 635 |
+
self.in_planes = dim
|
| 636 |
+
return nn.Sequential(*layers)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def forward(self, x):
|
| 640 |
+
|
| 641 |
+
# if input is list, combine batch dimension
|
| 642 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 643 |
+
if is_list:
|
| 644 |
+
batch_dim = x[0].shape[0]
|
| 645 |
+
x = torch.cat(x, dim=0)
|
| 646 |
+
|
| 647 |
+
x = self.conv1(x)
|
| 648 |
+
x = self.norm1(x)
|
| 649 |
+
x = self.relu1(x)
|
| 650 |
+
|
| 651 |
+
x = self.layer1(x)
|
| 652 |
+
x = self.layer2(x)
|
| 653 |
+
x = self.layer3(x)
|
| 654 |
+
|
| 655 |
+
x = self.conv2(x)
|
| 656 |
+
|
| 657 |
+
if self.training and self.dropout is not None:
|
| 658 |
+
x = self.dropout(x)
|
| 659 |
+
|
| 660 |
+
if is_list:
|
| 661 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 662 |
+
|
| 663 |
+
return x
|
| 664 |
+
|
| 665 |
+
class LargeEncoder(nn.Module):
|
| 666 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
| 667 |
+
super(LargeEncoder, self).__init__()
|
| 668 |
+
self.norm_fn = norm_fn
|
| 669 |
+
|
| 670 |
+
if self.norm_fn == 'group':
|
| 671 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 672 |
+
|
| 673 |
+
elif self.norm_fn == 'batch':
|
| 674 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 675 |
+
|
| 676 |
+
elif self.norm_fn == 'instance':
|
| 677 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 678 |
+
|
| 679 |
+
elif self.norm_fn == 'none':
|
| 680 |
+
self.norm1 = nn.Sequential()
|
| 681 |
+
|
| 682 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 683 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 684 |
+
|
| 685 |
+
self.in_planes = 64
|
| 686 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 687 |
+
self.layer2 = self._make_layer(112, stride=2)
|
| 688 |
+
self.layer3 = self._make_layer(160, stride=2)
|
| 689 |
+
self.layer3_2 = self._make_layer(160, stride=1)
|
| 690 |
+
|
| 691 |
+
# output convolution
|
| 692 |
+
self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1)
|
| 693 |
+
|
| 694 |
+
self.dropout = None
|
| 695 |
+
if dropout > 0:
|
| 696 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 697 |
+
|
| 698 |
+
for m in self.modules():
|
| 699 |
+
if isinstance(m, nn.Conv2d):
|
| 700 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 701 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 702 |
+
if m.weight is not None:
|
| 703 |
+
nn.init.constant_(m.weight, 1)
|
| 704 |
+
if m.bias is not None:
|
| 705 |
+
nn.init.constant_(m.bias, 0)
|
| 706 |
+
|
| 707 |
+
def _make_layer(self, dim, stride=1):
|
| 708 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 709 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 710 |
+
layers = (layer1, layer2)
|
| 711 |
+
|
| 712 |
+
self.in_planes = dim
|
| 713 |
+
return nn.Sequential(*layers)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def forward(self, x):
|
| 717 |
+
|
| 718 |
+
# if input is list, combine batch dimension
|
| 719 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 720 |
+
if is_list:
|
| 721 |
+
batch_dim = x[0].shape[0]
|
| 722 |
+
x = torch.cat(x, dim=0)
|
| 723 |
+
|
| 724 |
+
x = self.conv1(x)
|
| 725 |
+
x = self.norm1(x)
|
| 726 |
+
x = self.relu1(x)
|
| 727 |
+
|
| 728 |
+
x = self.layer1(x)
|
| 729 |
+
x = self.layer2(x)
|
| 730 |
+
x = self.layer3(x)
|
| 731 |
+
x = self.layer3_2(x)
|
| 732 |
+
|
| 733 |
+
x = self.conv2(x)
|
| 734 |
+
|
| 735 |
+
if self.training and self.dropout is not None:
|
| 736 |
+
x = self.dropout(x)
|
| 737 |
+
|
| 738 |
+
if is_list:
|
| 739 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 740 |
+
|
| 741 |
+
return x
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def resize(x, scale_factor):
|
| 754 |
+
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
| 755 |
+
|
| 756 |
+
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
|
| 757 |
+
return nn.Sequential(
|
| 758 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
|
| 759 |
+
nn.PReLU(out_channels)
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
class ResBlock(nn.Module):
|
| 763 |
+
def __init__(self, in_channels, side_channels, bias=True):
|
| 764 |
+
super(ResBlock, self).__init__()
|
| 765 |
+
self.side_channels = side_channels
|
| 766 |
+
self.conv1 = nn.Sequential(
|
| 767 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
| 768 |
+
nn.PReLU(in_channels)
|
| 769 |
+
)
|
| 770 |
+
self.conv2 = nn.Sequential(
|
| 771 |
+
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
| 772 |
+
nn.PReLU(side_channels)
|
| 773 |
+
)
|
| 774 |
+
self.conv3 = nn.Sequential(
|
| 775 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
| 776 |
+
nn.PReLU(in_channels)
|
| 777 |
+
)
|
| 778 |
+
self.conv4 = nn.Sequential(
|
| 779 |
+
nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias),
|
| 780 |
+
nn.PReLU(side_channels)
|
| 781 |
+
)
|
| 782 |
+
self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
|
| 783 |
+
self.prelu = nn.PReLU(in_channels)
|
| 784 |
+
|
| 785 |
+
def forward(self, x):
|
| 786 |
+
out = self.conv1(x)
|
| 787 |
+
|
| 788 |
+
res_feat = out[:, :-self.side_channels, ...]
|
| 789 |
+
side_feat = out[:, -self.side_channels:, :, :]
|
| 790 |
+
side_feat = self.conv2(side_feat)
|
| 791 |
+
out = self.conv3(torch.cat([res_feat, side_feat], 1))
|
| 792 |
+
|
| 793 |
+
res_feat = out[:, :-self.side_channels, ...]
|
| 794 |
+
side_feat = out[:, -self.side_channels:, :, :]
|
| 795 |
+
side_feat = self.conv4(side_feat)
|
| 796 |
+
out = self.conv5(torch.cat([res_feat, side_feat], 1))
|
| 797 |
+
|
| 798 |
+
out = self.prelu(x + out)
|
| 799 |
+
return out
|
| 800 |
+
|
| 801 |
+
class Encoder(nn.Module):
|
| 802 |
+
def __init__(self, channels, large=False):
|
| 803 |
+
super(Encoder, self).__init__()
|
| 804 |
+
self.channels = channels
|
| 805 |
+
prev_ch = 3
|
| 806 |
+
for idx, ch in enumerate(channels, 1):
|
| 807 |
+
k = 7 if large and idx == 1 else 3
|
| 808 |
+
p = 3 if k ==7 else 1
|
| 809 |
+
self.register_module(f'pyramid{idx}',
|
| 810 |
+
nn.Sequential(
|
| 811 |
+
convrelu(prev_ch, ch, k, 2, p),
|
| 812 |
+
convrelu(ch, ch, 3, 1, 1)
|
| 813 |
+
))
|
| 814 |
+
prev_ch = ch
|
| 815 |
+
|
| 816 |
+
def forward(self, in_x):
|
| 817 |
+
fs = []
|
| 818 |
+
for idx in range(len(self.channels)):
|
| 819 |
+
out_x = getattr(self, f'pyramid{idx+1}')(in_x)
|
| 820 |
+
fs.append(out_x)
|
| 821 |
+
in_x = out_x
|
| 822 |
+
return fs
|
| 823 |
+
|
| 824 |
+
class InitDecoder(nn.Module):
|
| 825 |
+
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
| 826 |
+
super().__init__()
|
| 827 |
+
self.convblock = nn.Sequential(
|
| 828 |
+
convrelu(in_ch*2+1, in_ch*2),
|
| 829 |
+
ResBlock(in_ch*2, skip_ch),
|
| 830 |
+
nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True)
|
| 831 |
+
)
|
| 832 |
+
def forward(self, f0, f1, embt):
|
| 833 |
+
h, w = f0.shape[2:]
|
| 834 |
+
embt = embt.repeat(1, 1, h, w)
|
| 835 |
+
out = self.convblock(torch.cat([f0, f1, embt], 1))
|
| 836 |
+
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
| 837 |
+
ft_ = out[:, 4:, ...]
|
| 838 |
+
return flow0, flow1, ft_
|
| 839 |
+
|
| 840 |
+
class IntermediateDecoder(nn.Module):
|
| 841 |
+
def __init__(self, in_ch, out_ch, skip_ch) -> None:
|
| 842 |
+
super().__init__()
|
| 843 |
+
self.convblock = nn.Sequential(
|
| 844 |
+
convrelu(in_ch*3+4, in_ch*3),
|
| 845 |
+
ResBlock(in_ch*3, skip_ch),
|
| 846 |
+
nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True)
|
| 847 |
+
)
|
| 848 |
+
def forward(self, ft_, f0, f1, flow0_in, flow1_in):
|
| 849 |
+
f0_warp = warp(f0, flow0_in)
|
| 850 |
+
f1_warp = warp(f1, flow1_in)
|
| 851 |
+
f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1)
|
| 852 |
+
out = self.convblock(f_in)
|
| 853 |
+
flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1)
|
| 854 |
+
ft_ = out[:, 4:, ...]
|
| 855 |
+
flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0)
|
| 856 |
+
flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0)
|
| 857 |
+
return flow0, flow1, ft_
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def multi_flow_combine(comb_block, img0, img1, flow0, flow1,
|
| 870 |
+
mask=None, img_res=None, mean=None):
|
| 871 |
+
'''
|
| 872 |
+
A parallel implementation of multiple flow field warping
|
| 873 |
+
comb_block: An nn.Seqential object.
|
| 874 |
+
img shape: [b, c, h, w]
|
| 875 |
+
flow shape: [b, 2*num_flows, h, w]
|
| 876 |
+
mask (opt):
|
| 877 |
+
If 'mask' is None, the function conduct a simple average.
|
| 878 |
+
img_res (opt):
|
| 879 |
+
If 'img_res' is None, the function adds zero instead.
|
| 880 |
+
mean (opt):
|
| 881 |
+
If 'mean' is None, the function adds zero instead.
|
| 882 |
+
'''
|
| 883 |
+
b, c, h, w = flow0.shape
|
| 884 |
+
num_flows = c // 2
|
| 885 |
+
flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
| 886 |
+
flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w)
|
| 887 |
+
|
| 888 |
+
mask = mask.reshape(b, num_flows, 1, h, w
|
| 889 |
+
).reshape(-1, 1, h, w) if mask is not None else None
|
| 890 |
+
img_res = img_res.reshape(b, num_flows, 3, h, w
|
| 891 |
+
).reshape(-1, 3, h, w) if img_res is not None else 0
|
| 892 |
+
img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w)
|
| 893 |
+
img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w)
|
| 894 |
+
mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1
|
| 895 |
+
) if mean is not None else 0
|
| 896 |
+
|
| 897 |
+
img0_warp = warp(img0, flow0)
|
| 898 |
+
img1_warp = warp(img1, flow1)
|
| 899 |
+
img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res
|
| 900 |
+
img_warps = img_warps.reshape(b, num_flows, 3, h, w)
|
| 901 |
+
imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w))
|
| 902 |
+
return imgt_pred
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class MultiFlowDecoder(nn.Module):
|
| 906 |
+
def __init__(self, in_ch, skip_ch, num_flows=3):
|
| 907 |
+
super(MultiFlowDecoder, self).__init__()
|
| 908 |
+
self.num_flows = num_flows
|
| 909 |
+
self.convblock = nn.Sequential(
|
| 910 |
+
convrelu(in_ch*3+4, in_ch*3),
|
| 911 |
+
ResBlock(in_ch*3, skip_ch),
|
| 912 |
+
nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True)
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
def forward(self, ft_, f0, f1, flow0, flow1):
|
| 916 |
+
n = self.num_flows
|
| 917 |
+
f0_warp = warp(f0, flow0)
|
| 918 |
+
f1_warp = warp(f1, flow1)
|
| 919 |
+
out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1))
|
| 920 |
+
delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1)
|
| 921 |
+
mask = torch.sigmoid(mask)
|
| 922 |
+
|
| 923 |
+
flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0
|
| 924 |
+
).repeat(1, self.num_flows, 1, 1)
|
| 925 |
+
flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0
|
| 926 |
+
).repeat(1, self.num_flows, 1, 1)
|
| 927 |
+
|
| 928 |
+
return flow0, flow1, mask, img_res
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def resize(x, scale_factor):
|
| 941 |
+
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def bilinear_sampler(img, coords, mask=False):
|
| 945 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
| 946 |
+
H, W = img.shape[-2:]
|
| 947 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
| 948 |
+
xgrid = 2*xgrid/(W-1) - 1
|
| 949 |
+
ygrid = 2*ygrid/(H-1) - 1
|
| 950 |
+
|
| 951 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 952 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
| 953 |
+
|
| 954 |
+
if mask:
|
| 955 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 956 |
+
return img, mask.float()
|
| 957 |
+
|
| 958 |
+
return img
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def coords_grid(batch, ht, wd, device):
|
| 962 |
+
coords = torch.meshgrid(torch.arange(ht, device=device),
|
| 963 |
+
torch.arange(wd, device=device),
|
| 964 |
+
indexing='ij')
|
| 965 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 966 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
class SmallUpdateBlock(nn.Module):
|
| 970 |
+
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim,
|
| 971 |
+
corr_levels=4, radius=3, scale_factor=None):
|
| 972 |
+
super(SmallUpdateBlock, self).__init__()
|
| 973 |
+
cor_planes = corr_levels * (2 * radius + 1) **2
|
| 974 |
+
self.scale_factor = scale_factor
|
| 975 |
+
|
| 976 |
+
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
| 977 |
+
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
| 978 |
+
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
| 979 |
+
self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1)
|
| 980 |
+
|
| 981 |
+
self.gru = nn.Sequential(
|
| 982 |
+
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
| 983 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
| 984 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
self.feat_head = nn.Sequential(
|
| 988 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 989 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
| 990 |
+
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
self.flow_head = nn.Sequential(
|
| 994 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 995 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
| 996 |
+
nn.Conv2d(hidden_dim, 4, 3, padding=1),
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 1000 |
+
|
| 1001 |
+
def forward(self, net, flow, corr):
|
| 1002 |
+
net = resize(net, 1 / self.scale_factor
|
| 1003 |
+
) if self.scale_factor is not None else net
|
| 1004 |
+
cor = self.lrelu(self.convc1(corr))
|
| 1005 |
+
flo = self.lrelu(self.convf1(flow))
|
| 1006 |
+
flo = self.lrelu(self.convf2(flo))
|
| 1007 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 1008 |
+
inp = self.lrelu(self.conv(cor_flo))
|
| 1009 |
+
inp = torch.cat([inp, flow, net], dim=1)
|
| 1010 |
+
|
| 1011 |
+
out = self.gru(inp)
|
| 1012 |
+
delta_net = self.feat_head(out)
|
| 1013 |
+
delta_flow = self.flow_head(out)
|
| 1014 |
+
|
| 1015 |
+
if self.scale_factor is not None:
|
| 1016 |
+
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
| 1017 |
+
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
| 1018 |
+
|
| 1019 |
+
return delta_net, delta_flow
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
class BasicUpdateBlock(nn.Module):
|
| 1023 |
+
def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2,
|
| 1024 |
+
fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1):
|
| 1025 |
+
super(BasicUpdateBlock, self).__init__()
|
| 1026 |
+
cor_planes = corr_levels * (2 * radius + 1) **2
|
| 1027 |
+
|
| 1028 |
+
self.scale_factor = scale_factor
|
| 1029 |
+
self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0)
|
| 1030 |
+
self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1)
|
| 1031 |
+
self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3)
|
| 1032 |
+
self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1)
|
| 1033 |
+
self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1)
|
| 1034 |
+
|
| 1035 |
+
self.gru = nn.Sequential(
|
| 1036 |
+
nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1),
|
| 1037 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
| 1038 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
self.feat_head = nn.Sequential(
|
| 1042 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 1043 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
| 1044 |
+
nn.Conv2d(hidden_dim, cdim, 3, padding=1),
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
self.flow_head = nn.Sequential(
|
| 1048 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
| 1049 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
| 1050 |
+
nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1),
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 1054 |
+
|
| 1055 |
+
def forward(self, net, flow, corr):
|
| 1056 |
+
net = resize(net, 1 / self.scale_factor
|
| 1057 |
+
) if self.scale_factor is not None else net
|
| 1058 |
+
cor = self.lrelu(self.convc1(corr))
|
| 1059 |
+
cor = self.lrelu(self.convc2(cor))
|
| 1060 |
+
flo = self.lrelu(self.convf1(flow))
|
| 1061 |
+
flo = self.lrelu(self.convf2(flo))
|
| 1062 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 1063 |
+
inp = self.lrelu(self.conv(cor_flo))
|
| 1064 |
+
inp = torch.cat([inp, flow, net], dim=1)
|
| 1065 |
+
|
| 1066 |
+
out = self.gru(inp)
|
| 1067 |
+
delta_net = self.feat_head(out)
|
| 1068 |
+
delta_flow = self.flow_head(out)
|
| 1069 |
+
|
| 1070 |
+
if self.scale_factor is not None:
|
| 1071 |
+
delta_net = resize(delta_net, scale_factor=self.scale_factor)
|
| 1072 |
+
delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor)
|
| 1073 |
+
return delta_net, delta_flow
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
class BidirCorrBlock:
|
| 1077 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 1078 |
+
self.num_levels = num_levels
|
| 1079 |
+
self.radius = radius
|
| 1080 |
+
self.corr_pyramid = []
|
| 1081 |
+
self.corr_pyramid_T = []
|
| 1082 |
+
|
| 1083 |
+
corr = BidirCorrBlock.corr(fmap1, fmap2)
|
| 1084 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
| 1085 |
+
corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2)
|
| 1086 |
+
|
| 1087 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
| 1088 |
+
corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1)
|
| 1089 |
+
|
| 1090 |
+
self.corr_pyramid.append(corr)
|
| 1091 |
+
self.corr_pyramid_T.append(corr_T)
|
| 1092 |
+
|
| 1093 |
+
for _ in range(self.num_levels-1):
|
| 1094 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
| 1095 |
+
corr_T = F.avg_pool2d(corr_T, 2, stride=2)
|
| 1096 |
+
self.corr_pyramid.append(corr)
|
| 1097 |
+
self.corr_pyramid_T.append(corr_T)
|
| 1098 |
+
|
| 1099 |
+
def __call__(self, coords0, coords1):
|
| 1100 |
+
r = self.radius
|
| 1101 |
+
coords0 = coords0.permute(0, 2, 3, 1)
|
| 1102 |
+
coords1 = coords1.permute(0, 2, 3, 1)
|
| 1103 |
+
assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]"
|
| 1104 |
+
batch, h1, w1, _ = coords0.shape
|
| 1105 |
+
|
| 1106 |
+
out_pyramid = []
|
| 1107 |
+
out_pyramid_T = []
|
| 1108 |
+
for i in range(self.num_levels):
|
| 1109 |
+
corr = self.corr_pyramid[i]
|
| 1110 |
+
corr_T = self.corr_pyramid_T[i]
|
| 1111 |
+
|
| 1112 |
+
dx = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
| 1113 |
+
dy = torch.linspace(-r, r, 2*r+1, device=coords0.device)
|
| 1114 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1)
|
| 1115 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
| 1116 |
+
|
| 1117 |
+
centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
| 1118 |
+
centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
| 1119 |
+
coords_lvl_0 = centroid_lvl_0 + delta_lvl
|
| 1120 |
+
coords_lvl_1 = centroid_lvl_1 + delta_lvl
|
| 1121 |
+
|
| 1122 |
+
corr = bilinear_sampler(corr, coords_lvl_0)
|
| 1123 |
+
corr_T = bilinear_sampler(corr_T, coords_lvl_1)
|
| 1124 |
+
corr = corr.view(batch, h1, w1, -1)
|
| 1125 |
+
corr_T = corr_T.view(batch, h1, w1, -1)
|
| 1126 |
+
out_pyramid.append(corr)
|
| 1127 |
+
out_pyramid_T.append(corr_T)
|
| 1128 |
+
|
| 1129 |
+
out = torch.cat(out_pyramid, dim=-1)
|
| 1130 |
+
out_T = torch.cat(out_pyramid_T, dim=-1)
|
| 1131 |
+
return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float()
|
| 1132 |
+
|
| 1133 |
+
@staticmethod
|
| 1134 |
+
def corr(fmap1, fmap2):
|
| 1135 |
+
batch, dim, ht, wd = fmap1.shape
|
| 1136 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
| 1137 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
| 1138 |
+
|
| 1139 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
| 1140 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
| 1141 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
| 1142 |
+
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
class AMT_S(nn.Module):
|
| 1154 |
+
def __init__(self,
|
| 1155 |
+
corr_radius=3,
|
| 1156 |
+
corr_lvls=4,
|
| 1157 |
+
num_flows=3,
|
| 1158 |
+
channels=[20, 32, 44, 56],
|
| 1159 |
+
skip_channels=20):
|
| 1160 |
+
super(AMT_S, self).__init__()
|
| 1161 |
+
self.radius = corr_radius
|
| 1162 |
+
self.corr_levels = corr_lvls
|
| 1163 |
+
self.num_flows = num_flows
|
| 1164 |
+
self.channels = channels
|
| 1165 |
+
self.skip_channels = skip_channels
|
| 1166 |
+
|
| 1167 |
+
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.)
|
| 1168 |
+
self.encoder = Encoder(channels)
|
| 1169 |
+
|
| 1170 |
+
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
|
| 1171 |
+
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
|
| 1172 |
+
self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
|
| 1173 |
+
self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
|
| 1174 |
+
|
| 1175 |
+
self.update4 = self._get_updateblock(44)
|
| 1176 |
+
self.update3 = self._get_updateblock(32, 2)
|
| 1177 |
+
self.update2 = self._get_updateblock(20, 4)
|
| 1178 |
+
|
| 1179 |
+
self.comb_block = nn.Sequential(
|
| 1180 |
+
nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1),
|
| 1181 |
+
nn.PReLU(6*num_flows),
|
| 1182 |
+
nn.Conv2d(6*num_flows, 3, 3, 1, 1),
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
def _get_updateblock(self, cdim, scale_factor=None):
|
| 1186 |
+
return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64,
|
| 1187 |
+
fc_dim=68, scale_factor=scale_factor,
|
| 1188 |
+
corr_levels=self.corr_levels, radius=self.radius)
|
| 1189 |
+
|
| 1190 |
+
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
| 1191 |
+
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
| 1192 |
+
# based on linear assumption
|
| 1193 |
+
t1_scale = 1. / embt
|
| 1194 |
+
t0_scale = 1. / (1. - embt)
|
| 1195 |
+
if downsample != 1:
|
| 1196 |
+
inv = 1 / downsample
|
| 1197 |
+
flow0 = inv * resize(flow0, scale_factor=inv)
|
| 1198 |
+
flow1 = inv * resize(flow1, scale_factor=inv)
|
| 1199 |
+
|
| 1200 |
+
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
|
| 1201 |
+
corr = torch.cat([corr0, corr1], dim=1)
|
| 1202 |
+
flow = torch.cat([flow0, flow1], dim=1)
|
| 1203 |
+
return corr, flow
|
| 1204 |
+
|
| 1205 |
+
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
| 1206 |
+
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
| 1207 |
+
img0 = img0 - mean_
|
| 1208 |
+
img1 = img1 - mean_
|
| 1209 |
+
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
|
| 1210 |
+
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
|
| 1211 |
+
b, _, h, w = img0_.shape
|
| 1212 |
+
coord = coords_grid(b, h // 8, w // 8, img0.device)
|
| 1213 |
+
|
| 1214 |
+
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
| 1215 |
+
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
|
| 1216 |
+
|
| 1217 |
+
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
|
| 1218 |
+
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
|
| 1219 |
+
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
|
| 1220 |
+
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
|
| 1221 |
+
|
| 1222 |
+
######################################### the 4th decoder #########################################
|
| 1223 |
+
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
|
| 1224 |
+
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
|
| 1225 |
+
up_flow0_4, up_flow1_4,
|
| 1226 |
+
embt, downsample=1)
|
| 1227 |
+
|
| 1228 |
+
# residue update with lookup corr
|
| 1229 |
+
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
|
| 1230 |
+
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
|
| 1231 |
+
up_flow0_4 = up_flow0_4 + delta_flow0_4
|
| 1232 |
+
up_flow1_4 = up_flow1_4 + delta_flow1_4
|
| 1233 |
+
ft_3_ = ft_3_ + delta_ft_3_
|
| 1234 |
+
|
| 1235 |
+
######################################### the 3rd decoder #########################################
|
| 1236 |
+
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
|
| 1237 |
+
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
|
| 1238 |
+
coord, up_flow0_3, up_flow1_3,
|
| 1239 |
+
embt, downsample=2)
|
| 1240 |
+
|
| 1241 |
+
# residue update with lookup corr
|
| 1242 |
+
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
|
| 1243 |
+
delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
|
| 1244 |
+
up_flow0_3 = up_flow0_3 + delta_flow0_3
|
| 1245 |
+
up_flow1_3 = up_flow1_3 + delta_flow1_3
|
| 1246 |
+
ft_2_ = ft_2_ + delta_ft_2_
|
| 1247 |
+
|
| 1248 |
+
######################################### the 2nd decoder #########################################
|
| 1249 |
+
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
|
| 1250 |
+
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
|
| 1251 |
+
coord, up_flow0_2, up_flow1_2,
|
| 1252 |
+
embt, downsample=4)
|
| 1253 |
+
|
| 1254 |
+
# residue update with lookup corr
|
| 1255 |
+
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
|
| 1256 |
+
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
|
| 1257 |
+
up_flow0_2 = up_flow0_2 + delta_flow0_2
|
| 1258 |
+
up_flow1_2 = up_flow1_2 + delta_flow1_2
|
| 1259 |
+
ft_1_ = ft_1_ + delta_ft_1_
|
| 1260 |
+
|
| 1261 |
+
######################################### the 1st decoder #########################################
|
| 1262 |
+
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
|
| 1263 |
+
|
| 1264 |
+
if scale_factor != 1.0:
|
| 1265 |
+
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
| 1266 |
+
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
| 1267 |
+
mask = resize(mask, scale_factor=(1.0/scale_factor))
|
| 1268 |
+
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
|
| 1269 |
+
|
| 1270 |
+
# Merge multiple predictions
|
| 1271 |
+
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
|
| 1272 |
+
mask, img_res, mean_)
|
| 1273 |
+
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
| 1274 |
+
|
| 1275 |
+
if eval:
|
| 1276 |
+
return { 'imgt_pred': imgt_pred, }
|
| 1277 |
+
else:
|
| 1278 |
+
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
|
| 1279 |
+
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
|
| 1280 |
+
return {
|
| 1281 |
+
'imgt_pred': imgt_pred,
|
| 1282 |
+
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
| 1283 |
+
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
| 1284 |
+
'ft_pred': [ft_1_, ft_2_, ft_3_],
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class AMT_L(nn.Module):
|
| 1298 |
+
def __init__(self,
|
| 1299 |
+
corr_radius=3,
|
| 1300 |
+
corr_lvls=4,
|
| 1301 |
+
num_flows=5,
|
| 1302 |
+
channels=[48, 64, 72, 128],
|
| 1303 |
+
skip_channels=48
|
| 1304 |
+
):
|
| 1305 |
+
super(AMT_L, self).__init__()
|
| 1306 |
+
self.radius = corr_radius
|
| 1307 |
+
self.corr_levels = corr_lvls
|
| 1308 |
+
self.num_flows = num_flows
|
| 1309 |
+
|
| 1310 |
+
self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.)
|
| 1311 |
+
self.encoder = Encoder([48, 64, 72, 128], large=True)
|
| 1312 |
+
|
| 1313 |
+
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
|
| 1314 |
+
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
|
| 1315 |
+
self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
|
| 1316 |
+
self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
|
| 1317 |
+
|
| 1318 |
+
self.update4 = self._get_updateblock(72, None)
|
| 1319 |
+
self.update3 = self._get_updateblock(64, 2.0)
|
| 1320 |
+
self.update2 = self._get_updateblock(48, 4.0)
|
| 1321 |
+
|
| 1322 |
+
self.comb_block = nn.Sequential(
|
| 1323 |
+
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
|
| 1324 |
+
nn.PReLU(6*self.num_flows),
|
| 1325 |
+
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
|
| 1326 |
+
)
|
| 1327 |
+
|
| 1328 |
+
def _get_updateblock(self, cdim, scale_factor=None):
|
| 1329 |
+
return BasicUpdateBlock(cdim=cdim, hidden_dim=128, flow_dim=48,
|
| 1330 |
+
corr_dim=256, corr_dim2=160, fc_dim=124,
|
| 1331 |
+
scale_factor=scale_factor, corr_levels=self.corr_levels,
|
| 1332 |
+
radius=self.radius)
|
| 1333 |
+
|
| 1334 |
+
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
| 1335 |
+
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
| 1336 |
+
# based on linear assumption
|
| 1337 |
+
t1_scale = 1. / embt
|
| 1338 |
+
t0_scale = 1. / (1. - embt)
|
| 1339 |
+
if downsample != 1:
|
| 1340 |
+
inv = 1 / downsample
|
| 1341 |
+
flow0 = inv * resize(flow0, scale_factor=inv)
|
| 1342 |
+
flow1 = inv * resize(flow1, scale_factor=inv)
|
| 1343 |
+
|
| 1344 |
+
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
|
| 1345 |
+
corr = torch.cat([corr0, corr1], dim=1)
|
| 1346 |
+
flow = torch.cat([flow0, flow1], dim=1)
|
| 1347 |
+
return corr, flow
|
| 1348 |
+
|
| 1349 |
+
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
| 1350 |
+
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
| 1351 |
+
img0 = img0 - mean_
|
| 1352 |
+
img1 = img1 - mean_
|
| 1353 |
+
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
|
| 1354 |
+
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
|
| 1355 |
+
b, _, h, w = img0_.shape
|
| 1356 |
+
coord = coords_grid(b, h // 8, w // 8, img0.device)
|
| 1357 |
+
|
| 1358 |
+
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
| 1359 |
+
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
|
| 1360 |
+
|
| 1361 |
+
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
|
| 1362 |
+
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
|
| 1363 |
+
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
|
| 1364 |
+
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
|
| 1365 |
+
|
| 1366 |
+
######################################### the 4th decoder #########################################
|
| 1367 |
+
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
|
| 1368 |
+
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
|
| 1369 |
+
up_flow0_4, up_flow1_4,
|
| 1370 |
+
embt, downsample=1)
|
| 1371 |
+
|
| 1372 |
+
# residue update with lookup corr
|
| 1373 |
+
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
|
| 1374 |
+
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
|
| 1375 |
+
up_flow0_4 = up_flow0_4 + delta_flow0_4
|
| 1376 |
+
up_flow1_4 = up_flow1_4 + delta_flow1_4
|
| 1377 |
+
ft_3_ = ft_3_ + delta_ft_3_
|
| 1378 |
+
|
| 1379 |
+
######################################### the 3rd decoder #########################################
|
| 1380 |
+
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
|
| 1381 |
+
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
|
| 1382 |
+
coord, up_flow0_3, up_flow1_3,
|
| 1383 |
+
embt, downsample=2)
|
| 1384 |
+
|
| 1385 |
+
# residue update with lookup corr
|
| 1386 |
+
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
|
| 1387 |
+
delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
|
| 1388 |
+
up_flow0_3 = up_flow0_3 + delta_flow0_3
|
| 1389 |
+
up_flow1_3 = up_flow1_3 + delta_flow1_3
|
| 1390 |
+
ft_2_ = ft_2_ + delta_ft_2_
|
| 1391 |
+
|
| 1392 |
+
######################################### the 2nd decoder #########################################
|
| 1393 |
+
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
|
| 1394 |
+
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
|
| 1395 |
+
coord, up_flow0_2, up_flow1_2,
|
| 1396 |
+
embt, downsample=4)
|
| 1397 |
+
|
| 1398 |
+
# residue update with lookup corr
|
| 1399 |
+
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
|
| 1400 |
+
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
|
| 1401 |
+
up_flow0_2 = up_flow0_2 + delta_flow0_2
|
| 1402 |
+
up_flow1_2 = up_flow1_2 + delta_flow1_2
|
| 1403 |
+
ft_1_ = ft_1_ + delta_ft_1_
|
| 1404 |
+
|
| 1405 |
+
######################################### the 1st decoder #########################################
|
| 1406 |
+
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
|
| 1407 |
+
|
| 1408 |
+
if scale_factor != 1.0:
|
| 1409 |
+
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
| 1410 |
+
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
| 1411 |
+
mask = resize(mask, scale_factor=(1.0/scale_factor))
|
| 1412 |
+
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
|
| 1413 |
+
|
| 1414 |
+
# Merge multiple predictions
|
| 1415 |
+
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
|
| 1416 |
+
mask, img_res, mean_)
|
| 1417 |
+
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
| 1418 |
+
|
| 1419 |
+
if eval:
|
| 1420 |
+
return { 'imgt_pred': imgt_pred, }
|
| 1421 |
+
else:
|
| 1422 |
+
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
|
| 1423 |
+
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
|
| 1424 |
+
return {
|
| 1425 |
+
'imgt_pred': imgt_pred,
|
| 1426 |
+
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
| 1427 |
+
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
| 1428 |
+
'ft_pred': [ft_1_, ft_2_, ft_3_],
|
| 1429 |
+
}
|
| 1430 |
+
|
| 1431 |
+
|
| 1432 |
+
|
| 1433 |
+
|
| 1434 |
+
|
| 1435 |
+
|
| 1436 |
+
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
|
| 1440 |
+
|
| 1441 |
+
class AMT_G(nn.Module):
|
| 1442 |
+
def __init__(self,
|
| 1443 |
+
corr_radius=3,
|
| 1444 |
+
corr_lvls=4,
|
| 1445 |
+
num_flows=5,
|
| 1446 |
+
channels=[84, 96, 112, 128],
|
| 1447 |
+
skip_channels=84):
|
| 1448 |
+
super(AMT_G, self).__init__()
|
| 1449 |
+
self.radius = corr_radius
|
| 1450 |
+
self.corr_levels = corr_lvls
|
| 1451 |
+
self.num_flows = num_flows
|
| 1452 |
+
|
| 1453 |
+
self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.)
|
| 1454 |
+
self.encoder = Encoder(channels, large=True)
|
| 1455 |
+
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
|
| 1456 |
+
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
|
| 1457 |
+
self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
|
| 1458 |
+
self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
|
| 1459 |
+
|
| 1460 |
+
self.update4 = self._get_updateblock(112, None)
|
| 1461 |
+
self.update3_low = self._get_updateblock(96, 2.0)
|
| 1462 |
+
self.update2_low = self._get_updateblock(84, 4.0)
|
| 1463 |
+
|
| 1464 |
+
self.update3_high = self._get_updateblock(96, None)
|
| 1465 |
+
self.update2_high = self._get_updateblock(84, None)
|
| 1466 |
+
|
| 1467 |
+
self.comb_block = nn.Sequential(
|
| 1468 |
+
nn.Conv2d(3*self.num_flows, 6*self.num_flows, 7, 1, 3),
|
| 1469 |
+
nn.PReLU(6*self.num_flows),
|
| 1470 |
+
nn.Conv2d(6*self.num_flows, 3, 7, 1, 3),
|
| 1471 |
+
)
|
| 1472 |
+
|
| 1473 |
+
def _get_updateblock(self, cdim, scale_factor=None):
|
| 1474 |
+
return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64,
|
| 1475 |
+
corr_dim=256, corr_dim2=192, fc_dim=188,
|
| 1476 |
+
scale_factor=scale_factor, corr_levels=self.corr_levels,
|
| 1477 |
+
radius=self.radius)
|
| 1478 |
+
|
| 1479 |
+
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
|
| 1480 |
+
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
|
| 1481 |
+
# based on linear assumption
|
| 1482 |
+
t1_scale = 1. / embt
|
| 1483 |
+
t0_scale = 1. / (1. - embt)
|
| 1484 |
+
if downsample != 1:
|
| 1485 |
+
inv = 1 / downsample
|
| 1486 |
+
flow0 = inv * resize(flow0, scale_factor=inv)
|
| 1487 |
+
flow1 = inv * resize(flow1, scale_factor=inv)
|
| 1488 |
+
|
| 1489 |
+
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
|
| 1490 |
+
corr = torch.cat([corr0, corr1], dim=1)
|
| 1491 |
+
flow = torch.cat([flow0, flow1], dim=1)
|
| 1492 |
+
return corr, flow
|
| 1493 |
+
|
| 1494 |
+
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
|
| 1495 |
+
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
| 1496 |
+
img0 = img0 - mean_
|
| 1497 |
+
img1 = img1 - mean_
|
| 1498 |
+
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
|
| 1499 |
+
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
|
| 1500 |
+
b, _, h, w = img0_.shape
|
| 1501 |
+
coord = coords_grid(b, h // 8, w // 8, img0.device)
|
| 1502 |
+
|
| 1503 |
+
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
|
| 1504 |
+
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
|
| 1505 |
+
|
| 1506 |
+
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
|
| 1507 |
+
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
|
| 1508 |
+
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
|
| 1509 |
+
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
|
| 1510 |
+
|
| 1511 |
+
######################################### the 4th decoder #########################################
|
| 1512 |
+
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
|
| 1513 |
+
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
|
| 1514 |
+
up_flow0_4, up_flow1_4,
|
| 1515 |
+
embt, downsample=1)
|
| 1516 |
+
|
| 1517 |
+
# residue update with lookup corr
|
| 1518 |
+
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
|
| 1519 |
+
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
|
| 1520 |
+
up_flow0_4 = up_flow0_4 + delta_flow0_4
|
| 1521 |
+
up_flow1_4 = up_flow1_4 + delta_flow1_4
|
| 1522 |
+
ft_3_ = ft_3_ + delta_ft_3_
|
| 1523 |
+
|
| 1524 |
+
######################################### the 3rd decoder #########################################
|
| 1525 |
+
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
|
| 1526 |
+
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
|
| 1527 |
+
coord, up_flow0_3, up_flow1_3,
|
| 1528 |
+
embt, downsample=2)
|
| 1529 |
+
|
| 1530 |
+
# residue update with lookup corr
|
| 1531 |
+
delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3)
|
| 1532 |
+
delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
|
| 1533 |
+
up_flow0_3 = up_flow0_3 + delta_flow0_3
|
| 1534 |
+
up_flow1_3 = up_flow1_3 + delta_flow1_3
|
| 1535 |
+
ft_2_ = ft_2_ + delta_ft_2_
|
| 1536 |
+
|
| 1537 |
+
# residue update with lookup corr (hr)
|
| 1538 |
+
corr_3 = resize(corr_3, scale_factor=2.0)
|
| 1539 |
+
up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1)
|
| 1540 |
+
delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3)
|
| 1541 |
+
ft_2_ += delta_ft_2_
|
| 1542 |
+
up_flow0_3 += delta_up_flow_3[:, 0:2]
|
| 1543 |
+
up_flow1_3 += delta_up_flow_3[:, 2:4]
|
| 1544 |
+
|
| 1545 |
+
######################################### the 2nd decoder #########################################
|
| 1546 |
+
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
|
| 1547 |
+
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
|
| 1548 |
+
coord, up_flow0_2, up_flow1_2,
|
| 1549 |
+
embt, downsample=4)
|
| 1550 |
+
|
| 1551 |
+
# residue update with lookup corr
|
| 1552 |
+
delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2)
|
| 1553 |
+
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
|
| 1554 |
+
up_flow0_2 = up_flow0_2 + delta_flow0_2
|
| 1555 |
+
up_flow1_2 = up_flow1_2 + delta_flow1_2
|
| 1556 |
+
ft_1_ = ft_1_ + delta_ft_1_
|
| 1557 |
+
|
| 1558 |
+
# residue update with lookup corr (hr)
|
| 1559 |
+
corr_2 = resize(corr_2, scale_factor=4.0)
|
| 1560 |
+
up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1)
|
| 1561 |
+
delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2)
|
| 1562 |
+
ft_1_ += delta_ft_1_
|
| 1563 |
+
up_flow0_2 += delta_up_flow_2[:, 0:2]
|
| 1564 |
+
up_flow1_2 += delta_up_flow_2[:, 2:4]
|
| 1565 |
+
|
| 1566 |
+
######################################### the 1st decoder #########################################
|
| 1567 |
+
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
|
| 1568 |
+
|
| 1569 |
+
if scale_factor != 1.0:
|
| 1570 |
+
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
| 1571 |
+
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
|
| 1572 |
+
mask = resize(mask, scale_factor=(1.0/scale_factor))
|
| 1573 |
+
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
|
| 1574 |
+
|
| 1575 |
+
# Merge multiple predictions
|
| 1576 |
+
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
|
| 1577 |
+
mask, img_res, mean_)
|
| 1578 |
+
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
| 1579 |
+
|
| 1580 |
+
if eval:
|
| 1581 |
+
return { 'imgt_pred': imgt_pred, }
|
| 1582 |
+
else:
|
| 1583 |
+
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
|
| 1584 |
+
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
|
| 1585 |
+
return {
|
| 1586 |
+
'imgt_pred': imgt_pred,
|
| 1587 |
+
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
|
| 1588 |
+
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
|
| 1589 |
+
'ft_pred': [ft_1_, ft_2_, ft_3_],
|
| 1590 |
+
}
|
vfi_models/cain/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
import pathlib
|
| 4 |
+
from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
|
| 5 |
+
import typing
|
| 6 |
+
from comfy.model_management import get_torch_device
|
| 7 |
+
|
| 8 |
+
MODEL_TYPE = pathlib.Path(__file__).parent.name
|
| 9 |
+
CKPT_NAMES = ["pretrained_cain.pth"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CAIN_VFI:
|
| 13 |
+
@classmethod
|
| 14 |
+
def INPUT_TYPES(s):
|
| 15 |
+
return {
|
| 16 |
+
"required": {
|
| 17 |
+
"ckpt_name": (CKPT_NAMES, ),
|
| 18 |
+
"frames": ("IMAGE", ),
|
| 19 |
+
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
| 20 |
+
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000})
|
| 21 |
+
},
|
| 22 |
+
"optional": {
|
| 23 |
+
"optional_interpolation_states": ("INTERPOLATION_STATES", )
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
RETURN_TYPES = ("IMAGE", )
|
| 28 |
+
FUNCTION = "vfi"
|
| 29 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
| 30 |
+
|
| 31 |
+
def vfi(
|
| 32 |
+
self,
|
| 33 |
+
ckpt_name: typing.AnyStr,
|
| 34 |
+
frames: torch.Tensor,
|
| 35 |
+
clear_cache_after_n_frames: typing.SupportsInt = 1,
|
| 36 |
+
multiplier: typing.SupportsInt = 2,
|
| 37 |
+
optional_interpolation_states: InterpolationStateList = None,
|
| 38 |
+
**kwargs
|
| 39 |
+
):
|
| 40 |
+
from .cain_arch import CAIN
|
| 41 |
+
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
|
| 42 |
+
sd = torch.load(model_path)["state_dict"]
|
| 43 |
+
sd = {key.replace('module.', ''): value for key, value in sd.items()}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
global interpolation_model
|
| 47 |
+
interpolation_model = CAIN(depth=3)
|
| 48 |
+
interpolation_model.load_state_dict(sd)
|
| 49 |
+
interpolation_model.eval().to(get_torch_device())
|
| 50 |
+
del sd
|
| 51 |
+
|
| 52 |
+
frames = preprocess_frames(frames)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def return_middle_frame(frame_0, frame_1, timestep, model):
|
| 56 |
+
#CAIN does some direct modifications to input frame tensors so we need to clone them
|
| 57 |
+
return model(frame_0.detach().clone(), frame_1.detach().clone())[0]
|
| 58 |
+
|
| 59 |
+
args = [interpolation_model]
|
| 60 |
+
out = postprocess_frames(
|
| 61 |
+
generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
|
| 62 |
+
interpolation_states=optional_interpolation_states, use_timestep=False, dtype=torch.float32)
|
| 63 |
+
)
|
| 64 |
+
return (out,)
|
vfi_models/cain/cain_arch.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .common import *
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Encoder(nn.Module):
|
| 11 |
+
def __init__(self, in_channels=3, depth=3):
|
| 12 |
+
super(Encoder, self).__init__()
|
| 13 |
+
|
| 14 |
+
# Shuffle pixels to expand in channel dimension
|
| 15 |
+
# shuffler_list = [PixelShuffle(0.5) for i in range(depth)]
|
| 16 |
+
# self.shuffler = nn.Sequential(*shuffler_list)
|
| 17 |
+
self.shuffler = PixelShuffle(1 / 2**depth)
|
| 18 |
+
|
| 19 |
+
relu = nn.LeakyReLU(0.2, True)
|
| 20 |
+
|
| 21 |
+
# FF_RCAN or FF_Resblocks
|
| 22 |
+
self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu)
|
| 23 |
+
|
| 24 |
+
def forward(self, x1, x2):
|
| 25 |
+
"""
|
| 26 |
+
Encoder: Shuffle-spread --> Feature Fusion --> Return fused features
|
| 27 |
+
"""
|
| 28 |
+
feats1 = self.shuffler(x1)
|
| 29 |
+
feats2 = self.shuffler(x2)
|
| 30 |
+
|
| 31 |
+
feats = self.interpolate(feats1, feats2)
|
| 32 |
+
|
| 33 |
+
return feats
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Decoder(nn.Module):
|
| 37 |
+
def __init__(self, depth=3):
|
| 38 |
+
super(Decoder, self).__init__()
|
| 39 |
+
|
| 40 |
+
# shuffler_list = [PixelShuffle(2) for i in range(depth)]
|
| 41 |
+
# self.shuffler = nn.Sequential(*shuffler_list)
|
| 42 |
+
self.shuffler = PixelShuffle(2**depth)
|
| 43 |
+
|
| 44 |
+
def forward(self, feats):
|
| 45 |
+
out = self.shuffler(feats)
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class CAIN(nn.Module):
|
| 50 |
+
def __init__(self, depth=3):
|
| 51 |
+
super(CAIN, self).__init__()
|
| 52 |
+
|
| 53 |
+
self.encoder = Encoder(in_channels=3, depth=depth)
|
| 54 |
+
self.decoder = Decoder(depth=depth)
|
| 55 |
+
|
| 56 |
+
def forward(self, x1, x2):
|
| 57 |
+
x1, m1 = sub_mean(x1)
|
| 58 |
+
x2, m2 = sub_mean(x2)
|
| 59 |
+
|
| 60 |
+
if not self.training:
|
| 61 |
+
paddingInput, paddingOutput = InOutPaddings(x1)
|
| 62 |
+
x1 = paddingInput(x1)
|
| 63 |
+
x2 = paddingInput(x2)
|
| 64 |
+
|
| 65 |
+
feats = self.encoder(x1, x2)
|
| 66 |
+
out = self.decoder(feats)
|
| 67 |
+
|
| 68 |
+
if not self.training:
|
| 69 |
+
out = paddingOutput(out)
|
| 70 |
+
|
| 71 |
+
mi = (m1 + m2) / 2
|
| 72 |
+
out += mi
|
| 73 |
+
|
| 74 |
+
return out, feats
|
vfi_models/cain/cain_encdec_arch.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .common import *
|
| 8 |
+
from comfy.model_management import get_torch_device
|
| 9 |
+
|
| 10 |
+
class Encoder(nn.Module):
|
| 11 |
+
def __init__(self, in_channels=3, depth=3, nf_start=32, norm=False):
|
| 12 |
+
super(Encoder, self).__init__()
|
| 13 |
+
self.device = get_torch_device()
|
| 14 |
+
|
| 15 |
+
nf = nf_start
|
| 16 |
+
relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 17 |
+
|
| 18 |
+
self.body = nn.Sequential(
|
| 19 |
+
ConvNorm(in_channels, nf * 1, 7, stride=1, norm=norm),
|
| 20 |
+
relu,
|
| 21 |
+
ConvNorm(nf * 1, nf * 2, 5, stride=2, norm=norm),
|
| 22 |
+
relu,
|
| 23 |
+
ConvNorm(nf * 2, nf * 4, 5, stride=2, norm=norm),
|
| 24 |
+
relu,
|
| 25 |
+
ConvNorm(nf * 4, nf * 6, 5, stride=2, norm=norm)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.interpolate = Interpolation(5, 12, nf * 6, reduction=16, act=relu)
|
| 29 |
+
|
| 30 |
+
def forward(self, x1, x2):
|
| 31 |
+
"""
|
| 32 |
+
Encoder: Feature Extraction --> Feature Fusion --> Return
|
| 33 |
+
"""
|
| 34 |
+
feats1 = self.body(x1)
|
| 35 |
+
feats2 = self.body(x2)
|
| 36 |
+
|
| 37 |
+
feats = self.interpolate(feats1, feats2)
|
| 38 |
+
|
| 39 |
+
return feats
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Decoder(nn.Module):
|
| 43 |
+
def __init__(self, in_channels=192, out_channels=3, depth=3, norm=False, up_mode='shuffle'):
|
| 44 |
+
super(Decoder, self).__init__()
|
| 45 |
+
self.device = get_torch_device()
|
| 46 |
+
|
| 47 |
+
relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 48 |
+
|
| 49 |
+
nf = [in_channels, (in_channels*2)//3, in_channels//3, in_channels//6]
|
| 50 |
+
#nf = [192, 128, 64, 32]
|
| 51 |
+
#nf = [186, 124, 62, 31]
|
| 52 |
+
self.body = nn.Sequential(
|
| 53 |
+
UpConvNorm(nf[0], nf[1], mode=up_mode, norm=norm),
|
| 54 |
+
ResBlock(nf[1], nf[1], norm=norm, act=relu),
|
| 55 |
+
UpConvNorm(nf[1], nf[2], mode=up_mode, norm=norm),
|
| 56 |
+
ResBlock(nf[2], nf[2], norm=norm, act=relu),
|
| 57 |
+
UpConvNorm(nf[2], nf[3], mode=up_mode, norm=norm),
|
| 58 |
+
ResBlock(nf[3], nf[3], norm=norm, act=relu),
|
| 59 |
+
conv7x7(nf[3], out_channels)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, feats):
|
| 63 |
+
out = self.body(feats)
|
| 64 |
+
#out = self.conv_final(out)
|
| 65 |
+
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class CAIN_EncDec(nn.Module):
|
| 70 |
+
def __init__(self, depth=3, n_resblocks=3, start_filts=32, up_mode='shuffle'):
|
| 71 |
+
super(CAIN_EncDec, self).__init__()
|
| 72 |
+
self.depth = depth
|
| 73 |
+
|
| 74 |
+
self.encoder = Encoder(in_channels=3, depth=depth, norm=False)
|
| 75 |
+
self.decoder = Decoder(in_channels=start_filts*6, depth=depth, norm=False, up_mode=up_mode)
|
| 76 |
+
|
| 77 |
+
def forward(self, x1, x2):
|
| 78 |
+
x1, m1 = sub_mean(x1)
|
| 79 |
+
x2, m2 = sub_mean(x2)
|
| 80 |
+
|
| 81 |
+
if not self.training:
|
| 82 |
+
paddingInput, paddingOutput = InOutPaddings(x1)
|
| 83 |
+
x1 = paddingInput(x1)
|
| 84 |
+
x2 = paddingInput(x2)
|
| 85 |
+
|
| 86 |
+
feats = self.encoder(x1, x2)
|
| 87 |
+
out = self.decoder(feats)
|
| 88 |
+
|
| 89 |
+
if not self.training:
|
| 90 |
+
out = paddingOutput(out)
|
| 91 |
+
|
| 92 |
+
mi = (m1 + m2)/2
|
| 93 |
+
out += mi
|
| 94 |
+
|
| 95 |
+
return out, feats
|
vfi_models/cain/cain_noca_arch.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .common import *
|
| 8 |
+
from comfy.model_management import get_torch_device
|
| 9 |
+
|
| 10 |
+
class Encoder(nn.Module):
|
| 11 |
+
def __init__(self, in_channels=3, depth=3):
|
| 12 |
+
super(Encoder, self).__init__()
|
| 13 |
+
self.device = get_torch_device()
|
| 14 |
+
|
| 15 |
+
self.shuffler = PixelShuffle(1/2**depth)
|
| 16 |
+
# self.shuffler = nn.Sequential(
|
| 17 |
+
# PixelShuffle(1/2),
|
| 18 |
+
# PixelShuffle(1/2),
|
| 19 |
+
# PixelShuffle(1/2))
|
| 20 |
+
self.interpolate = Interpolation_res(5, 12, in_channels * (4**depth))
|
| 21 |
+
|
| 22 |
+
def forward(self, x1, x2):
|
| 23 |
+
feats1 = self.shuffler(x1)
|
| 24 |
+
feats2 = self.shuffler(x2)
|
| 25 |
+
|
| 26 |
+
feats = self.interpolate(feats1, feats2)
|
| 27 |
+
|
| 28 |
+
return feats
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Decoder(nn.Module):
|
| 32 |
+
def __init__(self, depth=3):
|
| 33 |
+
super(Decoder, self).__init__()
|
| 34 |
+
self.device = get_torch_device()
|
| 35 |
+
|
| 36 |
+
self.shuffler = PixelShuffle(2**depth)
|
| 37 |
+
# self.shuffler = nn.Sequential(
|
| 38 |
+
# PixelShuffle(2),
|
| 39 |
+
# PixelShuffle(2),
|
| 40 |
+
# PixelShuffle(2))
|
| 41 |
+
|
| 42 |
+
def forward(self, feats):
|
| 43 |
+
out = self.shuffler(feats)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CAIN_NoCA(nn.Module):
|
| 48 |
+
def __init__(self, depth=3):
|
| 49 |
+
super(CAIN_NoCA, self).__init__()
|
| 50 |
+
self.depth = depth
|
| 51 |
+
|
| 52 |
+
self.encoder = Encoder(in_channels=3, depth=depth)
|
| 53 |
+
self.decoder = Decoder(depth=depth)
|
| 54 |
+
|
| 55 |
+
def forward(self, x1, x2):
|
| 56 |
+
x1, m1 = sub_mean(x1)
|
| 57 |
+
x2, m2 = sub_mean(x2)
|
| 58 |
+
|
| 59 |
+
if not self.training:
|
| 60 |
+
paddingInput, paddingOutput = InOutPaddings(x1)
|
| 61 |
+
x1 = paddingInput(x1)
|
| 62 |
+
x2 = paddingInput(x2)
|
| 63 |
+
|
| 64 |
+
feats = self.encoder(x1, x2)
|
| 65 |
+
out = self.decoder(feats)
|
| 66 |
+
|
| 67 |
+
if not self.training:
|
| 68 |
+
out = paddingOutput(out)
|
| 69 |
+
|
| 70 |
+
mi = (m1 + m2) / 2
|
| 71 |
+
out += mi
|
| 72 |
+
|
| 73 |
+
return out, feats
|
vfi_models/cain/common.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
def sub_mean(x):
|
| 8 |
+
mean = x.mean(2, keepdim=True).mean(3, keepdim=True)
|
| 9 |
+
x -= mean
|
| 10 |
+
return x, mean
|
| 11 |
+
|
| 12 |
+
def InOutPaddings(x):
|
| 13 |
+
w, h = x.size(3), x.size(2)
|
| 14 |
+
padding_width, padding_height = 0, 0
|
| 15 |
+
if w != ((w >> 7) << 7):
|
| 16 |
+
padding_width = (((w >> 7) + 1) << 7) - w
|
| 17 |
+
if h != ((h >> 7) << 7):
|
| 18 |
+
padding_height = (((h >> 7) + 1) << 7) - h
|
| 19 |
+
paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2,
|
| 20 |
+
padding_height // 2, padding_height - padding_height // 2])
|
| 21 |
+
paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width,
|
| 22 |
+
0 - padding_height // 2, padding_height // 2 - padding_height])
|
| 23 |
+
return paddingInput, paddingOutput
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConvNorm(nn.Module):
|
| 27 |
+
def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False):
|
| 28 |
+
super(ConvNorm, self).__init__()
|
| 29 |
+
|
| 30 |
+
reflection_padding = kernel_size // 2
|
| 31 |
+
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
|
| 32 |
+
self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True)
|
| 33 |
+
|
| 34 |
+
self.norm = norm
|
| 35 |
+
if norm == 'IN':
|
| 36 |
+
self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True)
|
| 37 |
+
elif norm == 'BN':
|
| 38 |
+
self.norm = nn.BatchNorm2d(out_feat)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
out = self.reflection_pad(x)
|
| 42 |
+
out = self.conv(out)
|
| 43 |
+
if self.norm:
|
| 44 |
+
out = self.norm(out)
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class UpConvNorm(nn.Module):
|
| 49 |
+
def __init__(self, in_channels, out_channels, mode='transpose', norm=False):
|
| 50 |
+
super(UpConvNorm, self).__init__()
|
| 51 |
+
|
| 52 |
+
if mode == 'transpose':
|
| 53 |
+
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
| 54 |
+
elif mode == 'shuffle':
|
| 55 |
+
self.upconv = nn.Sequential(
|
| 56 |
+
ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm),
|
| 57 |
+
PixelShuffle(2))
|
| 58 |
+
else:
|
| 59 |
+
# out_channels is always going to be the same as in_channels
|
| 60 |
+
self.upconv = nn.Sequential(
|
| 61 |
+
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
|
| 62 |
+
ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm))
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
out = self.upconv(x)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class meanShift(nn.Module):
|
| 71 |
+
def __init__(self, rgbRange, rgbMean, sign, nChannel=3):
|
| 72 |
+
super(meanShift, self).__init__()
|
| 73 |
+
if nChannel == 1:
|
| 74 |
+
l = rgbMean[0] * rgbRange * float(sign)
|
| 75 |
+
|
| 76 |
+
self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0)
|
| 77 |
+
self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1)
|
| 78 |
+
self.shifter.bias.data = torch.Tensor([l])
|
| 79 |
+
elif nChannel == 3:
|
| 80 |
+
r = rgbMean[0] * rgbRange * float(sign)
|
| 81 |
+
g = rgbMean[1] * rgbRange * float(sign)
|
| 82 |
+
b = rgbMean[2] * rgbRange * float(sign)
|
| 83 |
+
|
| 84 |
+
self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
|
| 85 |
+
self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
| 86 |
+
self.shifter.bias.data = torch.Tensor([r, g, b])
|
| 87 |
+
else:
|
| 88 |
+
r = rgbMean[0] * rgbRange * float(sign)
|
| 89 |
+
g = rgbMean[1] * rgbRange * float(sign)
|
| 90 |
+
b = rgbMean[2] * rgbRange * float(sign)
|
| 91 |
+
self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0)
|
| 92 |
+
self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1)
|
| 93 |
+
self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b])
|
| 94 |
+
|
| 95 |
+
# Freeze the meanShift layer
|
| 96 |
+
for params in self.shifter.parameters():
|
| 97 |
+
params.requires_grad = False
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
x = self.shifter(x)
|
| 101 |
+
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
""" CONV - (BN) - RELU - CONV - (BN) """
|
| 106 |
+
class ResBlock(nn.Module):
|
| 107 |
+
def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder
|
| 108 |
+
norm=False, act=nn.ReLU(True), downscale=False):
|
| 109 |
+
super(ResBlock, self).__init__()
|
| 110 |
+
|
| 111 |
+
self.body = nn.Sequential(
|
| 112 |
+
ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1),
|
| 113 |
+
act,
|
| 114 |
+
ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.downscale = None
|
| 118 |
+
if downscale:
|
| 119 |
+
self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
res = x
|
| 123 |
+
out = self.body(x)
|
| 124 |
+
if self.downscale is not None:
|
| 125 |
+
res = self.downscale(res)
|
| 126 |
+
out += res
|
| 127 |
+
|
| 128 |
+
return out
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
## Channel Attention (CA) Layer
|
| 132 |
+
class CALayer(nn.Module):
|
| 133 |
+
def __init__(self, channel, reduction=16):
|
| 134 |
+
super(CALayer, self).__init__()
|
| 135 |
+
# global average pooling: feature --> point
|
| 136 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 137 |
+
# feature channel downscale and upscale --> channel weight
|
| 138 |
+
self.conv_du = nn.Sequential(
|
| 139 |
+
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
|
| 140 |
+
nn.ReLU(inplace=True),
|
| 141 |
+
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
|
| 142 |
+
nn.Sigmoid()
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
y = self.avg_pool(x)
|
| 147 |
+
y = self.conv_du(y)
|
| 148 |
+
return x * y, y
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Residual Channel Attention Block (RCAB)
|
| 152 |
+
class RCAB(nn.Module):
|
| 153 |
+
def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True,
|
| 154 |
+
norm=False, act=nn.ReLU(True), downscale=False, return_ca=False):
|
| 155 |
+
super(RCAB, self).__init__()
|
| 156 |
+
|
| 157 |
+
self.body = nn.Sequential(
|
| 158 |
+
ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm),
|
| 159 |
+
act,
|
| 160 |
+
ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm),
|
| 161 |
+
CALayer(out_feat, reduction)
|
| 162 |
+
)
|
| 163 |
+
self.downscale = downscale
|
| 164 |
+
if downscale:
|
| 165 |
+
self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1)
|
| 166 |
+
self.return_ca = return_ca
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
res = x
|
| 170 |
+
out, ca = self.body(x)
|
| 171 |
+
if self.downscale:
|
| 172 |
+
res = self.downConv(res)
|
| 173 |
+
out += res
|
| 174 |
+
|
| 175 |
+
if self.return_ca:
|
| 176 |
+
return out, ca
|
| 177 |
+
else:
|
| 178 |
+
return out
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
## Residual Group (RG)
|
| 182 |
+
class ResidualGroup(nn.Module):
|
| 183 |
+
def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False):
|
| 184 |
+
super(ResidualGroup, self).__init__()
|
| 185 |
+
|
| 186 |
+
modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act)
|
| 187 |
+
for _ in range(n_resblocks)]
|
| 188 |
+
modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm))
|
| 189 |
+
self.body = nn.Sequential(*modules_body)
|
| 190 |
+
|
| 191 |
+
def forward(self, x):
|
| 192 |
+
res = self.body(x)
|
| 193 |
+
res += x
|
| 194 |
+
return res
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def pixel_shuffle(input, scale_factor):
|
| 198 |
+
batch_size, channels, in_height, in_width = input.size()
|
| 199 |
+
|
| 200 |
+
out_channels = int(int(channels / scale_factor) / scale_factor)
|
| 201 |
+
out_height = int(in_height * scale_factor)
|
| 202 |
+
out_width = int(in_width * scale_factor)
|
| 203 |
+
|
| 204 |
+
if scale_factor >= 1:
|
| 205 |
+
input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width)
|
| 206 |
+
shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
|
| 207 |
+
else:
|
| 208 |
+
block_size = int(1 / scale_factor)
|
| 209 |
+
input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size)
|
| 210 |
+
shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
|
| 211 |
+
|
| 212 |
+
return shuffle_out.view(batch_size, out_channels, out_height, out_width)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class PixelShuffle(nn.Module):
|
| 216 |
+
def __init__(self, scale_factor):
|
| 217 |
+
super(PixelShuffle, self).__init__()
|
| 218 |
+
self.scale_factor = scale_factor
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
return pixel_shuffle(x, self.scale_factor)
|
| 222 |
+
def extra_repr(self):
|
| 223 |
+
return 'scale_factor={}'.format(self.scale_factor)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def conv(in_channels, out_channels, kernel_size,
|
| 227 |
+
stride=1, bias=True, groups=1):
|
| 228 |
+
return nn.Conv2d(
|
| 229 |
+
in_channels,
|
| 230 |
+
out_channels,
|
| 231 |
+
kernel_size=kernel_size,
|
| 232 |
+
padding=kernel_size//2,
|
| 233 |
+
stride=1,
|
| 234 |
+
bias=bias,
|
| 235 |
+
groups=groups)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1):
|
| 239 |
+
return nn.Conv2d(
|
| 240 |
+
in_channels,
|
| 241 |
+
out_channels,
|
| 242 |
+
kernel_size=1,
|
| 243 |
+
stride=stride,
|
| 244 |
+
bias=bias,
|
| 245 |
+
groups=groups)
|
| 246 |
+
|
| 247 |
+
def conv3x3(in_channels, out_channels, stride=1,
|
| 248 |
+
padding=1, bias=True, groups=1):
|
| 249 |
+
return nn.Conv2d(
|
| 250 |
+
in_channels,
|
| 251 |
+
out_channels,
|
| 252 |
+
kernel_size=3,
|
| 253 |
+
stride=stride,
|
| 254 |
+
padding=padding,
|
| 255 |
+
bias=bias,
|
| 256 |
+
groups=groups)
|
| 257 |
+
|
| 258 |
+
def conv5x5(in_channels, out_channels, stride=1,
|
| 259 |
+
padding=2, bias=True, groups=1):
|
| 260 |
+
return nn.Conv2d(
|
| 261 |
+
in_channels,
|
| 262 |
+
out_channels,
|
| 263 |
+
kernel_size=5,
|
| 264 |
+
stride=stride,
|
| 265 |
+
padding=padding,
|
| 266 |
+
bias=bias,
|
| 267 |
+
groups=groups)
|
| 268 |
+
|
| 269 |
+
def conv7x7(in_channels, out_channels, stride=1,
|
| 270 |
+
padding=3, bias=True, groups=1):
|
| 271 |
+
return nn.Conv2d(
|
| 272 |
+
in_channels,
|
| 273 |
+
out_channels,
|
| 274 |
+
kernel_size=7,
|
| 275 |
+
stride=stride,
|
| 276 |
+
padding=padding,
|
| 277 |
+
bias=bias,
|
| 278 |
+
groups=groups)
|
| 279 |
+
|
| 280 |
+
def upconv2x2(in_channels, out_channels, mode='shuffle'):
|
| 281 |
+
if mode == 'transpose':
|
| 282 |
+
return nn.ConvTranspose2d(
|
| 283 |
+
in_channels,
|
| 284 |
+
out_channels,
|
| 285 |
+
kernel_size=4,
|
| 286 |
+
stride=2,
|
| 287 |
+
padding=1)
|
| 288 |
+
elif mode == 'shuffle':
|
| 289 |
+
return nn.Sequential(
|
| 290 |
+
conv3x3(in_channels, 4*out_channels),
|
| 291 |
+
PixelShuffle(2))
|
| 292 |
+
else:
|
| 293 |
+
# out_channels is always going to be the same as in_channels
|
| 294 |
+
return nn.Sequential(
|
| 295 |
+
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
|
| 296 |
+
conv1x1(in_channels, out_channels))
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class Interpolation(nn.Module):
|
| 301 |
+
def __init__(self, n_resgroups, n_resblocks, n_feats,
|
| 302 |
+
reduction=16, act=nn.LeakyReLU(0.2, True), norm=False):
|
| 303 |
+
super(Interpolation, self).__init__()
|
| 304 |
+
|
| 305 |
+
# define modules: head, body, tail
|
| 306 |
+
self.headConv = conv3x3(n_feats * 2, n_feats)
|
| 307 |
+
|
| 308 |
+
modules_body = [
|
| 309 |
+
ResidualGroup(
|
| 310 |
+
RCAB,
|
| 311 |
+
n_resblocks=n_resblocks,
|
| 312 |
+
n_feat=n_feats,
|
| 313 |
+
kernel_size=3,
|
| 314 |
+
reduction=reduction,
|
| 315 |
+
act=act,
|
| 316 |
+
norm=norm)
|
| 317 |
+
for _ in range(n_resgroups)]
|
| 318 |
+
self.body = nn.Sequential(*modules_body)
|
| 319 |
+
|
| 320 |
+
self.tailConv = conv3x3(n_feats, n_feats)
|
| 321 |
+
|
| 322 |
+
def forward(self, x0, x1):
|
| 323 |
+
# Build input tensor
|
| 324 |
+
x = torch.cat([x0, x1], dim=1)
|
| 325 |
+
x = self.headConv(x)
|
| 326 |
+
|
| 327 |
+
res = self.body(x)
|
| 328 |
+
res += x
|
| 329 |
+
|
| 330 |
+
out = self.tailConv(res)
|
| 331 |
+
return out
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class Interpolation_res(nn.Module):
|
| 335 |
+
def __init__(self, n_resgroups, n_resblocks, n_feats,
|
| 336 |
+
act=nn.LeakyReLU(0.2, True), norm=False):
|
| 337 |
+
super(Interpolation_res, self).__init__()
|
| 338 |
+
|
| 339 |
+
# define modules: head, body, tail (reduces concatenated inputs to n_feat)
|
| 340 |
+
self.headConv = conv3x3(n_feats * 2, n_feats)
|
| 341 |
+
|
| 342 |
+
modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3,
|
| 343 |
+
reduction=0, act=act, norm=norm)
|
| 344 |
+
for _ in range(n_resgroups)]
|
| 345 |
+
self.body = nn.Sequential(*modules_body)
|
| 346 |
+
|
| 347 |
+
self.tailConv = conv3x3(n_feats, n_feats)
|
| 348 |
+
|
| 349 |
+
def forward(self, x0, x1):
|
| 350 |
+
# Build input tensor
|
| 351 |
+
x = torch.cat([x0, x1], dim=1)
|
| 352 |
+
x = self.headConv(x)
|
| 353 |
+
|
| 354 |
+
res = x
|
| 355 |
+
for m in self.body:
|
| 356 |
+
res = m(res)
|
| 357 |
+
res += x
|
| 358 |
+
|
| 359 |
+
x = self.tailConv(res)
|
| 360 |
+
|
| 361 |
+
return x
|
vfi_models/eisai/__init__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
|
| 3 |
+
import typing
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from comfy.model_management import soft_empty_cache, get_torch_device
|
| 7 |
+
|
| 8 |
+
MODEL_TYPE = pathlib.Path(__file__).parent.name
|
| 9 |
+
MODEL_FILE_NAMES = {
|
| 10 |
+
"ssl": "eisai_ssl.pt",
|
| 11 |
+
"dtm": "eisai_dtm.pt",
|
| 12 |
+
"raft": "eisai_anime_interp_full.ckpt"
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
class EISAI(nn.Module):
|
| 16 |
+
def __init__(self, model_file_names) -> None:
|
| 17 |
+
from .eisai_arch import SoftsplatLite, DTM, RAFT
|
| 18 |
+
super(EISAI, self).__init__()
|
| 19 |
+
self.raft = RAFT(load_file_from_github_release(MODEL_TYPE, model_file_names["raft"]))
|
| 20 |
+
self.raft.to(get_torch_device()).eval()
|
| 21 |
+
|
| 22 |
+
self.ssl = SoftsplatLite()
|
| 23 |
+
self.ssl.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["ssl"])))
|
| 24 |
+
self.ssl.to(get_torch_device()).eval()
|
| 25 |
+
|
| 26 |
+
self.dtm = DTM()
|
| 27 |
+
self.dtm.load_state_dict(torch.load(load_file_from_github_release(MODEL_TYPE, model_file_names["dtm"])))
|
| 28 |
+
self.dtm.to(get_torch_device()).eval()
|
| 29 |
+
|
| 30 |
+
def forward(self, img0, img1, t):
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
flow0, _ = self.raft(img0, img1)
|
| 33 |
+
flow1, _ = self.raft(img1, img0)
|
| 34 |
+
x = {
|
| 35 |
+
"images": torch.stack([img0, img1], dim=1),
|
| 36 |
+
"flows": torch.stack([flow0, flow1], dim=1),
|
| 37 |
+
}
|
| 38 |
+
out_ssl, _ = self.ssl(x, t=t, return_more=True)
|
| 39 |
+
out_dtm, _ = self.dtm(x, out_ssl, _, return_more=False)
|
| 40 |
+
return out_dtm[:, :3]
|
| 41 |
+
|
| 42 |
+
class EISAI_VFI:
|
| 43 |
+
@classmethod
|
| 44 |
+
def INPUT_TYPES(s):
|
| 45 |
+
return {
|
| 46 |
+
"required": {
|
| 47 |
+
"ckpt_name": (["eisai"], ),
|
| 48 |
+
"frames": ("IMAGE", ),
|
| 49 |
+
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
| 50 |
+
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
|
| 51 |
+
},
|
| 52 |
+
"optional": {
|
| 53 |
+
"optional_interpolation_states": ("INTERPOLATION_STATES", )
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
RETURN_TYPES = ("IMAGE", )
|
| 58 |
+
FUNCTION = "vfi"
|
| 59 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
| 60 |
+
|
| 61 |
+
def vfi(
|
| 62 |
+
self,
|
| 63 |
+
ckpt_name: typing.AnyStr,
|
| 64 |
+
frames: torch.Tensor,
|
| 65 |
+
clear_cache_after_n_frames = 10,
|
| 66 |
+
multiplier: typing.SupportsInt = 2,
|
| 67 |
+
optional_interpolation_states: InterpolationStateList = None,
|
| 68 |
+
**kwargs
|
| 69 |
+
):
|
| 70 |
+
interpolation_model = EISAI(MODEL_FILE_NAMES)
|
| 71 |
+
interpolation_model.eval().to(get_torch_device())
|
| 72 |
+
frames = preprocess_frames(frames)
|
| 73 |
+
|
| 74 |
+
def return_middle_frame(frame_0, frame_1, timestep, model):
|
| 75 |
+
return model(frame_0, frame_1, t=timestep)
|
| 76 |
+
|
| 77 |
+
scale = 1
|
| 78 |
+
|
| 79 |
+
args = [interpolation_model, scale]
|
| 80 |
+
out = postprocess_frames(
|
| 81 |
+
generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
|
| 82 |
+
interpolation_states=optional_interpolation_states, dtype=torch.float32)
|
| 83 |
+
)
|
| 84 |
+
return (out,)
|
vfi_models/eisai/eisai_arch.py
ADDED
|
@@ -0,0 +1,2586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_scripts/interpolate.py
|
| 3 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/models/ssldtm.py
|
| 4 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/util_v0.py
|
| 5 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/twodee_v0.py
|
| 6 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/pytorch_v0.py
|
| 7 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/distance_transform_v0.py
|
| 8 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/sketchers_v1.py
|
| 9 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/interpolator_v0.py
|
| 10 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/gridnet_v1.py
|
| 11 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/flow_v0.py
|
| 12 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_util/softsplat_v0.py
|
| 13 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/rfr_new.py
|
| 14 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/extractor.py
|
| 15 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/update.py
|
| 16 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/corr.py
|
| 17 |
+
https://github.com/ShuhongChen/eisai-anime-interpolator/blob/master/_train/frame_interpolation/helpers/raft_v1/utils.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import copy
|
| 21 |
+
import cv2
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torchvision.transforms.functional as F
|
| 24 |
+
import gc
|
| 25 |
+
from PIL import Image, ImageFile, ImageFont, ImageDraw
|
| 26 |
+
import inspect
|
| 27 |
+
from scipy import interpolate
|
| 28 |
+
import kornia
|
| 29 |
+
import math
|
| 30 |
+
from argparse import Namespace
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import numpy as np
|
| 33 |
+
import os
|
| 34 |
+
from functools import partial
|
| 35 |
+
import pathlib
|
| 36 |
+
import PIL
|
| 37 |
+
import re
|
| 38 |
+
import requests
|
| 39 |
+
from scipy.spatial.transform import Rotation
|
| 40 |
+
import scipy
|
| 41 |
+
import shutil
|
| 42 |
+
import torchvision.transforms as T
|
| 43 |
+
import time
|
| 44 |
+
import torch
|
| 45 |
+
import torchvision as tv
|
| 46 |
+
import zlib
|
| 47 |
+
import numpy as np
|
| 48 |
+
import torch
|
| 49 |
+
import torch.nn as nn
|
| 50 |
+
import torch.nn.functional as F
|
| 51 |
+
from tqdm.auto import tqdm as std_tqdm
|
| 52 |
+
from tqdm.auto import trange as std_trange
|
| 53 |
+
from vfi_models.ops import FunctionSoftsplat, batch_edt
|
| 54 |
+
from comfy.model_management import get_torch_device
|
| 55 |
+
|
| 56 |
+
device = get_torch_device()
|
| 57 |
+
autocast = torch.autocast
|
| 58 |
+
tqdm = partial(std_tqdm, dynamic_ncols=True)
|
| 59 |
+
trange = partial(std_trange, dynamic_ncols=True)
|
| 60 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def pixel_ij(x, rounding=True):
|
| 64 |
+
if isinstance(x, np.ndarray):
|
| 65 |
+
x = x.tolist()
|
| 66 |
+
return tuple(
|
| 67 |
+
pixel_rounder(i, rounding)
|
| 68 |
+
for i in (x if isinstance(x, tuple) or isinstance(x, list) else (x, x))
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def rescale_dry(x, factor):
|
| 73 |
+
h, w = x[-2:] if isinstance(x, tuple) or isinstance(x, list) else I(x).size
|
| 74 |
+
return (h * factor, w * factor)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def pixel_rounder(n, mode):
|
| 78 |
+
if mode == True or mode == "round":
|
| 79 |
+
return round(n)
|
| 80 |
+
elif mode == "ceil":
|
| 81 |
+
return math.ceil(n)
|
| 82 |
+
elif mode == "floor":
|
| 83 |
+
return math.floor(n)
|
| 84 |
+
else:
|
| 85 |
+
return n
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def diam(x):
|
| 89 |
+
if isinstance(x, tuple) or isinstance(x, list):
|
| 90 |
+
h, w = x[-2:]
|
| 91 |
+
elif isinstance(x, I):
|
| 92 |
+
h, w = x.size
|
| 93 |
+
else:
|
| 94 |
+
h, w = x.shape[-2:]
|
| 95 |
+
return np.sqrt(h**2 + w**2)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def pixel_logit(x, pixel_margin=1):
|
| 99 |
+
x = (x * (255 - 2 * pixel_margin) + pixel_margin) / 255
|
| 100 |
+
return torch.log(x / (1 - x))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class InputPadder:
|
| 104 |
+
"""Pads images such that dimensions are divisible by 8"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, dims):
|
| 107 |
+
self.ht, self.wd = dims[-2:]
|
| 108 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
| 109 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
| 110 |
+
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
| 111 |
+
|
| 112 |
+
def pad(self, *inputs):
|
| 113 |
+
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
|
| 114 |
+
|
| 115 |
+
def unpad(self, x):
|
| 116 |
+
ht, wd = x.shape[-2:]
|
| 117 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
| 118 |
+
return x[..., c[0] : c[1], c[2] : c[3]]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def forward_interpolate(flow):
|
| 122 |
+
flow = flow.detach().cpu().numpy()
|
| 123 |
+
dx, dy = flow[0], flow[1]
|
| 124 |
+
|
| 125 |
+
ht, wd = dx.shape
|
| 126 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
| 127 |
+
|
| 128 |
+
x1 = x0 + dx
|
| 129 |
+
y1 = y0 + dy
|
| 130 |
+
|
| 131 |
+
x1 = x1.reshape(-1)
|
| 132 |
+
y1 = y1.reshape(-1)
|
| 133 |
+
dx = dx.reshape(-1)
|
| 134 |
+
dy = dy.reshape(-1)
|
| 135 |
+
|
| 136 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
| 137 |
+
x1 = x1[valid]
|
| 138 |
+
y1 = y1[valid]
|
| 139 |
+
dx = dx[valid]
|
| 140 |
+
dy = dy[valid]
|
| 141 |
+
|
| 142 |
+
flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method="cubic", fill_value=0)
|
| 143 |
+
|
| 144 |
+
flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method="cubic", fill_value=0)
|
| 145 |
+
|
| 146 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
| 147 |
+
return torch.from_numpy(flow).float()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
| 151 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
| 152 |
+
H, W = img.shape[-2:]
|
| 153 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
| 154 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
| 155 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
| 156 |
+
|
| 157 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 158 |
+
# print(img.size())
|
| 159 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
| 160 |
+
|
| 161 |
+
if mask:
|
| 162 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 163 |
+
return img, mask.float()
|
| 164 |
+
|
| 165 |
+
return img
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def coords_grid(batch, ht, wd):
|
| 169 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
| 170 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 171 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def upflow8(flow, mode="bilinear"):
|
| 175 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
| 176 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class CorrBlock:
|
| 180 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 181 |
+
self.num_levels = num_levels
|
| 182 |
+
self.radius = radius
|
| 183 |
+
self.corr_pyramid = []
|
| 184 |
+
|
| 185 |
+
# all pairs correlation
|
| 186 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
| 187 |
+
|
| 188 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
| 189 |
+
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
| 190 |
+
|
| 191 |
+
self.corr_pyramid.append(corr)
|
| 192 |
+
for i in range(self.num_levels - 1):
|
| 193 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
| 194 |
+
self.corr_pyramid.append(corr)
|
| 195 |
+
|
| 196 |
+
def __call__(self, coords):
|
| 197 |
+
r = self.radius
|
| 198 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 199 |
+
batch, h1, w1, _ = coords.shape
|
| 200 |
+
|
| 201 |
+
out_pyramid = []
|
| 202 |
+
for i in range(self.num_levels):
|
| 203 |
+
corr = self.corr_pyramid[i]
|
| 204 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 205 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 206 |
+
delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device)
|
| 207 |
+
|
| 208 |
+
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
| 209 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 210 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 211 |
+
|
| 212 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
| 213 |
+
corr = corr.view(batch, h1, w1, -1)
|
| 214 |
+
out_pyramid.append(corr)
|
| 215 |
+
|
| 216 |
+
out = torch.cat(out_pyramid, dim=-1)
|
| 217 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def corr(fmap1, fmap2):
|
| 221 |
+
batch, dim, ht, wd = fmap1.shape
|
| 222 |
+
fmap1 = fmap1.view(batch, dim, ht * wd)
|
| 223 |
+
fmap2 = fmap2.view(batch, dim, ht * wd)
|
| 224 |
+
|
| 225 |
+
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
| 226 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
| 227 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class FlowHead(nn.Module):
|
| 231 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
| 232 |
+
super(FlowHead, self).__init__()
|
| 233 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
| 234 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
| 235 |
+
self.relu = nn.ReLU(inplace=True)
|
| 236 |
+
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
return self.conv2(self.relu(self.conv1(x)))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class ConvGRU(nn.Module):
|
| 242 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
| 243 |
+
super(ConvGRU, self).__init__()
|
| 244 |
+
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
| 245 |
+
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
| 246 |
+
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
| 247 |
+
|
| 248 |
+
def forward(self, h, x):
|
| 249 |
+
hx = torch.cat([h, x], dim=1)
|
| 250 |
+
|
| 251 |
+
z = torch.sigmoid(self.convz(hx))
|
| 252 |
+
r = torch.sigmoid(self.convr(hx))
|
| 253 |
+
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
| 254 |
+
|
| 255 |
+
h = (1 - z) * h + z * q
|
| 256 |
+
return h
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class SepConvGRU(nn.Module):
|
| 260 |
+
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
| 261 |
+
super(SepConvGRU, self).__init__()
|
| 262 |
+
self.convz1 = nn.Conv2d(
|
| 263 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
| 264 |
+
)
|
| 265 |
+
self.convr1 = nn.Conv2d(
|
| 266 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
| 267 |
+
)
|
| 268 |
+
self.convq1 = nn.Conv2d(
|
| 269 |
+
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.convz2 = nn.Conv2d(
|
| 273 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
| 274 |
+
)
|
| 275 |
+
self.convr2 = nn.Conv2d(
|
| 276 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
| 277 |
+
)
|
| 278 |
+
self.convq2 = nn.Conv2d(
|
| 279 |
+
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, h, x):
|
| 283 |
+
# horizontal
|
| 284 |
+
hx = torch.cat([h, x], dim=1)
|
| 285 |
+
z = torch.sigmoid(self.convz1(hx))
|
| 286 |
+
r = torch.sigmoid(self.convr1(hx))
|
| 287 |
+
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
| 288 |
+
h = (1 - z) * h + z * q
|
| 289 |
+
|
| 290 |
+
# vertical
|
| 291 |
+
hx = torch.cat([h, x], dim=1)
|
| 292 |
+
z = torch.sigmoid(self.convz2(hx))
|
| 293 |
+
r = torch.sigmoid(self.convr2(hx))
|
| 294 |
+
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
| 295 |
+
h = (1 - z) * h + z * q
|
| 296 |
+
|
| 297 |
+
return h
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class SmallMotionEncoder(nn.Module):
|
| 301 |
+
def __init__(self, args):
|
| 302 |
+
super(SmallMotionEncoder, self).__init__()
|
| 303 |
+
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
| 304 |
+
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
| 305 |
+
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
| 306 |
+
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
| 307 |
+
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
| 308 |
+
|
| 309 |
+
def forward(self, flow, corr):
|
| 310 |
+
cor = F.relu(self.convc1(corr))
|
| 311 |
+
flo = F.relu(self.convf1(flow))
|
| 312 |
+
flo = F.relu(self.convf2(flo))
|
| 313 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 314 |
+
out = F.relu(self.conv(cor_flo))
|
| 315 |
+
return torch.cat([out, flow], dim=1)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class BasicMotionEncoder(nn.Module):
|
| 319 |
+
def __init__(self, args):
|
| 320 |
+
super(BasicMotionEncoder, self).__init__()
|
| 321 |
+
cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
|
| 322 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
| 323 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
| 324 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
| 325 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
| 326 |
+
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
| 327 |
+
|
| 328 |
+
def forward(self, flow, corr):
|
| 329 |
+
cor = F.relu(self.convc1(corr))
|
| 330 |
+
cor = F.relu(self.convc2(cor))
|
| 331 |
+
flo = F.relu(self.convf1(flow))
|
| 332 |
+
flo = F.relu(self.convf2(flo))
|
| 333 |
+
|
| 334 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
| 335 |
+
out = F.relu(self.conv(cor_flo))
|
| 336 |
+
return torch.cat([out, flow], dim=1)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class SmallUpdateBlock(nn.Module):
|
| 340 |
+
def __init__(self, args, hidden_dim=96):
|
| 341 |
+
super(SmallUpdateBlock, self).__init__()
|
| 342 |
+
self.encoder = SmallMotionEncoder(args)
|
| 343 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
|
| 344 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
| 345 |
+
|
| 346 |
+
def forward(self, net, inp, corr, flow):
|
| 347 |
+
motion_features = self.encoder(flow, corr)
|
| 348 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
| 349 |
+
net = self.gru(net, inp)
|
| 350 |
+
delta_flow = self.flow_head(net)
|
| 351 |
+
|
| 352 |
+
return net, None, delta_flow
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class BasicUpdateBlock(nn.Module):
|
| 356 |
+
def __init__(self, args, hidden_dim=128, input_dim=128):
|
| 357 |
+
super(BasicUpdateBlock, self).__init__()
|
| 358 |
+
self.args = args
|
| 359 |
+
self.encoder = BasicMotionEncoder(args)
|
| 360 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
| 361 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
| 362 |
+
|
| 363 |
+
self.mask = nn.Sequential(
|
| 364 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
| 365 |
+
nn.ReLU(inplace=True),
|
| 366 |
+
nn.Conv2d(256, 64 * 9, 1, padding=0),
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
| 370 |
+
motion_features = self.encoder(flow, corr)
|
| 371 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
| 372 |
+
|
| 373 |
+
net = self.gru(net, inp)
|
| 374 |
+
delta_flow = self.flow_head(net)
|
| 375 |
+
|
| 376 |
+
# scale mask to balence gradients
|
| 377 |
+
mask = 0.25 * self.mask(net)
|
| 378 |
+
return net, mask, delta_flow
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class ResidualBlock(nn.Module):
|
| 382 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 383 |
+
super(ResidualBlock, self).__init__()
|
| 384 |
+
|
| 385 |
+
self.conv1 = nn.Conv2d(
|
| 386 |
+
in_planes, planes, kernel_size=3, padding=1, stride=stride
|
| 387 |
+
)
|
| 388 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
| 389 |
+
self.relu = nn.ReLU(inplace=True)
|
| 390 |
+
|
| 391 |
+
num_groups = planes // 8
|
| 392 |
+
|
| 393 |
+
if norm_fn == "group":
|
| 394 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 395 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 396 |
+
if not stride == 1:
|
| 397 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 398 |
+
|
| 399 |
+
elif norm_fn == "batch":
|
| 400 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 401 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 402 |
+
if not stride == 1:
|
| 403 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 404 |
+
|
| 405 |
+
elif norm_fn == "instance":
|
| 406 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 407 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 408 |
+
if not stride == 1:
|
| 409 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 410 |
+
|
| 411 |
+
elif norm_fn == "none":
|
| 412 |
+
self.norm1 = nn.Sequential()
|
| 413 |
+
self.norm2 = nn.Sequential()
|
| 414 |
+
if not stride == 1:
|
| 415 |
+
self.norm3 = nn.Sequential()
|
| 416 |
+
|
| 417 |
+
if stride == 1:
|
| 418 |
+
self.downsample = None
|
| 419 |
+
|
| 420 |
+
else:
|
| 421 |
+
self.downsample = nn.Sequential(
|
| 422 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
def forward(self, x):
|
| 426 |
+
y = x
|
| 427 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 428 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 429 |
+
|
| 430 |
+
if self.downsample is not None:
|
| 431 |
+
x = self.downsample(x)
|
| 432 |
+
|
| 433 |
+
return self.relu(x + y)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class BottleneckBlock(nn.Module):
|
| 437 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
| 438 |
+
super(BottleneckBlock, self).__init__()
|
| 439 |
+
|
| 440 |
+
self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
|
| 441 |
+
self.conv2 = nn.Conv2d(
|
| 442 |
+
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
|
| 443 |
+
)
|
| 444 |
+
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
| 445 |
+
self.relu = nn.ReLU(inplace=True)
|
| 446 |
+
|
| 447 |
+
num_groups = planes // 8
|
| 448 |
+
|
| 449 |
+
if norm_fn == "group":
|
| 450 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
| 451 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
|
| 452 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 453 |
+
if not stride == 1:
|
| 454 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 455 |
+
|
| 456 |
+
elif norm_fn == "batch":
|
| 457 |
+
self.norm1 = nn.BatchNorm2d(planes // 4)
|
| 458 |
+
self.norm2 = nn.BatchNorm2d(planes // 4)
|
| 459 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 460 |
+
if not stride == 1:
|
| 461 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
| 462 |
+
|
| 463 |
+
elif norm_fn == "instance":
|
| 464 |
+
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
| 465 |
+
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
| 466 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 467 |
+
if not stride == 1:
|
| 468 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
| 469 |
+
|
| 470 |
+
elif norm_fn == "none":
|
| 471 |
+
self.norm1 = nn.Sequential()
|
| 472 |
+
self.norm2 = nn.Sequential()
|
| 473 |
+
self.norm3 = nn.Sequential()
|
| 474 |
+
if not stride == 1:
|
| 475 |
+
self.norm4 = nn.Sequential()
|
| 476 |
+
|
| 477 |
+
if stride == 1:
|
| 478 |
+
self.downsample = None
|
| 479 |
+
|
| 480 |
+
else:
|
| 481 |
+
self.downsample = nn.Sequential(
|
| 482 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def forward(self, x):
|
| 486 |
+
y = x
|
| 487 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 488 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 489 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
| 490 |
+
|
| 491 |
+
if self.downsample is not None:
|
| 492 |
+
x = self.downsample(x)
|
| 493 |
+
|
| 494 |
+
return self.relu(x + y)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
class BasicEncoder(nn.Module):
|
| 498 |
+
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
| 499 |
+
super(BasicEncoder, self).__init__()
|
| 500 |
+
self.norm_fn = norm_fn
|
| 501 |
+
|
| 502 |
+
if self.norm_fn == "group":
|
| 503 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 504 |
+
|
| 505 |
+
elif self.norm_fn == "batch":
|
| 506 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 507 |
+
|
| 508 |
+
elif self.norm_fn == "instance":
|
| 509 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 510 |
+
|
| 511 |
+
elif self.norm_fn == "none":
|
| 512 |
+
self.norm1 = nn.Sequential()
|
| 513 |
+
|
| 514 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
| 515 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 516 |
+
|
| 517 |
+
self.in_planes = 64
|
| 518 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 519 |
+
self.layer2 = self._make_layer(96, stride=2)
|
| 520 |
+
self.layer3 = self._make_layer(128, stride=2)
|
| 521 |
+
|
| 522 |
+
# output convolution
|
| 523 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
| 524 |
+
|
| 525 |
+
self.dropout = None
|
| 526 |
+
if dropout > 0:
|
| 527 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 528 |
+
|
| 529 |
+
for m in self.modules():
|
| 530 |
+
if isinstance(m, nn.Conv2d):
|
| 531 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 532 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 533 |
+
if m.weight is not None:
|
| 534 |
+
nn.init.constant_(m.weight, 1)
|
| 535 |
+
if m.bias is not None:
|
| 536 |
+
nn.init.constant_(m.bias, 0)
|
| 537 |
+
|
| 538 |
+
def _make_layer(self, dim, stride=1):
|
| 539 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 540 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 541 |
+
layers = (layer1, layer2)
|
| 542 |
+
|
| 543 |
+
self.in_planes = dim
|
| 544 |
+
return nn.Sequential(*layers)
|
| 545 |
+
|
| 546 |
+
def forward(self, x):
|
| 547 |
+
# if input is list, combine batch dimension
|
| 548 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 549 |
+
if is_list:
|
| 550 |
+
batch_dim = x[0].shape[0]
|
| 551 |
+
x = torch.cat(x, dim=0)
|
| 552 |
+
|
| 553 |
+
x = self.conv1(x)
|
| 554 |
+
x = self.norm1(x)
|
| 555 |
+
x = self.relu1(x)
|
| 556 |
+
|
| 557 |
+
x = self.layer1(x)
|
| 558 |
+
x = self.layer2(x)
|
| 559 |
+
x = self.layer3(x)
|
| 560 |
+
|
| 561 |
+
x = self.conv2(x)
|
| 562 |
+
|
| 563 |
+
if self.training and self.dropout is not None:
|
| 564 |
+
x = self.dropout(x)
|
| 565 |
+
|
| 566 |
+
if is_list:
|
| 567 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 568 |
+
|
| 569 |
+
return x
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class BasicEncoder1(nn.Module):
|
| 573 |
+
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
| 574 |
+
super(BasicEncoder1, self).__init__()
|
| 575 |
+
self.norm_fn = norm_fn
|
| 576 |
+
|
| 577 |
+
if self.norm_fn == "group":
|
| 578 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
| 579 |
+
|
| 580 |
+
elif self.norm_fn == "batch":
|
| 581 |
+
self.norm1 = nn.BatchNorm2d(64)
|
| 582 |
+
|
| 583 |
+
elif self.norm_fn == "instance":
|
| 584 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
| 585 |
+
|
| 586 |
+
elif self.norm_fn == "none":
|
| 587 |
+
self.norm1 = nn.Sequential()
|
| 588 |
+
|
| 589 |
+
self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3)
|
| 590 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 591 |
+
|
| 592 |
+
self.in_planes = 64
|
| 593 |
+
self.layer1 = self._make_layer(64, stride=1)
|
| 594 |
+
self.layer2 = self._make_layer(96, stride=2)
|
| 595 |
+
self.layer3 = self._make_layer(128, stride=2)
|
| 596 |
+
|
| 597 |
+
# output convolution
|
| 598 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
| 599 |
+
|
| 600 |
+
self.dropout = None
|
| 601 |
+
if dropout > 0:
|
| 602 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 603 |
+
|
| 604 |
+
for m in self.modules():
|
| 605 |
+
if isinstance(m, nn.Conv2d):
|
| 606 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 607 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 608 |
+
if m.weight is not None:
|
| 609 |
+
nn.init.constant_(m.weight, 1)
|
| 610 |
+
if m.bias is not None:
|
| 611 |
+
nn.init.constant_(m.bias, 0)
|
| 612 |
+
|
| 613 |
+
def _make_layer(self, dim, stride=1):
|
| 614 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 615 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 616 |
+
layers = (layer1, layer2)
|
| 617 |
+
|
| 618 |
+
self.in_planes = dim
|
| 619 |
+
return nn.Sequential(*layers)
|
| 620 |
+
|
| 621 |
+
def forward(self, x):
|
| 622 |
+
# if input is list, combine batch dimension
|
| 623 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 624 |
+
if is_list:
|
| 625 |
+
batch_dim = x[0].shape[0]
|
| 626 |
+
x = torch.cat(x, dim=0)
|
| 627 |
+
|
| 628 |
+
x = self.conv1(x)
|
| 629 |
+
x = self.norm1(x)
|
| 630 |
+
x = self.relu1(x)
|
| 631 |
+
|
| 632 |
+
x = self.layer1(x)
|
| 633 |
+
x = self.layer2(x)
|
| 634 |
+
x = self.layer3(x)
|
| 635 |
+
|
| 636 |
+
x = self.conv2(x)
|
| 637 |
+
|
| 638 |
+
if self.training and self.dropout is not None:
|
| 639 |
+
x = self.dropout(x)
|
| 640 |
+
|
| 641 |
+
if is_list:
|
| 642 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 643 |
+
|
| 644 |
+
return x
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
class SmallEncoder(nn.Module):
|
| 648 |
+
def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
|
| 649 |
+
super(SmallEncoder, self).__init__()
|
| 650 |
+
self.norm_fn = norm_fn
|
| 651 |
+
|
| 652 |
+
if self.norm_fn == "group":
|
| 653 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
| 654 |
+
|
| 655 |
+
elif self.norm_fn == "batch":
|
| 656 |
+
self.norm1 = nn.BatchNorm2d(32)
|
| 657 |
+
|
| 658 |
+
elif self.norm_fn == "instance":
|
| 659 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
| 660 |
+
|
| 661 |
+
elif self.norm_fn == "none":
|
| 662 |
+
self.norm1 = nn.Sequential()
|
| 663 |
+
|
| 664 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
| 665 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 666 |
+
|
| 667 |
+
self.in_planes = 32
|
| 668 |
+
self.layer1 = self._make_layer(32, stride=1)
|
| 669 |
+
self.layer2 = self._make_layer(64, stride=2)
|
| 670 |
+
self.layer3 = self._make_layer(96, stride=2)
|
| 671 |
+
|
| 672 |
+
self.dropout = None
|
| 673 |
+
if dropout > 0:
|
| 674 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
| 675 |
+
|
| 676 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
| 677 |
+
|
| 678 |
+
for m in self.modules():
|
| 679 |
+
if isinstance(m, nn.Conv2d):
|
| 680 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 681 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 682 |
+
if m.weight is not None:
|
| 683 |
+
nn.init.constant_(m.weight, 1)
|
| 684 |
+
if m.bias is not None:
|
| 685 |
+
nn.init.constant_(m.bias, 0)
|
| 686 |
+
|
| 687 |
+
def _make_layer(self, dim, stride=1):
|
| 688 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 689 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
| 690 |
+
layers = (layer1, layer2)
|
| 691 |
+
|
| 692 |
+
self.in_planes = dim
|
| 693 |
+
return nn.Sequential(*layers)
|
| 694 |
+
|
| 695 |
+
def forward(self, x):
|
| 696 |
+
# if input is list, combine batch dimension
|
| 697 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
| 698 |
+
if is_list:
|
| 699 |
+
batch_dim = x[0].shape[0]
|
| 700 |
+
x = torch.cat(x, dim=0)
|
| 701 |
+
|
| 702 |
+
x = self.conv1(x)
|
| 703 |
+
x = self.norm1(x)
|
| 704 |
+
x = self.relu1(x)
|
| 705 |
+
|
| 706 |
+
x = self.layer1(x)
|
| 707 |
+
x = self.layer2(x)
|
| 708 |
+
x = self.layer3(x)
|
| 709 |
+
x = self.conv2(x)
|
| 710 |
+
|
| 711 |
+
if self.training and self.dropout is not None:
|
| 712 |
+
x = self.dropout(x)
|
| 713 |
+
|
| 714 |
+
if is_list:
|
| 715 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
| 716 |
+
|
| 717 |
+
return x
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
##################################################
|
| 721 |
+
# RFR is implemented based on RAFT optical flow #
|
| 722 |
+
##################################################
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def backwarp(img, flow):
|
| 726 |
+
_, _, H, W = img.size()
|
| 727 |
+
|
| 728 |
+
u = flow[:, 0, :, :]
|
| 729 |
+
v = flow[:, 1, :, :]
|
| 730 |
+
|
| 731 |
+
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
|
| 732 |
+
|
| 733 |
+
gridX = torch.tensor(
|
| 734 |
+
gridX,
|
| 735 |
+
requires_grad=False,
|
| 736 |
+
).cuda()
|
| 737 |
+
gridY = torch.tensor(
|
| 738 |
+
gridY,
|
| 739 |
+
requires_grad=False,
|
| 740 |
+
).cuda()
|
| 741 |
+
x = gridX.unsqueeze(0).expand_as(u).float() + u
|
| 742 |
+
y = gridY.unsqueeze(0).expand_as(v).float() + v
|
| 743 |
+
# range -1 to 1
|
| 744 |
+
x = 2 * (x / (W - 1) - 0.5)
|
| 745 |
+
y = 2 * (y / (H - 1) - 0.5)
|
| 746 |
+
# stacking X and Y
|
| 747 |
+
grid = torch.stack((x, y), dim=3)
|
| 748 |
+
# Sample pixels using bilinear interpolation.
|
| 749 |
+
imgOut = torch.nn.functional.grid_sample(img, grid, align_corners=True)
|
| 750 |
+
|
| 751 |
+
return imgOut
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
class ErrorAttention(nn.Module):
|
| 755 |
+
"""A three-layer network for predicting mask"""
|
| 756 |
+
|
| 757 |
+
def __init__(self, input, output):
|
| 758 |
+
super(ErrorAttention, self).__init__()
|
| 759 |
+
self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
|
| 760 |
+
self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
|
| 761 |
+
self.conv3 = nn.Conv2d(38, output, 3, padding=1)
|
| 762 |
+
self.prelu1 = nn.PReLU()
|
| 763 |
+
self.prelu2 = nn.PReLU()
|
| 764 |
+
|
| 765 |
+
def forward(self, x1):
|
| 766 |
+
x = self.prelu1(self.conv1(x1))
|
| 767 |
+
x = self.prelu2(torch.cat([self.conv2(x), x1], dim=1))
|
| 768 |
+
x = self.conv3(x)
|
| 769 |
+
return x
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class RFR(nn.Module):
|
| 773 |
+
def __init__(self, args):
|
| 774 |
+
super(RFR, self).__init__()
|
| 775 |
+
self.attention2 = ErrorAttention(6, 1)
|
| 776 |
+
self.hidden_dim = hdim = 128
|
| 777 |
+
self.context_dim = cdim = 128
|
| 778 |
+
args.corr_levels = 4
|
| 779 |
+
args.corr_radius = 4
|
| 780 |
+
args.dropout = 0
|
| 781 |
+
self.args = args
|
| 782 |
+
|
| 783 |
+
# feature network, context network, and update block
|
| 784 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn="none", dropout=args.dropout)
|
| 785 |
+
# self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
| 786 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
| 787 |
+
|
| 788 |
+
def freeze_bn(self):
|
| 789 |
+
for m in self.modules():
|
| 790 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 791 |
+
m.eval()
|
| 792 |
+
|
| 793 |
+
def initialize_flow(self, img):
|
| 794 |
+
"""Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
| 795 |
+
N, C, H, W = img.shape
|
| 796 |
+
coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
|
| 797 |
+
coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
|
| 798 |
+
|
| 799 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
| 800 |
+
return coords0, coords1
|
| 801 |
+
|
| 802 |
+
def upsample_flow(self, flow, mask):
|
| 803 |
+
"""Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
|
| 804 |
+
N, _, H, W = flow.shape
|
| 805 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
| 806 |
+
mask = torch.softmax(mask, dim=2)
|
| 807 |
+
|
| 808 |
+
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
| 809 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
| 810 |
+
|
| 811 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
| 812 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
| 813 |
+
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
| 814 |
+
|
| 815 |
+
def forward(
|
| 816 |
+
self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
|
| 817 |
+
):
|
| 818 |
+
H, W = image1.size()[2:4]
|
| 819 |
+
H8 = H // 8 * 8
|
| 820 |
+
W8 = W // 8 * 8
|
| 821 |
+
|
| 822 |
+
if flow_init is not None:
|
| 823 |
+
flow_init_resize = F.interpolate(
|
| 824 |
+
flow_init, size=(H8 // 8, W8 // 8), mode="nearest"
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
flow_init_resize[:, :1] = (
|
| 828 |
+
flow_init_resize[:, :1].clone() * (W8 // 8 * 1.0) / flow_init.size()[3]
|
| 829 |
+
)
|
| 830 |
+
flow_init_resize[:, 1:] = (
|
| 831 |
+
flow_init_resize[:, 1:].clone() * (H8 // 8 * 1.0) / flow_init.size()[2]
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
if not hasattr(self.args, "not_use_rfr_mask") or (
|
| 835 |
+
hasattr(self.args, "not_use_rfr_mask")
|
| 836 |
+
and (not self.args.not_use_rfr_mask)
|
| 837 |
+
):
|
| 838 |
+
im18 = F.interpolate(image1, size=(H8 // 8, W8 // 8), mode="bilinear")
|
| 839 |
+
im28 = F.interpolate(image2, size=(H8 // 8, W8 // 8), mode="bilinear")
|
| 840 |
+
|
| 841 |
+
warp21 = backwarp(im28, flow_init_resize)
|
| 842 |
+
error21 = torch.sum(torch.abs(warp21 - im18), dim=1, keepdim=True)
|
| 843 |
+
# print('errormin', error21.min(), error21.max())
|
| 844 |
+
f12init = (
|
| 845 |
+
torch.exp(
|
| 846 |
+
-self.attention2(
|
| 847 |
+
torch.cat([im18, error21, flow_init_resize], dim=1)
|
| 848 |
+
)
|
| 849 |
+
** 2
|
| 850 |
+
)
|
| 851 |
+
* flow_init_resize
|
| 852 |
+
)
|
| 853 |
+
else:
|
| 854 |
+
flow_init_resize = None
|
| 855 |
+
flow_init = torch.zeros(
|
| 856 |
+
image1.size()[0], 2, image1.size()[2] // 8, image1.size()[3] // 8
|
| 857 |
+
).cuda()
|
| 858 |
+
error21 = torch.zeros(
|
| 859 |
+
image1.size()[0], 1, image1.size()[2] // 8, image1.size()[3] // 8
|
| 860 |
+
).cuda()
|
| 861 |
+
|
| 862 |
+
f12_init = flow_init
|
| 863 |
+
# print('None inital flow!')
|
| 864 |
+
|
| 865 |
+
image1 = F.interpolate(image1, size=(H8, W8), mode="bilinear")
|
| 866 |
+
image2 = F.interpolate(image2, size=(H8, W8), mode="bilinear")
|
| 867 |
+
|
| 868 |
+
f12s, f12, f12_init = self.forward_pred(
|
| 869 |
+
image1, image2, iters, flow_init_resize, upsample, test_mode
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
if hasattr(self.args, "requires_sq_flow") and self.args.requires_sq_flow:
|
| 873 |
+
for ii in range(len(f12s)):
|
| 874 |
+
f12s[ii] = F.interpolate(f12s[ii], size=(H, W), mode="bilinear")
|
| 875 |
+
f12s[ii][:, :1] = f12s[ii][:, :1].clone() / (1.0 * W8) * W
|
| 876 |
+
f12s[ii][:, 1:] = f12s[ii][:, 1:].clone() / (1.0 * H8) * H
|
| 877 |
+
if self.training:
|
| 878 |
+
return f12s
|
| 879 |
+
else:
|
| 880 |
+
return [f12s[-1]], f12_init
|
| 881 |
+
else:
|
| 882 |
+
f12[:, :1] = f12[:, :1].clone() / (1.0 * W8) * W
|
| 883 |
+
f12[:, 1:] = f12[:, 1:].clone() / (1.0 * H8) * H
|
| 884 |
+
|
| 885 |
+
f12 = F.interpolate(f12, size=(H, W), mode="bilinear")
|
| 886 |
+
# print('wo!!')
|
| 887 |
+
return (
|
| 888 |
+
f12,
|
| 889 |
+
f12_init,
|
| 890 |
+
error21,
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
def forward_pred(
|
| 894 |
+
self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
|
| 895 |
+
):
|
| 896 |
+
"""Estimate optical flow between pair of frames"""
|
| 897 |
+
|
| 898 |
+
image1 = image1.contiguous()
|
| 899 |
+
image2 = image2.contiguous()
|
| 900 |
+
|
| 901 |
+
hdim = self.hidden_dim
|
| 902 |
+
cdim = self.context_dim
|
| 903 |
+
|
| 904 |
+
# run the feature network
|
| 905 |
+
with autocast(device.type, enabled=self.args.mixed_precision):
|
| 906 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
| 907 |
+
fmap1 = fmap1.float()
|
| 908 |
+
fmap2 = fmap2.float()
|
| 909 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
| 910 |
+
|
| 911 |
+
# run the context network
|
| 912 |
+
with autocast(device.type, enabled=self.args.mixed_precision):
|
| 913 |
+
cnet = self.fnet(image1)
|
| 914 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
| 915 |
+
net = torch.tanh(net)
|
| 916 |
+
inp = torch.relu(inp)
|
| 917 |
+
|
| 918 |
+
coords0, coords1 = self.initialize_flow(image1)
|
| 919 |
+
|
| 920 |
+
if flow_init is not None:
|
| 921 |
+
coords1 = coords1 + flow_init
|
| 922 |
+
|
| 923 |
+
flow_predictions = []
|
| 924 |
+
for itr in range(iters):
|
| 925 |
+
coords1 = coords1.detach()
|
| 926 |
+
if itr == 0:
|
| 927 |
+
if flow_init is not None:
|
| 928 |
+
coords1 = coords1 + flow_init
|
| 929 |
+
corr = corr_fn(coords1) # index correlation volume
|
| 930 |
+
|
| 931 |
+
flow = coords1 - coords0
|
| 932 |
+
with autocast(device.type, enabled=self.args.mixed_precision):
|
| 933 |
+
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
| 934 |
+
|
| 935 |
+
# F(t+1) = F(t) + \Delta(t)
|
| 936 |
+
coords1 = coords1 + delta_flow
|
| 937 |
+
|
| 938 |
+
# upsample predictions
|
| 939 |
+
if up_mask is None:
|
| 940 |
+
flow_up = upflow8(coords1 - coords0)
|
| 941 |
+
else:
|
| 942 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
| 943 |
+
|
| 944 |
+
flow_predictions.append(flow_up)
|
| 945 |
+
|
| 946 |
+
return flow_predictions, flow_up, flow_init
|
| 947 |
+
|
| 948 |
+
####################### WARPING #######################
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
# expects batched tensors, considered low-level operation
|
| 952 |
+
# img: bs, ch, h, w
|
| 953 |
+
# flow: bs, xy (pix displace), h, w
|
| 954 |
+
def flow_backwarp(
|
| 955 |
+
img, flow, resample="bilinear", padding_mode="border", align_corners=False
|
| 956 |
+
):
|
| 957 |
+
if len(img.shape) != 4:
|
| 958 |
+
img = img[None,]
|
| 959 |
+
if len(flow.shape) != 4:
|
| 960 |
+
flow = flow[None,]
|
| 961 |
+
q = (
|
| 962 |
+
2
|
| 963 |
+
* flow
|
| 964 |
+
/ torch.tensor(
|
| 965 |
+
[
|
| 966 |
+
flow.shape[-2],
|
| 967 |
+
flow.shape[-1],
|
| 968 |
+
],
|
| 969 |
+
device=flow.device,
|
| 970 |
+
dtype=torch.float,
|
| 971 |
+
)[None, :, None, None]
|
| 972 |
+
)
|
| 973 |
+
q = q + torch.stack(
|
| 974 |
+
torch.meshgrid(
|
| 975 |
+
torch.linspace(-1, 1, flow.shape[-2]),
|
| 976 |
+
torch.linspace(-1, 1, flow.shape[-1]),
|
| 977 |
+
)
|
| 978 |
+
)[
|
| 979 |
+
None,
|
| 980 |
+
].to(
|
| 981 |
+
flow.device
|
| 982 |
+
)
|
| 983 |
+
if img.dtype != q.dtype:
|
| 984 |
+
img = img.type(q.dtype)
|
| 985 |
+
|
| 986 |
+
return nn.functional.grid_sample(
|
| 987 |
+
img,
|
| 988 |
+
q.flip(dims=(1,)).permute(0, 2, 3, 1),
|
| 989 |
+
mode=resample, # nearest, bicubic, bilinear
|
| 990 |
+
padding_mode=padding_mode, # border, zeros, reflection
|
| 991 |
+
align_corners=align_corners,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
backwarp = flow_warp = flow_backwarp
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
# mode: sum, avg, lin, softmax
|
| 999 |
+
# lin/softmax w/out metric defaults to avg
|
| 1000 |
+
# must use gpu, move back to cpu if retain_device
|
| 1001 |
+
# typical metric: -20 * | img0 - backwarp(img1,flow) |
|
| 1002 |
+
# From Fannovel16: Changed mode params for common ops.
|
| 1003 |
+
def flow_forewarp(
|
| 1004 |
+
img, flow, mode="average", metric=None, mask=False, retain_device=True
|
| 1005 |
+
):
|
| 1006 |
+
# setup
|
| 1007 |
+
#if mode == "sum":
|
| 1008 |
+
# mode = "summation"
|
| 1009 |
+
#elif mode == "avg":
|
| 1010 |
+
# mode = "average"
|
| 1011 |
+
if mode in ["lin", "linear"]:
|
| 1012 |
+
#mode = "linear" if metric is not None else "average"
|
| 1013 |
+
mode = "linear" if metric is not None else "avg"
|
| 1014 |
+
elif mode in ["sm", "softmax"]:
|
| 1015 |
+
#mode = "softmax" if metric is not None else "average"
|
| 1016 |
+
mode = "soft" if metric is not None else "avg"
|
| 1017 |
+
if len(img.shape) != 4:
|
| 1018 |
+
img = img[None,]
|
| 1019 |
+
if len(flow.shape) != 4:
|
| 1020 |
+
flow = flow[None,]
|
| 1021 |
+
if metric is not None and len(metric.shape) != 4:
|
| 1022 |
+
metric = metric[None,]
|
| 1023 |
+
flow = flow.flip(dims=(1,))
|
| 1024 |
+
if img.dtype != torch.float32:
|
| 1025 |
+
img = img.type(torch.float32)
|
| 1026 |
+
if flow.dtype != torch.float32:
|
| 1027 |
+
flow = flow.type(torch.float32)
|
| 1028 |
+
if metric is not None and metric.dtype != torch.float32:
|
| 1029 |
+
metric = metric.type(torch.float32)
|
| 1030 |
+
|
| 1031 |
+
# move to gpu if necessary
|
| 1032 |
+
assert img.device == flow.device
|
| 1033 |
+
if metric is not None:
|
| 1034 |
+
assert img.device == metric.device
|
| 1035 |
+
was_cpu = img.device.type == "cpu"
|
| 1036 |
+
if was_cpu:
|
| 1037 |
+
img = img.to("cuda")
|
| 1038 |
+
flow = flow.to("cuda")
|
| 1039 |
+
if metric is not None:
|
| 1040 |
+
metric = metric.to("cuda")
|
| 1041 |
+
|
| 1042 |
+
# add mask
|
| 1043 |
+
if mask:
|
| 1044 |
+
bs, ch, h, w = img.shape
|
| 1045 |
+
img = torch.cat(
|
| 1046 |
+
[img, torch.ones(bs, 1, h, w, dtype=img.dtype, device=img.device)], dim=1
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
# forward, move back to cpu if desired
|
| 1050 |
+
ans = FunctionSoftsplat(img, flow, metric, mode)
|
| 1051 |
+
if was_cpu and retain_device:
|
| 1052 |
+
ans = ans.cpu()
|
| 1053 |
+
return ans
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
forewarp = flow_forewarp
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
# resizing utility
|
| 1060 |
+
def flow_resize(flow, size, mode="nearest", align_corners=False):
|
| 1061 |
+
# flow: bs,xy,h,w
|
| 1062 |
+
size = pixel_ij(size, rounding=True)
|
| 1063 |
+
if flow.dtype != torch.float:
|
| 1064 |
+
flow = flow.float()
|
| 1065 |
+
if len(flow.shape) == 3:
|
| 1066 |
+
flow = flow[None,]
|
| 1067 |
+
if flow.shape[-2:] == size:
|
| 1068 |
+
return flow
|
| 1069 |
+
return (
|
| 1070 |
+
nn.functional.interpolate(
|
| 1071 |
+
flow,
|
| 1072 |
+
size=size,
|
| 1073 |
+
mode=mode,
|
| 1074 |
+
align_corners=align_corners if mode != "nearest" else None,
|
| 1075 |
+
)
|
| 1076 |
+
* torch.tensor(
|
| 1077 |
+
[b / a for a, b in zip(flow.shape[-2:], size)],
|
| 1078 |
+
device=flow.device,
|
| 1079 |
+
)[None, :, None, None]
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
####################### TRADITIONAL #######################
|
| 1084 |
+
|
| 1085 |
+
# dense
|
| 1086 |
+
_lucaskanade = lambda a, b: np.moveaxis(
|
| 1087 |
+
cv2.optflow.calcOpticalFlowSparseToDense(
|
| 1088 |
+
a,
|
| 1089 |
+
b, # grid_step=5, sigma=0.5,
|
| 1090 |
+
),
|
| 1091 |
+
2,
|
| 1092 |
+
0,
|
| 1093 |
+
)[
|
| 1094 |
+
None,
|
| 1095 |
+
]
|
| 1096 |
+
_farneback = lambda a, b: np.moveaxis(
|
| 1097 |
+
cv2.calcOpticalFlowFarneback(
|
| 1098 |
+
a,
|
| 1099 |
+
b,
|
| 1100 |
+
None,
|
| 1101 |
+
0.6,
|
| 1102 |
+
3,
|
| 1103 |
+
25,
|
| 1104 |
+
7,
|
| 1105 |
+
5,
|
| 1106 |
+
1.2,
|
| 1107 |
+
cv2.OPTFLOW_FARNEBACK_GAUSSIAN,
|
| 1108 |
+
),
|
| 1109 |
+
2,
|
| 1110 |
+
0,
|
| 1111 |
+
)[
|
| 1112 |
+
None,
|
| 1113 |
+
]
|
| 1114 |
+
_dtvl1_ = cv2.optflow.createOptFlow_DualTVL1()
|
| 1115 |
+
_dtvl1 = lambda a, b: np.moveaxis(
|
| 1116 |
+
_dtvl1_.calc(
|
| 1117 |
+
a,
|
| 1118 |
+
b,
|
| 1119 |
+
None,
|
| 1120 |
+
),
|
| 1121 |
+
2,
|
| 1122 |
+
0,
|
| 1123 |
+
)[
|
| 1124 |
+
None,
|
| 1125 |
+
]
|
| 1126 |
+
_simple = lambda a, b: np.moveaxis(
|
| 1127 |
+
cv2.optflow.calcOpticalFlowSF(
|
| 1128 |
+
a,
|
| 1129 |
+
b,
|
| 1130 |
+
3,
|
| 1131 |
+
5,
|
| 1132 |
+
5,
|
| 1133 |
+
),
|
| 1134 |
+
2,
|
| 1135 |
+
0,
|
| 1136 |
+
)[
|
| 1137 |
+
None,
|
| 1138 |
+
]
|
| 1139 |
+
_pca_ = cv2.optflow.createOptFlow_PCAFlow()
|
| 1140 |
+
_pca = lambda a, b: np.moveaxis(
|
| 1141 |
+
_pca_.calc(
|
| 1142 |
+
a,
|
| 1143 |
+
b,
|
| 1144 |
+
None,
|
| 1145 |
+
),
|
| 1146 |
+
2,
|
| 1147 |
+
0,
|
| 1148 |
+
)[
|
| 1149 |
+
None,
|
| 1150 |
+
]
|
| 1151 |
+
_drlof = lambda a, b: np.moveaxis(
|
| 1152 |
+
cv2.optflow.calcOpticalFlowDenseRLOF(
|
| 1153 |
+
a,
|
| 1154 |
+
b,
|
| 1155 |
+
None,
|
| 1156 |
+
),
|
| 1157 |
+
2,
|
| 1158 |
+
0,
|
| 1159 |
+
)[
|
| 1160 |
+
None,
|
| 1161 |
+
]
|
| 1162 |
+
_deepflow_ = cv2.optflow.createOptFlow_DeepFlow()
|
| 1163 |
+
_deepflow = lambda a, b: np.moveaxis(
|
| 1164 |
+
_deepflow_.calc(
|
| 1165 |
+
a,
|
| 1166 |
+
b,
|
| 1167 |
+
None,
|
| 1168 |
+
),
|
| 1169 |
+
2,
|
| 1170 |
+
0,
|
| 1171 |
+
)[
|
| 1172 |
+
None,
|
| 1173 |
+
]
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
def cv2flow(a, b, method="lucaskanade", back=False):
|
| 1177 |
+
if method == "lucaskanade":
|
| 1178 |
+
f = _lucaskanade
|
| 1179 |
+
a = a.convert("L").cv2()
|
| 1180 |
+
b = b.convert("L").cv2()
|
| 1181 |
+
elif method == "farneback":
|
| 1182 |
+
f = _farneback
|
| 1183 |
+
a = a.convert("L").cv2()
|
| 1184 |
+
b = b.convert("L").cv2()
|
| 1185 |
+
elif method == "dtvl1":
|
| 1186 |
+
f = _dtvl1
|
| 1187 |
+
a = a.convert("L").cv2()
|
| 1188 |
+
b = b.convert("L").cv2()
|
| 1189 |
+
elif method == "simple":
|
| 1190 |
+
f = _simple
|
| 1191 |
+
a = a.convert("RGB").cv2()
|
| 1192 |
+
b = b.convert("RGB").cv2()
|
| 1193 |
+
elif method == "pca":
|
| 1194 |
+
f = _pca
|
| 1195 |
+
a = a.convert("L").cv2()
|
| 1196 |
+
b = b.convert("L").cv2()
|
| 1197 |
+
elif method == "drlof":
|
| 1198 |
+
f = _drlof
|
| 1199 |
+
a = a.convert("RGB").cv2()
|
| 1200 |
+
b = b.convert("RGB").cv2()
|
| 1201 |
+
elif method == "deepflow":
|
| 1202 |
+
f = _deepflow
|
| 1203 |
+
a = a.convert("L").cv2()
|
| 1204 |
+
b = b.convert("L").cv2()
|
| 1205 |
+
else:
|
| 1206 |
+
assert 0
|
| 1207 |
+
ans = f(b, a)
|
| 1208 |
+
if back:
|
| 1209 |
+
ans = np.concatenate(
|
| 1210 |
+
[
|
| 1211 |
+
ans,
|
| 1212 |
+
f(a, b),
|
| 1213 |
+
]
|
| 1214 |
+
)
|
| 1215 |
+
return torch.tensor(ans).flip(dims=(1,))
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
####################### FLOWNET2 #######################
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
def flownet2(img_a, img_b, mode="shm", back=False):
|
| 1222 |
+
# package
|
| 1223 |
+
url = f"http://localhost:8109/get-flow"
|
| 1224 |
+
if mode == "shm":
|
| 1225 |
+
t = time.time()
|
| 1226 |
+
fn_a = img_a.save(mkfile(f"/dev/shm/_flownet2/{t}/img_a.png"))
|
| 1227 |
+
fn_b = img_b.save(mkfile(f"/dev/shm/_flownet2/{t}/img_b.png"))
|
| 1228 |
+
elif mode == "net":
|
| 1229 |
+
assert False, "not impl"
|
| 1230 |
+
q = u2d.img2uri(img.pil("RGB"))
|
| 1231 |
+
q.decode()
|
| 1232 |
+
resp = requests.get(
|
| 1233 |
+
url,
|
| 1234 |
+
params={
|
| 1235 |
+
"img_a": fn_a,
|
| 1236 |
+
"img_b": fn_b,
|
| 1237 |
+
"mode": mode,
|
| 1238 |
+
"back": back,
|
| 1239 |
+
# 'vis': vis,
|
| 1240 |
+
},
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
# return
|
| 1244 |
+
ans = {"response": resp}
|
| 1245 |
+
if resp.status_code == 200:
|
| 1246 |
+
j = resp.json()
|
| 1247 |
+
ans["time"] = j["time"]
|
| 1248 |
+
ans["output"] = {
|
| 1249 |
+
"flow": torch.tensor(load(j["fn_flow"])),
|
| 1250 |
+
}
|
| 1251 |
+
# if vis:
|
| 1252 |
+
# ans['output']['vis'] = I(j['fn_vis'])
|
| 1253 |
+
if mode == "shm":
|
| 1254 |
+
shutil.rmtree(f"/dev/shm/_flownet2/{t}")
|
| 1255 |
+
return ans
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
####################### VISUALIZATION #######################
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
class Gridnet(nn.Module):
|
| 1262 |
+
def __init__(self, channels_0, channels_1, channels_2, total_dropout_p, depth):
|
| 1263 |
+
super().__init__()
|
| 1264 |
+
self.channels_0 = ch0 = channels_0
|
| 1265 |
+
self.channels_1 = ch1 = channels_1
|
| 1266 |
+
self.channels_2 = ch2 = channels_2
|
| 1267 |
+
self.total_dropout_p = p = total_dropout_p
|
| 1268 |
+
self.depth = depth
|
| 1269 |
+
self.encoders = nn.ModuleList(
|
| 1270 |
+
[GridnetEncoder(ch0, ch1, ch2) for i in range(self.depth)]
|
| 1271 |
+
)
|
| 1272 |
+
self.decoders = nn.ModuleList(
|
| 1273 |
+
[GridnetDecoder(ch0, ch1, ch2) for i in range(self.depth)]
|
| 1274 |
+
)
|
| 1275 |
+
self.total_dropout = GridnetTotalDropout(p)
|
| 1276 |
+
return
|
| 1277 |
+
|
| 1278 |
+
def forward(self, x):
|
| 1279 |
+
for e, enc in enumerate(self.encoders):
|
| 1280 |
+
t = [self.total_dropout(i) for i in t] if e != 0 else x
|
| 1281 |
+
t = enc(t)
|
| 1282 |
+
for d, dec in enumerate(self.decoders):
|
| 1283 |
+
t = [self.total_dropout(i) for i in t]
|
| 1284 |
+
t = dec(t)
|
| 1285 |
+
return t
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
class GridnetEncoder(nn.Module):
|
| 1289 |
+
def __init__(self, channels_0, channels_1, channels_2):
|
| 1290 |
+
super().__init__()
|
| 1291 |
+
self.channels_0 = ch0 = channels_0
|
| 1292 |
+
self.channels_1 = ch1 = channels_1
|
| 1293 |
+
self.channels_2 = ch2 = channels_2
|
| 1294 |
+
self.resnet_0 = GridnetResnet(ch0)
|
| 1295 |
+
self.resnet_1 = GridnetResnet(ch1)
|
| 1296 |
+
self.resnet_2 = GridnetResnet(ch2)
|
| 1297 |
+
self.downsample_01 = GridnetDownsample(ch0, ch1)
|
| 1298 |
+
self.downsample_12 = GridnetDownsample(ch1, ch2)
|
| 1299 |
+
return
|
| 1300 |
+
|
| 1301 |
+
def forward(self, x):
|
| 1302 |
+
out = [
|
| 1303 |
+
None,
|
| 1304 |
+
] * 3
|
| 1305 |
+
out[0] = self.resnet_0(x[0])
|
| 1306 |
+
out[1] = self.resnet_1(x[1]) + self.downsample_01(out[0])
|
| 1307 |
+
out[2] = self.resnet_2(x[2]) + self.downsample_12(out[1])
|
| 1308 |
+
return out
|
| 1309 |
+
|
| 1310 |
+
|
| 1311 |
+
class GridnetDecoder(nn.Module):
|
| 1312 |
+
def __init__(self, channels_0, channels_1, channels_2):
|
| 1313 |
+
super().__init__()
|
| 1314 |
+
self.channels_0 = ch0 = channels_0
|
| 1315 |
+
self.channels_1 = ch1 = channels_1
|
| 1316 |
+
self.channels_2 = ch2 = channels_2
|
| 1317 |
+
self.resnet_0 = GridnetResnet(ch0)
|
| 1318 |
+
self.resnet_1 = GridnetResnet(ch1)
|
| 1319 |
+
self.resnet_2 = GridnetResnet(ch2)
|
| 1320 |
+
self.upsample_10 = GridnetUpsample(ch1, ch0)
|
| 1321 |
+
self.upsample_21 = GridnetUpsample(ch2, ch1)
|
| 1322 |
+
return
|
| 1323 |
+
|
| 1324 |
+
def forward(self, x):
|
| 1325 |
+
out = [
|
| 1326 |
+
None,
|
| 1327 |
+
] * 3
|
| 1328 |
+
out[2] = self.resnet_2(x[2])
|
| 1329 |
+
out[1] = self.resnet_1(x[1]) + self.upsample_21(out[2])
|
| 1330 |
+
out[0] = self.resnet_0(x[0]) + self.upsample_10(out[1])
|
| 1331 |
+
return out
|
| 1332 |
+
|
| 1333 |
+
|
| 1334 |
+
class GridnetConverter(nn.Module):
|
| 1335 |
+
def __init__(self, channels_in, channels_out):
|
| 1336 |
+
super().__init__()
|
| 1337 |
+
self.channels_in = cin = channels_in
|
| 1338 |
+
self.channels_out = cout = channels_out
|
| 1339 |
+
self.nets = nn.ModuleList(
|
| 1340 |
+
[
|
| 1341 |
+
nn.Sequential(
|
| 1342 |
+
nn.PReLU(a),
|
| 1343 |
+
nn.Conv2d(a, b, kernel_size=1, padding=0),
|
| 1344 |
+
nn.BatchNorm2d(b),
|
| 1345 |
+
)
|
| 1346 |
+
for a, b in zip(cin, cout)
|
| 1347 |
+
]
|
| 1348 |
+
)
|
| 1349 |
+
return
|
| 1350 |
+
|
| 1351 |
+
def forward(self, x):
|
| 1352 |
+
return [m(q) for m, q in zip(self.nets, x)]
|
| 1353 |
+
|
| 1354 |
+
|
| 1355 |
+
class GridnetResnet(nn.Module):
|
| 1356 |
+
def __init__(self, channels):
|
| 1357 |
+
super().__init__()
|
| 1358 |
+
self.channels = ch = channels
|
| 1359 |
+
self.net = nn.Sequential(
|
| 1360 |
+
nn.PReLU(ch),
|
| 1361 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 1362 |
+
nn.BatchNorm2d(ch),
|
| 1363 |
+
nn.PReLU(ch),
|
| 1364 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 1365 |
+
nn.BatchNorm2d(ch),
|
| 1366 |
+
)
|
| 1367 |
+
return
|
| 1368 |
+
|
| 1369 |
+
def forward(self, x):
|
| 1370 |
+
return x + self.net(x)
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
class GridnetDownsample(nn.Module):
|
| 1374 |
+
def __init__(self, channels_in, channels_out):
|
| 1375 |
+
super().__init__()
|
| 1376 |
+
self.channels_in = chin = channels_in
|
| 1377 |
+
self.channels_out = chout = channels_out
|
| 1378 |
+
self.net = nn.Sequential(
|
| 1379 |
+
nn.PReLU(chin),
|
| 1380 |
+
nn.Conv2d(chin, chin, kernel_size=3, padding=1, stride=2),
|
| 1381 |
+
nn.BatchNorm2d(chin),
|
| 1382 |
+
nn.PReLU(chin),
|
| 1383 |
+
nn.Conv2d(chin, chout, kernel_size=3, padding=1),
|
| 1384 |
+
nn.BatchNorm2d(chout),
|
| 1385 |
+
)
|
| 1386 |
+
return
|
| 1387 |
+
|
| 1388 |
+
def forward(self, x):
|
| 1389 |
+
return self.net(x)
|
| 1390 |
+
|
| 1391 |
+
|
| 1392 |
+
class GridnetUpsample(nn.Module):
|
| 1393 |
+
def __init__(self, channels_in, channels_out):
|
| 1394 |
+
super().__init__()
|
| 1395 |
+
self.channels_in = chin = channels_in
|
| 1396 |
+
self.channels_out = chout = channels_out
|
| 1397 |
+
self.net = nn.Sequential(
|
| 1398 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
| 1399 |
+
nn.PReLU(chin),
|
| 1400 |
+
nn.Conv2d(chin, chout, kernel_size=3, padding=1),
|
| 1401 |
+
nn.BatchNorm2d(chout),
|
| 1402 |
+
nn.PReLU(chout),
|
| 1403 |
+
nn.Conv2d(chout, chout, kernel_size=3, padding=1),
|
| 1404 |
+
nn.BatchNorm2d(chout),
|
| 1405 |
+
)
|
| 1406 |
+
return
|
| 1407 |
+
|
| 1408 |
+
def forward(self, x):
|
| 1409 |
+
return self.net(x)
|
| 1410 |
+
|
| 1411 |
+
|
| 1412 |
+
class GridnetTotalDropout(nn.Module):
|
| 1413 |
+
def __init__(self, p):
|
| 1414 |
+
super().__init__()
|
| 1415 |
+
assert 0 <= p < 1
|
| 1416 |
+
self.p = p
|
| 1417 |
+
self.weight = 1 / (1 - p)
|
| 1418 |
+
return
|
| 1419 |
+
|
| 1420 |
+
def get_drop(self, x):
|
| 1421 |
+
d = torch.rand(len(x))[:, None, None, None] < self.p
|
| 1422 |
+
d = (1 - d.float()).to(x.device) * self.weight
|
| 1423 |
+
return d
|
| 1424 |
+
|
| 1425 |
+
def forward(self, x, force_drop=None):
|
| 1426 |
+
if force_drop is True:
|
| 1427 |
+
ans = x * self.get_drop(x)
|
| 1428 |
+
elif force_drop is False:
|
| 1429 |
+
ans = x
|
| 1430 |
+
else:
|
| 1431 |
+
if self.training:
|
| 1432 |
+
ans = x * self.get_drop(x)
|
| 1433 |
+
else:
|
| 1434 |
+
ans = x
|
| 1435 |
+
return ans
|
| 1436 |
+
|
| 1437 |
+
|
| 1438 |
+
class Interpolator(nn.Module):
|
| 1439 |
+
def __init__(self, size, mode="bilinear"):
|
| 1440 |
+
super().__init__()
|
| 1441 |
+
self.size = size
|
| 1442 |
+
self.mode = mode
|
| 1443 |
+
return
|
| 1444 |
+
|
| 1445 |
+
def forward(self, x, is_flow=False):
|
| 1446 |
+
if x.shape[-2] == self.size:
|
| 1447 |
+
return x
|
| 1448 |
+
if len(x.shape) == 4:
|
| 1449 |
+
# bs,ch,h,w
|
| 1450 |
+
bs, ch, h, w = x.shape
|
| 1451 |
+
ans = nn.functional.interpolate(
|
| 1452 |
+
x,
|
| 1453 |
+
size=self.size,
|
| 1454 |
+
mode=self.mode,
|
| 1455 |
+
align_corners=(False, None)[self.mode == "nearest"],
|
| 1456 |
+
)
|
| 1457 |
+
if is_flow:
|
| 1458 |
+
ans = (
|
| 1459 |
+
ans
|
| 1460 |
+
* torch.tensor(
|
| 1461 |
+
[b / a for a, b in zip((h, w), self.size)],
|
| 1462 |
+
device=ans.device,
|
| 1463 |
+
)[None, :, None, None]
|
| 1464 |
+
)
|
| 1465 |
+
return ans
|
| 1466 |
+
elif len(x.shape) == 5:
|
| 1467 |
+
# bs,k,ch,h,w (merge bs and k)
|
| 1468 |
+
bs, k, ch, h, w = x.shape
|
| 1469 |
+
return self.forward(
|
| 1470 |
+
x.view(bs * k, ch, h, w),
|
| 1471 |
+
is_flow=is_flow,
|
| 1472 |
+
).view(bs, k, ch, *self.size)
|
| 1473 |
+
else:
|
| 1474 |
+
assert 0
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
###################### CANNY ######################
|
| 1478 |
+
|
| 1479 |
+
|
| 1480 |
+
def canny(img, a=100, b=200):
|
| 1481 |
+
img = I(img).convert("L")
|
| 1482 |
+
return I(cv2.Canny(img.cv2(), a, b))
|
| 1483 |
+
|
| 1484 |
+
|
| 1485 |
+
# https://www.pyimagesearch.com/2015/04/06/zero-parameter-automatic-canny-edge-detection-with-python-and-opencv/
|
| 1486 |
+
def canny_pis(img, sigma=0.33):
|
| 1487 |
+
# compute the median of the single channel pixel intensities
|
| 1488 |
+
img = I(img).convert("L").uint8(ch_last=False)
|
| 1489 |
+
v = np.median(img)
|
| 1490 |
+
# apply automatic Canny edge detection using the computed median
|
| 1491 |
+
lower = int(max(0, (1.0 - sigma) * v))
|
| 1492 |
+
upper = int(min(255, (1.0 + sigma) * v))
|
| 1493 |
+
edged = cv2.Canny(img[0], lower, upper)
|
| 1494 |
+
# return the edged image
|
| 1495 |
+
return I(edged)
|
| 1496 |
+
|
| 1497 |
+
|
| 1498 |
+
# https://en.wikipedia.org/wiki/Otsu%27s_method
|
| 1499 |
+
def canny_otsu(img):
|
| 1500 |
+
img = I(img).convert("L").uint8(ch_last=False)
|
| 1501 |
+
high, _ = cv2.threshold(img[0], 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 1502 |
+
low = 0.5 * high
|
| 1503 |
+
return I(cv2.Canny(img[0], low, high))
|
| 1504 |
+
|
| 1505 |
+
|
| 1506 |
+
def xdog(img, t=1.0, epsilon=0.04, phi=100, sigma=3, k=1.6):
|
| 1507 |
+
img = I(img).convert("L").uint8(ch_last=False)
|
| 1508 |
+
grey = np.asarray(img, dtype=np.float32)
|
| 1509 |
+
g0 = scipy.ndimage.gaussian_filter(grey, sigma)
|
| 1510 |
+
g1 = scipy.ndimage.gaussian_filter(grey, sigma * k)
|
| 1511 |
+
|
| 1512 |
+
# ans = ((1+p) * g0 - p * g1) / 255
|
| 1513 |
+
ans = (g0 - t * g1) / 255
|
| 1514 |
+
ans = 1 + np.tanh(phi * (ans - epsilon)) * (ans < epsilon)
|
| 1515 |
+
return ans
|
| 1516 |
+
|
| 1517 |
+
|
| 1518 |
+
def dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
|
| 1519 |
+
img = I(img).convert("L").tensor()[None]
|
| 1520 |
+
kern0 = max(2 * int(sigma * kernel_factor) + 1, 3)
|
| 1521 |
+
kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3)
|
| 1522 |
+
g0 = kornia.filters.gaussian_blur2d(
|
| 1523 |
+
img,
|
| 1524 |
+
(kern0, kern0),
|
| 1525 |
+
(sigma, sigma),
|
| 1526 |
+
border_type="replicate",
|
| 1527 |
+
)
|
| 1528 |
+
g1 = kornia.filters.gaussian_blur2d(
|
| 1529 |
+
img,
|
| 1530 |
+
(kern1, kern1),
|
| 1531 |
+
(sigma * k, sigma * k),
|
| 1532 |
+
border_type="replicate",
|
| 1533 |
+
)
|
| 1534 |
+
ans = 0.5 + t * (g1 - g0) - epsilon
|
| 1535 |
+
ans = ans.clip(0, 1) if clip else ans
|
| 1536 |
+
return ans[0].numpy()
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
# input: (bs,rgb(a),h,w) or (bs,1,h,w)
|
| 1540 |
+
# returns: (bs,1,h,w)
|
| 1541 |
+
def batch_dog(img, t=1.0, sigma=1.0, k=1.6, epsilon=0.01, kernel_factor=4, clip=True):
|
| 1542 |
+
# to grayscale if needed
|
| 1543 |
+
bs, ch, h, w = img.shape
|
| 1544 |
+
if ch in [3, 4]:
|
| 1545 |
+
img = kornia.color.rgb_to_grayscale(img[:, :3])
|
| 1546 |
+
else:
|
| 1547 |
+
assert ch == 1
|
| 1548 |
+
|
| 1549 |
+
# calculate dog
|
| 1550 |
+
kern0 = max(2 * int(sigma * kernel_factor) + 1, 3)
|
| 1551 |
+
kern1 = max(2 * int(sigma * k * kernel_factor) + 1, 3)
|
| 1552 |
+
g0 = kornia.filters.gaussian_blur2d(
|
| 1553 |
+
img,
|
| 1554 |
+
(kern0, kern0),
|
| 1555 |
+
(sigma, sigma),
|
| 1556 |
+
border_type="replicate",
|
| 1557 |
+
)
|
| 1558 |
+
g1 = kornia.filters.gaussian_blur2d(
|
| 1559 |
+
img,
|
| 1560 |
+
(kern1, kern1),
|
| 1561 |
+
(sigma * k, sigma * k),
|
| 1562 |
+
border_type="replicate",
|
| 1563 |
+
)
|
| 1564 |
+
ans = 0.5 + t * (g1 - g0) - epsilon
|
| 1565 |
+
ans = ans.clip(0, 1) if clip else ans
|
| 1566 |
+
return ans
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
############### DERIVED DISTANCES ###############
|
| 1570 |
+
|
| 1571 |
+
# input: (bs,h,w) or (bs,1,h,w)
|
| 1572 |
+
# returns: (bs,)
|
| 1573 |
+
# normalized s.t. metric is same across proportional image scales
|
| 1574 |
+
|
| 1575 |
+
|
| 1576 |
+
# average of two asymmetric distances
|
| 1577 |
+
# normalized by diameter and area
|
| 1578 |
+
def batch_chamfer_distance(gt, pred, block=1024, return_more=False):
|
| 1579 |
+
t = batch_chamfer_distance_t(gt, pred, block=block)
|
| 1580 |
+
p = batch_chamfer_distance_p(gt, pred, block=block)
|
| 1581 |
+
cd = (t + p) / 2
|
| 1582 |
+
return cd
|
| 1583 |
+
|
| 1584 |
+
|
| 1585 |
+
def batch_chamfer_distance_t(gt, pred, block=1024, return_more=False):
|
| 1586 |
+
assert gt.device == pred.device and gt.shape == pred.shape
|
| 1587 |
+
bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1]
|
| 1588 |
+
dpred = batch_edt(pred, block=block)
|
| 1589 |
+
cd = (gt * dpred).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
|
| 1590 |
+
if len(cd.shape) == 2:
|
| 1591 |
+
assert cd.shape[1] == 1
|
| 1592 |
+
cd = cd.squeeze(1)
|
| 1593 |
+
return cd
|
| 1594 |
+
|
| 1595 |
+
|
| 1596 |
+
def batch_chamfer_distance_p(gt, pred, block=1024, return_more=False):
|
| 1597 |
+
assert gt.device == pred.device and gt.shape == pred.shape
|
| 1598 |
+
bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1]
|
| 1599 |
+
dgt = batch_edt(gt, block=block)
|
| 1600 |
+
cd = (pred * dgt).float().mean((-2, -1)) / np.sqrt(h**2 + w**2)
|
| 1601 |
+
if len(cd.shape) == 2:
|
| 1602 |
+
assert cd.shape[1] == 1
|
| 1603 |
+
cd = cd.squeeze(1)
|
| 1604 |
+
return cd
|
| 1605 |
+
|
| 1606 |
+
|
| 1607 |
+
# normalized by diameter
|
| 1608 |
+
# always between [0,1]
|
| 1609 |
+
def batch_hausdorff_distance(gt, pred, block=1024, return_more=False):
|
| 1610 |
+
assert gt.device == pred.device and gt.shape == pred.shape
|
| 1611 |
+
bs, h, w = gt.shape[0], gt.shape[-2], gt.shape[-1]
|
| 1612 |
+
dgt = batch_edt(gt, block=block)
|
| 1613 |
+
dpred = batch_edt(pred, block=block)
|
| 1614 |
+
hd = torch.stack(
|
| 1615 |
+
[
|
| 1616 |
+
(dgt * pred).amax(dim=(-2, -1)),
|
| 1617 |
+
(dpred * gt).amax(dim=(-2, -1)),
|
| 1618 |
+
]
|
| 1619 |
+
).amax(dim=0).float() / np.sqrt(h**2 + w**2)
|
| 1620 |
+
if len(hd.shape) == 2:
|
| 1621 |
+
assert hd.shape[1] == 1
|
| 1622 |
+
hd = hd.squeeze(1)
|
| 1623 |
+
return hd
|
| 1624 |
+
|
| 1625 |
+
|
| 1626 |
+
#################### UTILITIES ####################
|
| 1627 |
+
|
| 1628 |
+
|
| 1629 |
+
def reset_parameters(model):
|
| 1630 |
+
for layer in model.children():
|
| 1631 |
+
if hasattr(layer, "reset_parameters"):
|
| 1632 |
+
layer.reset_parameters()
|
| 1633 |
+
return model
|
| 1634 |
+
|
| 1635 |
+
|
| 1636 |
+
def channel_squeeze(x, dim=1):
|
| 1637 |
+
a = x.shape[:dim]
|
| 1638 |
+
b = x.shape[dim + 2 :]
|
| 1639 |
+
return x.reshape(*a, -1, *b)
|
| 1640 |
+
|
| 1641 |
+
|
| 1642 |
+
def channel_unsqueeze(x, shape, dim=1):
|
| 1643 |
+
a = x.shape[:dim]
|
| 1644 |
+
b = x.shape[dim + 1 :]
|
| 1645 |
+
return x.reshape(*a, *shape, *b)
|
| 1646 |
+
|
| 1647 |
+
|
| 1648 |
+
def default_collate(items, device=None):
|
| 1649 |
+
return to(dict(torch.utils.data.dataloader.default_collate(items)), device)
|
| 1650 |
+
|
| 1651 |
+
|
| 1652 |
+
def to(x, device):
|
| 1653 |
+
if device is None:
|
| 1654 |
+
return x
|
| 1655 |
+
if issubclass(x.__class__, dict):
|
| 1656 |
+
return dict(
|
| 1657 |
+
{
|
| 1658 |
+
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 1659 |
+
for k, v in x.items()
|
| 1660 |
+
}
|
| 1661 |
+
)
|
| 1662 |
+
if isinstance(x, torch.Tensor):
|
| 1663 |
+
return x.to(device)
|
| 1664 |
+
if isinstance(x, np.ndarray):
|
| 1665 |
+
return torch.tensor(x).to(device)
|
| 1666 |
+
assert 0, "data not understood"
|
| 1667 |
+
|
| 1668 |
+
|
| 1669 |
+
################ PARSING ################
|
| 1670 |
+
|
| 1671 |
+
from argparse import Namespace
|
| 1672 |
+
|
| 1673 |
+
# args: all args
|
| 1674 |
+
# bargs: base args
|
| 1675 |
+
# pargs: data processing args
|
| 1676 |
+
# largs: data loading args
|
| 1677 |
+
# margs: model args
|
| 1678 |
+
# targs: training args
|
| 1679 |
+
|
| 1680 |
+
|
| 1681 |
+
# typically used to read dataset filters
|
| 1682 |
+
def read_filter(fn, cast=None, sort=True, sort_key=None):
|
| 1683 |
+
if cast is None:
|
| 1684 |
+
cast = lambda x: x
|
| 1685 |
+
ans = [cast(line) for line in read(fn).split("\n") if line != ""]
|
| 1686 |
+
if sort:
|
| 1687 |
+
return sorted(ans, key=sort_key)
|
| 1688 |
+
else:
|
| 1689 |
+
return ans
|
| 1690 |
+
|
| 1691 |
+
|
| 1692 |
+
################ FILE MANAGEMENT ################
|
| 1693 |
+
|
| 1694 |
+
|
| 1695 |
+
def mkfile(fn, parents=True, exist_ok=True):
|
| 1696 |
+
dn = "/".join(fn.split("/")[:-1])
|
| 1697 |
+
mkdir(dn, parents=parents, exist_ok=exist_ok)
|
| 1698 |
+
return fn
|
| 1699 |
+
|
| 1700 |
+
|
| 1701 |
+
def mkdir(dn, parents=True, exist_ok=True):
|
| 1702 |
+
pathlib.Path(dn).mkdir(parents=parents, exist_ok=exist_ok)
|
| 1703 |
+
return dn if (not dn[-1] == "/" or dn == "/") else dn[:-1]
|
| 1704 |
+
|
| 1705 |
+
|
| 1706 |
+
def fstrip(fn, return_more=False):
|
| 1707 |
+
dspl = fn.split("/")
|
| 1708 |
+
dn = "/".join(dspl[:-1]) if len(dspl) > 1 else "."
|
| 1709 |
+
fn = dspl[-1]
|
| 1710 |
+
fspl = fn.split(".")
|
| 1711 |
+
if len(fspl) == 1:
|
| 1712 |
+
bn = fspl[0]
|
| 1713 |
+
ext = ""
|
| 1714 |
+
else:
|
| 1715 |
+
bn = ".".join(fspl[:-1])
|
| 1716 |
+
ext = fspl[-1]
|
| 1717 |
+
if return_more:
|
| 1718 |
+
return Namespace(
|
| 1719 |
+
dn=dn,
|
| 1720 |
+
fn=fn,
|
| 1721 |
+
path=f"{dn}/{fn}",
|
| 1722 |
+
bn_path=f"{dn}/{bn}",
|
| 1723 |
+
bn=bn,
|
| 1724 |
+
ext=ext,
|
| 1725 |
+
)
|
| 1726 |
+
else:
|
| 1727 |
+
return bn
|
| 1728 |
+
|
| 1729 |
+
|
| 1730 |
+
def read(fn, mode="r"):
|
| 1731 |
+
with open(fn, mode) as handle:
|
| 1732 |
+
return handle.read()
|
| 1733 |
+
|
| 1734 |
+
|
| 1735 |
+
def write(text, fn, mode="w"):
|
| 1736 |
+
mkfile(fn, parents=True, exist_ok=True)
|
| 1737 |
+
with open(fn, mode) as handle:
|
| 1738 |
+
return handle.write(text)
|
| 1739 |
+
|
| 1740 |
+
|
| 1741 |
+
import pickle
|
| 1742 |
+
|
| 1743 |
+
|
| 1744 |
+
def dump(obj, fn, mode="wb"):
|
| 1745 |
+
mkfile(fn, parents=True, exist_ok=True)
|
| 1746 |
+
with open(fn, mode) as handle:
|
| 1747 |
+
return pickle.dump(obj, handle)
|
| 1748 |
+
|
| 1749 |
+
|
| 1750 |
+
def load(fn, mode="rb"):
|
| 1751 |
+
with open(fn, mode) as handle:
|
| 1752 |
+
return pickle.load(handle)
|
| 1753 |
+
|
| 1754 |
+
|
| 1755 |
+
import json
|
| 1756 |
+
|
| 1757 |
+
|
| 1758 |
+
def jwrite(x, fn, mode="w", indent="\t", sort_keys=False):
|
| 1759 |
+
mkfile(fn, parents=True, exist_ok=True)
|
| 1760 |
+
with open(fn, mode) as handle:
|
| 1761 |
+
return json.dump(x, handle, indent=indent, sort_keys=sort_keys)
|
| 1762 |
+
|
| 1763 |
+
|
| 1764 |
+
def jread(fn, mode="r"):
|
| 1765 |
+
with open(fn, mode) as handle:
|
| 1766 |
+
return json.load(handle)
|
| 1767 |
+
|
| 1768 |
+
|
| 1769 |
+
try:
|
| 1770 |
+
import yaml
|
| 1771 |
+
|
| 1772 |
+
def ywrite(x, fn, mode="w", default_flow_style=False):
|
| 1773 |
+
mkfile(fn, parents=True, exist_ok=True)
|
| 1774 |
+
with open(fn, mode) as handle:
|
| 1775 |
+
return yaml.dump(x, handle, default_flow_style=default_flow_style)
|
| 1776 |
+
|
| 1777 |
+
def yread(fn, mode="r"):
|
| 1778 |
+
with open(fn, mode) as handle:
|
| 1779 |
+
return yaml.safe_load(handle)
|
| 1780 |
+
|
| 1781 |
+
except:
|
| 1782 |
+
pass
|
| 1783 |
+
|
| 1784 |
+
try:
|
| 1785 |
+
import pyunpack
|
| 1786 |
+
except:
|
| 1787 |
+
pass
|
| 1788 |
+
|
| 1789 |
+
try:
|
| 1790 |
+
import mysql
|
| 1791 |
+
import mysql.connector
|
| 1792 |
+
except:
|
| 1793 |
+
pass
|
| 1794 |
+
|
| 1795 |
+
|
| 1796 |
+
################ MISC ################
|
| 1797 |
+
|
| 1798 |
+
hakase = "./env/__hakase__.jpg"
|
| 1799 |
+
if not os.path.isfile(hakase):
|
| 1800 |
+
hakase = "./__env__/__hakase__.jpg"
|
| 1801 |
+
|
| 1802 |
+
|
| 1803 |
+
def mem(units="m"):
|
| 1804 |
+
return (
|
| 1805 |
+
psProcess(os.getpid()).memory_info().rss
|
| 1806 |
+
/ {
|
| 1807 |
+
"b": 1,
|
| 1808 |
+
"k": 1e3,
|
| 1809 |
+
"m": 1e6,
|
| 1810 |
+
"g": 1e9,
|
| 1811 |
+
"t": 1e12,
|
| 1812 |
+
}[units[0].lower()]
|
| 1813 |
+
)
|
| 1814 |
+
|
| 1815 |
+
|
| 1816 |
+
def chunk(array, length, colwise=True):
|
| 1817 |
+
if colwise:
|
| 1818 |
+
return [array[i : i + length] for i in range(0, len(array), length)]
|
| 1819 |
+
else:
|
| 1820 |
+
return chunk(array, int(math.ceil(len(array) / length)), colwise=True)
|
| 1821 |
+
|
| 1822 |
+
|
| 1823 |
+
def classtree(x):
|
| 1824 |
+
return inspect.getclasstree(inspect.getmro(x))
|
| 1825 |
+
|
| 1826 |
+
|
| 1827 |
+
################ AESTHETIC ################
|
| 1828 |
+
|
| 1829 |
+
|
| 1830 |
+
class Table:
|
| 1831 |
+
def __init__(
|
| 1832 |
+
self,
|
| 1833 |
+
table,
|
| 1834 |
+
delimiter=" ",
|
| 1835 |
+
orientation="br",
|
| 1836 |
+
double_colon=True,
|
| 1837 |
+
):
|
| 1838 |
+
self.delimiter = delimiter
|
| 1839 |
+
self.orientation = orientation
|
| 1840 |
+
self.t = Table.parse(table, delimiter, orientation, double_colon)
|
| 1841 |
+
return
|
| 1842 |
+
|
| 1843 |
+
# rendering
|
| 1844 |
+
def __str__(self):
|
| 1845 |
+
return self.render()
|
| 1846 |
+
|
| 1847 |
+
def __repr__(self):
|
| 1848 |
+
return self.render()
|
| 1849 |
+
|
| 1850 |
+
def render(self):
|
| 1851 |
+
# set up empty entry
|
| 1852 |
+
empty = ("", Table._spec(self.orientation, transpose=False))
|
| 1853 |
+
|
| 1854 |
+
# calculate table size
|
| 1855 |
+
t = copy.deepcopy(self.t)
|
| 1856 |
+
totalrows = len(t)
|
| 1857 |
+
totalcols = [len(r) for r in t]
|
| 1858 |
+
assert min(totalcols) == max(totalcols)
|
| 1859 |
+
totalcols = totalcols[0]
|
| 1860 |
+
|
| 1861 |
+
# string-ify
|
| 1862 |
+
for i in range(totalrows):
|
| 1863 |
+
for j in range(totalcols):
|
| 1864 |
+
x, s = t[i][j]
|
| 1865 |
+
sp = s[11]
|
| 1866 |
+
if sp:
|
| 1867 |
+
x = eval(f'f"{{{x}{sp}}}"')
|
| 1868 |
+
Table._put((str(x), s), t, (i, j), empty)
|
| 1869 |
+
|
| 1870 |
+
# expand delimiters
|
| 1871 |
+
_repl = (
|
| 1872 |
+
lambda s: s[:2] + (1, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:]
|
| 1873 |
+
if s[2]
|
| 1874 |
+
else s[:2] + (0, 0, 0, 0, 0) + s[7:10] + (1,) + s[11:]
|
| 1875 |
+
)
|
| 1876 |
+
for i, row in enumerate(t):
|
| 1877 |
+
for j, (x, s_own) in enumerate(row):
|
| 1878 |
+
# expand delim_up(^)
|
| 1879 |
+
if s_own[3]:
|
| 1880 |
+
u, v = i, j
|
| 1881 |
+
while 0 <= u:
|
| 1882 |
+
_, s = t[u][v]
|
| 1883 |
+
if (i, j) != (u, v) and (s[2] and not s[10]):
|
| 1884 |
+
break
|
| 1885 |
+
Table._put((x, _repl(s)), t, (u, v), empty)
|
| 1886 |
+
u -= 1
|
| 1887 |
+
|
| 1888 |
+
# expand delim_down(v)
|
| 1889 |
+
if s_own[4]:
|
| 1890 |
+
u, v = i, j
|
| 1891 |
+
while u < totalrows:
|
| 1892 |
+
_, s = t[u][v]
|
| 1893 |
+
if (i, j) != (u, v) and (s[2] and not s[10]):
|
| 1894 |
+
break
|
| 1895 |
+
Table._put((x, _repl(s)), t, (u, v), empty)
|
| 1896 |
+
u += 1
|
| 1897 |
+
|
| 1898 |
+
# expand delim_right(>)
|
| 1899 |
+
if s_own[5]:
|
| 1900 |
+
u, v = i, j
|
| 1901 |
+
while v < totalcols:
|
| 1902 |
+
_, s = t[u][v]
|
| 1903 |
+
if (i, j) != (u, v) and (s[2] and not s[10]):
|
| 1904 |
+
break
|
| 1905 |
+
Table._put((x, _repl(s)), t, (u, v), empty)
|
| 1906 |
+
v += 1
|
| 1907 |
+
|
| 1908 |
+
# expand delim_left(<)
|
| 1909 |
+
if s_own[6]:
|
| 1910 |
+
u, v = i, j
|
| 1911 |
+
while 0 <= v:
|
| 1912 |
+
_, s = t[u][v]
|
| 1913 |
+
if (i, j) != (u, v) and (s[2] and not s[10]):
|
| 1914 |
+
break
|
| 1915 |
+
Table._put((x, _repl(s)), t, (u, v), empty)
|
| 1916 |
+
v -= 1
|
| 1917 |
+
|
| 1918 |
+
# justification calculation
|
| 1919 |
+
widths = [
|
| 1920 |
+
0,
|
| 1921 |
+
] * totalcols # j
|
| 1922 |
+
heights = [
|
| 1923 |
+
0,
|
| 1924 |
+
] * totalrows # i
|
| 1925 |
+
for i, row in enumerate(t):
|
| 1926 |
+
for j, (x, s) in enumerate(row):
|
| 1927 |
+
# height caclulation
|
| 1928 |
+
heights[i] = max(heights[i], x.count("\n"))
|
| 1929 |
+
|
| 1930 |
+
# width calculation; non-delim fillers no contribution
|
| 1931 |
+
if s[2] or not s[10]:
|
| 1932 |
+
w = max(len(q) for q in x.split("\n"))
|
| 1933 |
+
widths[j] = max(widths[j], w)
|
| 1934 |
+
# no newline ==> height=1
|
| 1935 |
+
heights = [h + 1 for h in heights]
|
| 1936 |
+
|
| 1937 |
+
# render table
|
| 1938 |
+
rend = []
|
| 1939 |
+
roff = 0
|
| 1940 |
+
for i, row in enumerate(t):
|
| 1941 |
+
for j, (x, s) in enumerate(row):
|
| 1942 |
+
w, h = widths[j], heights[i]
|
| 1943 |
+
|
| 1944 |
+
# expand fillers and delimiters
|
| 1945 |
+
if s[2] or s[10]:
|
| 1946 |
+
xs = x.split("\n")
|
| 1947 |
+
xw0 = min(len(l) for l in xs)
|
| 1948 |
+
xw1 = max(len(l) for l in xs)
|
| 1949 |
+
xh = len(xs)
|
| 1950 |
+
if (xw0 == xw1 == w) and (xh == h):
|
| 1951 |
+
pass
|
| 1952 |
+
elif xw0 == xw1 == w:
|
| 1953 |
+
x = "\n".join(
|
| 1954 |
+
[
|
| 1955 |
+
xs[0],
|
| 1956 |
+
]
|
| 1957 |
+
* h
|
| 1958 |
+
)
|
| 1959 |
+
elif xh == h:
|
| 1960 |
+
x = "\n".join([(l[0] if l else "") * w for l in xs])
|
| 1961 |
+
else:
|
| 1962 |
+
x = x[0] if x else " "
|
| 1963 |
+
x = "\n".join(
|
| 1964 |
+
[
|
| 1965 |
+
x * w,
|
| 1966 |
+
]
|
| 1967 |
+
* h
|
| 1968 |
+
)
|
| 1969 |
+
|
| 1970 |
+
# justify horizontally
|
| 1971 |
+
x = [l.rjust(w) if s[0] else l.ljust(w) for l in x.split("\n")]
|
| 1972 |
+
|
| 1973 |
+
# justify vertically
|
| 1974 |
+
plus = [
|
| 1975 |
+
" " * w,
|
| 1976 |
+
] * (h - len(x))
|
| 1977 |
+
x = plus + x if not s[1] else x + plus
|
| 1978 |
+
|
| 1979 |
+
# input to table
|
| 1980 |
+
for r, xline in enumerate(x):
|
| 1981 |
+
Table._put(xline, rend, (roff + r, j), None)
|
| 1982 |
+
roff += h
|
| 1983 |
+
|
| 1984 |
+
# return rendered string
|
| 1985 |
+
return "\n".join(["".join(r) for r in rend])
|
| 1986 |
+
|
| 1987 |
+
# parsing
|
| 1988 |
+
def _spec(s, transpose=False):
|
| 1989 |
+
if ":" in s:
|
| 1990 |
+
i = s.index(":")
|
| 1991 |
+
sp = s[i:]
|
| 1992 |
+
s = s[:i]
|
| 1993 |
+
else:
|
| 1994 |
+
sp = ""
|
| 1995 |
+
s = s.lower()
|
| 1996 |
+
return (
|
| 1997 |
+
int("r" in s), # 0:: 0:left(l) 1:right(r)
|
| 1998 |
+
int("t" in s), # 1:: 0:bottom(b) 1:top(t)
|
| 1999 |
+
int(any([i in s for i in [".", "<", ">", "^", "v"]])), # 2:: delim_here(.)
|
| 2000 |
+
int("^" in s if not transpose else "<" in s), # 3:: delim_up(^)
|
| 2001 |
+
int("v" in s if not transpose else ">" in s), # 4:: delim_down(v)
|
| 2002 |
+
int(">" in s if not transpose else "v" in s), # 5:: delim_right(>)
|
| 2003 |
+
int("<" in s if not transpose else "^" in s), # 6:: delim_left(<)
|
| 2004 |
+
int("+" in s), # 7:: subtable(+)
|
| 2005 |
+
int("-" in s if not transpose else "|" in s), # 8:: subtable_horiz(-)
|
| 2006 |
+
int("|" in s if not transpose else "-" in s), # 9:: subtable_vert(|)
|
| 2007 |
+
int("_" in s), # 10:: fill(_); if delim, overwrite; else fit
|
| 2008 |
+
sp, # 11:: special(:) f-string for numbers
|
| 2009 |
+
)
|
| 2010 |
+
|
| 2011 |
+
def _put(obj, t, ij, empty):
|
| 2012 |
+
i, j = ij
|
| 2013 |
+
while i >= len(t):
|
| 2014 |
+
t.append([])
|
| 2015 |
+
while j >= len(t[i]):
|
| 2016 |
+
t[i].append(empty)
|
| 2017 |
+
t[i][j] = obj
|
| 2018 |
+
return
|
| 2019 |
+
|
| 2020 |
+
def parse(
|
| 2021 |
+
table,
|
| 2022 |
+
delimiter=" ",
|
| 2023 |
+
orientation="br",
|
| 2024 |
+
double_colon=True,
|
| 2025 |
+
):
|
| 2026 |
+
# disabling transpose
|
| 2027 |
+
transpose = False
|
| 2028 |
+
|
| 2029 |
+
# set up empty entry
|
| 2030 |
+
empty = ("", Table._spec(orientation, transpose))
|
| 2031 |
+
|
| 2032 |
+
# transpose
|
| 2033 |
+
t = []
|
| 2034 |
+
for i, row in enumerate(table):
|
| 2035 |
+
for j, item in enumerate(row):
|
| 2036 |
+
ij = (i, j) if not transpose else (j, i)
|
| 2037 |
+
if type(item) == tuple and len(item) == 2 and type(item[1]) == str:
|
| 2038 |
+
item = (item[0], Table._spec(item[1], transpose))
|
| 2039 |
+
elif double_colon and type(item) == str and "::" in item:
|
| 2040 |
+
x, s = item.split("::")
|
| 2041 |
+
item = (x, Table._spec(s, transpose))
|
| 2042 |
+
else:
|
| 2043 |
+
item = (item, Table._spec(orientation, transpose))
|
| 2044 |
+
Table._put(item, t, ij, empty)
|
| 2045 |
+
|
| 2046 |
+
# normalization
|
| 2047 |
+
maxcol = 0
|
| 2048 |
+
maxrow = len(t)
|
| 2049 |
+
for i, row in enumerate(t):
|
| 2050 |
+
# take element number into account
|
| 2051 |
+
maxcol = max(maxcol, len([i for i in row if not i[1][2]]))
|
| 2052 |
+
|
| 2053 |
+
# take subtables into account
|
| 2054 |
+
for j, (x, s) in enumerate(row):
|
| 2055 |
+
if s[7]:
|
| 2056 |
+
r = len(x)
|
| 2057 |
+
maxrow = max(maxrow, i + r)
|
| 2058 |
+
c = max(len(q) for q in x)
|
| 2059 |
+
maxcol = max(maxcol, j + c)
|
| 2060 |
+
elif s[8]:
|
| 2061 |
+
c = len(x)
|
| 2062 |
+
maxcol = max(maxcol, j + c)
|
| 2063 |
+
elif s[9]:
|
| 2064 |
+
r = len(x)
|
| 2065 |
+
maxrow = max(maxrow, i + r)
|
| 2066 |
+
totalcols = 2 * maxcol + 1
|
| 2067 |
+
totalrows = maxrow
|
| 2068 |
+
t += [[]] * (totalrows - len(t))
|
| 2069 |
+
newt = []
|
| 2070 |
+
delim = (delimiter, Table._spec("._" + orientation, transpose))
|
| 2071 |
+
for i, row in enumerate(t):
|
| 2072 |
+
wasd = False
|
| 2073 |
+
tcount = 0
|
| 2074 |
+
for j in range(totalcols):
|
| 2075 |
+
item = t[i][tcount] if tcount < len(t[i]) else empty
|
| 2076 |
+
isd = item[1][2]
|
| 2077 |
+
if wasd and isd:
|
| 2078 |
+
Table._put(empty, newt, (i, j), empty)
|
| 2079 |
+
wasd = False
|
| 2080 |
+
elif wasd and not isd:
|
| 2081 |
+
Table._put(item, newt, (i, j), empty)
|
| 2082 |
+
tcount += 1
|
| 2083 |
+
wasd = False
|
| 2084 |
+
elif not wasd and isd:
|
| 2085 |
+
Table._put(item, newt, (i, j), empty)
|
| 2086 |
+
tcount += 1
|
| 2087 |
+
wasd = True
|
| 2088 |
+
elif not wasd and not isd:
|
| 2089 |
+
Table._put(delim, newt, (i, j), empty)
|
| 2090 |
+
wasd = True
|
| 2091 |
+
t = newt
|
| 2092 |
+
|
| 2093 |
+
# normalization: add dummy last column for delimiter
|
| 2094 |
+
for row in t:
|
| 2095 |
+
row.append(empty)
|
| 2096 |
+
|
| 2097 |
+
# expand subtables
|
| 2098 |
+
delim_cols = [i for i in range(totalcols) if i % 2 == 0]
|
| 2099 |
+
while True:
|
| 2100 |
+
# find a table
|
| 2101 |
+
ij = None
|
| 2102 |
+
for i, row in enumerate(t):
|
| 2103 |
+
for j, item in enumerate(row):
|
| 2104 |
+
st, s = item
|
| 2105 |
+
if s[7]:
|
| 2106 |
+
ij = i, j, 7, st, s
|
| 2107 |
+
break
|
| 2108 |
+
elif s[8]:
|
| 2109 |
+
ij = i, j, 8, st, s
|
| 2110 |
+
break
|
| 2111 |
+
elif s[9]:
|
| 2112 |
+
ij = i, j, 9, st, s
|
| 2113 |
+
break
|
| 2114 |
+
if ij is not None:
|
| 2115 |
+
break
|
| 2116 |
+
if ij is None:
|
| 2117 |
+
break
|
| 2118 |
+
|
| 2119 |
+
# replace its specs
|
| 2120 |
+
i, j, k, st, s = ij
|
| 2121 |
+
s = list(s)
|
| 2122 |
+
s[7] = s[8] = s[9] = 0
|
| 2123 |
+
s = tuple(s)
|
| 2124 |
+
|
| 2125 |
+
# expand it
|
| 2126 |
+
if k == 7: # 2d table
|
| 2127 |
+
for x, row in enumerate(st):
|
| 2128 |
+
for y, obj in enumerate(row):
|
| 2129 |
+
a = i + x if not transpose else i + y
|
| 2130 |
+
b = j + 2 * y if not transpose else j + 2 * x
|
| 2131 |
+
Table._put((obj, s), t, (a, b), None)
|
| 2132 |
+
if k == 8: # subtable_horiz
|
| 2133 |
+
for y, obj in enumerate(st):
|
| 2134 |
+
Table._put((obj, s), t, (i, j + 2 * y), None)
|
| 2135 |
+
if k == 9: # subtable_vert
|
| 2136 |
+
for x, obj in enumerate(st):
|
| 2137 |
+
Table._put((obj, s), t, (i + x, j), None)
|
| 2138 |
+
|
| 2139 |
+
# return, finally
|
| 2140 |
+
return t
|
| 2141 |
+
|
| 2142 |
+
|
| 2143 |
+
class Resnet(nn.Module):
|
| 2144 |
+
def __init__(self, channels):
|
| 2145 |
+
super().__init__()
|
| 2146 |
+
self.channels = ch = channels
|
| 2147 |
+
self.net = nn.Sequential(
|
| 2148 |
+
nn.PReLU(ch),
|
| 2149 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 2150 |
+
nn.BatchNorm2d(ch),
|
| 2151 |
+
nn.PReLU(ch),
|
| 2152 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 2153 |
+
nn.BatchNorm2d(ch),
|
| 2154 |
+
)
|
| 2155 |
+
return
|
| 2156 |
+
|
| 2157 |
+
def forward(self, x):
|
| 2158 |
+
return x + self.net(x)
|
| 2159 |
+
|
| 2160 |
+
|
| 2161 |
+
class Synthesizer(nn.Module):
|
| 2162 |
+
def __init__(
|
| 2163 |
+
self, size, channels_image, channels_flow, channels_mask, channels_feature
|
| 2164 |
+
):
|
| 2165 |
+
super().__init__()
|
| 2166 |
+
self.size = size
|
| 2167 |
+
self.diam = diam(self.size)
|
| 2168 |
+
self.channels_image = cimg = channels_image
|
| 2169 |
+
self.channels_flow = cflow = channels_flow
|
| 2170 |
+
self.channels_mask = cmask = channels_mask
|
| 2171 |
+
self.channels_feature = cfeat = channels_feature
|
| 2172 |
+
self.channels = ch = cimg + cflow // 2 + cmask + cfeat
|
| 2173 |
+
self.interpolator = Interpolator(self.size, mode="bilinear")
|
| 2174 |
+
self.net = nn.Sequential(
|
| 2175 |
+
nn.Conv2d(ch + 3, 64, kernel_size=1, padding=0),
|
| 2176 |
+
Resnet(64),
|
| 2177 |
+
nn.Sequential(
|
| 2178 |
+
nn.PReLU(64),
|
| 2179 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 2180 |
+
nn.BatchNorm2d(32),
|
| 2181 |
+
),
|
| 2182 |
+
Resnet(32),
|
| 2183 |
+
nn.Sequential(
|
| 2184 |
+
nn.PReLU(32),
|
| 2185 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1),
|
| 2186 |
+
nn.BatchNorm2d(16),
|
| 2187 |
+
),
|
| 2188 |
+
Resnet(16),
|
| 2189 |
+
nn.Sequential(
|
| 2190 |
+
nn.PReLU(16),
|
| 2191 |
+
nn.Conv2d(16, 3, kernel_size=3, padding=1),
|
| 2192 |
+
),
|
| 2193 |
+
)
|
| 2194 |
+
return
|
| 2195 |
+
|
| 2196 |
+
def forward(self, images, flows, masks, features, return_more=False):
|
| 2197 |
+
itp = self.interpolator
|
| 2198 |
+
images = [
|
| 2199 |
+
(images[0] + images[1]) / 2,
|
| 2200 |
+
] + images
|
| 2201 |
+
logimgs = [itp(pixel_logit(i[:, :3])) for i in images]
|
| 2202 |
+
cat = torch.cat(
|
| 2203 |
+
[
|
| 2204 |
+
*logimgs,
|
| 2205 |
+
*[itp(f).norm(dim=1, keepdim=True) / self.diam for f in flows],
|
| 2206 |
+
*[itp(m) for m in masks],
|
| 2207 |
+
*[itp(f) for f in features],
|
| 2208 |
+
],
|
| 2209 |
+
dim=1,
|
| 2210 |
+
)
|
| 2211 |
+
residual = self.net(cat)
|
| 2212 |
+
return torch.sigmoid(logimgs[0] + 0.5 * residual), (
|
| 2213 |
+
locals() if return_more else None
|
| 2214 |
+
)
|
| 2215 |
+
|
| 2216 |
+
|
| 2217 |
+
class FlowZMetric(nn.Module):
|
| 2218 |
+
def __init__(self):
|
| 2219 |
+
super().__init__()
|
| 2220 |
+
return
|
| 2221 |
+
|
| 2222 |
+
def forward(self, img0, img1, flow0, flow1, return_more=False):
|
| 2223 |
+
# B(i0,f0) = i1
|
| 2224 |
+
# B(i1,f1) = i0
|
| 2225 |
+
# F(x,f0,z0)
|
| 2226 |
+
# F(x,f1,z1)
|
| 2227 |
+
img0 = kornia.color.rgb_to_lab(img0[:, :3])
|
| 2228 |
+
img1 = kornia.color.rgb_to_lab(img1[:, :3])
|
| 2229 |
+
return [
|
| 2230 |
+
-0.1 * (img1 - flow_backwarp(img0, flow0)).norm(dim=1, keepdim=True), # z0
|
| 2231 |
+
-0.1 * (img0 - flow_backwarp(img1, flow1)).norm(dim=1, keepdim=True), # z1
|
| 2232 |
+
], (locals() if return_more else None)
|
| 2233 |
+
|
| 2234 |
+
|
| 2235 |
+
class NEDT(nn.Module):
|
| 2236 |
+
def __init__(self):
|
| 2237 |
+
super().__init__()
|
| 2238 |
+
return
|
| 2239 |
+
|
| 2240 |
+
def forward(
|
| 2241 |
+
self,
|
| 2242 |
+
img,
|
| 2243 |
+
t=2.0,
|
| 2244 |
+
sigma_factor=1 / 540,
|
| 2245 |
+
k=1.6,
|
| 2246 |
+
epsilon=0.01,
|
| 2247 |
+
kernel_factor=4,
|
| 2248 |
+
exp_factor=540 / 15,
|
| 2249 |
+
return_more=False,
|
| 2250 |
+
):
|
| 2251 |
+
with torch.no_grad():
|
| 2252 |
+
dog = batch_dog(
|
| 2253 |
+
img,
|
| 2254 |
+
t=t,
|
| 2255 |
+
sigma=img.shape[-2] * sigma_factor,
|
| 2256 |
+
k=k,
|
| 2257 |
+
epsilon=epsilon,
|
| 2258 |
+
kernel_factor=kernel_factor,
|
| 2259 |
+
clip=False,
|
| 2260 |
+
)
|
| 2261 |
+
edt = batch_edt((dog > 0.5).float())
|
| 2262 |
+
ans = 1 - (-edt * exp_factor / max(edt.shape[-2:])).exp()
|
| 2263 |
+
return ans, (locals() if return_more else None)
|
| 2264 |
+
|
| 2265 |
+
|
| 2266 |
+
class HalfWarper(nn.Module):
|
| 2267 |
+
def __init__(self):
|
| 2268 |
+
super().__init__()
|
| 2269 |
+
self.channels_image = 4 * 3
|
| 2270 |
+
self.channels_flow = 2 * 2
|
| 2271 |
+
self.channels_mask = 2 * 1
|
| 2272 |
+
self.channels = self.channels_image + self.channels_flow + self.channels_mask
|
| 2273 |
+
|
| 2274 |
+
def morph_open(self, x, k):
|
| 2275 |
+
if k == 0:
|
| 2276 |
+
return x
|
| 2277 |
+
else:
|
| 2278 |
+
with torch.no_grad():
|
| 2279 |
+
return kornia.morphology.opening(x, torch.ones(k, k, device=x.device))
|
| 2280 |
+
|
| 2281 |
+
def forward(self, img0, img1, flow0, flow1, z0, z1, k, t=0.5, return_more=False):
|
| 2282 |
+
# forewarps
|
| 2283 |
+
flow0_ = (1 - t) * flow0
|
| 2284 |
+
flow1_ = t * flow1
|
| 2285 |
+
f01 = forewarp(img0, flow1_, mode="sm", metric=z1, mask=True)
|
| 2286 |
+
f10 = forewarp(img1, flow0_, mode="sm", metric=z0, mask=True)
|
| 2287 |
+
f01i, f01m = f01[:, :-1], self.morph_open(f01[:, -1:], k=k)
|
| 2288 |
+
f10i, f10m = f10[:, :-1], self.morph_open(f10[:, -1:], k=k)
|
| 2289 |
+
|
| 2290 |
+
# base guess
|
| 2291 |
+
base0 = f01m * f01i + (1 - f01m) * f10i
|
| 2292 |
+
base1 = f10m * f10i + (1 - f10m) * f01i
|
| 2293 |
+
ans = [
|
| 2294 |
+
[ # images
|
| 2295 |
+
base0,
|
| 2296 |
+
base1,
|
| 2297 |
+
f01i,
|
| 2298 |
+
f10i,
|
| 2299 |
+
],
|
| 2300 |
+
[ # flows
|
| 2301 |
+
flow0_,
|
| 2302 |
+
flow1_,
|
| 2303 |
+
],
|
| 2304 |
+
[ # masks
|
| 2305 |
+
f01m,
|
| 2306 |
+
f10m,
|
| 2307 |
+
],
|
| 2308 |
+
]
|
| 2309 |
+
return ans, (locals() if return_more else None)
|
| 2310 |
+
|
| 2311 |
+
|
| 2312 |
+
class ResnetFeatureExtractor(nn.Module):
|
| 2313 |
+
def __init__(self, inferserve_query, size_in=None):
|
| 2314 |
+
super().__init__()
|
| 2315 |
+
self.inferserve_query = iq = inferserve_query
|
| 2316 |
+
self.size_in = si = size_in
|
| 2317 |
+
if iq[0] == "torchvision":
|
| 2318 |
+
# use pytorch pretrained resnet50
|
| 2319 |
+
self.base_hparams = None
|
| 2320 |
+
resnet = tv.models.resnet50(pretrained=True)
|
| 2321 |
+
|
| 2322 |
+
self.resize = T.Resize(256)
|
| 2323 |
+
self.resnet_preprocess = T.Normalize(
|
| 2324 |
+
mean=[0.485, 0.456, 0.406],
|
| 2325 |
+
std=[0.229, 0.224, 0.225],
|
| 2326 |
+
)
|
| 2327 |
+
self.conv1 = resnet.conv1
|
| 2328 |
+
self.bn1 = resnet.bn1
|
| 2329 |
+
self.relu = resnet.relu # 64ch, 128p (assuming 256p input)
|
| 2330 |
+
self.maxpool = resnet.maxpool
|
| 2331 |
+
self.layer1 = resnet.layer1 # 256ch, 64p
|
| 2332 |
+
self.layer2 = resnet.layer2 # 512ch, 32p
|
| 2333 |
+
else:
|
| 2334 |
+
base = userving.infer_model_load(*iq).eval()
|
| 2335 |
+
self.base_hparams = base.hparams
|
| 2336 |
+
|
| 2337 |
+
self.resize = T.Resize(base.hparams.largs.size)
|
| 2338 |
+
self.resnet_preprocess = base.resnet_preprocess
|
| 2339 |
+
self.conv1 = base.resnet.conv1
|
| 2340 |
+
self.bn1 = base.resnet.bn1
|
| 2341 |
+
self.relu = base.resnet.relu # 64ch, 128p (assuming 256p input)
|
| 2342 |
+
self.maxpool = base.resnet.maxpool
|
| 2343 |
+
self.layer1 = base.resnet.layer1 # 256ch, 64p
|
| 2344 |
+
self.layer2 = base.resnet.layer2 # 512ch, 32p
|
| 2345 |
+
if self.size_in is None:
|
| 2346 |
+
self.sizes_out = None
|
| 2347 |
+
else:
|
| 2348 |
+
s = self.resize.size
|
| 2349 |
+
self.sizes_out = [
|
| 2350 |
+
pixel_ij(
|
| 2351 |
+
rescale_dry(si, (s // 2) / si[0]), rounding="ceil"
|
| 2352 |
+
), # conv1, 128p
|
| 2353 |
+
pixel_ij(
|
| 2354 |
+
rescale_dry(si, (s // 4) / si[0]), rounding="ceil"
|
| 2355 |
+
), # layer1, 64p
|
| 2356 |
+
pixel_ij(
|
| 2357 |
+
rescale_dry(si, (s // 8) / si[0]), rounding="ceil"
|
| 2358 |
+
), # layer2, 32p
|
| 2359 |
+
]
|
| 2360 |
+
self.channels = [
|
| 2361 |
+
64,
|
| 2362 |
+
256,
|
| 2363 |
+
512,
|
| 2364 |
+
]
|
| 2365 |
+
return
|
| 2366 |
+
|
| 2367 |
+
def forward(self, x, force_sizes_out=False, return_more=False):
|
| 2368 |
+
ans = []
|
| 2369 |
+
x = x[:, :3]
|
| 2370 |
+
x = self.resize(x)
|
| 2371 |
+
x = self.resnet_preprocess(x)
|
| 2372 |
+
x = self.conv1(x)
|
| 2373 |
+
x = self.bn1(x)
|
| 2374 |
+
x = self.relu(x)
|
| 2375 |
+
ans.append(x) # conv1
|
| 2376 |
+
x = self.maxpool(x)
|
| 2377 |
+
x = self.layer1(x)
|
| 2378 |
+
ans.append(x) # layer1
|
| 2379 |
+
x = self.layer2(x)
|
| 2380 |
+
ans.append(x) # layer2
|
| 2381 |
+
if force_sizes_out or (self.sizes_out is None):
|
| 2382 |
+
self.sizes_out = [tuple(q.shape[-2:]) for q in ans]
|
| 2383 |
+
return ans, (locals() if return_more else None)
|
| 2384 |
+
|
| 2385 |
+
|
| 2386 |
+
class NetNedt(nn.Module):
|
| 2387 |
+
def __init__(self):
|
| 2388 |
+
super().__init__()
|
| 2389 |
+
chin = 3 + 1 + 4 + 4 + 1 + 1
|
| 2390 |
+
ch = 16
|
| 2391 |
+
chout = 1
|
| 2392 |
+
self.net = nn.Sequential(
|
| 2393 |
+
nn.PReLU(chin),
|
| 2394 |
+
nn.Conv2d(chin, ch, kernel_size=3, padding=1),
|
| 2395 |
+
nn.BatchNorm2d(ch),
|
| 2396 |
+
nn.PReLU(ch),
|
| 2397 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 2398 |
+
nn.BatchNorm2d(ch),
|
| 2399 |
+
nn.PReLU(ch),
|
| 2400 |
+
nn.Conv2d(ch, chout, kernel_size=3, padding=1),
|
| 2401 |
+
)
|
| 2402 |
+
return
|
| 2403 |
+
|
| 2404 |
+
def forward(self, out_base, out_base_nedt, hw_imgs, hw_masks, return_more=False):
|
| 2405 |
+
cat = torch.cat(
|
| 2406 |
+
[
|
| 2407 |
+
out_base, # 3
|
| 2408 |
+
out_base_nedt, # 1
|
| 2409 |
+
hw_imgs[0], # 4
|
| 2410 |
+
hw_imgs[1], # 4
|
| 2411 |
+
hw_masks[0], # 1
|
| 2412 |
+
hw_masks[1], # 1
|
| 2413 |
+
],
|
| 2414 |
+
dim=1,
|
| 2415 |
+
)
|
| 2416 |
+
log = pixel_logit(cat.clip(0, 1))
|
| 2417 |
+
ans = torch.sigmoid(self.net(log))
|
| 2418 |
+
return ans, (locals() if return_more else None)
|
| 2419 |
+
|
| 2420 |
+
|
| 2421 |
+
class NetTail(nn.Module):
|
| 2422 |
+
def __init__(self):
|
| 2423 |
+
super().__init__()
|
| 2424 |
+
chin = 3 + 1 + 1
|
| 2425 |
+
ch = 16
|
| 2426 |
+
chout = 3
|
| 2427 |
+
self.net = nn.Sequential(
|
| 2428 |
+
nn.PReLU(chin),
|
| 2429 |
+
nn.Conv2d(chin, ch, kernel_size=3, padding=1),
|
| 2430 |
+
nn.BatchNorm2d(ch),
|
| 2431 |
+
nn.PReLU(ch),
|
| 2432 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 2433 |
+
nn.BatchNorm2d(ch),
|
| 2434 |
+
nn.PReLU(ch),
|
| 2435 |
+
nn.Conv2d(ch, ch, kernel_size=3, padding=1),
|
| 2436 |
+
nn.BatchNorm2d(ch),
|
| 2437 |
+
nn.PReLU(ch),
|
| 2438 |
+
nn.Conv2d(ch, chout, kernel_size=3, padding=1),
|
| 2439 |
+
)
|
| 2440 |
+
return
|
| 2441 |
+
|
| 2442 |
+
def forward(self, out_base, out_base_nedt, pred_nedt, return_more=False):
|
| 2443 |
+
cat = torch.cat(
|
| 2444 |
+
[
|
| 2445 |
+
out_base, # 3
|
| 2446 |
+
out_base_nedt, # 1
|
| 2447 |
+
pred_nedt, # 1
|
| 2448 |
+
],
|
| 2449 |
+
dim=1,
|
| 2450 |
+
)
|
| 2451 |
+
log = pixel_logit(cat.clip(0, 1))
|
| 2452 |
+
ans = torch.sigmoid(log[:, :3] + self.net(log))
|
| 2453 |
+
return ans, (locals() if return_more else None)
|
| 2454 |
+
|
| 2455 |
+
|
| 2456 |
+
class SoftsplatLite(nn.Module):
|
| 2457 |
+
def __init__(self):
|
| 2458 |
+
super().__init__()
|
| 2459 |
+
self.feature_extractor = ResnetFeatureExtractor(
|
| 2460 |
+
("torchvision", "resnet50"),
|
| 2461 |
+
(540, 960),
|
| 2462 |
+
)
|
| 2463 |
+
self.z_metric = FlowZMetric()
|
| 2464 |
+
self.flow_downsamplers = [
|
| 2465 |
+
Interpolator(s, mode="bilinear") for s in self.feature_extractor.sizes_out
|
| 2466 |
+
]
|
| 2467 |
+
self.gridnet_converter = GridnetConverter(
|
| 2468 |
+
self.feature_extractor.channels,
|
| 2469 |
+
[32, 64, 128],
|
| 2470 |
+
)
|
| 2471 |
+
self.gridnet = Gridnet(
|
| 2472 |
+
*[32, 64, 128],
|
| 2473 |
+
total_dropout_p=0.0,
|
| 2474 |
+
depth=1, # equivalent to u-net
|
| 2475 |
+
)
|
| 2476 |
+
self.nedt = NEDT()
|
| 2477 |
+
self.half_warper = HalfWarper()
|
| 2478 |
+
self.synthesizer = Synthesizer(
|
| 2479 |
+
(540, 960),
|
| 2480 |
+
self.half_warper.channels_image,
|
| 2481 |
+
self.half_warper.channels_flow,
|
| 2482 |
+
self.half_warper.channels_mask,
|
| 2483 |
+
self.gridnet.channels_0,
|
| 2484 |
+
)
|
| 2485 |
+
return
|
| 2486 |
+
|
| 2487 |
+
def forward(self, x, t=0.5, k=5, return_more=False):
|
| 2488 |
+
rm = return_more
|
| 2489 |
+
flow0, flow1 = x["flows"].swapaxes(0, 1)
|
| 2490 |
+
img0, img1 = x["images"][:, 0], x["images"][:, -1]
|
| 2491 |
+
(z0, z1), locs_z = self.z_metric(img0, img1, flow0, flow1, return_more=rm)
|
| 2492 |
+
img0 = torch.cat([img0, self.nedt(img0)[0]], dim=1)
|
| 2493 |
+
img1 = torch.cat([img1, self.nedt(img1)[0]], dim=1)
|
| 2494 |
+
|
| 2495 |
+
# images and flows
|
| 2496 |
+
(hw_imgs, hw_flows, hw_masks), locs_hw = self.half_warper(
|
| 2497 |
+
img0,
|
| 2498 |
+
img1,
|
| 2499 |
+
flow0,
|
| 2500 |
+
flow1,
|
| 2501 |
+
z0,
|
| 2502 |
+
z1,
|
| 2503 |
+
k,
|
| 2504 |
+
t=t,
|
| 2505 |
+
return_more=rm,
|
| 2506 |
+
)
|
| 2507 |
+
|
| 2508 |
+
# features
|
| 2509 |
+
feats0, locs_fe0 = self.feature_extractor(img0, return_more=rm)
|
| 2510 |
+
feats1, locs_fe1 = self.feature_extractor(img1, return_more=rm)
|
| 2511 |
+
warps = []
|
| 2512 |
+
for ft0, ft1, ds in zip(feats0, feats1, self.flow_downsamplers):
|
| 2513 |
+
(w, _, _), _ = self.half_warper(
|
| 2514 |
+
ft0,
|
| 2515 |
+
ft1,
|
| 2516 |
+
ds(flow0, 1),
|
| 2517 |
+
ds(flow1, 1),
|
| 2518 |
+
ds(z0),
|
| 2519 |
+
ds(z1),
|
| 2520 |
+
k,
|
| 2521 |
+
t=t,
|
| 2522 |
+
)
|
| 2523 |
+
warps.append((w[0] + w[1]) / 2)
|
| 2524 |
+
feats = self.gridnet(self.gridnet_converter(warps))
|
| 2525 |
+
|
| 2526 |
+
# synthesis
|
| 2527 |
+
pred, locs_synth = self.synthesizer(
|
| 2528 |
+
hw_imgs,
|
| 2529 |
+
hw_flows,
|
| 2530 |
+
hw_masks,
|
| 2531 |
+
[
|
| 2532 |
+
feats[0],
|
| 2533 |
+
],
|
| 2534 |
+
return_more=rm,
|
| 2535 |
+
)
|
| 2536 |
+
return pred, (locals() if rm else None)
|
| 2537 |
+
|
| 2538 |
+
|
| 2539 |
+
class DTM(nn.Module):
|
| 2540 |
+
def __init__(self):
|
| 2541 |
+
super().__init__()
|
| 2542 |
+
self.net_nedt = NetNedt()
|
| 2543 |
+
self.net_tail = NetTail()
|
| 2544 |
+
self.nedt = NEDT()
|
| 2545 |
+
return
|
| 2546 |
+
|
| 2547 |
+
def forward(self, x, out_base, locs_base, return_more=False):
|
| 2548 |
+
rm = return_more
|
| 2549 |
+
with torch.no_grad():
|
| 2550 |
+
out_base_nedt, locs_base_nedt = self.nedt(out_base, return_more=rm)
|
| 2551 |
+
hw_imgs, hw_masks = locs_base["hw_imgs"], locs_base["hw_masks"]
|
| 2552 |
+
pred_nedt, locs_nedt = self.net_nedt(
|
| 2553 |
+
out_base, out_base_nedt, hw_imgs, hw_masks, return_more=rm
|
| 2554 |
+
)
|
| 2555 |
+
pred, locs_tail = self.net_tail(
|
| 2556 |
+
out_base, out_base_nedt, pred_nedt.clone().detach(), return_more=rm
|
| 2557 |
+
)
|
| 2558 |
+
return torch.cat([pred, pred_nedt], dim=1), (locals() if rm else None)
|
| 2559 |
+
|
| 2560 |
+
|
| 2561 |
+
class RAFT(nn.Module):
|
| 2562 |
+
def __init__(self, path="/workspace/tensorrt/models/anime_interp_full.ckpt"):
|
| 2563 |
+
super().__init__()
|
| 2564 |
+
self.raft = RFR(
|
| 2565 |
+
Namespace(
|
| 2566 |
+
small=False,
|
| 2567 |
+
mixed_precision=False,
|
| 2568 |
+
)
|
| 2569 |
+
)
|
| 2570 |
+
if path is not None:
|
| 2571 |
+
sd = torch.load(path)["model_state_dict"]
|
| 2572 |
+
self.raft.load_state_dict(
|
| 2573 |
+
{
|
| 2574 |
+
k[len("module.flownet.") :]: v
|
| 2575 |
+
for k, v in sd.items()
|
| 2576 |
+
if k.startswith("module.flownet.")
|
| 2577 |
+
},
|
| 2578 |
+
strict=False,
|
| 2579 |
+
)
|
| 2580 |
+
return
|
| 2581 |
+
|
| 2582 |
+
def forward(self, img0, img1, flow0=None, iters=12, return_more=False):
|
| 2583 |
+
if flow0 is not None:
|
| 2584 |
+
flow0 = flow0.flip(dims=(1,))
|
| 2585 |
+
out = self.raft(img1, img0, iters=iters, flow_init=flow0)
|
| 2586 |
+
return out[0].flip(dims=(1,)), (locals() if return_more else None)
|
vfi_models/film/__init__.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from comfy.model_management import get_torch_device, soft_empty_cache
|
| 3 |
+
import bisect
|
| 4 |
+
import numpy as np
|
| 5 |
+
import typing
|
| 6 |
+
from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames
|
| 7 |
+
import pathlib
|
| 8 |
+
import gc
|
| 9 |
+
|
| 10 |
+
MODEL_TYPE = pathlib.Path(__file__).parent.name
|
| 11 |
+
DEVICE = get_torch_device()
|
| 12 |
+
def inference(model, img_batch_1, img_batch_2, inter_frames):
|
| 13 |
+
results = [
|
| 14 |
+
img_batch_1,
|
| 15 |
+
img_batch_2
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
idxes = [0, inter_frames + 1]
|
| 19 |
+
remains = list(range(1, inter_frames + 1))
|
| 20 |
+
|
| 21 |
+
splits = torch.linspace(0, 1, inter_frames + 2)
|
| 22 |
+
|
| 23 |
+
for _ in range(len(remains)):
|
| 24 |
+
starts = splits[idxes[:-1]]
|
| 25 |
+
ends = splits[idxes[1:]]
|
| 26 |
+
distances = ((splits[None, remains] - starts[:, None]) / (ends[:, None] - starts[:, None]) - .5).abs()
|
| 27 |
+
matrix = torch.argmin(distances).item()
|
| 28 |
+
start_i, step = np.unravel_index(matrix, distances.shape)
|
| 29 |
+
end_i = start_i + 1
|
| 30 |
+
|
| 31 |
+
x0 = results[start_i].to(DEVICE)
|
| 32 |
+
x1 = results[end_i].to(DEVICE)
|
| 33 |
+
dt = x0.new_full((1, 1), (splits[remains[step]] - splits[idxes[start_i]])) / (splits[idxes[end_i]] - splits[idxes[start_i]])
|
| 34 |
+
|
| 35 |
+
with torch.no_grad():
|
| 36 |
+
prediction = model(x0, x1, dt)
|
| 37 |
+
insert_position = bisect.bisect_left(idxes, remains[step])
|
| 38 |
+
idxes.insert(insert_position, remains[step])
|
| 39 |
+
results.insert(insert_position, prediction.clamp(0, 1).float())
|
| 40 |
+
del remains[step]
|
| 41 |
+
|
| 42 |
+
return [tensor.flip(0) for tensor in results]
|
| 43 |
+
|
| 44 |
+
class FILM_VFI:
|
| 45 |
+
@classmethod
|
| 46 |
+
def INPUT_TYPES(s):
|
| 47 |
+
return {
|
| 48 |
+
"required": {
|
| 49 |
+
"ckpt_name": (["film_net_fp32.pt"], ),
|
| 50 |
+
"frames": ("IMAGE", ),
|
| 51 |
+
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
| 52 |
+
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
|
| 53 |
+
},
|
| 54 |
+
"optional": {
|
| 55 |
+
"optional_interpolation_states": ("INTERPOLATION_STATES", )
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
RETURN_TYPES = ("IMAGE", )
|
| 60 |
+
FUNCTION = "vfi"
|
| 61 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
| 62 |
+
|
| 63 |
+
def vfi(
|
| 64 |
+
self,
|
| 65 |
+
ckpt_name: typing.AnyStr,
|
| 66 |
+
frames: torch.Tensor,
|
| 67 |
+
clear_cache_after_n_frames = 10,
|
| 68 |
+
multiplier: typing.SupportsInt = 2,
|
| 69 |
+
optional_interpolation_states: InterpolationStateList = None,
|
| 70 |
+
**kwargs
|
| 71 |
+
):
|
| 72 |
+
interpolation_states = optional_interpolation_states
|
| 73 |
+
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
|
| 74 |
+
model = torch.jit.load(model_path, map_location='cpu')
|
| 75 |
+
model.eval()
|
| 76 |
+
model = model.to(DEVICE)
|
| 77 |
+
dtype = torch.float32
|
| 78 |
+
|
| 79 |
+
frames = preprocess_frames(frames)
|
| 80 |
+
number_of_frames_processed_since_last_cleared_cuda_cache = 0
|
| 81 |
+
output_frames = []
|
| 82 |
+
|
| 83 |
+
if type(multiplier) == int:
|
| 84 |
+
multipliers = [multiplier] * len(frames)
|
| 85 |
+
else:
|
| 86 |
+
multipliers = list(map(int, multiplier))
|
| 87 |
+
multipliers += [2] * (len(frames) - len(multipliers) - 1)
|
| 88 |
+
for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it
|
| 89 |
+
if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr):
|
| 90 |
+
continue
|
| 91 |
+
#Ensure that input frames are in fp32 - the same dtype as model
|
| 92 |
+
frame_0 = frames[frame_itr:frame_itr+1].to(DEVICE).float()
|
| 93 |
+
frame_1 = frames[frame_itr+1:frame_itr+2].to(DEVICE).float()
|
| 94 |
+
relust = inference(model, frame_0, frame_1, multipliers[frame_itr] - 1)
|
| 95 |
+
output_frames.extend([frame.detach().cpu().to(dtype=dtype) for frame in relust[:-1]])
|
| 96 |
+
|
| 97 |
+
number_of_frames_processed_since_last_cleared_cuda_cache += 1
|
| 98 |
+
# Try to avoid a memory overflow by clearing cuda cache regularly
|
| 99 |
+
if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames:
|
| 100 |
+
print("Comfy-VFI: Clearing cache...", end = ' ')
|
| 101 |
+
soft_empty_cache()
|
| 102 |
+
number_of_frames_processed_since_last_cleared_cuda_cache = 0
|
| 103 |
+
print("Done cache clearing")
|
| 104 |
+
gc.collect()
|
| 105 |
+
|
| 106 |
+
output_frames.append(frames[-1:].to(dtype=dtype)) # Append final frame
|
| 107 |
+
output_frames = [frame.cpu() for frame in output_frames] #Ensure all frames are in cpu
|
| 108 |
+
out = torch.cat(output_frames, dim=0)
|
| 109 |
+
# clear cache for courtesy
|
| 110 |
+
print("Comfy-VFI: Final clearing cache...", end = ' ')
|
| 111 |
+
soft_empty_cache()
|
| 112 |
+
print("Done cache clearing")
|
| 113 |
+
return (postprocess_frames(out), )
|
vfi_models/film/film_arch.py
ADDED
|
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/dajes/frame-interpolation-pytorch/blob/main/feature_extractor.py
|
| 3 |
+
https://github.com/dajes/frame-interpolation-pytorch/blob/main/fusion.py
|
| 4 |
+
https://github.com/dajes/frame-interpolation-pytorch/blob/main/interpolator.py
|
| 5 |
+
https://github.com/dajes/frame-interpolation-pytorch/blob/main/pyramid_flow_estimator.py
|
| 6 |
+
https://github.com/dajes/frame-interpolation-pytorch/blob/main/util.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
"""PyTorch layer for extracting image features for the film_net interpolator.
|
| 10 |
+
|
| 11 |
+
The feature extractor implemented here converts an image pyramid into a pyramid
|
| 12 |
+
of deep features. The feature pyramid serves a similar purpose as U-Net
|
| 13 |
+
architecture's encoder, but we use a special cascaded architecture described in
|
| 14 |
+
Multi-view Image Fusion [1].
|
| 15 |
+
|
| 16 |
+
For comprehensiveness, below is a short description of the idea. While the
|
| 17 |
+
description is a bit involved, the cascaded feature pyramid can be used just
|
| 18 |
+
like any image feature pyramid.
|
| 19 |
+
|
| 20 |
+
Why cascaded architeture?
|
| 21 |
+
=========================
|
| 22 |
+
To understand the concept it is worth reviewing a traditional feature pyramid
|
| 23 |
+
first: *A traditional feature pyramid* as in U-net or in many optical flow
|
| 24 |
+
networks is built by alternating between convolutions and pooling, starting
|
| 25 |
+
from the input image.
|
| 26 |
+
|
| 27 |
+
It is well known that early features of such architecture correspond to low
|
| 28 |
+
level concepts such as edges in the image whereas later layers extract
|
| 29 |
+
semantically higher level concepts such as object classes etc. In other words,
|
| 30 |
+
the meaning of the filters in each resolution level is different. For problems
|
| 31 |
+
such as semantic segmentation and many others this is a desirable property.
|
| 32 |
+
|
| 33 |
+
However, the asymmetric features preclude sharing weights across resolution
|
| 34 |
+
levels in the feature extractor itself and in any subsequent neural networks
|
| 35 |
+
that follow. This can be a downside, since optical flow prediction, for
|
| 36 |
+
instance is symmetric across resolution levels. The cascaded feature
|
| 37 |
+
architecture addresses this shortcoming.
|
| 38 |
+
|
| 39 |
+
How is it built?
|
| 40 |
+
================
|
| 41 |
+
The *cascaded* feature pyramid contains feature vectors that have constant
|
| 42 |
+
length and meaning on each resolution level, except few of the finest ones. The
|
| 43 |
+
advantage of this is that the subsequent optical flow layer can learn
|
| 44 |
+
synergically from many resolutions. This means that coarse level prediction can
|
| 45 |
+
benefit from finer resolution training examples, which can be useful with
|
| 46 |
+
moderately sized datasets to avoid overfitting.
|
| 47 |
+
|
| 48 |
+
The cascaded feature pyramid is built by extracting shallower subtree pyramids,
|
| 49 |
+
each one of them similar to the traditional architecture. Each subtree
|
| 50 |
+
pyramid S_i is extracted starting from each resolution level:
|
| 51 |
+
|
| 52 |
+
image resolution 0 -> S_0
|
| 53 |
+
image resolution 1 -> S_1
|
| 54 |
+
image resolution 2 -> S_2
|
| 55 |
+
...
|
| 56 |
+
|
| 57 |
+
If we denote the features at level j of subtree i as S_i_j, the cascaded pyramid
|
| 58 |
+
is constructed by concatenating features as follows (assuming subtree depth=3):
|
| 59 |
+
|
| 60 |
+
lvl
|
| 61 |
+
feat_0 = concat( S_0_0 )
|
| 62 |
+
feat_1 = concat( S_1_0 S_0_1 )
|
| 63 |
+
feat_2 = concat( S_2_0 S_1_1 S_0_2 )
|
| 64 |
+
feat_3 = concat( S_3_0 S_2_1 S_1_2 )
|
| 65 |
+
feat_4 = concat( S_4_0 S_3_1 S_2_2 )
|
| 66 |
+
feat_5 = concat( S_5_0 S_4_1 S_3_2 )
|
| 67 |
+
....
|
| 68 |
+
|
| 69 |
+
In above, all levels except feat_0 and feat_1 have the same number of features
|
| 70 |
+
with similar semantic meaning. This enables training a single optical flow
|
| 71 |
+
predictor module shared by levels 2,3,4,5... . For more details and evaluation
|
| 72 |
+
see [1].
|
| 73 |
+
|
| 74 |
+
[1] Multi-view Image Fusion, Trinidad et al. 2019
|
| 75 |
+
"""
|
| 76 |
+
from typing import List
|
| 77 |
+
|
| 78 |
+
import torch
|
| 79 |
+
from torch import nn
|
| 80 |
+
from torch.nn import functional as F
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SubTreeExtractor(nn.Module):
|
| 84 |
+
"""Extracts a hierarchical set of features from an image.
|
| 85 |
+
|
| 86 |
+
This is a conventional, hierarchical image feature extractor, that extracts
|
| 87 |
+
[k, k*2, k*4... ] filters for the image pyramid where k=options.sub_levels.
|
| 88 |
+
Each level is followed by average pooling.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, in_channels=3, channels=64, n_layers=4):
|
| 92 |
+
super().__init__()
|
| 93 |
+
convs = []
|
| 94 |
+
for i in range(n_layers):
|
| 95 |
+
convs.append(nn.Sequential(
|
| 96 |
+
conv(in_channels, (channels << i), 3),
|
| 97 |
+
conv((channels << i), (channels << i), 3)
|
| 98 |
+
))
|
| 99 |
+
in_channels = channels << i
|
| 100 |
+
self.convs = nn.ModuleList(convs)
|
| 101 |
+
|
| 102 |
+
def forward(self, image: torch.Tensor, n: int) -> List[torch.Tensor]:
|
| 103 |
+
"""Extracts a pyramid of features from the image.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
image: TORCH.Tensor with shape BATCH_SIZE x HEIGHT x WIDTH x CHANNELS.
|
| 107 |
+
n: number of pyramid levels to extract. This can be less or equal to
|
| 108 |
+
options.sub_levels given in the __init__.
|
| 109 |
+
Returns:
|
| 110 |
+
The pyramid of features, starting from the finest level. Each element
|
| 111 |
+
contains the output after the last convolution on the corresponding
|
| 112 |
+
pyramid level.
|
| 113 |
+
"""
|
| 114 |
+
head = image
|
| 115 |
+
pyramid = []
|
| 116 |
+
for i, layer in enumerate(self.convs):
|
| 117 |
+
head = layer(head)
|
| 118 |
+
pyramid.append(head)
|
| 119 |
+
if i < n - 1:
|
| 120 |
+
head = F.avg_pool2d(head, kernel_size=2, stride=2)
|
| 121 |
+
return pyramid
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class FeatureExtractor(nn.Module):
|
| 125 |
+
"""Extracts features from an image pyramid using a cascaded architecture.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, in_channels=3, channels=64, sub_levels=4):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels)
|
| 131 |
+
self.sub_levels = sub_levels
|
| 132 |
+
|
| 133 |
+
def forward(self, image_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 134 |
+
"""Extracts a cascaded feature pyramid.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
image_pyramid: Image pyramid as a list, starting from the finest level.
|
| 138 |
+
Returns:
|
| 139 |
+
A pyramid of cascaded features.
|
| 140 |
+
"""
|
| 141 |
+
sub_pyramids: List[List[torch.Tensor]] = []
|
| 142 |
+
for i in range(len(image_pyramid)):
|
| 143 |
+
# At each level of the image pyramid, creates a sub_pyramid of features
|
| 144 |
+
# with 'sub_levels' pyramid levels, re-using the same SubTreeExtractor.
|
| 145 |
+
# We use the same instance since we want to share the weights.
|
| 146 |
+
#
|
| 147 |
+
# However, we cap the depth of the sub_pyramid so we don't create features
|
| 148 |
+
# that are beyond the coarsest level of the cascaded feature pyramid we
|
| 149 |
+
# want to generate.
|
| 150 |
+
capped_sub_levels = min(len(image_pyramid) - i, self.sub_levels)
|
| 151 |
+
sub_pyramids.append(self.extract_sublevels(image_pyramid[i], capped_sub_levels))
|
| 152 |
+
# Below we generate the cascades of features on each level of the feature
|
| 153 |
+
# pyramid. Assuming sub_levels=3, The layout of the features will be
|
| 154 |
+
# as shown in the example on file documentation above.
|
| 155 |
+
feature_pyramid: List[torch.Tensor] = []
|
| 156 |
+
for i in range(len(image_pyramid)):
|
| 157 |
+
features = sub_pyramids[i][0]
|
| 158 |
+
for j in range(1, self.sub_levels):
|
| 159 |
+
if j <= i:
|
| 160 |
+
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
|
| 161 |
+
feature_pyramid.append(features)
|
| 162 |
+
return feature_pyramid
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
"""The final fusion stage for the film_net frame interpolator.
|
| 175 |
+
|
| 176 |
+
The inputs to this module are the warped input images, image features and
|
| 177 |
+
flow fields, all aligned to the target frame (often midway point between the
|
| 178 |
+
two original inputs). The output is the final image. FILM has no explicit
|
| 179 |
+
occlusion handling -- instead using the abovementioned information this module
|
| 180 |
+
automatically decides how to best blend the inputs together to produce content
|
| 181 |
+
in areas where the pixels can only be borrowed from one of the inputs.
|
| 182 |
+
|
| 183 |
+
Similarly, this module also decides on how much to blend in each input in case
|
| 184 |
+
of fractional timestep that is not at the halfway point. For example, if the two
|
| 185 |
+
inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1,
|
| 186 |
+
it often makes most sense to favor the first input. However, this is not
|
| 187 |
+
always the case -- in particular in occluded pixels.
|
| 188 |
+
|
| 189 |
+
The architecture of the Fusion module follows U-net [1] architecture's decoder
|
| 190 |
+
side, e.g. each pyramid level consists of concatenation with upsampled coarser
|
| 191 |
+
level output, and two 3x3 convolutions.
|
| 192 |
+
|
| 193 |
+
The upsampling is implemented as 'resize convolution', e.g. nearest neighbor
|
| 194 |
+
upsampling followed by 2x2 convolution as explained in [2]. The classic U-net
|
| 195 |
+
uses max-pooling which has a tendency to create checkerboard artifacts.
|
| 196 |
+
|
| 197 |
+
[1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image
|
| 198 |
+
Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf
|
| 199 |
+
[2] https://distill.pub/2016/deconv-checkerboard/
|
| 200 |
+
"""
|
| 201 |
+
from typing import List
|
| 202 |
+
|
| 203 |
+
import torch
|
| 204 |
+
from torch import nn
|
| 205 |
+
from torch.nn import functional as F
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
_NUMBER_OF_COLOR_CHANNELS = 3
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_channels_at_level(level, filters):
|
| 212 |
+
n_images = 2
|
| 213 |
+
channels = _NUMBER_OF_COLOR_CHANNELS
|
| 214 |
+
flows = 2
|
| 215 |
+
|
| 216 |
+
return (sum(filters << i for i in range(level)) + channels + flows) * n_images
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class Fusion(nn.Module):
|
| 220 |
+
"""The decoder."""
|
| 221 |
+
|
| 222 |
+
def __init__(self, n_layers=4, specialized_layers=3, filters=64):
|
| 223 |
+
"""
|
| 224 |
+
Args:
|
| 225 |
+
m: specialized levels
|
| 226 |
+
"""
|
| 227 |
+
super().__init__()
|
| 228 |
+
|
| 229 |
+
# The final convolution that outputs RGB:
|
| 230 |
+
self.output_conv = nn.Conv2d(filters, 3, kernel_size=1)
|
| 231 |
+
|
| 232 |
+
# Each item 'convs[i]' will contain the list of convolutions to be applied
|
| 233 |
+
# for pyramid level 'i'.
|
| 234 |
+
self.convs = nn.ModuleList()
|
| 235 |
+
|
| 236 |
+
# Create the convolutions. Roughly following the feature extractor, we
|
| 237 |
+
# double the number of filters when the resolution halves, but only up to
|
| 238 |
+
# the specialized_levels, after which we use the same number of filters on
|
| 239 |
+
# all levels.
|
| 240 |
+
#
|
| 241 |
+
# We create the convs in fine-to-coarse order, so that the array index
|
| 242 |
+
# for the convs will correspond to our normal indexing (0=finest level).
|
| 243 |
+
# in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442)
|
| 244 |
+
|
| 245 |
+
in_channels = get_channels_at_level(n_layers, filters)
|
| 246 |
+
increase = 0
|
| 247 |
+
for i in range(n_layers)[::-1]:
|
| 248 |
+
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
|
| 249 |
+
convs = nn.ModuleList([
|
| 250 |
+
conv(in_channels, num_filters, size=2, activation=None),
|
| 251 |
+
conv(in_channels + (increase or num_filters), num_filters, size=3),
|
| 252 |
+
conv(num_filters, num_filters, size=3)]
|
| 253 |
+
)
|
| 254 |
+
self.convs.append(convs)
|
| 255 |
+
in_channels = num_filters
|
| 256 |
+
increase = get_channels_at_level(i, filters) - num_filters // 2
|
| 257 |
+
|
| 258 |
+
def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor:
|
| 259 |
+
"""Runs the fusion module.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
pyramid: The input feature pyramid as list of tensors. Each tensor being
|
| 263 |
+
in (B x H x W x C) format, with finest level tensor first.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
A batch of RGB images.
|
| 267 |
+
Raises:
|
| 268 |
+
ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in
|
| 269 |
+
the constructor.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
# As a slight difference to a conventional decoder (e.g. U-net), we don't
|
| 273 |
+
# apply any extra convolutions to the coarsest level, but just pass it
|
| 274 |
+
# to finer levels for concatenation. This choice has not been thoroughly
|
| 275 |
+
# evaluated, but is motivated by the educated guess that the fusion part
|
| 276 |
+
# probably does not need large spatial context, because at this point the
|
| 277 |
+
# features are spatially aligned by the preceding warp.
|
| 278 |
+
net = pyramid[-1]
|
| 279 |
+
|
| 280 |
+
# Loop starting from the 2nd coarsest level:
|
| 281 |
+
# for i in reversed(range(0, len(pyramid) - 1)):
|
| 282 |
+
for k, layers in enumerate(self.convs):
|
| 283 |
+
i = len(self.convs) - 1 - k
|
| 284 |
+
# Resize the tensor from coarser level to match for concatenation.
|
| 285 |
+
level_size = pyramid[i].shape[2:4]
|
| 286 |
+
net = F.interpolate(net, size=level_size, mode='nearest')
|
| 287 |
+
net = layers[0](net)
|
| 288 |
+
net = torch.cat([pyramid[i], net], dim=1)
|
| 289 |
+
net = layers[1](net)
|
| 290 |
+
net = layers[2](net)
|
| 291 |
+
net = self.output_conv(net)
|
| 292 |
+
return net
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
"""The film_net frame interpolator main model code.
|
| 305 |
+
|
| 306 |
+
Basics
|
| 307 |
+
======
|
| 308 |
+
The film_net is an end-to-end learned neural frame interpolator implemented as
|
| 309 |
+
a PyTorch model. It has the following inputs and outputs:
|
| 310 |
+
|
| 311 |
+
Inputs:
|
| 312 |
+
x0: image A.
|
| 313 |
+
x1: image B.
|
| 314 |
+
time: desired sub-frame time.
|
| 315 |
+
|
| 316 |
+
Outputs:
|
| 317 |
+
image: the predicted in-between image at the chosen time in range [0, 1].
|
| 318 |
+
|
| 319 |
+
Additional outputs include forward and backward warped image pyramids, flow
|
| 320 |
+
pyramids, etc., that can be visualized for debugging and analysis.
|
| 321 |
+
|
| 322 |
+
Note that many training sets only contain triplets with ground truth at
|
| 323 |
+
time=0.5. If a model has been trained with such training set, it will only work
|
| 324 |
+
well for synthesizing frames at time=0.5. Such models can only generate more
|
| 325 |
+
in-between frames using recursion.
|
| 326 |
+
|
| 327 |
+
Architecture
|
| 328 |
+
============
|
| 329 |
+
The inference consists of three main stages: 1) feature extraction 2) warping
|
| 330 |
+
3) fusion. On high-level, the architecture has similarities to Context-aware
|
| 331 |
+
Synthesis for Video Frame Interpolation [1], but the exact architecture is
|
| 332 |
+
closer to Multi-view Image Fusion [2] with some modifications for the frame
|
| 333 |
+
interpolation use-case.
|
| 334 |
+
|
| 335 |
+
Feature extraction stage employs the cascaded multi-scale architecture described
|
| 336 |
+
in [2]. The advantage of this architecture is that coarse level flow prediction
|
| 337 |
+
can be learned from finer resolution image samples. This is especially useful
|
| 338 |
+
to avoid overfitting with moderately sized datasets.
|
| 339 |
+
|
| 340 |
+
The warping stage uses a residual flow prediction idea that is similar to
|
| 341 |
+
PWC-Net [3], Multi-view Image Fusion [2] and many others.
|
| 342 |
+
|
| 343 |
+
The fusion stage is similar to U-Net's decoder where the skip connections are
|
| 344 |
+
connected to warped image and feature pyramids. This is described in [2].
|
| 345 |
+
|
| 346 |
+
Implementation Conventions
|
| 347 |
+
====================
|
| 348 |
+
Pyramids
|
| 349 |
+
--------
|
| 350 |
+
Throughtout the model, all image and feature pyramids are stored as python lists
|
| 351 |
+
with finest level first followed by downscaled versions obtained by successively
|
| 352 |
+
halving the resolution. The depths of all pyramids are determined by
|
| 353 |
+
options.pyramid_levels. The only exception to this is internal to the feature
|
| 354 |
+
extractor, where smaller feature pyramids are temporarily constructed with depth
|
| 355 |
+
options.sub_levels.
|
| 356 |
+
|
| 357 |
+
Color ranges & gamma
|
| 358 |
+
--------------------
|
| 359 |
+
The model code makes no assumptions on whether the images are in gamma or
|
| 360 |
+
linearized space or what is the range of RGB color values. So a model can be
|
| 361 |
+
trained with different choices. This does not mean that all the choices lead to
|
| 362 |
+
similar results. In practice the model has been proven to work well with RGB
|
| 363 |
+
scale = [0,1] with gamma-space images (i.e. not linearized).
|
| 364 |
+
|
| 365 |
+
[1] Context-aware Synthesis for Video Frame Interpolation, Niklaus and Liu, 2018
|
| 366 |
+
[2] Multi-view Image Fusion, Trinidad et al, 2019
|
| 367 |
+
[3] PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
|
| 368 |
+
"""
|
| 369 |
+
from typing import Dict, List
|
| 370 |
+
|
| 371 |
+
import torch
|
| 372 |
+
from torch import nn
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class Interpolator(nn.Module):
|
| 377 |
+
def __init__(
|
| 378 |
+
self,
|
| 379 |
+
pyramid_levels=7,
|
| 380 |
+
fusion_pyramid_levels=5,
|
| 381 |
+
specialized_levels=3,
|
| 382 |
+
sub_levels=4,
|
| 383 |
+
filters=64,
|
| 384 |
+
flow_convs=(3, 3, 3, 3),
|
| 385 |
+
flow_filters=(32, 64, 128, 256),
|
| 386 |
+
):
|
| 387 |
+
super().__init__()
|
| 388 |
+
self.pyramid_levels = pyramid_levels
|
| 389 |
+
self.fusion_pyramid_levels = fusion_pyramid_levels
|
| 390 |
+
|
| 391 |
+
self.extract = FeatureExtractor(3, filters, sub_levels)
|
| 392 |
+
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters)
|
| 393 |
+
self.fuse = Fusion(sub_levels, specialized_levels, filters)
|
| 394 |
+
|
| 395 |
+
def shuffle_images(self, x0, x1):
|
| 396 |
+
return [
|
| 397 |
+
build_image_pyramid(x0, self.pyramid_levels),
|
| 398 |
+
build_image_pyramid(x1, self.pyramid_levels)
|
| 399 |
+
]
|
| 400 |
+
|
| 401 |
+
def debug_forward(self, x0, x1, batch_dt) -> Dict[str, List[torch.Tensor]]:
|
| 402 |
+
image_pyramids = self.shuffle_images(x0, x1)
|
| 403 |
+
|
| 404 |
+
# Siamese feature pyramids:
|
| 405 |
+
feature_pyramids = [self.extract(image_pyramids[0]), self.extract(image_pyramids[1])]
|
| 406 |
+
|
| 407 |
+
# Predict forward flow.
|
| 408 |
+
forward_residual_flow_pyramid = self.predict_flow(feature_pyramids[0], feature_pyramids[1])
|
| 409 |
+
|
| 410 |
+
# Predict backward flow.
|
| 411 |
+
backward_residual_flow_pyramid = self.predict_flow(feature_pyramids[1], feature_pyramids[0])
|
| 412 |
+
|
| 413 |
+
# Concatenate features and images:
|
| 414 |
+
|
| 415 |
+
# Note that we keep up to 'fusion_pyramid_levels' levels as only those
|
| 416 |
+
# are used by the fusion module.
|
| 417 |
+
|
| 418 |
+
forward_flow_pyramid = flow_pyramid_synthesis(forward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
|
| 419 |
+
|
| 420 |
+
backward_flow_pyramid = flow_pyramid_synthesis(backward_residual_flow_pyramid)[:self.fusion_pyramid_levels]
|
| 421 |
+
|
| 422 |
+
# We multiply the flows with t and 1-t to warp to the desired fractional time.
|
| 423 |
+
#
|
| 424 |
+
# Note: In film_net we fix time to be 0.5, and recursively invoke the interpo-
|
| 425 |
+
# lator for multi-frame interpolation. Below, we create a constant tensor of
|
| 426 |
+
# shape [B]. We use the `time` tensor to infer the batch size.
|
| 427 |
+
mid_time = torch.full_like(batch_dt, .5)
|
| 428 |
+
backward_flow = multiply_pyramid(backward_flow_pyramid, mid_time[:, 0])
|
| 429 |
+
forward_flow = multiply_pyramid(forward_flow_pyramid, 1 - mid_time[:, 0])
|
| 430 |
+
|
| 431 |
+
pyramids_to_warp = [
|
| 432 |
+
concatenate_pyramids(image_pyramids[0][:self.fusion_pyramid_levels],
|
| 433 |
+
feature_pyramids[0][:self.fusion_pyramid_levels]),
|
| 434 |
+
concatenate_pyramids(image_pyramids[1][:self.fusion_pyramid_levels],
|
| 435 |
+
feature_pyramids[1][:self.fusion_pyramid_levels])
|
| 436 |
+
]
|
| 437 |
+
|
| 438 |
+
# Warp features and images using the flow. Note that we use backward warping
|
| 439 |
+
# and backward flow is used to read from image 0 and forward flow from
|
| 440 |
+
# image 1.
|
| 441 |
+
forward_warped_pyramid = pyramid_warp(pyramids_to_warp[0], backward_flow)
|
| 442 |
+
backward_warped_pyramid = pyramid_warp(pyramids_to_warp[1], forward_flow)
|
| 443 |
+
|
| 444 |
+
aligned_pyramid = concatenate_pyramids(forward_warped_pyramid,
|
| 445 |
+
backward_warped_pyramid)
|
| 446 |
+
aligned_pyramid = concatenate_pyramids(aligned_pyramid, backward_flow)
|
| 447 |
+
aligned_pyramid = concatenate_pyramids(aligned_pyramid, forward_flow)
|
| 448 |
+
|
| 449 |
+
return {
|
| 450 |
+
'image': [self.fuse(aligned_pyramid)],
|
| 451 |
+
'forward_residual_flow_pyramid': forward_residual_flow_pyramid,
|
| 452 |
+
'backward_residual_flow_pyramid': backward_residual_flow_pyramid,
|
| 453 |
+
'forward_flow_pyramid': forward_flow_pyramid,
|
| 454 |
+
'backward_flow_pyramid': backward_flow_pyramid,
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def forward(self, x0, x1, batch_dt) -> torch.Tensor:
|
| 459 |
+
return self.debug_forward(x0, x1, batch_dt)['image'][0]
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
"""PyTorch layer for estimating optical flow by a residual flow pyramid.
|
| 471 |
+
|
| 472 |
+
This approach of estimating optical flow between two images can be traced back
|
| 473 |
+
to [1], but is also used by later neural optical flow computation methods such
|
| 474 |
+
as SpyNet [2] and PWC-Net [3].
|
| 475 |
+
|
| 476 |
+
The basic idea is that the optical flow is first estimated in a coarse
|
| 477 |
+
resolution, then the flow is upsampled to warp the higher resolution image and
|
| 478 |
+
then a residual correction is computed and added to the estimated flow. This
|
| 479 |
+
process is repeated in a pyramid on coarse to fine order to successively
|
| 480 |
+
increase the resolution of both optical flow and the warped image.
|
| 481 |
+
|
| 482 |
+
In here, the optical flow predictor is used as an internal component for the
|
| 483 |
+
film_net frame interpolator, to warp the two input images into the inbetween,
|
| 484 |
+
target frame.
|
| 485 |
+
|
| 486 |
+
[1] F. Glazer, Hierarchical motion detection. PhD thesis, 1987.
|
| 487 |
+
[2] A. Ranjan and M. J. Black, Optical Flow Estimation using a Spatial Pyramid
|
| 488 |
+
Network. 2016
|
| 489 |
+
[3] D. Sun X. Yang, M-Y. Liu and J. Kautz, PWC-Net: CNNs for Optical Flow Using
|
| 490 |
+
Pyramid, Warping, and Cost Volume, 2017
|
| 491 |
+
"""
|
| 492 |
+
from typing import List
|
| 493 |
+
|
| 494 |
+
import torch
|
| 495 |
+
from torch import nn
|
| 496 |
+
from torch.nn import functional as F
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class FlowEstimator(nn.Module):
|
| 501 |
+
"""Small-receptive field predictor for computing the flow between two images.
|
| 502 |
+
|
| 503 |
+
This is used to compute the residual flow fields in PyramidFlowEstimator.
|
| 504 |
+
|
| 505 |
+
Note that while the number of 3x3 convolutions & filters to apply is
|
| 506 |
+
configurable, two extra 1x1 convolutions are appended to extract the flow in
|
| 507 |
+
the end.
|
| 508 |
+
|
| 509 |
+
Attributes:
|
| 510 |
+
name: The name of the layer
|
| 511 |
+
num_convs: Number of 3x3 convolutions to apply
|
| 512 |
+
num_filters: Number of filters in each 3x3 convolution
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def __init__(self, in_channels: int, num_convs: int, num_filters: int):
|
| 516 |
+
super(FlowEstimator, self).__init__()
|
| 517 |
+
|
| 518 |
+
self._convs = nn.ModuleList()
|
| 519 |
+
for i in range(num_convs):
|
| 520 |
+
self._convs.append(conv(in_channels=in_channels, out_channels=num_filters, size=3))
|
| 521 |
+
in_channels = num_filters
|
| 522 |
+
self._convs.append(conv(in_channels, num_filters // 2, size=1))
|
| 523 |
+
in_channels = num_filters // 2
|
| 524 |
+
# For the final convolution, we want no activation at all to predict the
|
| 525 |
+
# optical flow vector values. We have done extensive testing on explicitly
|
| 526 |
+
# bounding these values using sigmoid, but it turned out that having no
|
| 527 |
+
# activation gives better results.
|
| 528 |
+
self._convs.append(conv(in_channels, 2, size=1, activation=None))
|
| 529 |
+
|
| 530 |
+
def forward(self, features_a: torch.Tensor, features_b: torch.Tensor) -> torch.Tensor:
|
| 531 |
+
"""Estimates optical flow between two images.
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
features_a: per pixel feature vectors for image A (B x H x W x C)
|
| 535 |
+
features_b: per pixel feature vectors for image B (B x H x W x C)
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
A tensor with optical flow from A to B
|
| 539 |
+
"""
|
| 540 |
+
net = torch.cat([features_a, features_b], dim=1)
|
| 541 |
+
for conv in self._convs:
|
| 542 |
+
net = conv(net)
|
| 543 |
+
return net
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class PyramidFlowEstimator(nn.Module):
|
| 547 |
+
"""Predicts optical flow by coarse-to-fine refinement.
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
def __init__(self, filters: int = 64,
|
| 551 |
+
flow_convs: tuple = (3, 3, 3, 3),
|
| 552 |
+
flow_filters: tuple = (32, 64, 128, 256)):
|
| 553 |
+
super(PyramidFlowEstimator, self).__init__()
|
| 554 |
+
|
| 555 |
+
in_channels = filters << 1
|
| 556 |
+
predictors = []
|
| 557 |
+
for i in range(len(flow_convs)):
|
| 558 |
+
predictors.append(
|
| 559 |
+
FlowEstimator(
|
| 560 |
+
in_channels=in_channels,
|
| 561 |
+
num_convs=flow_convs[i],
|
| 562 |
+
num_filters=flow_filters[i]))
|
| 563 |
+
in_channels += filters << (i + 2)
|
| 564 |
+
self._predictor = predictors[-1]
|
| 565 |
+
self._predictors = nn.ModuleList(predictors[:-1][::-1])
|
| 566 |
+
|
| 567 |
+
def forward(self, feature_pyramid_a: List[torch.Tensor],
|
| 568 |
+
feature_pyramid_b: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 569 |
+
"""Estimates residual flow pyramids between two image pyramids.
|
| 570 |
+
|
| 571 |
+
Each image pyramid is represented as a list of tensors in fine-to-coarse
|
| 572 |
+
order. Each individual image is represented as a tensor where each pixel is
|
| 573 |
+
a vector of image features.
|
| 574 |
+
|
| 575 |
+
flow_pyramid_synthesis can be used to convert the residual flow
|
| 576 |
+
pyramid returned by this method into a flow pyramid, where each level
|
| 577 |
+
encodes the flow instead of a residual correction.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
feature_pyramid_a: image pyramid as a list in fine-to-coarse order
|
| 581 |
+
feature_pyramid_b: image pyramid as a list in fine-to-coarse order
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
List of flow tensors, in fine-to-coarse order, each level encoding the
|
| 585 |
+
difference against the bilinearly upsampled version from the coarser
|
| 586 |
+
level. The coarsest flow tensor, e.g. the last element in the array is the
|
| 587 |
+
'DC-term', e.g. not a residual (alternatively you can think of it being a
|
| 588 |
+
residual against zero).
|
| 589 |
+
"""
|
| 590 |
+
levels = len(feature_pyramid_a)
|
| 591 |
+
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
|
| 592 |
+
residuals = [v]
|
| 593 |
+
for i in range(levels - 2, len(self._predictors) - 1, -1):
|
| 594 |
+
# Upsamples the flow to match the current pyramid level. Also, scales the
|
| 595 |
+
# magnitude by two to reflect the new size.
|
| 596 |
+
level_size = feature_pyramid_a[i].shape[2:4]
|
| 597 |
+
v = F.interpolate(2 * v, size=level_size, mode='bilinear')
|
| 598 |
+
# Warp feature_pyramid_b[i] image based on the current flow estimate.
|
| 599 |
+
warped = warp(feature_pyramid_b[i], v)
|
| 600 |
+
# Estimate the residual flow between pyramid_a[i] and warped image:
|
| 601 |
+
v_residual = self._predictor(feature_pyramid_a[i], warped)
|
| 602 |
+
residuals.insert(0, v_residual)
|
| 603 |
+
v = v_residual + v
|
| 604 |
+
|
| 605 |
+
for k, predictor in enumerate(self._predictors):
|
| 606 |
+
i = len(self._predictors) - 1 - k
|
| 607 |
+
# Upsamples the flow to match the current pyramid level. Also, scales the
|
| 608 |
+
# magnitude by two to reflect the new size.
|
| 609 |
+
level_size = feature_pyramid_a[i].shape[2:4]
|
| 610 |
+
v = F.interpolate(2 * v, size=level_size, mode='bilinear')
|
| 611 |
+
# Warp feature_pyramid_b[i] image based on the current flow estimate.
|
| 612 |
+
warped = warp(feature_pyramid_b[i], v)
|
| 613 |
+
# Estimate the residual flow between pyramid_a[i] and warped image:
|
| 614 |
+
v_residual = predictor(feature_pyramid_a[i], warped)
|
| 615 |
+
residuals.insert(0, v_residual)
|
| 616 |
+
v = v_residual + v
|
| 617 |
+
return residuals
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
"""Various utilities used in the film_net frame interpolator model."""
|
| 629 |
+
from typing import List, Optional
|
| 630 |
+
|
| 631 |
+
import cv2
|
| 632 |
+
import numpy as np
|
| 633 |
+
import torch
|
| 634 |
+
from torch import nn
|
| 635 |
+
from torch.nn import functional as F
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def pad_batch(batch, align):
|
| 639 |
+
height, width = batch.shape[1:3]
|
| 640 |
+
height_to_pad = (align - height % align) if height % align != 0 else 0
|
| 641 |
+
width_to_pad = (align - width % align) if width % align != 0 else 0
|
| 642 |
+
|
| 643 |
+
crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)]
|
| 644 |
+
batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
|
| 645 |
+
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant')
|
| 646 |
+
return batch, crop_region
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def load_image(path, align=64):
|
| 650 |
+
image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
|
| 651 |
+
image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align)
|
| 652 |
+
return image_batch, crop_region
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]:
|
| 656 |
+
"""Builds an image pyramid from a given image.
|
| 657 |
+
|
| 658 |
+
The original image is included in the pyramid and the rest are generated by
|
| 659 |
+
successively halving the resolution.
|
| 660 |
+
|
| 661 |
+
Args:
|
| 662 |
+
image: the input image.
|
| 663 |
+
options: film_net options object
|
| 664 |
+
|
| 665 |
+
Returns:
|
| 666 |
+
A list of images starting from the finest with options.pyramid_levels items
|
| 667 |
+
"""
|
| 668 |
+
|
| 669 |
+
pyramid = []
|
| 670 |
+
for i in range(pyramid_levels):
|
| 671 |
+
pyramid.append(image)
|
| 672 |
+
if i < pyramid_levels - 1:
|
| 673 |
+
image = F.avg_pool2d(image, 2, 2)
|
| 674 |
+
return pyramid
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
|
| 678 |
+
"""Backward warps the image using the given flow.
|
| 679 |
+
|
| 680 |
+
Specifically, the output pixel in batch b, at position x, y will be computed
|
| 681 |
+
as follows:
|
| 682 |
+
(flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
|
| 683 |
+
output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
|
| 684 |
+
|
| 685 |
+
Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
|
| 686 |
+
y in position 1.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
image: An image with shape BxHxWxC.
|
| 690 |
+
flow: A flow with shape BxHxWx2, with the two channels denoting the relative
|
| 691 |
+
offset in order: (dx, dy).
|
| 692 |
+
Returns:
|
| 693 |
+
A warped image.
|
| 694 |
+
"""
|
| 695 |
+
flow = -flow.flip(1)
|
| 696 |
+
|
| 697 |
+
dtype = flow.dtype
|
| 698 |
+
device = flow.device
|
| 699 |
+
|
| 700 |
+
# warped = tfa_image.dense_image_warp(image, flow)
|
| 701 |
+
# Same as above but with pytorch
|
| 702 |
+
ls1 = 1 - 1 / flow.shape[3]
|
| 703 |
+
ls2 = 1 - 1 / flow.shape[2]
|
| 704 |
+
|
| 705 |
+
normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor(
|
| 706 |
+
[flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None]
|
| 707 |
+
normalized_flow2 = torch.stack([
|
| 708 |
+
torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1],
|
| 709 |
+
torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0],
|
| 710 |
+
], dim=3)
|
| 711 |
+
|
| 712 |
+
padding_mode = "border"
|
| 713 |
+
if device.type == "mps":
|
| 714 |
+
# https://github.com/pytorch/pytorch/issues/125098
|
| 715 |
+
padding_mode = "zeros"
|
| 716 |
+
normalized_flow2 = normalized_flow2.clamp(-1, 1)
|
| 717 |
+
warped = F.grid_sample(
|
| 718 |
+
input=image,
|
| 719 |
+
grid=normalized_flow2,
|
| 720 |
+
mode='bilinear',
|
| 721 |
+
padding_mode=padding_mode,
|
| 722 |
+
align_corners=False,
|
| 723 |
+
)
|
| 724 |
+
return warped.reshape(image.shape)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def multiply_pyramid(pyramid: List[torch.Tensor],
|
| 728 |
+
scalar: torch.Tensor) -> List[torch.Tensor]:
|
| 729 |
+
"""Multiplies all image batches in the pyramid by a batch of scalars.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
pyramid: Pyramid of image batches.
|
| 733 |
+
scalar: Batch of scalars.
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
An image pyramid with all images multiplied by the scalar.
|
| 737 |
+
"""
|
| 738 |
+
# To multiply each image with its corresponding scalar, we first transpose
|
| 739 |
+
# the batch of images from BxHxWxC-format to CxHxWxB. This can then be
|
| 740 |
+
# multiplied with a batch of scalars, then we transpose back to the standard
|
| 741 |
+
# BxHxWxC form.
|
| 742 |
+
return [image * scalar for image in pyramid]
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def flow_pyramid_synthesis(
|
| 746 |
+
residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 747 |
+
"""Converts a residual flow pyramid into a flow pyramid."""
|
| 748 |
+
flow = residual_pyramid[-1]
|
| 749 |
+
flow_pyramid: List[torch.Tensor] = [flow]
|
| 750 |
+
for residual_flow in residual_pyramid[:-1][::-1]:
|
| 751 |
+
level_size = residual_flow.shape[2:4]
|
| 752 |
+
flow = F.interpolate(2 * flow, size=level_size, mode='bilinear')
|
| 753 |
+
flow = residual_flow + flow
|
| 754 |
+
flow_pyramid.insert(0, flow)
|
| 755 |
+
return flow_pyramid
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def pyramid_warp(feature_pyramid: List[torch.Tensor],
|
| 759 |
+
flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 760 |
+
"""Warps the feature pyramid using the flow pyramid.
|
| 761 |
+
|
| 762 |
+
Args:
|
| 763 |
+
feature_pyramid: feature pyramid starting from the finest level.
|
| 764 |
+
flow_pyramid: flow fields, starting from the finest level.
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
Reverse warped feature pyramid.
|
| 768 |
+
"""
|
| 769 |
+
warped_feature_pyramid = []
|
| 770 |
+
for features, flow in zip(feature_pyramid, flow_pyramid):
|
| 771 |
+
warped_feature_pyramid.append(warp(features, flow))
|
| 772 |
+
return warped_feature_pyramid
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def concatenate_pyramids(pyramid1: List[torch.Tensor],
|
| 776 |
+
pyramid2: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 777 |
+
"""Concatenates each pyramid level together in the channel dimension."""
|
| 778 |
+
result = []
|
| 779 |
+
for features1, features2 in zip(pyramid1, pyramid2):
|
| 780 |
+
result.append(torch.cat([features1, features2], dim=1))
|
| 781 |
+
return result
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def conv(in_channels, out_channels, size, activation: Optional[str] = 'relu'):
|
| 785 |
+
# Since PyTorch doesn't have an in-built activation in Conv2d, we use a
|
| 786 |
+
# Sequential layer to combine Conv2d and Leaky ReLU in one module.
|
| 787 |
+
_conv = nn.Conv2d(
|
| 788 |
+
in_channels=in_channels,
|
| 789 |
+
out_channels=out_channels,
|
| 790 |
+
kernel_size=size,
|
| 791 |
+
padding='same')
|
| 792 |
+
if activation is None:
|
| 793 |
+
return _conv
|
| 794 |
+
assert activation == 'relu'
|
| 795 |
+
return nn.Sequential(
|
| 796 |
+
_conv,
|
| 797 |
+
nn.LeakyReLU(.2)
|
| 798 |
+
)
|
vfi_models/flavr/__init__.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from comfy.model_management import get_torch_device, soft_empty_cache
|
| 3 |
+
import numpy as np
|
| 4 |
+
import typing
|
| 5 |
+
from vfi_utils import InterpolationStateList, load_file_from_github_release, preprocess_frames, postprocess_frames, assert_batch_size
|
| 6 |
+
import pathlib
|
| 7 |
+
import warnings
|
| 8 |
+
from .flavr_arch import UNet_3D_3D, InputPadder
|
| 9 |
+
import gc
|
| 10 |
+
|
| 11 |
+
device = get_torch_device()
|
| 12 |
+
NBR_FRAME = 4
|
| 13 |
+
|
| 14 |
+
def build_flavr(model_path):
|
| 15 |
+
sd = torch.load(model_path)['state_dict']
|
| 16 |
+
sd = {k.partition("module.")[-1]:v for k,v in sd.items()}
|
| 17 |
+
|
| 18 |
+
#Ref: Class UNet_3D_3D
|
| 19 |
+
model = UNet_3D_3D("unet_18", n_inputs=NBR_FRAME, n_outputs=sd["outconv.1.weight"].shape[0] // 3, joinType="concat" , upmode="transpose")
|
| 20 |
+
model.load_state_dict(sd)
|
| 21 |
+
model.to(device).eval()
|
| 22 |
+
del sd
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
MODEL_TYPE = pathlib.Path(__file__).parent.name
|
| 26 |
+
CKPT_NAMES = ["FLAVR_2x.pth", "FLAVR_4x.pth", "FLAVR_8x.pth"]
|
| 27 |
+
|
| 28 |
+
class FLAVR_VFI:
|
| 29 |
+
@classmethod
|
| 30 |
+
def INPUT_TYPES(s):
|
| 31 |
+
return {
|
| 32 |
+
"required": {
|
| 33 |
+
"ckpt_name": (CKPT_NAMES, ),
|
| 34 |
+
"frames": ("IMAGE", ),
|
| 35 |
+
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
| 36 |
+
"multiplier": ("INT", {"default": 2, "min": 2, "max": 2}), #TODO: Implement recursively invoking interpolator for multi-frame interpolation
|
| 37 |
+
"duplicate_first_last_frames": ("BOOLEAN", {"default": False})
|
| 38 |
+
},
|
| 39 |
+
"optional": {
|
| 40 |
+
"optional_interpolation_states": ("INTERPOLATION_STATES", )
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
RETURN_TYPES = ("IMAGE", )
|
| 45 |
+
FUNCTION = "vfi"
|
| 46 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
| 47 |
+
|
| 48 |
+
#Reference: https://github.com/danier97/ST-MFNet/blob/main/interpolate_yuv.py#L93
|
| 49 |
+
def vfi(
|
| 50 |
+
self,
|
| 51 |
+
ckpt_name: typing.AnyStr,
|
| 52 |
+
frames: torch.Tensor,
|
| 53 |
+
clear_cache_after_n_frames = 10,
|
| 54 |
+
multiplier: typing.SupportsInt = 2,
|
| 55 |
+
duplicate_first_last_frames: bool = False,
|
| 56 |
+
optional_interpolation_states: InterpolationStateList = None,
|
| 57 |
+
**kwargs
|
| 58 |
+
):
|
| 59 |
+
if multiplier != 2:
|
| 60 |
+
warnings.warn("Currently, FLAVR only supports 2x interpolation. The process will continue but please set multiplier=2 afterward")
|
| 61 |
+
|
| 62 |
+
assert_batch_size(frames, batch_size=4, vfi_name="ST-MFNet")
|
| 63 |
+
interpolation_states = optional_interpolation_states
|
| 64 |
+
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
|
| 65 |
+
model = build_flavr(model_path)
|
| 66 |
+
frames = preprocess_frames(frames)
|
| 67 |
+
padder = InputPadder(frames.shape, 16)
|
| 68 |
+
frames = padder.pad(frames)
|
| 69 |
+
|
| 70 |
+
number_of_frames_processed_since_last_cleared_cuda_cache = 0
|
| 71 |
+
output_frames = []
|
| 72 |
+
for frame_itr in range(len(frames) - 3):
|
| 73 |
+
#Does skipping frame i+1 make sanse in this case?
|
| 74 |
+
if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr) and interpolation_states.is_frame_skipped(frame_itr + 1):
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
#Ensure that input frames are in fp32 - the same dtype as model
|
| 78 |
+
frame0, frame1, frame2, frame3 = (
|
| 79 |
+
frames[frame_itr:frame_itr+1].float(),
|
| 80 |
+
frames[frame_itr+1:frame_itr+2].float(),
|
| 81 |
+
frames[frame_itr+2:frame_itr+3].float(),
|
| 82 |
+
frames[frame_itr+3:frame_itr+4].float()
|
| 83 |
+
)
|
| 84 |
+
new_frame = model([frame0.to(device), frame1.to(device), frame2.to(device), frame3.to(device)])[0].detach().cpu()
|
| 85 |
+
number_of_frames_processed_since_last_cleared_cuda_cache += 2
|
| 86 |
+
|
| 87 |
+
if frame_itr == 0:
|
| 88 |
+
output_frames.append(frame0)
|
| 89 |
+
if duplicate_first_last_frames:
|
| 90 |
+
output_frames.append(frame0) # repeat the first frame
|
| 91 |
+
output_frames.append(frame1)
|
| 92 |
+
output_frames.append(new_frame)
|
| 93 |
+
output_frames.append(frame2)
|
| 94 |
+
if frame_itr == len(frames) - 4:
|
| 95 |
+
output_frames.append(frame3)
|
| 96 |
+
if duplicate_first_last_frames:
|
| 97 |
+
output_frames.append(frame3) # repeat the last frame
|
| 98 |
+
|
| 99 |
+
# Try to avoid a memory overflow by clearing cuda cache regularly
|
| 100 |
+
if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames:
|
| 101 |
+
print("Comfy-VFI: Clearing cache...", end = ' ')
|
| 102 |
+
soft_empty_cache()
|
| 103 |
+
number_of_frames_processed_since_last_cleared_cuda_cache = 0
|
| 104 |
+
print("Done cache clearing")
|
| 105 |
+
gc.collect()
|
| 106 |
+
|
| 107 |
+
dtype = torch.float32
|
| 108 |
+
output_frames = [frame.cpu().to(dtype=dtype) for frame in output_frames] #Ensure all frames are in cpu
|
| 109 |
+
out = torch.cat(output_frames, dim=0)
|
| 110 |
+
out = padder.unpad(out)
|
| 111 |
+
# clear cache for courtesy
|
| 112 |
+
print("Comfy-VFI: Final clearing cache...", end=' ')
|
| 113 |
+
soft_empty_cache()
|
| 114 |
+
print("Done cache clearing")
|
| 115 |
+
return (postprocess_frames(out), )
|
vfi_models/flavr/flavr_arch.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/tarun005/FLAVR/blob/main/model/FLAVR_arch.py
|
| 3 |
+
https://github.com/tarun005/FLAVR/blob/main/model/resnet_3D.py (only SEGating)
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import numpy as np
|
| 7 |
+
import importlib
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
class SEGating(nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(self , inplanes , reduction=16):
|
| 16 |
+
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.pool = nn.AdaptiveAvgPool3d(1)
|
| 20 |
+
self.attn_layer = nn.Sequential(
|
| 21 |
+
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
|
| 22 |
+
nn.Sigmoid()
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def forward(self , x):
|
| 26 |
+
|
| 27 |
+
out = self.pool(x)
|
| 28 |
+
y = self.attn_layer(out)
|
| 29 |
+
return x * y
|
| 30 |
+
|
| 31 |
+
def joinTensors(X1 , X2 , type="concat"):
|
| 32 |
+
|
| 33 |
+
if type == "concat":
|
| 34 |
+
return torch.cat([X1 , X2] , dim=1)
|
| 35 |
+
elif type == "add":
|
| 36 |
+
return X1 + X2
|
| 37 |
+
else:
|
| 38 |
+
return X1
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Conv_2d(nn.Module):
|
| 42 |
+
|
| 43 |
+
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
|
| 44 |
+
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
|
| 47 |
+
|
| 48 |
+
if batchnorm:
|
| 49 |
+
self.conv += [nn.BatchNorm2d(out_ch)]
|
| 50 |
+
|
| 51 |
+
self.conv = nn.Sequential(*self.conv)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
|
| 55 |
+
return self.conv(x)
|
| 56 |
+
|
| 57 |
+
class upConv3D(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
|
| 60 |
+
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.upmode = upmode
|
| 64 |
+
|
| 65 |
+
if self.upmode=="transpose":
|
| 66 |
+
self.upconv = nn.ModuleList(
|
| 67 |
+
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
|
| 68 |
+
SEGating(out_ch)
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
self.upconv = nn.ModuleList(
|
| 74 |
+
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
|
| 75 |
+
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
|
| 76 |
+
SEGating(out_ch)
|
| 77 |
+
]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
if batchnorm:
|
| 81 |
+
self.upconv += [nn.BatchNorm3d(out_ch)]
|
| 82 |
+
|
| 83 |
+
self.upconv = nn.Sequential(*self.upconv)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
|
| 87 |
+
return self.upconv(x)
|
| 88 |
+
|
| 89 |
+
class Conv_3d(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
|
| 92 |
+
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
|
| 95 |
+
SEGating(out_ch)
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
if batchnorm:
|
| 99 |
+
self.conv += [nn.BatchNorm3d(out_ch)]
|
| 100 |
+
|
| 101 |
+
self.conv = nn.Sequential(*self.conv)
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
|
| 105 |
+
return self.conv(x)
|
| 106 |
+
|
| 107 |
+
class upConv2D(nn.Module):
|
| 108 |
+
|
| 109 |
+
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
|
| 110 |
+
|
| 111 |
+
super().__init__()
|
| 112 |
+
|
| 113 |
+
self.upmode = upmode
|
| 114 |
+
|
| 115 |
+
if self.upmode=="transpose":
|
| 116 |
+
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
self.upconv = [
|
| 120 |
+
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
|
| 121 |
+
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
if batchnorm:
|
| 125 |
+
self.upconv += [nn.BatchNorm2d(out_ch)]
|
| 126 |
+
|
| 127 |
+
self.upconv = nn.Sequential(*self.upconv)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
|
| 131 |
+
return self.upconv(x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class UNet_3D_3D(nn.Module):
|
| 135 |
+
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
|
| 136 |
+
super().__init__()
|
| 137 |
+
|
| 138 |
+
nf = [512 , 256 , 128 , 64]
|
| 139 |
+
out_channels = 3*n_outputs
|
| 140 |
+
self.joinType = joinType
|
| 141 |
+
self.n_outputs = n_outputs
|
| 142 |
+
|
| 143 |
+
growth = 2 if joinType == "concat" else 1
|
| 144 |
+
self.lrelu = nn.LeakyReLU(0.2, True)
|
| 145 |
+
|
| 146 |
+
unet_3D = importlib.import_module(".resnet_3D", "vfi_models.flavr")
|
| 147 |
+
if n_outputs > 1:
|
| 148 |
+
unet_3D.useBias = True
|
| 149 |
+
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
|
| 150 |
+
|
| 151 |
+
self.decoder = nn.Sequential(
|
| 152 |
+
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
|
| 153 |
+
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
|
| 154 |
+
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
|
| 155 |
+
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
|
| 156 |
+
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
|
| 160 |
+
|
| 161 |
+
self.outconv = nn.Sequential(
|
| 162 |
+
nn.ReflectionPad2d(3),
|
| 163 |
+
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(self, images):
|
| 167 |
+
|
| 168 |
+
images = torch.stack(images , dim=2)
|
| 169 |
+
|
| 170 |
+
## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
|
| 171 |
+
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
|
| 172 |
+
images = images-mean_
|
| 173 |
+
|
| 174 |
+
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
|
| 175 |
+
|
| 176 |
+
dx_3 = self.lrelu(self.decoder[0](x_4))
|
| 177 |
+
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
|
| 178 |
+
|
| 179 |
+
dx_2 = self.lrelu(self.decoder[1](dx_3))
|
| 180 |
+
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
|
| 181 |
+
|
| 182 |
+
dx_1 = self.lrelu(self.decoder[2](dx_2))
|
| 183 |
+
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
|
| 184 |
+
|
| 185 |
+
dx_0 = self.lrelu(self.decoder[3](dx_1))
|
| 186 |
+
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
|
| 187 |
+
|
| 188 |
+
dx_out = self.lrelu(self.decoder[4](dx_0))
|
| 189 |
+
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
|
| 190 |
+
|
| 191 |
+
out = self.lrelu(self.feature_fuse(dx_out))
|
| 192 |
+
out = self.outconv(out)
|
| 193 |
+
|
| 194 |
+
out = torch.split(out, dim=1, split_size_or_sections=3)
|
| 195 |
+
mean_ = mean_.squeeze(2)
|
| 196 |
+
out = [o+mean_ for o in out]
|
| 197 |
+
|
| 198 |
+
return out
|
| 199 |
+
|
| 200 |
+
class InputPadder:
|
| 201 |
+
""" Pads images such that dimensions are divisible by divisor """
|
| 202 |
+
def __init__(self, dims, divisor=16):
|
| 203 |
+
self.ht, self.wd = dims[-2:]
|
| 204 |
+
pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor
|
| 205 |
+
pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor
|
| 206 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 207 |
+
|
| 208 |
+
def pad(self, input_tensor):
|
| 209 |
+
return F.pad(input_tensor, self._pad, mode='replicate')
|
| 210 |
+
|
| 211 |
+
def unpad(self, input_tensor):
|
| 212 |
+
return self._unpad(input_tensor)
|
| 213 |
+
|
| 214 |
+
def _unpad(self, x):
|
| 215 |
+
ht, wd = x.shape[-2:]
|
| 216 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 217 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
vfi_models/flavr/resnet_3D.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/pytorch/vision/tree/master/torchvision/models/video
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
__all__ = ['unet_18', 'unet_34']
|
| 7 |
+
|
| 8 |
+
useBias = False
|
| 9 |
+
|
| 10 |
+
class identity(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self , *args , **kwargs):
|
| 13 |
+
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
def forward(self , x):
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
class Conv3DSimple(nn.Conv3d):
|
| 20 |
+
def __init__(self,
|
| 21 |
+
in_planes,
|
| 22 |
+
out_planes,
|
| 23 |
+
midplanes=None,
|
| 24 |
+
stride=1,
|
| 25 |
+
padding=1):
|
| 26 |
+
|
| 27 |
+
super(Conv3DSimple, self).__init__(
|
| 28 |
+
in_channels=in_planes,
|
| 29 |
+
out_channels=out_planes,
|
| 30 |
+
kernel_size=(3, 3, 3),
|
| 31 |
+
stride=stride,
|
| 32 |
+
padding=padding,
|
| 33 |
+
bias=useBias)
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def get_downsample_stride(stride , temporal_stride):
|
| 37 |
+
if temporal_stride:
|
| 38 |
+
return (temporal_stride, stride, stride)
|
| 39 |
+
else:
|
| 40 |
+
return (stride , stride , stride)
|
| 41 |
+
|
| 42 |
+
class BasicStem(nn.Sequential):
|
| 43 |
+
"""The default conv-batchnorm-relu stem
|
| 44 |
+
"""
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super().__init__(
|
| 47 |
+
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
|
| 48 |
+
padding=(1, 3, 3), bias=useBias),
|
| 49 |
+
batchnorm(64),
|
| 50 |
+
nn.ReLU(inplace=False))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Conv2Plus1D(nn.Sequential):
|
| 54 |
+
|
| 55 |
+
def __init__(self,
|
| 56 |
+
in_planes,
|
| 57 |
+
out_planes,
|
| 58 |
+
midplanes,
|
| 59 |
+
stride=1,
|
| 60 |
+
padding=1):
|
| 61 |
+
if not isinstance(stride , int):
|
| 62 |
+
temporal_stride , stride , stride = stride
|
| 63 |
+
else:
|
| 64 |
+
temporal_stride = stride
|
| 65 |
+
|
| 66 |
+
super(Conv2Plus1D, self).__init__(
|
| 67 |
+
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
|
| 68 |
+
stride=(1, stride, stride), padding=(0, padding, padding),
|
| 69 |
+
bias=False),
|
| 70 |
+
# batchnorm(midplanes),
|
| 71 |
+
nn.ReLU(inplace=True),
|
| 72 |
+
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
|
| 73 |
+
stride=(temporal_stride, 1, 1), padding=(padding, 0, 0),
|
| 74 |
+
bias=False))
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def get_downsample_stride(stride , temporal_stride):
|
| 78 |
+
if temporal_stride:
|
| 79 |
+
return (temporal_stride, stride, stride)
|
| 80 |
+
else:
|
| 81 |
+
return (stride , stride , stride)
|
| 82 |
+
|
| 83 |
+
class R2Plus1dStem(nn.Sequential):
|
| 84 |
+
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
|
| 85 |
+
"""
|
| 86 |
+
def __init__(self):
|
| 87 |
+
super().__init__(
|
| 88 |
+
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
|
| 89 |
+
stride=(1, 2, 2), padding=(0, 3, 3),
|
| 90 |
+
bias=False),
|
| 91 |
+
batchnorm(45),
|
| 92 |
+
nn.ReLU(inplace=True),
|
| 93 |
+
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
|
| 94 |
+
stride=(1, 1, 1), padding=(1, 0, 0),
|
| 95 |
+
bias=False),
|
| 96 |
+
batchnorm(64),
|
| 97 |
+
nn.ReLU(inplace=True))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class SEGating(nn.Module):
|
| 101 |
+
|
| 102 |
+
def __init__(self , inplanes , reduction=16):
|
| 103 |
+
|
| 104 |
+
super().__init__()
|
| 105 |
+
|
| 106 |
+
self.pool = nn.AdaptiveAvgPool3d(1)
|
| 107 |
+
self.attn_layer = nn.Sequential(
|
| 108 |
+
nn.Conv3d(inplanes , inplanes , kernel_size=1 , stride=1 , bias=True),
|
| 109 |
+
nn.Sigmoid()
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self , x):
|
| 113 |
+
|
| 114 |
+
out = self.pool(x)
|
| 115 |
+
y = self.attn_layer(out)
|
| 116 |
+
return x * y
|
| 117 |
+
|
| 118 |
+
class BasicBlock(nn.Module):
|
| 119 |
+
|
| 120 |
+
expansion = 1
|
| 121 |
+
|
| 122 |
+
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
|
| 123 |
+
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)
|
| 124 |
+
|
| 125 |
+
super(BasicBlock, self).__init__()
|
| 126 |
+
self.conv1 = nn.Sequential(
|
| 127 |
+
conv_builder(inplanes, planes, midplanes, stride),
|
| 128 |
+
batchnorm(planes),
|
| 129 |
+
nn.ReLU(inplace=True)
|
| 130 |
+
)
|
| 131 |
+
self.conv2 = nn.Sequential(
|
| 132 |
+
conv_builder(planes, planes, midplanes),
|
| 133 |
+
batchnorm(planes)
|
| 134 |
+
)
|
| 135 |
+
self.fg = SEGating(planes) ## Feature Gating
|
| 136 |
+
self.relu = nn.ReLU(inplace=True)
|
| 137 |
+
self.downsample = downsample
|
| 138 |
+
self.stride = stride
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
residual = x
|
| 142 |
+
out = self.conv1(x)
|
| 143 |
+
out = self.conv2(out)
|
| 144 |
+
out = self.fg(out)
|
| 145 |
+
if self.downsample is not None:
|
| 146 |
+
residual = self.downsample(x)
|
| 147 |
+
|
| 148 |
+
out += residual
|
| 149 |
+
out = self.relu(out)
|
| 150 |
+
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
class VideoResNet(nn.Module):
|
| 154 |
+
|
| 155 |
+
def __init__(self, block, conv_makers, layers,
|
| 156 |
+
stem, zero_init_residual=False):
|
| 157 |
+
"""Generic resnet video generator.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
block (nn.Module): resnet building block
|
| 161 |
+
conv_makers (list(functions)): generator function for each layer
|
| 162 |
+
layers (List[int]): number of blocks per layer
|
| 163 |
+
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
|
| 164 |
+
"""
|
| 165 |
+
super(VideoResNet, self).__init__()
|
| 166 |
+
self.inplanes = 64
|
| 167 |
+
|
| 168 |
+
self.stem = stem()
|
| 169 |
+
|
| 170 |
+
self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1 )
|
| 171 |
+
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2 , temporal_stride=1)
|
| 172 |
+
self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2 , temporal_stride=1)
|
| 173 |
+
self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=1, temporal_stride=1)
|
| 174 |
+
|
| 175 |
+
# init weights
|
| 176 |
+
self._initialize_weights()
|
| 177 |
+
|
| 178 |
+
if zero_init_residual:
|
| 179 |
+
for m in self.modules():
|
| 180 |
+
if isinstance(m, Bottleneck):
|
| 181 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
x_0 = self.stem(x)
|
| 185 |
+
x_1 = self.layer1(x_0)
|
| 186 |
+
x_2 = self.layer2(x_1)
|
| 187 |
+
x_3 = self.layer3(x_2)
|
| 188 |
+
x_4 = self.layer4(x_3)
|
| 189 |
+
return x_0 , x_1 , x_2 , x_3 , x_4
|
| 190 |
+
|
| 191 |
+
def _make_layer(self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None):
|
| 192 |
+
downsample = None
|
| 193 |
+
|
| 194 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 195 |
+
ds_stride = conv_builder.get_downsample_stride(stride , temporal_stride)
|
| 196 |
+
downsample = nn.Sequential(
|
| 197 |
+
nn.Conv3d(self.inplanes, planes * block.expansion,
|
| 198 |
+
kernel_size=1, stride=ds_stride, bias=False),
|
| 199 |
+
batchnorm(planes * block.expansion)
|
| 200 |
+
)
|
| 201 |
+
stride = ds_stride
|
| 202 |
+
|
| 203 |
+
layers = []
|
| 204 |
+
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample ))
|
| 205 |
+
|
| 206 |
+
self.inplanes = planes * block.expansion
|
| 207 |
+
for i in range(1, blocks):
|
| 208 |
+
layers.append(block(self.inplanes, planes, conv_builder ))
|
| 209 |
+
|
| 210 |
+
return nn.Sequential(*layers)
|
| 211 |
+
|
| 212 |
+
def _initialize_weights(self):
|
| 213 |
+
for m in self.modules():
|
| 214 |
+
if isinstance(m, nn.Conv3d):
|
| 215 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out',
|
| 216 |
+
nonlinearity='relu')
|
| 217 |
+
if m.bias is not None:
|
| 218 |
+
nn.init.constant_(m.bias, 0)
|
| 219 |
+
elif isinstance(m, nn.BatchNorm3d):
|
| 220 |
+
nn.init.constant_(m.weight, 1)
|
| 221 |
+
nn.init.constant_(m.bias, 0)
|
| 222 |
+
elif isinstance(m, nn.Linear):
|
| 223 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 224 |
+
nn.init.constant_(m.bias, 0)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
|
| 228 |
+
model = VideoResNet(**kwargs)
|
| 229 |
+
## TODO: Other 3D resnet models, like S3D, r(2+1)D.
|
| 230 |
+
|
| 231 |
+
if pretrained:
|
| 232 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
| 233 |
+
progress=progress)
|
| 234 |
+
model.load_state_dict(state_dict)
|
| 235 |
+
return model
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def unet_18(pretrained=False, bn=False, progress=True, **kwargs):
|
| 239 |
+
"""
|
| 240 |
+
Construct 18 layer Unet3D model as in
|
| 241 |
+
https://arxiv.org/abs/1711.11248
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
| 245 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
nn.Module: R3D-18 encoder
|
| 249 |
+
"""
|
| 250 |
+
global batchnorm
|
| 251 |
+
if bn:
|
| 252 |
+
batchnorm = nn.BatchNorm3d
|
| 253 |
+
else:
|
| 254 |
+
batchnorm = identity
|
| 255 |
+
|
| 256 |
+
return _video_resnet('r3d_18',
|
| 257 |
+
pretrained, progress,
|
| 258 |
+
block=BasicBlock,
|
| 259 |
+
conv_makers=[Conv3DSimple] * 4,
|
| 260 |
+
layers=[2, 2, 2, 2],
|
| 261 |
+
stem=BasicStem, **kwargs)
|
| 262 |
+
|
| 263 |
+
def unet_34(pretrained=False, bn=False, progress=True, **kwargs):
|
| 264 |
+
"""
|
| 265 |
+
Construct 34 layer Unet3D model as in
|
| 266 |
+
https://arxiv.org/abs/1711.11248
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
|
| 270 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
nn.Module: R3D-18 encoder
|
| 274 |
+
"""
|
| 275 |
+
global batchnorm
|
| 276 |
+
# bn = False
|
| 277 |
+
if bn:
|
| 278 |
+
batchnorm = nn.BatchNorm3d
|
| 279 |
+
else:
|
| 280 |
+
batchnorm = identity
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
return _video_resnet('r3d_34',
|
| 284 |
+
pretrained, progress,
|
| 285 |
+
block=BasicBlock,
|
| 286 |
+
conv_makers=[Conv3DSimple] * 4,
|
| 287 |
+
layers=[3, 4, 6, 3],
|
| 288 |
+
stem=BasicStem, **kwargs)
|
vfi_models/gmfss_fortuna/GMFSS_Fortuna.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import numpy as np
|
| 3 |
+
import vapoursynth as vs
|
| 4 |
+
from .GMFSS_Fortuna_arch import Model_inference
|
| 5 |
+
import torch
|
| 6 |
+
import traceback
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GMFSS_Fortuna:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.cache = False
|
| 12 |
+
self.amount_input_img = 2
|
| 13 |
+
|
| 14 |
+
torch.set_grad_enabled(False)
|
| 15 |
+
torch.backends.cudnn.enabled = True
|
| 16 |
+
torch.backends.cudnn.benchmark = True
|
| 17 |
+
|
| 18 |
+
self.model = Model_inference()
|
| 19 |
+
self.model.eval()
|
| 20 |
+
|
| 21 |
+
def execute(self, I0, I1, timestep):
|
| 22 |
+
with torch.inference_mode():
|
| 23 |
+
middle = self.model(I0, I1, timestep).cpu()
|
| 24 |
+
return middle
|
vfi_models/gmfss_fortuna/GMFSS_Fortuna_arch.py
ADDED
|
@@ -0,0 +1,1850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/GMFSS_infer_b.py
|
| 3 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/softsplat.py
|
| 4 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FusionNet_b.py
|
| 5 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FeatureNet.py
|
| 6 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/MetricNet.py
|
| 7 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/IFNet_HDv3.py
|
| 8 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/gmflow.py
|
| 9 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/utils.py
|
| 10 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/position.py
|
| 11 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/geometry.py
|
| 12 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/matching.py
|
| 13 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/transformer.py
|
| 14 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/backbone.py
|
| 15 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/trident_conv.py
|
| 16 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/warplayer.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from torch.nn.modules.utils import _pair
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
import torch
|
| 27 |
+
import math
|
| 28 |
+
from vfi_models.rife.rife_arch import IFNet
|
| 29 |
+
from vfi_models.ops import softsplat
|
| 30 |
+
from comfy.model_management import get_torch_device
|
| 31 |
+
|
| 32 |
+
device = get_torch_device()
|
| 33 |
+
backwarp_tenGrid = {}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def warp(tenInput, tenFlow):
|
| 37 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
| 38 |
+
if k not in backwarp_tenGrid:
|
| 39 |
+
tenHorizontal = (
|
| 40 |
+
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
|
| 41 |
+
.view(1, 1, 1, tenFlow.shape[3])
|
| 42 |
+
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
| 43 |
+
)
|
| 44 |
+
tenVertical = (
|
| 45 |
+
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
|
| 46 |
+
.view(1, 1, tenFlow.shape[2], 1)
|
| 47 |
+
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
| 48 |
+
)
|
| 49 |
+
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
|
| 50 |
+
|
| 51 |
+
tenFlow = torch.cat(
|
| 52 |
+
[
|
| 53 |
+
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
| 54 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
|
| 55 |
+
],
|
| 56 |
+
1,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
| 60 |
+
return torch.nn.functional.grid_sample(
|
| 61 |
+
input=tenInput,
|
| 62 |
+
grid=g,
|
| 63 |
+
mode="bilinear",
|
| 64 |
+
padding_mode="border",
|
| 65 |
+
align_corners=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MultiScaleTridentConv(nn.Module):
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
in_channels,
|
| 73 |
+
out_channels,
|
| 74 |
+
kernel_size,
|
| 75 |
+
stride=1,
|
| 76 |
+
strides=1,
|
| 77 |
+
paddings=0,
|
| 78 |
+
dilations=1,
|
| 79 |
+
dilation=1,
|
| 80 |
+
groups=1,
|
| 81 |
+
num_branch=1,
|
| 82 |
+
test_branch_idx=-1,
|
| 83 |
+
bias=False,
|
| 84 |
+
norm=None,
|
| 85 |
+
activation=None,
|
| 86 |
+
):
|
| 87 |
+
super(MultiScaleTridentConv, self).__init__()
|
| 88 |
+
self.in_channels = in_channels
|
| 89 |
+
self.out_channels = out_channels
|
| 90 |
+
self.kernel_size = _pair(kernel_size)
|
| 91 |
+
self.num_branch = num_branch
|
| 92 |
+
self.stride = _pair(stride)
|
| 93 |
+
self.groups = groups
|
| 94 |
+
self.with_bias = bias
|
| 95 |
+
self.dilation = dilation
|
| 96 |
+
if isinstance(paddings, int):
|
| 97 |
+
paddings = [paddings] * self.num_branch
|
| 98 |
+
if isinstance(dilations, int):
|
| 99 |
+
dilations = [dilations] * self.num_branch
|
| 100 |
+
if isinstance(strides, int):
|
| 101 |
+
strides = [strides] * self.num_branch
|
| 102 |
+
self.paddings = [_pair(padding) for padding in paddings]
|
| 103 |
+
self.dilations = [_pair(dilation) for dilation in dilations]
|
| 104 |
+
self.strides = [_pair(stride) for stride in strides]
|
| 105 |
+
self.test_branch_idx = test_branch_idx
|
| 106 |
+
self.norm = norm
|
| 107 |
+
self.activation = activation
|
| 108 |
+
|
| 109 |
+
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
| 110 |
+
|
| 111 |
+
self.weight = nn.Parameter(
|
| 112 |
+
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
| 113 |
+
)
|
| 114 |
+
if bias:
|
| 115 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 116 |
+
else:
|
| 117 |
+
self.bias = None
|
| 118 |
+
|
| 119 |
+
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
| 120 |
+
if self.bias is not None:
|
| 121 |
+
nn.init.constant_(self.bias, 0)
|
| 122 |
+
|
| 123 |
+
def forward(self, inputs):
|
| 124 |
+
num_branch = (
|
| 125 |
+
self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
| 126 |
+
)
|
| 127 |
+
assert len(inputs) == num_branch
|
| 128 |
+
|
| 129 |
+
if self.training or self.test_branch_idx == -1:
|
| 130 |
+
outputs = [
|
| 131 |
+
F.conv2d(
|
| 132 |
+
input,
|
| 133 |
+
self.weight,
|
| 134 |
+
self.bias,
|
| 135 |
+
stride,
|
| 136 |
+
padding,
|
| 137 |
+
self.dilation,
|
| 138 |
+
self.groups,
|
| 139 |
+
)
|
| 140 |
+
for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
| 141 |
+
]
|
| 142 |
+
else:
|
| 143 |
+
outputs = [
|
| 144 |
+
F.conv2d(
|
| 145 |
+
inputs[0],
|
| 146 |
+
self.weight,
|
| 147 |
+
self.bias,
|
| 148 |
+
self.strides[self.test_branch_idx]
|
| 149 |
+
if self.test_branch_idx == -1
|
| 150 |
+
else self.strides[-1],
|
| 151 |
+
self.paddings[self.test_branch_idx]
|
| 152 |
+
if self.test_branch_idx == -1
|
| 153 |
+
else self.paddings[-1],
|
| 154 |
+
self.dilation,
|
| 155 |
+
self.groups,
|
| 156 |
+
)
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
if self.norm is not None:
|
| 160 |
+
outputs = [self.norm(x) for x in outputs]
|
| 161 |
+
if self.activation is not None:
|
| 162 |
+
outputs = [self.activation(x) for x in outputs]
|
| 163 |
+
return outputs
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ResidualBlock_class(nn.Module):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
in_planes,
|
| 170 |
+
planes,
|
| 171 |
+
norm_layer=nn.InstanceNorm2d,
|
| 172 |
+
stride=1,
|
| 173 |
+
dilation=1,
|
| 174 |
+
):
|
| 175 |
+
super(ResidualBlock_class, self).__init__()
|
| 176 |
+
|
| 177 |
+
self.conv1 = nn.Conv2d(
|
| 178 |
+
in_planes,
|
| 179 |
+
planes,
|
| 180 |
+
kernel_size=3,
|
| 181 |
+
dilation=dilation,
|
| 182 |
+
padding=dilation,
|
| 183 |
+
stride=stride,
|
| 184 |
+
bias=False,
|
| 185 |
+
)
|
| 186 |
+
self.conv2 = nn.Conv2d(
|
| 187 |
+
planes,
|
| 188 |
+
planes,
|
| 189 |
+
kernel_size=3,
|
| 190 |
+
dilation=dilation,
|
| 191 |
+
padding=dilation,
|
| 192 |
+
bias=False,
|
| 193 |
+
)
|
| 194 |
+
self.relu = nn.ReLU(inplace=True)
|
| 195 |
+
|
| 196 |
+
self.norm1 = norm_layer(planes)
|
| 197 |
+
self.norm2 = norm_layer(planes)
|
| 198 |
+
if not stride == 1 or in_planes != planes:
|
| 199 |
+
self.norm3 = norm_layer(planes)
|
| 200 |
+
|
| 201 |
+
if stride == 1 and in_planes == planes:
|
| 202 |
+
self.downsample = None
|
| 203 |
+
else:
|
| 204 |
+
self.downsample = nn.Sequential(
|
| 205 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def forward(self, x):
|
| 209 |
+
y = x
|
| 210 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 211 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 212 |
+
|
| 213 |
+
if self.downsample is not None:
|
| 214 |
+
x = self.downsample(x)
|
| 215 |
+
|
| 216 |
+
return self.relu(x + y)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class CNNEncoder(nn.Module):
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
output_dim=128,
|
| 223 |
+
norm_layer=nn.InstanceNorm2d,
|
| 224 |
+
num_output_scales=1,
|
| 225 |
+
**kwargs,
|
| 226 |
+
):
|
| 227 |
+
super(CNNEncoder, self).__init__()
|
| 228 |
+
self.num_branch = num_output_scales
|
| 229 |
+
|
| 230 |
+
feature_dims = [64, 96, 128]
|
| 231 |
+
|
| 232 |
+
self.conv1 = nn.Conv2d(
|
| 233 |
+
3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False
|
| 234 |
+
) # 1/2
|
| 235 |
+
self.norm1 = norm_layer(feature_dims[0])
|
| 236 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 237 |
+
|
| 238 |
+
self.in_planes = feature_dims[0]
|
| 239 |
+
self.layer1 = self._make_layer(
|
| 240 |
+
feature_dims[0], stride=1, norm_layer=norm_layer
|
| 241 |
+
) # 1/2
|
| 242 |
+
self.layer2 = self._make_layer(
|
| 243 |
+
feature_dims[1], stride=2, norm_layer=norm_layer
|
| 244 |
+
) # 1/4
|
| 245 |
+
|
| 246 |
+
# highest resolution 1/4 or 1/8
|
| 247 |
+
stride = 2 if num_output_scales == 1 else 1
|
| 248 |
+
self.layer3 = self._make_layer(
|
| 249 |
+
feature_dims[2],
|
| 250 |
+
stride=stride,
|
| 251 |
+
norm_layer=norm_layer,
|
| 252 |
+
) # 1/4 or 1/8
|
| 253 |
+
|
| 254 |
+
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
| 255 |
+
|
| 256 |
+
if self.num_branch > 1:
|
| 257 |
+
if self.num_branch == 4:
|
| 258 |
+
strides = (1, 2, 4, 8)
|
| 259 |
+
elif self.num_branch == 3:
|
| 260 |
+
strides = (1, 2, 4)
|
| 261 |
+
elif self.num_branch == 2:
|
| 262 |
+
strides = (1, 2)
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError
|
| 265 |
+
|
| 266 |
+
self.trident_conv = MultiScaleTridentConv(
|
| 267 |
+
output_dim,
|
| 268 |
+
output_dim,
|
| 269 |
+
kernel_size=3,
|
| 270 |
+
strides=strides,
|
| 271 |
+
paddings=1,
|
| 272 |
+
num_branch=self.num_branch,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
for m in self.modules():
|
| 276 |
+
if isinstance(m, nn.Conv2d):
|
| 277 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 278 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 279 |
+
if m.weight is not None:
|
| 280 |
+
nn.init.constant_(m.weight, 1)
|
| 281 |
+
if m.bias is not None:
|
| 282 |
+
nn.init.constant_(m.bias, 0)
|
| 283 |
+
|
| 284 |
+
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
| 285 |
+
layer1 = ResidualBlock_class(
|
| 286 |
+
self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation
|
| 287 |
+
)
|
| 288 |
+
layer2 = ResidualBlock_class(
|
| 289 |
+
dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
layers = (layer1, layer2)
|
| 293 |
+
|
| 294 |
+
self.in_planes = dim
|
| 295 |
+
return nn.Sequential(*layers)
|
| 296 |
+
|
| 297 |
+
def forward(self, x):
|
| 298 |
+
x = self.conv1(x)
|
| 299 |
+
x = self.norm1(x)
|
| 300 |
+
x = self.relu1(x)
|
| 301 |
+
|
| 302 |
+
x = self.layer1(x) # 1/2
|
| 303 |
+
x = self.layer2(x) # 1/4
|
| 304 |
+
x = self.layer3(x) # 1/8 or 1/4
|
| 305 |
+
|
| 306 |
+
x = self.conv2(x)
|
| 307 |
+
|
| 308 |
+
if self.num_branch > 1:
|
| 309 |
+
out = self.trident_conv([x] * self.num_branch) # high to low res
|
| 310 |
+
else:
|
| 311 |
+
out = [x]
|
| 312 |
+
|
| 313 |
+
return out
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def single_head_full_attention(q, k, v):
|
| 317 |
+
# q, k, v: [B, L, C]
|
| 318 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
| 319 |
+
|
| 320 |
+
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L]
|
| 321 |
+
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
| 322 |
+
out = torch.matmul(attn, v) # [B, L, C]
|
| 323 |
+
|
| 324 |
+
return out
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def generate_shift_window_attn_mask(
|
| 328 |
+
input_resolution,
|
| 329 |
+
window_size_h,
|
| 330 |
+
window_size_w,
|
| 331 |
+
shift_size_h,
|
| 332 |
+
shift_size_w,
|
| 333 |
+
device=get_torch_device(),
|
| 334 |
+
):
|
| 335 |
+
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
| 336 |
+
# calculate attention mask for SW-MSA
|
| 337 |
+
h, w = input_resolution
|
| 338 |
+
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
| 339 |
+
h_slices = (
|
| 340 |
+
slice(0, -window_size_h),
|
| 341 |
+
slice(-window_size_h, -shift_size_h),
|
| 342 |
+
slice(-shift_size_h, None),
|
| 343 |
+
)
|
| 344 |
+
w_slices = (
|
| 345 |
+
slice(0, -window_size_w),
|
| 346 |
+
slice(-window_size_w, -shift_size_w),
|
| 347 |
+
slice(-shift_size_w, None),
|
| 348 |
+
)
|
| 349 |
+
cnt = 0
|
| 350 |
+
for h in h_slices:
|
| 351 |
+
for w in w_slices:
|
| 352 |
+
img_mask[:, h, w, :] = cnt
|
| 353 |
+
cnt += 1
|
| 354 |
+
|
| 355 |
+
mask_windows = split_feature(
|
| 356 |
+
img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
| 360 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 361 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
| 362 |
+
attn_mask == 0, float(0.0)
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return attn_mask
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def single_head_split_window_attention(
|
| 369 |
+
q,
|
| 370 |
+
k,
|
| 371 |
+
v,
|
| 372 |
+
num_splits=1,
|
| 373 |
+
with_shift=False,
|
| 374 |
+
h=None,
|
| 375 |
+
w=None,
|
| 376 |
+
attn_mask=None,
|
| 377 |
+
):
|
| 378 |
+
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
| 379 |
+
# q, k, v: [B, L, C]
|
| 380 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
| 381 |
+
|
| 382 |
+
assert h is not None and w is not None
|
| 383 |
+
assert q.size(1) == h * w
|
| 384 |
+
|
| 385 |
+
b, _, c = q.size()
|
| 386 |
+
|
| 387 |
+
b_new = b * num_splits * num_splits
|
| 388 |
+
|
| 389 |
+
window_size_h = h // num_splits
|
| 390 |
+
window_size_w = w // num_splits
|
| 391 |
+
|
| 392 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
| 393 |
+
k = k.view(b, h, w, c)
|
| 394 |
+
v = v.view(b, h, w, c)
|
| 395 |
+
|
| 396 |
+
scale_factor = c**0.5
|
| 397 |
+
|
| 398 |
+
if with_shift:
|
| 399 |
+
assert attn_mask is not None # compute once
|
| 400 |
+
shift_size_h = window_size_h // 2
|
| 401 |
+
shift_size_w = window_size_w // 2
|
| 402 |
+
|
| 403 |
+
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 404 |
+
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 405 |
+
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 406 |
+
|
| 407 |
+
q = split_feature(
|
| 408 |
+
q, num_splits=num_splits, channel_last=True
|
| 409 |
+
) # [B*K*K, H/K, W/K, C]
|
| 410 |
+
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
| 411 |
+
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
| 412 |
+
|
| 413 |
+
scores = (
|
| 414 |
+
torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1))
|
| 415 |
+
/ scale_factor
|
| 416 |
+
) # [B*K*K, H/K*W/K, H/K*W/K]
|
| 417 |
+
|
| 418 |
+
if with_shift:
|
| 419 |
+
scores += attn_mask.repeat(b, 1, 1)
|
| 420 |
+
|
| 421 |
+
attn = torch.softmax(scores, dim=-1)
|
| 422 |
+
|
| 423 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
| 424 |
+
|
| 425 |
+
out = merge_splits(
|
| 426 |
+
out.view(b_new, h // num_splits, w // num_splits, c),
|
| 427 |
+
num_splits=num_splits,
|
| 428 |
+
channel_last=True,
|
| 429 |
+
) # [B, H, W, C]
|
| 430 |
+
|
| 431 |
+
# shift back
|
| 432 |
+
if with_shift:
|
| 433 |
+
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
| 434 |
+
|
| 435 |
+
out = out.view(b, -1, c)
|
| 436 |
+
|
| 437 |
+
return out
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class TransformerLayer(nn.Module):
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
d_model=256,
|
| 444 |
+
nhead=1,
|
| 445 |
+
attention_type="swin",
|
| 446 |
+
no_ffn=False,
|
| 447 |
+
ffn_dim_expansion=4,
|
| 448 |
+
with_shift=False,
|
| 449 |
+
**kwargs,
|
| 450 |
+
):
|
| 451 |
+
super(TransformerLayer, self).__init__()
|
| 452 |
+
|
| 453 |
+
self.dim = d_model
|
| 454 |
+
self.nhead = nhead
|
| 455 |
+
self.attention_type = attention_type
|
| 456 |
+
self.no_ffn = no_ffn
|
| 457 |
+
|
| 458 |
+
self.with_shift = with_shift
|
| 459 |
+
|
| 460 |
+
# multi-head attention
|
| 461 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 462 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 463 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 464 |
+
|
| 465 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
| 466 |
+
|
| 467 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 468 |
+
|
| 469 |
+
# no ffn after self-attn, with ffn after cross-attn
|
| 470 |
+
if not self.no_ffn:
|
| 471 |
+
in_channels = d_model * 2
|
| 472 |
+
self.mlp = nn.Sequential(
|
| 473 |
+
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
| 474 |
+
nn.GELU(),
|
| 475 |
+
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 479 |
+
|
| 480 |
+
def forward(
|
| 481 |
+
self,
|
| 482 |
+
source,
|
| 483 |
+
target,
|
| 484 |
+
height=None,
|
| 485 |
+
width=None,
|
| 486 |
+
shifted_window_attn_mask=None,
|
| 487 |
+
attn_num_splits=None,
|
| 488 |
+
**kwargs,
|
| 489 |
+
):
|
| 490 |
+
# source, target: [B, L, C]
|
| 491 |
+
query, key, value = source, target, target
|
| 492 |
+
|
| 493 |
+
# single-head attention
|
| 494 |
+
query = self.q_proj(query) # [B, L, C]
|
| 495 |
+
key = self.k_proj(key) # [B, L, C]
|
| 496 |
+
value = self.v_proj(value) # [B, L, C]
|
| 497 |
+
|
| 498 |
+
if self.attention_type == "swin" and attn_num_splits > 1:
|
| 499 |
+
if self.nhead > 1:
|
| 500 |
+
# we observe that multihead attention slows down the speed and increases the memory consumption
|
| 501 |
+
# without bringing obvious performance gains and thus the implementation is removed
|
| 502 |
+
raise NotImplementedError
|
| 503 |
+
else:
|
| 504 |
+
message = single_head_split_window_attention(
|
| 505 |
+
query,
|
| 506 |
+
key,
|
| 507 |
+
value,
|
| 508 |
+
num_splits=attn_num_splits,
|
| 509 |
+
with_shift=self.with_shift,
|
| 510 |
+
h=height,
|
| 511 |
+
w=width,
|
| 512 |
+
attn_mask=shifted_window_attn_mask,
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
message = single_head_full_attention(query, key, value) # [B, L, C]
|
| 516 |
+
|
| 517 |
+
message = self.merge(message) # [B, L, C]
|
| 518 |
+
message = self.norm1(message)
|
| 519 |
+
|
| 520 |
+
if not self.no_ffn:
|
| 521 |
+
message = self.mlp(torch.cat([source, message], dim=-1))
|
| 522 |
+
message = self.norm2(message)
|
| 523 |
+
|
| 524 |
+
return source + message
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class TransformerBlock(nn.Module):
|
| 528 |
+
"""self attention + cross attention + FFN"""
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
d_model=256,
|
| 533 |
+
nhead=1,
|
| 534 |
+
attention_type="swin",
|
| 535 |
+
ffn_dim_expansion=4,
|
| 536 |
+
with_shift=False,
|
| 537 |
+
**kwargs,
|
| 538 |
+
):
|
| 539 |
+
super(TransformerBlock, self).__init__()
|
| 540 |
+
|
| 541 |
+
self.self_attn = TransformerLayer(
|
| 542 |
+
d_model=d_model,
|
| 543 |
+
nhead=nhead,
|
| 544 |
+
attention_type=attention_type,
|
| 545 |
+
no_ffn=True,
|
| 546 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 547 |
+
with_shift=with_shift,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
self.cross_attn_ffn = TransformerLayer(
|
| 551 |
+
d_model=d_model,
|
| 552 |
+
nhead=nhead,
|
| 553 |
+
attention_type=attention_type,
|
| 554 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 555 |
+
with_shift=with_shift,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def forward(
|
| 559 |
+
self,
|
| 560 |
+
source,
|
| 561 |
+
target,
|
| 562 |
+
height=None,
|
| 563 |
+
width=None,
|
| 564 |
+
shifted_window_attn_mask=None,
|
| 565 |
+
attn_num_splits=None,
|
| 566 |
+
**kwargs,
|
| 567 |
+
):
|
| 568 |
+
# source, target: [B, L, C]
|
| 569 |
+
|
| 570 |
+
# self attention
|
| 571 |
+
source = self.self_attn(
|
| 572 |
+
source,
|
| 573 |
+
source,
|
| 574 |
+
height=height,
|
| 575 |
+
width=width,
|
| 576 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 577 |
+
attn_num_splits=attn_num_splits,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# cross attention and ffn
|
| 581 |
+
source = self.cross_attn_ffn(
|
| 582 |
+
source,
|
| 583 |
+
target,
|
| 584 |
+
height=height,
|
| 585 |
+
width=width,
|
| 586 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 587 |
+
attn_num_splits=attn_num_splits,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
return source
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
class FeatureTransformer(nn.Module):
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
num_layers=6,
|
| 597 |
+
d_model=128,
|
| 598 |
+
nhead=1,
|
| 599 |
+
attention_type="swin",
|
| 600 |
+
ffn_dim_expansion=4,
|
| 601 |
+
**kwargs,
|
| 602 |
+
):
|
| 603 |
+
super(FeatureTransformer, self).__init__()
|
| 604 |
+
|
| 605 |
+
self.attention_type = attention_type
|
| 606 |
+
|
| 607 |
+
self.d_model = d_model
|
| 608 |
+
self.nhead = nhead
|
| 609 |
+
|
| 610 |
+
self.layers = nn.ModuleList(
|
| 611 |
+
[
|
| 612 |
+
TransformerBlock(
|
| 613 |
+
d_model=d_model,
|
| 614 |
+
nhead=nhead,
|
| 615 |
+
attention_type=attention_type,
|
| 616 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 617 |
+
with_shift=True
|
| 618 |
+
if attention_type == "swin" and i % 2 == 1
|
| 619 |
+
else False,
|
| 620 |
+
)
|
| 621 |
+
for i in range(num_layers)
|
| 622 |
+
]
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
for p in self.parameters():
|
| 626 |
+
if p.dim() > 1:
|
| 627 |
+
nn.init.xavier_uniform_(p)
|
| 628 |
+
|
| 629 |
+
def forward(
|
| 630 |
+
self,
|
| 631 |
+
feature0,
|
| 632 |
+
feature1,
|
| 633 |
+
attn_num_splits=None,
|
| 634 |
+
**kwargs,
|
| 635 |
+
):
|
| 636 |
+
b, c, h, w = feature0.shape
|
| 637 |
+
assert self.d_model == c
|
| 638 |
+
|
| 639 |
+
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
| 640 |
+
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
| 641 |
+
|
| 642 |
+
if self.attention_type == "swin" and attn_num_splits > 1:
|
| 643 |
+
# global and refine use different number of splits
|
| 644 |
+
window_size_h = h // attn_num_splits
|
| 645 |
+
window_size_w = w // attn_num_splits
|
| 646 |
+
|
| 647 |
+
# compute attn mask once
|
| 648 |
+
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
| 649 |
+
input_resolution=(h, w),
|
| 650 |
+
window_size_h=window_size_h,
|
| 651 |
+
window_size_w=window_size_w,
|
| 652 |
+
shift_size_h=window_size_h // 2,
|
| 653 |
+
shift_size_w=window_size_w // 2,
|
| 654 |
+
device=feature0.device,
|
| 655 |
+
) # [K*K, H/K*W/K, H/K*W/K]
|
| 656 |
+
else:
|
| 657 |
+
shifted_window_attn_mask = None
|
| 658 |
+
|
| 659 |
+
# concat feature0 and feature1 in batch dimension to compute in parallel
|
| 660 |
+
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
| 661 |
+
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
| 662 |
+
|
| 663 |
+
for layer in self.layers:
|
| 664 |
+
concat0 = layer(
|
| 665 |
+
concat0,
|
| 666 |
+
concat1,
|
| 667 |
+
height=h,
|
| 668 |
+
width=w,
|
| 669 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 670 |
+
attn_num_splits=attn_num_splits,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# update feature1
|
| 674 |
+
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
| 675 |
+
|
| 676 |
+
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
| 677 |
+
|
| 678 |
+
# reshape back
|
| 679 |
+
feature0 = (
|
| 680 |
+
feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
| 681 |
+
) # [B, C, H, W]
|
| 682 |
+
feature1 = (
|
| 683 |
+
feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
| 684 |
+
) # [B, C, H, W]
|
| 685 |
+
|
| 686 |
+
return feature0, feature1
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
class FeatureFlowAttention(nn.Module):
|
| 690 |
+
"""
|
| 691 |
+
flow propagation with self-attention on feature
|
| 692 |
+
query: feature0, key: feature0, value: flow
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
def __init__(
|
| 696 |
+
self,
|
| 697 |
+
in_channels,
|
| 698 |
+
**kwargs,
|
| 699 |
+
):
|
| 700 |
+
super(FeatureFlowAttention, self).__init__()
|
| 701 |
+
|
| 702 |
+
self.q_proj = nn.Linear(in_channels, in_channels)
|
| 703 |
+
self.k_proj = nn.Linear(in_channels, in_channels)
|
| 704 |
+
|
| 705 |
+
for p in self.parameters():
|
| 706 |
+
if p.dim() > 1:
|
| 707 |
+
nn.init.xavier_uniform_(p)
|
| 708 |
+
|
| 709 |
+
def forward(
|
| 710 |
+
self,
|
| 711 |
+
feature0,
|
| 712 |
+
flow,
|
| 713 |
+
local_window_attn=False,
|
| 714 |
+
local_window_radius=1,
|
| 715 |
+
**kwargs,
|
| 716 |
+
):
|
| 717 |
+
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
| 718 |
+
if local_window_attn:
|
| 719 |
+
return self.forward_local_window_attn(
|
| 720 |
+
feature0, flow, local_window_radius=local_window_radius
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
b, c, h, w = feature0.size()
|
| 724 |
+
|
| 725 |
+
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
| 726 |
+
|
| 727 |
+
# a note: the ``correct'' implementation should be:
|
| 728 |
+
# ``query = self.q_proj(query), key = self.k_proj(query)''
|
| 729 |
+
# this problem is observed while cleaning up the code
|
| 730 |
+
# however, this doesn't affect the performance since the projection is a linear operation,
|
| 731 |
+
# thus the two projection matrices for key can be merged
|
| 732 |
+
# so I just leave it as is in order to not re-train all models :)
|
| 733 |
+
query = self.q_proj(query) # [B, H*W, C]
|
| 734 |
+
key = self.k_proj(query) # [B, H*W, C]
|
| 735 |
+
|
| 736 |
+
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
| 737 |
+
|
| 738 |
+
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W]
|
| 739 |
+
prob = torch.softmax(scores, dim=-1)
|
| 740 |
+
|
| 741 |
+
out = torch.matmul(prob, value) # [B, H*W, 2]
|
| 742 |
+
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
| 743 |
+
|
| 744 |
+
return out
|
| 745 |
+
|
| 746 |
+
def forward_local_window_attn(
|
| 747 |
+
self,
|
| 748 |
+
feature0,
|
| 749 |
+
flow,
|
| 750 |
+
local_window_radius=1,
|
| 751 |
+
):
|
| 752 |
+
assert flow.size(1) == 2
|
| 753 |
+
assert local_window_radius > 0
|
| 754 |
+
|
| 755 |
+
b, c, h, w = feature0.size()
|
| 756 |
+
|
| 757 |
+
feature0_reshape = self.q_proj(
|
| 758 |
+
feature0.view(b, c, -1).permute(0, 2, 1)
|
| 759 |
+
).reshape(
|
| 760 |
+
b * h * w, 1, c
|
| 761 |
+
) # [B*H*W, 1, C]
|
| 762 |
+
|
| 763 |
+
kernel_size = 2 * local_window_radius + 1
|
| 764 |
+
|
| 765 |
+
feature0_proj = (
|
| 766 |
+
self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1))
|
| 767 |
+
.permute(0, 2, 1)
|
| 768 |
+
.reshape(b, c, h, w)
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
feature0_window = F.unfold(
|
| 772 |
+
feature0_proj, kernel_size=kernel_size, padding=local_window_radius
|
| 773 |
+
) # [B, C*(2R+1)^2), H*W]
|
| 774 |
+
|
| 775 |
+
feature0_window = (
|
| 776 |
+
feature0_window.view(b, c, kernel_size**2, h, w)
|
| 777 |
+
.permute(0, 3, 4, 1, 2)
|
| 778 |
+
.reshape(b * h * w, c, kernel_size**2)
|
| 779 |
+
) # [B*H*W, C, (2R+1)^2]
|
| 780 |
+
|
| 781 |
+
flow_window = F.unfold(
|
| 782 |
+
flow, kernel_size=kernel_size, padding=local_window_radius
|
| 783 |
+
) # [B, 2*(2R+1)^2), H*W]
|
| 784 |
+
|
| 785 |
+
flow_window = (
|
| 786 |
+
flow_window.view(b, 2, kernel_size**2, h, w)
|
| 787 |
+
.permute(0, 3, 4, 2, 1)
|
| 788 |
+
.reshape(b * h * w, kernel_size**2, 2)
|
| 789 |
+
) # [B*H*W, (2R+1)^2, 2]
|
| 790 |
+
|
| 791 |
+
scores = torch.matmul(feature0_reshape, feature0_window) / (
|
| 792 |
+
c**0.5
|
| 793 |
+
) # [B*H*W, 1, (2R+1)^2]
|
| 794 |
+
|
| 795 |
+
prob = torch.softmax(scores, dim=-1)
|
| 796 |
+
|
| 797 |
+
out = (
|
| 798 |
+
torch.matmul(prob, flow_window)
|
| 799 |
+
.view(b, h, w, 2)
|
| 800 |
+
.permute(0, 3, 1, 2)
|
| 801 |
+
.contiguous()
|
| 802 |
+
) # [B, 2, H, W]
|
| 803 |
+
|
| 804 |
+
return out
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def global_correlation_softmax(
|
| 808 |
+
feature0,
|
| 809 |
+
feature1,
|
| 810 |
+
pred_bidir_flow=False,
|
| 811 |
+
):
|
| 812 |
+
# global correlation
|
| 813 |
+
b, c, h, w = feature0.shape
|
| 814 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
| 815 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
| 816 |
+
|
| 817 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
|
| 818 |
+
c**0.5
|
| 819 |
+
) # [B, H, W, H, W]
|
| 820 |
+
|
| 821 |
+
# flow from softmax
|
| 822 |
+
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
|
| 823 |
+
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 824 |
+
|
| 825 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
| 826 |
+
|
| 827 |
+
if pred_bidir_flow:
|
| 828 |
+
correlation = torch.cat(
|
| 829 |
+
(correlation, correlation.permute(0, 2, 1)), dim=0
|
| 830 |
+
) # [2*B, H*W, H*W]
|
| 831 |
+
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
|
| 832 |
+
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
|
| 833 |
+
b = b * 2
|
| 834 |
+
|
| 835 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
| 836 |
+
|
| 837 |
+
correspondence = (
|
| 838 |
+
torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2)
|
| 839 |
+
) # [B, 2, H, W]
|
| 840 |
+
|
| 841 |
+
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
|
| 842 |
+
flow = correspondence - init_grid
|
| 843 |
+
|
| 844 |
+
return flow, prob
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def local_correlation_softmax(
|
| 848 |
+
feature0,
|
| 849 |
+
feature1,
|
| 850 |
+
local_radius,
|
| 851 |
+
padding_mode="zeros",
|
| 852 |
+
):
|
| 853 |
+
b, c, h, w = feature0.size()
|
| 854 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
| 855 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 856 |
+
|
| 857 |
+
local_h = 2 * local_radius + 1
|
| 858 |
+
local_w = 2 * local_radius + 1
|
| 859 |
+
|
| 860 |
+
window_grid = generate_window_grid(
|
| 861 |
+
-local_radius,
|
| 862 |
+
local_radius,
|
| 863 |
+
-local_radius,
|
| 864 |
+
local_radius,
|
| 865 |
+
local_h,
|
| 866 |
+
local_w,
|
| 867 |
+
device=feature0.device,
|
| 868 |
+
) # [2R+1, 2R+1, 2]
|
| 869 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
| 870 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
|
| 871 |
+
|
| 872 |
+
sample_coords_softmax = sample_coords
|
| 873 |
+
|
| 874 |
+
# exclude coords that are out of image space
|
| 875 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (
|
| 876 |
+
sample_coords[:, :, :, 0] < w
|
| 877 |
+
) # [B, H*W, (2R+1)^2]
|
| 878 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (
|
| 879 |
+
sample_coords[:, :, :, 1] < h
|
| 880 |
+
) # [B, H*W, (2R+1)^2]
|
| 881 |
+
|
| 882 |
+
valid = (
|
| 883 |
+
valid_x & valid_y
|
| 884 |
+
) # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
| 885 |
+
|
| 886 |
+
# normalize coordinates to [-1, 1]
|
| 887 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
| 888 |
+
window_feature = F.grid_sample(
|
| 889 |
+
feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True
|
| 890 |
+
).permute(
|
| 891 |
+
0, 2, 1, 3
|
| 892 |
+
) # [B, H*W, C, (2R+1)^2]
|
| 893 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
| 894 |
+
|
| 895 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
|
| 896 |
+
c**0.5
|
| 897 |
+
) # [B, H*W, (2R+1)^2]
|
| 898 |
+
|
| 899 |
+
# mask invalid locations
|
| 900 |
+
corr[~valid] = -1e9
|
| 901 |
+
|
| 902 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
|
| 903 |
+
|
| 904 |
+
correspondence = (
|
| 905 |
+
torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
|
| 906 |
+
.squeeze(-2)
|
| 907 |
+
.view(b, h, w, 2)
|
| 908 |
+
.permute(0, 3, 1, 2)
|
| 909 |
+
) # [B, 2, H, W]
|
| 910 |
+
|
| 911 |
+
flow = correspondence - coords_init
|
| 912 |
+
match_prob = prob
|
| 913 |
+
|
| 914 |
+
return flow, match_prob
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def coords_grid(b, h, w, homogeneous=False, device=None):
|
| 918 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
| 919 |
+
|
| 920 |
+
stacks = [x, y]
|
| 921 |
+
|
| 922 |
+
if homogeneous:
|
| 923 |
+
ones = torch.ones_like(x) # [H, W]
|
| 924 |
+
stacks.append(ones)
|
| 925 |
+
|
| 926 |
+
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
| 927 |
+
|
| 928 |
+
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
| 929 |
+
|
| 930 |
+
if device is not None:
|
| 931 |
+
grid = grid.to(device)
|
| 932 |
+
|
| 933 |
+
return grid
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
| 937 |
+
assert device is not None
|
| 938 |
+
|
| 939 |
+
x, y = torch.meshgrid(
|
| 940 |
+
[
|
| 941 |
+
torch.linspace(w_min, w_max, len_w, device=device),
|
| 942 |
+
torch.linspace(h_min, h_max, len_h, device=device),
|
| 943 |
+
],
|
| 944 |
+
)
|
| 945 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
| 946 |
+
|
| 947 |
+
return grid
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def normalize_coords(coords, h, w):
|
| 951 |
+
# coords: [B, H, W, 2]
|
| 952 |
+
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
|
| 953 |
+
return (coords - c) / c # [-1, 1]
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def bilinear_sample(
|
| 957 |
+
img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False
|
| 958 |
+
):
|
| 959 |
+
# img: [B, C, H, W]
|
| 960 |
+
# sample_coords: [B, 2, H, W] in image scale
|
| 961 |
+
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
| 962 |
+
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
| 963 |
+
|
| 964 |
+
b, _, h, w = sample_coords.shape
|
| 965 |
+
|
| 966 |
+
# Normalize to [-1, 1]
|
| 967 |
+
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
| 968 |
+
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
| 969 |
+
|
| 970 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
| 971 |
+
|
| 972 |
+
img = F.grid_sample(
|
| 973 |
+
img, grid, mode=mode, padding_mode=padding_mode, align_corners=True
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
if return_mask:
|
| 977 |
+
mask = (
|
| 978 |
+
(x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1)
|
| 979 |
+
) # [B, H, W]
|
| 980 |
+
|
| 981 |
+
return img, mask
|
| 982 |
+
|
| 983 |
+
return img
|
| 984 |
+
|
| 985 |
+
|
| 986 |
+
def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
|
| 987 |
+
b, c, h, w = feature.size()
|
| 988 |
+
assert flow.size(1) == 2
|
| 989 |
+
|
| 990 |
+
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
| 991 |
+
|
| 992 |
+
return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
|
| 996 |
+
# fwd_flow, bwd_flow: [B, 2, H, W]
|
| 997 |
+
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
| 998 |
+
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
| 999 |
+
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
| 1000 |
+
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
| 1001 |
+
|
| 1002 |
+
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
| 1003 |
+
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
| 1004 |
+
|
| 1005 |
+
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
| 1006 |
+
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
| 1007 |
+
|
| 1008 |
+
threshold = alpha * flow_mag + beta
|
| 1009 |
+
|
| 1010 |
+
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
| 1011 |
+
bwd_occ = (diff_bwd > threshold).float()
|
| 1012 |
+
|
| 1013 |
+
return fwd_occ, bwd_occ
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
class PositionEmbeddingSine(nn.Module):
|
| 1017 |
+
"""
|
| 1018 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 1019 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 1020 |
+
"""
|
| 1021 |
+
|
| 1022 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
|
| 1023 |
+
super().__init__()
|
| 1024 |
+
self.num_pos_feats = num_pos_feats
|
| 1025 |
+
self.temperature = temperature
|
| 1026 |
+
self.normalize = normalize
|
| 1027 |
+
if scale is not None and normalize is False:
|
| 1028 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 1029 |
+
if scale is None:
|
| 1030 |
+
scale = 2 * math.pi
|
| 1031 |
+
self.scale = scale
|
| 1032 |
+
|
| 1033 |
+
def forward(self, x):
|
| 1034 |
+
# x = tensor_list.tensors # [B, C, H, W]
|
| 1035 |
+
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
|
| 1036 |
+
b, c, h, w = x.size()
|
| 1037 |
+
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
|
| 1038 |
+
y_embed = mask.cumsum(1, dtype=torch.float32)
|
| 1039 |
+
x_embed = mask.cumsum(2, dtype=torch.float32)
|
| 1040 |
+
if self.normalize:
|
| 1041 |
+
eps = 1e-6
|
| 1042 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 1043 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 1044 |
+
|
| 1045 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 1046 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 1047 |
+
|
| 1048 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 1049 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 1050 |
+
pos_x = torch.stack(
|
| 1051 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 1052 |
+
).flatten(3)
|
| 1053 |
+
pos_y = torch.stack(
|
| 1054 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 1055 |
+
).flatten(3)
|
| 1056 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 1057 |
+
return pos
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
def split_feature(
|
| 1061 |
+
feature,
|
| 1062 |
+
num_splits=2,
|
| 1063 |
+
channel_last=False,
|
| 1064 |
+
):
|
| 1065 |
+
if channel_last: # [B, H, W, C]
|
| 1066 |
+
b, h, w, c = feature.size()
|
| 1067 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
| 1068 |
+
|
| 1069 |
+
b_new = b * num_splits * num_splits
|
| 1070 |
+
h_new = h // num_splits
|
| 1071 |
+
w_new = w // num_splits
|
| 1072 |
+
|
| 1073 |
+
feature = (
|
| 1074 |
+
feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
|
| 1075 |
+
.permute(0, 1, 3, 2, 4, 5)
|
| 1076 |
+
.reshape(b_new, h_new, w_new, c)
|
| 1077 |
+
) # [B*K*K, H/K, W/K, C]
|
| 1078 |
+
else: # [B, C, H, W]
|
| 1079 |
+
b, c, h, w = feature.size()
|
| 1080 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
| 1081 |
+
|
| 1082 |
+
b_new = b * num_splits * num_splits
|
| 1083 |
+
h_new = h // num_splits
|
| 1084 |
+
w_new = w // num_splits
|
| 1085 |
+
|
| 1086 |
+
feature = (
|
| 1087 |
+
feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
|
| 1088 |
+
.permute(0, 2, 4, 1, 3, 5)
|
| 1089 |
+
.reshape(b_new, c, h_new, w_new)
|
| 1090 |
+
) # [B*K*K, C, H/K, W/K]
|
| 1091 |
+
|
| 1092 |
+
return feature
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
def merge_splits(
|
| 1096 |
+
splits,
|
| 1097 |
+
num_splits=2,
|
| 1098 |
+
channel_last=False,
|
| 1099 |
+
):
|
| 1100 |
+
if channel_last: # [B*K*K, H/K, W/K, C]
|
| 1101 |
+
b, h, w, c = splits.size()
|
| 1102 |
+
new_b = b // num_splits // num_splits
|
| 1103 |
+
|
| 1104 |
+
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
| 1105 |
+
merge = (
|
| 1106 |
+
splits.permute(0, 1, 3, 2, 4, 5)
|
| 1107 |
+
.contiguous()
|
| 1108 |
+
.view(new_b, num_splits * h, num_splits * w, c)
|
| 1109 |
+
) # [B, H, W, C]
|
| 1110 |
+
else: # [B*K*K, C, H/K, W/K]
|
| 1111 |
+
b, c, h, w = splits.size()
|
| 1112 |
+
new_b = b // num_splits // num_splits
|
| 1113 |
+
|
| 1114 |
+
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
| 1115 |
+
merge = (
|
| 1116 |
+
splits.permute(0, 3, 1, 4, 2, 5)
|
| 1117 |
+
.contiguous()
|
| 1118 |
+
.view(new_b, c, num_splits * h, num_splits * w)
|
| 1119 |
+
) # [B, C, H, W]
|
| 1120 |
+
|
| 1121 |
+
return merge
|
| 1122 |
+
|
| 1123 |
+
|
| 1124 |
+
def normalize_img(img0, img1):
|
| 1125 |
+
# loaded images are in [0, 255]
|
| 1126 |
+
# normalize by ImageNet mean and std
|
| 1127 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
|
| 1128 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
|
| 1129 |
+
img0 = (img0 - mean) / std
|
| 1130 |
+
img1 = (img1 - mean) / std
|
| 1131 |
+
|
| 1132 |
+
return img0, img1
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
+
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
|
| 1136 |
+
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
| 1137 |
+
|
| 1138 |
+
if attn_splits > 1: # add position in splited window
|
| 1139 |
+
feature0_splits = split_feature(feature0, num_splits=attn_splits)
|
| 1140 |
+
feature1_splits = split_feature(feature1, num_splits=attn_splits)
|
| 1141 |
+
|
| 1142 |
+
position = pos_enc(feature0_splits)
|
| 1143 |
+
|
| 1144 |
+
feature0_splits = feature0_splits + position
|
| 1145 |
+
feature1_splits = feature1_splits + position
|
| 1146 |
+
|
| 1147 |
+
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
|
| 1148 |
+
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
|
| 1149 |
+
else:
|
| 1150 |
+
position = pos_enc(feature0)
|
| 1151 |
+
|
| 1152 |
+
feature0 = feature0 + position
|
| 1153 |
+
feature1 = feature1 + position
|
| 1154 |
+
|
| 1155 |
+
return feature0, feature1
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
class GMFlow(nn.Module):
|
| 1159 |
+
def __init__(
|
| 1160 |
+
self,
|
| 1161 |
+
num_scales=2,
|
| 1162 |
+
upsample_factor=4,
|
| 1163 |
+
feature_channels=128,
|
| 1164 |
+
attention_type="swin",
|
| 1165 |
+
num_transformer_layers=6,
|
| 1166 |
+
ffn_dim_expansion=4,
|
| 1167 |
+
num_head=1,
|
| 1168 |
+
**kwargs,
|
| 1169 |
+
):
|
| 1170 |
+
super(GMFlow, self).__init__()
|
| 1171 |
+
|
| 1172 |
+
self.num_scales = num_scales
|
| 1173 |
+
self.feature_channels = feature_channels
|
| 1174 |
+
self.upsample_factor = upsample_factor
|
| 1175 |
+
self.attention_type = attention_type
|
| 1176 |
+
self.num_transformer_layers = num_transformer_layers
|
| 1177 |
+
|
| 1178 |
+
# CNN backbone
|
| 1179 |
+
self.backbone = CNNEncoder(
|
| 1180 |
+
output_dim=feature_channels, num_output_scales=num_scales
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
# Transformer
|
| 1184 |
+
self.transformer = FeatureTransformer(
|
| 1185 |
+
num_layers=num_transformer_layers,
|
| 1186 |
+
d_model=feature_channels,
|
| 1187 |
+
nhead=num_head,
|
| 1188 |
+
attention_type=attention_type,
|
| 1189 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
# flow propagation with self-attn
|
| 1193 |
+
self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
|
| 1194 |
+
|
| 1195 |
+
# convex upsampling: concat feature0 and flow as input
|
| 1196 |
+
self.upsampler = nn.Sequential(
|
| 1197 |
+
nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
|
| 1198 |
+
nn.ReLU(inplace=True),
|
| 1199 |
+
nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
def extract_feature(self, img0, img1):
|
| 1203 |
+
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
| 1204 |
+
features = self.backbone(
|
| 1205 |
+
concat
|
| 1206 |
+
) # list of [2B, C, H, W], resolution from high to low
|
| 1207 |
+
|
| 1208 |
+
# reverse: resolution from low to high
|
| 1209 |
+
features = features[::-1]
|
| 1210 |
+
|
| 1211 |
+
feature0, feature1 = [], []
|
| 1212 |
+
|
| 1213 |
+
for i in range(len(features)):
|
| 1214 |
+
feature = features[i]
|
| 1215 |
+
chunks = torch.chunk(feature, 2, 0) # tuple
|
| 1216 |
+
feature0.append(chunks[0])
|
| 1217 |
+
feature1.append(chunks[1])
|
| 1218 |
+
|
| 1219 |
+
return feature0, feature1
|
| 1220 |
+
|
| 1221 |
+
def upsample_flow(
|
| 1222 |
+
self,
|
| 1223 |
+
flow,
|
| 1224 |
+
feature,
|
| 1225 |
+
bilinear=False,
|
| 1226 |
+
upsample_factor=8,
|
| 1227 |
+
):
|
| 1228 |
+
if bilinear:
|
| 1229 |
+
up_flow = (
|
| 1230 |
+
F.interpolate(
|
| 1231 |
+
flow,
|
| 1232 |
+
scale_factor=upsample_factor,
|
| 1233 |
+
mode="bilinear",
|
| 1234 |
+
align_corners=True,
|
| 1235 |
+
)
|
| 1236 |
+
* upsample_factor
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
else:
|
| 1240 |
+
# convex upsampling
|
| 1241 |
+
concat = torch.cat((flow, feature), dim=1)
|
| 1242 |
+
|
| 1243 |
+
mask = self.upsampler(concat)
|
| 1244 |
+
b, flow_channel, h, w = flow.shape
|
| 1245 |
+
mask = mask.view(
|
| 1246 |
+
b, 1, 9, self.upsample_factor, self.upsample_factor, h, w
|
| 1247 |
+
) # [B, 1, 9, K, K, H, W]
|
| 1248 |
+
mask = torch.softmax(mask, dim=2)
|
| 1249 |
+
|
| 1250 |
+
up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
|
| 1251 |
+
up_flow = up_flow.view(
|
| 1252 |
+
b, flow_channel, 9, 1, 1, h, w
|
| 1253 |
+
) # [B, 2, 9, 1, 1, H, W]
|
| 1254 |
+
|
| 1255 |
+
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
|
| 1256 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
|
| 1257 |
+
up_flow = up_flow.reshape(
|
| 1258 |
+
b, flow_channel, self.upsample_factor * h, self.upsample_factor * w
|
| 1259 |
+
) # [B, 2, K*H, K*W]
|
| 1260 |
+
|
| 1261 |
+
return up_flow
|
| 1262 |
+
|
| 1263 |
+
def forward(
|
| 1264 |
+
self,
|
| 1265 |
+
img0,
|
| 1266 |
+
img1,
|
| 1267 |
+
attn_splits_list=[2, 8],
|
| 1268 |
+
corr_radius_list=[-1, 4],
|
| 1269 |
+
prop_radius_list=[-1, 1],
|
| 1270 |
+
pred_bidir_flow=False,
|
| 1271 |
+
**kwargs,
|
| 1272 |
+
):
|
| 1273 |
+
img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
|
| 1274 |
+
|
| 1275 |
+
# resolution low to high
|
| 1276 |
+
feature0_list, feature1_list = self.extract_feature(
|
| 1277 |
+
img0, img1
|
| 1278 |
+
) # list of features
|
| 1279 |
+
|
| 1280 |
+
flow = None
|
| 1281 |
+
|
| 1282 |
+
assert (
|
| 1283 |
+
len(attn_splits_list)
|
| 1284 |
+
== len(corr_radius_list)
|
| 1285 |
+
== len(prop_radius_list)
|
| 1286 |
+
== self.num_scales
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
for scale_idx in range(self.num_scales):
|
| 1290 |
+
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
| 1291 |
+
|
| 1292 |
+
if pred_bidir_flow and scale_idx > 0:
|
| 1293 |
+
# predicting bidirectional flow with refinement
|
| 1294 |
+
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
|
| 1295 |
+
(feature1, feature0), dim=0
|
| 1296 |
+
)
|
| 1297 |
+
|
| 1298 |
+
upsample_factor = self.upsample_factor * (
|
| 1299 |
+
2 ** (self.num_scales - 1 - scale_idx)
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
if scale_idx > 0:
|
| 1303 |
+
flow = (
|
| 1304 |
+
F.interpolate(
|
| 1305 |
+
flow, scale_factor=2, mode="bilinear", align_corners=True
|
| 1306 |
+
)
|
| 1307 |
+
* 2
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
if flow is not None:
|
| 1311 |
+
flow = flow.detach()
|
| 1312 |
+
feature1 = flow_warp(feature1, flow) # [B, C, H, W]
|
| 1313 |
+
|
| 1314 |
+
attn_splits = attn_splits_list[scale_idx]
|
| 1315 |
+
corr_radius = corr_radius_list[scale_idx]
|
| 1316 |
+
prop_radius = prop_radius_list[scale_idx]
|
| 1317 |
+
|
| 1318 |
+
# add position to features
|
| 1319 |
+
feature0, feature1 = feature_add_position(
|
| 1320 |
+
feature0, feature1, attn_splits, self.feature_channels
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
# Transformer
|
| 1324 |
+
feature0, feature1 = self.transformer(
|
| 1325 |
+
feature0, feature1, attn_num_splits=attn_splits
|
| 1326 |
+
)
|
| 1327 |
+
|
| 1328 |
+
# correlation and softmax
|
| 1329 |
+
if corr_radius == -1: # global matching
|
| 1330 |
+
flow_pred = global_correlation_softmax(
|
| 1331 |
+
feature0, feature1, pred_bidir_flow
|
| 1332 |
+
)[0]
|
| 1333 |
+
else: # local matching
|
| 1334 |
+
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[
|
| 1335 |
+
0
|
| 1336 |
+
]
|
| 1337 |
+
|
| 1338 |
+
# flow or residual flow
|
| 1339 |
+
flow = flow + flow_pred if flow is not None else flow_pred
|
| 1340 |
+
|
| 1341 |
+
# upsample to the original resolution for supervison
|
| 1342 |
+
if (
|
| 1343 |
+
self.training
|
| 1344 |
+
): # only need to upsample intermediate flow predictions at training time
|
| 1345 |
+
flow_bilinear = self.upsample_flow(
|
| 1346 |
+
flow, None, bilinear=True, upsample_factor=upsample_factor
|
| 1347 |
+
)
|
| 1348 |
+
|
| 1349 |
+
# flow propagation with self-attn
|
| 1350 |
+
if pred_bidir_flow and scale_idx == 0:
|
| 1351 |
+
feature0 = torch.cat(
|
| 1352 |
+
(feature0, feature1), dim=0
|
| 1353 |
+
) # [2*B, C, H, W] for propagation
|
| 1354 |
+
flow = self.feature_flow_attn(
|
| 1355 |
+
feature0,
|
| 1356 |
+
flow.detach(),
|
| 1357 |
+
local_window_attn=prop_radius > 0,
|
| 1358 |
+
local_window_radius=prop_radius,
|
| 1359 |
+
)
|
| 1360 |
+
|
| 1361 |
+
# bilinear upsampling at training time except the last one
|
| 1362 |
+
if self.training and scale_idx < self.num_scales - 1:
|
| 1363 |
+
flow_up = self.upsample_flow(
|
| 1364 |
+
flow, feature0, bilinear=True, upsample_factor=upsample_factor
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
if scale_idx == self.num_scales - 1:
|
| 1368 |
+
flow_up = self.upsample_flow(flow, feature0)
|
| 1369 |
+
|
| 1370 |
+
return flow_up
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
backwarp_tenGrid = {}
|
| 1374 |
+
|
| 1375 |
+
|
| 1376 |
+
def backwarp(tenIn, tenflow):
|
| 1377 |
+
if str(tenflow.shape) not in backwarp_tenGrid:
|
| 1378 |
+
tenHor = (
|
| 1379 |
+
torch.linspace(
|
| 1380 |
+
start=-1.0,
|
| 1381 |
+
end=1.0,
|
| 1382 |
+
steps=tenflow.shape[3],
|
| 1383 |
+
dtype=tenflow.dtype,
|
| 1384 |
+
device=tenflow.device,
|
| 1385 |
+
)
|
| 1386 |
+
.view(1, 1, 1, -1)
|
| 1387 |
+
.repeat(1, 1, tenflow.shape[2], 1)
|
| 1388 |
+
)
|
| 1389 |
+
tenVer = (
|
| 1390 |
+
torch.linspace(
|
| 1391 |
+
start=-1.0,
|
| 1392 |
+
end=1.0,
|
| 1393 |
+
steps=tenflow.shape[2],
|
| 1394 |
+
dtype=tenflow.dtype,
|
| 1395 |
+
device=tenflow.device,
|
| 1396 |
+
)
|
| 1397 |
+
.view(1, 1, -1, 1)
|
| 1398 |
+
.repeat(1, 1, 1, tenflow.shape[3])
|
| 1399 |
+
)
|
| 1400 |
+
|
| 1401 |
+
backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).to(get_torch_device())
|
| 1402 |
+
# end
|
| 1403 |
+
|
| 1404 |
+
tenflow = torch.cat(
|
| 1405 |
+
[
|
| 1406 |
+
tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0),
|
| 1407 |
+
tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0),
|
| 1408 |
+
],
|
| 1409 |
+
1,
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
return torch.nn.functional.grid_sample(
|
| 1413 |
+
input=tenIn,
|
| 1414 |
+
grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1),
|
| 1415 |
+
mode="bilinear",
|
| 1416 |
+
padding_mode="zeros",
|
| 1417 |
+
align_corners=True,
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
|
| 1421 |
+
class MetricNet(nn.Module):
|
| 1422 |
+
def __init__(self):
|
| 1423 |
+
super(MetricNet, self).__init__()
|
| 1424 |
+
self.metric_in = nn.Conv2d(14, 64, 3, 1, 1)
|
| 1425 |
+
self.metric_net1 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
|
| 1426 |
+
self.metric_net2 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
|
| 1427 |
+
self.metric_net3 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
|
| 1428 |
+
self.metric_out = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 2, 3, 1, 1))
|
| 1429 |
+
|
| 1430 |
+
def forward(self, img0, img1, flow01, flow10):
|
| 1431 |
+
metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction="none").mean(
|
| 1432 |
+
[1], True
|
| 1433 |
+
)
|
| 1434 |
+
metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction="none").mean(
|
| 1435 |
+
[1], True
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10)
|
| 1439 |
+
|
| 1440 |
+
flow01 = torch.cat(
|
| 1441 |
+
[
|
| 1442 |
+
flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0),
|
| 1443 |
+
flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0),
|
| 1444 |
+
],
|
| 1445 |
+
1,
|
| 1446 |
+
)
|
| 1447 |
+
flow10 = torch.cat(
|
| 1448 |
+
[
|
| 1449 |
+
flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0),
|
| 1450 |
+
flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0),
|
| 1451 |
+
],
|
| 1452 |
+
1,
|
| 1453 |
+
)
|
| 1454 |
+
|
| 1455 |
+
img = torch.cat((img0, img1), 1)
|
| 1456 |
+
metric = torch.cat((-metric0, -metric1), 1)
|
| 1457 |
+
flow = torch.cat((flow01, flow10), 1)
|
| 1458 |
+
occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1)
|
| 1459 |
+
|
| 1460 |
+
feat = self.metric_in(torch.cat((img, metric, flow, occ), 1))
|
| 1461 |
+
feat = self.metric_net1(feat) + feat
|
| 1462 |
+
feat = self.metric_net2(feat) + feat
|
| 1463 |
+
feat = self.metric_net3(feat) + feat
|
| 1464 |
+
metric = self.metric_out(feat)
|
| 1465 |
+
|
| 1466 |
+
metric = torch.tanh(metric) * 10
|
| 1467 |
+
|
| 1468 |
+
return metric[:, :1], metric[:, 1:2]
|
| 1469 |
+
|
| 1470 |
+
|
| 1471 |
+
class FeatureNet(nn.Module):
|
| 1472 |
+
"""The quadratic model"""
|
| 1473 |
+
|
| 1474 |
+
def __init__(self):
|
| 1475 |
+
super(FeatureNet, self).__init__()
|
| 1476 |
+
self.block1 = nn.Sequential(
|
| 1477 |
+
nn.PReLU(),
|
| 1478 |
+
nn.Conv2d(3, 64, 3, 2, 1),
|
| 1479 |
+
nn.PReLU(),
|
| 1480 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 1481 |
+
)
|
| 1482 |
+
self.block2 = nn.Sequential(
|
| 1483 |
+
nn.PReLU(),
|
| 1484 |
+
nn.Conv2d(64, 128, 3, 2, 1),
|
| 1485 |
+
nn.PReLU(),
|
| 1486 |
+
nn.Conv2d(128, 128, 3, 1, 1),
|
| 1487 |
+
)
|
| 1488 |
+
self.block3 = nn.Sequential(
|
| 1489 |
+
nn.PReLU(),
|
| 1490 |
+
nn.Conv2d(128, 192, 3, 2, 1),
|
| 1491 |
+
nn.PReLU(),
|
| 1492 |
+
nn.Conv2d(192, 192, 3, 1, 1),
|
| 1493 |
+
)
|
| 1494 |
+
|
| 1495 |
+
def forward(self, x):
|
| 1496 |
+
x1 = self.block1(x)
|
| 1497 |
+
x2 = self.block2(x1)
|
| 1498 |
+
x3 = self.block3(x2)
|
| 1499 |
+
|
| 1500 |
+
return x1, x2, x3
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
# Residual Block
|
| 1504 |
+
def ResidualBlock(in_channels, out_channels, stride=1):
|
| 1505 |
+
return torch.nn.Sequential(
|
| 1506 |
+
nn.PReLU(),
|
| 1507 |
+
nn.Conv2d(
|
| 1508 |
+
in_channels,
|
| 1509 |
+
out_channels,
|
| 1510 |
+
kernel_size=3,
|
| 1511 |
+
stride=stride,
|
| 1512 |
+
padding=1,
|
| 1513 |
+
bias=True,
|
| 1514 |
+
),
|
| 1515 |
+
nn.PReLU(),
|
| 1516 |
+
nn.Conv2d(
|
| 1517 |
+
out_channels,
|
| 1518 |
+
out_channels,
|
| 1519 |
+
kernel_size=3,
|
| 1520 |
+
stride=stride,
|
| 1521 |
+
padding=1,
|
| 1522 |
+
bias=True,
|
| 1523 |
+
),
|
| 1524 |
+
)
|
| 1525 |
+
|
| 1526 |
+
|
| 1527 |
+
# downsample block
|
| 1528 |
+
def DownsampleBlock(in_channels, out_channels, stride=2):
|
| 1529 |
+
return torch.nn.Sequential(
|
| 1530 |
+
nn.PReLU(),
|
| 1531 |
+
nn.Conv2d(
|
| 1532 |
+
in_channels,
|
| 1533 |
+
out_channels,
|
| 1534 |
+
kernel_size=3,
|
| 1535 |
+
stride=stride,
|
| 1536 |
+
padding=1,
|
| 1537 |
+
bias=True,
|
| 1538 |
+
),
|
| 1539 |
+
nn.PReLU(),
|
| 1540 |
+
nn.Conv2d(
|
| 1541 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
|
| 1542 |
+
),
|
| 1543 |
+
)
|
| 1544 |
+
|
| 1545 |
+
|
| 1546 |
+
# upsample block
|
| 1547 |
+
def UpsampleBlock(in_channels, out_channels, stride=2):
|
| 1548 |
+
return torch.nn.Sequential(
|
| 1549 |
+
nn.PReLU(),
|
| 1550 |
+
nn.ConvTranspose2d(
|
| 1551 |
+
in_channels,
|
| 1552 |
+
out_channels,
|
| 1553 |
+
kernel_size=4,
|
| 1554 |
+
stride=stride,
|
| 1555 |
+
padding=1,
|
| 1556 |
+
bias=True,
|
| 1557 |
+
),
|
| 1558 |
+
nn.PReLU(),
|
| 1559 |
+
nn.Conv2d(
|
| 1560 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
|
| 1561 |
+
),
|
| 1562 |
+
)
|
| 1563 |
+
|
| 1564 |
+
|
| 1565 |
+
class PixelShuffleBlcok(nn.Module):
|
| 1566 |
+
def __init__(self, in_feat, num_feat, num_out_ch):
|
| 1567 |
+
super(PixelShuffleBlcok, self).__init__()
|
| 1568 |
+
self.conv_before_upsample = nn.Sequential(
|
| 1569 |
+
nn.Conv2d(in_feat, num_feat, 3, 1, 1), nn.PReLU()
|
| 1570 |
+
)
|
| 1571 |
+
self.upsample = nn.Sequential(
|
| 1572 |
+
nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), nn.PixelShuffle(2)
|
| 1573 |
+
)
|
| 1574 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
| 1575 |
+
|
| 1576 |
+
def forward(self, x):
|
| 1577 |
+
x = self.conv_before_upsample(x)
|
| 1578 |
+
x = self.conv_last(self.upsample(x))
|
| 1579 |
+
return x
|
| 1580 |
+
|
| 1581 |
+
|
| 1582 |
+
# grid network
|
| 1583 |
+
class GridNet(nn.Module):
|
| 1584 |
+
def __init__(
|
| 1585 |
+
self,
|
| 1586 |
+
in_channels=12,
|
| 1587 |
+
in_channels1=128,
|
| 1588 |
+
in_channels2=256,
|
| 1589 |
+
in_channels3=384,
|
| 1590 |
+
out_channels=3,
|
| 1591 |
+
):
|
| 1592 |
+
super(GridNet, self).__init__()
|
| 1593 |
+
|
| 1594 |
+
self.residual_model_head = ResidualBlock(in_channels, 64)
|
| 1595 |
+
self.residual_model_head1 = ResidualBlock(in_channels1, 64)
|
| 1596 |
+
self.residual_model_head2 = ResidualBlock(in_channels2, 128)
|
| 1597 |
+
self.residual_model_head3 = ResidualBlock(in_channels3, 192)
|
| 1598 |
+
|
| 1599 |
+
self.residual_model_01 = ResidualBlock(64, 64)
|
| 1600 |
+
# self.residual_model_02=ResidualBlock(64, 64)
|
| 1601 |
+
# self.residual_model_03=ResidualBlock(64, 64)
|
| 1602 |
+
self.residual_model_04 = ResidualBlock(64, 64)
|
| 1603 |
+
self.residual_model_05 = ResidualBlock(64, 64)
|
| 1604 |
+
self.residual_model_tail = PixelShuffleBlcok(64, 64, out_channels)
|
| 1605 |
+
|
| 1606 |
+
self.residual_model_11 = ResidualBlock(128, 128)
|
| 1607 |
+
# self.residual_model_12=ResidualBlock(128, 128)
|
| 1608 |
+
# self.residual_model_13=ResidualBlock(128, 128)
|
| 1609 |
+
self.residual_model_14 = ResidualBlock(128, 128)
|
| 1610 |
+
self.residual_model_15 = ResidualBlock(128, 128)
|
| 1611 |
+
|
| 1612 |
+
self.residual_model_21 = ResidualBlock(192, 192)
|
| 1613 |
+
# self.residual_model_22=ResidualBlock(192, 192)
|
| 1614 |
+
# self.residual_model_23=ResidualBlock(192, 192)
|
| 1615 |
+
self.residual_model_24 = ResidualBlock(192, 192)
|
| 1616 |
+
self.residual_model_25 = ResidualBlock(192, 192)
|
| 1617 |
+
|
| 1618 |
+
#
|
| 1619 |
+
|
| 1620 |
+
self.downsample_model_10 = DownsampleBlock(64, 128)
|
| 1621 |
+
self.downsample_model_20 = DownsampleBlock(128, 192)
|
| 1622 |
+
|
| 1623 |
+
self.downsample_model_11 = DownsampleBlock(64, 128)
|
| 1624 |
+
self.downsample_model_21 = DownsampleBlock(128, 192)
|
| 1625 |
+
|
| 1626 |
+
# self.downsample_model_12=DownsampleBlock(64, 128)
|
| 1627 |
+
# self.downsample_model_22=DownsampleBlock(128, 192)
|
| 1628 |
+
|
| 1629 |
+
#
|
| 1630 |
+
|
| 1631 |
+
# self.upsample_model_03=UpsampleBlock(128, 64)
|
| 1632 |
+
# self.upsample_model_13=UpsampleBlock(192, 128)
|
| 1633 |
+
|
| 1634 |
+
self.upsample_model_04 = UpsampleBlock(128, 64)
|
| 1635 |
+
self.upsample_model_14 = UpsampleBlock(192, 128)
|
| 1636 |
+
|
| 1637 |
+
self.upsample_model_05 = UpsampleBlock(128, 64)
|
| 1638 |
+
self.upsample_model_15 = UpsampleBlock(192, 128)
|
| 1639 |
+
|
| 1640 |
+
def forward(self, x, x1, x2, x3):
|
| 1641 |
+
X00 = self.residual_model_head(x) + self.residual_model_head1(
|
| 1642 |
+
x1
|
| 1643 |
+
) # --- 182 ~ 185
|
| 1644 |
+
# X10 = self.residual_model_head1(x1)
|
| 1645 |
+
|
| 1646 |
+
X01 = self.residual_model_01(X00) + X00 # --- 208 ~ 211 ,AddBackward1213
|
| 1647 |
+
|
| 1648 |
+
X10 = self.downsample_model_10(X00) + self.residual_model_head2(
|
| 1649 |
+
x2
|
| 1650 |
+
) # --- 186 ~ 189
|
| 1651 |
+
X20 = self.downsample_model_20(X10) + self.residual_model_head3(
|
| 1652 |
+
x3
|
| 1653 |
+
) # --- 190 ~ 193
|
| 1654 |
+
|
| 1655 |
+
residual_11 = (
|
| 1656 |
+
self.residual_model_11(X10) + X10
|
| 1657 |
+
) # 201 ~ 204 , sum AddBackward1206
|
| 1658 |
+
downsample_11 = self.downsample_model_11(X01) # 214 ~ 217
|
| 1659 |
+
X11 = residual_11 + downsample_11 # --- AddBackward1218
|
| 1660 |
+
|
| 1661 |
+
residual_21 = (
|
| 1662 |
+
self.residual_model_21(X20) + X20
|
| 1663 |
+
) # 194 ~ 197 , sum AddBackward1199
|
| 1664 |
+
downsample_21 = self.downsample_model_21(X11) # 219 ~ 222
|
| 1665 |
+
X21 = residual_21 + downsample_21 # AddBackward1223
|
| 1666 |
+
|
| 1667 |
+
X24 = self.residual_model_24(X21) + X21 # --- 224 ~ 227 , AddBackward1229
|
| 1668 |
+
X25 = self.residual_model_25(X24) + X24 # --- 230 ~ 233 , AddBackward1235
|
| 1669 |
+
|
| 1670 |
+
upsample_14 = self.upsample_model_14(X24) # 242 ~ 246
|
| 1671 |
+
residual_14 = self.residual_model_14(X11) + X11 # 248 ~ 251, AddBackward1253
|
| 1672 |
+
X14 = upsample_14 + residual_14 # --- AddBackward1254
|
| 1673 |
+
|
| 1674 |
+
upsample_04 = self.upsample_model_04(X14) # 268 ~ 272
|
| 1675 |
+
residual_04 = self.residual_model_04(X01) + X01 # 274 ~ 277, AddBackward1279
|
| 1676 |
+
X04 = upsample_04 + residual_04 # --- AddBackward1280
|
| 1677 |
+
|
| 1678 |
+
upsample_15 = self.upsample_model_15(X25) # 236 ~ 240
|
| 1679 |
+
residual_15 = self.residual_model_15(X14) + X14 # 255 ~ 258, AddBackward1260
|
| 1680 |
+
X15 = upsample_15 + residual_15 # AddBackward1261
|
| 1681 |
+
|
| 1682 |
+
upsample_05 = self.upsample_model_05(X15) # 262 ~ 266
|
| 1683 |
+
residual_05 = self.residual_model_05(X04) + X04 # 281 ~ 284,AddBackward1286
|
| 1684 |
+
X05 = upsample_05 + residual_05 # AddBackward1287
|
| 1685 |
+
|
| 1686 |
+
X_tail = self.residual_model_tail(X05) # 288 ~ 291
|
| 1687 |
+
|
| 1688 |
+
return X_tail
|
| 1689 |
+
# end
|
| 1690 |
+
|
| 1691 |
+
class Model:
|
| 1692 |
+
def __init__(self):
|
| 1693 |
+
self.flownet = GMFlow()
|
| 1694 |
+
self.metricnet = MetricNet()
|
| 1695 |
+
self.feat_ext = FeatureNet()
|
| 1696 |
+
self.fusionnet = GridNet()
|
| 1697 |
+
self.version = 3.9
|
| 1698 |
+
|
| 1699 |
+
def eval(self):
|
| 1700 |
+
self.flownet.eval()
|
| 1701 |
+
self.metricnet.eval()
|
| 1702 |
+
self.feat_ext.eval()
|
| 1703 |
+
self.fusionnet.eval()
|
| 1704 |
+
|
| 1705 |
+
def device(self):
|
| 1706 |
+
self.flownet.to(device)
|
| 1707 |
+
self.metricnet.to(device)
|
| 1708 |
+
self.feat_ext.to(device)
|
| 1709 |
+
self.fusionnet.to(device)
|
| 1710 |
+
|
| 1711 |
+
def load_model(self, path_dict):
|
| 1712 |
+
#models/GMFSS_fortuna_flownet.pkl
|
| 1713 |
+
self.flownet.load_state_dict(torch.load(path_dict["flownet"]))
|
| 1714 |
+
#models/GMFSS_fortuna_metric.pkl
|
| 1715 |
+
self.metricnet.load_state_dict(torch.load(path_dict["metricnet"]))
|
| 1716 |
+
#models/GMFSS_fortuna_feat.pkl
|
| 1717 |
+
self.feat_ext.load_state_dict(torch.load(path_dict["feat_ext"]))
|
| 1718 |
+
#models/GMFSS_fortuna_fusionnet.pkl
|
| 1719 |
+
self.fusionnet.load_state_dict(torch.load(path_dict["fusionnet"]))
|
| 1720 |
+
|
| 1721 |
+
def reuse(self, img0, img1, scale):
|
| 1722 |
+
feat11, feat12, feat13 = self.feat_ext(img0)
|
| 1723 |
+
feat21, feat22, feat23 = self.feat_ext(img1)
|
| 1724 |
+
|
| 1725 |
+
img0 = F.interpolate(
|
| 1726 |
+
img0, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1727 |
+
)
|
| 1728 |
+
img1 = F.interpolate(
|
| 1729 |
+
img1, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1730 |
+
)
|
| 1731 |
+
|
| 1732 |
+
if scale != 1.0:
|
| 1733 |
+
imgf0 = F.interpolate(
|
| 1734 |
+
img0, scale_factor=scale, mode="bilinear", align_corners=False
|
| 1735 |
+
)
|
| 1736 |
+
imgf1 = F.interpolate(
|
| 1737 |
+
img1, scale_factor=scale, mode="bilinear", align_corners=False
|
| 1738 |
+
)
|
| 1739 |
+
else:
|
| 1740 |
+
imgf0 = img0
|
| 1741 |
+
imgf1 = img1
|
| 1742 |
+
flow01 = self.flownet(imgf0, imgf1, return_flow=True)
|
| 1743 |
+
flow10 = self.flownet(imgf1, imgf0, return_flow=True)
|
| 1744 |
+
if scale != 1.0:
|
| 1745 |
+
flow01 = (
|
| 1746 |
+
F.interpolate(
|
| 1747 |
+
flow01,
|
| 1748 |
+
scale_factor=1.0 / scale,
|
| 1749 |
+
mode="bilinear",
|
| 1750 |
+
align_corners=False,
|
| 1751 |
+
)
|
| 1752 |
+
/ scale
|
| 1753 |
+
)
|
| 1754 |
+
flow10 = (
|
| 1755 |
+
F.interpolate(
|
| 1756 |
+
flow10,
|
| 1757 |
+
scale_factor=1.0 / scale,
|
| 1758 |
+
mode="bilinear",
|
| 1759 |
+
align_corners=False,
|
| 1760 |
+
)
|
| 1761 |
+
/ scale
|
| 1762 |
+
)
|
| 1763 |
+
|
| 1764 |
+
metric0, metric1 = self.metricnet(img0, img1, flow01, flow10)
|
| 1765 |
+
|
| 1766 |
+
return (
|
| 1767 |
+
flow01,
|
| 1768 |
+
flow10,
|
| 1769 |
+
metric0,
|
| 1770 |
+
metric1,
|
| 1771 |
+
feat11,
|
| 1772 |
+
feat12,
|
| 1773 |
+
feat13,
|
| 1774 |
+
feat21,
|
| 1775 |
+
feat22,
|
| 1776 |
+
feat23,
|
| 1777 |
+
)
|
| 1778 |
+
|
| 1779 |
+
def inference(
|
| 1780 |
+
self,
|
| 1781 |
+
img0,
|
| 1782 |
+
img1,
|
| 1783 |
+
flow01,
|
| 1784 |
+
flow10,
|
| 1785 |
+
metric0,
|
| 1786 |
+
metric1,
|
| 1787 |
+
feat11,
|
| 1788 |
+
feat12,
|
| 1789 |
+
feat13,
|
| 1790 |
+
feat21,
|
| 1791 |
+
feat22,
|
| 1792 |
+
feat23,
|
| 1793 |
+
timestep,
|
| 1794 |
+
):
|
| 1795 |
+
F1t = timestep * flow01
|
| 1796 |
+
F2t = (1 - timestep) * flow10
|
| 1797 |
+
|
| 1798 |
+
Z1t = timestep * metric0
|
| 1799 |
+
Z2t = (1 - timestep) * metric1
|
| 1800 |
+
|
| 1801 |
+
img0 = F.interpolate(
|
| 1802 |
+
img0, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1803 |
+
)
|
| 1804 |
+
I1t = softsplat(img0, F1t, Z1t, strMode="soft")
|
| 1805 |
+
img1 = F.interpolate(
|
| 1806 |
+
img1, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1807 |
+
)
|
| 1808 |
+
I2t = softsplat(img1, F2t, Z2t, strMode="soft")
|
| 1809 |
+
|
| 1810 |
+
feat1t1 = softsplat(feat11, F1t, Z1t, strMode="soft")
|
| 1811 |
+
feat2t1 = softsplat(feat21, F2t, Z2t, strMode="soft")
|
| 1812 |
+
|
| 1813 |
+
F1td = (
|
| 1814 |
+
F.interpolate(F1t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1815 |
+
* 0.5
|
| 1816 |
+
)
|
| 1817 |
+
Z1d = F.interpolate(Z1t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1818 |
+
feat1t2 = softsplat(feat12, F1td, Z1d, strMode="soft")
|
| 1819 |
+
F2td = (
|
| 1820 |
+
F.interpolate(F2t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1821 |
+
* 0.5
|
| 1822 |
+
)
|
| 1823 |
+
Z2d = F.interpolate(Z2t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1824 |
+
feat2t2 = softsplat(feat22, F2td, Z2d, strMode="soft")
|
| 1825 |
+
|
| 1826 |
+
F1tdd = (
|
| 1827 |
+
F.interpolate(F1t, scale_factor=0.25, mode="bilinear", align_corners=False)
|
| 1828 |
+
* 0.25
|
| 1829 |
+
)
|
| 1830 |
+
Z1dd = F.interpolate(
|
| 1831 |
+
Z1t, scale_factor=0.25, mode="bilinear", align_corners=False
|
| 1832 |
+
)
|
| 1833 |
+
feat1t3 = softsplat(feat13, F1tdd, Z1dd, strMode="soft")
|
| 1834 |
+
F2tdd = (
|
| 1835 |
+
F.interpolate(F2t, scale_factor=0.25, mode="bilinear", align_corners=False)
|
| 1836 |
+
* 0.25
|
| 1837 |
+
)
|
| 1838 |
+
Z2dd = F.interpolate(
|
| 1839 |
+
Z2t, scale_factor=0.25, mode="bilinear", align_corners=False
|
| 1840 |
+
)
|
| 1841 |
+
feat2t3 = softsplat(feat23, F2tdd, Z2dd, strMode="soft")
|
| 1842 |
+
|
| 1843 |
+
out = self.fusionnet(
|
| 1844 |
+
torch.cat([img0, I1t, I2t, img1], dim=1),
|
| 1845 |
+
torch.cat([feat1t1, feat2t1], dim=1),
|
| 1846 |
+
torch.cat([feat1t2, feat2t2], dim=1),
|
| 1847 |
+
torch.cat([feat1t3, feat2t3], dim=1),
|
| 1848 |
+
)
|
| 1849 |
+
|
| 1850 |
+
return torch.clamp(out, 0, 1)
|
vfi_models/gmfss_fortuna/GMFSS_Fortuna_union.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import numpy as np
|
| 3 |
+
import vapoursynth as vs
|
| 4 |
+
from .GMFSS_Fortuna_union_arch import Model_inference
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GMFSS_Fortuna_union:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.cache = False
|
| 11 |
+
self.amount_input_img = 2
|
| 12 |
+
|
| 13 |
+
torch.set_grad_enabled(False)
|
| 14 |
+
torch.backends.cudnn.enabled = True
|
| 15 |
+
torch.backends.cudnn.benchmark = True
|
| 16 |
+
|
| 17 |
+
self.model = Model_inference()
|
| 18 |
+
self.model.eval()
|
| 19 |
+
|
| 20 |
+
def execute(self, I0, I1, timestep):
|
| 21 |
+
with torch.inference_mode():
|
| 22 |
+
middle = self.model(I0, I1, timestep).cpu()
|
| 23 |
+
return middle
|
vfi_models/gmfss_fortuna/GMFSS_Fortuna_union_arch.py
ADDED
|
@@ -0,0 +1,1857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/GMFSS_infer_u.py
|
| 3 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/softsplat.py
|
| 4 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FusionNet_u.py
|
| 5 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/FeatureNet.py
|
| 6 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/MetricNet.py
|
| 7 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/IFNet_HDv3.py
|
| 8 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/gmflow.py
|
| 9 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/utils.py
|
| 10 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/position.py
|
| 11 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/geometry.py
|
| 12 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/matching.py
|
| 13 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/transformer.py
|
| 14 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/backbone.py
|
| 15 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/gmflow/trident_conv.py
|
| 16 |
+
https://github.com/98mxr/GMFSS_Fortuna/blob/b5d0bd544e3f1eee6a059e49c69bcd3124c8343c/model/warplayer.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from torch.nn.modules.utils import _pair
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import torch
|
| 26 |
+
import math
|
| 27 |
+
from vfi_models.rife.rife_arch import IFNet
|
| 28 |
+
from vfi_models.ops import softsplat
|
| 29 |
+
from comfy.model_management import get_torch_device
|
| 30 |
+
|
| 31 |
+
device = get_torch_device()
|
| 32 |
+
backwarp_tenGrid = {}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def warp(tenInput, tenFlow):
|
| 36 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
| 37 |
+
if k not in backwarp_tenGrid:
|
| 38 |
+
tenHorizontal = (
|
| 39 |
+
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
|
| 40 |
+
.view(1, 1, 1, tenFlow.shape[3])
|
| 41 |
+
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
| 42 |
+
)
|
| 43 |
+
tenVertical = (
|
| 44 |
+
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
|
| 45 |
+
.view(1, 1, tenFlow.shape[2], 1)
|
| 46 |
+
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
| 47 |
+
)
|
| 48 |
+
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
|
| 49 |
+
|
| 50 |
+
tenFlow = torch.cat(
|
| 51 |
+
[
|
| 52 |
+
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
| 53 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
|
| 54 |
+
],
|
| 55 |
+
1,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
| 59 |
+
return torch.nn.functional.grid_sample(
|
| 60 |
+
input=tenInput,
|
| 61 |
+
grid=g,
|
| 62 |
+
mode="bilinear",
|
| 63 |
+
padding_mode="border",
|
| 64 |
+
align_corners=True,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class MultiScaleTridentConv(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
in_channels,
|
| 72 |
+
out_channels,
|
| 73 |
+
kernel_size,
|
| 74 |
+
stride=1,
|
| 75 |
+
strides=1,
|
| 76 |
+
paddings=0,
|
| 77 |
+
dilations=1,
|
| 78 |
+
dilation=1,
|
| 79 |
+
groups=1,
|
| 80 |
+
num_branch=1,
|
| 81 |
+
test_branch_idx=-1,
|
| 82 |
+
bias=False,
|
| 83 |
+
norm=None,
|
| 84 |
+
activation=None,
|
| 85 |
+
):
|
| 86 |
+
super(MultiScaleTridentConv, self).__init__()
|
| 87 |
+
self.in_channels = in_channels
|
| 88 |
+
self.out_channels = out_channels
|
| 89 |
+
self.kernel_size = _pair(kernel_size)
|
| 90 |
+
self.num_branch = num_branch
|
| 91 |
+
self.stride = _pair(stride)
|
| 92 |
+
self.groups = groups
|
| 93 |
+
self.with_bias = bias
|
| 94 |
+
self.dilation = dilation
|
| 95 |
+
if isinstance(paddings, int):
|
| 96 |
+
paddings = [paddings] * self.num_branch
|
| 97 |
+
if isinstance(dilations, int):
|
| 98 |
+
dilations = [dilations] * self.num_branch
|
| 99 |
+
if isinstance(strides, int):
|
| 100 |
+
strides = [strides] * self.num_branch
|
| 101 |
+
self.paddings = [_pair(padding) for padding in paddings]
|
| 102 |
+
self.dilations = [_pair(dilation) for dilation in dilations]
|
| 103 |
+
self.strides = [_pair(stride) for stride in strides]
|
| 104 |
+
self.test_branch_idx = test_branch_idx
|
| 105 |
+
self.norm = norm
|
| 106 |
+
self.activation = activation
|
| 107 |
+
|
| 108 |
+
assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
|
| 109 |
+
|
| 110 |
+
self.weight = nn.Parameter(
|
| 111 |
+
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
| 112 |
+
)
|
| 113 |
+
if bias:
|
| 114 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 115 |
+
else:
|
| 116 |
+
self.bias = None
|
| 117 |
+
|
| 118 |
+
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
| 119 |
+
if self.bias is not None:
|
| 120 |
+
nn.init.constant_(self.bias, 0)
|
| 121 |
+
|
| 122 |
+
def forward(self, inputs):
|
| 123 |
+
num_branch = (
|
| 124 |
+
self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
| 125 |
+
)
|
| 126 |
+
assert len(inputs) == num_branch
|
| 127 |
+
|
| 128 |
+
if self.training or self.test_branch_idx == -1:
|
| 129 |
+
outputs = [
|
| 130 |
+
F.conv2d(
|
| 131 |
+
input,
|
| 132 |
+
self.weight,
|
| 133 |
+
self.bias,
|
| 134 |
+
stride,
|
| 135 |
+
padding,
|
| 136 |
+
self.dilation,
|
| 137 |
+
self.groups,
|
| 138 |
+
)
|
| 139 |
+
for input, stride, padding in zip(inputs, self.strides, self.paddings)
|
| 140 |
+
]
|
| 141 |
+
else:
|
| 142 |
+
outputs = [
|
| 143 |
+
F.conv2d(
|
| 144 |
+
inputs[0],
|
| 145 |
+
self.weight,
|
| 146 |
+
self.bias,
|
| 147 |
+
self.strides[self.test_branch_idx]
|
| 148 |
+
if self.test_branch_idx == -1
|
| 149 |
+
else self.strides[-1],
|
| 150 |
+
self.paddings[self.test_branch_idx]
|
| 151 |
+
if self.test_branch_idx == -1
|
| 152 |
+
else self.paddings[-1],
|
| 153 |
+
self.dilation,
|
| 154 |
+
self.groups,
|
| 155 |
+
)
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
if self.norm is not None:
|
| 159 |
+
outputs = [self.norm(x) for x in outputs]
|
| 160 |
+
if self.activation is not None:
|
| 161 |
+
outputs = [self.activation(x) for x in outputs]
|
| 162 |
+
return outputs
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ResidualBlock_class(nn.Module):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
in_planes,
|
| 169 |
+
planes,
|
| 170 |
+
norm_layer=nn.InstanceNorm2d,
|
| 171 |
+
stride=1,
|
| 172 |
+
dilation=1,
|
| 173 |
+
):
|
| 174 |
+
super(ResidualBlock_class, self).__init__()
|
| 175 |
+
|
| 176 |
+
self.conv1 = nn.Conv2d(
|
| 177 |
+
in_planes,
|
| 178 |
+
planes,
|
| 179 |
+
kernel_size=3,
|
| 180 |
+
dilation=dilation,
|
| 181 |
+
padding=dilation,
|
| 182 |
+
stride=stride,
|
| 183 |
+
bias=False,
|
| 184 |
+
)
|
| 185 |
+
self.conv2 = nn.Conv2d(
|
| 186 |
+
planes,
|
| 187 |
+
planes,
|
| 188 |
+
kernel_size=3,
|
| 189 |
+
dilation=dilation,
|
| 190 |
+
padding=dilation,
|
| 191 |
+
bias=False,
|
| 192 |
+
)
|
| 193 |
+
self.relu = nn.ReLU(inplace=True)
|
| 194 |
+
|
| 195 |
+
self.norm1 = norm_layer(planes)
|
| 196 |
+
self.norm2 = norm_layer(planes)
|
| 197 |
+
if not stride == 1 or in_planes != planes:
|
| 198 |
+
self.norm3 = norm_layer(planes)
|
| 199 |
+
|
| 200 |
+
if stride == 1 and in_planes == planes:
|
| 201 |
+
self.downsample = None
|
| 202 |
+
else:
|
| 203 |
+
self.downsample = nn.Sequential(
|
| 204 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
y = x
|
| 209 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 210 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 211 |
+
|
| 212 |
+
if self.downsample is not None:
|
| 213 |
+
x = self.downsample(x)
|
| 214 |
+
|
| 215 |
+
return self.relu(x + y)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class CNNEncoder(nn.Module):
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
output_dim=128,
|
| 222 |
+
norm_layer=nn.InstanceNorm2d,
|
| 223 |
+
num_output_scales=1,
|
| 224 |
+
**kwargs,
|
| 225 |
+
):
|
| 226 |
+
super(CNNEncoder, self).__init__()
|
| 227 |
+
self.num_branch = num_output_scales
|
| 228 |
+
|
| 229 |
+
feature_dims = [64, 96, 128]
|
| 230 |
+
|
| 231 |
+
self.conv1 = nn.Conv2d(
|
| 232 |
+
3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False
|
| 233 |
+
) # 1/2
|
| 234 |
+
self.norm1 = norm_layer(feature_dims[0])
|
| 235 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 236 |
+
|
| 237 |
+
self.in_planes = feature_dims[0]
|
| 238 |
+
self.layer1 = self._make_layer(
|
| 239 |
+
feature_dims[0], stride=1, norm_layer=norm_layer
|
| 240 |
+
) # 1/2
|
| 241 |
+
self.layer2 = self._make_layer(
|
| 242 |
+
feature_dims[1], stride=2, norm_layer=norm_layer
|
| 243 |
+
) # 1/4
|
| 244 |
+
|
| 245 |
+
# highest resolution 1/4 or 1/8
|
| 246 |
+
stride = 2 if num_output_scales == 1 else 1
|
| 247 |
+
self.layer3 = self._make_layer(
|
| 248 |
+
feature_dims[2],
|
| 249 |
+
stride=stride,
|
| 250 |
+
norm_layer=norm_layer,
|
| 251 |
+
) # 1/4 or 1/8
|
| 252 |
+
|
| 253 |
+
self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
|
| 254 |
+
|
| 255 |
+
if self.num_branch > 1:
|
| 256 |
+
if self.num_branch == 4:
|
| 257 |
+
strides = (1, 2, 4, 8)
|
| 258 |
+
elif self.num_branch == 3:
|
| 259 |
+
strides = (1, 2, 4)
|
| 260 |
+
elif self.num_branch == 2:
|
| 261 |
+
strides = (1, 2)
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError
|
| 264 |
+
|
| 265 |
+
self.trident_conv = MultiScaleTridentConv(
|
| 266 |
+
output_dim,
|
| 267 |
+
output_dim,
|
| 268 |
+
kernel_size=3,
|
| 269 |
+
strides=strides,
|
| 270 |
+
paddings=1,
|
| 271 |
+
num_branch=self.num_branch,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
for m in self.modules():
|
| 275 |
+
if isinstance(m, nn.Conv2d):
|
| 276 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 277 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 278 |
+
if m.weight is not None:
|
| 279 |
+
nn.init.constant_(m.weight, 1)
|
| 280 |
+
if m.bias is not None:
|
| 281 |
+
nn.init.constant_(m.bias, 0)
|
| 282 |
+
|
| 283 |
+
def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
|
| 284 |
+
layer1 = ResidualBlock_class(
|
| 285 |
+
self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation
|
| 286 |
+
)
|
| 287 |
+
layer2 = ResidualBlock_class(
|
| 288 |
+
dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
layers = (layer1, layer2)
|
| 292 |
+
|
| 293 |
+
self.in_planes = dim
|
| 294 |
+
return nn.Sequential(*layers)
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
x = self.conv1(x)
|
| 298 |
+
x = self.norm1(x)
|
| 299 |
+
x = self.relu1(x)
|
| 300 |
+
|
| 301 |
+
x = self.layer1(x) # 1/2
|
| 302 |
+
x = self.layer2(x) # 1/4
|
| 303 |
+
x = self.layer3(x) # 1/8 or 1/4
|
| 304 |
+
|
| 305 |
+
x = self.conv2(x)
|
| 306 |
+
|
| 307 |
+
if self.num_branch > 1:
|
| 308 |
+
out = self.trident_conv([x] * self.num_branch) # high to low res
|
| 309 |
+
else:
|
| 310 |
+
out = [x]
|
| 311 |
+
|
| 312 |
+
return out
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def single_head_full_attention(q, k, v):
|
| 316 |
+
# q, k, v: [B, L, C]
|
| 317 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
| 318 |
+
|
| 319 |
+
scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L]
|
| 320 |
+
attn = torch.softmax(scores, dim=2) # [B, L, L]
|
| 321 |
+
out = torch.matmul(attn, v) # [B, L, C]
|
| 322 |
+
|
| 323 |
+
return out
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def generate_shift_window_attn_mask(
|
| 327 |
+
input_resolution,
|
| 328 |
+
window_size_h,
|
| 329 |
+
window_size_w,
|
| 330 |
+
shift_size_h,
|
| 331 |
+
shift_size_w,
|
| 332 |
+
device=get_torch_device(),
|
| 333 |
+
):
|
| 334 |
+
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
| 335 |
+
# calculate attention mask for SW-MSA
|
| 336 |
+
h, w = input_resolution
|
| 337 |
+
img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
|
| 338 |
+
h_slices = (
|
| 339 |
+
slice(0, -window_size_h),
|
| 340 |
+
slice(-window_size_h, -shift_size_h),
|
| 341 |
+
slice(-shift_size_h, None),
|
| 342 |
+
)
|
| 343 |
+
w_slices = (
|
| 344 |
+
slice(0, -window_size_w),
|
| 345 |
+
slice(-window_size_w, -shift_size_w),
|
| 346 |
+
slice(-shift_size_w, None),
|
| 347 |
+
)
|
| 348 |
+
cnt = 0
|
| 349 |
+
for h in h_slices:
|
| 350 |
+
for w in w_slices:
|
| 351 |
+
img_mask[:, h, w, :] = cnt
|
| 352 |
+
cnt += 1
|
| 353 |
+
|
| 354 |
+
mask_windows = split_feature(
|
| 355 |
+
img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
|
| 359 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
| 360 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
|
| 361 |
+
attn_mask == 0, float(0.0)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
return attn_mask
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def single_head_split_window_attention(
|
| 368 |
+
q,
|
| 369 |
+
k,
|
| 370 |
+
v,
|
| 371 |
+
num_splits=1,
|
| 372 |
+
with_shift=False,
|
| 373 |
+
h=None,
|
| 374 |
+
w=None,
|
| 375 |
+
attn_mask=None,
|
| 376 |
+
):
|
| 377 |
+
# Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
|
| 378 |
+
# q, k, v: [B, L, C]
|
| 379 |
+
assert q.dim() == k.dim() == v.dim() == 3
|
| 380 |
+
|
| 381 |
+
assert h is not None and w is not None
|
| 382 |
+
assert q.size(1) == h * w
|
| 383 |
+
|
| 384 |
+
b, _, c = q.size()
|
| 385 |
+
|
| 386 |
+
b_new = b * num_splits * num_splits
|
| 387 |
+
|
| 388 |
+
window_size_h = h // num_splits
|
| 389 |
+
window_size_w = w // num_splits
|
| 390 |
+
|
| 391 |
+
q = q.view(b, h, w, c) # [B, H, W, C]
|
| 392 |
+
k = k.view(b, h, w, c)
|
| 393 |
+
v = v.view(b, h, w, c)
|
| 394 |
+
|
| 395 |
+
scale_factor = c**0.5
|
| 396 |
+
|
| 397 |
+
if with_shift:
|
| 398 |
+
assert attn_mask is not None # compute once
|
| 399 |
+
shift_size_h = window_size_h // 2
|
| 400 |
+
shift_size_w = window_size_w // 2
|
| 401 |
+
|
| 402 |
+
q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 403 |
+
k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 404 |
+
v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
|
| 405 |
+
|
| 406 |
+
q = split_feature(
|
| 407 |
+
q, num_splits=num_splits, channel_last=True
|
| 408 |
+
) # [B*K*K, H/K, W/K, C]
|
| 409 |
+
k = split_feature(k, num_splits=num_splits, channel_last=True)
|
| 410 |
+
v = split_feature(v, num_splits=num_splits, channel_last=True)
|
| 411 |
+
|
| 412 |
+
scores = (
|
| 413 |
+
torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1))
|
| 414 |
+
/ scale_factor
|
| 415 |
+
) # [B*K*K, H/K*W/K, H/K*W/K]
|
| 416 |
+
|
| 417 |
+
if with_shift:
|
| 418 |
+
scores += attn_mask.repeat(b, 1, 1)
|
| 419 |
+
|
| 420 |
+
attn = torch.softmax(scores, dim=-1)
|
| 421 |
+
|
| 422 |
+
out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
|
| 423 |
+
|
| 424 |
+
out = merge_splits(
|
| 425 |
+
out.view(b_new, h // num_splits, w // num_splits, c),
|
| 426 |
+
num_splits=num_splits,
|
| 427 |
+
channel_last=True,
|
| 428 |
+
) # [B, H, W, C]
|
| 429 |
+
|
| 430 |
+
# shift back
|
| 431 |
+
if with_shift:
|
| 432 |
+
out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
|
| 433 |
+
|
| 434 |
+
out = out.view(b, -1, c)
|
| 435 |
+
|
| 436 |
+
return out
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class TransformerLayer(nn.Module):
|
| 440 |
+
def __init__(
|
| 441 |
+
self,
|
| 442 |
+
d_model=256,
|
| 443 |
+
nhead=1,
|
| 444 |
+
attention_type="swin",
|
| 445 |
+
no_ffn=False,
|
| 446 |
+
ffn_dim_expansion=4,
|
| 447 |
+
with_shift=False,
|
| 448 |
+
**kwargs,
|
| 449 |
+
):
|
| 450 |
+
super(TransformerLayer, self).__init__()
|
| 451 |
+
|
| 452 |
+
self.dim = d_model
|
| 453 |
+
self.nhead = nhead
|
| 454 |
+
self.attention_type = attention_type
|
| 455 |
+
self.no_ffn = no_ffn
|
| 456 |
+
|
| 457 |
+
self.with_shift = with_shift
|
| 458 |
+
|
| 459 |
+
# multi-head attention
|
| 460 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
| 461 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
| 462 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
| 463 |
+
|
| 464 |
+
self.merge = nn.Linear(d_model, d_model, bias=False)
|
| 465 |
+
|
| 466 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 467 |
+
|
| 468 |
+
# no ffn after self-attn, with ffn after cross-attn
|
| 469 |
+
if not self.no_ffn:
|
| 470 |
+
in_channels = d_model * 2
|
| 471 |
+
self.mlp = nn.Sequential(
|
| 472 |
+
nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
|
| 473 |
+
nn.GELU(),
|
| 474 |
+
nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 478 |
+
|
| 479 |
+
def forward(
|
| 480 |
+
self,
|
| 481 |
+
source,
|
| 482 |
+
target,
|
| 483 |
+
height=None,
|
| 484 |
+
width=None,
|
| 485 |
+
shifted_window_attn_mask=None,
|
| 486 |
+
attn_num_splits=None,
|
| 487 |
+
**kwargs,
|
| 488 |
+
):
|
| 489 |
+
# source, target: [B, L, C]
|
| 490 |
+
query, key, value = source, target, target
|
| 491 |
+
|
| 492 |
+
# single-head attention
|
| 493 |
+
query = self.q_proj(query) # [B, L, C]
|
| 494 |
+
key = self.k_proj(key) # [B, L, C]
|
| 495 |
+
value = self.v_proj(value) # [B, L, C]
|
| 496 |
+
|
| 497 |
+
if self.attention_type == "swin" and attn_num_splits > 1:
|
| 498 |
+
if self.nhead > 1:
|
| 499 |
+
# we observe that multihead attention slows down the speed and increases the memory consumption
|
| 500 |
+
# without bringing obvious performance gains and thus the implementation is removed
|
| 501 |
+
raise NotImplementedError
|
| 502 |
+
else:
|
| 503 |
+
message = single_head_split_window_attention(
|
| 504 |
+
query,
|
| 505 |
+
key,
|
| 506 |
+
value,
|
| 507 |
+
num_splits=attn_num_splits,
|
| 508 |
+
with_shift=self.with_shift,
|
| 509 |
+
h=height,
|
| 510 |
+
w=width,
|
| 511 |
+
attn_mask=shifted_window_attn_mask,
|
| 512 |
+
)
|
| 513 |
+
else:
|
| 514 |
+
message = single_head_full_attention(query, key, value) # [B, L, C]
|
| 515 |
+
|
| 516 |
+
message = self.merge(message) # [B, L, C]
|
| 517 |
+
message = self.norm1(message)
|
| 518 |
+
|
| 519 |
+
if not self.no_ffn:
|
| 520 |
+
message = self.mlp(torch.cat([source, message], dim=-1))
|
| 521 |
+
message = self.norm2(message)
|
| 522 |
+
|
| 523 |
+
return source + message
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
class TransformerBlock(nn.Module):
|
| 527 |
+
"""self attention + cross attention + FFN"""
|
| 528 |
+
|
| 529 |
+
def __init__(
|
| 530 |
+
self,
|
| 531 |
+
d_model=256,
|
| 532 |
+
nhead=1,
|
| 533 |
+
attention_type="swin",
|
| 534 |
+
ffn_dim_expansion=4,
|
| 535 |
+
with_shift=False,
|
| 536 |
+
**kwargs,
|
| 537 |
+
):
|
| 538 |
+
super(TransformerBlock, self).__init__()
|
| 539 |
+
|
| 540 |
+
self.self_attn = TransformerLayer(
|
| 541 |
+
d_model=d_model,
|
| 542 |
+
nhead=nhead,
|
| 543 |
+
attention_type=attention_type,
|
| 544 |
+
no_ffn=True,
|
| 545 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 546 |
+
with_shift=with_shift,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
self.cross_attn_ffn = TransformerLayer(
|
| 550 |
+
d_model=d_model,
|
| 551 |
+
nhead=nhead,
|
| 552 |
+
attention_type=attention_type,
|
| 553 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 554 |
+
with_shift=with_shift,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
def forward(
|
| 558 |
+
self,
|
| 559 |
+
source,
|
| 560 |
+
target,
|
| 561 |
+
height=None,
|
| 562 |
+
width=None,
|
| 563 |
+
shifted_window_attn_mask=None,
|
| 564 |
+
attn_num_splits=None,
|
| 565 |
+
**kwargs,
|
| 566 |
+
):
|
| 567 |
+
# source, target: [B, L, C]
|
| 568 |
+
|
| 569 |
+
# self attention
|
| 570 |
+
source = self.self_attn(
|
| 571 |
+
source,
|
| 572 |
+
source,
|
| 573 |
+
height=height,
|
| 574 |
+
width=width,
|
| 575 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 576 |
+
attn_num_splits=attn_num_splits,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# cross attention and ffn
|
| 580 |
+
source = self.cross_attn_ffn(
|
| 581 |
+
source,
|
| 582 |
+
target,
|
| 583 |
+
height=height,
|
| 584 |
+
width=width,
|
| 585 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 586 |
+
attn_num_splits=attn_num_splits,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
return source
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class FeatureTransformer(nn.Module):
|
| 593 |
+
def __init__(
|
| 594 |
+
self,
|
| 595 |
+
num_layers=6,
|
| 596 |
+
d_model=128,
|
| 597 |
+
nhead=1,
|
| 598 |
+
attention_type="swin",
|
| 599 |
+
ffn_dim_expansion=4,
|
| 600 |
+
**kwargs,
|
| 601 |
+
):
|
| 602 |
+
super(FeatureTransformer, self).__init__()
|
| 603 |
+
|
| 604 |
+
self.attention_type = attention_type
|
| 605 |
+
|
| 606 |
+
self.d_model = d_model
|
| 607 |
+
self.nhead = nhead
|
| 608 |
+
|
| 609 |
+
self.layers = nn.ModuleList(
|
| 610 |
+
[
|
| 611 |
+
TransformerBlock(
|
| 612 |
+
d_model=d_model,
|
| 613 |
+
nhead=nhead,
|
| 614 |
+
attention_type=attention_type,
|
| 615 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 616 |
+
with_shift=True
|
| 617 |
+
if attention_type == "swin" and i % 2 == 1
|
| 618 |
+
else False,
|
| 619 |
+
)
|
| 620 |
+
for i in range(num_layers)
|
| 621 |
+
]
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
for p in self.parameters():
|
| 625 |
+
if p.dim() > 1:
|
| 626 |
+
nn.init.xavier_uniform_(p)
|
| 627 |
+
|
| 628 |
+
def forward(
|
| 629 |
+
self,
|
| 630 |
+
feature0,
|
| 631 |
+
feature1,
|
| 632 |
+
attn_num_splits=None,
|
| 633 |
+
**kwargs,
|
| 634 |
+
):
|
| 635 |
+
b, c, h, w = feature0.shape
|
| 636 |
+
assert self.d_model == c
|
| 637 |
+
|
| 638 |
+
feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
| 639 |
+
feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
|
| 640 |
+
|
| 641 |
+
if self.attention_type == "swin" and attn_num_splits > 1:
|
| 642 |
+
# global and refine use different number of splits
|
| 643 |
+
window_size_h = h // attn_num_splits
|
| 644 |
+
window_size_w = w // attn_num_splits
|
| 645 |
+
|
| 646 |
+
# compute attn mask once
|
| 647 |
+
shifted_window_attn_mask = generate_shift_window_attn_mask(
|
| 648 |
+
input_resolution=(h, w),
|
| 649 |
+
window_size_h=window_size_h,
|
| 650 |
+
window_size_w=window_size_w,
|
| 651 |
+
shift_size_h=window_size_h // 2,
|
| 652 |
+
shift_size_w=window_size_w // 2,
|
| 653 |
+
device=feature0.device,
|
| 654 |
+
) # [K*K, H/K*W/K, H/K*W/K]
|
| 655 |
+
else:
|
| 656 |
+
shifted_window_attn_mask = None
|
| 657 |
+
|
| 658 |
+
# concat feature0 and feature1 in batch dimension to compute in parallel
|
| 659 |
+
concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
|
| 660 |
+
concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
|
| 661 |
+
|
| 662 |
+
for layer in self.layers:
|
| 663 |
+
concat0 = layer(
|
| 664 |
+
concat0,
|
| 665 |
+
concat1,
|
| 666 |
+
height=h,
|
| 667 |
+
width=w,
|
| 668 |
+
shifted_window_attn_mask=shifted_window_attn_mask,
|
| 669 |
+
attn_num_splits=attn_num_splits,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# update feature1
|
| 673 |
+
concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
|
| 674 |
+
|
| 675 |
+
feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
|
| 676 |
+
|
| 677 |
+
# reshape back
|
| 678 |
+
feature0 = (
|
| 679 |
+
feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
| 680 |
+
) # [B, C, H, W]
|
| 681 |
+
feature1 = (
|
| 682 |
+
feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous()
|
| 683 |
+
) # [B, C, H, W]
|
| 684 |
+
|
| 685 |
+
return feature0, feature1
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
class FeatureFlowAttention(nn.Module):
|
| 689 |
+
"""
|
| 690 |
+
flow propagation with self-attention on feature
|
| 691 |
+
query: feature0, key: feature0, value: flow
|
| 692 |
+
"""
|
| 693 |
+
|
| 694 |
+
def __init__(
|
| 695 |
+
self,
|
| 696 |
+
in_channels,
|
| 697 |
+
**kwargs,
|
| 698 |
+
):
|
| 699 |
+
super(FeatureFlowAttention, self).__init__()
|
| 700 |
+
|
| 701 |
+
self.q_proj = nn.Linear(in_channels, in_channels)
|
| 702 |
+
self.k_proj = nn.Linear(in_channels, in_channels)
|
| 703 |
+
|
| 704 |
+
for p in self.parameters():
|
| 705 |
+
if p.dim() > 1:
|
| 706 |
+
nn.init.xavier_uniform_(p)
|
| 707 |
+
|
| 708 |
+
def forward(
|
| 709 |
+
self,
|
| 710 |
+
feature0,
|
| 711 |
+
flow,
|
| 712 |
+
local_window_attn=False,
|
| 713 |
+
local_window_radius=1,
|
| 714 |
+
**kwargs,
|
| 715 |
+
):
|
| 716 |
+
# q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
|
| 717 |
+
if local_window_attn:
|
| 718 |
+
return self.forward_local_window_attn(
|
| 719 |
+
feature0, flow, local_window_radius=local_window_radius
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
b, c, h, w = feature0.size()
|
| 723 |
+
|
| 724 |
+
query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
|
| 725 |
+
|
| 726 |
+
# a note: the ``correct'' implementation should be:
|
| 727 |
+
# ``query = self.q_proj(query), key = self.k_proj(query)''
|
| 728 |
+
# this problem is observed while cleaning up the code
|
| 729 |
+
# however, this doesn't affect the performance since the projection is a linear operation,
|
| 730 |
+
# thus the two projection matrices for key can be merged
|
| 731 |
+
# so I just leave it as is in order to not re-train all models :)
|
| 732 |
+
query = self.q_proj(query) # [B, H*W, C]
|
| 733 |
+
key = self.k_proj(query) # [B, H*W, C]
|
| 734 |
+
|
| 735 |
+
value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
|
| 736 |
+
|
| 737 |
+
scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) # [B, H*W, H*W]
|
| 738 |
+
prob = torch.softmax(scores, dim=-1)
|
| 739 |
+
|
| 740 |
+
out = torch.matmul(prob, value) # [B, H*W, 2]
|
| 741 |
+
out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
|
| 742 |
+
|
| 743 |
+
return out
|
| 744 |
+
|
| 745 |
+
def forward_local_window_attn(
|
| 746 |
+
self,
|
| 747 |
+
feature0,
|
| 748 |
+
flow,
|
| 749 |
+
local_window_radius=1,
|
| 750 |
+
):
|
| 751 |
+
assert flow.size(1) == 2
|
| 752 |
+
assert local_window_radius > 0
|
| 753 |
+
|
| 754 |
+
b, c, h, w = feature0.size()
|
| 755 |
+
|
| 756 |
+
feature0_reshape = self.q_proj(
|
| 757 |
+
feature0.view(b, c, -1).permute(0, 2, 1)
|
| 758 |
+
).reshape(
|
| 759 |
+
b * h * w, 1, c
|
| 760 |
+
) # [B*H*W, 1, C]
|
| 761 |
+
|
| 762 |
+
kernel_size = 2 * local_window_radius + 1
|
| 763 |
+
|
| 764 |
+
feature0_proj = (
|
| 765 |
+
self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1))
|
| 766 |
+
.permute(0, 2, 1)
|
| 767 |
+
.reshape(b, c, h, w)
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
feature0_window = F.unfold(
|
| 771 |
+
feature0_proj, kernel_size=kernel_size, padding=local_window_radius
|
| 772 |
+
) # [B, C*(2R+1)^2), H*W]
|
| 773 |
+
|
| 774 |
+
feature0_window = (
|
| 775 |
+
feature0_window.view(b, c, kernel_size**2, h, w)
|
| 776 |
+
.permute(0, 3, 4, 1, 2)
|
| 777 |
+
.reshape(b * h * w, c, kernel_size**2)
|
| 778 |
+
) # [B*H*W, C, (2R+1)^2]
|
| 779 |
+
|
| 780 |
+
flow_window = F.unfold(
|
| 781 |
+
flow, kernel_size=kernel_size, padding=local_window_radius
|
| 782 |
+
) # [B, 2*(2R+1)^2), H*W]
|
| 783 |
+
|
| 784 |
+
flow_window = (
|
| 785 |
+
flow_window.view(b, 2, kernel_size**2, h, w)
|
| 786 |
+
.permute(0, 3, 4, 2, 1)
|
| 787 |
+
.reshape(b * h * w, kernel_size**2, 2)
|
| 788 |
+
) # [B*H*W, (2R+1)^2, 2]
|
| 789 |
+
|
| 790 |
+
scores = torch.matmul(feature0_reshape, feature0_window) / (
|
| 791 |
+
c**0.5
|
| 792 |
+
) # [B*H*W, 1, (2R+1)^2]
|
| 793 |
+
|
| 794 |
+
prob = torch.softmax(scores, dim=-1)
|
| 795 |
+
|
| 796 |
+
out = (
|
| 797 |
+
torch.matmul(prob, flow_window)
|
| 798 |
+
.view(b, h, w, 2)
|
| 799 |
+
.permute(0, 3, 1, 2)
|
| 800 |
+
.contiguous()
|
| 801 |
+
) # [B, 2, H, W]
|
| 802 |
+
|
| 803 |
+
return out
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def global_correlation_softmax(
|
| 807 |
+
feature0,
|
| 808 |
+
feature1,
|
| 809 |
+
pred_bidir_flow=False,
|
| 810 |
+
):
|
| 811 |
+
# global correlation
|
| 812 |
+
b, c, h, w = feature0.shape
|
| 813 |
+
feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
|
| 814 |
+
feature1 = feature1.view(b, c, -1) # [B, C, H*W]
|
| 815 |
+
|
| 816 |
+
correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (
|
| 817 |
+
c**0.5
|
| 818 |
+
) # [B, H, W, H, W]
|
| 819 |
+
|
| 820 |
+
# flow from softmax
|
| 821 |
+
init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
|
| 822 |
+
grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 823 |
+
|
| 824 |
+
correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
|
| 825 |
+
|
| 826 |
+
if pred_bidir_flow:
|
| 827 |
+
correlation = torch.cat(
|
| 828 |
+
(correlation, correlation.permute(0, 2, 1)), dim=0
|
| 829 |
+
) # [2*B, H*W, H*W]
|
| 830 |
+
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
|
| 831 |
+
grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
|
| 832 |
+
b = b * 2
|
| 833 |
+
|
| 834 |
+
prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
|
| 835 |
+
|
| 836 |
+
correspondence = (
|
| 837 |
+
torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2)
|
| 838 |
+
) # [B, 2, H, W]
|
| 839 |
+
|
| 840 |
+
# when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
|
| 841 |
+
flow = correspondence - init_grid
|
| 842 |
+
|
| 843 |
+
return flow, prob
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
def local_correlation_softmax(
|
| 847 |
+
feature0,
|
| 848 |
+
feature1,
|
| 849 |
+
local_radius,
|
| 850 |
+
padding_mode="zeros",
|
| 851 |
+
):
|
| 852 |
+
b, c, h, w = feature0.size()
|
| 853 |
+
coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
|
| 854 |
+
coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
|
| 855 |
+
|
| 856 |
+
local_h = 2 * local_radius + 1
|
| 857 |
+
local_w = 2 * local_radius + 1
|
| 858 |
+
|
| 859 |
+
window_grid = generate_window_grid(
|
| 860 |
+
-local_radius,
|
| 861 |
+
local_radius,
|
| 862 |
+
-local_radius,
|
| 863 |
+
local_radius,
|
| 864 |
+
local_h,
|
| 865 |
+
local_w,
|
| 866 |
+
device=feature0.device,
|
| 867 |
+
) # [2R+1, 2R+1, 2]
|
| 868 |
+
window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
|
| 869 |
+
sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
|
| 870 |
+
|
| 871 |
+
sample_coords_softmax = sample_coords
|
| 872 |
+
|
| 873 |
+
# exclude coords that are out of image space
|
| 874 |
+
valid_x = (sample_coords[:, :, :, 0] >= 0) & (
|
| 875 |
+
sample_coords[:, :, :, 0] < w
|
| 876 |
+
) # [B, H*W, (2R+1)^2]
|
| 877 |
+
valid_y = (sample_coords[:, :, :, 1] >= 0) & (
|
| 878 |
+
sample_coords[:, :, :, 1] < h
|
| 879 |
+
) # [B, H*W, (2R+1)^2]
|
| 880 |
+
|
| 881 |
+
valid = (
|
| 882 |
+
valid_x & valid_y
|
| 883 |
+
) # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
|
| 884 |
+
|
| 885 |
+
# normalize coordinates to [-1, 1]
|
| 886 |
+
sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
|
| 887 |
+
window_feature = F.grid_sample(
|
| 888 |
+
feature1, sample_coords_norm, padding_mode=padding_mode, align_corners=True
|
| 889 |
+
).permute(
|
| 890 |
+
0, 2, 1, 3
|
| 891 |
+
) # [B, H*W, C, (2R+1)^2]
|
| 892 |
+
feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
|
| 893 |
+
|
| 894 |
+
corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (
|
| 895 |
+
c**0.5
|
| 896 |
+
) # [B, H*W, (2R+1)^2]
|
| 897 |
+
|
| 898 |
+
# mask invalid locations
|
| 899 |
+
corr[~valid] = -1e9
|
| 900 |
+
|
| 901 |
+
prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
|
| 902 |
+
|
| 903 |
+
correspondence = (
|
| 904 |
+
torch.matmul(prob.unsqueeze(-2), sample_coords_softmax)
|
| 905 |
+
.squeeze(-2)
|
| 906 |
+
.view(b, h, w, 2)
|
| 907 |
+
.permute(0, 3, 1, 2)
|
| 908 |
+
) # [B, 2, H, W]
|
| 909 |
+
|
| 910 |
+
flow = correspondence - coords_init
|
| 911 |
+
match_prob = prob
|
| 912 |
+
|
| 913 |
+
return flow, match_prob
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
def coords_grid(b, h, w, homogeneous=False, device=None):
|
| 917 |
+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
| 918 |
+
|
| 919 |
+
stacks = [x, y]
|
| 920 |
+
|
| 921 |
+
if homogeneous:
|
| 922 |
+
ones = torch.ones_like(x) # [H, W]
|
| 923 |
+
stacks.append(ones)
|
| 924 |
+
|
| 925 |
+
grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
|
| 926 |
+
|
| 927 |
+
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
| 928 |
+
|
| 929 |
+
if device is not None:
|
| 930 |
+
grid = grid.to(device)
|
| 931 |
+
|
| 932 |
+
return grid
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
|
| 936 |
+
assert device is not None
|
| 937 |
+
|
| 938 |
+
x, y = torch.meshgrid(
|
| 939 |
+
[
|
| 940 |
+
torch.linspace(w_min, w_max, len_w, device=device),
|
| 941 |
+
torch.linspace(h_min, h_max, len_h, device=device),
|
| 942 |
+
],
|
| 943 |
+
)
|
| 944 |
+
grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
|
| 945 |
+
|
| 946 |
+
return grid
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def normalize_coords(coords, h, w):
|
| 950 |
+
# coords: [B, H, W, 2]
|
| 951 |
+
c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device)
|
| 952 |
+
return (coords - c) / c # [-1, 1]
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
def bilinear_sample(
|
| 956 |
+
img, sample_coords, mode="bilinear", padding_mode="zeros", return_mask=False
|
| 957 |
+
):
|
| 958 |
+
# img: [B, C, H, W]
|
| 959 |
+
# sample_coords: [B, 2, H, W] in image scale
|
| 960 |
+
if sample_coords.size(1) != 2: # [B, H, W, 2]
|
| 961 |
+
sample_coords = sample_coords.permute(0, 3, 1, 2)
|
| 962 |
+
|
| 963 |
+
b, _, h, w = sample_coords.shape
|
| 964 |
+
|
| 965 |
+
# Normalize to [-1, 1]
|
| 966 |
+
x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
|
| 967 |
+
y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
|
| 968 |
+
|
| 969 |
+
grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
|
| 970 |
+
|
| 971 |
+
img = F.grid_sample(
|
| 972 |
+
img, grid, mode=mode, padding_mode=padding_mode, align_corners=True
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
if return_mask:
|
| 976 |
+
mask = (
|
| 977 |
+
(x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1)
|
| 978 |
+
) # [B, H, W]
|
| 979 |
+
|
| 980 |
+
return img, mask
|
| 981 |
+
|
| 982 |
+
return img
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
def flow_warp(feature, flow, mask=False, padding_mode="zeros"):
|
| 986 |
+
b, c, h, w = feature.size()
|
| 987 |
+
assert flow.size(1) == 2
|
| 988 |
+
|
| 989 |
+
grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
|
| 990 |
+
|
| 991 |
+
return bilinear_sample(feature, grid, padding_mode=padding_mode, return_mask=mask)
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):
|
| 995 |
+
# fwd_flow, bwd_flow: [B, 2, H, W]
|
| 996 |
+
# alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
|
| 997 |
+
assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
|
| 998 |
+
assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
|
| 999 |
+
flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
|
| 1000 |
+
|
| 1001 |
+
warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
|
| 1002 |
+
warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
|
| 1003 |
+
|
| 1004 |
+
diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
|
| 1005 |
+
diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
|
| 1006 |
+
|
| 1007 |
+
threshold = alpha * flow_mag + beta
|
| 1008 |
+
|
| 1009 |
+
fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
|
| 1010 |
+
bwd_occ = (diff_bwd > threshold).float()
|
| 1011 |
+
|
| 1012 |
+
return fwd_occ, bwd_occ
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
class PositionEmbeddingSine(nn.Module):
|
| 1016 |
+
"""
|
| 1017 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 1018 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 1019 |
+
"""
|
| 1020 |
+
|
| 1021 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
|
| 1022 |
+
super().__init__()
|
| 1023 |
+
self.num_pos_feats = num_pos_feats
|
| 1024 |
+
self.temperature = temperature
|
| 1025 |
+
self.normalize = normalize
|
| 1026 |
+
if scale is not None and normalize is False:
|
| 1027 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 1028 |
+
if scale is None:
|
| 1029 |
+
scale = 2 * math.pi
|
| 1030 |
+
self.scale = scale
|
| 1031 |
+
|
| 1032 |
+
def forward(self, x):
|
| 1033 |
+
# x = tensor_list.tensors # [B, C, H, W]
|
| 1034 |
+
# mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
|
| 1035 |
+
b, c, h, w = x.size()
|
| 1036 |
+
mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
|
| 1037 |
+
y_embed = mask.cumsum(1, dtype=torch.float32)
|
| 1038 |
+
x_embed = mask.cumsum(2, dtype=torch.float32)
|
| 1039 |
+
if self.normalize:
|
| 1040 |
+
eps = 1e-6
|
| 1041 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 1042 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 1043 |
+
|
| 1044 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 1045 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 1046 |
+
|
| 1047 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 1048 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 1049 |
+
pos_x = torch.stack(
|
| 1050 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 1051 |
+
).flatten(3)
|
| 1052 |
+
pos_y = torch.stack(
|
| 1053 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 1054 |
+
).flatten(3)
|
| 1055 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 1056 |
+
return pos
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def split_feature(
|
| 1060 |
+
feature,
|
| 1061 |
+
num_splits=2,
|
| 1062 |
+
channel_last=False,
|
| 1063 |
+
):
|
| 1064 |
+
if channel_last: # [B, H, W, C]
|
| 1065 |
+
b, h, w, c = feature.size()
|
| 1066 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
| 1067 |
+
|
| 1068 |
+
b_new = b * num_splits * num_splits
|
| 1069 |
+
h_new = h // num_splits
|
| 1070 |
+
w_new = w // num_splits
|
| 1071 |
+
|
| 1072 |
+
feature = (
|
| 1073 |
+
feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c)
|
| 1074 |
+
.permute(0, 1, 3, 2, 4, 5)
|
| 1075 |
+
.reshape(b_new, h_new, w_new, c)
|
| 1076 |
+
) # [B*K*K, H/K, W/K, C]
|
| 1077 |
+
else: # [B, C, H, W]
|
| 1078 |
+
b, c, h, w = feature.size()
|
| 1079 |
+
assert h % num_splits == 0 and w % num_splits == 0
|
| 1080 |
+
|
| 1081 |
+
b_new = b * num_splits * num_splits
|
| 1082 |
+
h_new = h // num_splits
|
| 1083 |
+
w_new = w // num_splits
|
| 1084 |
+
|
| 1085 |
+
feature = (
|
| 1086 |
+
feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits)
|
| 1087 |
+
.permute(0, 2, 4, 1, 3, 5)
|
| 1088 |
+
.reshape(b_new, c, h_new, w_new)
|
| 1089 |
+
) # [B*K*K, C, H/K, W/K]
|
| 1090 |
+
|
| 1091 |
+
return feature
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
def merge_splits(
|
| 1095 |
+
splits,
|
| 1096 |
+
num_splits=2,
|
| 1097 |
+
channel_last=False,
|
| 1098 |
+
):
|
| 1099 |
+
if channel_last: # [B*K*K, H/K, W/K, C]
|
| 1100 |
+
b, h, w, c = splits.size()
|
| 1101 |
+
new_b = b // num_splits // num_splits
|
| 1102 |
+
|
| 1103 |
+
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
|
| 1104 |
+
merge = (
|
| 1105 |
+
splits.permute(0, 1, 3, 2, 4, 5)
|
| 1106 |
+
.contiguous()
|
| 1107 |
+
.view(new_b, num_splits * h, num_splits * w, c)
|
| 1108 |
+
) # [B, H, W, C]
|
| 1109 |
+
else: # [B*K*K, C, H/K, W/K]
|
| 1110 |
+
b, c, h, w = splits.size()
|
| 1111 |
+
new_b = b // num_splits // num_splits
|
| 1112 |
+
|
| 1113 |
+
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
|
| 1114 |
+
merge = (
|
| 1115 |
+
splits.permute(0, 3, 1, 4, 2, 5)
|
| 1116 |
+
.contiguous()
|
| 1117 |
+
.view(new_b, c, num_splits * h, num_splits * w)
|
| 1118 |
+
) # [B, C, H, W]
|
| 1119 |
+
|
| 1120 |
+
return merge
|
| 1121 |
+
|
| 1122 |
+
|
| 1123 |
+
def normalize_img(img0, img1):
|
| 1124 |
+
# loaded images are in [0, 255]
|
| 1125 |
+
# normalize by ImageNet mean and std
|
| 1126 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
|
| 1127 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
|
| 1128 |
+
img0 = (img0 - mean) / std
|
| 1129 |
+
img1 = (img1 - mean) / std
|
| 1130 |
+
|
| 1131 |
+
return img0, img1
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
|
| 1135 |
+
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
|
| 1136 |
+
|
| 1137 |
+
if attn_splits > 1: # add position in splited window
|
| 1138 |
+
feature0_splits = split_feature(feature0, num_splits=attn_splits)
|
| 1139 |
+
feature1_splits = split_feature(feature1, num_splits=attn_splits)
|
| 1140 |
+
|
| 1141 |
+
position = pos_enc(feature0_splits)
|
| 1142 |
+
|
| 1143 |
+
feature0_splits = feature0_splits + position
|
| 1144 |
+
feature1_splits = feature1_splits + position
|
| 1145 |
+
|
| 1146 |
+
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
|
| 1147 |
+
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
|
| 1148 |
+
else:
|
| 1149 |
+
position = pos_enc(feature0)
|
| 1150 |
+
|
| 1151 |
+
feature0 = feature0 + position
|
| 1152 |
+
feature1 = feature1 + position
|
| 1153 |
+
|
| 1154 |
+
return feature0, feature1
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
class GMFlow(nn.Module):
|
| 1158 |
+
def __init__(
|
| 1159 |
+
self,
|
| 1160 |
+
num_scales=2,
|
| 1161 |
+
upsample_factor=4,
|
| 1162 |
+
feature_channels=128,
|
| 1163 |
+
attention_type="swin",
|
| 1164 |
+
num_transformer_layers=6,
|
| 1165 |
+
ffn_dim_expansion=4,
|
| 1166 |
+
num_head=1,
|
| 1167 |
+
**kwargs,
|
| 1168 |
+
):
|
| 1169 |
+
super(GMFlow, self).__init__()
|
| 1170 |
+
|
| 1171 |
+
self.num_scales = num_scales
|
| 1172 |
+
self.feature_channels = feature_channels
|
| 1173 |
+
self.upsample_factor = upsample_factor
|
| 1174 |
+
self.attention_type = attention_type
|
| 1175 |
+
self.num_transformer_layers = num_transformer_layers
|
| 1176 |
+
|
| 1177 |
+
# CNN backbone
|
| 1178 |
+
self.backbone = CNNEncoder(
|
| 1179 |
+
output_dim=feature_channels, num_output_scales=num_scales
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
# Transformer
|
| 1183 |
+
self.transformer = FeatureTransformer(
|
| 1184 |
+
num_layers=num_transformer_layers,
|
| 1185 |
+
d_model=feature_channels,
|
| 1186 |
+
nhead=num_head,
|
| 1187 |
+
attention_type=attention_type,
|
| 1188 |
+
ffn_dim_expansion=ffn_dim_expansion,
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
# flow propagation with self-attn
|
| 1192 |
+
self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels)
|
| 1193 |
+
|
| 1194 |
+
# convex upsampling: concat feature0 and flow as input
|
| 1195 |
+
self.upsampler = nn.Sequential(
|
| 1196 |
+
nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
|
| 1197 |
+
nn.ReLU(inplace=True),
|
| 1198 |
+
nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0),
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
def extract_feature(self, img0, img1):
|
| 1202 |
+
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
|
| 1203 |
+
features = self.backbone(
|
| 1204 |
+
concat
|
| 1205 |
+
) # list of [2B, C, H, W], resolution from high to low
|
| 1206 |
+
|
| 1207 |
+
# reverse: resolution from low to high
|
| 1208 |
+
features = features[::-1]
|
| 1209 |
+
|
| 1210 |
+
feature0, feature1 = [], []
|
| 1211 |
+
|
| 1212 |
+
for i in range(len(features)):
|
| 1213 |
+
feature = features[i]
|
| 1214 |
+
chunks = torch.chunk(feature, 2, 0) # tuple
|
| 1215 |
+
feature0.append(chunks[0])
|
| 1216 |
+
feature1.append(chunks[1])
|
| 1217 |
+
|
| 1218 |
+
return feature0, feature1
|
| 1219 |
+
|
| 1220 |
+
def upsample_flow(
|
| 1221 |
+
self,
|
| 1222 |
+
flow,
|
| 1223 |
+
feature,
|
| 1224 |
+
bilinear=False,
|
| 1225 |
+
upsample_factor=8,
|
| 1226 |
+
):
|
| 1227 |
+
if bilinear:
|
| 1228 |
+
up_flow = (
|
| 1229 |
+
F.interpolate(
|
| 1230 |
+
flow,
|
| 1231 |
+
scale_factor=upsample_factor,
|
| 1232 |
+
mode="bilinear",
|
| 1233 |
+
align_corners=True,
|
| 1234 |
+
)
|
| 1235 |
+
* upsample_factor
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
else:
|
| 1239 |
+
# convex upsampling
|
| 1240 |
+
concat = torch.cat((flow, feature), dim=1)
|
| 1241 |
+
|
| 1242 |
+
mask = self.upsampler(concat)
|
| 1243 |
+
b, flow_channel, h, w = flow.shape
|
| 1244 |
+
mask = mask.view(
|
| 1245 |
+
b, 1, 9, self.upsample_factor, self.upsample_factor, h, w
|
| 1246 |
+
) # [B, 1, 9, K, K, H, W]
|
| 1247 |
+
mask = torch.softmax(mask, dim=2)
|
| 1248 |
+
|
| 1249 |
+
up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
|
| 1250 |
+
up_flow = up_flow.view(
|
| 1251 |
+
b, flow_channel, 9, 1, 1, h, w
|
| 1252 |
+
) # [B, 2, 9, 1, 1, H, W]
|
| 1253 |
+
|
| 1254 |
+
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
|
| 1255 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
|
| 1256 |
+
up_flow = up_flow.reshape(
|
| 1257 |
+
b, flow_channel, self.upsample_factor * h, self.upsample_factor * w
|
| 1258 |
+
) # [B, 2, K*H, K*W]
|
| 1259 |
+
|
| 1260 |
+
return up_flow
|
| 1261 |
+
|
| 1262 |
+
def forward(
|
| 1263 |
+
self,
|
| 1264 |
+
img0,
|
| 1265 |
+
img1,
|
| 1266 |
+
attn_splits_list=[2, 8],
|
| 1267 |
+
corr_radius_list=[-1, 4],
|
| 1268 |
+
prop_radius_list=[-1, 1],
|
| 1269 |
+
pred_bidir_flow=False,
|
| 1270 |
+
**kwargs,
|
| 1271 |
+
):
|
| 1272 |
+
img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
|
| 1273 |
+
|
| 1274 |
+
# resolution low to high
|
| 1275 |
+
feature0_list, feature1_list = self.extract_feature(
|
| 1276 |
+
img0, img1
|
| 1277 |
+
) # list of features
|
| 1278 |
+
|
| 1279 |
+
flow = None
|
| 1280 |
+
|
| 1281 |
+
assert (
|
| 1282 |
+
len(attn_splits_list)
|
| 1283 |
+
== len(corr_radius_list)
|
| 1284 |
+
== len(prop_radius_list)
|
| 1285 |
+
== self.num_scales
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
for scale_idx in range(self.num_scales):
|
| 1289 |
+
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
|
| 1290 |
+
|
| 1291 |
+
if pred_bidir_flow and scale_idx > 0:
|
| 1292 |
+
# predicting bidirectional flow with refinement
|
| 1293 |
+
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat(
|
| 1294 |
+
(feature1, feature0), dim=0
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
upsample_factor = self.upsample_factor * (
|
| 1298 |
+
2 ** (self.num_scales - 1 - scale_idx)
|
| 1299 |
+
)
|
| 1300 |
+
|
| 1301 |
+
if scale_idx > 0:
|
| 1302 |
+
flow = (
|
| 1303 |
+
F.interpolate(
|
| 1304 |
+
flow, scale_factor=2, mode="bilinear", align_corners=True
|
| 1305 |
+
)
|
| 1306 |
+
* 2
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
if flow is not None:
|
| 1310 |
+
flow = flow.detach()
|
| 1311 |
+
feature1 = flow_warp(feature1, flow) # [B, C, H, W]
|
| 1312 |
+
|
| 1313 |
+
attn_splits = attn_splits_list[scale_idx]
|
| 1314 |
+
corr_radius = corr_radius_list[scale_idx]
|
| 1315 |
+
prop_radius = prop_radius_list[scale_idx]
|
| 1316 |
+
|
| 1317 |
+
# add position to features
|
| 1318 |
+
feature0, feature1 = feature_add_position(
|
| 1319 |
+
feature0, feature1, attn_splits, self.feature_channels
|
| 1320 |
+
)
|
| 1321 |
+
|
| 1322 |
+
# Transformer
|
| 1323 |
+
feature0, feature1 = self.transformer(
|
| 1324 |
+
feature0, feature1, attn_num_splits=attn_splits
|
| 1325 |
+
)
|
| 1326 |
+
|
| 1327 |
+
# correlation and softmax
|
| 1328 |
+
if corr_radius == -1: # global matching
|
| 1329 |
+
flow_pred = global_correlation_softmax(
|
| 1330 |
+
feature0, feature1, pred_bidir_flow
|
| 1331 |
+
)[0]
|
| 1332 |
+
else: # local matching
|
| 1333 |
+
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[
|
| 1334 |
+
0
|
| 1335 |
+
]
|
| 1336 |
+
|
| 1337 |
+
# flow or residual flow
|
| 1338 |
+
flow = flow + flow_pred if flow is not None else flow_pred
|
| 1339 |
+
|
| 1340 |
+
# upsample to the original resolution for supervison
|
| 1341 |
+
if (
|
| 1342 |
+
self.training
|
| 1343 |
+
): # only need to upsample intermediate flow predictions at training time
|
| 1344 |
+
flow_bilinear = self.upsample_flow(
|
| 1345 |
+
flow, None, bilinear=True, upsample_factor=upsample_factor
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
# flow propagation with self-attn
|
| 1349 |
+
if pred_bidir_flow and scale_idx == 0:
|
| 1350 |
+
feature0 = torch.cat(
|
| 1351 |
+
(feature0, feature1), dim=0
|
| 1352 |
+
) # [2*B, C, H, W] for propagation
|
| 1353 |
+
flow = self.feature_flow_attn(
|
| 1354 |
+
feature0,
|
| 1355 |
+
flow.detach(),
|
| 1356 |
+
local_window_attn=prop_radius > 0,
|
| 1357 |
+
local_window_radius=prop_radius,
|
| 1358 |
+
)
|
| 1359 |
+
|
| 1360 |
+
# bilinear upsampling at training time except the last one
|
| 1361 |
+
if self.training and scale_idx < self.num_scales - 1:
|
| 1362 |
+
flow_up = self.upsample_flow(
|
| 1363 |
+
flow, feature0, bilinear=True, upsample_factor=upsample_factor
|
| 1364 |
+
)
|
| 1365 |
+
|
| 1366 |
+
if scale_idx == self.num_scales - 1:
|
| 1367 |
+
flow_up = self.upsample_flow(flow, feature0)
|
| 1368 |
+
|
| 1369 |
+
return flow_up
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
backwarp_tenGrid = {}
|
| 1373 |
+
|
| 1374 |
+
|
| 1375 |
+
def backwarp(tenIn, tenflow):
|
| 1376 |
+
if str(tenflow.shape) not in backwarp_tenGrid:
|
| 1377 |
+
tenHor = (
|
| 1378 |
+
torch.linspace(
|
| 1379 |
+
start=-1.0,
|
| 1380 |
+
end=1.0,
|
| 1381 |
+
steps=tenflow.shape[3],
|
| 1382 |
+
dtype=tenflow.dtype,
|
| 1383 |
+
device=tenflow.device,
|
| 1384 |
+
)
|
| 1385 |
+
.view(1, 1, 1, -1)
|
| 1386 |
+
.repeat(1, 1, tenflow.shape[2], 1)
|
| 1387 |
+
)
|
| 1388 |
+
tenVer = (
|
| 1389 |
+
torch.linspace(
|
| 1390 |
+
start=-1.0,
|
| 1391 |
+
end=1.0,
|
| 1392 |
+
steps=tenflow.shape[2],
|
| 1393 |
+
dtype=tenflow.dtype,
|
| 1394 |
+
device=tenflow.device,
|
| 1395 |
+
)
|
| 1396 |
+
.view(1, 1, -1, 1)
|
| 1397 |
+
.repeat(1, 1, 1, tenflow.shape[3])
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).to(get_torch_device())
|
| 1401 |
+
# end
|
| 1402 |
+
|
| 1403 |
+
tenflow = torch.cat(
|
| 1404 |
+
[
|
| 1405 |
+
tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0),
|
| 1406 |
+
tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0),
|
| 1407 |
+
],
|
| 1408 |
+
1,
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
return torch.nn.functional.grid_sample(
|
| 1412 |
+
input=tenIn,
|
| 1413 |
+
grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1),
|
| 1414 |
+
mode="bilinear",
|
| 1415 |
+
padding_mode="zeros",
|
| 1416 |
+
align_corners=True,
|
| 1417 |
+
)
|
| 1418 |
+
|
| 1419 |
+
|
| 1420 |
+
class MetricNet(nn.Module):
|
| 1421 |
+
def __init__(self):
|
| 1422 |
+
super(MetricNet, self).__init__()
|
| 1423 |
+
self.metric_in = nn.Conv2d(14, 64, 3, 1, 1)
|
| 1424 |
+
self.metric_net1 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
|
| 1425 |
+
self.metric_net2 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
|
| 1426 |
+
self.metric_net3 = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 64, 3, 1, 1))
|
| 1427 |
+
self.metric_out = nn.Sequential(nn.PReLU(), nn.Conv2d(64, 2, 3, 1, 1))
|
| 1428 |
+
|
| 1429 |
+
def forward(self, img0, img1, flow01, flow10):
|
| 1430 |
+
metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction="none").mean(
|
| 1431 |
+
[1], True
|
| 1432 |
+
)
|
| 1433 |
+
metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction="none").mean(
|
| 1434 |
+
[1], True
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10)
|
| 1438 |
+
|
| 1439 |
+
flow01 = torch.cat(
|
| 1440 |
+
[
|
| 1441 |
+
flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0),
|
| 1442 |
+
flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0),
|
| 1443 |
+
],
|
| 1444 |
+
1,
|
| 1445 |
+
)
|
| 1446 |
+
flow10 = torch.cat(
|
| 1447 |
+
[
|
| 1448 |
+
flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0),
|
| 1449 |
+
flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0),
|
| 1450 |
+
],
|
| 1451 |
+
1,
|
| 1452 |
+
)
|
| 1453 |
+
|
| 1454 |
+
img = torch.cat((img0, img1), 1)
|
| 1455 |
+
metric = torch.cat((-metric0, -metric1), 1)
|
| 1456 |
+
flow = torch.cat((flow01, flow10), 1)
|
| 1457 |
+
occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1)
|
| 1458 |
+
|
| 1459 |
+
feat = self.metric_in(torch.cat((img, metric, flow, occ), 1))
|
| 1460 |
+
feat = self.metric_net1(feat) + feat
|
| 1461 |
+
feat = self.metric_net2(feat) + feat
|
| 1462 |
+
feat = self.metric_net3(feat) + feat
|
| 1463 |
+
metric = self.metric_out(feat)
|
| 1464 |
+
|
| 1465 |
+
metric = torch.tanh(metric) * 10
|
| 1466 |
+
|
| 1467 |
+
return metric[:, :1], metric[:, 1:2]
|
| 1468 |
+
|
| 1469 |
+
|
| 1470 |
+
class FeatureNet(nn.Module):
|
| 1471 |
+
"""The quadratic model"""
|
| 1472 |
+
|
| 1473 |
+
def __init__(self):
|
| 1474 |
+
super(FeatureNet, self).__init__()
|
| 1475 |
+
self.block1 = nn.Sequential(
|
| 1476 |
+
nn.PReLU(),
|
| 1477 |
+
nn.Conv2d(3, 64, 3, 2, 1),
|
| 1478 |
+
nn.PReLU(),
|
| 1479 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 1480 |
+
)
|
| 1481 |
+
self.block2 = nn.Sequential(
|
| 1482 |
+
nn.PReLU(),
|
| 1483 |
+
nn.Conv2d(64, 128, 3, 2, 1),
|
| 1484 |
+
nn.PReLU(),
|
| 1485 |
+
nn.Conv2d(128, 128, 3, 1, 1),
|
| 1486 |
+
)
|
| 1487 |
+
self.block3 = nn.Sequential(
|
| 1488 |
+
nn.PReLU(),
|
| 1489 |
+
nn.Conv2d(128, 192, 3, 2, 1),
|
| 1490 |
+
nn.PReLU(),
|
| 1491 |
+
nn.Conv2d(192, 192, 3, 1, 1),
|
| 1492 |
+
)
|
| 1493 |
+
|
| 1494 |
+
def forward(self, x):
|
| 1495 |
+
x1 = self.block1(x)
|
| 1496 |
+
x2 = self.block2(x1)
|
| 1497 |
+
x3 = self.block3(x2)
|
| 1498 |
+
|
| 1499 |
+
return x1, x2, x3
|
| 1500 |
+
|
| 1501 |
+
|
| 1502 |
+
# Residual Block
|
| 1503 |
+
def ResidualBlock(in_channels, out_channels, stride=1):
|
| 1504 |
+
return torch.nn.Sequential(
|
| 1505 |
+
nn.PReLU(),
|
| 1506 |
+
nn.Conv2d(
|
| 1507 |
+
in_channels,
|
| 1508 |
+
out_channels,
|
| 1509 |
+
kernel_size=3,
|
| 1510 |
+
stride=stride,
|
| 1511 |
+
padding=1,
|
| 1512 |
+
bias=True,
|
| 1513 |
+
),
|
| 1514 |
+
nn.PReLU(),
|
| 1515 |
+
nn.Conv2d(
|
| 1516 |
+
out_channels,
|
| 1517 |
+
out_channels,
|
| 1518 |
+
kernel_size=3,
|
| 1519 |
+
stride=stride,
|
| 1520 |
+
padding=1,
|
| 1521 |
+
bias=True,
|
| 1522 |
+
),
|
| 1523 |
+
)
|
| 1524 |
+
|
| 1525 |
+
|
| 1526 |
+
# downsample block
|
| 1527 |
+
def DownsampleBlock(in_channels, out_channels, stride=2):
|
| 1528 |
+
return torch.nn.Sequential(
|
| 1529 |
+
nn.PReLU(),
|
| 1530 |
+
nn.Conv2d(
|
| 1531 |
+
in_channels,
|
| 1532 |
+
out_channels,
|
| 1533 |
+
kernel_size=3,
|
| 1534 |
+
stride=stride,
|
| 1535 |
+
padding=1,
|
| 1536 |
+
bias=True,
|
| 1537 |
+
),
|
| 1538 |
+
nn.PReLU(),
|
| 1539 |
+
nn.Conv2d(
|
| 1540 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
|
| 1541 |
+
),
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
|
| 1545 |
+
# upsample block
|
| 1546 |
+
def UpsampleBlock(in_channels, out_channels, stride=2):
|
| 1547 |
+
return torch.nn.Sequential(
|
| 1548 |
+
nn.PReLU(),
|
| 1549 |
+
nn.ConvTranspose2d(
|
| 1550 |
+
in_channels,
|
| 1551 |
+
out_channels,
|
| 1552 |
+
kernel_size=4,
|
| 1553 |
+
stride=stride,
|
| 1554 |
+
padding=1,
|
| 1555 |
+
bias=True,
|
| 1556 |
+
),
|
| 1557 |
+
nn.PReLU(),
|
| 1558 |
+
nn.Conv2d(
|
| 1559 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
|
| 1560 |
+
),
|
| 1561 |
+
)
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
class PixelShuffleBlcok(nn.Module):
|
| 1565 |
+
def __init__(self, in_feat, num_feat, num_out_ch):
|
| 1566 |
+
super(PixelShuffleBlcok, self).__init__()
|
| 1567 |
+
self.conv_before_upsample = nn.Sequential(
|
| 1568 |
+
nn.Conv2d(in_feat, num_feat, 3, 1, 1), nn.PReLU()
|
| 1569 |
+
)
|
| 1570 |
+
self.upsample = nn.Sequential(
|
| 1571 |
+
nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), nn.PixelShuffle(2)
|
| 1572 |
+
)
|
| 1573 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
| 1574 |
+
|
| 1575 |
+
def forward(self, x):
|
| 1576 |
+
x = self.conv_before_upsample(x)
|
| 1577 |
+
x = self.conv_last(self.upsample(x))
|
| 1578 |
+
return x
|
| 1579 |
+
|
| 1580 |
+
|
| 1581 |
+
# grid network
|
| 1582 |
+
class GridNet(nn.Module):
|
| 1583 |
+
def __init__(
|
| 1584 |
+
self,
|
| 1585 |
+
in_channels=9,
|
| 1586 |
+
in_channels1=128,
|
| 1587 |
+
in_channels2=256,
|
| 1588 |
+
in_channels3=384,
|
| 1589 |
+
out_channels=3,
|
| 1590 |
+
):
|
| 1591 |
+
super(GridNet, self).__init__()
|
| 1592 |
+
|
| 1593 |
+
self.residual_model_head0 = ResidualBlock(in_channels, 64)
|
| 1594 |
+
self.residual_model_head1 = ResidualBlock(in_channels1, 64)
|
| 1595 |
+
self.residual_model_head2 = ResidualBlock(in_channels2, 128)
|
| 1596 |
+
self.residual_model_head3 = ResidualBlock(in_channels3, 192)
|
| 1597 |
+
|
| 1598 |
+
self.residual_model_01 = ResidualBlock(64, 64)
|
| 1599 |
+
# self.residual_model_02=ResidualBlock(64, 64)
|
| 1600 |
+
# self.residual_model_03=ResidualBlock(64, 64)
|
| 1601 |
+
self.residual_model_04 = ResidualBlock(64, 64)
|
| 1602 |
+
self.residual_model_05 = ResidualBlock(64, 64)
|
| 1603 |
+
self.residual_model_tail = PixelShuffleBlcok(64, 64, out_channels)
|
| 1604 |
+
|
| 1605 |
+
self.residual_model_11 = ResidualBlock(128, 128)
|
| 1606 |
+
# self.residual_model_12=ResidualBlock(128, 128)
|
| 1607 |
+
# self.residual_model_13=ResidualBlock(128, 128)
|
| 1608 |
+
self.residual_model_14 = ResidualBlock(128, 128)
|
| 1609 |
+
self.residual_model_15 = ResidualBlock(128, 128)
|
| 1610 |
+
|
| 1611 |
+
self.residual_model_21 = ResidualBlock(192, 192)
|
| 1612 |
+
# self.residual_model_22=ResidualBlock(192, 192)
|
| 1613 |
+
# self.residual_model_23=ResidualBlock(192, 192)
|
| 1614 |
+
self.residual_model_24 = ResidualBlock(192, 192)
|
| 1615 |
+
self.residual_model_25 = ResidualBlock(192, 192)
|
| 1616 |
+
|
| 1617 |
+
#
|
| 1618 |
+
|
| 1619 |
+
self.downsample_model_10 = DownsampleBlock(64, 128)
|
| 1620 |
+
self.downsample_model_20 = DownsampleBlock(128, 192)
|
| 1621 |
+
|
| 1622 |
+
self.downsample_model_11 = DownsampleBlock(64, 128)
|
| 1623 |
+
self.downsample_model_21 = DownsampleBlock(128, 192)
|
| 1624 |
+
|
| 1625 |
+
# self.downsample_model_12=DownsampleBlock(64, 128)
|
| 1626 |
+
# self.downsample_model_22=DownsampleBlock(128, 192)
|
| 1627 |
+
|
| 1628 |
+
#
|
| 1629 |
+
|
| 1630 |
+
# self.upsample_model_03=UpsampleBlock(128, 64)
|
| 1631 |
+
# self.upsample_model_13=UpsampleBlock(192, 128)
|
| 1632 |
+
|
| 1633 |
+
self.upsample_model_04 = UpsampleBlock(128, 64)
|
| 1634 |
+
self.upsample_model_14 = UpsampleBlock(192, 128)
|
| 1635 |
+
|
| 1636 |
+
self.upsample_model_05 = UpsampleBlock(128, 64)
|
| 1637 |
+
self.upsample_model_15 = UpsampleBlock(192, 128)
|
| 1638 |
+
|
| 1639 |
+
def forward(self, x, x1, x2, x3):
|
| 1640 |
+
X00 = self.residual_model_head0(x) + self.residual_model_head1(
|
| 1641 |
+
x1
|
| 1642 |
+
) # --- 182 ~ 185
|
| 1643 |
+
# X10 = self.residual_model_head1(x1)
|
| 1644 |
+
|
| 1645 |
+
X01 = self.residual_model_01(X00) + X00 # --- 208 ~ 211 ,AddBackward1213
|
| 1646 |
+
|
| 1647 |
+
X10 = self.downsample_model_10(X00) + self.residual_model_head2(
|
| 1648 |
+
x2
|
| 1649 |
+
) # --- 186 ~ 189
|
| 1650 |
+
X20 = self.downsample_model_20(X10) + self.residual_model_head3(
|
| 1651 |
+
x3
|
| 1652 |
+
) # --- 190 ~ 193
|
| 1653 |
+
|
| 1654 |
+
residual_11 = (
|
| 1655 |
+
self.residual_model_11(X10) + X10
|
| 1656 |
+
) # 201 ~ 204 , sum AddBackward1206
|
| 1657 |
+
downsample_11 = self.downsample_model_11(X01) # 214 ~ 217
|
| 1658 |
+
X11 = residual_11 + downsample_11 # --- AddBackward1218
|
| 1659 |
+
|
| 1660 |
+
residual_21 = (
|
| 1661 |
+
self.residual_model_21(X20) + X20
|
| 1662 |
+
) # 194 ~ 197 , sum AddBackward1199
|
| 1663 |
+
downsample_21 = self.downsample_model_21(X11) # 219 ~ 222
|
| 1664 |
+
X21 = residual_21 + downsample_21 # AddBackward1223
|
| 1665 |
+
|
| 1666 |
+
X24 = self.residual_model_24(X21) + X21 # --- 224 ~ 227 , AddBackward1229
|
| 1667 |
+
X25 = self.residual_model_25(X24) + X24 # --- 230 ~ 233 , AddBackward1235
|
| 1668 |
+
|
| 1669 |
+
upsample_14 = self.upsample_model_14(X24) # 242 ~ 246
|
| 1670 |
+
residual_14 = self.residual_model_14(X11) + X11 # 248 ~ 251, AddBackward1253
|
| 1671 |
+
X14 = upsample_14 + residual_14 # --- AddBackward1254
|
| 1672 |
+
|
| 1673 |
+
upsample_04 = self.upsample_model_04(X14) # 268 ~ 272
|
| 1674 |
+
residual_04 = self.residual_model_04(X01) + X01 # 274 ~ 277, AddBackward1279
|
| 1675 |
+
X04 = upsample_04 + residual_04 # --- AddBackward1280
|
| 1676 |
+
|
| 1677 |
+
upsample_15 = self.upsample_model_15(X25) # 236 ~ 240
|
| 1678 |
+
residual_15 = self.residual_model_15(X14) + X14 # 255 ~ 258, AddBackward1260
|
| 1679 |
+
X15 = upsample_15 + residual_15 # AddBackward1261
|
| 1680 |
+
|
| 1681 |
+
upsample_05 = self.upsample_model_05(X15) # 262 ~ 266
|
| 1682 |
+
residual_05 = self.residual_model_05(X04) + X04 # 281 ~ 284,AddBackward1286
|
| 1683 |
+
X05 = upsample_05 + residual_05 # AddBackward1287
|
| 1684 |
+
|
| 1685 |
+
X_tail = self.residual_model_tail(X05) # 288 ~ 291
|
| 1686 |
+
|
| 1687 |
+
return X_tail
|
| 1688 |
+
# end
|
| 1689 |
+
|
| 1690 |
+
|
| 1691 |
+
class Model:
|
| 1692 |
+
def __init__(self):
|
| 1693 |
+
self.flownet = GMFlow()
|
| 1694 |
+
self.ifnet = IFNet(arch_ver="4.6")
|
| 1695 |
+
self.metricnet = MetricNet()
|
| 1696 |
+
self.feat_ext = FeatureNet()
|
| 1697 |
+
self.fusionnet = GridNet()
|
| 1698 |
+
self.version = 3.9
|
| 1699 |
+
|
| 1700 |
+
def eval(self):
|
| 1701 |
+
self.flownet.eval()
|
| 1702 |
+
self.ifnet.eval()
|
| 1703 |
+
self.metricnet.eval()
|
| 1704 |
+
self.feat_ext.eval()
|
| 1705 |
+
self.fusionnet.eval()
|
| 1706 |
+
|
| 1707 |
+
def device(self):
|
| 1708 |
+
self.flownet.to(device)
|
| 1709 |
+
self.ifnet.to(device)
|
| 1710 |
+
self.metricnet.to(device)
|
| 1711 |
+
self.feat_ext.to(device)
|
| 1712 |
+
self.fusionnet.to(device)
|
| 1713 |
+
|
| 1714 |
+
def load_model(self, path_dict):
|
| 1715 |
+
#models/rife46.pth
|
| 1716 |
+
self.ifnet.load_state_dict(torch.load(path_dict["ifnet"]))
|
| 1717 |
+
#models/GMFSS_fortuna_flownet.pkl
|
| 1718 |
+
self.flownet.load_state_dict(torch.load(path_dict["flownet"]))
|
| 1719 |
+
#models/GMFSS_fortuna_union_metric.pkl
|
| 1720 |
+
self.metricnet.load_state_dict(torch.load(path_dict["metricnet"]))
|
| 1721 |
+
#models/GMFSS_fortuna_union_feat.pkl
|
| 1722 |
+
self.feat_ext.load_state_dict(torch.load(path_dict["feat_ext"]))
|
| 1723 |
+
#models/GMFSS_fortuna_union_fusionnet.pkl
|
| 1724 |
+
self.fusionnet.load_state_dict(torch.load(path_dict["fusionnet"]))
|
| 1725 |
+
|
| 1726 |
+
def reuse(self, img0, img1, scale):
|
| 1727 |
+
feat11, feat12, feat13 = self.feat_ext(img0)
|
| 1728 |
+
feat21, feat22, feat23 = self.feat_ext(img1)
|
| 1729 |
+
|
| 1730 |
+
img0 = F.interpolate(
|
| 1731 |
+
img0, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1732 |
+
)
|
| 1733 |
+
img1 = F.interpolate(
|
| 1734 |
+
img1, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1735 |
+
)
|
| 1736 |
+
|
| 1737 |
+
if scale != 1.0:
|
| 1738 |
+
imgf0 = F.interpolate(
|
| 1739 |
+
img0, scale_factor=scale, mode="bilinear", align_corners=False
|
| 1740 |
+
)
|
| 1741 |
+
imgf1 = F.interpolate(
|
| 1742 |
+
img1, scale_factor=scale, mode="bilinear", align_corners=False
|
| 1743 |
+
)
|
| 1744 |
+
else:
|
| 1745 |
+
imgf0 = img0
|
| 1746 |
+
imgf1 = img1
|
| 1747 |
+
flow01 = self.flownet(imgf0, imgf1, return_flow=True)
|
| 1748 |
+
flow10 = self.flownet(imgf1, imgf0, return_flow=True)
|
| 1749 |
+
if scale != 1.0:
|
| 1750 |
+
flow01 = (
|
| 1751 |
+
F.interpolate(
|
| 1752 |
+
flow01,
|
| 1753 |
+
scale_factor=1.0 / scale,
|
| 1754 |
+
mode="bilinear",
|
| 1755 |
+
align_corners=False,
|
| 1756 |
+
)
|
| 1757 |
+
/ scale
|
| 1758 |
+
)
|
| 1759 |
+
flow10 = (
|
| 1760 |
+
F.interpolate(
|
| 1761 |
+
flow10,
|
| 1762 |
+
scale_factor=1.0 / scale,
|
| 1763 |
+
mode="bilinear",
|
| 1764 |
+
align_corners=False,
|
| 1765 |
+
)
|
| 1766 |
+
/ scale
|
| 1767 |
+
)
|
| 1768 |
+
|
| 1769 |
+
metric0, metric1 = self.metricnet(img0, img1, flow01, flow10)
|
| 1770 |
+
|
| 1771 |
+
return (
|
| 1772 |
+
flow01,
|
| 1773 |
+
flow10,
|
| 1774 |
+
metric0,
|
| 1775 |
+
metric1,
|
| 1776 |
+
feat11,
|
| 1777 |
+
feat12,
|
| 1778 |
+
feat13,
|
| 1779 |
+
feat21,
|
| 1780 |
+
feat22,
|
| 1781 |
+
feat23,
|
| 1782 |
+
)
|
| 1783 |
+
|
| 1784 |
+
def inference(
|
| 1785 |
+
self,
|
| 1786 |
+
img0,
|
| 1787 |
+
img1,
|
| 1788 |
+
flow01,
|
| 1789 |
+
flow10,
|
| 1790 |
+
metric0,
|
| 1791 |
+
metric1,
|
| 1792 |
+
feat11,
|
| 1793 |
+
feat12,
|
| 1794 |
+
feat13,
|
| 1795 |
+
feat21,
|
| 1796 |
+
feat22,
|
| 1797 |
+
feat23,
|
| 1798 |
+
timestep,
|
| 1799 |
+
):
|
| 1800 |
+
F1t = timestep * flow01
|
| 1801 |
+
F2t = (1 - timestep) * flow10
|
| 1802 |
+
|
| 1803 |
+
Z1t = timestep * metric0
|
| 1804 |
+
Z2t = (1 - timestep) * metric1
|
| 1805 |
+
|
| 1806 |
+
img0 = F.interpolate(
|
| 1807 |
+
img0, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1808 |
+
)
|
| 1809 |
+
I1t = softsplat(img0, F1t, Z1t, strMode="soft")
|
| 1810 |
+
img1 = F.interpolate(
|
| 1811 |
+
img1, scale_factor=0.5, mode="bilinear", align_corners=False
|
| 1812 |
+
)
|
| 1813 |
+
I2t = softsplat(img1, F2t, Z2t, strMode="soft")
|
| 1814 |
+
|
| 1815 |
+
rife = self.ifnet(img0, img1, timestep, scale_list=[8, 4, 2, 1])
|
| 1816 |
+
|
| 1817 |
+
feat1t1 = softsplat(feat11, F1t, Z1t, strMode="soft")
|
| 1818 |
+
feat2t1 = softsplat(feat21, F2t, Z2t, strMode="soft")
|
| 1819 |
+
|
| 1820 |
+
F1td = (
|
| 1821 |
+
F.interpolate(F1t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1822 |
+
* 0.5
|
| 1823 |
+
)
|
| 1824 |
+
Z1d = F.interpolate(Z1t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1825 |
+
feat1t2 = softsplat(feat12, F1td, Z1d, strMode="soft")
|
| 1826 |
+
F2td = (
|
| 1827 |
+
F.interpolate(F2t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1828 |
+
* 0.5
|
| 1829 |
+
)
|
| 1830 |
+
Z2d = F.interpolate(Z2t, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 1831 |
+
feat2t2 = softsplat(feat22, F2td, Z2d, strMode="soft")
|
| 1832 |
+
|
| 1833 |
+
F1tdd = (
|
| 1834 |
+
F.interpolate(F1t, scale_factor=0.25, mode="bilinear", align_corners=False)
|
| 1835 |
+
* 0.25
|
| 1836 |
+
)
|
| 1837 |
+
Z1dd = F.interpolate(
|
| 1838 |
+
Z1t, scale_factor=0.25, mode="bilinear", align_corners=False
|
| 1839 |
+
)
|
| 1840 |
+
feat1t3 = softsplat(feat13, F1tdd, Z1dd, strMode="soft")
|
| 1841 |
+
F2tdd = (
|
| 1842 |
+
F.interpolate(F2t, scale_factor=0.25, mode="bilinear", align_corners=False)
|
| 1843 |
+
* 0.25
|
| 1844 |
+
)
|
| 1845 |
+
Z2dd = F.interpolate(
|
| 1846 |
+
Z2t, scale_factor=0.25, mode="bilinear", align_corners=False
|
| 1847 |
+
)
|
| 1848 |
+
feat2t3 = softsplat(feat23, F2tdd, Z2dd, strMode="soft")
|
| 1849 |
+
|
| 1850 |
+
out = self.fusionnet(
|
| 1851 |
+
torch.cat([I1t, rife, I2t], dim=1),
|
| 1852 |
+
torch.cat([feat1t1, feat2t1], dim=1),
|
| 1853 |
+
torch.cat([feat1t2, feat2t2], dim=1),
|
| 1854 |
+
torch.cat([feat1t3, feat2t3], dim=1),
|
| 1855 |
+
)
|
| 1856 |
+
|
| 1857 |
+
return torch.clamp(out, 0, 1)
|
vfi_models/gmfss_fortuna/__init__.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from vfi_utils import load_file_from_github_release, preprocess_frames, postprocess_frames, generic_frame_loop, InterpolationStateList
|
| 3 |
+
import typing
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from comfy.model_management import get_torch_device
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
GLOBAL_MODEL_TYPE = pathlib.Path(__file__).parent.name
|
| 11 |
+
CKPTS_PATH_CONFIG = {
|
| 12 |
+
"GMFSS_fortuna_union": {
|
| 13 |
+
"ifnet": ("rife", "rife46.pth"),
|
| 14 |
+
"flownet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_flownet.pkl"),
|
| 15 |
+
"metricnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_metric.pkl"),
|
| 16 |
+
"feat_ext": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_feat.pkl"),
|
| 17 |
+
"fusionnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_union_fusionnet.pkl")
|
| 18 |
+
},
|
| 19 |
+
"GMFSS_fortuna": {
|
| 20 |
+
"flownet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_flownet.pkl"),
|
| 21 |
+
"metricnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_metric.pkl"),
|
| 22 |
+
"feat_ext": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_feat.pkl"),
|
| 23 |
+
"fusionnet": (GLOBAL_MODEL_TYPE, "GMFSS_fortuna_fusionnet.pkl")
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
class CommonModelInference(nn.Module):
|
| 28 |
+
def __init__(self, model_type):
|
| 29 |
+
super(CommonModelInference, self).__init__()
|
| 30 |
+
from .GMFSS_Fortuna_arch import Model as GMFSS
|
| 31 |
+
from .GMFSS_Fortuna_union_arch import Model as GMFSS_Union
|
| 32 |
+
self.model = GMFSS_Union() if "union" in model_type else GMFSS()
|
| 33 |
+
self.model.eval()
|
| 34 |
+
self.model.device()
|
| 35 |
+
_model_path_config = CKPTS_PATH_CONFIG[model_type]
|
| 36 |
+
self.model.load_model({
|
| 37 |
+
key: load_file_from_github_release(*_model_path_config[key])
|
| 38 |
+
for key in _model_path_config
|
| 39 |
+
})
|
| 40 |
+
|
| 41 |
+
def forward(self, I0, I1, timestep, scale=1.0):
|
| 42 |
+
n, c, h, w = I0.shape
|
| 43 |
+
tmp = max(64, int(64 / scale))
|
| 44 |
+
ph = ((h - 1) // tmp + 1) * tmp
|
| 45 |
+
pw = ((w - 1) // tmp + 1) * tmp
|
| 46 |
+
padding = (0, pw - w, 0, ph - h)
|
| 47 |
+
I0 = F.pad(I0, padding)
|
| 48 |
+
I1 = F.pad(I1, padding)
|
| 49 |
+
(
|
| 50 |
+
flow01,
|
| 51 |
+
flow10,
|
| 52 |
+
metric0,
|
| 53 |
+
metric1,
|
| 54 |
+
feat11,
|
| 55 |
+
feat12,
|
| 56 |
+
feat13,
|
| 57 |
+
feat21,
|
| 58 |
+
feat22,
|
| 59 |
+
feat23,
|
| 60 |
+
) = self.model.reuse(I0, I1, scale)
|
| 61 |
+
|
| 62 |
+
output = self.model.inference(
|
| 63 |
+
I0,
|
| 64 |
+
I1,
|
| 65 |
+
flow01,
|
| 66 |
+
flow10,
|
| 67 |
+
metric0,
|
| 68 |
+
metric1,
|
| 69 |
+
feat11,
|
| 70 |
+
feat12,
|
| 71 |
+
feat13,
|
| 72 |
+
feat21,
|
| 73 |
+
feat22,
|
| 74 |
+
feat23,
|
| 75 |
+
timestep,
|
| 76 |
+
)
|
| 77 |
+
return output[:, :, :h, :w]
|
| 78 |
+
|
| 79 |
+
class GMFSS_Fortuna_VFI:
|
| 80 |
+
@classmethod
|
| 81 |
+
def INPUT_TYPES(s):
|
| 82 |
+
return {
|
| 83 |
+
"required": {
|
| 84 |
+
"ckpt_name": (list(CKPTS_PATH_CONFIG.keys()), ),
|
| 85 |
+
"frames": ("IMAGE", ),
|
| 86 |
+
"clear_cache_after_n_frames": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
| 87 |
+
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000}),
|
| 88 |
+
},
|
| 89 |
+
"optional": {
|
| 90 |
+
"optional_interpolation_states": ("INTERPOLATION_STATES", )
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
RETURN_TYPES = ("IMAGE", )
|
| 95 |
+
FUNCTION = "vfi"
|
| 96 |
+
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
|
| 97 |
+
|
| 98 |
+
def vfi(
|
| 99 |
+
self,
|
| 100 |
+
ckpt_name: typing.AnyStr,
|
| 101 |
+
frames: torch.Tensor,
|
| 102 |
+
clear_cache_after_n_frames = 10,
|
| 103 |
+
multiplier: typing.SupportsInt = 2,
|
| 104 |
+
optional_interpolation_states: InterpolationStateList = None,
|
| 105 |
+
**kwargs
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Perform video frame interpolation using a given checkpoint model.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
ckpt_name (str): The name of the checkpoint model to use.
|
| 112 |
+
frames (torch.Tensor): A tensor containing input video frames.
|
| 113 |
+
clear_cache_after_n_frames (int, optional): The number of frames to process before clearing CUDA cache
|
| 114 |
+
to prevent memory overflow. Defaults to 10. Lower numbers are safer but mean more processing time.
|
| 115 |
+
How high you should set it depends on how many input frames there are, input resolution (after upscaling),
|
| 116 |
+
how many times you want to multiply them, and how long you're willing to wait for the process to complete.
|
| 117 |
+
multiplier (int, optional): The multiplier for each input frame. 60 input frames * 2 = 120 output frames. Defaults to 2.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
tuple: A tuple containing the output interpolated frames.
|
| 121 |
+
|
| 122 |
+
Note:
|
| 123 |
+
This method interpolates frames in a video sequence using a specified checkpoint model.
|
| 124 |
+
It processes each frame sequentially, generating interpolated frames between them.
|
| 125 |
+
|
| 126 |
+
To prevent memory overflow, it clears the CUDA cache after processing a specified number of frames.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
interpolation_model = CommonModelInference(model_type=ckpt_name)
|
| 130 |
+
interpolation_model.eval().to(get_torch_device())
|
| 131 |
+
frames = preprocess_frames(frames)
|
| 132 |
+
|
| 133 |
+
def return_middle_frame(frame_0, frame_1, timestep, model, scale):
|
| 134 |
+
return model(frame_0, frame_1, timestep, scale)
|
| 135 |
+
|
| 136 |
+
scale = 1
|
| 137 |
+
|
| 138 |
+
args = [interpolation_model, scale]
|
| 139 |
+
out = postprocess_frames(
|
| 140 |
+
generic_frame_loop(type(self).__name__, frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
|
| 141 |
+
interpolation_states=optional_interpolation_states, dtype=torch.float32)
|
| 142 |
+
)
|
| 143 |
+
return (out,)
|
vfi_models/ifrnet/IFRNet_L_arch.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/ltkong218/IFRNet/blob/main/models/IFRNet_L.py
|
| 2 |
+
# https://github.com/ltkong218/IFRNet/blob/main/utils.py
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from comfy.model_management import get_torch_device
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def warp(img, flow):
|
| 10 |
+
B, _, H, W = flow.shape
|
| 11 |
+
xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
|
| 12 |
+
yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
|
| 13 |
+
grid = torch.cat([xx, yy], 1).to(img)
|
| 14 |
+
flow_ = torch.cat(
|
| 15 |
+
[
|
| 16 |
+
flow[:, 0:1, :, :] / ((W - 1.0) / 2.0),
|
| 17 |
+
flow[:, 1:2, :, :] / ((H - 1.0) / 2.0),
|
| 18 |
+
],
|
| 19 |
+
1,
|
| 20 |
+
)
|
| 21 |
+
grid_ = (grid + flow_).permute(0, 2, 3, 1)
|
| 22 |
+
output = F.grid_sample(
|
| 23 |
+
input=img,
|
| 24 |
+
grid=grid_,
|
| 25 |
+
mode="bilinear",
|
| 26 |
+
padding_mode="border",
|
| 27 |
+
align_corners=True,
|
| 28 |
+
)
|
| 29 |
+
return output
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_robust_weight(flow_pred, flow_gt, beta):
|
| 33 |
+
epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=1, keepdim=True) ** 0.5
|
| 34 |
+
robust_weight = torch.exp(-beta * epe)
|
| 35 |
+
return robust_weight
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def resize(x, scale_factor):
|
| 39 |
+
return F.interpolate(
|
| 40 |
+
x, scale_factor=scale_factor, mode="bilinear", align_corners=False
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def convrelu(
|
| 45 |
+
in_channels,
|
| 46 |
+
out_channels,
|
| 47 |
+
kernel_size=3,
|
| 48 |
+
stride=1,
|
| 49 |
+
padding=1,
|
| 50 |
+
dilation=1,
|
| 51 |
+
groups=1,
|
| 52 |
+
bias=True,
|
| 53 |
+
):
|
| 54 |
+
return nn.Sequential(
|
| 55 |
+
nn.Conv2d(
|
| 56 |
+
in_channels,
|
| 57 |
+
out_channels,
|
| 58 |
+
kernel_size,
|
| 59 |
+
stride,
|
| 60 |
+
padding,
|
| 61 |
+
dilation,
|
| 62 |
+
groups,
|
| 63 |
+
bias=bias,
|
| 64 |
+
),
|
| 65 |
+
nn.PReLU(out_channels),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ResBlock(nn.Module):
|
| 70 |
+
def __init__(self, in_channels, side_channels, bias=True):
|
| 71 |
+
super(ResBlock, self).__init__()
|
| 72 |
+
self.side_channels = side_channels
|
| 73 |
+
self.conv1 = nn.Sequential(
|
| 74 |
+
nn.Conv2d(
|
| 75 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
| 76 |
+
),
|
| 77 |
+
nn.PReLU(in_channels),
|
| 78 |
+
)
|
| 79 |
+
self.conv2 = nn.Sequential(
|
| 80 |
+
nn.Conv2d(
|
| 81 |
+
side_channels,
|
| 82 |
+
side_channels,
|
| 83 |
+
kernel_size=3,
|
| 84 |
+
stride=1,
|
| 85 |
+
padding=1,
|
| 86 |
+
bias=bias,
|
| 87 |
+
),
|
| 88 |
+
nn.PReLU(side_channels),
|
| 89 |
+
)
|
| 90 |
+
self.conv3 = nn.Sequential(
|
| 91 |
+
nn.Conv2d(
|
| 92 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
| 93 |
+
),
|
| 94 |
+
nn.PReLU(in_channels),
|
| 95 |
+
)
|
| 96 |
+
self.conv4 = nn.Sequential(
|
| 97 |
+
nn.Conv2d(
|
| 98 |
+
side_channels,
|
| 99 |
+
side_channels,
|
| 100 |
+
kernel_size=3,
|
| 101 |
+
stride=1,
|
| 102 |
+
padding=1,
|
| 103 |
+
bias=bias,
|
| 104 |
+
),
|
| 105 |
+
nn.PReLU(side_channels),
|
| 106 |
+
)
|
| 107 |
+
self.conv5 = nn.Conv2d(
|
| 108 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
| 109 |
+
)
|
| 110 |
+
self.prelu = nn.PReLU(in_channels)
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
out = self.conv1(x)
|
| 114 |
+
out[:, -self.side_channels :, :, :] = self.conv2(
|
| 115 |
+
out[:, -self.side_channels :, :, :]
|
| 116 |
+
)
|
| 117 |
+
out = self.conv3(out)
|
| 118 |
+
out[:, -self.side_channels :, :, :] = self.conv4(
|
| 119 |
+
out[:, -self.side_channels :, :, :]
|
| 120 |
+
)
|
| 121 |
+
out = self.prelu(x + self.conv5(out))
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Encoder(nn.Module):
|
| 126 |
+
def __init__(self):
|
| 127 |
+
super(Encoder, self).__init__()
|
| 128 |
+
self.pyramid1 = nn.Sequential(
|
| 129 |
+
convrelu(3, 64, 7, 2, 3), convrelu(64, 64, 3, 1, 1)
|
| 130 |
+
)
|
| 131 |
+
self.pyramid2 = nn.Sequential(
|
| 132 |
+
convrelu(64, 96, 3, 2, 1), convrelu(96, 96, 3, 1, 1)
|
| 133 |
+
)
|
| 134 |
+
self.pyramid3 = nn.Sequential(
|
| 135 |
+
convrelu(96, 144, 3, 2, 1), convrelu(144, 144, 3, 1, 1)
|
| 136 |
+
)
|
| 137 |
+
self.pyramid4 = nn.Sequential(
|
| 138 |
+
convrelu(144, 192, 3, 2, 1), convrelu(192, 192, 3, 1, 1)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, img):
|
| 142 |
+
f1 = self.pyramid1(img)
|
| 143 |
+
f2 = self.pyramid2(f1)
|
| 144 |
+
f3 = self.pyramid3(f2)
|
| 145 |
+
f4 = self.pyramid4(f3)
|
| 146 |
+
return f1, f2, f3, f4
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Decoder4(nn.Module):
|
| 150 |
+
def __init__(self):
|
| 151 |
+
super(Decoder4, self).__init__()
|
| 152 |
+
self.convblock = nn.Sequential(
|
| 153 |
+
convrelu(384 + 1, 384),
|
| 154 |
+
ResBlock(384, 64),
|
| 155 |
+
nn.ConvTranspose2d(384, 148, 4, 2, 1, bias=True),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def forward(self, f0, f1, embt):
|
| 159 |
+
b, c, h, w = f0.shape
|
| 160 |
+
embt = embt.repeat(1, 1, h, w)
|
| 161 |
+
f_in = torch.cat([f0, f1, embt], 1)
|
| 162 |
+
f_out = self.convblock(f_in)
|
| 163 |
+
return f_out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class Decoder3(nn.Module):
|
| 167 |
+
def __init__(self):
|
| 168 |
+
super(Decoder3, self).__init__()
|
| 169 |
+
self.convblock = nn.Sequential(
|
| 170 |
+
convrelu(436, 432),
|
| 171 |
+
ResBlock(432, 64),
|
| 172 |
+
nn.ConvTranspose2d(432, 100, 4, 2, 1, bias=True),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
|
| 176 |
+
f0_warp = warp(f0, up_flow0)
|
| 177 |
+
f1_warp = warp(f1, up_flow1)
|
| 178 |
+
f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
|
| 179 |
+
f_out = self.convblock(f_in)
|
| 180 |
+
return f_out
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Decoder2(nn.Module):
|
| 184 |
+
def __init__(self):
|
| 185 |
+
super(Decoder2, self).__init__()
|
| 186 |
+
self.convblock = nn.Sequential(
|
| 187 |
+
convrelu(292, 288),
|
| 188 |
+
ResBlock(288, 64),
|
| 189 |
+
nn.ConvTranspose2d(288, 68, 4, 2, 1, bias=True),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
|
| 193 |
+
f0_warp = warp(f0, up_flow0)
|
| 194 |
+
f1_warp = warp(f1, up_flow1)
|
| 195 |
+
f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
|
| 196 |
+
f_out = self.convblock(f_in)
|
| 197 |
+
return f_out
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Decoder1(nn.Module):
|
| 201 |
+
def __init__(self):
|
| 202 |
+
super(Decoder1, self).__init__()
|
| 203 |
+
self.convblock = nn.Sequential(
|
| 204 |
+
convrelu(196, 192),
|
| 205 |
+
ResBlock(192, 64),
|
| 206 |
+
nn.ConvTranspose2d(192, 8, 4, 2, 1, bias=True),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def forward(self, ft_, f0, f1, up_flow0, up_flow1):
|
| 210 |
+
f0_warp = warp(f0, up_flow0)
|
| 211 |
+
f1_warp = warp(f1, up_flow1)
|
| 212 |
+
f_in = torch.cat([ft_, f0_warp, f1_warp, up_flow0, up_flow1], 1)
|
| 213 |
+
f_out = self.convblock(f_in)
|
| 214 |
+
return f_out
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class IRFNet_L(nn.Module):
|
| 218 |
+
def __init__(self):
|
| 219 |
+
super(IRFNet_L, self).__init__()
|
| 220 |
+
self.encoder = Encoder()
|
| 221 |
+
self.decoder4 = Decoder4()
|
| 222 |
+
self.decoder3 = Decoder3()
|
| 223 |
+
self.decoder2 = Decoder2()
|
| 224 |
+
self.decoder1 = Decoder1()
|
| 225 |
+
|
| 226 |
+
def forward(self, img0, img1, scale_factor=1.0, timestep=0.5):
|
| 227 |
+
# emb1 = torch.tensor(1/2).view(1, 1, 1, 1).float()
|
| 228 |
+
# emb2 = torch.tensor(2/2).view(1, 1, 1, 1).float()
|
| 229 |
+
# embt = torch.cat([emb1, emb2], 0)
|
| 230 |
+
n, c, h, w = img0.shape
|
| 231 |
+
|
| 232 |
+
ph = ((h - 1) // 64 + 1) * 64
|
| 233 |
+
pw = ((w - 1) // 64 + 1) * 64
|
| 234 |
+
padding = (0, pw - w, 0, ph - h)
|
| 235 |
+
img0 = F.pad(img0, padding)
|
| 236 |
+
img1 = F.pad(img1, padding)
|
| 237 |
+
|
| 238 |
+
#Support multiple batches
|
| 239 |
+
embt = torch.tensor([timestep] * n).view(n, 1, 1, 1).float().to(get_torch_device())
|
| 240 |
+
if "HalfTensor" in str(img0.type()):
|
| 241 |
+
embt = embt.half()
|
| 242 |
+
|
| 243 |
+
mean_ = (
|
| 244 |
+
torch.cat([img0, img1], 2)
|
| 245 |
+
.mean(1, keepdim=True)
|
| 246 |
+
.mean(2, keepdim=True)
|
| 247 |
+
.mean(3, keepdim=True)
|
| 248 |
+
)
|
| 249 |
+
img0 = img0 - mean_
|
| 250 |
+
img1 = img1 - mean_
|
| 251 |
+
|
| 252 |
+
img0_ = resize(img0, scale_factor=scale_factor)
|
| 253 |
+
img1_ = resize(img1, scale_factor=scale_factor)
|
| 254 |
+
|
| 255 |
+
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
|
| 256 |
+
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
|
| 257 |
+
|
| 258 |
+
out4 = self.decoder4(f0_4, f1_4, embt)
|
| 259 |
+
up_flow0_4 = out4[:, 0:2]
|
| 260 |
+
up_flow1_4 = out4[:, 2:4]
|
| 261 |
+
ft_3_ = out4[:, 4:]
|
| 262 |
+
|
| 263 |
+
out3 = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
|
| 264 |
+
up_flow0_3 = out3[:, 0:2] + 2.0 * resize(up_flow0_4, scale_factor=2.0)
|
| 265 |
+
up_flow1_3 = out3[:, 2:4] + 2.0 * resize(up_flow1_4, scale_factor=2.0)
|
| 266 |
+
ft_2_ = out3[:, 4:]
|
| 267 |
+
|
| 268 |
+
out2 = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
|
| 269 |
+
up_flow0_2 = out2[:, 0:2] + 2.0 * resize(up_flow0_3, scale_factor=2.0)
|
| 270 |
+
up_flow1_2 = out2[:, 2:4] + 2.0 * resize(up_flow1_3, scale_factor=2.0)
|
| 271 |
+
ft_1_ = out2[:, 4:]
|
| 272 |
+
|
| 273 |
+
out1 = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
|
| 274 |
+
up_flow0_1 = out1[:, 0:2] + 2.0 * resize(up_flow0_2, scale_factor=2.0)
|
| 275 |
+
up_flow1_1 = out1[:, 2:4] + 2.0 * resize(up_flow1_2, scale_factor=2.0)
|
| 276 |
+
up_mask_1 = torch.sigmoid(out1[:, 4:5])
|
| 277 |
+
up_res_1 = out1[:, 5:]
|
| 278 |
+
|
| 279 |
+
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0 / scale_factor)) * (
|
| 280 |
+
1.0 / scale_factor
|
| 281 |
+
)
|
| 282 |
+
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0 / scale_factor)) * (
|
| 283 |
+
1.0 / scale_factor
|
| 284 |
+
)
|
| 285 |
+
up_mask_1 = resize(up_mask_1, scale_factor=(1.0 / scale_factor))
|
| 286 |
+
up_res_1 = resize(up_res_1, scale_factor=(1.0 / scale_factor))
|
| 287 |
+
|
| 288 |
+
img0_warp = warp(img0, up_flow0_1)
|
| 289 |
+
img1_warp = warp(img1, up_flow1_1)
|
| 290 |
+
imgt_merge = up_mask_1 * img0_warp + (1 - up_mask_1) * img1_warp + mean_
|
| 291 |
+
imgt_pred = imgt_merge + up_res_1
|
| 292 |
+
imgt_pred = torch.clamp(imgt_pred, 0, 1)
|
| 293 |
+
return imgt_pred[:, :, :h, :w]
|