Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +173 -0
- LICENSE +101 -0
- README.md +179 -12
- app.py +389 -0
- basicsr/VERSION +1 -0
- basicsr/__init__.py +11 -0
- basicsr/archs/__init__.py +25 -0
- basicsr/archs/arcface_arch.py +245 -0
- basicsr/archs/arch_util.py +318 -0
- basicsr/archs/codeformer_arch.py +276 -0
- basicsr/archs/rrdbnet_arch.py +119 -0
- basicsr/archs/vgg_arch.py +161 -0
- basicsr/archs/vqgan_arch.py +406 -0
- basicsr/data/__init__.py +100 -0
- basicsr/data/data_sampler.py +48 -0
- basicsr/data/data_util.py +305 -0
- basicsr/data/prefetch_dataloader.py +138 -0
- basicsr/data/transforms.py +165 -0
- basicsr/losses/__init__.py +26 -0
- basicsr/losses/loss_util.py +95 -0
- basicsr/losses/losses.py +455 -0
- basicsr/metrics/__init__.py +19 -0
- basicsr/metrics/metric_util.py +45 -0
- basicsr/metrics/psnr_ssim.py +128 -0
- basicsr/models/__init__.py +30 -0
- basicsr/ops/__init__.py +0 -0
- basicsr/ops/dcn/__init__.py +7 -0
- basicsr/ops/dcn/deform_conv.py +377 -0
- basicsr/ops/dcn/src/deform_conv_cuda.cpp +685 -0
- basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu +867 -0
- basicsr/ops/dcn/src/deform_conv_ext.cpp +164 -0
- basicsr/ops/fused_act/__init__.py +3 -0
- basicsr/ops/fused_act/fused_act.py +89 -0
- basicsr/ops/fused_act/src/fused_bias_act.cpp +26 -0
- basicsr/ops/fused_act/src/fused_bias_act_kernel.cu +100 -0
- basicsr/ops/upfirdn2d/__init__.py +3 -0
- basicsr/ops/upfirdn2d/src/upfirdn2d.cpp +24 -0
- basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu +370 -0
- basicsr/ops/upfirdn2d/upfirdn2d.py +186 -0
- basicsr/setup.py +165 -0
- basicsr/train.py +207 -0
- basicsr/utils/__init__.py +29 -0
- basicsr/utils/dist_util.py +82 -0
- basicsr/utils/download_util.py +95 -0
- basicsr/utils/file_client.py +167 -0
- basicsr/utils/img_util.py +170 -0
- basicsr/utils/lmdb_util.py +196 -0
- basicsr/utils/logger.py +169 -0
- basicsr/utils/matlab_functions.py +347 -0
- basicsr/utils/misc.py +134 -0
.gitignore
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
out/*
|
| 163 |
+
!out/.gitkeep
|
| 164 |
+
media
|
| 165 |
+
tests
|
| 166 |
+
*.onnx
|
| 167 |
+
|
| 168 |
+
aaa.md
|
| 169 |
+
|
| 170 |
+
*_test.py
|
| 171 |
+
img.jpg
|
| 172 |
+
test_data
|
| 173 |
+
testsrc.mp4
|
LICENSE
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
IMPORTANT NOTICE
|
| 2 |
+
|
| 3 |
+
This project is licensed under a custom MIT License, **except** for the optional
|
| 4 |
+
`codeformer` component, which is licensed under the Creative Commons
|
| 5 |
+
Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0).
|
| 6 |
+
If you require commercial use, you **must remove** `codeformer`.
|
| 7 |
+
See the bottom of this file for full details and removal instructions.
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Custom MIT License
|
| 12 |
+
|
| 13 |
+
Copyright (c) 2023 xaviviro
|
| 14 |
+
Copyright (c) 2025 Felipe Daragon
|
| 15 |
+
|
| 16 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 17 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 18 |
+
in the Software without restriction, including without limitation the rights
|
| 19 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 20 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 21 |
+
furnished to do so, subject to the following conditions:
|
| 22 |
+
|
| 23 |
+
- The above copyright notice and this permission notice shall be included in
|
| 24 |
+
all copies or substantial portions of the Software.
|
| 25 |
+
|
| 26 |
+
- You may only use this Software with content (such as images and videos)
|
| 27 |
+
for which you have the necessary rights and permissions. Unauthorized use of
|
| 28 |
+
third-party content is strictly prohibited.
|
| 29 |
+
|
| 30 |
+
- This Software is intended for educational and research purposes only. Use
|
| 31 |
+
of this Software for malicious purposes, including but not limited to identity
|
| 32 |
+
theft, invasion of privacy, or defamation, is strictly prohibited.
|
| 33 |
+
|
| 34 |
+
- By using this Software, you agree to comply with all applicable laws and
|
| 35 |
+
to respect the rights and privacy of others. You agree to use the Software
|
| 36 |
+
responsibly and ethically.
|
| 37 |
+
|
| 38 |
+
- The Software may contain protective mechanisms intended to prevent its use
|
| 39 |
+
with illegal or unauthorized media.
|
| 40 |
+
|
| 41 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 42 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 43 |
+
FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 44 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
| 45 |
+
WHETHER IN AN ACTION OF CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF,
|
| 46 |
+
OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## Additional License Notice: Optional `codeformer` Component
|
| 51 |
+
|
| 52 |
+
This project distribution optionally includes an old version of a component
|
| 53 |
+
named `codeformer` (https://github.com/felipedaragon/codeformer),
|
| 54 |
+
developed by Shangchen Zhou.
|
| 55 |
+
The `codeformer` component is **NOT** licensed under the MIT License.
|
| 56 |
+
Instead, it is licensed under:
|
| 57 |
+
|
| 58 |
+
**Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)**
|
| 59 |
+
License details: https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 60 |
+
|
| 61 |
+
Key points about this license:
|
| 62 |
+
- **Non-commercial use only**: You may not use `codeformer` for commercial purposes.
|
| 63 |
+
- **Attribution required**: You must credit the original creators.
|
| 64 |
+
- **ShareAlike**: If you modify and share `codeformer`, you must do so under the same license.
|
| 65 |
+
|
| 66 |
+
### How to Use This Project as MIT Only
|
| 67 |
+
|
| 68 |
+
If you wish to use this project solely under the MIT License (for example,
|
| 69 |
+
for commercial purposes), you **must remove** the `codeformer` component.
|
| 70 |
+
Please follow the instructions provided below:
|
| 71 |
+
|
| 72 |
+
- Remove the subdirectories basicsr and facelib
|
| 73 |
+
- Remove within weights subdirectory Codeformer and facelib.
|
| 74 |
+
- Remove codeformer_wrapper.py
|
| 75 |
+
- Edit refacer.py and remove the import: codeformer_wrapper
|
| 76 |
+
- Adjust the code so that it doesn't calls the enhance functions from the commented wrapper
|
| 77 |
+
- That's all!
|
| 78 |
+
|
| 79 |
+
Failure to remove `codeformer` when required may violate the terms of its license.
|
| 80 |
+
|
| 81 |
+
About User Outputs:
|
| 82 |
+
The outputs generated by this software (such as refaced images or videos) are not
|
| 83 |
+
subject to the CC BY-NC-SA license and may be used freely, including for commercial
|
| 84 |
+
purposes, regardless of whether the optional codeformer component is used.
|
| 85 |
+
|
| 86 |
+
Explanation:
|
| 87 |
+
|
| 88 |
+
Codeformer is a model that processes images for face enhancement.
|
| 89 |
+
It does not embed its own visible content into the output.
|
| 90 |
+
|
| 91 |
+
This is very different from a case where licensed assets (such as textures,
|
| 92 |
+
overlays, characters, backgrounds, or artwork) appear visibly in the output.
|
| 93 |
+
|
| 94 |
+
Codeformer simply modifies the input image without adding original copyrighted
|
| 95 |
+
material.
|
| 96 |
+
|
| 97 |
+
Therefore, output images are not derivative works of Codeformer
|
| 98 |
+
and are not bound by the NonCommercial restrictions of its license.
|
| 99 |
+
|
| 100 |
+
Only the Codeformer model code and weights themselves are under CC BY-NC-SA 4.0
|
| 101 |
+
— not the results produced through their use.
|
README.md
CHANGED
|
@@ -1,12 +1,179 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<img src="https://raw.githubusercontent.com/MechasAI/NeoRefacer/main/icon.png"/>
|
| 2 |
+
|
| 3 |
+
# NeoRefacer: Images. GIFs. TIFFs. Full-length videos.
|
| 4 |
+
|
| 5 |
+
In a future where identity flows like data and reality is just another layer, NeoRefacer gives you the power to transform.
|
| 6 |
+
|
| 7 |
+
Images. GIFs. TIFFs. Full-length videos.
|
| 8 |
+
|
| 9 |
+
All yours to reface and reimagine - with a single pulse of electricity.
|
| 10 |
+
|
| 11 |
+
Evolved from the foundations of the [Refacer](https://github.com/xaviviro/refacer) project, NeoRefacer is a next-generation, fully open-source refacer.
|
| 12 |
+
|
| 13 |
+
<img src="https://raw.githubusercontent.com/MechasAI/NeoRefacer/main/demo.jpg"/>
|
| 14 |
+
|
| 15 |
+
1. Clone the repository.
|
| 16 |
+
2. Spin up the environment.
|
| 17 |
+
3. Launch the local interface.
|
| 18 |
+
4. Control the face of tomorrow.
|
| 19 |
+
|
| 20 |
+
[OFFICIAL WEBSITE](https://www.mechas.ai/projects-neorefacer.php)
|
| 21 |
+
|
| 22 |
+
## Core DNA of NeoRefacer
|
| 23 |
+
* **Instant Identity Shift** - Swap faces in images, GIFs, multi-page TIFFs and movies faster than your neural implants can blink.
|
| 24 |
+
* **Overclocked Engine** - Optimized for CPU rebels and GPU warlords.
|
| 25 |
+
* **Feature Film Reface** - Not just TikToks. Full two-hour cinematic overthrows.
|
| 26 |
+
* **Targeted Strike Modes** - Single-face raids, multi-face takeovers, or precision-targeted matchups.
|
| 27 |
+
* **Bulk Warfare** - Mass-process entire image archives with industrial-scale automation.
|
| 28 |
+
* **Neural Enhancement Suite** - Automatic image enhancement.
|
| 29 |
+
|
| 30 |
+
## Use Cases
|
| 31 |
+
|
| 32 |
+
* **Entertainment**: Rewrite memories, remix movies, animate the past.
|
| 33 |
+
* **Education**: Step into history, speak through new faces.
|
| 34 |
+
* **Content Creation**: Craft AI doubles, weave digital alter-egos.
|
| 35 |
+
* **Business/Marketing**: Personalize ads inside the algorithmic flood.
|
| 36 |
+
* **Niche Fun**: Trace ancestral echoes, forge RPG legends, hijack fame.
|
| 37 |
+
|
| 38 |
+
## What's New (Since Refacer)
|
| 39 |
+
|
| 40 |
+
* Image, GIF, TIFF and Video reface modes
|
| 41 |
+
* Significantly faster processing
|
| 42 |
+
* Automatic image enhancing (Image mode)
|
| 43 |
+
* Improved video output quality
|
| 44 |
+
* Support for videos that have long duration
|
| 45 |
+
* Preview generation for videos and GIFs (skips 90% of frames)
|
| 46 |
+
* Multiple replacement modes:
|
| 47 |
+
* **Single Face** (Fast): all faces are replaced with a single face. Ideal for images, GIFs or videos with a single face
|
| 48 |
+
* **Multiple Faces** (Fast): faces are replaced with the faces you provide based on their order from left to right
|
| 49 |
+
* **Faces by Match** (Slower): faces are first detected and replaced with the faces you provide.
|
| 50 |
+
* Reface ratio: full face to half face.
|
| 51 |
+
* Improved GPU detection
|
| 52 |
+
* Support for multi-page TIFF
|
| 53 |
+
* Uses local Gradio cache with auto-cleanup on startup
|
| 54 |
+
* Includes a bulk image refacer utility (refacer_bulk.py)
|
| 55 |
+
* Videos and images are saved to the root of /output, and GIFs are saved to /output/gifs and previews are saved to /output/preview subdirectory
|
| 56 |
+
|
| 57 |
+
NeoRefacer, just like the original Refacer project, requires no training - just one photo and you're ready to go.
|
| 58 |
+
|
| 59 |
+
:warning: Please, before using the code from this repository, make sure to read the [LICENSE](https://github.com/MechasAI/NeoRefacer/blob/main/LICENSE).
|
| 60 |
+
|
| 61 |
+
## System Compatibility
|
| 62 |
+
|
| 63 |
+
NeoRefacer has been tested on the following operating systems:
|
| 64 |
+
|
| 65 |
+
| Operating System | CPU Support | GPU Support |
|
| 66 |
+
| ---------------- | ----------- | ----------- |
|
| 67 |
+
| MacOSX | ✅ | ✅ |
|
| 68 |
+
| Windows | ✅ | ✅ |
|
| 69 |
+
| Linux | ✅ | ✅ |
|
| 70 |
+
|
| 71 |
+
The application is compatible with both CPU and GPU (Nvidia CUDA) environments, and MacOSX (CoreML)
|
| 72 |
+
|
| 73 |
+
## Installation
|
| 74 |
+
|
| 75 |
+
NeoRefacer has been tested and is known to work with Python 3.11.11, but it is likely to work with other Python versions as well. It is recommended to use a virtual environment, such as [Conda](https://www.anaconda.com/download), for setting up and running the project to avoid potential conflicts with other Python packages you may have installed.
|
| 76 |
+
|
| 77 |
+
On Windows, before continuing, ensure that you have the [Visual Studio Build Tools](https://visualstudio.microsoft.com/downloads/) installed. They are required for installing dependencies. If you skip this step, you will likely encounter an error prompting you to install them.
|
| 78 |
+
|
| 79 |
+
Follow these steps to install Refacer and its dependencies:
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
# Check if ffmpeg is available (if not, you might to download it and add it to your PATH)
|
| 83 |
+
ffmpeg
|
| 84 |
+
|
| 85 |
+
# Windows: download ffmpeg-git-essentials.7z from https://www.gyan.dev/ffmpeg/builds/
|
| 86 |
+
# MacOS: if you have brew installed:
|
| 87 |
+
# brew install ffmpeg
|
| 88 |
+
# Other systems: see a tutorial https://www.hostinger.com/tutorials/how-to-install-ffmpeg
|
| 89 |
+
|
| 90 |
+
# Clone the repository
|
| 91 |
+
git clone https://github.com/MechasAI/NeoRefacer.git
|
| 92 |
+
cd NeoRefacer
|
| 93 |
+
|
| 94 |
+
# Create the environment
|
| 95 |
+
# Windows:
|
| 96 |
+
conda create -n neorefacer-env python=3.11 conda-forge::vs2015_runtime
|
| 97 |
+
# Linux:
|
| 98 |
+
conda create -n neorefacer-env python=3.11
|
| 99 |
+
# MacOS:
|
| 100 |
+
conda create -n neorefacer-env python=3.11
|
| 101 |
+
|
| 102 |
+
# Activate the environment
|
| 103 |
+
conda activate neorefacer-env
|
| 104 |
+
|
| 105 |
+
# Instal the dependencies:
|
| 106 |
+
# For CPU only (compatible with Windows, MacOSX, and Linux)
|
| 107 |
+
pip install -r requirements-CPU.txt
|
| 108 |
+
|
| 109 |
+
# For NVIDIA RTX GPU only (compatible with Windows and Linux only, requires a NVIDIA GPU with CUDA and its libraries)
|
| 110 |
+
# Install Torch with CUDA enabled:
|
| 111 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 112 |
+
# This should install torch 2.5.1, torchaudio 2.5.1 and torchvision 0.20.1
|
| 113 |
+
# Make sure that CUDA is returning True:
|
| 114 |
+
python -c "import torch; print('CUDA:', torch.cuda.is_available()); print(torch.version.cuda); print(torch.cuda.get_device_name(0))"
|
| 115 |
+
# Now install the rest of the dependencies
|
| 116 |
+
pip install -r requirements-GPU.txt
|
| 117 |
+
|
| 118 |
+
# For CoreML only (compatible with MacOSX, requires Silicon architecture):
|
| 119 |
+
pip install -r requirements-COREML.txt
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
For NVIDIA GPU, make sure you have both NVIDIA GPU Computing Toolkit and NVIDIA CUDNN installed. The onnxruntime-gpu version must match your version of CUDA. This example uses onnxruntime-gpu 1.21.0, which is compatible with CUDA 12.6 and CUDNN 9.4 - Refacer.py is pre-loading both libraries. Remember to update the paths if needed in refacer.py if you have different location or versions.
|
| 123 |
+
|
| 124 |
+
For more information on installing the CUDA necessary to use `onnxruntime-gpu`, please refer directly to the official [ONNX Runtime repository](https://github.com/microsoft/onnxruntime/).
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
## Usage
|
| 128 |
+
|
| 129 |
+
Once you have successfully installed NeoRefacer and its dependencies, you can run the application using the following command:
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
python app.py
|
| 133 |
+
|
| 134 |
+
# Alternatively, if you need to force CPU mode
|
| 135 |
+
python app.py --force_cpu
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
Then, open your web browser and navigate to the following address:
|
| 139 |
+
|
| 140 |
+
```
|
| 141 |
+
http://127.0.0.1:7680
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
## Bulk Refacing
|
| 145 |
+
|
| 146 |
+
There are two ways to perform bulk refacing:
|
| 147 |
+
|
| 148 |
+
1. Using the GUI (Graphical User Interface):
|
| 149 |
+
Select **TIFF mode**, and then input a multi-page TIFF file. A multi-page TIFF is a special type of TIFF file that contains multiple images (pages) inside a single .tif file — similar to how a PDF can have many pages. Instead of handling individual images separately, all the images are stored together in one file. You can create a multi-page TIFF using an image editor that supports this format. Once you input this multi-page TIFF into the application, it will automatically process and reface every image (page) inside it. After processing, the output will be a .tif file where all internal images have been refaced.
|
| 150 |
+
|
| 151 |
+
2. Using the CLI (Command Line Interface):
|
| 152 |
+
You can call the **refacer_bulk.py** script directly through the command line. This allows you to process multiple images in a batch by providing the necessary parameters via CLI command, as shown below.
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
python refacer_bulk.py --input_path ./input --dest_face myface.jpg
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
## Questions?
|
| 160 |
+
|
| 161 |
+
If you have any questions or issues, feel free to [open an issue](https://github.com/MechasAI/NeoRefacer/issues/new).
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
## Third-Party Modules
|
| 165 |
+
|
| 166 |
+
The `recognition` folder in this repository is derived from Insightface's GitHub repository. You can find the original source code here: [Insightface Recognition Source Code](https://github.com/deepinsight/insightface/tree/master/web-demos/src_recognition)
|
| 167 |
+
|
| 168 |
+
This module is used for recognizing and handling face data within the NeoRefacer application. We are grateful to Insightface for their work and for making their code available.
|
| 169 |
+
|
| 170 |
+
The image enhancing capability is based on [codeformer](https://github.com/felipedaragon/codeformer/) (by Shangchen Zhou) and [BasicSR](https://github.com/XPixelGroup/BasicSR). It also borrow some codes from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). Thanks for their awesome works.
|
| 171 |
+
|
| 172 |
+
## License
|
| 173 |
+
|
| 174 |
+
Note: This project uses a Custom MIT License, not allowing commercial use of the code unless you remove the image enhancing component. The output (refaced image or video) is not restricted by CC BY-NC-SA and may be used including for commercial purposes. See [LICENSE](https://github.com/MechasAI/NeoRefacer/blob/main/LICENSE) for full terms.
|
| 175 |
+
|
| 176 |
+
The generated content (refaced images or videos) does not represent the views, beliefs, or attitudes of the authors of this Software. Please use the Software and its outputs responsibly, ethically, and with respect toward others.
|
| 177 |
+
|
| 178 |
+
## Credits
|
| 179 |
+
Special thanks to Roberto Marc for the additional testing.
|
app.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from refacer import Refacer
|
| 6 |
+
import argparse
|
| 7 |
+
import ngrok
|
| 8 |
+
import imageio
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import tempfile
|
| 12 |
+
import base64
|
| 13 |
+
import pyfiglet
|
| 14 |
+
import shutil
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
print("\033[94m" + pyfiglet.Figlet(font='slant').renderText("NeoRefacer") + "\033[0m")
|
| 18 |
+
|
| 19 |
+
def cleanup_temp(folder_path):
|
| 20 |
+
try:
|
| 21 |
+
shutil.rmtree(folder_path)
|
| 22 |
+
print("Gradio cache cleared successfully.")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error: {e}")
|
| 25 |
+
|
| 26 |
+
# Prepare temp folder
|
| 27 |
+
os.environ["GRADIO_TEMP_DIR"] = "./tmp"
|
| 28 |
+
if os.path.exists("./tmp"):
|
| 29 |
+
cleanup_temp(os.environ['GRADIO_TEMP_DIR'])
|
| 30 |
+
if not os.path.exists("./tmp"):
|
| 31 |
+
os.makedirs("./tmp")
|
| 32 |
+
|
| 33 |
+
# Parse arguments
|
| 34 |
+
parser = argparse.ArgumentParser(description='Refacer')
|
| 35 |
+
parser.add_argument("--max_num_faces", type=int, default=8)
|
| 36 |
+
parser.add_argument("--force_cpu", default=False, action="store_true")
|
| 37 |
+
parser.add_argument("--share_gradio", default=False, action="store_true")
|
| 38 |
+
parser.add_argument("--server_name", type=str, default="127.0.0.1")
|
| 39 |
+
parser.add_argument("--server_port", type=int, default=7860)
|
| 40 |
+
parser.add_argument("--colab_performance", default=False, action="store_true")
|
| 41 |
+
parser.add_argument("--ngrok", type=str, default=None)
|
| 42 |
+
parser.add_argument("--ngrok_region", type=str, default="us")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
# Initialize
|
| 46 |
+
refacer = Refacer(force_cpu=args.force_cpu, colab_performance=args.colab_performance)
|
| 47 |
+
num_faces = args.max_num_faces
|
| 48 |
+
|
| 49 |
+
def create_dummy_image():
|
| 50 |
+
dummy = Image.new('RGB', (1, 1), color=(255, 255, 255))
|
| 51 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, dir="./tmp", suffix=".png")
|
| 52 |
+
dummy.save(temp_file.name)
|
| 53 |
+
return temp_file.name
|
| 54 |
+
|
| 55 |
+
def run_image(*vars):
|
| 56 |
+
image_path = vars[0]
|
| 57 |
+
origins = vars[1:(num_faces+1)]
|
| 58 |
+
destinations = vars[(num_faces+1):(num_faces*2)+1]
|
| 59 |
+
thresholds = vars[(num_faces*2)+1:-2]
|
| 60 |
+
face_mode = vars[-2]
|
| 61 |
+
partial_reface_ratio = vars[-1]
|
| 62 |
+
|
| 63 |
+
disable_similarity = (face_mode in ["Single Face", "Multiple Faces"])
|
| 64 |
+
multiple_faces_mode = (face_mode == "Multiple Faces")
|
| 65 |
+
|
| 66 |
+
faces = []
|
| 67 |
+
for k in range(num_faces):
|
| 68 |
+
if destinations[k] is not None:
|
| 69 |
+
faces.append({
|
| 70 |
+
'origin': origins[k] if not multiple_faces_mode else None,
|
| 71 |
+
'destination': destinations[k],
|
| 72 |
+
'threshold': thresholds[k] if not multiple_faces_mode else 0.0
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
return refacer.reface_image(image_path, faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode, partial_reface_ratio=partial_reface_ratio)
|
| 76 |
+
|
| 77 |
+
def run(*vars):
|
| 78 |
+
video_path = vars[0]
|
| 79 |
+
origins = vars[1:(num_faces+1)]
|
| 80 |
+
destinations = vars[(num_faces+1):(num_faces*2)+1]
|
| 81 |
+
thresholds = vars[(num_faces*2)+1:-3]
|
| 82 |
+
preview = vars[-3]
|
| 83 |
+
face_mode = vars[-2]
|
| 84 |
+
partial_reface_ratio = vars[-1]
|
| 85 |
+
|
| 86 |
+
disable_similarity = (face_mode in ["Single Face", "Multiple Faces"])
|
| 87 |
+
multiple_faces_mode = (face_mode == "Multiple Faces")
|
| 88 |
+
|
| 89 |
+
faces = []
|
| 90 |
+
for k in range(num_faces):
|
| 91 |
+
if destinations[k] is not None:
|
| 92 |
+
faces.append({
|
| 93 |
+
'origin': origins[k] if not multiple_faces_mode else None,
|
| 94 |
+
'destination': destinations[k],
|
| 95 |
+
'threshold': thresholds[k] if not multiple_faces_mode else 0.0
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
mp4_path, gif_path = refacer.reface(video_path, faces, preview=preview, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode, partial_reface_ratio=partial_reface_ratio)
|
| 99 |
+
return mp4_path, gif_path if gif_path else None
|
| 100 |
+
|
| 101 |
+
def load_first_frame(filepath):
|
| 102 |
+
if filepath is None:
|
| 103 |
+
return None
|
| 104 |
+
frames = imageio.get_reader(filepath)
|
| 105 |
+
return frames.get_data(0)
|
| 106 |
+
|
| 107 |
+
def extract_faces_auto(filepath, refacer_instance, max_faces=5, isvideo=False):
|
| 108 |
+
if filepath is None:
|
| 109 |
+
return [None] * max_faces
|
| 110 |
+
|
| 111 |
+
if isvideo and os.path.getsize(filepath) > 5 * 1024 * 1024:
|
| 112 |
+
print("Video too large for auto-extract, skipping face extraction.")
|
| 113 |
+
return [None] * max_faces
|
| 114 |
+
|
| 115 |
+
frame = load_first_frame(filepath)
|
| 116 |
+
if frame is None:
|
| 117 |
+
return [None] * max_faces
|
| 118 |
+
|
| 119 |
+
while len(frame.shape) > 3:
|
| 120 |
+
frame = frame[0]
|
| 121 |
+
|
| 122 |
+
if frame.shape[-1] != 3:
|
| 123 |
+
raise ValueError(f"Expected last dimension to be 3 (RGB), but got {frame.shape[-1]}")
|
| 124 |
+
|
| 125 |
+
temp_image_path = os.path.join("./tmp", f"temp_face_extract_{int(time.time() * 1000)}.png")
|
| 126 |
+
Image.fromarray(frame).save(temp_image_path)
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
faces = refacer_instance.extract_faces_from_image(temp_image_path, max_faces=max_faces)
|
| 130 |
+
return faces + [None] * (max_faces - len(faces))
|
| 131 |
+
finally:
|
| 132 |
+
if os.path.exists(temp_image_path):
|
| 133 |
+
try:
|
| 134 |
+
os.remove(temp_image_path)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"Warning: Could not delete temp file {temp_image_path}: {e}")
|
| 137 |
+
|
| 138 |
+
def toggle_tabs_and_faces(mode, face_tabs, origin_faces):
|
| 139 |
+
if mode == "Single Face":
|
| 140 |
+
tab_updates = [gr.update(visible=(i == 0)) for i in range(len(face_tabs))]
|
| 141 |
+
origin_updates = [gr.update(visible=False) for _ in range(len(origin_faces))]
|
| 142 |
+
elif mode == "Multiple Faces":
|
| 143 |
+
tab_updates = [gr.update(visible=True) for _ in range(len(face_tabs))]
|
| 144 |
+
origin_updates = [gr.update(visible=False) for _ in range(len(origin_faces))]
|
| 145 |
+
else:
|
| 146 |
+
tab_updates = [gr.update(visible=True) for _ in range(len(face_tabs))]
|
| 147 |
+
origin_updates = [gr.update(visible=True) for _ in range(len(origin_faces))]
|
| 148 |
+
return tab_updates + origin_updates
|
| 149 |
+
|
| 150 |
+
def handle_tif_preview(filepath):
|
| 151 |
+
if filepath is None:
|
| 152 |
+
return None
|
| 153 |
+
preview_path = os.path.join("./tmp", f"tif_preview_{int(time.time() * 1000)}.jpg")
|
| 154 |
+
Image.open(filepath).convert('RGB').save(preview_path)
|
| 155 |
+
return preview_path
|
| 156 |
+
|
| 157 |
+
# --- UI ---
|
| 158 |
+
theme = gr.themes.Base(primary_hue="blue", secondary_hue="cyan")
|
| 159 |
+
|
| 160 |
+
with gr.Blocks(theme=theme, title="NeoRefacer - AI Refacer") as demo:
|
| 161 |
+
with open("icon.png", "rb") as f:
|
| 162 |
+
icon_data = base64.b64encode(f.read()).decode()
|
| 163 |
+
icon_html = f'<img src="data:image/png;base64,{icon_data}" style="width:40px;height:40px;margin-right:10px;">'
|
| 164 |
+
|
| 165 |
+
with gr.Row():
|
| 166 |
+
gr.Markdown(f"""
|
| 167 |
+
<div style="display: flex; align-items: center;">
|
| 168 |
+
{icon_html}
|
| 169 |
+
<span style="font-size: 2em; font-weight: bold; color:#2563eb;">NeoRefacer</span>
|
| 170 |
+
</div>
|
| 171 |
+
""")
|
| 172 |
+
|
| 173 |
+
# --- IMAGE MODE ---
|
| 174 |
+
with gr.Tab("Image Mode"):
|
| 175 |
+
with gr.Row():
|
| 176 |
+
image_input = gr.Image(label="Original image", type="filepath")
|
| 177 |
+
image_output = gr.Image(label="Refaced image", interactive=False, type="filepath")
|
| 178 |
+
|
| 179 |
+
with gr.Row():
|
| 180 |
+
face_mode_image = gr.Radio(["Single Face", "Multiple Faces", "Faces By Match"], value="Single Face", label="Replacement Mode")
|
| 181 |
+
partial_reface_ratio_image = gr.Slider(label="Reface Ratio (0 = Full Face, 0.5 = Half Face)", minimum=0.0, maximum=0.5, value=0.0, step=0.1)
|
| 182 |
+
image_btn = gr.Button("Reface Image", variant="primary")
|
| 183 |
+
|
| 184 |
+
origin_image, destination_image, thresholds_image, face_tabs_image = [], [], [], []
|
| 185 |
+
|
| 186 |
+
for i in range(num_faces):
|
| 187 |
+
with gr.Tab(f"Face #{i+1}") as tab:
|
| 188 |
+
with gr.Row():
|
| 189 |
+
origin = gr.Image(label="Face to replace")
|
| 190 |
+
destination = gr.Image(label="Destination face")
|
| 191 |
+
threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2)
|
| 192 |
+
origin_image.append(origin)
|
| 193 |
+
destination_image.append(destination)
|
| 194 |
+
thresholds_image.append(threshold)
|
| 195 |
+
face_tabs_image.append(tab)
|
| 196 |
+
|
| 197 |
+
face_mode_image.change(fn=lambda mode: toggle_tabs_and_faces(mode, face_tabs_image, origin_image), inputs=[face_mode_image], outputs=face_tabs_image + origin_image)
|
| 198 |
+
demo.load(fn=lambda: toggle_tabs_and_faces("Single Face", face_tabs_image, origin_image), inputs=None, outputs=face_tabs_image + origin_image)
|
| 199 |
+
|
| 200 |
+
image_btn.click(fn=run_image, inputs=[image_input] + origin_image + destination_image + thresholds_image + [face_mode_image, partial_reface_ratio_image], outputs=[image_output])
|
| 201 |
+
image_input.change(fn=lambda filepath: extract_faces_auto(filepath, refacer, max_faces=num_faces), inputs=image_input, outputs=origin_image)
|
| 202 |
+
image_input.change(fn=lambda _: 0.0, inputs=image_input, outputs=partial_reface_ratio_image)
|
| 203 |
+
|
| 204 |
+
# --- GIF MODE ---
|
| 205 |
+
with gr.Tab("GIF Mode"):
|
| 206 |
+
with gr.Row():
|
| 207 |
+
gif_input = gr.File(label="Original GIF", file_types=[".gif"])
|
| 208 |
+
gif_preview = gr.Video(label="GIF Preview", interactive=False)
|
| 209 |
+
gif_output = gr.Video(label="Refaced GIF (MP4)", interactive=False, format="mp4")
|
| 210 |
+
gif_file_output = gr.Image(label="Refaced GIF (GIF)", type="filepath")
|
| 211 |
+
|
| 212 |
+
with gr.Row():
|
| 213 |
+
face_mode_gif = gr.Radio(["Single Face", "Multiple Faces", "Faces By Match"], value="Single Face", label="Replacement Mode")
|
| 214 |
+
partial_reface_ratio_gif = gr.Slider(label="Reface Ratio (0 = Full Face, 0.5 = Half Face)", minimum=0.0, maximum=0.5, value=0.0, step=0.1)
|
| 215 |
+
gif_btn = gr.Button("Reface GIF", variant="primary")
|
| 216 |
+
preview_checkbox_gif = gr.Checkbox(label="Preview Generation (skip 90% of frames)", value=False)
|
| 217 |
+
|
| 218 |
+
origin_gif, destination_gif, thresholds_gif, face_tabs_gif = [], [], [], []
|
| 219 |
+
|
| 220 |
+
for i in range(num_faces):
|
| 221 |
+
with gr.Tab(f"Face #{i+1}") as tab:
|
| 222 |
+
with gr.Row():
|
| 223 |
+
origin = gr.Image(label="Face to replace")
|
| 224 |
+
destination = gr.Image(label="Destination face")
|
| 225 |
+
threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2)
|
| 226 |
+
origin_gif.append(origin)
|
| 227 |
+
destination_gif.append(destination)
|
| 228 |
+
thresholds_gif.append(threshold)
|
| 229 |
+
face_tabs_gif.append(tab)
|
| 230 |
+
|
| 231 |
+
face_mode_gif.change(fn=lambda mode: toggle_tabs_and_faces(mode, face_tabs_gif, origin_gif), inputs=[face_mode_gif], outputs=face_tabs_gif + origin_gif)
|
| 232 |
+
demo.load(fn=lambda: toggle_tabs_and_faces("Single Face", face_tabs_gif, origin_gif), inputs=None, outputs=face_tabs_gif + origin_gif)
|
| 233 |
+
|
| 234 |
+
gif_btn.click(fn=run, inputs=[gif_input] + origin_gif + destination_gif + thresholds_gif + [preview_checkbox_gif, face_mode_gif, partial_reface_ratio_gif], outputs=[gif_output, gif_file_output])
|
| 235 |
+
|
| 236 |
+
gif_input.change(fn=lambda filepath: extract_faces_auto(filepath, refacer, max_faces=num_faces), inputs=gif_input, outputs=origin_gif)
|
| 237 |
+
gif_input.change(fn=lambda file: file, inputs=gif_input, outputs=[gif_preview])
|
| 238 |
+
gif_input.change(fn=lambda _: 0.0, inputs=gif_input, outputs=partial_reface_ratio_gif)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# --- TIF MODE ---
|
| 242 |
+
with gr.Tab("TIFF Mode"):
|
| 243 |
+
with gr.Row():
|
| 244 |
+
tif_input = gr.File(label="Original TIF", file_types=[".tif", ".tiff"])
|
| 245 |
+
tif_preview = gr.Image(label="TIF Preview (Cover Page)", type="filepath")
|
| 246 |
+
tif_output_preview = gr.Image(label="Refaced TIF Preview (Cover Page)", type="filepath")
|
| 247 |
+
tif_output_file = gr.File(label="Refaced TIF (Download)", interactive=False)
|
| 248 |
+
|
| 249 |
+
with gr.Row():
|
| 250 |
+
face_mode_tif = gr.Radio(
|
| 251 |
+
choices=["Single Face", "Multiple Faces", "Faces By Match"],
|
| 252 |
+
value="Single Face",
|
| 253 |
+
label="Replacement Mode"
|
| 254 |
+
)
|
| 255 |
+
partial_reface_ratio_tif = gr.Slider(label="Reface Ratio (0 = Full Face, 0.5 = Half Face)", minimum=0.0, maximum=0.5, value=0.0, step=0.1)
|
| 256 |
+
tif_btn = gr.Button("Reface TIF", variant="primary")
|
| 257 |
+
|
| 258 |
+
origin_tif, destination_tif, thresholds_tif, face_tabs_tif = [], [], [], []
|
| 259 |
+
|
| 260 |
+
for i in range(num_faces):
|
| 261 |
+
with gr.Tab(f"Face #{i+1}") as tab:
|
| 262 |
+
with gr.Row():
|
| 263 |
+
origin = gr.Image(label="Face to replace")
|
| 264 |
+
destination = gr.Image(label="Destination face")
|
| 265 |
+
threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2)
|
| 266 |
+
origin_tif.append(origin)
|
| 267 |
+
destination_tif.append(destination)
|
| 268 |
+
thresholds_tif.append(threshold)
|
| 269 |
+
face_tabs_tif.append(tab)
|
| 270 |
+
|
| 271 |
+
face_mode_tif.change(
|
| 272 |
+
fn=lambda mode: toggle_tabs_and_faces(mode, face_tabs_tif, origin_tif),
|
| 273 |
+
inputs=[face_mode_tif],
|
| 274 |
+
outputs=face_tabs_tif + origin_tif
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
demo.load(
|
| 278 |
+
fn=lambda: toggle_tabs_and_faces("Single Face", face_tabs_tif, origin_tif),
|
| 279 |
+
inputs=None,
|
| 280 |
+
outputs=face_tabs_tif + origin_tif
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def process_tif(tif_path, *vars):
|
| 284 |
+
original_img = Image.open(tif_path)
|
| 285 |
+
if hasattr(original_img, "n_frames") and original_img.n_frames > 1:
|
| 286 |
+
original_img.seek(0)
|
| 287 |
+
temp_preview_path = os.path.join("./tmp", f"tif_preview_{int(time.time() * 1000)}.jpg")
|
| 288 |
+
original_img.convert('RGB').save(temp_preview_path)
|
| 289 |
+
|
| 290 |
+
refaced_path = run_image(tif_path, *vars)
|
| 291 |
+
|
| 292 |
+
refaced_img = Image.open(refaced_path)
|
| 293 |
+
if hasattr(refaced_img, "n_frames") and refaced_img.n_frames > 1:
|
| 294 |
+
refaced_img.seek(0)
|
| 295 |
+
temp_refaced_preview_path = os.path.join("./tmp", f"refaced_tif_preview_{int(time.time() * 1000)}.jpg")
|
| 296 |
+
refaced_img.convert('RGB').save(temp_refaced_preview_path)
|
| 297 |
+
|
| 298 |
+
return temp_preview_path, temp_refaced_preview_path, refaced_path
|
| 299 |
+
|
| 300 |
+
tif_btn.click(
|
| 301 |
+
fn=lambda tif_path, *args: process_tif(tif_path, *args),
|
| 302 |
+
inputs=[tif_input] + origin_tif + destination_tif + thresholds_tif + [face_mode_tif, partial_reface_ratio_tif],
|
| 303 |
+
outputs=[tif_preview, tif_output_preview, tif_output_file]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
tif_input.change(
|
| 307 |
+
fn=lambda filepath: extract_faces_auto(filepath, refacer, max_faces=num_faces),
|
| 308 |
+
inputs=tif_input,
|
| 309 |
+
outputs=origin_tif
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
tif_input.change(
|
| 313 |
+
fn=handle_tif_preview,
|
| 314 |
+
inputs=tif_input,
|
| 315 |
+
outputs=tif_preview
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
tif_input.change(fn=lambda _: 0.0, inputs=tif_input, outputs=partial_reface_ratio_tif)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# --- VIDEO MODE ---
|
| 322 |
+
with gr.Tab("Video Mode"):
|
| 323 |
+
with gr.Row():
|
| 324 |
+
video_input = gr.Video(label="Original video", format="mp4")
|
| 325 |
+
video_output = gr.Video(label="Refaced Video", interactive=False, format="mp4")
|
| 326 |
+
|
| 327 |
+
with gr.Row():
|
| 328 |
+
face_mode_video = gr.Radio(
|
| 329 |
+
choices=["Single Face", "Multiple Faces", "Faces By Match"],
|
| 330 |
+
value="Single Face",
|
| 331 |
+
label="Replacement Mode"
|
| 332 |
+
)
|
| 333 |
+
partial_reface_ratio_video = gr.Slider(label="Reface Ratio (0 = Full Face, 0.5 = Half Face)", minimum=0.0, maximum=0.5, value=0.0, step=0.1)
|
| 334 |
+
video_btn = gr.Button("Reface Video", variant="primary")
|
| 335 |
+
|
| 336 |
+
preview_checkbox_video = gr.Checkbox(label="Preview Generation (skip 90% of frames)", value=False)
|
| 337 |
+
|
| 338 |
+
origin_video, destination_video, thresholds_video, face_tabs_video = [], [], [], []
|
| 339 |
+
|
| 340 |
+
for i in range(num_faces):
|
| 341 |
+
with gr.Tab(f"Face #{i+1}") as tab:
|
| 342 |
+
with gr.Row():
|
| 343 |
+
origin = gr.Image(label="Face to replace")
|
| 344 |
+
destination = gr.Image(label="Destination face")
|
| 345 |
+
threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2)
|
| 346 |
+
origin_video.append(origin)
|
| 347 |
+
destination_video.append(destination)
|
| 348 |
+
thresholds_video.append(threshold)
|
| 349 |
+
face_tabs_video.append(tab)
|
| 350 |
+
|
| 351 |
+
face_mode_video.change(
|
| 352 |
+
fn=lambda mode: toggle_tabs_and_faces(mode, face_tabs_video, origin_video),
|
| 353 |
+
inputs=[face_mode_video],
|
| 354 |
+
outputs=face_tabs_video + origin_video
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
demo.load(
|
| 358 |
+
fn=lambda: toggle_tabs_and_faces("Single Face", face_tabs_video, origin_video),
|
| 359 |
+
inputs=None,
|
| 360 |
+
outputs=face_tabs_video + origin_video
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
video_input.change(
|
| 364 |
+
fn=lambda filepath: extract_faces_auto(filepath, refacer, max_faces=num_faces, isvideo=True),
|
| 365 |
+
inputs=video_input,
|
| 366 |
+
outputs=origin_video
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
video_input.change(fn=lambda _: 0.0, inputs=video_input, outputs=partial_reface_ratio_video)
|
| 370 |
+
|
| 371 |
+
video_btn.click(
|
| 372 |
+
fn=lambda *args: run(*args),
|
| 373 |
+
inputs=[video_input] + origin_video + destination_video + thresholds_video + [preview_checkbox_video, face_mode_video, partial_reface_ratio_video],
|
| 374 |
+
outputs=[video_output, gr.File(visible=False)]
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# --- ngrok connect (optional) ---
|
| 378 |
+
if args.ngrok:
|
| 379 |
+
def connect(token, port, options):
|
| 380 |
+
try:
|
| 381 |
+
public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
|
| 382 |
+
print(f'ngrok URL: {public_url}')
|
| 383 |
+
except Exception as e:
|
| 384 |
+
print(f'ngrok connection aborted: {e}')
|
| 385 |
+
|
| 386 |
+
connect(args.ngrok, args.server_port, {'region': args.ngrok_region, 'authtoken_from_env': False})
|
| 387 |
+
|
| 388 |
+
# --- Launch app ---
|
| 389 |
+
demo.queue().launch(favicon_path="icon.png", show_error=True, share=args.share_gradio, server_name=args.server_name, server_port=args.server_port)
|
basicsr/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1.3.2
|
basicsr/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/xinntao/BasicSR
|
| 2 |
+
# flake8: noqa
|
| 3 |
+
from .archs import *
|
| 4 |
+
from .data import *
|
| 5 |
+
from .losses import *
|
| 6 |
+
from .metrics import *
|
| 7 |
+
from .models import *
|
| 8 |
+
from .ops import *
|
| 9 |
+
from .train import *
|
| 10 |
+
from .utils import *
|
| 11 |
+
from .version import __gitsha__, __version__
|
basicsr/archs/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
from basicsr.utils import get_root_logger, scandir
|
| 6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 7 |
+
|
| 8 |
+
__all__ = ['build_network']
|
| 9 |
+
|
| 10 |
+
# automatically scan and import arch modules for registry
|
| 11 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
| 12 |
+
# '_arch.py'
|
| 13 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
| 14 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
| 15 |
+
# import all the arch modules
|
| 16 |
+
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_network(opt):
|
| 20 |
+
opt = deepcopy(opt)
|
| 21 |
+
network_type = opt.pop('type')
|
| 22 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
| 23 |
+
logger = get_root_logger()
|
| 24 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
| 25 |
+
return net
|
basicsr/archs/arcface_arch.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
| 6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
inplanes (int): Channel number of inputs.
|
| 10 |
+
outplanes (int): Channel number of outputs.
|
| 11 |
+
stride (int): Stride in convolution. Default: 1.
|
| 12 |
+
"""
|
| 13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BasicBlock(nn.Module):
|
| 17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
inplanes (int): Channel number of inputs.
|
| 21 |
+
planes (int): Channel number of outputs.
|
| 22 |
+
stride (int): Stride in convolution. Default: 1.
|
| 23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 24 |
+
"""
|
| 25 |
+
expansion = 1 # output channel expansion ratio
|
| 26 |
+
|
| 27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 28 |
+
super(BasicBlock, self).__init__()
|
| 29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 31 |
+
self.relu = nn.ReLU(inplace=True)
|
| 32 |
+
self.conv2 = conv3x3(planes, planes)
|
| 33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 34 |
+
self.downsample = downsample
|
| 35 |
+
self.stride = stride
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
residual = x
|
| 39 |
+
|
| 40 |
+
out = self.conv1(x)
|
| 41 |
+
out = self.bn1(out)
|
| 42 |
+
out = self.relu(out)
|
| 43 |
+
|
| 44 |
+
out = self.conv2(out)
|
| 45 |
+
out = self.bn2(out)
|
| 46 |
+
|
| 47 |
+
if self.downsample is not None:
|
| 48 |
+
residual = self.downsample(x)
|
| 49 |
+
|
| 50 |
+
out += residual
|
| 51 |
+
out = self.relu(out)
|
| 52 |
+
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class IRBlock(nn.Module):
|
| 57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
inplanes (int): Channel number of inputs.
|
| 61 |
+
planes (int): Channel number of outputs.
|
| 62 |
+
stride (int): Stride in convolution. Default: 1.
|
| 63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 65 |
+
"""
|
| 66 |
+
expansion = 1 # output channel expansion ratio
|
| 67 |
+
|
| 68 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
| 69 |
+
super(IRBlock, self).__init__()
|
| 70 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
| 71 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
| 72 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
| 73 |
+
self.prelu = nn.PReLU()
|
| 74 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
| 75 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 76 |
+
self.downsample = downsample
|
| 77 |
+
self.stride = stride
|
| 78 |
+
self.use_se = use_se
|
| 79 |
+
if self.use_se:
|
| 80 |
+
self.se = SEBlock(planes)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
residual = x
|
| 84 |
+
out = self.bn0(x)
|
| 85 |
+
out = self.conv1(out)
|
| 86 |
+
out = self.bn1(out)
|
| 87 |
+
out = self.prelu(out)
|
| 88 |
+
|
| 89 |
+
out = self.conv2(out)
|
| 90 |
+
out = self.bn2(out)
|
| 91 |
+
if self.use_se:
|
| 92 |
+
out = self.se(out)
|
| 93 |
+
|
| 94 |
+
if self.downsample is not None:
|
| 95 |
+
residual = self.downsample(x)
|
| 96 |
+
|
| 97 |
+
out += residual
|
| 98 |
+
out = self.prelu(out)
|
| 99 |
+
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Bottleneck(nn.Module):
|
| 104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
inplanes (int): Channel number of inputs.
|
| 108 |
+
planes (int): Channel number of outputs.
|
| 109 |
+
stride (int): Stride in convolution. Default: 1.
|
| 110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 111 |
+
"""
|
| 112 |
+
expansion = 4 # output channel expansion ratio
|
| 113 |
+
|
| 114 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 115 |
+
super(Bottleneck, self).__init__()
|
| 116 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 117 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 118 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 119 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 120 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 121 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 122 |
+
self.relu = nn.ReLU(inplace=True)
|
| 123 |
+
self.downsample = downsample
|
| 124 |
+
self.stride = stride
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
residual = x
|
| 128 |
+
|
| 129 |
+
out = self.conv1(x)
|
| 130 |
+
out = self.bn1(out)
|
| 131 |
+
out = self.relu(out)
|
| 132 |
+
|
| 133 |
+
out = self.conv2(out)
|
| 134 |
+
out = self.bn2(out)
|
| 135 |
+
out = self.relu(out)
|
| 136 |
+
|
| 137 |
+
out = self.conv3(out)
|
| 138 |
+
out = self.bn3(out)
|
| 139 |
+
|
| 140 |
+
if self.downsample is not None:
|
| 141 |
+
residual = self.downsample(x)
|
| 142 |
+
|
| 143 |
+
out += residual
|
| 144 |
+
out = self.relu(out)
|
| 145 |
+
|
| 146 |
+
return out
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SEBlock(nn.Module):
|
| 150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
channel (int): Channel number of inputs.
|
| 154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, channel, reduction=16):
|
| 158 |
+
super(SEBlock, self).__init__()
|
| 159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
| 160 |
+
self.fc = nn.Sequential(
|
| 161 |
+
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
| 162 |
+
nn.Sigmoid())
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
b, c, _, _ = x.size()
|
| 166 |
+
y = self.avg_pool(x).view(b, c)
|
| 167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 168 |
+
return x * y
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@ARCH_REGISTRY.register()
|
| 172 |
+
class ResNetArcFace(nn.Module):
|
| 173 |
+
"""ArcFace with ResNet architectures.
|
| 174 |
+
|
| 175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
block (str): Block used in the ArcFace architecture.
|
| 179 |
+
layers (tuple(int)): Block numbers in each layer.
|
| 180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, block, layers, use_se=True):
|
| 184 |
+
if block == 'IRBlock':
|
| 185 |
+
block = IRBlock
|
| 186 |
+
self.inplanes = 64
|
| 187 |
+
self.use_se = use_se
|
| 188 |
+
super(ResNetArcFace, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
| 191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 192 |
+
self.prelu = nn.PReLU()
|
| 193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 198 |
+
self.bn4 = nn.BatchNorm2d(512)
|
| 199 |
+
self.dropout = nn.Dropout()
|
| 200 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
| 201 |
+
self.bn5 = nn.BatchNorm1d(512)
|
| 202 |
+
|
| 203 |
+
# initialization
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv2d):
|
| 206 |
+
nn.init.xavier_normal_(m.weight)
|
| 207 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
| 208 |
+
nn.init.constant_(m.weight, 1)
|
| 209 |
+
nn.init.constant_(m.bias, 0)
|
| 210 |
+
elif isinstance(m, nn.Linear):
|
| 211 |
+
nn.init.xavier_normal_(m.weight)
|
| 212 |
+
nn.init.constant_(m.bias, 0)
|
| 213 |
+
|
| 214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
| 215 |
+
downsample = None
|
| 216 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 217 |
+
downsample = nn.Sequential(
|
| 218 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
| 219 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 220 |
+
)
|
| 221 |
+
layers = []
|
| 222 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
| 223 |
+
self.inplanes = planes
|
| 224 |
+
for _ in range(1, num_blocks):
|
| 225 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
| 226 |
+
|
| 227 |
+
return nn.Sequential(*layers)
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
x = self.conv1(x)
|
| 231 |
+
x = self.bn1(x)
|
| 232 |
+
x = self.prelu(x)
|
| 233 |
+
x = self.maxpool(x)
|
| 234 |
+
|
| 235 |
+
x = self.layer1(x)
|
| 236 |
+
x = self.layer2(x)
|
| 237 |
+
x = self.layer3(x)
|
| 238 |
+
x = self.layer4(x)
|
| 239 |
+
x = self.bn4(x)
|
| 240 |
+
x = self.dropout(x)
|
| 241 |
+
x = x.view(x.size(0), -1)
|
| 242 |
+
x = self.fc5(x)
|
| 243 |
+
x = self.bn5(x)
|
| 244 |
+
|
| 245 |
+
return x
|
basicsr/archs/arch_util.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections.abc
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
import warnings
|
| 6 |
+
from distutils.version import LooseVersion
|
| 7 |
+
from itertools import repeat
|
| 8 |
+
from torch import nn as nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.nn import init as init
|
| 11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 12 |
+
|
| 13 |
+
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
| 14 |
+
from basicsr.utils import get_root_logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
| 19 |
+
"""Initialize network weights.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
| 23 |
+
scale (float): Scale initialized weights, especially for residual
|
| 24 |
+
blocks. Default: 1.
|
| 25 |
+
bias_fill (float): The value to fill bias. Default: 0
|
| 26 |
+
kwargs (dict): Other arguments for initialization function.
|
| 27 |
+
"""
|
| 28 |
+
if not isinstance(module_list, list):
|
| 29 |
+
module_list = [module_list]
|
| 30 |
+
for module in module_list:
|
| 31 |
+
for m in module.modules():
|
| 32 |
+
if isinstance(m, nn.Conv2d):
|
| 33 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 34 |
+
m.weight.data *= scale
|
| 35 |
+
if m.bias is not None:
|
| 36 |
+
m.bias.data.fill_(bias_fill)
|
| 37 |
+
elif isinstance(m, nn.Linear):
|
| 38 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 39 |
+
m.weight.data *= scale
|
| 40 |
+
if m.bias is not None:
|
| 41 |
+
m.bias.data.fill_(bias_fill)
|
| 42 |
+
elif isinstance(m, _BatchNorm):
|
| 43 |
+
init.constant_(m.weight, 1)
|
| 44 |
+
if m.bias is not None:
|
| 45 |
+
m.bias.data.fill_(bias_fill)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
| 49 |
+
"""Make layers by stacking the same blocks.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
basic_block (nn.module): nn.module class for basic block.
|
| 53 |
+
num_basic_block (int): number of blocks.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
| 57 |
+
"""
|
| 58 |
+
layers = []
|
| 59 |
+
for _ in range(num_basic_block):
|
| 60 |
+
layers.append(basic_block(**kwarg))
|
| 61 |
+
return nn.Sequential(*layers)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResidualBlockNoBN(nn.Module):
|
| 65 |
+
"""Residual block without BN.
|
| 66 |
+
|
| 67 |
+
It has a style of:
|
| 68 |
+
---Conv-ReLU-Conv-+-
|
| 69 |
+
|________________|
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
num_feat (int): Channel number of intermediate features.
|
| 73 |
+
Default: 64.
|
| 74 |
+
res_scale (float): Residual scale. Default: 1.
|
| 75 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
| 76 |
+
otherwise, use default_init_weights. Default: False.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
| 80 |
+
super(ResidualBlockNoBN, self).__init__()
|
| 81 |
+
self.res_scale = res_scale
|
| 82 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 83 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 84 |
+
self.relu = nn.ReLU(inplace=True)
|
| 85 |
+
|
| 86 |
+
if not pytorch_init:
|
| 87 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
identity = x
|
| 91 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
| 92 |
+
return identity + out * self.res_scale
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Upsample(nn.Sequential):
|
| 96 |
+
"""Upsample module.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
| 100 |
+
num_feat (int): Channel number of intermediate features.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, scale, num_feat):
|
| 104 |
+
m = []
|
| 105 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
| 106 |
+
for _ in range(int(math.log(scale, 2))):
|
| 107 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
| 108 |
+
m.append(nn.PixelShuffle(2))
|
| 109 |
+
elif scale == 3:
|
| 110 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
| 111 |
+
m.append(nn.PixelShuffle(3))
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
| 114 |
+
super(Upsample, self).__init__(*m)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
| 118 |
+
"""Warp an image or feature map with optical flow.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
| 122 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
| 123 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
| 124 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
| 125 |
+
Default: 'zeros'.
|
| 126 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
| 127 |
+
align_corners=True. After pytorch 1.3, the default value is
|
| 128 |
+
align_corners=False. Here, we use the True as default.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tensor: Warped image or feature map.
|
| 132 |
+
"""
|
| 133 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
| 134 |
+
_, _, h, w = x.size()
|
| 135 |
+
# create mesh grid
|
| 136 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
| 137 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
| 138 |
+
grid.requires_grad = False
|
| 139 |
+
|
| 140 |
+
vgrid = grid + flow
|
| 141 |
+
# scale grid to [-1,1]
|
| 142 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
| 143 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
| 144 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
| 145 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
| 146 |
+
|
| 147 |
+
# TODO, what if align_corners=False
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
| 152 |
+
"""Resize a flow according to ratio or shape.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
| 156 |
+
size_type (str): 'ratio' or 'shape'.
|
| 157 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
| 158 |
+
shape.
|
| 159 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
| 160 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
| 161 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
| 162 |
+
ratio > 1.0).
|
| 163 |
+
2) The order of output_size should be [out_h, out_w].
|
| 164 |
+
interp_mode (str): The mode of interpolation for resizing.
|
| 165 |
+
Default: 'bilinear'.
|
| 166 |
+
align_corners (bool): Whether align corners. Default: False.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tensor: Resized flow.
|
| 170 |
+
"""
|
| 171 |
+
_, _, flow_h, flow_w = flow.size()
|
| 172 |
+
if size_type == 'ratio':
|
| 173 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
| 174 |
+
elif size_type == 'shape':
|
| 175 |
+
output_h, output_w = sizes[0], sizes[1]
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
| 178 |
+
|
| 179 |
+
input_flow = flow.clone()
|
| 180 |
+
ratio_h = output_h / flow_h
|
| 181 |
+
ratio_w = output_w / flow_w
|
| 182 |
+
input_flow[:, 0, :, :] *= ratio_w
|
| 183 |
+
input_flow[:, 1, :, :] *= ratio_h
|
| 184 |
+
resized_flow = F.interpolate(
|
| 185 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
| 186 |
+
return resized_flow
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# TODO: may write a cpp file
|
| 190 |
+
def pixel_unshuffle(x, scale):
|
| 191 |
+
""" Pixel unshuffle.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
| 195 |
+
scale (int): Downsample ratio.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tensor: the pixel unshuffled feature.
|
| 199 |
+
"""
|
| 200 |
+
b, c, hh, hw = x.size()
|
| 201 |
+
out_channel = c * (scale**2)
|
| 202 |
+
assert hh % scale == 0 and hw % scale == 0
|
| 203 |
+
h = hh // scale
|
| 204 |
+
w = hw // scale
|
| 205 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
| 206 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class DCNv2Pack(ModulatedDeformConvPack):
|
| 210 |
+
"""Modulated deformable conv for deformable alignment.
|
| 211 |
+
|
| 212 |
+
Different from the official DCNv2Pack, which generates offsets and masks
|
| 213 |
+
from the preceding features, this DCNv2Pack takes another different
|
| 214 |
+
features to generate offsets and masks.
|
| 215 |
+
|
| 216 |
+
Ref:
|
| 217 |
+
Delving Deep into Deformable Alignment in Video Super-Resolution.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def forward(self, x, feat):
|
| 221 |
+
out = self.conv_offset(feat)
|
| 222 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 223 |
+
offset = torch.cat((o1, o2), dim=1)
|
| 224 |
+
mask = torch.sigmoid(mask)
|
| 225 |
+
|
| 226 |
+
offset_absmean = torch.mean(torch.abs(offset))
|
| 227 |
+
if offset_absmean > 50:
|
| 228 |
+
logger = get_root_logger()
|
| 229 |
+
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
| 230 |
+
|
| 231 |
+
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
| 232 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
| 233 |
+
self.dilation, mask)
|
| 234 |
+
else:
|
| 235 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
| 236 |
+
self.dilation, self.groups, self.deformable_groups)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 240 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
| 241 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 242 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 243 |
+
def norm_cdf(x):
|
| 244 |
+
# Computes standard normal cumulative distribution function
|
| 245 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 246 |
+
|
| 247 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 248 |
+
warnings.warn(
|
| 249 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
| 250 |
+
'The distribution of values may be incorrect.',
|
| 251 |
+
stacklevel=2)
|
| 252 |
+
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
# Values are generated by using a truncated uniform distribution and
|
| 255 |
+
# then using the inverse CDF for the normal distribution.
|
| 256 |
+
# Get upper and lower cdf values
|
| 257 |
+
low = norm_cdf((a - mean) / std)
|
| 258 |
+
up = norm_cdf((b - mean) / std)
|
| 259 |
+
|
| 260 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
| 261 |
+
# [2l-1, 2u-1].
|
| 262 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
| 263 |
+
|
| 264 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 265 |
+
# standard normal
|
| 266 |
+
tensor.erfinv_()
|
| 267 |
+
|
| 268 |
+
# Transform to proper mean, std
|
| 269 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 270 |
+
tensor.add_(mean)
|
| 271 |
+
|
| 272 |
+
# Clamp to ensure it's in the proper range
|
| 273 |
+
tensor.clamp_(min=a, max=b)
|
| 274 |
+
return tensor
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 278 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 279 |
+
normal distribution.
|
| 280 |
+
|
| 281 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
| 282 |
+
|
| 283 |
+
The values are effectively drawn from the
|
| 284 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 285 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 286 |
+
the bounds. The method used for generating the random values works
|
| 287 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 291 |
+
mean: the mean of the normal distribution
|
| 292 |
+
std: the standard deviation of the normal distribution
|
| 293 |
+
a: the minimum cutoff value
|
| 294 |
+
b: the maximum cutoff value
|
| 295 |
+
|
| 296 |
+
Examples:
|
| 297 |
+
>>> w = torch.empty(3, 5)
|
| 298 |
+
>>> nn.init.trunc_normal_(w)
|
| 299 |
+
"""
|
| 300 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# From PyTorch
|
| 304 |
+
def _ntuple(n):
|
| 305 |
+
|
| 306 |
+
def parse(x):
|
| 307 |
+
if isinstance(x, collections.abc.Iterable):
|
| 308 |
+
return x
|
| 309 |
+
return tuple(repeat(x, n))
|
| 310 |
+
|
| 311 |
+
return parse
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
to_1tuple = _ntuple(1)
|
| 315 |
+
to_2tuple = _ntuple(2)
|
| 316 |
+
to_3tuple = _ntuple(3)
|
| 317 |
+
to_4tuple = _ntuple(4)
|
| 318 |
+
to_ntuple = _ntuple
|
basicsr/archs/codeformer_arch.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
|
| 8 |
+
from basicsr.archs.vqgan_arch import *
|
| 9 |
+
from basicsr.utils import get_root_logger
|
| 10 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 11 |
+
|
| 12 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 13 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
feat (Tensor): 4D tensor.
|
| 17 |
+
eps (float): A small value added to the variance to avoid
|
| 18 |
+
divide-by-zero. Default: 1e-5.
|
| 19 |
+
"""
|
| 20 |
+
size = feat.size()
|
| 21 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 22 |
+
b, c = size[:2]
|
| 23 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
| 24 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
| 25 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
| 26 |
+
return feat_mean, feat_std
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
| 30 |
+
"""Adaptive instance normalization.
|
| 31 |
+
|
| 32 |
+
Adjust the reference features to have the similar color and illuminations
|
| 33 |
+
as those in the degradate features.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
content_feat (Tensor): The reference feature.
|
| 37 |
+
style_feat (Tensor): The degradate features.
|
| 38 |
+
"""
|
| 39 |
+
size = content_feat.size()
|
| 40 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
| 41 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
| 42 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 43 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PositionEmbeddingSine(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 49 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.num_pos_feats = num_pos_feats
|
| 55 |
+
self.temperature = temperature
|
| 56 |
+
self.normalize = normalize
|
| 57 |
+
if scale is not None and normalize is False:
|
| 58 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 59 |
+
if scale is None:
|
| 60 |
+
scale = 2 * math.pi
|
| 61 |
+
self.scale = scale
|
| 62 |
+
|
| 63 |
+
def forward(self, x, mask=None):
|
| 64 |
+
if mask is None:
|
| 65 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
| 66 |
+
not_mask = ~mask
|
| 67 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 68 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 69 |
+
if self.normalize:
|
| 70 |
+
eps = 1e-6
|
| 71 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 72 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 73 |
+
|
| 74 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 75 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 76 |
+
|
| 77 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 78 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 79 |
+
pos_x = torch.stack(
|
| 80 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 81 |
+
).flatten(3)
|
| 82 |
+
pos_y = torch.stack(
|
| 83 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 84 |
+
).flatten(3)
|
| 85 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 86 |
+
return pos
|
| 87 |
+
|
| 88 |
+
def _get_activation_fn(activation):
|
| 89 |
+
"""Return an activation function given a string"""
|
| 90 |
+
if activation == "relu":
|
| 91 |
+
return F.relu
|
| 92 |
+
if activation == "gelu":
|
| 93 |
+
return F.gelu
|
| 94 |
+
if activation == "glu":
|
| 95 |
+
return F.glu
|
| 96 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TransformerSALayer(nn.Module):
|
| 100 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
| 103 |
+
# Implementation of Feedforward model - MLP
|
| 104 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 105 |
+
self.dropout = nn.Dropout(dropout)
|
| 106 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
| 107 |
+
|
| 108 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
| 109 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
| 110 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 111 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 112 |
+
|
| 113 |
+
self.activation = _get_activation_fn(activation)
|
| 114 |
+
|
| 115 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 116 |
+
return tensor if pos is None else tensor + pos
|
| 117 |
+
|
| 118 |
+
def forward(self, tgt,
|
| 119 |
+
tgt_mask: Optional[Tensor] = None,
|
| 120 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 121 |
+
query_pos: Optional[Tensor] = None):
|
| 122 |
+
|
| 123 |
+
# self attention
|
| 124 |
+
tgt2 = self.norm1(tgt)
|
| 125 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 126 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 127 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 128 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 129 |
+
|
| 130 |
+
# ffn
|
| 131 |
+
tgt2 = self.norm2(tgt)
|
| 132 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 133 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 134 |
+
return tgt
|
| 135 |
+
|
| 136 |
+
class Fuse_sft_block(nn.Module):
|
| 137 |
+
def __init__(self, in_ch, out_ch):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
| 140 |
+
|
| 141 |
+
self.scale = nn.Sequential(
|
| 142 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
| 143 |
+
nn.LeakyReLU(0.2, True),
|
| 144 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
| 145 |
+
|
| 146 |
+
self.shift = nn.Sequential(
|
| 147 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
| 148 |
+
nn.LeakyReLU(0.2, True),
|
| 149 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
| 150 |
+
|
| 151 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
| 152 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
| 153 |
+
scale = self.scale(enc_feat)
|
| 154 |
+
shift = self.shift(enc_feat)
|
| 155 |
+
residual = w * (dec_feat * scale + shift)
|
| 156 |
+
out = dec_feat + residual
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@ARCH_REGISTRY.register()
|
| 161 |
+
class CodeFormer(VQAutoEncoder):
|
| 162 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
| 163 |
+
codebook_size=1024, latent_size=256,
|
| 164 |
+
connect_list=['32', '64', '128', '256'],
|
| 165 |
+
fix_modules=['quantize','generator']):
|
| 166 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
| 167 |
+
|
| 168 |
+
if fix_modules is not None:
|
| 169 |
+
for module in fix_modules:
|
| 170 |
+
for param in getattr(self, module).parameters():
|
| 171 |
+
param.requires_grad = False
|
| 172 |
+
|
| 173 |
+
self.connect_list = connect_list
|
| 174 |
+
self.n_layers = n_layers
|
| 175 |
+
self.dim_embd = dim_embd
|
| 176 |
+
self.dim_mlp = dim_embd*2
|
| 177 |
+
|
| 178 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
| 179 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 180 |
+
|
| 181 |
+
# transformer
|
| 182 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 183 |
+
for _ in range(self.n_layers)])
|
| 184 |
+
|
| 185 |
+
# logits_predict head
|
| 186 |
+
self.idx_pred_layer = nn.Sequential(
|
| 187 |
+
nn.LayerNorm(dim_embd),
|
| 188 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
| 189 |
+
|
| 190 |
+
self.channels = {
|
| 191 |
+
'16': 512,
|
| 192 |
+
'32': 256,
|
| 193 |
+
'64': 256,
|
| 194 |
+
'128': 128,
|
| 195 |
+
'256': 128,
|
| 196 |
+
'512': 64,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# after second residual block for > 16, before attn layer for ==16
|
| 200 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
| 201 |
+
# after first residual block for > 16, before attn layer for ==16
|
| 202 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
| 203 |
+
|
| 204 |
+
# fuse_convs_dict
|
| 205 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
| 206 |
+
for f_size in self.connect_list:
|
| 207 |
+
in_ch = self.channels[f_size]
|
| 208 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
| 209 |
+
|
| 210 |
+
def _init_weights(self, module):
|
| 211 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 212 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 213 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 214 |
+
module.bias.data.zero_()
|
| 215 |
+
elif isinstance(module, nn.LayerNorm):
|
| 216 |
+
module.bias.data.zero_()
|
| 217 |
+
module.weight.data.fill_(1.0)
|
| 218 |
+
|
| 219 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
| 220 |
+
# ################### Encoder #####################
|
| 221 |
+
enc_feat_dict = {}
|
| 222 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
| 223 |
+
for i, block in enumerate(self.encoder.blocks):
|
| 224 |
+
x = block(x)
|
| 225 |
+
if i in out_list:
|
| 226 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
| 227 |
+
|
| 228 |
+
lq_feat = x
|
| 229 |
+
# ################# Transformer ###################
|
| 230 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
| 231 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
| 232 |
+
# BCHW -> BC(HW) -> (HW)BC
|
| 233 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
| 234 |
+
query_emb = feat_emb
|
| 235 |
+
# Transformer encoder
|
| 236 |
+
for layer in self.ft_layers:
|
| 237 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
| 238 |
+
|
| 239 |
+
# output logits
|
| 240 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
| 241 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
| 242 |
+
|
| 243 |
+
if code_only: # for training stage II
|
| 244 |
+
# logits doesn't need softmax before cross_entropy loss
|
| 245 |
+
return logits, lq_feat
|
| 246 |
+
|
| 247 |
+
# ################# Quantization ###################
|
| 248 |
+
# if self.training:
|
| 249 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
| 250 |
+
# # b(hw)c -> bc(hw) -> bchw
|
| 251 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
| 252 |
+
# ------------
|
| 253 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
| 254 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
| 255 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
| 256 |
+
# preserve gradients
|
| 257 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
| 258 |
+
|
| 259 |
+
if detach_16:
|
| 260 |
+
quant_feat = quant_feat.detach() # for training stage III
|
| 261 |
+
if adain:
|
| 262 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
| 263 |
+
|
| 264 |
+
# ################## Generator ####################
|
| 265 |
+
x = quant_feat
|
| 266 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
| 267 |
+
|
| 268 |
+
for i, block in enumerate(self.generator.blocks):
|
| 269 |
+
x = block(x)
|
| 270 |
+
if i in fuse_list: # fuse after i-th block
|
| 271 |
+
f_size = str(x.shape[-1])
|
| 272 |
+
if w>0:
|
| 273 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
| 274 |
+
out = x
|
| 275 |
+
# logits doesn't need softmax before cross_entropy loss
|
| 276 |
+
return out, logits, lq_feat
|
basicsr/archs/rrdbnet_arch.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 6 |
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ResidualDenseBlock(nn.Module):
|
| 10 |
+
"""Residual Dense Block.
|
| 11 |
+
|
| 12 |
+
Used in RRDB block in ESRGAN.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
num_feat (int): Channel number of intermediate features.
|
| 16 |
+
num_grow_ch (int): Channels for each growth.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
| 20 |
+
super(ResidualDenseBlock, self).__init__()
|
| 21 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
| 22 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 23 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 24 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 25 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
| 26 |
+
|
| 27 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 28 |
+
|
| 29 |
+
# initialization
|
| 30 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
x1 = self.lrelu(self.conv1(x))
|
| 34 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
| 35 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
| 36 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
| 37 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 38 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
| 39 |
+
return x5 * 0.2 + x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RRDB(nn.Module):
|
| 43 |
+
"""Residual in Residual Dense Block.
|
| 44 |
+
|
| 45 |
+
Used in RRDB-Net in ESRGAN.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
num_feat (int): Channel number of intermediate features.
|
| 49 |
+
num_grow_ch (int): Channels for each growth.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
| 53 |
+
super(RRDB, self).__init__()
|
| 54 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 55 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 56 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
out = self.rdb1(x)
|
| 60 |
+
out = self.rdb2(out)
|
| 61 |
+
out = self.rdb3(out)
|
| 62 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
| 63 |
+
return out * 0.2 + x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@ARCH_REGISTRY.register()
|
| 67 |
+
class RRDBNet(nn.Module):
|
| 68 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
| 69 |
+
in ESRGAN.
|
| 70 |
+
|
| 71 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
| 72 |
+
|
| 73 |
+
We extend ESRGAN for scale x2 and scale x1.
|
| 74 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
| 75 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
| 76 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
num_in_ch (int): Channel number of inputs.
|
| 80 |
+
num_out_ch (int): Channel number of outputs.
|
| 81 |
+
num_feat (int): Channel number of intermediate features.
|
| 82 |
+
Default: 64
|
| 83 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
| 84 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
| 88 |
+
super(RRDBNet, self).__init__()
|
| 89 |
+
self.scale = scale
|
| 90 |
+
if scale == 2:
|
| 91 |
+
num_in_ch = num_in_ch * 4
|
| 92 |
+
elif scale == 1:
|
| 93 |
+
num_in_ch = num_in_ch * 16
|
| 94 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
| 95 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
| 96 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 97 |
+
# upsample
|
| 98 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 99 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
| 102 |
+
|
| 103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
if self.scale == 2:
|
| 107 |
+
feat = pixel_unshuffle(x, scale=2)
|
| 108 |
+
elif self.scale == 1:
|
| 109 |
+
feat = pixel_unshuffle(x, scale=4)
|
| 110 |
+
else:
|
| 111 |
+
feat = x
|
| 112 |
+
feat = self.conv_first(feat)
|
| 113 |
+
body_feat = self.conv_body(self.body(feat))
|
| 114 |
+
feat = feat + body_feat
|
| 115 |
+
# upsample
|
| 116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 118 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
| 119 |
+
return out
|
basicsr/archs/vgg_arch.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from torch import nn as nn
|
| 5 |
+
from torchvision.models import vgg as vgg
|
| 6 |
+
|
| 7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 8 |
+
|
| 9 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
| 10 |
+
NAMES = {
|
| 11 |
+
'vgg11': [
|
| 12 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
| 13 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
| 14 |
+
'pool5'
|
| 15 |
+
],
|
| 16 |
+
'vgg13': [
|
| 17 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
| 19 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
| 20 |
+
],
|
| 21 |
+
'vgg16': [
|
| 22 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 23 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
| 24 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
| 25 |
+
'pool5'
|
| 26 |
+
],
|
| 27 |
+
'vgg19': [
|
| 28 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 29 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
| 30 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
| 31 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
| 32 |
+
]
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def insert_bn(names):
|
| 37 |
+
"""Insert bn layer after each conv.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
names (list): The list of layer names.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
list: The list of layer names with bn layers.
|
| 44 |
+
"""
|
| 45 |
+
names_bn = []
|
| 46 |
+
for name in names:
|
| 47 |
+
names_bn.append(name)
|
| 48 |
+
if 'conv' in name:
|
| 49 |
+
position = name.replace('conv', '')
|
| 50 |
+
names_bn.append('bn' + position)
|
| 51 |
+
return names_bn
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@ARCH_REGISTRY.register()
|
| 55 |
+
class VGGFeatureExtractor(nn.Module):
|
| 56 |
+
"""VGG network for feature extraction.
|
| 57 |
+
|
| 58 |
+
In this implementation, we allow users to choose whether use normalization
|
| 59 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
| 60 |
+
path must fit the vgg type.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
| 64 |
+
features according to the layer_name_list.
|
| 65 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
| 66 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
| 67 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
| 68 |
+
the input feature must in the range [0, 1]. Default: True.
|
| 69 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
| 70 |
+
Default: False.
|
| 71 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
| 72 |
+
optimized. Default: False.
|
| 73 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
| 74 |
+
will be removed. Default: False.
|
| 75 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self,
|
| 79 |
+
layer_name_list,
|
| 80 |
+
vgg_type='vgg19',
|
| 81 |
+
use_input_norm=True,
|
| 82 |
+
range_norm=False,
|
| 83 |
+
requires_grad=False,
|
| 84 |
+
remove_pooling=False,
|
| 85 |
+
pooling_stride=2):
|
| 86 |
+
super(VGGFeatureExtractor, self).__init__()
|
| 87 |
+
|
| 88 |
+
self.layer_name_list = layer_name_list
|
| 89 |
+
self.use_input_norm = use_input_norm
|
| 90 |
+
self.range_norm = range_norm
|
| 91 |
+
|
| 92 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
| 93 |
+
if 'bn' in vgg_type:
|
| 94 |
+
self.names = insert_bn(self.names)
|
| 95 |
+
|
| 96 |
+
# only borrow layers that will be used to avoid unused params
|
| 97 |
+
max_idx = 0
|
| 98 |
+
for v in layer_name_list:
|
| 99 |
+
idx = self.names.index(v)
|
| 100 |
+
if idx > max_idx:
|
| 101 |
+
max_idx = idx
|
| 102 |
+
|
| 103 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
| 104 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
| 105 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
| 106 |
+
vgg_net.load_state_dict(state_dict)
|
| 107 |
+
else:
|
| 108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
| 109 |
+
|
| 110 |
+
features = vgg_net.features[:max_idx + 1]
|
| 111 |
+
|
| 112 |
+
modified_net = OrderedDict()
|
| 113 |
+
for k, v in zip(self.names, features):
|
| 114 |
+
if 'pool' in k:
|
| 115 |
+
# if remove_pooling is true, pooling operation will be removed
|
| 116 |
+
if remove_pooling:
|
| 117 |
+
continue
|
| 118 |
+
else:
|
| 119 |
+
# in some cases, we may want to change the default stride
|
| 120 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
| 121 |
+
else:
|
| 122 |
+
modified_net[k] = v
|
| 123 |
+
|
| 124 |
+
self.vgg_net = nn.Sequential(modified_net)
|
| 125 |
+
|
| 126 |
+
if not requires_grad:
|
| 127 |
+
self.vgg_net.eval()
|
| 128 |
+
for param in self.parameters():
|
| 129 |
+
param.requires_grad = False
|
| 130 |
+
else:
|
| 131 |
+
self.vgg_net.train()
|
| 132 |
+
for param in self.parameters():
|
| 133 |
+
param.requires_grad = True
|
| 134 |
+
|
| 135 |
+
if self.use_input_norm:
|
| 136 |
+
# the mean is for image with range [0, 1]
|
| 137 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 138 |
+
# the std is for image with range [0, 1]
|
| 139 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
"""Forward function.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tensor: Forward results.
|
| 149 |
+
"""
|
| 150 |
+
if self.range_norm:
|
| 151 |
+
x = (x + 1) / 2
|
| 152 |
+
if self.use_input_norm:
|
| 153 |
+
x = (x - self.mean) / self.std
|
| 154 |
+
output = {}
|
| 155 |
+
|
| 156 |
+
for key, layer in self.vgg_net._modules.items():
|
| 157 |
+
x = layer(x)
|
| 158 |
+
if key in self.layer_name_list:
|
| 159 |
+
output[key] = x.clone()
|
| 160 |
+
|
| 161 |
+
return output
|
basicsr/archs/vqgan_arch.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import copy
|
| 6 |
+
import os
|
| 7 |
+
from basicsr.utils import get_root_logger
|
| 8 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 9 |
+
|
| 10 |
+
# Select Device
|
| 11 |
+
def select_device(prefer_coreml=False):
|
| 12 |
+
if torch.backends.mps.is_available() and prefer_coreml:
|
| 13 |
+
print("BasicSR Archs: Using CoreML backend (MPS).")
|
| 14 |
+
return torch.device("mps")
|
| 15 |
+
elif torch.cuda.is_available():
|
| 16 |
+
print("BasicSR Archs: Using CUDA backend.")
|
| 17 |
+
return torch.device("cuda")
|
| 18 |
+
else:
|
| 19 |
+
print("BasicSR Archs: Using CPU backend.")
|
| 20 |
+
return torch.device("cpu")
|
| 21 |
+
|
| 22 |
+
# Set device globally
|
| 23 |
+
DEVICE = select_device(prefer_coreml=True)
|
| 24 |
+
|
| 25 |
+
def normalize(in_channels):
|
| 26 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 27 |
+
|
| 28 |
+
@torch.jit.script
|
| 29 |
+
def swish(x):
|
| 30 |
+
return x * torch.sigmoid(x)
|
| 31 |
+
|
| 32 |
+
class VectorQuantizer(nn.Module):
|
| 33 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
| 34 |
+
super(VectorQuantizer, self).__init__()
|
| 35 |
+
self.codebook_size = codebook_size
|
| 36 |
+
self.emb_dim = emb_dim
|
| 37 |
+
self.beta = beta
|
| 38 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
| 39 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
| 40 |
+
|
| 41 |
+
def forward(self, z):
|
| 42 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 43 |
+
z_flattened = z.view(-1, self.emb_dim)
|
| 44 |
+
|
| 45 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + \
|
| 46 |
+
(self.embedding.weight ** 2).sum(1) - \
|
| 47 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
| 48 |
+
|
| 49 |
+
mean_distance = torch.mean(d)
|
| 50 |
+
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
| 51 |
+
min_encoding_scores = torch.exp(-min_encoding_scores / 10)
|
| 52 |
+
|
| 53 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size, device=z.device)
|
| 54 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
| 55 |
+
|
| 56 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
| 57 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
| 58 |
+
z_q = z + (z_q - z).detach()
|
| 59 |
+
|
| 60 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
| 61 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
| 62 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 63 |
+
|
| 64 |
+
return z_q, loss, {
|
| 65 |
+
"perplexity": perplexity,
|
| 66 |
+
"min_encodings": min_encodings,
|
| 67 |
+
"min_encoding_indices": min_encoding_indices,
|
| 68 |
+
"min_encoding_scores": min_encoding_scores,
|
| 69 |
+
"mean_distance": mean_distance
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_codebook_feat(self, indices, shape):
|
| 73 |
+
indices = indices.view(-1, 1)
|
| 74 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size, device=indices.device)
|
| 75 |
+
min_encodings.scatter_(1, indices, 1)
|
| 76 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
| 77 |
+
|
| 78 |
+
if shape is not None:
|
| 79 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
| 80 |
+
|
| 81 |
+
return z_q
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class GumbelQuantizer(nn.Module):
|
| 85 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.codebook_size = codebook_size # number of embeddings
|
| 88 |
+
self.emb_dim = emb_dim # dimension of embedding
|
| 89 |
+
self.straight_through = straight_through
|
| 90 |
+
self.temperature = temp_init
|
| 91 |
+
self.kl_weight = kl_weight
|
| 92 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
| 93 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
| 94 |
+
|
| 95 |
+
def forward(self, z):
|
| 96 |
+
hard = self.straight_through if self.training else True
|
| 97 |
+
|
| 98 |
+
logits = self.proj(z)
|
| 99 |
+
|
| 100 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
| 101 |
+
|
| 102 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
| 103 |
+
|
| 104 |
+
# + kl divergence to the prior loss
|
| 105 |
+
qy = F.softmax(logits, dim=1)
|
| 106 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
| 107 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
| 108 |
+
|
| 109 |
+
return z_q, diff, {
|
| 110 |
+
"min_encoding_indices": min_encoding_indices
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Downsample(nn.Module):
|
| 115 |
+
def __init__(self, in_channels):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
pad = (0, 1, 0, 1)
|
| 121 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 122 |
+
x = self.conv(x)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class Upsample(nn.Module):
|
| 127 |
+
def __init__(self, in_channels):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 133 |
+
x = self.conv(x)
|
| 134 |
+
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ResBlock(nn.Module):
|
| 139 |
+
def __init__(self, in_channels, out_channels=None):
|
| 140 |
+
super(ResBlock, self).__init__()
|
| 141 |
+
self.in_channels = in_channels
|
| 142 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 143 |
+
self.norm1 = normalize(in_channels)
|
| 144 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 145 |
+
self.norm2 = normalize(out_channels)
|
| 146 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 147 |
+
if self.in_channels != self.out_channels:
|
| 148 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 149 |
+
|
| 150 |
+
def forward(self, x_in):
|
| 151 |
+
x = x_in
|
| 152 |
+
x = self.norm1(x)
|
| 153 |
+
x = swish(x)
|
| 154 |
+
x = self.conv1(x)
|
| 155 |
+
x = self.norm2(x)
|
| 156 |
+
x = swish(x)
|
| 157 |
+
x = self.conv2(x)
|
| 158 |
+
if self.in_channels != self.out_channels:
|
| 159 |
+
x_in = self.conv_out(x_in)
|
| 160 |
+
|
| 161 |
+
return x + x_in
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class AttnBlock(nn.Module):
|
| 165 |
+
def __init__(self, in_channels):
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.in_channels = in_channels
|
| 168 |
+
|
| 169 |
+
self.norm = normalize(in_channels)
|
| 170 |
+
self.q = torch.nn.Conv2d(
|
| 171 |
+
in_channels,
|
| 172 |
+
in_channels,
|
| 173 |
+
kernel_size=1,
|
| 174 |
+
stride=1,
|
| 175 |
+
padding=0
|
| 176 |
+
)
|
| 177 |
+
self.k = torch.nn.Conv2d(
|
| 178 |
+
in_channels,
|
| 179 |
+
in_channels,
|
| 180 |
+
kernel_size=1,
|
| 181 |
+
stride=1,
|
| 182 |
+
padding=0
|
| 183 |
+
)
|
| 184 |
+
self.v = torch.nn.Conv2d(
|
| 185 |
+
in_channels,
|
| 186 |
+
in_channels,
|
| 187 |
+
kernel_size=1,
|
| 188 |
+
stride=1,
|
| 189 |
+
padding=0
|
| 190 |
+
)
|
| 191 |
+
self.proj_out = torch.nn.Conv2d(
|
| 192 |
+
in_channels,
|
| 193 |
+
in_channels,
|
| 194 |
+
kernel_size=1,
|
| 195 |
+
stride=1,
|
| 196 |
+
padding=0
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
h_ = x
|
| 201 |
+
h_ = self.norm(h_)
|
| 202 |
+
q = self.q(h_)
|
| 203 |
+
k = self.k(h_)
|
| 204 |
+
v = self.v(h_)
|
| 205 |
+
|
| 206 |
+
# compute attention
|
| 207 |
+
b, c, h, w = q.shape
|
| 208 |
+
q = q.reshape(b, c, h*w)
|
| 209 |
+
q = q.permute(0, 2, 1)
|
| 210 |
+
k = k.reshape(b, c, h*w)
|
| 211 |
+
w_ = torch.bmm(q, k)
|
| 212 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 213 |
+
w_ = F.softmax(w_, dim=2)
|
| 214 |
+
|
| 215 |
+
# attend to values
|
| 216 |
+
v = v.reshape(b, c, h*w)
|
| 217 |
+
w_ = w_.permute(0, 2, 1)
|
| 218 |
+
h_ = torch.bmm(v, w_)
|
| 219 |
+
h_ = h_.reshape(b, c, h, w)
|
| 220 |
+
|
| 221 |
+
h_ = self.proj_out(h_)
|
| 222 |
+
|
| 223 |
+
return x+h_
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class Encoder(nn.Module):
|
| 227 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.nf = nf
|
| 230 |
+
self.num_resolutions = len(ch_mult)
|
| 231 |
+
self.num_res_blocks = num_res_blocks
|
| 232 |
+
self.resolution = resolution
|
| 233 |
+
self.attn_resolutions = attn_resolutions
|
| 234 |
+
|
| 235 |
+
curr_res = self.resolution
|
| 236 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 237 |
+
|
| 238 |
+
blocks = []
|
| 239 |
+
# initial convultion
|
| 240 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
| 241 |
+
|
| 242 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
| 243 |
+
for i in range(self.num_resolutions):
|
| 244 |
+
block_in_ch = nf * in_ch_mult[i]
|
| 245 |
+
block_out_ch = nf * ch_mult[i]
|
| 246 |
+
for _ in range(self.num_res_blocks):
|
| 247 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
| 248 |
+
block_in_ch = block_out_ch
|
| 249 |
+
if curr_res in attn_resolutions:
|
| 250 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 251 |
+
|
| 252 |
+
if i != self.num_resolutions - 1:
|
| 253 |
+
blocks.append(Downsample(block_in_ch))
|
| 254 |
+
curr_res = curr_res // 2
|
| 255 |
+
|
| 256 |
+
# non-local attention block
|
| 257 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 258 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 259 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 260 |
+
|
| 261 |
+
# normalise and convert to latent size
|
| 262 |
+
blocks.append(normalize(block_in_ch))
|
| 263 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
| 264 |
+
self.blocks = nn.ModuleList(blocks)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
for block in self.blocks:
|
| 268 |
+
x = block(x)
|
| 269 |
+
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class Generator(nn.Module):
|
| 274 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.nf = nf
|
| 277 |
+
self.ch_mult = ch_mult
|
| 278 |
+
self.num_resolutions = len(self.ch_mult)
|
| 279 |
+
self.num_res_blocks = res_blocks
|
| 280 |
+
self.resolution = img_size
|
| 281 |
+
self.attn_resolutions = attn_resolutions
|
| 282 |
+
self.in_channels = emb_dim
|
| 283 |
+
self.out_channels = 3
|
| 284 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
| 285 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
| 286 |
+
|
| 287 |
+
blocks = []
|
| 288 |
+
# initial conv
|
| 289 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
| 290 |
+
|
| 291 |
+
# non-local attention block
|
| 292 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 293 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 294 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 295 |
+
|
| 296 |
+
for i in reversed(range(self.num_resolutions)):
|
| 297 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
| 298 |
+
|
| 299 |
+
for _ in range(self.num_res_blocks):
|
| 300 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
| 301 |
+
block_in_ch = block_out_ch
|
| 302 |
+
|
| 303 |
+
if curr_res in self.attn_resolutions:
|
| 304 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 305 |
+
|
| 306 |
+
if i != 0:
|
| 307 |
+
blocks.append(Upsample(block_in_ch))
|
| 308 |
+
curr_res = curr_res * 2
|
| 309 |
+
|
| 310 |
+
blocks.append(normalize(block_in_ch))
|
| 311 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
| 312 |
+
|
| 313 |
+
self.blocks = nn.ModuleList(blocks)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
for block in self.blocks:
|
| 318 |
+
x = block(x)
|
| 319 |
+
|
| 320 |
+
return x
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# Autoencoder with device transfer
|
| 324 |
+
@ARCH_REGISTRY.register()
|
| 325 |
+
class VQAutoEncoder(nn.Module):
|
| 326 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2,
|
| 327 |
+
attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
| 328 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
| 329 |
+
super().__init__()
|
| 330 |
+
logger = get_root_logger()
|
| 331 |
+
self.in_channels = 3
|
| 332 |
+
self.nf = nf
|
| 333 |
+
self.codebook_size = codebook_size
|
| 334 |
+
self.embed_dim = emb_dim
|
| 335 |
+
self.ch_mult = ch_mult
|
| 336 |
+
self.resolution = img_size
|
| 337 |
+
self.attn_resolutions = attn_resolutions
|
| 338 |
+
self.quantizer_type = quantizer
|
| 339 |
+
|
| 340 |
+
self.encoder = Encoder(
|
| 341 |
+
self.in_channels, self.nf, self.embed_dim, self.ch_mult,
|
| 342 |
+
res_blocks, self.resolution, self.attn_resolutions
|
| 343 |
+
).to(DEVICE)
|
| 344 |
+
|
| 345 |
+
if self.quantizer_type == "nearest":
|
| 346 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, beta).to(DEVICE)
|
| 347 |
+
else:
|
| 348 |
+
self.quantize = GumbelQuantizer(
|
| 349 |
+
self.codebook_size, self.embed_dim, emb_dim,
|
| 350 |
+
gumbel_straight_through, gumbel_kl_weight
|
| 351 |
+
).to(DEVICE)
|
| 352 |
+
|
| 353 |
+
self.generator = Generator(
|
| 354 |
+
self.nf, self.embed_dim, self.ch_mult, res_blocks,
|
| 355 |
+
self.resolution, self.attn_resolutions
|
| 356 |
+
).to(DEVICE)
|
| 357 |
+
|
| 358 |
+
if model_path is not None:
|
| 359 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
| 360 |
+
if 'params_ema' in chkpt:
|
| 361 |
+
self.load_state_dict(chkpt['params_ema'])
|
| 362 |
+
logger.info(f'Loaded VQGAN from: {model_path} [params_ema]')
|
| 363 |
+
elif 'params' in chkpt:
|
| 364 |
+
self.load_state_dict(chkpt['params'])
|
| 365 |
+
logger.info(f'Loaded VQGAN from: {model_path} [params]')
|
| 366 |
+
else:
|
| 367 |
+
raise ValueError("Invalid model format!")
|
| 368 |
+
|
| 369 |
+
def forward(self, x):
|
| 370 |
+
x = x.to(DEVICE)
|
| 371 |
+
x = self.encoder(x)
|
| 372 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
| 373 |
+
x = self.generator(quant)
|
| 374 |
+
return x, codebook_loss, quant_stats
|
| 375 |
+
|
| 376 |
+
@ARCH_REGISTRY.register()
|
| 377 |
+
class VQGANDiscriminator(nn.Module):
|
| 378 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
| 379 |
+
super().__init__()
|
| 380 |
+
layers = [
|
| 381 |
+
nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
|
| 382 |
+
nn.LeakyReLU(0.2, True)
|
| 383 |
+
]
|
| 384 |
+
nf_mult = 1
|
| 385 |
+
for n in range(1, n_layers):
|
| 386 |
+
prev = nf_mult
|
| 387 |
+
nf_mult = min(2 ** n, 8)
|
| 388 |
+
layers += [
|
| 389 |
+
nn.Conv2d(ndf * prev, ndf * nf_mult, 4, 2, 1, bias=False),
|
| 390 |
+
nn.BatchNorm2d(ndf * nf_mult),
|
| 391 |
+
nn.LeakyReLU(0.2, True)
|
| 392 |
+
]
|
| 393 |
+
layers += [
|
| 394 |
+
nn.Conv2d(ndf * nf_mult, 1, 4, 1, 1)
|
| 395 |
+
]
|
| 396 |
+
self.main = nn.Sequential(*layers).to(DEVICE)
|
| 397 |
+
|
| 398 |
+
if model_path:
|
| 399 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
| 400 |
+
if 'params_d' in chkpt:
|
| 401 |
+
self.load_state_dict(chkpt['params_d'])
|
| 402 |
+
elif 'params' in chkpt:
|
| 403 |
+
self.load_state_dict(chkpt['params'])
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
return self.main(x.to(DEVICE))
|
basicsr/data/__init__.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from functools import partial
|
| 8 |
+
from os import path as osp
|
| 9 |
+
|
| 10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
| 11 |
+
from basicsr.utils import get_root_logger, scandir
|
| 12 |
+
from basicsr.utils.dist_util import get_dist_info
|
| 13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 14 |
+
|
| 15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
| 16 |
+
|
| 17 |
+
# automatically scan and import dataset modules for registry
|
| 18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
| 19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
| 21 |
+
# import all the dataset modules
|
| 22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_dataset(dataset_opt):
|
| 26 |
+
"""Build dataset from options.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dataset_opt (dict): Configuration for dataset. It must constain:
|
| 30 |
+
name (str): Dataset name.
|
| 31 |
+
type (str): Dataset type.
|
| 32 |
+
"""
|
| 33 |
+
dataset_opt = deepcopy(dataset_opt)
|
| 34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
| 35 |
+
logger = get_root_logger()
|
| 36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
|
| 37 |
+
return dataset
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
| 41 |
+
"""Build dataloader.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
| 45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
| 46 |
+
phase (str): 'train' or 'val'.
|
| 47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
| 48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
| 49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
| 50 |
+
Default: 1.
|
| 51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
| 52 |
+
phase. Default: False.
|
| 53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
| 54 |
+
seed (int | None): Seed. Default: None
|
| 55 |
+
"""
|
| 56 |
+
phase = dataset_opt['phase']
|
| 57 |
+
rank, _ = get_dist_info()
|
| 58 |
+
if phase == 'train':
|
| 59 |
+
if dist: # distributed training
|
| 60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
| 61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
| 62 |
+
else: # non-distributed training
|
| 63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
| 64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
| 65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
| 66 |
+
dataloader_args = dict(
|
| 67 |
+
dataset=dataset,
|
| 68 |
+
batch_size=batch_size,
|
| 69 |
+
shuffle=False,
|
| 70 |
+
num_workers=num_workers,
|
| 71 |
+
sampler=sampler,
|
| 72 |
+
drop_last=True)
|
| 73 |
+
if sampler is None:
|
| 74 |
+
dataloader_args['shuffle'] = True
|
| 75 |
+
dataloader_args['worker_init_fn'] = partial(
|
| 76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
| 77 |
+
elif phase in ['val', 'test']: # validation
|
| 78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
|
| 81 |
+
|
| 82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
| 83 |
+
|
| 84 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
| 85 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
| 86 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
| 87 |
+
logger = get_root_logger()
|
| 88 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
|
| 89 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
| 90 |
+
else:
|
| 91 |
+
# prefetch_mode=None: Normal dataloader
|
| 92 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
| 93 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
| 97 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
| 98 |
+
worker_seed = num_workers * rank + worker_id + seed
|
| 99 |
+
np.random.seed(worker_seed)
|
| 100 |
+
random.seed(worker_seed)
|
basicsr/data/data_sampler.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.sampler import Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EnlargedSampler(Sampler):
|
| 7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 8 |
+
|
| 9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
| 10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
| 11 |
+
time when restart the dataloader after each epoch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
| 15 |
+
num_replicas (int | None): Number of processes participating in
|
| 16 |
+
the training. It is usually the world_size.
|
| 17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
| 18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.num_replicas = num_replicas
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.epoch = 0
|
| 26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
| 27 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 28 |
+
|
| 29 |
+
def __iter__(self):
|
| 30 |
+
# deterministically shuffle based on epoch
|
| 31 |
+
g = torch.Generator()
|
| 32 |
+
g.manual_seed(self.epoch)
|
| 33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
| 34 |
+
|
| 35 |
+
dataset_size = len(self.dataset)
|
| 36 |
+
indices = [v % dataset_size for v in indices]
|
| 37 |
+
|
| 38 |
+
# subsample
|
| 39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 40 |
+
assert len(indices) == self.num_samples
|
| 41 |
+
|
| 42 |
+
return iter(indices)
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return self.num_samples
|
| 46 |
+
|
| 47 |
+
def set_epoch(self, epoch):
|
| 48 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from os import path as osp
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from basicsr.data.transforms import mod_crop
|
| 8 |
+
from basicsr.utils import img2tensor, scandir
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def read_img_seq(path, require_mod_crop=False, scale=1):
|
| 12 |
+
"""Read a sequence of images from a given folder path.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
path (list[str] | str): List of image paths or image folder path.
|
| 16 |
+
require_mod_crop (bool): Require mod crop for each image.
|
| 17 |
+
Default: False.
|
| 18 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
| 22 |
+
"""
|
| 23 |
+
if isinstance(path, list):
|
| 24 |
+
img_paths = path
|
| 25 |
+
else:
|
| 26 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
| 27 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
| 28 |
+
if require_mod_crop:
|
| 29 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
| 30 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
| 31 |
+
imgs = torch.stack(imgs, dim=0)
|
| 32 |
+
return imgs
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
| 36 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
| 37 |
+
of images.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
crt_idx (int): Current center index.
|
| 41 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
| 42 |
+
num_frames (int): Reading num_frames frames.
|
| 43 |
+
padding (str): Padding mode, one of
|
| 44 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
| 45 |
+
Examples: current_idx = 0, num_frames = 5
|
| 46 |
+
The generated frame indices under different padding mode:
|
| 47 |
+
replicate: [0, 0, 0, 1, 2]
|
| 48 |
+
reflection: [2, 1, 0, 1, 2]
|
| 49 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
| 50 |
+
circle: [3, 4, 0, 1, 2]
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
list[int]: A list of indices.
|
| 54 |
+
"""
|
| 55 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
| 56 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
| 57 |
+
|
| 58 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
| 59 |
+
num_pad = num_frames // 2
|
| 60 |
+
|
| 61 |
+
indices = []
|
| 62 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
| 63 |
+
if i < 0:
|
| 64 |
+
if padding == 'replicate':
|
| 65 |
+
pad_idx = 0
|
| 66 |
+
elif padding == 'reflection':
|
| 67 |
+
pad_idx = -i
|
| 68 |
+
elif padding == 'reflection_circle':
|
| 69 |
+
pad_idx = crt_idx + num_pad - i
|
| 70 |
+
else:
|
| 71 |
+
pad_idx = num_frames + i
|
| 72 |
+
elif i > max_frame_num:
|
| 73 |
+
if padding == 'replicate':
|
| 74 |
+
pad_idx = max_frame_num
|
| 75 |
+
elif padding == 'reflection':
|
| 76 |
+
pad_idx = max_frame_num * 2 - i
|
| 77 |
+
elif padding == 'reflection_circle':
|
| 78 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
| 79 |
+
else:
|
| 80 |
+
pad_idx = i - num_frames
|
| 81 |
+
else:
|
| 82 |
+
pad_idx = i
|
| 83 |
+
indices.append(pad_idx)
|
| 84 |
+
return indices
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def paired_paths_from_lmdb(folders, keys):
|
| 88 |
+
"""Generate paired paths from lmdb files.
|
| 89 |
+
|
| 90 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
| 91 |
+
|
| 92 |
+
lq.lmdb
|
| 93 |
+
├── data.mdb
|
| 94 |
+
├── lock.mdb
|
| 95 |
+
├── meta_info.txt
|
| 96 |
+
|
| 97 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
| 98 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
| 99 |
+
|
| 100 |
+
The meta_info.txt is a specified txt file to record the meta information
|
| 101 |
+
of our datasets. It will be automatically created when preparing
|
| 102 |
+
datasets by our provided dataset tools.
|
| 103 |
+
Each line in the txt file records
|
| 104 |
+
1)image name (with extension),
|
| 105 |
+
2)image shape,
|
| 106 |
+
3)compression level, separated by a white space.
|
| 107 |
+
Example: `baboon.png (120,125,3) 1`
|
| 108 |
+
|
| 109 |
+
We use the image name without extension as the lmdb key.
|
| 110 |
+
Note that we use the same key for the corresponding lq and gt images.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 114 |
+
be [input_folder, gt_folder].
|
| 115 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 116 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 117 |
+
Note that this key is different from lmdb keys.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
list[str]: Returned path list.
|
| 121 |
+
"""
|
| 122 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 123 |
+
f'But got {len(folders)}')
|
| 124 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
| 125 |
+
input_folder, gt_folder = folders
|
| 126 |
+
input_key, gt_key = keys
|
| 127 |
+
|
| 128 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
| 129 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
| 130 |
+
f'formats. But received {input_key}: {input_folder}; '
|
| 131 |
+
f'{gt_key}: {gt_folder}')
|
| 132 |
+
# ensure that the two meta_info files are the same
|
| 133 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
| 134 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 135 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
| 136 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 137 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
| 138 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
| 139 |
+
else:
|
| 140 |
+
paths = []
|
| 141 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
| 142 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
| 143 |
+
return paths
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
| 147 |
+
"""Generate paired paths from an meta information file.
|
| 148 |
+
|
| 149 |
+
Each line in the meta information file contains the image names and
|
| 150 |
+
image shape (usually for gt), separated by a white space.
|
| 151 |
+
|
| 152 |
+
Example of an meta information file:
|
| 153 |
+
```
|
| 154 |
+
0001_s001.png (480,480,3)
|
| 155 |
+
0001_s002.png (480,480,3)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 160 |
+
be [input_folder, gt_folder].
|
| 161 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 162 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 163 |
+
meta_info_file (str): Path to the meta information file.
|
| 164 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 165 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 166 |
+
for files in the input folder.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
list[str]: Returned path list.
|
| 170 |
+
"""
|
| 171 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 172 |
+
f'But got {len(folders)}')
|
| 173 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
| 174 |
+
input_folder, gt_folder = folders
|
| 175 |
+
input_key, gt_key = keys
|
| 176 |
+
|
| 177 |
+
with open(meta_info_file, 'r') as fin:
|
| 178 |
+
gt_names = [line.split(' ')[0] for line in fin]
|
| 179 |
+
|
| 180 |
+
paths = []
|
| 181 |
+
for gt_name in gt_names:
|
| 182 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
| 183 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 184 |
+
input_path = osp.join(input_folder, input_name)
|
| 185 |
+
gt_path = osp.join(gt_folder, gt_name)
|
| 186 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
| 187 |
+
return paths
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
| 191 |
+
"""Generate paired paths from folders.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 195 |
+
be [input_folder, gt_folder].
|
| 196 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 197 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 198 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 199 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 200 |
+
for files in the input folder.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
list[str]: Returned path list.
|
| 204 |
+
"""
|
| 205 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 206 |
+
f'But got {len(folders)}')
|
| 207 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
| 208 |
+
input_folder, gt_folder = folders
|
| 209 |
+
input_key, gt_key = keys
|
| 210 |
+
|
| 211 |
+
input_paths = list(scandir(input_folder))
|
| 212 |
+
gt_paths = list(scandir(gt_folder))
|
| 213 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
| 214 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
| 215 |
+
paths = []
|
| 216 |
+
for gt_path in gt_paths:
|
| 217 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
| 218 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 219 |
+
input_path = osp.join(input_folder, input_name)
|
| 220 |
+
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
|
| 221 |
+
gt_path = osp.join(gt_folder, gt_path)
|
| 222 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
| 223 |
+
return paths
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def paths_from_folder(folder):
|
| 227 |
+
"""Generate paths from folder.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
folder (str): Folder path.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
list[str]: Returned path list.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
paths = list(scandir(folder))
|
| 237 |
+
paths = [osp.join(folder, path) for path in paths]
|
| 238 |
+
return paths
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def paths_from_lmdb(folder):
|
| 242 |
+
"""Generate paths from lmdb.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
folder (str): Folder path.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
list[str]: Returned path list.
|
| 249 |
+
"""
|
| 250 |
+
if not folder.endswith('.lmdb'):
|
| 251 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
| 252 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
| 253 |
+
paths = [line.split('.')[0] for line in fin]
|
| 254 |
+
return paths
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
| 258 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 262 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
np.array: The Gaussian kernel.
|
| 266 |
+
"""
|
| 267 |
+
from scipy.ndimage import filters as filters
|
| 268 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
| 269 |
+
# set element at the middle to one, a dirac delta
|
| 270 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
| 271 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
| 272 |
+
return filters.gaussian_filter(kernel, sigma)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
| 276 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
| 280 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 281 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
| 282 |
+
Default: 4.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Tensor: DUF downsampled frames.
|
| 286 |
+
"""
|
| 287 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
| 288 |
+
|
| 289 |
+
squeeze_flag = False
|
| 290 |
+
if x.ndim == 4:
|
| 291 |
+
squeeze_flag = True
|
| 292 |
+
x = x.unsqueeze(0)
|
| 293 |
+
b, t, c, h, w = x.size()
|
| 294 |
+
x = x.view(-1, 1, h, w)
|
| 295 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
| 296 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
| 297 |
+
|
| 298 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
| 299 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
| 300 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
| 301 |
+
x = x[:, :, 2:-2, 2:-2]
|
| 302 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
| 303 |
+
if squeeze_flag:
|
| 304 |
+
x = x.squeeze(0)
|
| 305 |
+
return x
|
basicsr/data/prefetch_dataloader.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue as Queue
|
| 2 |
+
import threading
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PrefetchGenerator(threading.Thread):
|
| 8 |
+
"""A general prefetch generator.
|
| 9 |
+
|
| 10 |
+
Ref:
|
| 11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
generator: Python generator.
|
| 15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator, num_prefetch_queue):
|
| 19 |
+
threading.Thread.__init__(self)
|
| 20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
| 21 |
+
self.generator = generator
|
| 22 |
+
self.daemon = True
|
| 23 |
+
self.start()
|
| 24 |
+
|
| 25 |
+
def run(self):
|
| 26 |
+
for item in self.generator:
|
| 27 |
+
self.queue.put(item)
|
| 28 |
+
self.queue.put(None)
|
| 29 |
+
|
| 30 |
+
def __next__(self):
|
| 31 |
+
next_item = self.queue.get()
|
| 32 |
+
if next_item is None:
|
| 33 |
+
raise StopIteration
|
| 34 |
+
return next_item
|
| 35 |
+
|
| 36 |
+
def __iter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PrefetchDataLoader(DataLoader):
|
| 41 |
+
"""Prefetch version of dataloader.
|
| 42 |
+
|
| 43 |
+
Ref:
|
| 44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
| 45 |
+
|
| 46 |
+
TODO:
|
| 47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
| 48 |
+
ddp.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 52 |
+
kwargs (dict): Other arguments for dataloader.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
| 56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
| 57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
| 58 |
+
|
| 59 |
+
def __iter__(self):
|
| 60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CPUPrefetcher():
|
| 64 |
+
"""CPU prefetcher.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
loader: Dataloader.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, loader):
|
| 71 |
+
self.ori_loader = loader
|
| 72 |
+
self.loader = iter(loader)
|
| 73 |
+
|
| 74 |
+
def next(self):
|
| 75 |
+
try:
|
| 76 |
+
return next(self.loader)
|
| 77 |
+
except StopIteration:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def reset(self):
|
| 81 |
+
self.loader = iter(self.ori_loader)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CUDAPrefetcher():
|
| 85 |
+
"""CUDA (or MPS/CPU) prefetcher.
|
| 86 |
+
|
| 87 |
+
It may consume more GPU memory.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
loader: Dataloader.
|
| 91 |
+
opt (dict): Options.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, loader, opt):
|
| 95 |
+
self.ori_loader = loader
|
| 96 |
+
self.loader = iter(loader)
|
| 97 |
+
self.opt = opt
|
| 98 |
+
|
| 99 |
+
# Cross-platform device detection
|
| 100 |
+
if opt['num_gpu'] != 0 and torch.cuda.is_available():
|
| 101 |
+
self.device = torch.device('cuda')
|
| 102 |
+
self.stream = torch.cuda.Stream()
|
| 103 |
+
elif torch.backends.mps.is_available():
|
| 104 |
+
self.device = torch.device('mps')
|
| 105 |
+
self.stream = None
|
| 106 |
+
else:
|
| 107 |
+
self.device = torch.device('cpu')
|
| 108 |
+
self.stream = None
|
| 109 |
+
|
| 110 |
+
self.preload()
|
| 111 |
+
|
| 112 |
+
def preload(self):
|
| 113 |
+
try:
|
| 114 |
+
self.batch = next(self.loader) # self.batch is a dict
|
| 115 |
+
except StopIteration:
|
| 116 |
+
self.batch = None
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
if self.stream is not None:
|
| 120 |
+
with torch.cuda.stream(self.stream):
|
| 121 |
+
for k, v in self.batch.items():
|
| 122 |
+
if torch.is_tensor(v):
|
| 123 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
| 124 |
+
else:
|
| 125 |
+
for k, v in self.batch.items():
|
| 126 |
+
if torch.is_tensor(v):
|
| 127 |
+
self.batch[k] = self.batch[k].to(device=self.device)
|
| 128 |
+
|
| 129 |
+
def next(self):
|
| 130 |
+
if self.stream is not None:
|
| 131 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
| 132 |
+
batch = self.batch
|
| 133 |
+
self.preload()
|
| 134 |
+
return batch
|
| 135 |
+
|
| 136 |
+
def reset(self):
|
| 137 |
+
self.loader = iter(self.ori_loader)
|
| 138 |
+
self.preload()
|
basicsr/data/transforms.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def mod_crop(img, scale):
|
| 6 |
+
"""Mod crop images, used during testing.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
img (ndarray): Input image.
|
| 10 |
+
scale (int): Scale factor.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
ndarray: Result image.
|
| 14 |
+
"""
|
| 15 |
+
img = img.copy()
|
| 16 |
+
if img.ndim in (2, 3):
|
| 17 |
+
h, w = img.shape[0], img.shape[1]
|
| 18 |
+
h_remainder, w_remainder = h % scale, w % scale
|
| 19 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
| 22 |
+
return img
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
|
| 26 |
+
"""Paired random crop.
|
| 27 |
+
|
| 28 |
+
It crops lists of lq and gt images with corresponding locations.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
| 32 |
+
should have the same shape. If the input is an ndarray, it will
|
| 33 |
+
be transformed to a list containing itself.
|
| 34 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 35 |
+
should have the same shape. If the input is an ndarray, it will
|
| 36 |
+
be transformed to a list containing itself.
|
| 37 |
+
gt_patch_size (int): GT patch size.
|
| 38 |
+
scale (int): Scale factor.
|
| 39 |
+
gt_path (str): Path to ground-truth.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 43 |
+
only have one element, just return ndarray.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
if not isinstance(img_gts, list):
|
| 47 |
+
img_gts = [img_gts]
|
| 48 |
+
if not isinstance(img_lqs, list):
|
| 49 |
+
img_lqs = [img_lqs]
|
| 50 |
+
|
| 51 |
+
h_lq, w_lq, _ = img_lqs[0].shape
|
| 52 |
+
h_gt, w_gt, _ = img_gts[0].shape
|
| 53 |
+
lq_patch_size = gt_patch_size // scale
|
| 54 |
+
|
| 55 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 56 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 57 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 58 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 59 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 60 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
| 61 |
+
f'Please remove {gt_path}.')
|
| 62 |
+
|
| 63 |
+
# randomly choose top and left coordinates for lq patch
|
| 64 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
| 65 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
| 66 |
+
|
| 67 |
+
# crop lq patch
|
| 68 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
| 69 |
+
|
| 70 |
+
# crop corresponding gt patch
|
| 71 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
| 72 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
| 73 |
+
if len(img_gts) == 1:
|
| 74 |
+
img_gts = img_gts[0]
|
| 75 |
+
if len(img_lqs) == 1:
|
| 76 |
+
img_lqs = img_lqs[0]
|
| 77 |
+
return img_gts, img_lqs
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
| 81 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
| 82 |
+
|
| 83 |
+
We use vertical flip and transpose for rotation implementation.
|
| 84 |
+
All the images in the list use the same augmentation.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
| 88 |
+
is an ndarray, it will be transformed to a list.
|
| 89 |
+
hflip (bool): Horizontal flip. Default: True.
|
| 90 |
+
rotation (bool): Ratotation. Default: True.
|
| 91 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
| 92 |
+
ndarray, it will be transformed to a list.
|
| 93 |
+
Dimension is (h, w, 2). Default: None.
|
| 94 |
+
return_status (bool): Return the status of flip and rotation.
|
| 95 |
+
Default: False.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
| 99 |
+
results only have one element, just return ndarray.
|
| 100 |
+
|
| 101 |
+
"""
|
| 102 |
+
hflip = hflip and random.random() < 0.5
|
| 103 |
+
vflip = rotation and random.random() < 0.5
|
| 104 |
+
rot90 = rotation and random.random() < 0.5
|
| 105 |
+
|
| 106 |
+
def _augment(img):
|
| 107 |
+
if hflip: # horizontal
|
| 108 |
+
cv2.flip(img, 1, img)
|
| 109 |
+
if vflip: # vertical
|
| 110 |
+
cv2.flip(img, 0, img)
|
| 111 |
+
if rot90:
|
| 112 |
+
img = img.transpose(1, 0, 2)
|
| 113 |
+
return img
|
| 114 |
+
|
| 115 |
+
def _augment_flow(flow):
|
| 116 |
+
if hflip: # horizontal
|
| 117 |
+
cv2.flip(flow, 1, flow)
|
| 118 |
+
flow[:, :, 0] *= -1
|
| 119 |
+
if vflip: # vertical
|
| 120 |
+
cv2.flip(flow, 0, flow)
|
| 121 |
+
flow[:, :, 1] *= -1
|
| 122 |
+
if rot90:
|
| 123 |
+
flow = flow.transpose(1, 0, 2)
|
| 124 |
+
flow = flow[:, :, [1, 0]]
|
| 125 |
+
return flow
|
| 126 |
+
|
| 127 |
+
if not isinstance(imgs, list):
|
| 128 |
+
imgs = [imgs]
|
| 129 |
+
imgs = [_augment(img) for img in imgs]
|
| 130 |
+
if len(imgs) == 1:
|
| 131 |
+
imgs = imgs[0]
|
| 132 |
+
|
| 133 |
+
if flows is not None:
|
| 134 |
+
if not isinstance(flows, list):
|
| 135 |
+
flows = [flows]
|
| 136 |
+
flows = [_augment_flow(flow) for flow in flows]
|
| 137 |
+
if len(flows) == 1:
|
| 138 |
+
flows = flows[0]
|
| 139 |
+
return imgs, flows
|
| 140 |
+
else:
|
| 141 |
+
if return_status:
|
| 142 |
+
return imgs, (hflip, vflip, rot90)
|
| 143 |
+
else:
|
| 144 |
+
return imgs
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
| 148 |
+
"""Rotate image.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
img (ndarray): Image to be rotated.
|
| 152 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
| 153 |
+
counter-clockwise rotation.
|
| 154 |
+
center (tuple[int]): Rotation center. If the center is None,
|
| 155 |
+
initialize it as the center of the image. Default: None.
|
| 156 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
| 157 |
+
"""
|
| 158 |
+
(h, w) = img.shape[:2]
|
| 159 |
+
|
| 160 |
+
if center is None:
|
| 161 |
+
center = (w // 2, h // 2)
|
| 162 |
+
|
| 163 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 164 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
| 165 |
+
return rotated_img
|
basicsr/losses/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
from basicsr.utils import get_root_logger
|
| 4 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
| 5 |
+
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
|
| 6 |
+
gradient_penalty_loss, r1_penalty)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
|
| 10 |
+
'r1_penalty', 'g_path_regularize'
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_loss(opt):
|
| 15 |
+
"""Build loss from options.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
opt (dict): Configuration. It must constain:
|
| 19 |
+
type (str): Model type.
|
| 20 |
+
"""
|
| 21 |
+
opt = deepcopy(opt)
|
| 22 |
+
loss_type = opt.pop('type')
|
| 23 |
+
loss = LOSS_REGISTRY.get(loss_type)(**opt)
|
| 24 |
+
logger = get_root_logger()
|
| 25 |
+
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
|
| 26 |
+
return loss
|
basicsr/losses/loss_util.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def reduce_loss(loss, reduction):
|
| 6 |
+
"""Reduce loss as specified.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
loss (Tensor): Elementwise loss tensor.
|
| 10 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Tensor: Reduced loss tensor.
|
| 14 |
+
"""
|
| 15 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 16 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 17 |
+
if reduction_enum == 0:
|
| 18 |
+
return loss
|
| 19 |
+
elif reduction_enum == 1:
|
| 20 |
+
return loss.mean()
|
| 21 |
+
else:
|
| 22 |
+
return loss.sum()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
| 26 |
+
"""Apply element-wise weight and reduce loss.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
loss (Tensor): Element-wise loss.
|
| 30 |
+
weight (Tensor): Element-wise weights. Default: None.
|
| 31 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
| 32 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Tensor: Loss values.
|
| 36 |
+
"""
|
| 37 |
+
# if weight is specified, apply element-wise weight
|
| 38 |
+
if weight is not None:
|
| 39 |
+
assert weight.dim() == loss.dim()
|
| 40 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
| 41 |
+
loss = loss * weight
|
| 42 |
+
|
| 43 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
| 44 |
+
if weight is None or reduction == 'sum':
|
| 45 |
+
loss = reduce_loss(loss, reduction)
|
| 46 |
+
# if reduction is mean, then compute mean over weight region
|
| 47 |
+
elif reduction == 'mean':
|
| 48 |
+
if weight.size(1) > 1:
|
| 49 |
+
weight = weight.sum()
|
| 50 |
+
else:
|
| 51 |
+
weight = weight.sum() * loss.size(1)
|
| 52 |
+
loss = loss.sum() / weight
|
| 53 |
+
|
| 54 |
+
return loss
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def weighted_loss(loss_func):
|
| 58 |
+
"""Create a weighted version of a given loss function.
|
| 59 |
+
|
| 60 |
+
To use this decorator, the loss function must have the signature like
|
| 61 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
| 62 |
+
element-wise loss without any reduction. This decorator will add weight
|
| 63 |
+
and reduction arguments to the function. The decorated function will have
|
| 64 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
| 65 |
+
**kwargs)`.
|
| 66 |
+
|
| 67 |
+
:Example:
|
| 68 |
+
|
| 69 |
+
>>> import torch
|
| 70 |
+
>>> @weighted_loss
|
| 71 |
+
>>> def l1_loss(pred, target):
|
| 72 |
+
>>> return (pred - target).abs()
|
| 73 |
+
|
| 74 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
| 75 |
+
>>> target = torch.Tensor([1, 1, 1])
|
| 76 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
| 77 |
+
|
| 78 |
+
>>> l1_loss(pred, target)
|
| 79 |
+
tensor(1.3333)
|
| 80 |
+
>>> l1_loss(pred, target, weight)
|
| 81 |
+
tensor(1.5000)
|
| 82 |
+
>>> l1_loss(pred, target, reduction='none')
|
| 83 |
+
tensor([1., 1., 2.])
|
| 84 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
| 85 |
+
tensor(3.)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
@functools.wraps(loss_func)
|
| 89 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
| 90 |
+
# get element-wise loss
|
| 91 |
+
loss = loss_func(pred, target, **kwargs)
|
| 92 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
| 93 |
+
return loss
|
| 94 |
+
|
| 95 |
+
return wrapper
|
basicsr/losses/losses.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import lpips
|
| 3 |
+
import torch
|
| 4 |
+
from torch import autograd as autograd
|
| 5 |
+
from torch import nn as nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from basicsr.archs.vgg_arch import VGGFeatureExtractor
|
| 9 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
| 10 |
+
from .loss_util import weighted_loss
|
| 11 |
+
|
| 12 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@weighted_loss
|
| 16 |
+
def l1_loss(pred, target):
|
| 17 |
+
return F.l1_loss(pred, target, reduction='none')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@weighted_loss
|
| 21 |
+
def mse_loss(pred, target):
|
| 22 |
+
return F.mse_loss(pred, target, reduction='none')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@weighted_loss
|
| 26 |
+
def charbonnier_loss(pred, target, eps=1e-12):
|
| 27 |
+
return torch.sqrt((pred - target)**2 + eps)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@LOSS_REGISTRY.register()
|
| 31 |
+
class L1Loss(nn.Module):
|
| 32 |
+
"""L1 (mean absolute error, MAE) loss.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
| 36 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 37 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 41 |
+
super(L1Loss, self).__init__()
|
| 42 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 43 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
| 44 |
+
|
| 45 |
+
self.loss_weight = loss_weight
|
| 46 |
+
self.reduction = reduction
|
| 47 |
+
|
| 48 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 52 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 53 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 54 |
+
weights. Default: None.
|
| 55 |
+
"""
|
| 56 |
+
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@LOSS_REGISTRY.register()
|
| 60 |
+
class MSELoss(nn.Module):
|
| 61 |
+
"""MSE (L2) loss.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
| 65 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 66 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 70 |
+
super(MSELoss, self).__init__()
|
| 71 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 72 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
| 73 |
+
|
| 74 |
+
self.loss_weight = loss_weight
|
| 75 |
+
self.reduction = reduction
|
| 76 |
+
|
| 77 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 81 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 82 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 83 |
+
weights. Default: None.
|
| 84 |
+
"""
|
| 85 |
+
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@LOSS_REGISTRY.register()
|
| 89 |
+
class CharbonnierLoss(nn.Module):
|
| 90 |
+
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
|
| 91 |
+
variant of L1Loss).
|
| 92 |
+
|
| 93 |
+
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
|
| 94 |
+
Super-Resolution".
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
| 98 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 99 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 100 |
+
eps (float): A value used to control the curvature near zero.
|
| 101 |
+
Default: 1e-12.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
|
| 105 |
+
super(CharbonnierLoss, self).__init__()
|
| 106 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 107 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
| 108 |
+
|
| 109 |
+
self.loss_weight = loss_weight
|
| 110 |
+
self.reduction = reduction
|
| 111 |
+
self.eps = eps
|
| 112 |
+
|
| 113 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
Args:
|
| 116 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 117 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 118 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 119 |
+
weights. Default: None.
|
| 120 |
+
"""
|
| 121 |
+
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@LOSS_REGISTRY.register()
|
| 125 |
+
class WeightedTVLoss(L1Loss):
|
| 126 |
+
"""Weighted TV loss.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, loss_weight=1.0):
|
| 133 |
+
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
|
| 134 |
+
|
| 135 |
+
def forward(self, pred, weight=None):
|
| 136 |
+
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
|
| 137 |
+
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
|
| 138 |
+
|
| 139 |
+
loss = x_diff + y_diff
|
| 140 |
+
|
| 141 |
+
return loss
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@LOSS_REGISTRY.register()
|
| 145 |
+
class PerceptualLoss(nn.Module):
|
| 146 |
+
"""Perceptual loss with commonly used style loss.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
| 150 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
| 151 |
+
feature layer (before relu5_4) will be extracted with weight
|
| 152 |
+
1.0 in calculting losses.
|
| 153 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
| 154 |
+
Default: 'vgg19'.
|
| 155 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
| 156 |
+
Default: True.
|
| 157 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
| 158 |
+
Default: False.
|
| 159 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
| 160 |
+
loss will be calculated and the loss will multiplied by the
|
| 161 |
+
weight. Default: 1.0.
|
| 162 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
| 163 |
+
calculated and the loss will multiplied by the weight.
|
| 164 |
+
Default: 0.
|
| 165 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self,
|
| 169 |
+
layer_weights,
|
| 170 |
+
vgg_type='vgg19',
|
| 171 |
+
use_input_norm=True,
|
| 172 |
+
range_norm=False,
|
| 173 |
+
perceptual_weight=1.0,
|
| 174 |
+
style_weight=0.,
|
| 175 |
+
criterion='l1'):
|
| 176 |
+
super(PerceptualLoss, self).__init__()
|
| 177 |
+
self.perceptual_weight = perceptual_weight
|
| 178 |
+
self.style_weight = style_weight
|
| 179 |
+
self.layer_weights = layer_weights
|
| 180 |
+
self.vgg = VGGFeatureExtractor(
|
| 181 |
+
layer_name_list=list(layer_weights.keys()),
|
| 182 |
+
vgg_type=vgg_type,
|
| 183 |
+
use_input_norm=use_input_norm,
|
| 184 |
+
range_norm=range_norm)
|
| 185 |
+
|
| 186 |
+
self.criterion_type = criterion
|
| 187 |
+
if self.criterion_type == 'l1':
|
| 188 |
+
self.criterion = torch.nn.L1Loss()
|
| 189 |
+
elif self.criterion_type == 'l2':
|
| 190 |
+
self.criterion = torch.nn.L2loss()
|
| 191 |
+
elif self.criterion_type == 'mse':
|
| 192 |
+
self.criterion = torch.nn.MSELoss(reduction='mean')
|
| 193 |
+
elif self.criterion_type == 'fro':
|
| 194 |
+
self.criterion = None
|
| 195 |
+
else:
|
| 196 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
| 197 |
+
|
| 198 |
+
def forward(self, x, gt):
|
| 199 |
+
"""Forward function.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
| 203 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tensor: Forward results.
|
| 207 |
+
"""
|
| 208 |
+
# extract vgg features
|
| 209 |
+
x_features = self.vgg(x)
|
| 210 |
+
gt_features = self.vgg(gt.detach())
|
| 211 |
+
|
| 212 |
+
# calculate perceptual loss
|
| 213 |
+
if self.perceptual_weight > 0:
|
| 214 |
+
percep_loss = 0
|
| 215 |
+
for k in x_features.keys():
|
| 216 |
+
if self.criterion_type == 'fro':
|
| 217 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
| 218 |
+
else:
|
| 219 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
| 220 |
+
percep_loss *= self.perceptual_weight
|
| 221 |
+
else:
|
| 222 |
+
percep_loss = None
|
| 223 |
+
|
| 224 |
+
# calculate style loss
|
| 225 |
+
if self.style_weight > 0:
|
| 226 |
+
style_loss = 0
|
| 227 |
+
for k in x_features.keys():
|
| 228 |
+
if self.criterion_type == 'fro':
|
| 229 |
+
style_loss += torch.norm(
|
| 230 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
| 231 |
+
else:
|
| 232 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
| 233 |
+
gt_features[k])) * self.layer_weights[k]
|
| 234 |
+
style_loss *= self.style_weight
|
| 235 |
+
else:
|
| 236 |
+
style_loss = None
|
| 237 |
+
|
| 238 |
+
return percep_loss, style_loss
|
| 239 |
+
|
| 240 |
+
def _gram_mat(self, x):
|
| 241 |
+
"""Calculate Gram matrix.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
torch.Tensor: Gram matrix.
|
| 248 |
+
"""
|
| 249 |
+
n, c, h, w = x.size()
|
| 250 |
+
features = x.view(n, c, w * h)
|
| 251 |
+
features_t = features.transpose(1, 2)
|
| 252 |
+
gram = features.bmm(features_t) / (c * h * w)
|
| 253 |
+
return gram
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@LOSS_REGISTRY.register()
|
| 257 |
+
class LPIPSLoss(nn.Module):
|
| 258 |
+
def __init__(self,
|
| 259 |
+
loss_weight=1.0,
|
| 260 |
+
use_input_norm=True,
|
| 261 |
+
range_norm=False,):
|
| 262 |
+
super(LPIPSLoss, self).__init__()
|
| 263 |
+
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
|
| 264 |
+
self.loss_weight = loss_weight
|
| 265 |
+
self.use_input_norm = use_input_norm
|
| 266 |
+
self.range_norm = range_norm
|
| 267 |
+
|
| 268 |
+
if self.use_input_norm:
|
| 269 |
+
# the mean is for image with range [0, 1]
|
| 270 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 271 |
+
# the std is for image with range [0, 1]
|
| 272 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 273 |
+
|
| 274 |
+
def forward(self, pred, target):
|
| 275 |
+
if self.range_norm:
|
| 276 |
+
pred = (pred + 1) / 2
|
| 277 |
+
target = (target + 1) / 2
|
| 278 |
+
if self.use_input_norm:
|
| 279 |
+
pred = (pred - self.mean) / self.std
|
| 280 |
+
target = (target - self.mean) / self.std
|
| 281 |
+
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
|
| 282 |
+
return self.loss_weight * lpips_loss.mean()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@LOSS_REGISTRY.register()
|
| 286 |
+
class GANLoss(nn.Module):
|
| 287 |
+
"""Define GAN loss.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
| 291 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
| 292 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
| 293 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
| 294 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
| 295 |
+
for discriminators.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
| 299 |
+
super(GANLoss, self).__init__()
|
| 300 |
+
self.gan_type = gan_type
|
| 301 |
+
self.loss_weight = loss_weight
|
| 302 |
+
self.real_label_val = real_label_val
|
| 303 |
+
self.fake_label_val = fake_label_val
|
| 304 |
+
|
| 305 |
+
if self.gan_type == 'vanilla':
|
| 306 |
+
self.loss = nn.BCEWithLogitsLoss()
|
| 307 |
+
elif self.gan_type == 'lsgan':
|
| 308 |
+
self.loss = nn.MSELoss()
|
| 309 |
+
elif self.gan_type == 'wgan':
|
| 310 |
+
self.loss = self._wgan_loss
|
| 311 |
+
elif self.gan_type == 'wgan_softplus':
|
| 312 |
+
self.loss = self._wgan_softplus_loss
|
| 313 |
+
elif self.gan_type == 'hinge':
|
| 314 |
+
self.loss = nn.ReLU()
|
| 315 |
+
else:
|
| 316 |
+
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
| 317 |
+
|
| 318 |
+
def _wgan_loss(self, input, target):
|
| 319 |
+
"""wgan loss.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
input (Tensor): Input tensor.
|
| 323 |
+
target (bool): Target label.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
Tensor: wgan loss.
|
| 327 |
+
"""
|
| 328 |
+
return -input.mean() if target else input.mean()
|
| 329 |
+
|
| 330 |
+
def _wgan_softplus_loss(self, input, target):
|
| 331 |
+
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
| 332 |
+
ReLU function.
|
| 333 |
+
|
| 334 |
+
In StyleGAN2, it is called:
|
| 335 |
+
Logistic loss for discriminator;
|
| 336 |
+
Non-saturating loss for generator.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
input (Tensor): Input tensor.
|
| 340 |
+
target (bool): Target label.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
Tensor: wgan loss.
|
| 344 |
+
"""
|
| 345 |
+
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
| 346 |
+
|
| 347 |
+
def get_target_label(self, input, target_is_real):
|
| 348 |
+
"""Get target label.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
input (Tensor): Input tensor.
|
| 352 |
+
target_is_real (bool): Whether the target is real or fake.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
| 356 |
+
return Tensor.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
if self.gan_type in ['wgan', 'wgan_softplus']:
|
| 360 |
+
return target_is_real
|
| 361 |
+
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
| 362 |
+
return input.new_ones(input.size()) * target_val
|
| 363 |
+
|
| 364 |
+
def forward(self, input, target_is_real, is_disc=False):
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
input (Tensor): The input for the loss module, i.e., the network
|
| 368 |
+
prediction.
|
| 369 |
+
target_is_real (bool): Whether the targe is real or fake.
|
| 370 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
| 371 |
+
Default: False.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Tensor: GAN loss value.
|
| 375 |
+
"""
|
| 376 |
+
if self.gan_type == 'hinge':
|
| 377 |
+
if is_disc: # for discriminators in hinge-gan
|
| 378 |
+
input = -input if target_is_real else input
|
| 379 |
+
loss = self.loss(1 + input).mean()
|
| 380 |
+
else: # for generators in hinge-gan
|
| 381 |
+
loss = -input.mean()
|
| 382 |
+
else: # other gan types
|
| 383 |
+
target_label = self.get_target_label(input, target_is_real)
|
| 384 |
+
loss = self.loss(input, target_label)
|
| 385 |
+
|
| 386 |
+
# loss_weight is always 1.0 for discriminators
|
| 387 |
+
return loss if is_disc else loss * self.loss_weight
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def r1_penalty(real_pred, real_img):
|
| 391 |
+
"""R1 regularization for discriminator. The core idea is to
|
| 392 |
+
penalize the gradient on real data alone: when the
|
| 393 |
+
generator distribution produces the true data distribution
|
| 394 |
+
and the discriminator is equal to 0 on the data manifold, the
|
| 395 |
+
gradient penalty ensures that the discriminator cannot create
|
| 396 |
+
a non-zero gradient orthogonal to the data manifold without
|
| 397 |
+
suffering a loss in the GAN game.
|
| 398 |
+
|
| 399 |
+
Ref:
|
| 400 |
+
Eq. 9 in Which training methods for GANs do actually converge.
|
| 401 |
+
"""
|
| 402 |
+
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
| 403 |
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
| 404 |
+
return grad_penalty
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
| 408 |
+
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
|
| 409 |
+
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
|
| 410 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
| 411 |
+
|
| 412 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
| 413 |
+
|
| 414 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
| 415 |
+
|
| 416 |
+
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
| 420 |
+
"""Calculate gradient penalty for wgan-gp.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
discriminator (nn.Module): Network for the discriminator.
|
| 424 |
+
real_data (Tensor): Real input data.
|
| 425 |
+
fake_data (Tensor): Fake input data.
|
| 426 |
+
weight (Tensor): Weight tensor. Default: None.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tensor: A tensor for gradient penalty.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
batch_size = real_data.size(0)
|
| 433 |
+
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
| 434 |
+
|
| 435 |
+
# interpolate between real_data and fake_data
|
| 436 |
+
interpolates = alpha * real_data + (1. - alpha) * fake_data
|
| 437 |
+
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
| 438 |
+
|
| 439 |
+
disc_interpolates = discriminator(interpolates)
|
| 440 |
+
gradients = autograd.grad(
|
| 441 |
+
outputs=disc_interpolates,
|
| 442 |
+
inputs=interpolates,
|
| 443 |
+
grad_outputs=torch.ones_like(disc_interpolates),
|
| 444 |
+
create_graph=True,
|
| 445 |
+
retain_graph=True,
|
| 446 |
+
only_inputs=True)[0]
|
| 447 |
+
|
| 448 |
+
if weight is not None:
|
| 449 |
+
gradients = gradients * weight
|
| 450 |
+
|
| 451 |
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
|
| 452 |
+
if weight is not None:
|
| 453 |
+
gradients_penalty /= torch.mean(weight)
|
| 454 |
+
|
| 455 |
+
return gradients_penalty
|
basicsr/metrics/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
| 4 |
+
from .psnr_ssim import calculate_psnr, calculate_ssim
|
| 5 |
+
|
| 6 |
+
__all__ = ['calculate_psnr', 'calculate_ssim']
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def calculate_metric(data, opt):
|
| 10 |
+
"""Calculate metric from data and options.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
opt (dict): Configuration. It must constain:
|
| 14 |
+
type (str): Model type.
|
| 15 |
+
"""
|
| 16 |
+
opt = deepcopy(opt)
|
| 17 |
+
metric_type = opt.pop('type')
|
| 18 |
+
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
|
| 19 |
+
return metric
|
basicsr/metrics/metric_util.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from basicsr.utils.matlab_functions import bgr2ycbcr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def reorder_image(img, input_order='HWC'):
|
| 7 |
+
"""Reorder images to 'HWC' order.
|
| 8 |
+
|
| 9 |
+
If the input_order is (h, w), return (h, w, 1);
|
| 10 |
+
If the input_order is (c, h, w), return (h, w, c);
|
| 11 |
+
If the input_order is (h, w, c), return as it is.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
img (ndarray): Input image.
|
| 15 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 16 |
+
If the input image shape is (h, w), input_order will not have
|
| 17 |
+
effects. Default: 'HWC'.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
ndarray: reordered image.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
if input_order not in ['HWC', 'CHW']:
|
| 24 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
|
| 25 |
+
if len(img.shape) == 2:
|
| 26 |
+
img = img[..., None]
|
| 27 |
+
if input_order == 'CHW':
|
| 28 |
+
img = img.transpose(1, 2, 0)
|
| 29 |
+
return img
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def to_y_channel(img):
|
| 33 |
+
"""Change to Y channel of YCbCr.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
img (ndarray): Images with range [0, 255].
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
| 40 |
+
"""
|
| 41 |
+
img = img.astype(np.float32) / 255.
|
| 42 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
| 43 |
+
img = bgr2ycbcr(img, y_only=True)
|
| 44 |
+
img = img[..., None]
|
| 45 |
+
return img * 255.
|
basicsr/metrics/psnr_ssim.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
| 5 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@METRIC_REGISTRY.register()
|
| 9 |
+
def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
|
| 10 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
| 11 |
+
|
| 12 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
img1 (ndarray): Images with range [0, 255].
|
| 16 |
+
img2 (ndarray): Images with range [0, 255].
|
| 17 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 18 |
+
pixels are not involved in the PSNR calculation.
|
| 19 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 20 |
+
Default: 'HWC'.
|
| 21 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
float: psnr result.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 28 |
+
if input_order not in ['HWC', 'CHW']:
|
| 29 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
| 30 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 31 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 32 |
+
img1 = img1.astype(np.float64)
|
| 33 |
+
img2 = img2.astype(np.float64)
|
| 34 |
+
|
| 35 |
+
if crop_border != 0:
|
| 36 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 37 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 38 |
+
|
| 39 |
+
if test_y_channel:
|
| 40 |
+
img1 = to_y_channel(img1)
|
| 41 |
+
img2 = to_y_channel(img2)
|
| 42 |
+
|
| 43 |
+
mse = np.mean((img1 - img2)**2)
|
| 44 |
+
if mse == 0:
|
| 45 |
+
return float('inf')
|
| 46 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _ssim(img1, img2):
|
| 50 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 51 |
+
|
| 52 |
+
It is called by func:`calculate_ssim`.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 56 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
float: ssim result.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
C1 = (0.01 * 255)**2
|
| 63 |
+
C2 = (0.03 * 255)**2
|
| 64 |
+
|
| 65 |
+
img1 = img1.astype(np.float64)
|
| 66 |
+
img2 = img2.astype(np.float64)
|
| 67 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 68 |
+
window = np.outer(kernel, kernel.transpose())
|
| 69 |
+
|
| 70 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
| 71 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 72 |
+
mu1_sq = mu1**2
|
| 73 |
+
mu2_sq = mu2**2
|
| 74 |
+
mu1_mu2 = mu1 * mu2
|
| 75 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 76 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 77 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 78 |
+
|
| 79 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
| 80 |
+
return ssim_map.mean()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@METRIC_REGISTRY.register()
|
| 84 |
+
def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
|
| 85 |
+
"""Calculate SSIM (structural similarity).
|
| 86 |
+
|
| 87 |
+
Ref:
|
| 88 |
+
Image quality assessment: From error visibility to structural similarity
|
| 89 |
+
|
| 90 |
+
The results are the same as that of the official released MATLAB code in
|
| 91 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
| 92 |
+
|
| 93 |
+
For three-channel images, SSIM is calculated for each channel and then
|
| 94 |
+
averaged.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
img1 (ndarray): Images with range [0, 255].
|
| 98 |
+
img2 (ndarray): Images with range [0, 255].
|
| 99 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 100 |
+
pixels are not involved in the SSIM calculation.
|
| 101 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 102 |
+
Default: 'HWC'.
|
| 103 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
float: ssim result.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 110 |
+
if input_order not in ['HWC', 'CHW']:
|
| 111 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
| 112 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 113 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 114 |
+
img1 = img1.astype(np.float64)
|
| 115 |
+
img2 = img2.astype(np.float64)
|
| 116 |
+
|
| 117 |
+
if crop_border != 0:
|
| 118 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 119 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 120 |
+
|
| 121 |
+
if test_y_channel:
|
| 122 |
+
img1 = to_y_channel(img1)
|
| 123 |
+
img2 = to_y_channel(img2)
|
| 124 |
+
|
| 125 |
+
ssims = []
|
| 126 |
+
for i in range(img1.shape[2]):
|
| 127 |
+
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
| 128 |
+
return np.array(ssims).mean()
|
basicsr/models/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
from basicsr.utils import get_root_logger, scandir
|
| 6 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 7 |
+
|
| 8 |
+
__all__ = ['build_model']
|
| 9 |
+
|
| 10 |
+
# automatically scan and import model modules for registry
|
| 11 |
+
# scan all the files under the 'models' folder and collect files ending with
|
| 12 |
+
# '_model.py'
|
| 13 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 14 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
| 15 |
+
# import all the model modules
|
| 16 |
+
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_model(opt):
|
| 20 |
+
"""Build model from options.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
opt (dict): Configuration. It must constain:
|
| 24 |
+
model_type (str): Model type.
|
| 25 |
+
"""
|
| 26 |
+
opt = deepcopy(opt)
|
| 27 |
+
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
|
| 28 |
+
logger = get_root_logger()
|
| 29 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
| 30 |
+
return model
|
basicsr/ops/__init__.py
ADDED
|
File without changes
|
basicsr/ops/dcn/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
|
| 2 |
+
modulated_deform_conv)
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
|
| 6 |
+
'modulated_deform_conv'
|
| 7 |
+
]
|
basicsr/ops/dcn/deform_conv.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
from torch.autograd import Function
|
| 5 |
+
from torch.autograd.function import once_differentiable
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from torch.nn.modules.utils import _pair, _single
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from . import deform_conv_ext
|
| 11 |
+
except ImportError:
|
| 12 |
+
import os
|
| 13 |
+
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
| 14 |
+
if BASICSR_JIT == 'True':
|
| 15 |
+
from torch.utils.cpp_extension import load
|
| 16 |
+
module_path = os.path.dirname(__file__)
|
| 17 |
+
deform_conv_ext = load(
|
| 18 |
+
'deform_conv',
|
| 19 |
+
sources=[
|
| 20 |
+
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
|
| 21 |
+
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
|
| 22 |
+
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
|
| 23 |
+
],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DeformConvFunction(Function):
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def forward(ctx,
|
| 31 |
+
input,
|
| 32 |
+
offset,
|
| 33 |
+
weight,
|
| 34 |
+
stride=1,
|
| 35 |
+
padding=0,
|
| 36 |
+
dilation=1,
|
| 37 |
+
groups=1,
|
| 38 |
+
deformable_groups=1,
|
| 39 |
+
im2col_step=64):
|
| 40 |
+
if input is not None and input.dim() != 4:
|
| 41 |
+
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
|
| 42 |
+
ctx.stride = _pair(stride)
|
| 43 |
+
ctx.padding = _pair(padding)
|
| 44 |
+
ctx.dilation = _pair(dilation)
|
| 45 |
+
ctx.groups = groups
|
| 46 |
+
ctx.deformable_groups = deformable_groups
|
| 47 |
+
ctx.im2col_step = im2col_step
|
| 48 |
+
|
| 49 |
+
ctx.save_for_backward(input, offset, weight)
|
| 50 |
+
|
| 51 |
+
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
|
| 52 |
+
|
| 53 |
+
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
| 54 |
+
|
| 55 |
+
if not input.is_cuda:
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
else:
|
| 58 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
| 59 |
+
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
| 60 |
+
deform_conv_ext.deform_conv_forward(input, weight,
|
| 61 |
+
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
| 62 |
+
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
| 63 |
+
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
| 64 |
+
ctx.deformable_groups, cur_im2col_step)
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
@once_differentiable
|
| 69 |
+
def backward(ctx, grad_output):
|
| 70 |
+
input, offset, weight = ctx.saved_tensors
|
| 71 |
+
|
| 72 |
+
grad_input = grad_offset = grad_weight = None
|
| 73 |
+
|
| 74 |
+
if not grad_output.is_cuda:
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
else:
|
| 77 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
| 78 |
+
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
| 79 |
+
|
| 80 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
| 81 |
+
grad_input = torch.zeros_like(input)
|
| 82 |
+
grad_offset = torch.zeros_like(offset)
|
| 83 |
+
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
|
| 84 |
+
grad_offset, weight, ctx.bufs_[0], weight.size(3),
|
| 85 |
+
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
| 86 |
+
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
| 87 |
+
ctx.deformable_groups, cur_im2col_step)
|
| 88 |
+
|
| 89 |
+
if ctx.needs_input_grad[2]:
|
| 90 |
+
grad_weight = torch.zeros_like(weight)
|
| 91 |
+
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
|
| 92 |
+
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
| 93 |
+
weight.size(2), ctx.stride[1], ctx.stride[0],
|
| 94 |
+
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
| 95 |
+
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
|
| 96 |
+
cur_im2col_step)
|
| 97 |
+
|
| 98 |
+
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _output_size(input, weight, padding, dilation, stride):
|
| 102 |
+
channels = weight.size(0)
|
| 103 |
+
output_size = (input.size(0), channels)
|
| 104 |
+
for d in range(input.dim() - 2):
|
| 105 |
+
in_size = input.size(d + 2)
|
| 106 |
+
pad = padding[d]
|
| 107 |
+
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
| 108 |
+
stride_ = stride[d]
|
| 109 |
+
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
| 110 |
+
if not all(map(lambda s: s > 0, output_size)):
|
| 111 |
+
raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
|
| 112 |
+
return output_size
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ModulatedDeformConvFunction(Function):
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def forward(ctx,
|
| 119 |
+
input,
|
| 120 |
+
offset,
|
| 121 |
+
mask,
|
| 122 |
+
weight,
|
| 123 |
+
bias=None,
|
| 124 |
+
stride=1,
|
| 125 |
+
padding=0,
|
| 126 |
+
dilation=1,
|
| 127 |
+
groups=1,
|
| 128 |
+
deformable_groups=1):
|
| 129 |
+
ctx.stride = stride
|
| 130 |
+
ctx.padding = padding
|
| 131 |
+
ctx.dilation = dilation
|
| 132 |
+
ctx.groups = groups
|
| 133 |
+
ctx.deformable_groups = deformable_groups
|
| 134 |
+
ctx.with_bias = bias is not None
|
| 135 |
+
if not ctx.with_bias:
|
| 136 |
+
bias = input.new_empty(1) # fake tensor
|
| 137 |
+
if not input.is_cuda:
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
| 140 |
+
or input.requires_grad:
|
| 141 |
+
ctx.save_for_backward(input, offset, mask, weight, bias)
|
| 142 |
+
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
| 143 |
+
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
| 144 |
+
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
|
| 145 |
+
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
|
| 146 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
| 147 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
@once_differentiable
|
| 152 |
+
def backward(ctx, grad_output):
|
| 153 |
+
if not grad_output.is_cuda:
|
| 154 |
+
raise NotImplementedError
|
| 155 |
+
input, offset, mask, weight, bias = ctx.saved_tensors
|
| 156 |
+
grad_input = torch.zeros_like(input)
|
| 157 |
+
grad_offset = torch.zeros_like(offset)
|
| 158 |
+
grad_mask = torch.zeros_like(mask)
|
| 159 |
+
grad_weight = torch.zeros_like(weight)
|
| 160 |
+
grad_bias = torch.zeros_like(bias)
|
| 161 |
+
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
|
| 162 |
+
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
|
| 163 |
+
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
|
| 164 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
| 165 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
| 166 |
+
if not ctx.with_bias:
|
| 167 |
+
grad_bias = None
|
| 168 |
+
|
| 169 |
+
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def _infer_shape(ctx, input, weight):
|
| 173 |
+
n = input.size(0)
|
| 174 |
+
channels_out = weight.size(0)
|
| 175 |
+
height, width = input.shape[2:4]
|
| 176 |
+
kernel_h, kernel_w = weight.shape[2:4]
|
| 177 |
+
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
|
| 178 |
+
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
|
| 179 |
+
return n, channels_out, height_out, width_out
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
deform_conv = DeformConvFunction.apply
|
| 183 |
+
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class DeformConv(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self,
|
| 189 |
+
in_channels,
|
| 190 |
+
out_channels,
|
| 191 |
+
kernel_size,
|
| 192 |
+
stride=1,
|
| 193 |
+
padding=0,
|
| 194 |
+
dilation=1,
|
| 195 |
+
groups=1,
|
| 196 |
+
deformable_groups=1,
|
| 197 |
+
bias=False):
|
| 198 |
+
super(DeformConv, self).__init__()
|
| 199 |
+
|
| 200 |
+
assert not bias
|
| 201 |
+
assert in_channels % groups == 0, \
|
| 202 |
+
f'in_channels {in_channels} is not divisible by groups {groups}'
|
| 203 |
+
assert out_channels % groups == 0, \
|
| 204 |
+
f'out_channels {out_channels} is not divisible ' \
|
| 205 |
+
f'by groups {groups}'
|
| 206 |
+
|
| 207 |
+
self.in_channels = in_channels
|
| 208 |
+
self.out_channels = out_channels
|
| 209 |
+
self.kernel_size = _pair(kernel_size)
|
| 210 |
+
self.stride = _pair(stride)
|
| 211 |
+
self.padding = _pair(padding)
|
| 212 |
+
self.dilation = _pair(dilation)
|
| 213 |
+
self.groups = groups
|
| 214 |
+
self.deformable_groups = deformable_groups
|
| 215 |
+
# enable compatibility with nn.Conv2d
|
| 216 |
+
self.transposed = False
|
| 217 |
+
self.output_padding = _single(0)
|
| 218 |
+
|
| 219 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
|
| 220 |
+
|
| 221 |
+
self.reset_parameters()
|
| 222 |
+
|
| 223 |
+
def reset_parameters(self):
|
| 224 |
+
n = self.in_channels
|
| 225 |
+
for k in self.kernel_size:
|
| 226 |
+
n *= k
|
| 227 |
+
stdv = 1. / math.sqrt(n)
|
| 228 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 229 |
+
|
| 230 |
+
def forward(self, x, offset):
|
| 231 |
+
# To fix an assert error in deform_conv_cuda.cpp:128
|
| 232 |
+
# input image is smaller than kernel
|
| 233 |
+
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
|
| 234 |
+
if input_pad:
|
| 235 |
+
pad_h = max(self.kernel_size[0] - x.size(2), 0)
|
| 236 |
+
pad_w = max(self.kernel_size[1] - x.size(3), 0)
|
| 237 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
| 238 |
+
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
| 239 |
+
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
| 240 |
+
self.deformable_groups)
|
| 241 |
+
if input_pad:
|
| 242 |
+
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
|
| 243 |
+
return out
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class DeformConvPack(DeformConv):
|
| 247 |
+
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
in_channels (int): Same as nn.Conv2d.
|
| 251 |
+
out_channels (int): Same as nn.Conv2d.
|
| 252 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
| 253 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
| 254 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
| 255 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
| 256 |
+
groups (int): Same as nn.Conv2d.
|
| 257 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
| 258 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
| 259 |
+
False.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
_version = 2
|
| 263 |
+
|
| 264 |
+
def __init__(self, *args, **kwargs):
|
| 265 |
+
super(DeformConvPack, self).__init__(*args, **kwargs)
|
| 266 |
+
|
| 267 |
+
self.conv_offset = nn.Conv2d(
|
| 268 |
+
self.in_channels,
|
| 269 |
+
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
| 270 |
+
kernel_size=self.kernel_size,
|
| 271 |
+
stride=_pair(self.stride),
|
| 272 |
+
padding=_pair(self.padding),
|
| 273 |
+
dilation=_pair(self.dilation),
|
| 274 |
+
bias=True)
|
| 275 |
+
self.init_offset()
|
| 276 |
+
|
| 277 |
+
def init_offset(self):
|
| 278 |
+
self.conv_offset.weight.data.zero_()
|
| 279 |
+
self.conv_offset.bias.data.zero_()
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
offset = self.conv_offset(x)
|
| 283 |
+
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
| 284 |
+
self.deformable_groups)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class ModulatedDeformConv(nn.Module):
|
| 288 |
+
|
| 289 |
+
def __init__(self,
|
| 290 |
+
in_channels,
|
| 291 |
+
out_channels,
|
| 292 |
+
kernel_size,
|
| 293 |
+
stride=1,
|
| 294 |
+
padding=0,
|
| 295 |
+
dilation=1,
|
| 296 |
+
groups=1,
|
| 297 |
+
deformable_groups=1,
|
| 298 |
+
bias=True):
|
| 299 |
+
super(ModulatedDeformConv, self).__init__()
|
| 300 |
+
self.in_channels = in_channels
|
| 301 |
+
self.out_channels = out_channels
|
| 302 |
+
self.kernel_size = _pair(kernel_size)
|
| 303 |
+
self.stride = stride
|
| 304 |
+
self.padding = padding
|
| 305 |
+
self.dilation = dilation
|
| 306 |
+
self.groups = groups
|
| 307 |
+
self.deformable_groups = deformable_groups
|
| 308 |
+
self.with_bias = bias
|
| 309 |
+
# enable compatibility with nn.Conv2d
|
| 310 |
+
self.transposed = False
|
| 311 |
+
self.output_padding = _single(0)
|
| 312 |
+
|
| 313 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
| 314 |
+
if bias:
|
| 315 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 316 |
+
else:
|
| 317 |
+
self.register_parameter('bias', None)
|
| 318 |
+
self.init_weights()
|
| 319 |
+
|
| 320 |
+
def init_weights(self):
|
| 321 |
+
n = self.in_channels
|
| 322 |
+
for k in self.kernel_size:
|
| 323 |
+
n *= k
|
| 324 |
+
stdv = 1. / math.sqrt(n)
|
| 325 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 326 |
+
if self.bias is not None:
|
| 327 |
+
self.bias.data.zero_()
|
| 328 |
+
|
| 329 |
+
def forward(self, x, offset, mask):
|
| 330 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
| 331 |
+
self.groups, self.deformable_groups)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ModulatedDeformConvPack(ModulatedDeformConv):
|
| 335 |
+
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
in_channels (int): Same as nn.Conv2d.
|
| 339 |
+
out_channels (int): Same as nn.Conv2d.
|
| 340 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
| 341 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
| 342 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
| 343 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
| 344 |
+
groups (int): Same as nn.Conv2d.
|
| 345 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
| 346 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
| 347 |
+
False.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
_version = 2
|
| 351 |
+
|
| 352 |
+
def __init__(self, *args, **kwargs):
|
| 353 |
+
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
| 354 |
+
|
| 355 |
+
self.conv_offset = nn.Conv2d(
|
| 356 |
+
self.in_channels,
|
| 357 |
+
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
| 358 |
+
kernel_size=self.kernel_size,
|
| 359 |
+
stride=_pair(self.stride),
|
| 360 |
+
padding=_pair(self.padding),
|
| 361 |
+
dilation=_pair(self.dilation),
|
| 362 |
+
bias=True)
|
| 363 |
+
self.init_weights()
|
| 364 |
+
|
| 365 |
+
def init_weights(self):
|
| 366 |
+
super(ModulatedDeformConvPack, self).init_weights()
|
| 367 |
+
if hasattr(self, 'conv_offset'):
|
| 368 |
+
self.conv_offset.weight.data.zero_()
|
| 369 |
+
self.conv_offset.bias.data.zero_()
|
| 370 |
+
|
| 371 |
+
def forward(self, x):
|
| 372 |
+
out = self.conv_offset(x)
|
| 373 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 374 |
+
offset = torch.cat((o1, o2), dim=1)
|
| 375 |
+
mask = torch.sigmoid(mask)
|
| 376 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
| 377 |
+
self.groups, self.deformable_groups)
|
basicsr/ops/dcn/src/deform_conv_cuda.cpp
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// modify from
|
| 2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
| 3 |
+
|
| 4 |
+
#include <torch/extension.h>
|
| 5 |
+
#include <ATen/DeviceGuard.h>
|
| 6 |
+
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
| 11 |
+
const int channels, const int height, const int width,
|
| 12 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
| 13 |
+
const int pad_w, const int stride_h, const int stride_w,
|
| 14 |
+
const int dilation_h, const int dilation_w,
|
| 15 |
+
const int parallel_imgs, const int deformable_group,
|
| 16 |
+
at::Tensor data_col);
|
| 17 |
+
|
| 18 |
+
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
| 19 |
+
const int channels, const int height, const int width,
|
| 20 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
| 21 |
+
const int pad_w, const int stride_h, const int stride_w,
|
| 22 |
+
const int dilation_h, const int dilation_w,
|
| 23 |
+
const int parallel_imgs, const int deformable_group,
|
| 24 |
+
at::Tensor grad_im);
|
| 25 |
+
|
| 26 |
+
void deformable_col2im_coord(
|
| 27 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
| 28 |
+
const at::Tensor data_offset, const int channels, const int height,
|
| 29 |
+
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
| 30 |
+
const int pad_w, const int stride_h, const int stride_w,
|
| 31 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
| 32 |
+
const int deformable_group, at::Tensor grad_offset);
|
| 33 |
+
|
| 34 |
+
void modulated_deformable_im2col_cuda(
|
| 35 |
+
const at::Tensor data_im, const at::Tensor data_offset,
|
| 36 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
| 37 |
+
const int height_im, const int width_im, const int height_col,
|
| 38 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
| 39 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 40 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
| 41 |
+
at::Tensor data_col);
|
| 42 |
+
|
| 43 |
+
void modulated_deformable_col2im_cuda(
|
| 44 |
+
const at::Tensor data_col, const at::Tensor data_offset,
|
| 45 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
| 46 |
+
const int height_im, const int width_im, const int height_col,
|
| 47 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
| 48 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 49 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
| 50 |
+
at::Tensor grad_im);
|
| 51 |
+
|
| 52 |
+
void modulated_deformable_col2im_coord_cuda(
|
| 53 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
| 54 |
+
const at::Tensor data_offset, const at::Tensor data_mask,
|
| 55 |
+
const int batch_size, const int channels, const int height_im,
|
| 56 |
+
const int width_im, const int height_col, const int width_col,
|
| 57 |
+
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
| 58 |
+
const int stride_h, const int stride_w, const int dilation_h,
|
| 59 |
+
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
| 60 |
+
at::Tensor grad_mask);
|
| 61 |
+
|
| 62 |
+
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
| 63 |
+
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
| 64 |
+
int padW, int dilationH, int dilationW, int group,
|
| 65 |
+
int deformable_group) {
|
| 66 |
+
TORCH_CHECK(weight.ndimension() == 4,
|
| 67 |
+
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
| 68 |
+
"but got: %s",
|
| 69 |
+
weight.ndimension());
|
| 70 |
+
|
| 71 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
| 72 |
+
|
| 73 |
+
TORCH_CHECK(kW > 0 && kH > 0,
|
| 74 |
+
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
| 75 |
+
kW);
|
| 76 |
+
|
| 77 |
+
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
| 78 |
+
"kernel size should be consistent with weight, ",
|
| 79 |
+
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
| 80 |
+
kW, weight.size(2), weight.size(3));
|
| 81 |
+
|
| 82 |
+
TORCH_CHECK(dW > 0 && dH > 0,
|
| 83 |
+
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
| 84 |
+
|
| 85 |
+
TORCH_CHECK(
|
| 86 |
+
dilationW > 0 && dilationH > 0,
|
| 87 |
+
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
| 88 |
+
dilationH, dilationW);
|
| 89 |
+
|
| 90 |
+
int ndim = input.ndimension();
|
| 91 |
+
int dimf = 0;
|
| 92 |
+
int dimh = 1;
|
| 93 |
+
int dimw = 2;
|
| 94 |
+
|
| 95 |
+
if (ndim == 4) {
|
| 96 |
+
dimf++;
|
| 97 |
+
dimh++;
|
| 98 |
+
dimw++;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
| 102 |
+
ndim);
|
| 103 |
+
|
| 104 |
+
long nInputPlane = weight.size(1) * group;
|
| 105 |
+
long inputHeight = input.size(dimh);
|
| 106 |
+
long inputWidth = input.size(dimw);
|
| 107 |
+
long nOutputPlane = weight.size(0);
|
| 108 |
+
long outputHeight =
|
| 109 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 110 |
+
long outputWidth =
|
| 111 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 112 |
+
|
| 113 |
+
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
| 114 |
+
"input channels must divide deformable group size");
|
| 115 |
+
|
| 116 |
+
if (outputWidth < 1 || outputHeight < 1)
|
| 117 |
+
AT_ERROR(
|
| 118 |
+
"Given input size: (%ld x %ld x %ld). "
|
| 119 |
+
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
| 120 |
+
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
| 121 |
+
outputWidth);
|
| 122 |
+
|
| 123 |
+
TORCH_CHECK(input.size(1) == nInputPlane,
|
| 124 |
+
"invalid number of input planes, expected: %d, but got: %d",
|
| 125 |
+
nInputPlane, input.size(1));
|
| 126 |
+
|
| 127 |
+
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
| 128 |
+
"input image is smaller than kernel");
|
| 129 |
+
|
| 130 |
+
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
| 131 |
+
"invalid spatial size of offset, expected height: %d width: %d, but "
|
| 132 |
+
"got height: %d width: %d",
|
| 133 |
+
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
| 134 |
+
|
| 135 |
+
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
| 136 |
+
"invalid number of channels of offset");
|
| 137 |
+
|
| 138 |
+
if (gradOutput != NULL) {
|
| 139 |
+
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
| 140 |
+
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
| 141 |
+
nOutputPlane, gradOutput->size(dimf));
|
| 142 |
+
|
| 143 |
+
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
|
| 144 |
+
gradOutput->size(dimw) == outputWidth),
|
| 145 |
+
"invalid size of gradOutput, expected height: %d width: %d , but "
|
| 146 |
+
"got height: %d width: %d",
|
| 147 |
+
outputHeight, outputWidth, gradOutput->size(dimh),
|
| 148 |
+
gradOutput->size(dimw));
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
| 153 |
+
at::Tensor offset, at::Tensor output,
|
| 154 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
| 155 |
+
int kH, int dW, int dH, int padW, int padH,
|
| 156 |
+
int dilationW, int dilationH, int group,
|
| 157 |
+
int deformable_group, int im2col_step) {
|
| 158 |
+
// todo: resize columns to include im2col: done
|
| 159 |
+
// todo: add im2col_step as input
|
| 160 |
+
// todo: add new output buffer and transpose it to output (or directly
|
| 161 |
+
// transpose output) todo: possibly change data indexing because of
|
| 162 |
+
// parallel_imgs
|
| 163 |
+
|
| 164 |
+
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
| 165 |
+
dilationH, dilationW, group, deformable_group);
|
| 166 |
+
at::DeviceGuard guard(input.device());
|
| 167 |
+
|
| 168 |
+
input = input.contiguous();
|
| 169 |
+
offset = offset.contiguous();
|
| 170 |
+
weight = weight.contiguous();
|
| 171 |
+
|
| 172 |
+
int batch = 1;
|
| 173 |
+
if (input.ndimension() == 3) {
|
| 174 |
+
// Force batch
|
| 175 |
+
batch = 0;
|
| 176 |
+
input.unsqueeze_(0);
|
| 177 |
+
offset.unsqueeze_(0);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// todo: assert batchsize dividable by im2col_step
|
| 181 |
+
|
| 182 |
+
long batchSize = input.size(0);
|
| 183 |
+
long nInputPlane = input.size(1);
|
| 184 |
+
long inputHeight = input.size(2);
|
| 185 |
+
long inputWidth = input.size(3);
|
| 186 |
+
|
| 187 |
+
long nOutputPlane = weight.size(0);
|
| 188 |
+
|
| 189 |
+
long outputWidth =
|
| 190 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 191 |
+
long outputHeight =
|
| 192 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 193 |
+
|
| 194 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
| 195 |
+
|
| 196 |
+
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
| 197 |
+
outputHeight, outputWidth});
|
| 198 |
+
columns = at::zeros(
|
| 199 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
| 200 |
+
input.options());
|
| 201 |
+
|
| 202 |
+
if (ones.ndimension() != 2 ||
|
| 203 |
+
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
| 204 |
+
ones = at::ones({outputHeight, outputWidth}, input.options());
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 208 |
+
inputHeight, inputWidth});
|
| 209 |
+
offset =
|
| 210 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
| 211 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 212 |
+
|
| 213 |
+
at::Tensor output_buffer =
|
| 214 |
+
at::zeros({batchSize / im2col_step, nOutputPlane,
|
| 215 |
+
im2col_step * outputHeight, outputWidth},
|
| 216 |
+
output.options());
|
| 217 |
+
|
| 218 |
+
output_buffer = output_buffer.view(
|
| 219 |
+
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
| 220 |
+
output_buffer.size(2), output_buffer.size(3)});
|
| 221 |
+
|
| 222 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
| 223 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
| 224 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
| 225 |
+
dilationW, im2col_step, deformable_group, columns);
|
| 226 |
+
|
| 227 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 228 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 229 |
+
weight.size(2), weight.size(3)});
|
| 230 |
+
|
| 231 |
+
for (int g = 0; g < group; g++) {
|
| 232 |
+
output_buffer[elt][g] = output_buffer[elt][g]
|
| 233 |
+
.flatten(1)
|
| 234 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
| 235 |
+
.view_as(output_buffer[elt][g]);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
output_buffer = output_buffer.view(
|
| 240 |
+
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
| 241 |
+
output_buffer.size(3), output_buffer.size(4)});
|
| 242 |
+
|
| 243 |
+
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
| 244 |
+
im2col_step, outputHeight, outputWidth});
|
| 245 |
+
output_buffer.transpose_(1, 2);
|
| 246 |
+
output.copy_(output_buffer);
|
| 247 |
+
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
| 248 |
+
|
| 249 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 250 |
+
offset = offset.view(
|
| 251 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 252 |
+
|
| 253 |
+
if (batch == 0) {
|
| 254 |
+
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
| 255 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
| 256 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return 1;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
| 263 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
| 264 |
+
at::Tensor gradOffset, at::Tensor weight,
|
| 265 |
+
at::Tensor columns, int kW, int kH, int dW,
|
| 266 |
+
int dH, int padW, int padH, int dilationW,
|
| 267 |
+
int dilationH, int group,
|
| 268 |
+
int deformable_group, int im2col_step) {
|
| 269 |
+
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
| 270 |
+
dilationH, dilationW, group, deformable_group);
|
| 271 |
+
at::DeviceGuard guard(input.device());
|
| 272 |
+
|
| 273 |
+
input = input.contiguous();
|
| 274 |
+
offset = offset.contiguous();
|
| 275 |
+
gradOutput = gradOutput.contiguous();
|
| 276 |
+
weight = weight.contiguous();
|
| 277 |
+
|
| 278 |
+
int batch = 1;
|
| 279 |
+
|
| 280 |
+
if (input.ndimension() == 3) {
|
| 281 |
+
// Force batch
|
| 282 |
+
batch = 0;
|
| 283 |
+
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
| 284 |
+
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
| 285 |
+
gradOutput = gradOutput.view(
|
| 286 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
long batchSize = input.size(0);
|
| 290 |
+
long nInputPlane = input.size(1);
|
| 291 |
+
long inputHeight = input.size(2);
|
| 292 |
+
long inputWidth = input.size(3);
|
| 293 |
+
|
| 294 |
+
long nOutputPlane = weight.size(0);
|
| 295 |
+
|
| 296 |
+
long outputWidth =
|
| 297 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 298 |
+
long outputHeight =
|
| 299 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 300 |
+
|
| 301 |
+
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
| 302 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 303 |
+
columns = at::zeros(
|
| 304 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
| 305 |
+
input.options());
|
| 306 |
+
|
| 307 |
+
// change order of grad output
|
| 308 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
| 309 |
+
nOutputPlane, outputHeight, outputWidth});
|
| 310 |
+
gradOutput.transpose_(1, 2);
|
| 311 |
+
|
| 312 |
+
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 313 |
+
inputHeight, inputWidth});
|
| 314 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 315 |
+
inputHeight, inputWidth});
|
| 316 |
+
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
| 317 |
+
deformable_group * 2 * kH * kW, outputHeight,
|
| 318 |
+
outputWidth});
|
| 319 |
+
offset =
|
| 320 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
| 321 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 322 |
+
|
| 323 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
| 324 |
+
// divide into groups
|
| 325 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 326 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 327 |
+
weight.size(2), weight.size(3)});
|
| 328 |
+
gradOutput = gradOutput.view(
|
| 329 |
+
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
| 330 |
+
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
| 331 |
+
|
| 332 |
+
for (int g = 0; g < group; g++) {
|
| 333 |
+
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
| 334 |
+
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
columns =
|
| 338 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 339 |
+
gradOutput = gradOutput.view(
|
| 340 |
+
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
| 341 |
+
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
| 342 |
+
|
| 343 |
+
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
| 344 |
+
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
| 345 |
+
dilationH, dilationW, im2col_step, deformable_group,
|
| 346 |
+
gradOffset[elt]);
|
| 347 |
+
|
| 348 |
+
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
| 349 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
| 350 |
+
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
gradOutput.transpose_(1, 2);
|
| 354 |
+
gradOutput =
|
| 355 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
| 356 |
+
|
| 357 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 358 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 359 |
+
gradOffset = gradOffset.view(
|
| 360 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 361 |
+
offset = offset.view(
|
| 362 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 363 |
+
|
| 364 |
+
if (batch == 0) {
|
| 365 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
| 366 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
| 367 |
+
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
| 368 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
| 369 |
+
gradOffset =
|
| 370 |
+
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
return 1;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
int deform_conv_backward_parameters_cuda(
|
| 377 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
| 378 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
| 379 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
| 380 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
| 381 |
+
int deformable_group, float scale, int im2col_step) {
|
| 382 |
+
// todo: transpose and reshape outGrad
|
| 383 |
+
// todo: reshape columns
|
| 384 |
+
// todo: add im2col_step as input
|
| 385 |
+
|
| 386 |
+
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
| 387 |
+
padW, dilationH, dilationW, group, deformable_group);
|
| 388 |
+
at::DeviceGuard guard(input.device());
|
| 389 |
+
|
| 390 |
+
input = input.contiguous();
|
| 391 |
+
offset = offset.contiguous();
|
| 392 |
+
gradOutput = gradOutput.contiguous();
|
| 393 |
+
|
| 394 |
+
int batch = 1;
|
| 395 |
+
|
| 396 |
+
if (input.ndimension() == 3) {
|
| 397 |
+
// Force batch
|
| 398 |
+
batch = 0;
|
| 399 |
+
input = input.view(
|
| 400 |
+
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
| 401 |
+
gradOutput = gradOutput.view(
|
| 402 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
long batchSize = input.size(0);
|
| 406 |
+
long nInputPlane = input.size(1);
|
| 407 |
+
long inputHeight = input.size(2);
|
| 408 |
+
long inputWidth = input.size(3);
|
| 409 |
+
|
| 410 |
+
long nOutputPlane = gradWeight.size(0);
|
| 411 |
+
|
| 412 |
+
long outputWidth =
|
| 413 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 414 |
+
long outputHeight =
|
| 415 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 416 |
+
|
| 417 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
| 418 |
+
|
| 419 |
+
columns = at::zeros(
|
| 420 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
| 421 |
+
input.options());
|
| 422 |
+
|
| 423 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
| 424 |
+
nOutputPlane, outputHeight, outputWidth});
|
| 425 |
+
gradOutput.transpose_(1, 2);
|
| 426 |
+
|
| 427 |
+
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
| 428 |
+
gradOutputBuffer =
|
| 429 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
| 430 |
+
outputHeight, outputWidth});
|
| 431 |
+
gradOutputBuffer.copy_(gradOutput);
|
| 432 |
+
gradOutputBuffer =
|
| 433 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
| 434 |
+
im2col_step * outputHeight, outputWidth});
|
| 435 |
+
|
| 436 |
+
gradOutput.transpose_(1, 2);
|
| 437 |
+
gradOutput =
|
| 438 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
| 439 |
+
|
| 440 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 441 |
+
inputHeight, inputWidth});
|
| 442 |
+
offset =
|
| 443 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
| 444 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 445 |
+
|
| 446 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
| 447 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
| 448 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
| 449 |
+
dilationW, im2col_step, deformable_group, columns);
|
| 450 |
+
|
| 451 |
+
// divide into group
|
| 452 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
| 453 |
+
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
| 454 |
+
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
| 455 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 456 |
+
gradWeight =
|
| 457 |
+
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
| 458 |
+
gradWeight.size(2), gradWeight.size(3)});
|
| 459 |
+
|
| 460 |
+
for (int g = 0; g < group; g++) {
|
| 461 |
+
gradWeight[g] = gradWeight[g]
|
| 462 |
+
.flatten(1)
|
| 463 |
+
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
| 464 |
+
columns[g].transpose(1, 0), 1.0, scale)
|
| 465 |
+
.view_as(gradWeight[g]);
|
| 466 |
+
}
|
| 467 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
| 468 |
+
{gradOutputBuffer.size(0),
|
| 469 |
+
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
| 470 |
+
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
| 471 |
+
columns =
|
| 472 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 473 |
+
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
| 474 |
+
gradWeight.size(2), gradWeight.size(3),
|
| 475 |
+
gradWeight.size(4)});
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 479 |
+
offset = offset.view(
|
| 480 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 481 |
+
|
| 482 |
+
if (batch == 0) {
|
| 483 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
| 484 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
return 1;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
void modulated_deform_conv_cuda_forward(
|
| 491 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 492 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
| 493 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
| 494 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
| 495 |
+
const int dilation_w, const int group, const int deformable_group,
|
| 496 |
+
const bool with_bias) {
|
| 497 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
| 498 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
| 499 |
+
at::DeviceGuard guard(input.device());
|
| 500 |
+
|
| 501 |
+
const int batch = input.size(0);
|
| 502 |
+
const int channels = input.size(1);
|
| 503 |
+
const int height = input.size(2);
|
| 504 |
+
const int width = input.size(3);
|
| 505 |
+
|
| 506 |
+
const int channels_out = weight.size(0);
|
| 507 |
+
const int channels_kernel = weight.size(1);
|
| 508 |
+
const int kernel_h_ = weight.size(2);
|
| 509 |
+
const int kernel_w_ = weight.size(3);
|
| 510 |
+
|
| 511 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
| 512 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
| 513 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
| 514 |
+
if (channels != channels_kernel * group)
|
| 515 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
| 516 |
+
channels, channels_kernel * group);
|
| 517 |
+
|
| 518 |
+
const int height_out =
|
| 519 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
| 520 |
+
const int width_out =
|
| 521 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
| 522 |
+
|
| 523 |
+
if (ones.ndimension() != 2 ||
|
| 524 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
| 525 |
+
// Resize plane and fill with ones...
|
| 526 |
+
ones = at::ones({height_out, width_out}, input.options());
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
// resize output
|
| 530 |
+
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
| 531 |
+
// resize temporary columns
|
| 532 |
+
columns =
|
| 533 |
+
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
| 534 |
+
input.options());
|
| 535 |
+
|
| 536 |
+
output = output.view({output.size(0), group, output.size(1) / group,
|
| 537 |
+
output.size(2), output.size(3)});
|
| 538 |
+
|
| 539 |
+
for (int b = 0; b < batch; b++) {
|
| 540 |
+
modulated_deformable_im2col_cuda(
|
| 541 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
| 542 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 543 |
+
dilation_h, dilation_w, deformable_group, columns);
|
| 544 |
+
|
| 545 |
+
// divide into group
|
| 546 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 547 |
+
weight.size(2), weight.size(3)});
|
| 548 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 549 |
+
|
| 550 |
+
for (int g = 0; g < group; g++) {
|
| 551 |
+
output[b][g] = output[b][g]
|
| 552 |
+
.flatten(1)
|
| 553 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
| 554 |
+
.view_as(output[b][g]);
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
| 558 |
+
weight.size(3), weight.size(4)});
|
| 559 |
+
columns =
|
| 560 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
output = output.view({output.size(0), output.size(1) * output.size(2),
|
| 564 |
+
output.size(3), output.size(4)});
|
| 565 |
+
|
| 566 |
+
if (with_bias) {
|
| 567 |
+
output += bias.view({1, bias.size(0), 1, 1});
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
void modulated_deform_conv_cuda_backward(
|
| 572 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 573 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
| 574 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
| 575 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
| 576 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
| 577 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
| 578 |
+
const bool with_bias) {
|
| 579 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
| 580 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
| 581 |
+
at::DeviceGuard guard(input.device());
|
| 582 |
+
|
| 583 |
+
const int batch = input.size(0);
|
| 584 |
+
const int channels = input.size(1);
|
| 585 |
+
const int height = input.size(2);
|
| 586 |
+
const int width = input.size(3);
|
| 587 |
+
|
| 588 |
+
const int channels_kernel = weight.size(1);
|
| 589 |
+
const int kernel_h_ = weight.size(2);
|
| 590 |
+
const int kernel_w_ = weight.size(3);
|
| 591 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
| 592 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
| 593 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
| 594 |
+
if (channels != channels_kernel * group)
|
| 595 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
| 596 |
+
channels, channels_kernel * group);
|
| 597 |
+
|
| 598 |
+
const int height_out =
|
| 599 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
| 600 |
+
const int width_out =
|
| 601 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
| 602 |
+
|
| 603 |
+
if (ones.ndimension() != 2 ||
|
| 604 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
| 605 |
+
// Resize plane and fill with ones...
|
| 606 |
+
ones = at::ones({height_out, width_out}, input.options());
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
grad_input = grad_input.view({batch, channels, height, width});
|
| 610 |
+
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
| 611 |
+
input.options());
|
| 612 |
+
|
| 613 |
+
grad_output =
|
| 614 |
+
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
| 615 |
+
grad_output.size(2), grad_output.size(3)});
|
| 616 |
+
|
| 617 |
+
for (int b = 0; b < batch; b++) {
|
| 618 |
+
// divide int group
|
| 619 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 620 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 621 |
+
weight.size(2), weight.size(3)});
|
| 622 |
+
|
| 623 |
+
for (int g = 0; g < group; g++) {
|
| 624 |
+
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
| 625 |
+
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
columns =
|
| 629 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 630 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
| 631 |
+
weight.size(3), weight.size(4)});
|
| 632 |
+
|
| 633 |
+
// gradient w.r.t. input coordinate data
|
| 634 |
+
modulated_deformable_col2im_coord_cuda(
|
| 635 |
+
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
| 636 |
+
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
| 637 |
+
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
| 638 |
+
grad_mask[b]);
|
| 639 |
+
// gradient w.r.t. input data
|
| 640 |
+
modulated_deformable_col2im_cuda(
|
| 641 |
+
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
| 642 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 643 |
+
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
| 644 |
+
|
| 645 |
+
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
| 646 |
+
// group
|
| 647 |
+
modulated_deformable_im2col_cuda(
|
| 648 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
| 649 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 650 |
+
dilation_h, dilation_w, deformable_group, columns);
|
| 651 |
+
|
| 652 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 653 |
+
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
| 654 |
+
grad_weight.size(1), grad_weight.size(2),
|
| 655 |
+
grad_weight.size(3)});
|
| 656 |
+
if (with_bias)
|
| 657 |
+
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
| 658 |
+
|
| 659 |
+
for (int g = 0; g < group; g++) {
|
| 660 |
+
grad_weight[g] =
|
| 661 |
+
grad_weight[g]
|
| 662 |
+
.flatten(1)
|
| 663 |
+
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
| 664 |
+
.view_as(grad_weight[g]);
|
| 665 |
+
if (with_bias) {
|
| 666 |
+
grad_bias[g] =
|
| 667 |
+
grad_bias[g]
|
| 668 |
+
.view({-1, 1})
|
| 669 |
+
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
| 670 |
+
.view(-1);
|
| 671 |
+
}
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
columns =
|
| 675 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 676 |
+
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
| 677 |
+
grad_weight.size(2), grad_weight.size(3),
|
| 678 |
+
grad_weight.size(4)});
|
| 679 |
+
if (with_bias)
|
| 680 |
+
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
| 681 |
+
}
|
| 682 |
+
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
| 683 |
+
grad_output.size(2), grad_output.size(3),
|
| 684 |
+
grad_output.size(4)});
|
| 685 |
+
}
|
basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
|
| 3 |
+
*
|
| 4 |
+
* COPYRIGHT
|
| 5 |
+
*
|
| 6 |
+
* All contributions by the University of California:
|
| 7 |
+
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
| 8 |
+
* All rights reserved.
|
| 9 |
+
*
|
| 10 |
+
* All other contributions:
|
| 11 |
+
* Copyright (c) 2014-2017, the respective contributors
|
| 12 |
+
* All rights reserved.
|
| 13 |
+
*
|
| 14 |
+
* Caffe uses a shared copyright model: each contributor holds copyright over
|
| 15 |
+
* their contributions to Caffe. The project versioning records all such
|
| 16 |
+
* contribution and copyright details. If a contributor wants to further mark
|
| 17 |
+
* their specific copyright on a particular contribution, they should indicate
|
| 18 |
+
* their copyright solely in the commit message of the change when it is
|
| 19 |
+
* committed.
|
| 20 |
+
*
|
| 21 |
+
* LICENSE
|
| 22 |
+
*
|
| 23 |
+
* Redistribution and use in source and binary forms, with or without
|
| 24 |
+
* modification, are permitted provided that the following conditions are met:
|
| 25 |
+
*
|
| 26 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 27 |
+
* list of conditions and the following disclaimer.
|
| 28 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 29 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 30 |
+
* and/or other materials provided with the distribution.
|
| 31 |
+
*
|
| 32 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 33 |
+
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 34 |
+
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 35 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
| 36 |
+
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 37 |
+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 38 |
+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 39 |
+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 40 |
+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 41 |
+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 42 |
+
*
|
| 43 |
+
* CONTRIBUTION AGREEMENT
|
| 44 |
+
*
|
| 45 |
+
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
| 46 |
+
* or otherwise, the contributor releases their content to the
|
| 47 |
+
* license and copyright terms herein.
|
| 48 |
+
*
|
| 49 |
+
***************** END Caffe Copyright Notice and Disclaimer ********************
|
| 50 |
+
*
|
| 51 |
+
* Copyright (c) 2018 Microsoft
|
| 52 |
+
* Licensed under The MIT License [see LICENSE for details]
|
| 53 |
+
* \file modulated_deformable_im2col.cuh
|
| 54 |
+
* \brief Function definitions of converting an image to
|
| 55 |
+
* column matrix based on kernel, padding, dilation, and offset.
|
| 56 |
+
* These functions are mainly used in deformable convolution operators.
|
| 57 |
+
* \ref: https://arxiv.org/abs/1703.06211
|
| 58 |
+
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
| 59 |
+
*/
|
| 60 |
+
|
| 61 |
+
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
| 62 |
+
|
| 63 |
+
#include <ATen/ATen.h>
|
| 64 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 65 |
+
#include <THC/THCAtomics.cuh>
|
| 66 |
+
#include <stdio.h>
|
| 67 |
+
#include <math.h>
|
| 68 |
+
#include <float.h>
|
| 69 |
+
|
| 70 |
+
using namespace at;
|
| 71 |
+
|
| 72 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 73 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
| 74 |
+
i += blockDim.x * gridDim.x)
|
| 75 |
+
|
| 76 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 77 |
+
const int kMaxGridNum = 65535;
|
| 78 |
+
|
| 79 |
+
inline int GET_BLOCKS(const int N)
|
| 80 |
+
{
|
| 81 |
+
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
template <typename scalar_t>
|
| 85 |
+
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
| 86 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
| 87 |
+
{
|
| 88 |
+
|
| 89 |
+
int h_low = floor(h);
|
| 90 |
+
int w_low = floor(w);
|
| 91 |
+
int h_high = h_low + 1;
|
| 92 |
+
int w_high = w_low + 1;
|
| 93 |
+
|
| 94 |
+
scalar_t lh = h - h_low;
|
| 95 |
+
scalar_t lw = w - w_low;
|
| 96 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 97 |
+
|
| 98 |
+
scalar_t v1 = 0;
|
| 99 |
+
if (h_low >= 0 && w_low >= 0)
|
| 100 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
| 101 |
+
scalar_t v2 = 0;
|
| 102 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 103 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
| 104 |
+
scalar_t v3 = 0;
|
| 105 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 106 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
| 107 |
+
scalar_t v4 = 0;
|
| 108 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 109 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
| 110 |
+
|
| 111 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 112 |
+
|
| 113 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 114 |
+
return val;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
template <typename scalar_t>
|
| 118 |
+
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 119 |
+
const int h, const int w, const int height, const int width)
|
| 120 |
+
{
|
| 121 |
+
|
| 122 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 123 |
+
{
|
| 124 |
+
//empty
|
| 125 |
+
return 0;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
int argmax_h_low = floor(argmax_h);
|
| 129 |
+
int argmax_w_low = floor(argmax_w);
|
| 130 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 131 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 132 |
+
|
| 133 |
+
scalar_t weight = 0;
|
| 134 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
| 135 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
| 136 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
| 137 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
| 138 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
| 139 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
| 140 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
| 141 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
| 142 |
+
return weight;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <typename scalar_t>
|
| 146 |
+
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 147 |
+
const int height, const int width, const scalar_t *im_data,
|
| 148 |
+
const int data_width, const int bp_dir)
|
| 149 |
+
{
|
| 150 |
+
|
| 151 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 152 |
+
{
|
| 153 |
+
//empty
|
| 154 |
+
return 0;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
int argmax_h_low = floor(argmax_h);
|
| 158 |
+
int argmax_w_low = floor(argmax_w);
|
| 159 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 160 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 161 |
+
|
| 162 |
+
scalar_t weight = 0;
|
| 163 |
+
|
| 164 |
+
if (bp_dir == 0)
|
| 165 |
+
{
|
| 166 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 167 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 168 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 169 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 170 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 171 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 172 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 173 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 174 |
+
}
|
| 175 |
+
else if (bp_dir == 1)
|
| 176 |
+
{
|
| 177 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 178 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 179 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 180 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 181 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 182 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 183 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 184 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
return weight;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <typename scalar_t>
|
| 191 |
+
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
|
| 192 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
| 193 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 194 |
+
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
|
| 195 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
| 196 |
+
const int height_col, const int width_col,
|
| 197 |
+
scalar_t *data_col)
|
| 198 |
+
{
|
| 199 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 200 |
+
{
|
| 201 |
+
// index index of output matrix
|
| 202 |
+
const int w_col = index % width_col;
|
| 203 |
+
const int h_col = (index / width_col) % height_col;
|
| 204 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
| 205 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
| 206 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
| 207 |
+
|
| 208 |
+
// compute deformable group index
|
| 209 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
| 210 |
+
|
| 211 |
+
const int h_in = h_col * stride_h - pad_h;
|
| 212 |
+
const int w_in = w_col * stride_w - pad_w;
|
| 213 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
| 214 |
+
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
| 215 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
| 216 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 217 |
+
|
| 218 |
+
for (int i = 0; i < kernel_h; ++i)
|
| 219 |
+
{
|
| 220 |
+
for (int j = 0; j < kernel_w; ++j)
|
| 221 |
+
{
|
| 222 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
| 223 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
| 224 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 225 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 226 |
+
scalar_t val = static_cast<scalar_t>(0);
|
| 227 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
| 228 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
| 229 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
| 230 |
+
{
|
| 231 |
+
//const scalar_t map_h = i * dilation_h + offset_h;
|
| 232 |
+
//const scalar_t map_w = j * dilation_w + offset_w;
|
| 233 |
+
//const int cur_height = height - h_in;
|
| 234 |
+
//const int cur_width = width - w_in;
|
| 235 |
+
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
| 236 |
+
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
| 237 |
+
}
|
| 238 |
+
*data_col_ptr = val;
|
| 239 |
+
data_col_ptr += batch_size * height_col * width_col;
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
void deformable_im2col(
|
| 246 |
+
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
|
| 247 |
+
const int height, const int width, const int ksize_h, const int ksize_w,
|
| 248 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 249 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
| 250 |
+
const int deformable_group, at::Tensor data_col)
|
| 251 |
+
{
|
| 252 |
+
// num_axes should be smaller than block size
|
| 253 |
+
// todo: check parallel_imgs is correctly passed in
|
| 254 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
| 255 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
| 256 |
+
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
| 257 |
+
int channel_per_deformable_group = channels / deformable_group;
|
| 258 |
+
|
| 259 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 260 |
+
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
|
| 261 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 262 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 263 |
+
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 264 |
+
|
| 265 |
+
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 266 |
+
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
|
| 267 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
| 268 |
+
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
|
| 269 |
+
height_col, width_col, data_col_);
|
| 270 |
+
}));
|
| 271 |
+
|
| 272 |
+
cudaError_t err = cudaGetLastError();
|
| 273 |
+
if (err != cudaSuccess)
|
| 274 |
+
{
|
| 275 |
+
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template <typename scalar_t>
|
| 280 |
+
__global__ void deformable_col2im_gpu_kernel(
|
| 281 |
+
const int n, const scalar_t *data_col, const scalar_t *data_offset,
|
| 282 |
+
const int channels, const int height, const int width,
|
| 283 |
+
const int kernel_h, const int kernel_w,
|
| 284 |
+
const int pad_h, const int pad_w,
|
| 285 |
+
const int stride_h, const int stride_w,
|
| 286 |
+
const int dilation_h, const int dilation_w,
|
| 287 |
+
const int channel_per_deformable_group,
|
| 288 |
+
const int batch_size, const int deformable_group,
|
| 289 |
+
const int height_col, const int width_col,
|
| 290 |
+
scalar_t *grad_im)
|
| 291 |
+
{
|
| 292 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 293 |
+
{
|
| 294 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
| 295 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 296 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
| 297 |
+
// compute the start and end of the output
|
| 298 |
+
|
| 299 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
| 300 |
+
|
| 301 |
+
int w_out = index % width_col;
|
| 302 |
+
int h_out = (index / width_col) % height_col;
|
| 303 |
+
int b = (index / width_col / height_col) % batch_size;
|
| 304 |
+
int w_in = w_out * stride_w - pad_w;
|
| 305 |
+
int h_in = h_out * stride_h - pad_h;
|
| 306 |
+
|
| 307 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
|
| 308 |
+
2 * kernel_h * kernel_w * height_col * width_col;
|
| 309 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
| 310 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
| 311 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 312 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 313 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
| 314 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
| 315 |
+
|
| 316 |
+
const scalar_t cur_top_grad = data_col[index];
|
| 317 |
+
const int cur_h = (int)cur_inv_h_data;
|
| 318 |
+
const int cur_w = (int)cur_inv_w_data;
|
| 319 |
+
for (int dy = -2; dy <= 2; dy++)
|
| 320 |
+
{
|
| 321 |
+
for (int dx = -2; dx <= 2; dx++)
|
| 322 |
+
{
|
| 323 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
| 324 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
| 325 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
| 326 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
| 327 |
+
{
|
| 328 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
| 329 |
+
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
| 330 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
void deformable_col2im(
|
| 338 |
+
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
|
| 339 |
+
const int height, const int width, const int ksize_h,
|
| 340 |
+
const int ksize_w, const int pad_h, const int pad_w,
|
| 341 |
+
const int stride_h, const int stride_w,
|
| 342 |
+
const int dilation_h, const int dilation_w,
|
| 343 |
+
const int parallel_imgs, const int deformable_group,
|
| 344 |
+
at::Tensor grad_im)
|
| 345 |
+
{
|
| 346 |
+
|
| 347 |
+
// todo: make sure parallel_imgs is passed in correctly
|
| 348 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
| 349 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
| 350 |
+
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
| 351 |
+
int channel_per_deformable_group = channels / deformable_group;
|
| 352 |
+
|
| 353 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 354 |
+
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
|
| 355 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 356 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 357 |
+
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
| 358 |
+
|
| 359 |
+
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 360 |
+
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
|
| 361 |
+
ksize_w, pad_h, pad_w, stride_h, stride_w,
|
| 362 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 363 |
+
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
|
| 364 |
+
}));
|
| 365 |
+
|
| 366 |
+
cudaError_t err = cudaGetLastError();
|
| 367 |
+
if (err != cudaSuccess)
|
| 368 |
+
{
|
| 369 |
+
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
template <typename scalar_t>
|
| 374 |
+
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
|
| 375 |
+
const scalar_t *data_im, const scalar_t *data_offset,
|
| 376 |
+
const int channels, const int height, const int width,
|
| 377 |
+
const int kernel_h, const int kernel_w,
|
| 378 |
+
const int pad_h, const int pad_w,
|
| 379 |
+
const int stride_h, const int stride_w,
|
| 380 |
+
const int dilation_h, const int dilation_w,
|
| 381 |
+
const int channel_per_deformable_group,
|
| 382 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
| 383 |
+
const int height_col, const int width_col, scalar_t *grad_offset)
|
| 384 |
+
{
|
| 385 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 386 |
+
{
|
| 387 |
+
scalar_t val = 0;
|
| 388 |
+
int w = index % width_col;
|
| 389 |
+
int h = (index / width_col) % height_col;
|
| 390 |
+
int c = (index / width_col / height_col) % offset_channels;
|
| 391 |
+
int b = (index / width_col / height_col) / offset_channels;
|
| 392 |
+
// compute the start and end of the output
|
| 393 |
+
|
| 394 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
| 395 |
+
const int col_step = kernel_h * kernel_w;
|
| 396 |
+
int cnt = 0;
|
| 397 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
|
| 398 |
+
batch_size * width_col * height_col;
|
| 399 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
|
| 400 |
+
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
| 401 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
| 402 |
+
kernel_h * kernel_w * height_col * width_col;
|
| 403 |
+
|
| 404 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
| 405 |
+
|
| 406 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
| 407 |
+
{
|
| 408 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
| 409 |
+
const int bp_dir = offset_c % 2;
|
| 410 |
+
|
| 411 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
| 412 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 413 |
+
int w_out = col_pos % width_col;
|
| 414 |
+
int h_out = (col_pos / width_col) % height_col;
|
| 415 |
+
int w_in = w_out * stride_w - pad_w;
|
| 416 |
+
int h_in = h_out * stride_h - pad_h;
|
| 417 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
| 418 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
| 419 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 420 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 421 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
| 422 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
| 423 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
| 424 |
+
{
|
| 425 |
+
inv_h = inv_w = -2;
|
| 426 |
+
}
|
| 427 |
+
const scalar_t weight = get_coordinate_weight(
|
| 428 |
+
inv_h, inv_w,
|
| 429 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
| 430 |
+
val += weight * data_col_ptr[col_pos];
|
| 431 |
+
cnt += 1;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
grad_offset[index] = val;
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
void deformable_col2im_coord(
|
| 439 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
|
| 440 |
+
const int channels, const int height, const int width, const int ksize_h,
|
| 441 |
+
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
|
| 442 |
+
const int stride_w, const int dilation_h, const int dilation_w,
|
| 443 |
+
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
|
| 444 |
+
{
|
| 445 |
+
|
| 446 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
| 447 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
| 448 |
+
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
|
| 449 |
+
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
|
| 450 |
+
|
| 451 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 452 |
+
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
|
| 453 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 454 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 455 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 456 |
+
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
| 457 |
+
|
| 458 |
+
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 459 |
+
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
|
| 460 |
+
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
| 461 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 462 |
+
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
| 463 |
+
height_col, width_col, grad_offset_);
|
| 464 |
+
}));
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
template <typename scalar_t>
|
| 468 |
+
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
| 469 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
| 470 |
+
{
|
| 471 |
+
int h_low = floor(h);
|
| 472 |
+
int w_low = floor(w);
|
| 473 |
+
int h_high = h_low + 1;
|
| 474 |
+
int w_high = w_low + 1;
|
| 475 |
+
|
| 476 |
+
scalar_t lh = h - h_low;
|
| 477 |
+
scalar_t lw = w - w_low;
|
| 478 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 479 |
+
|
| 480 |
+
scalar_t v1 = 0;
|
| 481 |
+
if (h_low >= 0 && w_low >= 0)
|
| 482 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
| 483 |
+
scalar_t v2 = 0;
|
| 484 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 485 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
| 486 |
+
scalar_t v3 = 0;
|
| 487 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 488 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
| 489 |
+
scalar_t v4 = 0;
|
| 490 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 491 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
| 492 |
+
|
| 493 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 494 |
+
|
| 495 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 496 |
+
return val;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
template <typename scalar_t>
|
| 500 |
+
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 501 |
+
const int h, const int w, const int height, const int width)
|
| 502 |
+
{
|
| 503 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 504 |
+
{
|
| 505 |
+
//empty
|
| 506 |
+
return 0;
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
int argmax_h_low = floor(argmax_h);
|
| 510 |
+
int argmax_w_low = floor(argmax_w);
|
| 511 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 512 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 513 |
+
|
| 514 |
+
scalar_t weight = 0;
|
| 515 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
| 516 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
| 517 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
| 518 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
| 519 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
| 520 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
| 521 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
| 522 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
| 523 |
+
return weight;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
template <typename scalar_t>
|
| 527 |
+
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 528 |
+
const int height, const int width, const scalar_t *im_data,
|
| 529 |
+
const int data_width, const int bp_dir)
|
| 530 |
+
{
|
| 531 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 532 |
+
{
|
| 533 |
+
//empty
|
| 534 |
+
return 0;
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
int argmax_h_low = floor(argmax_h);
|
| 538 |
+
int argmax_w_low = floor(argmax_w);
|
| 539 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 540 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 541 |
+
|
| 542 |
+
scalar_t weight = 0;
|
| 543 |
+
|
| 544 |
+
if (bp_dir == 0)
|
| 545 |
+
{
|
| 546 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 547 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 548 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 549 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 550 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 551 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 552 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 553 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 554 |
+
}
|
| 555 |
+
else if (bp_dir == 1)
|
| 556 |
+
{
|
| 557 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 558 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 559 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 560 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 561 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 562 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 563 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 564 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
return weight;
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
template <typename scalar_t>
|
| 571 |
+
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
|
| 572 |
+
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
|
| 573 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
| 574 |
+
const int pad_h, const int pad_w,
|
| 575 |
+
const int stride_h, const int stride_w,
|
| 576 |
+
const int dilation_h, const int dilation_w,
|
| 577 |
+
const int channel_per_deformable_group,
|
| 578 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
| 579 |
+
const int height_col, const int width_col,
|
| 580 |
+
scalar_t *data_col)
|
| 581 |
+
{
|
| 582 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 583 |
+
{
|
| 584 |
+
// index index of output matrix
|
| 585 |
+
const int w_col = index % width_col;
|
| 586 |
+
const int h_col = (index / width_col) % height_col;
|
| 587 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
| 588 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
| 589 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
| 590 |
+
|
| 591 |
+
// compute deformable group index
|
| 592 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
| 593 |
+
|
| 594 |
+
const int h_in = h_col * stride_h - pad_h;
|
| 595 |
+
const int w_in = w_col * stride_w - pad_w;
|
| 596 |
+
|
| 597 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
| 598 |
+
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
| 599 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
| 600 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 601 |
+
|
| 602 |
+
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
| 603 |
+
|
| 604 |
+
for (int i = 0; i < kernel_h; ++i)
|
| 605 |
+
{
|
| 606 |
+
for (int j = 0; j < kernel_w; ++j)
|
| 607 |
+
{
|
| 608 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
| 609 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
| 610 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
| 611 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 612 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 613 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
| 614 |
+
scalar_t val = static_cast<scalar_t>(0);
|
| 615 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
| 616 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
| 617 |
+
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
| 618 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
| 619 |
+
{
|
| 620 |
+
//const float map_h = i * dilation_h + offset_h;
|
| 621 |
+
//const float map_w = j * dilation_w + offset_w;
|
| 622 |
+
//const int cur_height = height - h_in;
|
| 623 |
+
//const int cur_width = width - w_in;
|
| 624 |
+
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
| 625 |
+
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
| 626 |
+
}
|
| 627 |
+
*data_col_ptr = val * mask;
|
| 628 |
+
data_col_ptr += batch_size * height_col * width_col;
|
| 629 |
+
//data_col_ptr += height_col * width_col;
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
template <typename scalar_t>
|
| 636 |
+
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
|
| 637 |
+
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
|
| 638 |
+
const int channels, const int height, const int width,
|
| 639 |
+
const int kernel_h, const int kernel_w,
|
| 640 |
+
const int pad_h, const int pad_w,
|
| 641 |
+
const int stride_h, const int stride_w,
|
| 642 |
+
const int dilation_h, const int dilation_w,
|
| 643 |
+
const int channel_per_deformable_group,
|
| 644 |
+
const int batch_size, const int deformable_group,
|
| 645 |
+
const int height_col, const int width_col,
|
| 646 |
+
scalar_t *grad_im)
|
| 647 |
+
{
|
| 648 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 649 |
+
{
|
| 650 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
| 651 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 652 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
| 653 |
+
// compute the start and end of the output
|
| 654 |
+
|
| 655 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
| 656 |
+
|
| 657 |
+
int w_out = index % width_col;
|
| 658 |
+
int h_out = (index / width_col) % height_col;
|
| 659 |
+
int b = (index / width_col / height_col) % batch_size;
|
| 660 |
+
int w_in = w_out * stride_w - pad_w;
|
| 661 |
+
int h_in = h_out * stride_h - pad_h;
|
| 662 |
+
|
| 663 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 664 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
| 665 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
| 666 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
| 667 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
| 668 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 669 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 670 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
| 671 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
| 672 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
| 673 |
+
|
| 674 |
+
const scalar_t cur_top_grad = data_col[index] * mask;
|
| 675 |
+
const int cur_h = (int)cur_inv_h_data;
|
| 676 |
+
const int cur_w = (int)cur_inv_w_data;
|
| 677 |
+
for (int dy = -2; dy <= 2; dy++)
|
| 678 |
+
{
|
| 679 |
+
for (int dx = -2; dx <= 2; dx++)
|
| 680 |
+
{
|
| 681 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
| 682 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
| 683 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
| 684 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
| 685 |
+
{
|
| 686 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
| 687 |
+
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
| 688 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
| 689 |
+
}
|
| 690 |
+
}
|
| 691 |
+
}
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
template <typename scalar_t>
|
| 696 |
+
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
|
| 697 |
+
const scalar_t *data_col, const scalar_t *data_im,
|
| 698 |
+
const scalar_t *data_offset, const scalar_t *data_mask,
|
| 699 |
+
const int channels, const int height, const int width,
|
| 700 |
+
const int kernel_h, const int kernel_w,
|
| 701 |
+
const int pad_h, const int pad_w,
|
| 702 |
+
const int stride_h, const int stride_w,
|
| 703 |
+
const int dilation_h, const int dilation_w,
|
| 704 |
+
const int channel_per_deformable_group,
|
| 705 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
| 706 |
+
const int height_col, const int width_col,
|
| 707 |
+
scalar_t *grad_offset, scalar_t *grad_mask)
|
| 708 |
+
{
|
| 709 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 710 |
+
{
|
| 711 |
+
scalar_t val = 0, mval = 0;
|
| 712 |
+
int w = index % width_col;
|
| 713 |
+
int h = (index / width_col) % height_col;
|
| 714 |
+
int c = (index / width_col / height_col) % offset_channels;
|
| 715 |
+
int b = (index / width_col / height_col) / offset_channels;
|
| 716 |
+
// compute the start and end of the output
|
| 717 |
+
|
| 718 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
| 719 |
+
const int col_step = kernel_h * kernel_w;
|
| 720 |
+
int cnt = 0;
|
| 721 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
|
| 722 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
| 723 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 724 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
| 725 |
+
|
| 726 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
| 727 |
+
|
| 728 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
| 729 |
+
{
|
| 730 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
| 731 |
+
const int bp_dir = offset_c % 2;
|
| 732 |
+
|
| 733 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
| 734 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 735 |
+
int w_out = col_pos % width_col;
|
| 736 |
+
int h_out = (col_pos / width_col) % height_col;
|
| 737 |
+
int w_in = w_out * stride_w - pad_w;
|
| 738 |
+
int h_in = h_out * stride_h - pad_h;
|
| 739 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
| 740 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
| 741 |
+
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
| 742 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 743 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 744 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
| 745 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
| 746 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
| 747 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
| 748 |
+
{
|
| 749 |
+
inv_h = inv_w = -2;
|
| 750 |
+
}
|
| 751 |
+
else
|
| 752 |
+
{
|
| 753 |
+
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
|
| 754 |
+
}
|
| 755 |
+
const scalar_t weight = dmcn_get_coordinate_weight(
|
| 756 |
+
inv_h, inv_w,
|
| 757 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
| 758 |
+
val += weight * data_col_ptr[col_pos] * mask;
|
| 759 |
+
cnt += 1;
|
| 760 |
+
}
|
| 761 |
+
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
| 762 |
+
grad_offset[index] = val;
|
| 763 |
+
if (offset_c % 2 == 0)
|
| 764 |
+
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
|
| 765 |
+
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
|
| 766 |
+
}
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
void modulated_deformable_im2col_cuda(
|
| 770 |
+
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
| 771 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
| 772 |
+
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
|
| 773 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 774 |
+
const int dilation_h, const int dilation_w,
|
| 775 |
+
const int deformable_group, at::Tensor data_col)
|
| 776 |
+
{
|
| 777 |
+
// num_axes should be smaller than block size
|
| 778 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
| 779 |
+
const int num_kernels = channels * batch_size * height_col * width_col;
|
| 780 |
+
|
| 781 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 782 |
+
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
|
| 783 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 784 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 785 |
+
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
| 786 |
+
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 787 |
+
|
| 788 |
+
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 789 |
+
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
|
| 790 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
| 791 |
+
batch_size, channels, deformable_group, height_col, width_col, data_col_);
|
| 792 |
+
}));
|
| 793 |
+
|
| 794 |
+
cudaError_t err = cudaGetLastError();
|
| 795 |
+
if (err != cudaSuccess)
|
| 796 |
+
{
|
| 797 |
+
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 798 |
+
}
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
void modulated_deformable_col2im_cuda(
|
| 802 |
+
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
|
| 803 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
| 804 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
| 805 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 806 |
+
const int dilation_h, const int dilation_w,
|
| 807 |
+
const int deformable_group, at::Tensor grad_im)
|
| 808 |
+
{
|
| 809 |
+
|
| 810 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
| 811 |
+
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
| 812 |
+
|
| 813 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 814 |
+
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
|
| 815 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 816 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 817 |
+
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
| 818 |
+
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
| 819 |
+
|
| 820 |
+
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 821 |
+
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
|
| 822 |
+
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 823 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 824 |
+
batch_size, deformable_group, height_col, width_col, grad_im_);
|
| 825 |
+
}));
|
| 826 |
+
|
| 827 |
+
cudaError_t err = cudaGetLastError();
|
| 828 |
+
if (err != cudaSuccess)
|
| 829 |
+
{
|
| 830 |
+
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 831 |
+
}
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
void modulated_deformable_col2im_coord_cuda(
|
| 835 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
| 836 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
| 837 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
| 838 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 839 |
+
const int dilation_h, const int dilation_w,
|
| 840 |
+
const int deformable_group,
|
| 841 |
+
at::Tensor grad_offset, at::Tensor grad_mask)
|
| 842 |
+
{
|
| 843 |
+
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
|
| 844 |
+
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
|
| 845 |
+
|
| 846 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 847 |
+
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
| 848 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 849 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 850 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 851 |
+
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
| 852 |
+
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
| 853 |
+
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
|
| 854 |
+
|
| 855 |
+
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 856 |
+
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
|
| 857 |
+
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 858 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 859 |
+
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
|
| 860 |
+
grad_offset_, grad_mask_);
|
| 861 |
+
}));
|
| 862 |
+
cudaError_t err = cudaGetLastError();
|
| 863 |
+
if (err != cudaSuccess)
|
| 864 |
+
{
|
| 865 |
+
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
|
| 866 |
+
}
|
| 867 |
+
}
|
basicsr/ops/dcn/src/deform_conv_ext.cpp
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// modify from
|
| 2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
| 3 |
+
|
| 4 |
+
#include <torch/extension.h>
|
| 5 |
+
#include <ATen/DeviceGuard.h>
|
| 6 |
+
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#define WITH_CUDA // always use cuda
|
| 11 |
+
#ifdef WITH_CUDA
|
| 12 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
| 13 |
+
at::Tensor offset, at::Tensor output,
|
| 14 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
| 15 |
+
int kH, int dW, int dH, int padW, int padH,
|
| 16 |
+
int dilationW, int dilationH, int group,
|
| 17 |
+
int deformable_group, int im2col_step);
|
| 18 |
+
|
| 19 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
| 20 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
| 21 |
+
at::Tensor gradOffset, at::Tensor weight,
|
| 22 |
+
at::Tensor columns, int kW, int kH, int dW,
|
| 23 |
+
int dH, int padW, int padH, int dilationW,
|
| 24 |
+
int dilationH, int group,
|
| 25 |
+
int deformable_group, int im2col_step);
|
| 26 |
+
|
| 27 |
+
int deform_conv_backward_parameters_cuda(
|
| 28 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
| 29 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
| 30 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
| 31 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
| 32 |
+
int deformable_group, float scale, int im2col_step);
|
| 33 |
+
|
| 34 |
+
void modulated_deform_conv_cuda_forward(
|
| 35 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 36 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
| 37 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
| 38 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
| 39 |
+
const int dilation_w, const int group, const int deformable_group,
|
| 40 |
+
const bool with_bias);
|
| 41 |
+
|
| 42 |
+
void modulated_deform_conv_cuda_backward(
|
| 43 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 44 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
| 45 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
| 46 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
| 47 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
| 48 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
| 49 |
+
const bool with_bias);
|
| 50 |
+
#endif
|
| 51 |
+
|
| 52 |
+
int deform_conv_forward(at::Tensor input, at::Tensor weight,
|
| 53 |
+
at::Tensor offset, at::Tensor output,
|
| 54 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
| 55 |
+
int kH, int dW, int dH, int padW, int padH,
|
| 56 |
+
int dilationW, int dilationH, int group,
|
| 57 |
+
int deformable_group, int im2col_step) {
|
| 58 |
+
if (input.device().is_cuda()) {
|
| 59 |
+
#ifdef WITH_CUDA
|
| 60 |
+
return deform_conv_forward_cuda(input, weight, offset, output, columns,
|
| 61 |
+
ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
|
| 62 |
+
deformable_group, im2col_step);
|
| 63 |
+
#else
|
| 64 |
+
AT_ERROR("deform conv is not compiled with GPU support");
|
| 65 |
+
#endif
|
| 66 |
+
}
|
| 67 |
+
AT_ERROR("deform conv is not implemented on CPU");
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
|
| 71 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
| 72 |
+
at::Tensor gradOffset, at::Tensor weight,
|
| 73 |
+
at::Tensor columns, int kW, int kH, int dW,
|
| 74 |
+
int dH, int padW, int padH, int dilationW,
|
| 75 |
+
int dilationH, int group,
|
| 76 |
+
int deformable_group, int im2col_step) {
|
| 77 |
+
if (input.device().is_cuda()) {
|
| 78 |
+
#ifdef WITH_CUDA
|
| 79 |
+
return deform_conv_backward_input_cuda(input, offset, gradOutput,
|
| 80 |
+
gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
|
| 81 |
+
dilationW, dilationH, group, deformable_group, im2col_step);
|
| 82 |
+
#else
|
| 83 |
+
AT_ERROR("deform conv is not compiled with GPU support");
|
| 84 |
+
#endif
|
| 85 |
+
}
|
| 86 |
+
AT_ERROR("deform conv is not implemented on CPU");
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
int deform_conv_backward_parameters(
|
| 90 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
| 91 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
| 92 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
| 93 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
| 94 |
+
int deformable_group, float scale, int im2col_step) {
|
| 95 |
+
if (input.device().is_cuda()) {
|
| 96 |
+
#ifdef WITH_CUDA
|
| 97 |
+
return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
|
| 98 |
+
gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
|
| 99 |
+
dilationH, group, deformable_group, scale, im2col_step);
|
| 100 |
+
#else
|
| 101 |
+
AT_ERROR("deform conv is not compiled with GPU support");
|
| 102 |
+
#endif
|
| 103 |
+
}
|
| 104 |
+
AT_ERROR("deform conv is not implemented on CPU");
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
void modulated_deform_conv_forward(
|
| 108 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 109 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
| 110 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
| 111 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
| 112 |
+
const int dilation_w, const int group, const int deformable_group,
|
| 113 |
+
const bool with_bias) {
|
| 114 |
+
if (input.device().is_cuda()) {
|
| 115 |
+
#ifdef WITH_CUDA
|
| 116 |
+
return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
|
| 117 |
+
offset, mask, output, columns, kernel_h, kernel_w, stride_h,
|
| 118 |
+
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
| 119 |
+
deformable_group, with_bias);
|
| 120 |
+
#else
|
| 121 |
+
AT_ERROR("modulated deform conv is not compiled with GPU support");
|
| 122 |
+
#endif
|
| 123 |
+
}
|
| 124 |
+
AT_ERROR("modulated deform conv is not implemented on CPU");
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
void modulated_deform_conv_backward(
|
| 128 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 129 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
| 130 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
| 131 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
| 132 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
| 133 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
| 134 |
+
const bool with_bias) {
|
| 135 |
+
if (input.device().is_cuda()) {
|
| 136 |
+
#ifdef WITH_CUDA
|
| 137 |
+
return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
|
| 138 |
+
offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
|
| 139 |
+
grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
|
| 140 |
+
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
|
| 141 |
+
with_bias);
|
| 142 |
+
#else
|
| 143 |
+
AT_ERROR("modulated deform conv is not compiled with GPU support");
|
| 144 |
+
#endif
|
| 145 |
+
}
|
| 146 |
+
AT_ERROR("modulated deform conv is not implemented on CPU");
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 151 |
+
m.def("deform_conv_forward", &deform_conv_forward,
|
| 152 |
+
"deform forward");
|
| 153 |
+
m.def("deform_conv_backward_input", &deform_conv_backward_input,
|
| 154 |
+
"deform_conv_backward_input");
|
| 155 |
+
m.def("deform_conv_backward_parameters",
|
| 156 |
+
&deform_conv_backward_parameters,
|
| 157 |
+
"deform_conv_backward_parameters");
|
| 158 |
+
m.def("modulated_deform_conv_forward",
|
| 159 |
+
&modulated_deform_conv_forward,
|
| 160 |
+
"modulated deform conv forward");
|
| 161 |
+
m.def("modulated_deform_conv_backward",
|
| 162 |
+
&modulated_deform_conv_backward,
|
| 163 |
+
"modulated deform conv backward");
|
| 164 |
+
}
|
basicsr/ops/fused_act/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
| 2 |
+
|
| 3 |
+
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
|
basicsr/ops/fused_act/fused_act.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.autograd import Function
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from . import fused_act_ext
|
| 9 |
+
except ImportError:
|
| 10 |
+
import os
|
| 11 |
+
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
| 12 |
+
if BASICSR_JIT == 'True':
|
| 13 |
+
from torch.utils.cpp_extension import load
|
| 14 |
+
module_path = os.path.dirname(__file__)
|
| 15 |
+
fused_act_ext = load(
|
| 16 |
+
'fused',
|
| 17 |
+
sources=[
|
| 18 |
+
os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
|
| 19 |
+
os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
|
| 20 |
+
],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
| 28 |
+
ctx.save_for_backward(out)
|
| 29 |
+
ctx.negative_slope = negative_slope
|
| 30 |
+
ctx.scale = scale
|
| 31 |
+
|
| 32 |
+
empty = grad_output.new_empty(0)
|
| 33 |
+
|
| 34 |
+
grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
|
| 35 |
+
|
| 36 |
+
dim = [0]
|
| 37 |
+
|
| 38 |
+
if grad_input.ndim > 2:
|
| 39 |
+
dim += list(range(2, grad_input.ndim))
|
| 40 |
+
|
| 41 |
+
grad_bias = grad_input.sum(dim).detach()
|
| 42 |
+
|
| 43 |
+
return grad_input, grad_bias
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
| 47 |
+
out, = ctx.saved_tensors
|
| 48 |
+
gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
|
| 49 |
+
ctx.scale)
|
| 50 |
+
|
| 51 |
+
return gradgrad_out, None, None, None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FusedLeakyReLUFunction(Function):
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
| 58 |
+
empty = input.new_empty(0)
|
| 59 |
+
out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
| 60 |
+
ctx.save_for_backward(out)
|
| 61 |
+
ctx.negative_slope = negative_slope
|
| 62 |
+
ctx.scale = scale
|
| 63 |
+
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def backward(ctx, grad_output):
|
| 68 |
+
out, = ctx.saved_tensors
|
| 69 |
+
|
| 70 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
|
| 71 |
+
|
| 72 |
+
return grad_input, grad_bias, None, None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FusedLeakyReLU(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
| 81 |
+
self.negative_slope = negative_slope
|
| 82 |
+
self.scale = scale
|
| 83 |
+
|
| 84 |
+
def forward(self, input):
|
| 85 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
|
| 89 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
basicsr/ops/fused_act/src/fused_bias_act.cpp
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
|
| 2 |
+
#include <torch/extension.h>
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input,
|
| 6 |
+
const torch::Tensor& bias,
|
| 7 |
+
const torch::Tensor& refer,
|
| 8 |
+
int act, int grad, float alpha, float scale);
|
| 9 |
+
|
| 10 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 11 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 12 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 13 |
+
|
| 14 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input,
|
| 15 |
+
const torch::Tensor& bias,
|
| 16 |
+
const torch::Tensor& refer,
|
| 17 |
+
int act, int grad, float alpha, float scale) {
|
| 18 |
+
CHECK_CUDA(input);
|
| 19 |
+
CHECK_CUDA(bias);
|
| 20 |
+
|
| 21 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 25 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
| 26 |
+
}
|
basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
|
| 2 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 5 |
+
// To view a copy of this license, visit
|
| 6 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
| 7 |
+
|
| 8 |
+
#include <torch/types.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/ATen.h>
|
| 11 |
+
#include <ATen/AccumulateType.h>
|
| 12 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 13 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 14 |
+
|
| 15 |
+
#include <cuda.h>
|
| 16 |
+
#include <cuda_runtime.h>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
template <typename scalar_t>
|
| 20 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
| 21 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
| 22 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
| 23 |
+
|
| 24 |
+
scalar_t zero = 0.0;
|
| 25 |
+
|
| 26 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
| 27 |
+
scalar_t x = p_x[xi];
|
| 28 |
+
|
| 29 |
+
if (use_bias) {
|
| 30 |
+
x += p_b[(xi / step_b) % size_b];
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
| 34 |
+
|
| 35 |
+
scalar_t y;
|
| 36 |
+
|
| 37 |
+
switch (act * 10 + grad) {
|
| 38 |
+
default:
|
| 39 |
+
case 10: y = x; break;
|
| 40 |
+
case 11: y = x; break;
|
| 41 |
+
case 12: y = 0.0; break;
|
| 42 |
+
|
| 43 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
| 44 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
| 45 |
+
case 32: y = 0.0; break;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
out[xi] = y * scale;
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
| 54 |
+
int act, int grad, float alpha, float scale) {
|
| 55 |
+
int curDevice = -1;
|
| 56 |
+
cudaGetDevice(&curDevice);
|
| 57 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 58 |
+
|
| 59 |
+
auto x = input.contiguous();
|
| 60 |
+
auto b = bias.contiguous();
|
| 61 |
+
auto ref = refer.contiguous();
|
| 62 |
+
|
| 63 |
+
int use_bias = b.numel() ? 1 : 0;
|
| 64 |
+
int use_ref = ref.numel() ? 1 : 0;
|
| 65 |
+
|
| 66 |
+
int size_x = x.numel();
|
| 67 |
+
int size_b = b.numel();
|
| 68 |
+
int step_b = 1;
|
| 69 |
+
|
| 70 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
| 71 |
+
step_b *= x.size(i);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
int loop_x = 4;
|
| 75 |
+
int block_size = 4 * 32;
|
| 76 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
| 77 |
+
|
| 78 |
+
auto y = torch::empty_like(x);
|
| 79 |
+
|
| 80 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
| 81 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
| 82 |
+
y.data_ptr<scalar_t>(),
|
| 83 |
+
x.data_ptr<scalar_t>(),
|
| 84 |
+
b.data_ptr<scalar_t>(),
|
| 85 |
+
ref.data_ptr<scalar_t>(),
|
| 86 |
+
act,
|
| 87 |
+
grad,
|
| 88 |
+
alpha,
|
| 89 |
+
scale,
|
| 90 |
+
loop_x,
|
| 91 |
+
size_x,
|
| 92 |
+
step_b,
|
| 93 |
+
size_b,
|
| 94 |
+
use_bias,
|
| 95 |
+
use_ref
|
| 96 |
+
);
|
| 97 |
+
});
|
| 98 |
+
|
| 99 |
+
return y;
|
| 100 |
+
}
|
basicsr/ops/upfirdn2d/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .upfirdn2d import upfirdn2d
|
| 2 |
+
|
| 3 |
+
__all__ = ['upfirdn2d']
|
basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
|
| 2 |
+
#include <torch/extension.h>
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
| 6 |
+
int up_x, int up_y, int down_x, int down_y,
|
| 7 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
| 8 |
+
|
| 9 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 10 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 11 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 12 |
+
|
| 13 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
| 14 |
+
int up_x, int up_y, int down_x, int down_y,
|
| 15 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
| 16 |
+
CHECK_CUDA(input);
|
| 17 |
+
CHECK_CUDA(kernel);
|
| 18 |
+
|
| 19 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 23 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
| 24 |
+
}
|
basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
|
| 2 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 5 |
+
// To view a copy of this license, visit
|
| 6 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
| 7 |
+
|
| 8 |
+
#include <torch/types.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/ATen.h>
|
| 11 |
+
#include <ATen/AccumulateType.h>
|
| 12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 13 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 14 |
+
|
| 15 |
+
#include <cuda.h>
|
| 16 |
+
#include <cuda_runtime.h>
|
| 17 |
+
|
| 18 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
| 19 |
+
int c = a / b;
|
| 20 |
+
|
| 21 |
+
if (c * b > a) {
|
| 22 |
+
c--;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
return c;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
struct UpFirDn2DKernelParams {
|
| 29 |
+
int up_x;
|
| 30 |
+
int up_y;
|
| 31 |
+
int down_x;
|
| 32 |
+
int down_y;
|
| 33 |
+
int pad_x0;
|
| 34 |
+
int pad_x1;
|
| 35 |
+
int pad_y0;
|
| 36 |
+
int pad_y1;
|
| 37 |
+
|
| 38 |
+
int major_dim;
|
| 39 |
+
int in_h;
|
| 40 |
+
int in_w;
|
| 41 |
+
int minor_dim;
|
| 42 |
+
int kernel_h;
|
| 43 |
+
int kernel_w;
|
| 44 |
+
int out_h;
|
| 45 |
+
int out_w;
|
| 46 |
+
int loop_major;
|
| 47 |
+
int loop_x;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <typename scalar_t>
|
| 51 |
+
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
| 52 |
+
const scalar_t *kernel,
|
| 53 |
+
const UpFirDn2DKernelParams p) {
|
| 54 |
+
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 55 |
+
int out_y = minor_idx / p.minor_dim;
|
| 56 |
+
minor_idx -= out_y * p.minor_dim;
|
| 57 |
+
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
| 58 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
| 59 |
+
|
| 60 |
+
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
| 61 |
+
major_idx_base >= p.major_dim) {
|
| 62 |
+
return;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
| 66 |
+
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
| 67 |
+
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
| 68 |
+
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
| 69 |
+
|
| 70 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
| 71 |
+
loop_major < p.loop_major && major_idx < p.major_dim;
|
| 72 |
+
loop_major++, major_idx++) {
|
| 73 |
+
for (int loop_x = 0, out_x = out_x_base;
|
| 74 |
+
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
| 75 |
+
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
| 76 |
+
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
| 77 |
+
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
| 78 |
+
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
| 79 |
+
|
| 80 |
+
const scalar_t *x_p =
|
| 81 |
+
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
| 82 |
+
minor_idx];
|
| 83 |
+
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
| 84 |
+
int x_px = p.minor_dim;
|
| 85 |
+
int k_px = -p.up_x;
|
| 86 |
+
int x_py = p.in_w * p.minor_dim;
|
| 87 |
+
int k_py = -p.up_y * p.kernel_w;
|
| 88 |
+
|
| 89 |
+
scalar_t v = 0.0f;
|
| 90 |
+
|
| 91 |
+
for (int y = 0; y < h; y++) {
|
| 92 |
+
for (int x = 0; x < w; x++) {
|
| 93 |
+
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
| 94 |
+
x_p += x_px;
|
| 95 |
+
k_p += k_px;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
x_p += x_py - w * x_px;
|
| 99 |
+
k_p += k_py - w * k_px;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
| 103 |
+
minor_idx] = v;
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
| 109 |
+
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
| 110 |
+
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
| 111 |
+
const scalar_t *kernel,
|
| 112 |
+
const UpFirDn2DKernelParams p) {
|
| 113 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
| 114 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
| 115 |
+
|
| 116 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
| 117 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
| 118 |
+
|
| 119 |
+
int minor_idx = blockIdx.x;
|
| 120 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
| 121 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
| 122 |
+
tile_out_y *= tile_out_h;
|
| 123 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
| 124 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
| 125 |
+
|
| 126 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
| 127 |
+
major_idx_base >= p.major_dim) {
|
| 128 |
+
return;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
| 132 |
+
tap_idx += blockDim.x) {
|
| 133 |
+
int ky = tap_idx / kernel_w;
|
| 134 |
+
int kx = tap_idx - ky * kernel_w;
|
| 135 |
+
scalar_t v = 0.0;
|
| 136 |
+
|
| 137 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
| 138 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
sk[ky][kx] = v;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
for (int loop_major = 0, major_idx = major_idx_base;
|
| 145 |
+
loop_major < p.loop_major & major_idx < p.major_dim;
|
| 146 |
+
loop_major++, major_idx++) {
|
| 147 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
| 148 |
+
loop_x < p.loop_x & tile_out_x < p.out_w;
|
| 149 |
+
loop_x++, tile_out_x += tile_out_w) {
|
| 150 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
| 151 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
| 152 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
| 153 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
| 154 |
+
|
| 155 |
+
__syncthreads();
|
| 156 |
+
|
| 157 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
| 158 |
+
in_idx += blockDim.x) {
|
| 159 |
+
int rel_in_y = in_idx / tile_in_w;
|
| 160 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
| 161 |
+
int in_x = rel_in_x + tile_in_x;
|
| 162 |
+
int in_y = rel_in_y + tile_in_y;
|
| 163 |
+
|
| 164 |
+
scalar_t v = 0.0;
|
| 165 |
+
|
| 166 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
| 167 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
| 168 |
+
p.minor_dim +
|
| 169 |
+
minor_idx];
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
sx[rel_in_y][rel_in_x] = v;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
__syncthreads();
|
| 176 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
| 177 |
+
out_idx += blockDim.x) {
|
| 178 |
+
int rel_out_y = out_idx / tile_out_w;
|
| 179 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
| 180 |
+
int out_x = rel_out_x + tile_out_x;
|
| 181 |
+
int out_y = rel_out_y + tile_out_y;
|
| 182 |
+
|
| 183 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
| 184 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
| 185 |
+
int in_x = floor_div(mid_x, up_x);
|
| 186 |
+
int in_y = floor_div(mid_y, up_y);
|
| 187 |
+
int rel_in_x = in_x - tile_in_x;
|
| 188 |
+
int rel_in_y = in_y - tile_in_y;
|
| 189 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
| 190 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
| 191 |
+
|
| 192 |
+
scalar_t v = 0.0;
|
| 193 |
+
|
| 194 |
+
#pragma unroll
|
| 195 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
| 196 |
+
#pragma unroll
|
| 197 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
| 198 |
+
v += sx[rel_in_y + y][rel_in_x + x] *
|
| 199 |
+
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
| 200 |
+
|
| 201 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
| 202 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
| 203 |
+
minor_idx] = v;
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
| 211 |
+
const torch::Tensor &kernel, int up_x, int up_y,
|
| 212 |
+
int down_x, int down_y, int pad_x0, int pad_x1,
|
| 213 |
+
int pad_y0, int pad_y1) {
|
| 214 |
+
int curDevice = -1;
|
| 215 |
+
cudaGetDevice(&curDevice);
|
| 216 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 217 |
+
|
| 218 |
+
UpFirDn2DKernelParams p;
|
| 219 |
+
|
| 220 |
+
auto x = input.contiguous();
|
| 221 |
+
auto k = kernel.contiguous();
|
| 222 |
+
|
| 223 |
+
p.major_dim = x.size(0);
|
| 224 |
+
p.in_h = x.size(1);
|
| 225 |
+
p.in_w = x.size(2);
|
| 226 |
+
p.minor_dim = x.size(3);
|
| 227 |
+
p.kernel_h = k.size(0);
|
| 228 |
+
p.kernel_w = k.size(1);
|
| 229 |
+
p.up_x = up_x;
|
| 230 |
+
p.up_y = up_y;
|
| 231 |
+
p.down_x = down_x;
|
| 232 |
+
p.down_y = down_y;
|
| 233 |
+
p.pad_x0 = pad_x0;
|
| 234 |
+
p.pad_x1 = pad_x1;
|
| 235 |
+
p.pad_y0 = pad_y0;
|
| 236 |
+
p.pad_y1 = pad_y1;
|
| 237 |
+
|
| 238 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
| 239 |
+
p.down_y;
|
| 240 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
| 241 |
+
p.down_x;
|
| 242 |
+
|
| 243 |
+
auto out =
|
| 244 |
+
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
| 245 |
+
|
| 246 |
+
int mode = -1;
|
| 247 |
+
|
| 248 |
+
int tile_out_h = -1;
|
| 249 |
+
int tile_out_w = -1;
|
| 250 |
+
|
| 251 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
| 252 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
| 253 |
+
mode = 1;
|
| 254 |
+
tile_out_h = 16;
|
| 255 |
+
tile_out_w = 64;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
| 259 |
+
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
| 260 |
+
mode = 2;
|
| 261 |
+
tile_out_h = 16;
|
| 262 |
+
tile_out_w = 64;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
| 266 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
| 267 |
+
mode = 3;
|
| 268 |
+
tile_out_h = 16;
|
| 269 |
+
tile_out_w = 64;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
| 273 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
| 274 |
+
mode = 4;
|
| 275 |
+
tile_out_h = 16;
|
| 276 |
+
tile_out_w = 64;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
| 280 |
+
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
| 281 |
+
mode = 5;
|
| 282 |
+
tile_out_h = 8;
|
| 283 |
+
tile_out_w = 32;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
| 287 |
+
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
| 288 |
+
mode = 6;
|
| 289 |
+
tile_out_h = 8;
|
| 290 |
+
tile_out_w = 32;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
dim3 block_size;
|
| 294 |
+
dim3 grid_size;
|
| 295 |
+
|
| 296 |
+
if (tile_out_h > 0 && tile_out_w > 0) {
|
| 297 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
| 298 |
+
p.loop_x = 1;
|
| 299 |
+
block_size = dim3(32 * 8, 1, 1);
|
| 300 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
| 301 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
| 302 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
| 303 |
+
} else {
|
| 304 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
| 305 |
+
p.loop_x = 4;
|
| 306 |
+
block_size = dim3(4, 32, 1);
|
| 307 |
+
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
| 308 |
+
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
| 309 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
| 313 |
+
switch (mode) {
|
| 314 |
+
case 1:
|
| 315 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
| 316 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 317 |
+
x.data_ptr<scalar_t>(),
|
| 318 |
+
k.data_ptr<scalar_t>(), p);
|
| 319 |
+
|
| 320 |
+
break;
|
| 321 |
+
|
| 322 |
+
case 2:
|
| 323 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
| 324 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 325 |
+
x.data_ptr<scalar_t>(),
|
| 326 |
+
k.data_ptr<scalar_t>(), p);
|
| 327 |
+
|
| 328 |
+
break;
|
| 329 |
+
|
| 330 |
+
case 3:
|
| 331 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
| 332 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 333 |
+
x.data_ptr<scalar_t>(),
|
| 334 |
+
k.data_ptr<scalar_t>(), p);
|
| 335 |
+
|
| 336 |
+
break;
|
| 337 |
+
|
| 338 |
+
case 4:
|
| 339 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
| 340 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 341 |
+
x.data_ptr<scalar_t>(),
|
| 342 |
+
k.data_ptr<scalar_t>(), p);
|
| 343 |
+
|
| 344 |
+
break;
|
| 345 |
+
|
| 346 |
+
case 5:
|
| 347 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
| 348 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 349 |
+
x.data_ptr<scalar_t>(),
|
| 350 |
+
k.data_ptr<scalar_t>(), p);
|
| 351 |
+
|
| 352 |
+
break;
|
| 353 |
+
|
| 354 |
+
case 6:
|
| 355 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
| 356 |
+
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
| 357 |
+
x.data_ptr<scalar_t>(),
|
| 358 |
+
k.data_ptr<scalar_t>(), p);
|
| 359 |
+
|
| 360 |
+
break;
|
| 361 |
+
|
| 362 |
+
default:
|
| 363 |
+
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
| 364 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
| 365 |
+
k.data_ptr<scalar_t>(), p);
|
| 366 |
+
}
|
| 367 |
+
});
|
| 368 |
+
|
| 369 |
+
return out;
|
| 370 |
+
}
|
basicsr/ops/upfirdn2d/upfirdn2d.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.autograd import Function
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from . import upfirdn2d_ext
|
| 9 |
+
except ImportError:
|
| 10 |
+
import os
|
| 11 |
+
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
| 12 |
+
if BASICSR_JIT == 'True':
|
| 13 |
+
from torch.utils.cpp_extension import load
|
| 14 |
+
module_path = os.path.dirname(__file__)
|
| 15 |
+
upfirdn2d_ext = load(
|
| 16 |
+
'upfirdn2d',
|
| 17 |
+
sources=[
|
| 18 |
+
os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
|
| 19 |
+
os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
|
| 20 |
+
],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class UpFirDn2dBackward(Function):
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
|
| 28 |
+
|
| 29 |
+
up_x, up_y = up
|
| 30 |
+
down_x, down_y = down
|
| 31 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
| 32 |
+
|
| 33 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
| 34 |
+
|
| 35 |
+
grad_input = upfirdn2d_ext.upfirdn2d(
|
| 36 |
+
grad_output,
|
| 37 |
+
grad_kernel,
|
| 38 |
+
down_x,
|
| 39 |
+
down_y,
|
| 40 |
+
up_x,
|
| 41 |
+
up_y,
|
| 42 |
+
g_pad_x0,
|
| 43 |
+
g_pad_x1,
|
| 44 |
+
g_pad_y0,
|
| 45 |
+
g_pad_y1,
|
| 46 |
+
)
|
| 47 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
| 48 |
+
|
| 49 |
+
ctx.save_for_backward(kernel)
|
| 50 |
+
|
| 51 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
| 52 |
+
|
| 53 |
+
ctx.up_x = up_x
|
| 54 |
+
ctx.up_y = up_y
|
| 55 |
+
ctx.down_x = down_x
|
| 56 |
+
ctx.down_y = down_y
|
| 57 |
+
ctx.pad_x0 = pad_x0
|
| 58 |
+
ctx.pad_x1 = pad_x1
|
| 59 |
+
ctx.pad_y0 = pad_y0
|
| 60 |
+
ctx.pad_y1 = pad_y1
|
| 61 |
+
ctx.in_size = in_size
|
| 62 |
+
ctx.out_size = out_size
|
| 63 |
+
|
| 64 |
+
return grad_input
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def backward(ctx, gradgrad_input):
|
| 68 |
+
kernel, = ctx.saved_tensors
|
| 69 |
+
|
| 70 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
| 71 |
+
|
| 72 |
+
gradgrad_out = upfirdn2d_ext.upfirdn2d(
|
| 73 |
+
gradgrad_input,
|
| 74 |
+
kernel,
|
| 75 |
+
ctx.up_x,
|
| 76 |
+
ctx.up_y,
|
| 77 |
+
ctx.down_x,
|
| 78 |
+
ctx.down_y,
|
| 79 |
+
ctx.pad_x0,
|
| 80 |
+
ctx.pad_x1,
|
| 81 |
+
ctx.pad_y0,
|
| 82 |
+
ctx.pad_y1,
|
| 83 |
+
)
|
| 84 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
|
| 85 |
+
# ctx.out_size[1], ctx.in_size[3])
|
| 86 |
+
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
|
| 87 |
+
|
| 88 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class UpFirDn2d(Function):
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def forward(ctx, input, kernel, up, down, pad):
|
| 95 |
+
up_x, up_y = up
|
| 96 |
+
down_x, down_y = down
|
| 97 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
| 98 |
+
|
| 99 |
+
kernel_h, kernel_w = kernel.shape
|
| 100 |
+
batch, channel, in_h, in_w = input.shape
|
| 101 |
+
ctx.in_size = input.shape
|
| 102 |
+
|
| 103 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
| 104 |
+
|
| 105 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
| 106 |
+
|
| 107 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
| 108 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
| 109 |
+
ctx.out_size = (out_h, out_w)
|
| 110 |
+
|
| 111 |
+
ctx.up = (up_x, up_y)
|
| 112 |
+
ctx.down = (down_x, down_y)
|
| 113 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
| 114 |
+
|
| 115 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
| 116 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
| 117 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
| 118 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
| 119 |
+
|
| 120 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
| 121 |
+
|
| 122 |
+
out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
|
| 123 |
+
# out = out.view(major, out_h, out_w, minor)
|
| 124 |
+
out = out.view(-1, channel, out_h, out_w)
|
| 125 |
+
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def backward(ctx, grad_output):
|
| 130 |
+
kernel, grad_kernel = ctx.saved_tensors
|
| 131 |
+
|
| 132 |
+
grad_input = UpFirDn2dBackward.apply(
|
| 133 |
+
grad_output,
|
| 134 |
+
kernel,
|
| 135 |
+
grad_kernel,
|
| 136 |
+
ctx.up,
|
| 137 |
+
ctx.down,
|
| 138 |
+
ctx.pad,
|
| 139 |
+
ctx.g_pad,
|
| 140 |
+
ctx.in_size,
|
| 141 |
+
ctx.out_size,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return grad_input, None, None, None, None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
| 148 |
+
if input.device.type == 'cpu':
|
| 149 |
+
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
| 150 |
+
else:
|
| 151 |
+
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
|
| 152 |
+
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
| 157 |
+
_, channel, in_h, in_w = input.shape
|
| 158 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
| 159 |
+
|
| 160 |
+
_, in_h, in_w, minor = input.shape
|
| 161 |
+
kernel_h, kernel_w = kernel.shape
|
| 162 |
+
|
| 163 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
| 164 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
| 165 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
| 166 |
+
|
| 167 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
| 168 |
+
out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
|
| 169 |
+
|
| 170 |
+
out = out.permute(0, 3, 1, 2)
|
| 171 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
| 172 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
| 173 |
+
out = F.conv2d(out, w)
|
| 174 |
+
out = out.reshape(
|
| 175 |
+
-1,
|
| 176 |
+
minor,
|
| 177 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
| 178 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
| 179 |
+
)
|
| 180 |
+
out = out.permute(0, 2, 3, 1)
|
| 181 |
+
out = out[:, ::down_y, ::down_x, :]
|
| 182 |
+
|
| 183 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
| 184 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
| 185 |
+
|
| 186 |
+
return out.view(-1, channel, out_h, out_w)
|
basicsr/setup.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
from setuptools import find_packages, setup
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
| 11 |
+
|
| 12 |
+
version_file = './basicsr/version.py'
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def readme():
|
| 16 |
+
with open('README.md', encoding='utf-8') as f:
|
| 17 |
+
content = f.read()
|
| 18 |
+
return content
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_git_hash():
|
| 22 |
+
|
| 23 |
+
def _minimal_ext_cmd(cmd):
|
| 24 |
+
# construct minimal environment
|
| 25 |
+
env = {}
|
| 26 |
+
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
| 27 |
+
v = os.environ.get(k)
|
| 28 |
+
if v is not None:
|
| 29 |
+
env[k] = v
|
| 30 |
+
# LANGUAGE is used on win32
|
| 31 |
+
env['LANGUAGE'] = 'C'
|
| 32 |
+
env['LANG'] = 'C'
|
| 33 |
+
env['LC_ALL'] = 'C'
|
| 34 |
+
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
| 39 |
+
sha = out.strip().decode('ascii')
|
| 40 |
+
except OSError:
|
| 41 |
+
sha = 'unknown'
|
| 42 |
+
|
| 43 |
+
return sha
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_hash():
|
| 47 |
+
if os.path.exists('.git'):
|
| 48 |
+
sha = get_git_hash()[:7]
|
| 49 |
+
elif os.path.exists(version_file):
|
| 50 |
+
try:
|
| 51 |
+
from version import __version__
|
| 52 |
+
sha = __version__.split('+')[-1]
|
| 53 |
+
except ImportError:
|
| 54 |
+
raise ImportError('Unable to get git version')
|
| 55 |
+
else:
|
| 56 |
+
sha = 'unknown'
|
| 57 |
+
|
| 58 |
+
return sha
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def write_version_py():
|
| 62 |
+
content = """# GENERATED VERSION FILE
|
| 63 |
+
# TIME: {}
|
| 64 |
+
__version__ = '{}'
|
| 65 |
+
__gitsha__ = '{}'
|
| 66 |
+
version_info = ({})
|
| 67 |
+
"""
|
| 68 |
+
sha = get_hash()
|
| 69 |
+
with open('./basicsr/VERSION', 'r') as f:
|
| 70 |
+
SHORT_VERSION = f.read().strip()
|
| 71 |
+
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
| 72 |
+
|
| 73 |
+
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
| 74 |
+
with open(version_file, 'w') as f:
|
| 75 |
+
f.write(version_file_str)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_version():
|
| 79 |
+
with open(version_file, 'r') as f:
|
| 80 |
+
exec(compile(f.read(), version_file, 'exec'))
|
| 81 |
+
return locals()['__version__']
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def make_cuda_ext(name, module, sources, sources_cuda=None):
|
| 85 |
+
if sources_cuda is None:
|
| 86 |
+
sources_cuda = []
|
| 87 |
+
define_macros = []
|
| 88 |
+
extra_compile_args = {'cxx': []}
|
| 89 |
+
|
| 90 |
+
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
|
| 91 |
+
define_macros += [('WITH_CUDA', None)]
|
| 92 |
+
extension = CUDAExtension
|
| 93 |
+
extra_compile_args['nvcc'] = [
|
| 94 |
+
'-D__CUDA_NO_HALF_OPERATORS__',
|
| 95 |
+
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
| 96 |
+
'-D__CUDA_NO_HALF2_OPERATORS__',
|
| 97 |
+
]
|
| 98 |
+
sources += sources_cuda
|
| 99 |
+
else:
|
| 100 |
+
print(f'Compiling {name} without CUDA')
|
| 101 |
+
extension = CppExtension
|
| 102 |
+
|
| 103 |
+
return extension(
|
| 104 |
+
name=f'{module}.{name}',
|
| 105 |
+
sources=[os.path.join(*module.split('.'), p) for p in sources],
|
| 106 |
+
define_macros=define_macros,
|
| 107 |
+
extra_compile_args=extra_compile_args)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_requirements(filename='requirements.txt'):
|
| 111 |
+
with open(os.path.join('.', filename), 'r') as f:
|
| 112 |
+
requires = [line.replace('\n', '') for line in f.readlines()]
|
| 113 |
+
return requires
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == '__main__':
|
| 117 |
+
if '--cuda_ext' in sys.argv:
|
| 118 |
+
ext_modules = [
|
| 119 |
+
make_cuda_ext(
|
| 120 |
+
name='deform_conv_ext',
|
| 121 |
+
module='ops.dcn',
|
| 122 |
+
sources=['src/deform_conv_ext.cpp'],
|
| 123 |
+
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
|
| 124 |
+
make_cuda_ext(
|
| 125 |
+
name='fused_act_ext',
|
| 126 |
+
module='ops.fused_act',
|
| 127 |
+
sources=['src/fused_bias_act.cpp'],
|
| 128 |
+
sources_cuda=['src/fused_bias_act_kernel.cu']),
|
| 129 |
+
make_cuda_ext(
|
| 130 |
+
name='upfirdn2d_ext',
|
| 131 |
+
module='ops.upfirdn2d',
|
| 132 |
+
sources=['src/upfirdn2d.cpp'],
|
| 133 |
+
sources_cuda=['src/upfirdn2d_kernel.cu']),
|
| 134 |
+
]
|
| 135 |
+
sys.argv.remove('--cuda_ext')
|
| 136 |
+
else:
|
| 137 |
+
ext_modules = []
|
| 138 |
+
|
| 139 |
+
write_version_py()
|
| 140 |
+
setup(
|
| 141 |
+
name='basicsr',
|
| 142 |
+
version=get_version(),
|
| 143 |
+
description='Open Source Image and Video Super-Resolution Toolbox',
|
| 144 |
+
long_description=readme(),
|
| 145 |
+
long_description_content_type='text/markdown',
|
| 146 |
+
author='Xintao Wang',
|
| 147 |
+
author_email='xintao.wang@outlook.com',
|
| 148 |
+
keywords='computer vision, restoration, super resolution',
|
| 149 |
+
url='https://github.com/xinntao/BasicSR',
|
| 150 |
+
include_package_data=True,
|
| 151 |
+
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
| 152 |
+
classifiers=[
|
| 153 |
+
'Development Status :: 4 - Beta',
|
| 154 |
+
'License :: OSI Approved :: Apache Software License',
|
| 155 |
+
'Operating System :: OS Independent',
|
| 156 |
+
'Programming Language :: Python :: 3',
|
| 157 |
+
'Programming Language :: Python :: 3.7',
|
| 158 |
+
'Programming Language :: Python :: 3.8',
|
| 159 |
+
],
|
| 160 |
+
license='Apache License 2.0',
|
| 161 |
+
setup_requires=['cython', 'numpy'],
|
| 162 |
+
install_requires=get_requirements(),
|
| 163 |
+
ext_modules=ext_modules,
|
| 164 |
+
cmdclass={'build_ext': BuildExtension},
|
| 165 |
+
zip_safe=False)
|
basicsr/train.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import platform
|
| 9 |
+
from os import path as osp
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
from basicsr.data import build_dataloader, build_dataset
|
| 13 |
+
from basicsr.data.data_sampler import EnlargedSampler
|
| 14 |
+
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
|
| 15 |
+
from basicsr.models import build_model
|
| 16 |
+
from basicsr.utils import (
|
| 17 |
+
MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
|
| 18 |
+
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed
|
| 19 |
+
)
|
| 20 |
+
from basicsr.utils.dist_util import get_dist_info, init_dist
|
| 21 |
+
from basicsr.utils.options import dict2str, parse
|
| 22 |
+
|
| 23 |
+
# ----------- DEVICE SELECTION ----------
|
| 24 |
+
def select_device(prefer_coreml=True):
|
| 25 |
+
if torch.backends.mps.is_available() and prefer_coreml and platform.system() == "Darwin":
|
| 26 |
+
print("BasicSR: Using CoreML backend (MPS).")
|
| 27 |
+
return torch.device("mps")
|
| 28 |
+
elif torch.cuda.is_available():
|
| 29 |
+
print("BasicSR: Using CUDA backend.")
|
| 30 |
+
return torch.device("cuda")
|
| 31 |
+
else:
|
| 32 |
+
print("BasicSR: Using CPU backend.")
|
| 33 |
+
return torch.device("cpu")
|
| 34 |
+
|
| 35 |
+
DEVICE = select_device(prefer_coreml=True)
|
| 36 |
+
|
| 37 |
+
# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
|
| 38 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 39 |
+
|
| 40 |
+
def parse_options(root_path, is_train=True):
|
| 41 |
+
parser = argparse.ArgumentParser()
|
| 42 |
+
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
|
| 43 |
+
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
|
| 44 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
opt = parse(args.opt, root_path, is_train=is_train)
|
| 47 |
+
|
| 48 |
+
# distributed settings
|
| 49 |
+
if args.launcher == 'none' or DEVICE.type != 'cuda':
|
| 50 |
+
opt['dist'] = False
|
| 51 |
+
print('Distributed training disabled.', flush=True)
|
| 52 |
+
else:
|
| 53 |
+
opt['dist'] = True
|
| 54 |
+
if args.launcher == 'slurm' and 'dist_params' in opt:
|
| 55 |
+
init_dist(args.launcher, **opt['dist_params'])
|
| 56 |
+
else:
|
| 57 |
+
init_dist(args.launcher)
|
| 58 |
+
|
| 59 |
+
opt['rank'], opt['world_size'] = get_dist_info()
|
| 60 |
+
|
| 61 |
+
# random seed
|
| 62 |
+
seed = opt.get('manual_seed')
|
| 63 |
+
if seed is None:
|
| 64 |
+
seed = random.randint(1, 10000)
|
| 65 |
+
opt['manual_seed'] = seed
|
| 66 |
+
set_random_seed(seed + opt['rank'])
|
| 67 |
+
|
| 68 |
+
return opt
|
| 69 |
+
|
| 70 |
+
def init_loggers(opt):
|
| 71 |
+
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
|
| 72 |
+
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
|
| 73 |
+
logger.info(get_env_info())
|
| 74 |
+
logger.info(dict2str(opt))
|
| 75 |
+
|
| 76 |
+
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
|
| 77 |
+
assert opt['logger'].get('use_tb_logger') is True
|
| 78 |
+
init_wandb_logger(opt)
|
| 79 |
+
|
| 80 |
+
tb_logger = None
|
| 81 |
+
if opt['logger'].get('use_tb_logger'):
|
| 82 |
+
tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
|
| 83 |
+
return logger, tb_logger
|
| 84 |
+
|
| 85 |
+
def create_train_val_dataloader(opt, logger):
|
| 86 |
+
train_loader, val_loader = None, None
|
| 87 |
+
for phase, dataset_opt in opt['datasets'].items():
|
| 88 |
+
if phase == 'train':
|
| 89 |
+
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
|
| 90 |
+
train_set = build_dataset(dataset_opt)
|
| 91 |
+
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
|
| 92 |
+
train_loader = build_dataloader(train_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=train_sampler, seed=opt['manual_seed'])
|
| 93 |
+
num_iter_per_epoch = math.ceil(len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
|
| 94 |
+
total_iters = int(opt['train']['total_iter'])
|
| 95 |
+
total_epochs = math.ceil(total_iters / num_iter_per_epoch)
|
| 96 |
+
logger.info(f'Training stats:\n\tTrain images: {len(train_set)}\n\tEnlarge ratio: {dataset_enlarge_ratio}\n\tBatch/GPU: {dataset_opt["batch_size_per_gpu"]}\n\tGPUs: {opt["world_size"]}\n\tIters/epoch: {num_iter_per_epoch}\n\tTotal epochs: {total_epochs}, Iters: {total_iters}')
|
| 97 |
+
|
| 98 |
+
elif phase == 'val':
|
| 99 |
+
val_set = build_dataset(dataset_opt)
|
| 100 |
+
val_loader = build_dataloader(val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
|
| 101 |
+
logger.info(f'Validation items in {dataset_opt["name"]}: {len(val_set)}')
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f'Dataset phase {phase} not recognized.')
|
| 104 |
+
|
| 105 |
+
return train_loader, train_sampler, val_loader, total_epochs, total_iters
|
| 106 |
+
|
| 107 |
+
def train_pipeline(root_path):
|
| 108 |
+
opt = parse_options(root_path, is_train=True)
|
| 109 |
+
|
| 110 |
+
if DEVICE.type == 'cuda':
|
| 111 |
+
torch.backends.cudnn.benchmark = True
|
| 112 |
+
|
| 113 |
+
if opt['path'].get('resume_state'):
|
| 114 |
+
resume_state = torch.load(opt['path']['resume_state'], map_location=DEVICE)
|
| 115 |
+
else:
|
| 116 |
+
resume_state = None
|
| 117 |
+
|
| 118 |
+
if resume_state is None:
|
| 119 |
+
make_exp_dirs(opt)
|
| 120 |
+
if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
|
| 121 |
+
mkdir_and_rename(osp.join('tb_logger', opt['name']))
|
| 122 |
+
|
| 123 |
+
logger, tb_logger = init_loggers(opt)
|
| 124 |
+
train_loader, train_sampler, val_loader, total_epochs, total_iters = create_train_val_dataloader(opt, logger)
|
| 125 |
+
|
| 126 |
+
if resume_state:
|
| 127 |
+
check_resume(opt, resume_state['iter'])
|
| 128 |
+
model = build_model(opt).to(DEVICE)
|
| 129 |
+
model.resume_training(resume_state)
|
| 130 |
+
logger.info(f"Resuming from epoch {resume_state['epoch']}, iter {resume_state['iter']}")
|
| 131 |
+
start_epoch = resume_state['epoch']
|
| 132 |
+
current_iter = resume_state['iter']
|
| 133 |
+
else:
|
| 134 |
+
model = build_model(opt).to(DEVICE)
|
| 135 |
+
start_epoch = 0
|
| 136 |
+
current_iter = 0
|
| 137 |
+
|
| 138 |
+
msg_logger = MessageLogger(opt, current_iter, tb_logger)
|
| 139 |
+
|
| 140 |
+
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
|
| 141 |
+
if prefetch_mode is None or prefetch_mode == 'cpu' or DEVICE.type in ['cpu', 'mps']:
|
| 142 |
+
if prefetch_mode == 'cuda' and DEVICE.type == 'mps':
|
| 143 |
+
logger.warning("CUDA prefetch requested but MPS (CoreML) is in use. Falling back to CPU prefetch.")
|
| 144 |
+
prefetcher = CPUPrefetcher(train_loader)
|
| 145 |
+
elif prefetch_mode == 'cuda':
|
| 146 |
+
if DEVICE.type != 'cuda':
|
| 147 |
+
logger.warning("CUDA prefetch requested but CUDA unavailable. Using CPU prefetch.")
|
| 148 |
+
prefetcher = CPUPrefetcher(train_loader)
|
| 149 |
+
else:
|
| 150 |
+
if opt['datasets']['train'].get('pin_memory') is not True:
|
| 151 |
+
raise ValueError('Set pin_memory=True for CUDAPrefetcher.')
|
| 152 |
+
prefetcher = CUDAPrefetcher(train_loader, opt)
|
| 153 |
+
logger.info(f'Using CUDA prefetcher')
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError(f"Invalid prefetch_mode: {prefetch_mode}. Supported: 'cpu', 'cuda', None")
|
| 156 |
+
|
| 157 |
+
logger.info(f'Start training at epoch {start_epoch}, iter {current_iter + 1}')
|
| 158 |
+
start_time = time.time()
|
| 159 |
+
data_time, iter_time = time.time(), time.time()
|
| 160 |
+
|
| 161 |
+
for epoch in range(start_epoch, total_epochs + 1):
|
| 162 |
+
train_sampler.set_epoch(epoch)
|
| 163 |
+
prefetcher.reset()
|
| 164 |
+
train_data = prefetcher.next()
|
| 165 |
+
|
| 166 |
+
while train_data is not None:
|
| 167 |
+
data_time = time.time() - data_time
|
| 168 |
+
current_iter += 1
|
| 169 |
+
if current_iter > total_iters:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
|
| 173 |
+
model.feed_data(train_data)
|
| 174 |
+
model.optimize_parameters(current_iter)
|
| 175 |
+
|
| 176 |
+
iter_time = time.time() - iter_time
|
| 177 |
+
if current_iter % opt['logger']['print_freq'] == 0:
|
| 178 |
+
log_vars = {'epoch': epoch, 'iter': current_iter}
|
| 179 |
+
log_vars.update({'lrs': model.get_current_learning_rate()})
|
| 180 |
+
log_vars.update({'time': iter_time, 'data_time': data_time})
|
| 181 |
+
log_vars.update(model.get_current_log())
|
| 182 |
+
msg_logger(log_vars)
|
| 183 |
+
|
| 184 |
+
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
|
| 185 |
+
logger.info('Saving model and training state.')
|
| 186 |
+
model.save(epoch, current_iter)
|
| 187 |
+
|
| 188 |
+
if opt.get('val') and opt['datasets'].get('val') and (current_iter % opt['val']['val_freq'] == 0):
|
| 189 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
| 190 |
+
|
| 191 |
+
data_time = time.time()
|
| 192 |
+
iter_time = time.time()
|
| 193 |
+
train_data = prefetcher.next()
|
| 194 |
+
|
| 195 |
+
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
|
| 196 |
+
logger.info(f'Training complete. Time: {consumed_time}')
|
| 197 |
+
logger.info('Saving latest model.')
|
| 198 |
+
model.save(epoch=-1, current_iter=-1)
|
| 199 |
+
|
| 200 |
+
if opt.get('val') and opt['datasets'].get('val'):
|
| 201 |
+
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
|
| 202 |
+
if tb_logger:
|
| 203 |
+
tb_logger.close()
|
| 204 |
+
|
| 205 |
+
if __name__ == '__main__':
|
| 206 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 207 |
+
train_pipeline(root_path)
|
basicsr/utils/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .file_client import FileClient
|
| 2 |
+
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
| 3 |
+
from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
| 4 |
+
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
# file_client.py
|
| 8 |
+
'FileClient',
|
| 9 |
+
# img_util.py
|
| 10 |
+
'img2tensor',
|
| 11 |
+
'tensor2img',
|
| 12 |
+
'imfrombytes',
|
| 13 |
+
'imwrite',
|
| 14 |
+
'crop_border',
|
| 15 |
+
# logger.py
|
| 16 |
+
'MessageLogger',
|
| 17 |
+
'init_tb_logger',
|
| 18 |
+
'init_wandb_logger',
|
| 19 |
+
'get_root_logger',
|
| 20 |
+
'get_env_info',
|
| 21 |
+
# misc.py
|
| 22 |
+
'set_random_seed',
|
| 23 |
+
'get_time_str',
|
| 24 |
+
'mkdir_and_rename',
|
| 25 |
+
'make_exp_dirs',
|
| 26 |
+
'scandir',
|
| 27 |
+
'check_resume',
|
| 28 |
+
'sizeof_fmt'
|
| 29 |
+
]
|
basicsr/utils/dist_util.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import torch.multiprocessing as mp
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
| 11 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 12 |
+
mp.set_start_method('spawn')
|
| 13 |
+
if launcher == 'pytorch':
|
| 14 |
+
_init_dist_pytorch(backend, **kwargs)
|
| 15 |
+
elif launcher == 'slurm':
|
| 16 |
+
_init_dist_slurm(backend, **kwargs)
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _init_dist_pytorch(backend, **kwargs):
|
| 22 |
+
rank = int(os.environ['RANK'])
|
| 23 |
+
num_gpus = torch.cuda.device_count()
|
| 24 |
+
torch.cuda.set_device(rank % num_gpus)
|
| 25 |
+
dist.init_process_group(backend=backend, **kwargs)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _init_dist_slurm(backend, port=None):
|
| 29 |
+
"""Initialize slurm distributed training environment.
|
| 30 |
+
|
| 31 |
+
If argument ``port`` is not specified, then the master port will be system
|
| 32 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
| 33 |
+
environment variable, then a default port ``29500`` will be used.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
backend (str): Backend of torch.distributed.
|
| 37 |
+
port (int, optional): Master port. Defaults to None.
|
| 38 |
+
"""
|
| 39 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
| 40 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
| 41 |
+
node_list = os.environ['SLURM_NODELIST']
|
| 42 |
+
num_gpus = torch.cuda.device_count()
|
| 43 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
| 44 |
+
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
| 45 |
+
# specify master port
|
| 46 |
+
if port is not None:
|
| 47 |
+
os.environ['MASTER_PORT'] = str(port)
|
| 48 |
+
elif 'MASTER_PORT' in os.environ:
|
| 49 |
+
pass # use MASTER_PORT in the environment variable
|
| 50 |
+
else:
|
| 51 |
+
# 29500 is torch.distributed default port
|
| 52 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 53 |
+
os.environ['MASTER_ADDR'] = addr
|
| 54 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
| 55 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
| 56 |
+
os.environ['RANK'] = str(proc_id)
|
| 57 |
+
dist.init_process_group(backend=backend)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_dist_info():
|
| 61 |
+
if dist.is_available():
|
| 62 |
+
initialized = dist.is_initialized()
|
| 63 |
+
else:
|
| 64 |
+
initialized = False
|
| 65 |
+
if initialized:
|
| 66 |
+
rank = dist.get_rank()
|
| 67 |
+
world_size = dist.get_world_size()
|
| 68 |
+
else:
|
| 69 |
+
rank = 0
|
| 70 |
+
world_size = 1
|
| 71 |
+
return rank, world_size
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def master_only(func):
|
| 75 |
+
|
| 76 |
+
@functools.wraps(func)
|
| 77 |
+
def wrapper(*args, **kwargs):
|
| 78 |
+
rank, _ = get_dist_info()
|
| 79 |
+
if rank == 0:
|
| 80 |
+
return func(*args, **kwargs)
|
| 81 |
+
|
| 82 |
+
return wrapper
|
basicsr/utils/download_util.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import requests
|
| 4 |
+
from torch.hub import download_url_to_file, get_dir
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from urllib.parse import urlparse
|
| 7 |
+
|
| 8 |
+
from .misc import sizeof_fmt
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def download_file_from_google_drive(file_id, save_path):
|
| 12 |
+
"""Download files from google drive.
|
| 13 |
+
Ref:
|
| 14 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
| 15 |
+
Args:
|
| 16 |
+
file_id (str): File id.
|
| 17 |
+
save_path (str): Save path.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
session = requests.Session()
|
| 21 |
+
URL = 'https://docs.google.com/uc?export=download'
|
| 22 |
+
params = {'id': file_id}
|
| 23 |
+
|
| 24 |
+
response = session.get(URL, params=params, stream=True)
|
| 25 |
+
token = get_confirm_token(response)
|
| 26 |
+
if token:
|
| 27 |
+
params['confirm'] = token
|
| 28 |
+
response = session.get(URL, params=params, stream=True)
|
| 29 |
+
|
| 30 |
+
# get file size
|
| 31 |
+
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
| 32 |
+
print(response_file_size)
|
| 33 |
+
if 'Content-Range' in response_file_size.headers:
|
| 34 |
+
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
| 35 |
+
else:
|
| 36 |
+
file_size = None
|
| 37 |
+
|
| 38 |
+
save_response_content(response, save_path, file_size)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_confirm_token(response):
|
| 42 |
+
for key, value in response.cookies.items():
|
| 43 |
+
if key.startswith('download_warning'):
|
| 44 |
+
return value
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
| 49 |
+
if file_size is not None:
|
| 50 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
| 51 |
+
|
| 52 |
+
readable_file_size = sizeof_fmt(file_size)
|
| 53 |
+
else:
|
| 54 |
+
pbar = None
|
| 55 |
+
|
| 56 |
+
with open(destination, 'wb') as f:
|
| 57 |
+
downloaded_size = 0
|
| 58 |
+
for chunk in response.iter_content(chunk_size):
|
| 59 |
+
downloaded_size += chunk_size
|
| 60 |
+
if pbar is not None:
|
| 61 |
+
pbar.update(1)
|
| 62 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
| 63 |
+
if chunk: # filter out keep-alive new chunks
|
| 64 |
+
f.write(chunk)
|
| 65 |
+
if pbar is not None:
|
| 66 |
+
pbar.close()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
| 70 |
+
"""Load file form http url, will download models if necessary.
|
| 71 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
| 72 |
+
Args:
|
| 73 |
+
url (str): URL to be downloaded.
|
| 74 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
| 75 |
+
Default: None.
|
| 76 |
+
progress (bool): Whether to show the download progress. Default: True.
|
| 77 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
| 78 |
+
Returns:
|
| 79 |
+
str: The path to the downloaded file.
|
| 80 |
+
"""
|
| 81 |
+
if model_dir is None: # use the pytorch hub_dir
|
| 82 |
+
hub_dir = get_dir()
|
| 83 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
| 84 |
+
|
| 85 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
parts = urlparse(url)
|
| 88 |
+
filename = os.path.basename(parts.path)
|
| 89 |
+
if file_name is not None:
|
| 90 |
+
filename = file_name
|
| 91 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
| 92 |
+
if not os.path.exists(cached_file):
|
| 93 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
| 94 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
| 95 |
+
return cached_file
|
basicsr/utils/file_client.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
| 2 |
+
from abc import ABCMeta, abstractmethod
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
| 6 |
+
"""Abstract class of storage backends.
|
| 7 |
+
|
| 8 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
| 9 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
| 10 |
+
as texts.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def get(self, filepath):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def get_text(self, filepath):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MemcachedBackend(BaseStorageBackend):
|
| 23 |
+
"""Memcached storage backend.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
server_list_cfg (str): Config file for memcached server list.
|
| 27 |
+
client_cfg (str): Config file for memcached client.
|
| 28 |
+
sys_path (str | None): Additional path to be appended to `sys.path`.
|
| 29 |
+
Default: None.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
| 33 |
+
if sys_path is not None:
|
| 34 |
+
import sys
|
| 35 |
+
sys.path.append(sys_path)
|
| 36 |
+
try:
|
| 37 |
+
import mc
|
| 38 |
+
except ImportError:
|
| 39 |
+
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
| 40 |
+
|
| 41 |
+
self.server_list_cfg = server_list_cfg
|
| 42 |
+
self.client_cfg = client_cfg
|
| 43 |
+
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
| 44 |
+
# mc.pyvector servers as a point which points to a memory cache
|
| 45 |
+
self._mc_buffer = mc.pyvector()
|
| 46 |
+
|
| 47 |
+
def get(self, filepath):
|
| 48 |
+
filepath = str(filepath)
|
| 49 |
+
import mc
|
| 50 |
+
self._client.Get(filepath, self._mc_buffer)
|
| 51 |
+
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
| 52 |
+
return value_buf
|
| 53 |
+
|
| 54 |
+
def get_text(self, filepath):
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class HardDiskBackend(BaseStorageBackend):
|
| 59 |
+
"""Raw hard disks storage backend."""
|
| 60 |
+
|
| 61 |
+
def get(self, filepath):
|
| 62 |
+
filepath = str(filepath)
|
| 63 |
+
with open(filepath, 'rb') as f:
|
| 64 |
+
value_buf = f.read()
|
| 65 |
+
return value_buf
|
| 66 |
+
|
| 67 |
+
def get_text(self, filepath):
|
| 68 |
+
filepath = str(filepath)
|
| 69 |
+
with open(filepath, 'r') as f:
|
| 70 |
+
value_buf = f.read()
|
| 71 |
+
return value_buf
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class LmdbBackend(BaseStorageBackend):
|
| 75 |
+
"""Lmdb storage backend.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
db_paths (str | list[str]): Lmdb database paths.
|
| 79 |
+
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
| 80 |
+
readonly (bool, optional): Lmdb environment parameter. If True,
|
| 81 |
+
disallow any write operations. Default: True.
|
| 82 |
+
lock (bool, optional): Lmdb environment parameter. If False, when
|
| 83 |
+
concurrent access occurs, do not lock the database. Default: False.
|
| 84 |
+
readahead (bool, optional): Lmdb environment parameter. If False,
|
| 85 |
+
disable the OS filesystem readahead mechanism, which may improve
|
| 86 |
+
random read performance when a database is larger than RAM.
|
| 87 |
+
Default: False.
|
| 88 |
+
|
| 89 |
+
Attributes:
|
| 90 |
+
db_paths (list): Lmdb database path.
|
| 91 |
+
_client (list): A list of several lmdb envs.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
| 95 |
+
try:
|
| 96 |
+
import lmdb
|
| 97 |
+
except ImportError:
|
| 98 |
+
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
| 99 |
+
|
| 100 |
+
if isinstance(client_keys, str):
|
| 101 |
+
client_keys = [client_keys]
|
| 102 |
+
|
| 103 |
+
if isinstance(db_paths, list):
|
| 104 |
+
self.db_paths = [str(v) for v in db_paths]
|
| 105 |
+
elif isinstance(db_paths, str):
|
| 106 |
+
self.db_paths = [str(db_paths)]
|
| 107 |
+
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
| 108 |
+
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
| 109 |
+
|
| 110 |
+
self._client = {}
|
| 111 |
+
for client, path in zip(client_keys, self.db_paths):
|
| 112 |
+
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
| 113 |
+
|
| 114 |
+
def get(self, filepath, client_key):
|
| 115 |
+
"""Get values according to the filepath from one lmdb named client_key.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
| 119 |
+
client_key (str): Used for distinguishing differnet lmdb envs.
|
| 120 |
+
"""
|
| 121 |
+
filepath = str(filepath)
|
| 122 |
+
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
|
| 123 |
+
client = self._client[client_key]
|
| 124 |
+
with client.begin(write=False) as txn:
|
| 125 |
+
value_buf = txn.get(filepath.encode('ascii'))
|
| 126 |
+
return value_buf
|
| 127 |
+
|
| 128 |
+
def get_text(self, filepath):
|
| 129 |
+
raise NotImplementedError
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class FileClient(object):
|
| 133 |
+
"""A general file client to access files in different backend.
|
| 134 |
+
|
| 135 |
+
The client loads a file or text in a specified backend from its path
|
| 136 |
+
and return it as a binary file. it can also register other backend
|
| 137 |
+
accessor with a given name and backend class.
|
| 138 |
+
|
| 139 |
+
Attributes:
|
| 140 |
+
backend (str): The storage backend type. Options are "disk",
|
| 141 |
+
"memcached" and "lmdb".
|
| 142 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
_backends = {
|
| 146 |
+
'disk': HardDiskBackend,
|
| 147 |
+
'memcached': MemcachedBackend,
|
| 148 |
+
'lmdb': LmdbBackend,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
def __init__(self, backend='disk', **kwargs):
|
| 152 |
+
if backend not in self._backends:
|
| 153 |
+
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
| 154 |
+
f' are {list(self._backends.keys())}')
|
| 155 |
+
self.backend = backend
|
| 156 |
+
self.client = self._backends[backend](**kwargs)
|
| 157 |
+
|
| 158 |
+
def get(self, filepath, client_key='default'):
|
| 159 |
+
# client_key is used only for lmdb, where different fileclients have
|
| 160 |
+
# different lmdb environments.
|
| 161 |
+
if self.backend == 'lmdb':
|
| 162 |
+
return self.client.get(filepath, client_key)
|
| 163 |
+
else:
|
| 164 |
+
return self.client.get(filepath)
|
| 165 |
+
|
| 166 |
+
def get_text(self, filepath):
|
| 167 |
+
return self.client.get_text(filepath)
|
basicsr/utils/img_util.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from torchvision.utils import make_grid
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 10 |
+
"""Numpy array to tensor.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 14 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 15 |
+
float32 (bool): Whether to change to float32.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 19 |
+
one element, just return tensor.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def _totensor(img, bgr2rgb, float32):
|
| 23 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 24 |
+
if img.dtype == 'float64':
|
| 25 |
+
img = img.astype('float32')
|
| 26 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 27 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 28 |
+
if float32:
|
| 29 |
+
img = img.float()
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
if isinstance(imgs, list):
|
| 33 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 34 |
+
else:
|
| 35 |
+
return _totensor(imgs, bgr2rgb, float32)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
| 39 |
+
"""Convert torch Tensors into image numpy arrays.
|
| 40 |
+
|
| 41 |
+
After clamping to [min, max], values will be normalized to [0, 1].
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
tensor (Tensor or list[Tensor]): Accept shapes:
|
| 45 |
+
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
| 46 |
+
2) 3D Tensor of shape (3/1 x H x W);
|
| 47 |
+
3) 2D Tensor of shape (H x W).
|
| 48 |
+
Tensor channel should be in RGB order.
|
| 49 |
+
rgb2bgr (bool): Whether to change rgb to bgr.
|
| 50 |
+
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
| 51 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
| 52 |
+
range [0, 1]. Default: ``np.uint8``.
|
| 53 |
+
min_max (tuple[int]): min and max values for clamp.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
| 57 |
+
shape (H x W). The channel order is BGR.
|
| 58 |
+
"""
|
| 59 |
+
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
| 60 |
+
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
| 61 |
+
|
| 62 |
+
if torch.is_tensor(tensor):
|
| 63 |
+
tensor = [tensor]
|
| 64 |
+
result = []
|
| 65 |
+
for _tensor in tensor:
|
| 66 |
+
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
| 67 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
| 68 |
+
|
| 69 |
+
n_dim = _tensor.dim()
|
| 70 |
+
if n_dim == 4:
|
| 71 |
+
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
| 72 |
+
img_np = img_np.transpose(1, 2, 0)
|
| 73 |
+
if rgb2bgr:
|
| 74 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 75 |
+
elif n_dim == 3:
|
| 76 |
+
img_np = _tensor.numpy()
|
| 77 |
+
img_np = img_np.transpose(1, 2, 0)
|
| 78 |
+
if img_np.shape[2] == 1: # gray image
|
| 79 |
+
img_np = np.squeeze(img_np, axis=2)
|
| 80 |
+
else:
|
| 81 |
+
if rgb2bgr:
|
| 82 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 83 |
+
elif n_dim == 2:
|
| 84 |
+
img_np = _tensor.numpy()
|
| 85 |
+
else:
|
| 86 |
+
raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
|
| 87 |
+
if out_type == np.uint8:
|
| 88 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
| 89 |
+
img_np = (img_np * 255.0).round()
|
| 90 |
+
img_np = img_np.astype(out_type)
|
| 91 |
+
result.append(img_np)
|
| 92 |
+
if len(result) == 1:
|
| 93 |
+
result = result[0]
|
| 94 |
+
return result
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
| 98 |
+
"""This implementation is slightly faster than tensor2img.
|
| 99 |
+
It now only supports torch tensor with shape (1, c, h, w).
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
| 103 |
+
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
| 104 |
+
min_max (tuple[int]): min and max values for clamp.
|
| 105 |
+
"""
|
| 106 |
+
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
| 107 |
+
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
| 108 |
+
output = output.type(torch.uint8).cpu().numpy()
|
| 109 |
+
if rgb2bgr:
|
| 110 |
+
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
| 111 |
+
return output
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def imfrombytes(content, flag='color', float32=False):
|
| 115 |
+
"""Read an image from bytes.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
content (bytes): Image bytes got from files or other streams.
|
| 119 |
+
flag (str): Flags specifying the color type of a loaded image,
|
| 120 |
+
candidates are `color`, `grayscale` and `unchanged`.
|
| 121 |
+
float32 (bool): Whether to change to float32., If True, will also norm
|
| 122 |
+
to [0, 1]. Default: False.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
ndarray: Loaded image array.
|
| 126 |
+
"""
|
| 127 |
+
img_np = np.frombuffer(content, np.uint8)
|
| 128 |
+
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
| 129 |
+
img = cv2.imdecode(img_np, imread_flags[flag])
|
| 130 |
+
if float32:
|
| 131 |
+
img = img.astype(np.float32) / 255.
|
| 132 |
+
return img
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
| 136 |
+
"""Write image to file.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
img (ndarray): Image array to be written.
|
| 140 |
+
file_path (str): Image file path.
|
| 141 |
+
params (None or list): Same as opencv's :func:`imwrite` interface.
|
| 142 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
| 143 |
+
whether to create it automatically.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
bool: Successful or not.
|
| 147 |
+
"""
|
| 148 |
+
if auto_mkdir:
|
| 149 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
| 150 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 151 |
+
return cv2.imwrite(file_path, img, params)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def crop_border(imgs, crop_border):
|
| 155 |
+
"""Crop borders of images.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
| 159 |
+
crop_border (int): Crop border for each end of height and weight.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
list[ndarray]: Cropped images.
|
| 163 |
+
"""
|
| 164 |
+
if crop_border == 0:
|
| 165 |
+
return imgs
|
| 166 |
+
else:
|
| 167 |
+
if isinstance(imgs, list):
|
| 168 |
+
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
| 169 |
+
else:
|
| 170 |
+
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
basicsr/utils/lmdb_util.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import lmdb
|
| 3 |
+
import sys
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
from os import path as osp
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def make_lmdb_from_imgs(data_path,
|
| 10 |
+
lmdb_path,
|
| 11 |
+
img_path_list,
|
| 12 |
+
keys,
|
| 13 |
+
batch=5000,
|
| 14 |
+
compress_level=1,
|
| 15 |
+
multiprocessing_read=False,
|
| 16 |
+
n_thread=40,
|
| 17 |
+
map_size=None):
|
| 18 |
+
"""Make lmdb from images.
|
| 19 |
+
|
| 20 |
+
Contents of lmdb. The file structure is:
|
| 21 |
+
example.lmdb
|
| 22 |
+
├── data.mdb
|
| 23 |
+
├── lock.mdb
|
| 24 |
+
├── meta_info.txt
|
| 25 |
+
|
| 26 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
| 27 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
| 28 |
+
|
| 29 |
+
The meta_info.txt is a specified txt file to record the meta information
|
| 30 |
+
of our datasets. It will be automatically created when preparing
|
| 31 |
+
datasets by our provided dataset tools.
|
| 32 |
+
Each line in the txt file records 1)image name (with extension),
|
| 33 |
+
2)image shape, and 3)compression level, separated by a white space.
|
| 34 |
+
|
| 35 |
+
For example, the meta information could be:
|
| 36 |
+
`000_00000000.png (720,1280,3) 1`, which means:
|
| 37 |
+
1) image name (with extension): 000_00000000.png;
|
| 38 |
+
2) image shape: (720,1280,3);
|
| 39 |
+
3) compression level: 1
|
| 40 |
+
|
| 41 |
+
We use the image name without extension as the lmdb key.
|
| 42 |
+
|
| 43 |
+
If `multiprocessing_read` is True, it will read all the images to memory
|
| 44 |
+
using multiprocessing. Thus, your server needs to have enough memory.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
data_path (str): Data path for reading images.
|
| 48 |
+
lmdb_path (str): Lmdb save path.
|
| 49 |
+
img_path_list (str): Image path list.
|
| 50 |
+
keys (str): Used for lmdb keys.
|
| 51 |
+
batch (int): After processing batch images, lmdb commits.
|
| 52 |
+
Default: 5000.
|
| 53 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
| 54 |
+
multiprocessing_read (bool): Whether use multiprocessing to read all
|
| 55 |
+
the images to memory. Default: False.
|
| 56 |
+
n_thread (int): For multiprocessing.
|
| 57 |
+
map_size (int | None): Map size for lmdb env. If None, use the
|
| 58 |
+
estimated size from images. Default: None
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
| 62 |
+
f'but got {len(img_path_list)} and {len(keys)}')
|
| 63 |
+
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
| 64 |
+
print(f'Totoal images: {len(img_path_list)}')
|
| 65 |
+
if not lmdb_path.endswith('.lmdb'):
|
| 66 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
| 67 |
+
if osp.exists(lmdb_path):
|
| 68 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
| 69 |
+
sys.exit(1)
|
| 70 |
+
|
| 71 |
+
if multiprocessing_read:
|
| 72 |
+
# read all the images to memory (multiprocessing)
|
| 73 |
+
dataset = {} # use dict to keep the order for multiprocessing
|
| 74 |
+
shapes = {}
|
| 75 |
+
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
| 76 |
+
pbar = tqdm(total=len(img_path_list), unit='image')
|
| 77 |
+
|
| 78 |
+
def callback(arg):
|
| 79 |
+
"""get the image data and update pbar."""
|
| 80 |
+
key, dataset[key], shapes[key] = arg
|
| 81 |
+
pbar.update(1)
|
| 82 |
+
pbar.set_description(f'Read {key}')
|
| 83 |
+
|
| 84 |
+
pool = Pool(n_thread)
|
| 85 |
+
for path, key in zip(img_path_list, keys):
|
| 86 |
+
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
| 87 |
+
pool.close()
|
| 88 |
+
pool.join()
|
| 89 |
+
pbar.close()
|
| 90 |
+
print(f'Finish reading {len(img_path_list)} images.')
|
| 91 |
+
|
| 92 |
+
# create lmdb environment
|
| 93 |
+
if map_size is None:
|
| 94 |
+
# obtain data size for one image
|
| 95 |
+
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
| 96 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
| 97 |
+
data_size_per_img = img_byte.nbytes
|
| 98 |
+
print('Data size per image is: ', data_size_per_img)
|
| 99 |
+
data_size = data_size_per_img * len(img_path_list)
|
| 100 |
+
map_size = data_size * 10
|
| 101 |
+
|
| 102 |
+
env = lmdb.open(lmdb_path, map_size=map_size)
|
| 103 |
+
|
| 104 |
+
# write data to lmdb
|
| 105 |
+
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
| 106 |
+
txn = env.begin(write=True)
|
| 107 |
+
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
| 108 |
+
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
| 109 |
+
pbar.update(1)
|
| 110 |
+
pbar.set_description(f'Write {key}')
|
| 111 |
+
key_byte = key.encode('ascii')
|
| 112 |
+
if multiprocessing_read:
|
| 113 |
+
img_byte = dataset[key]
|
| 114 |
+
h, w, c = shapes[key]
|
| 115 |
+
else:
|
| 116 |
+
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
| 117 |
+
h, w, c = img_shape
|
| 118 |
+
|
| 119 |
+
txn.put(key_byte, img_byte)
|
| 120 |
+
# write meta information
|
| 121 |
+
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
| 122 |
+
if idx % batch == 0:
|
| 123 |
+
txn.commit()
|
| 124 |
+
txn = env.begin(write=True)
|
| 125 |
+
pbar.close()
|
| 126 |
+
txn.commit()
|
| 127 |
+
env.close()
|
| 128 |
+
txt_file.close()
|
| 129 |
+
print('\nFinish writing lmdb.')
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def read_img_worker(path, key, compress_level):
|
| 133 |
+
"""Read image worker.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
path (str): Image path.
|
| 137 |
+
key (str): Image key.
|
| 138 |
+
compress_level (int): Compress level when encoding images.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
str: Image key.
|
| 142 |
+
byte: Image byte.
|
| 143 |
+
tuple[int]: Image shape.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
| 147 |
+
if img.ndim == 2:
|
| 148 |
+
h, w = img.shape
|
| 149 |
+
c = 1
|
| 150 |
+
else:
|
| 151 |
+
h, w, c = img.shape
|
| 152 |
+
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
| 153 |
+
return (key, img_byte, (h, w, c))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class LmdbMaker():
|
| 157 |
+
"""LMDB Maker.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
lmdb_path (str): Lmdb save path.
|
| 161 |
+
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
| 162 |
+
batch (int): After processing batch images, lmdb commits.
|
| 163 |
+
Default: 5000.
|
| 164 |
+
compress_level (int): Compress level when encoding images. Default: 1.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
| 168 |
+
if not lmdb_path.endswith('.lmdb'):
|
| 169 |
+
raise ValueError("lmdb_path must end with '.lmdb'.")
|
| 170 |
+
if osp.exists(lmdb_path):
|
| 171 |
+
print(f'Folder {lmdb_path} already exists. Exit.')
|
| 172 |
+
sys.exit(1)
|
| 173 |
+
|
| 174 |
+
self.lmdb_path = lmdb_path
|
| 175 |
+
self.batch = batch
|
| 176 |
+
self.compress_level = compress_level
|
| 177 |
+
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
| 178 |
+
self.txn = self.env.begin(write=True)
|
| 179 |
+
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
| 180 |
+
self.counter = 0
|
| 181 |
+
|
| 182 |
+
def put(self, img_byte, key, img_shape):
|
| 183 |
+
self.counter += 1
|
| 184 |
+
key_byte = key.encode('ascii')
|
| 185 |
+
self.txn.put(key_byte, img_byte)
|
| 186 |
+
# write meta information
|
| 187 |
+
h, w, c = img_shape
|
| 188 |
+
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
| 189 |
+
if self.counter % self.batch == 0:
|
| 190 |
+
self.txn.commit()
|
| 191 |
+
self.txn = self.env.begin(write=True)
|
| 192 |
+
|
| 193 |
+
def close(self):
|
| 194 |
+
self.txn.commit()
|
| 195 |
+
self.env.close()
|
| 196 |
+
self.txt_file.close()
|
basicsr/utils/logger.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from .dist_util import get_dist_info, master_only
|
| 6 |
+
|
| 7 |
+
initialized_logger = {}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MessageLogger():
|
| 11 |
+
"""Message logger for printing.
|
| 12 |
+
Args:
|
| 13 |
+
opt (dict): Config. It contains the following keys:
|
| 14 |
+
name (str): Exp name.
|
| 15 |
+
logger (dict): Contains 'print_freq' (str) for logger interval.
|
| 16 |
+
train (dict): Contains 'total_iter' (int) for total iters.
|
| 17 |
+
use_tb_logger (bool): Use tensorboard logger.
|
| 18 |
+
start_iter (int): Start iter. Default: 1.
|
| 19 |
+
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, opt, start_iter=1, tb_logger=None):
|
| 23 |
+
self.exp_name = opt['name']
|
| 24 |
+
self.interval = opt['logger']['print_freq']
|
| 25 |
+
self.start_iter = start_iter
|
| 26 |
+
self.max_iters = opt['train']['total_iter']
|
| 27 |
+
self.use_tb_logger = opt['logger']['use_tb_logger']
|
| 28 |
+
self.tb_logger = tb_logger
|
| 29 |
+
self.start_time = time.time()
|
| 30 |
+
self.logger = get_root_logger()
|
| 31 |
+
|
| 32 |
+
@master_only
|
| 33 |
+
def __call__(self, log_vars):
|
| 34 |
+
"""Format logging message.
|
| 35 |
+
Args:
|
| 36 |
+
log_vars (dict): It contains the following keys:
|
| 37 |
+
epoch (int): Epoch number.
|
| 38 |
+
iter (int): Current iter.
|
| 39 |
+
lrs (list): List for learning rates.
|
| 40 |
+
time (float): Iter time.
|
| 41 |
+
data_time (float): Data time for each iter.
|
| 42 |
+
"""
|
| 43 |
+
# epoch, iter, learning rates
|
| 44 |
+
epoch = log_vars.pop('epoch')
|
| 45 |
+
current_iter = log_vars.pop('iter')
|
| 46 |
+
lrs = log_vars.pop('lrs')
|
| 47 |
+
|
| 48 |
+
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
|
| 49 |
+
for v in lrs:
|
| 50 |
+
message += f'{v:.3e},'
|
| 51 |
+
message += ')] '
|
| 52 |
+
|
| 53 |
+
# time and estimated time
|
| 54 |
+
if 'time' in log_vars.keys():
|
| 55 |
+
iter_time = log_vars.pop('time')
|
| 56 |
+
data_time = log_vars.pop('data_time')
|
| 57 |
+
|
| 58 |
+
total_time = time.time() - self.start_time
|
| 59 |
+
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
| 60 |
+
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
| 61 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
| 62 |
+
message += f'[eta: {eta_str}, '
|
| 63 |
+
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
| 64 |
+
|
| 65 |
+
# other items, especially losses
|
| 66 |
+
for k, v in log_vars.items():
|
| 67 |
+
message += f'{k}: {v:.4e} '
|
| 68 |
+
# tensorboard logger
|
| 69 |
+
if self.use_tb_logger:
|
| 70 |
+
if k.startswith('l_'):
|
| 71 |
+
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
| 72 |
+
else:
|
| 73 |
+
self.tb_logger.add_scalar(k, v, current_iter)
|
| 74 |
+
self.logger.info(message)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@master_only
|
| 78 |
+
def init_tb_logger(log_dir):
|
| 79 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 80 |
+
tb_logger = SummaryWriter(log_dir=log_dir)
|
| 81 |
+
return tb_logger
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@master_only
|
| 85 |
+
def init_wandb_logger(opt):
|
| 86 |
+
"""We now only use wandb to sync tensorboard log."""
|
| 87 |
+
import wandb
|
| 88 |
+
logger = logging.getLogger('basicsr')
|
| 89 |
+
|
| 90 |
+
project = opt['logger']['wandb']['project']
|
| 91 |
+
resume_id = opt['logger']['wandb'].get('resume_id')
|
| 92 |
+
if resume_id:
|
| 93 |
+
wandb_id = resume_id
|
| 94 |
+
resume = 'allow'
|
| 95 |
+
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
| 96 |
+
else:
|
| 97 |
+
wandb_id = wandb.util.generate_id()
|
| 98 |
+
resume = 'never'
|
| 99 |
+
|
| 100 |
+
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
| 101 |
+
|
| 102 |
+
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
| 106 |
+
"""Get the root logger.
|
| 107 |
+
The logger will be initialized if it has not been initialized. By default a
|
| 108 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
| 109 |
+
also be added.
|
| 110 |
+
Args:
|
| 111 |
+
logger_name (str): root logger name. Default: 'basicsr'.
|
| 112 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
| 113 |
+
will be added to the root logger.
|
| 114 |
+
log_level (int): The root logger level. Note that only the process of
|
| 115 |
+
rank 0 is affected, while other processes will set the level to
|
| 116 |
+
"Error" and be silent most of the time.
|
| 117 |
+
Returns:
|
| 118 |
+
logging.Logger: The root logger.
|
| 119 |
+
"""
|
| 120 |
+
logger = logging.getLogger(logger_name)
|
| 121 |
+
# if the logger has been initialized, just return it
|
| 122 |
+
if logger_name in initialized_logger:
|
| 123 |
+
return logger
|
| 124 |
+
|
| 125 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
| 126 |
+
stream_handler = logging.StreamHandler()
|
| 127 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
| 128 |
+
logger.addHandler(stream_handler)
|
| 129 |
+
logger.propagate = False
|
| 130 |
+
rank, _ = get_dist_info()
|
| 131 |
+
if rank != 0:
|
| 132 |
+
logger.setLevel('ERROR')
|
| 133 |
+
elif log_file is not None:
|
| 134 |
+
logger.setLevel(log_level)
|
| 135 |
+
# add file handler
|
| 136 |
+
# file_handler = logging.FileHandler(log_file, 'w')
|
| 137 |
+
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
|
| 138 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
| 139 |
+
file_handler.setLevel(log_level)
|
| 140 |
+
logger.addHandler(file_handler)
|
| 141 |
+
initialized_logger[logger_name] = True
|
| 142 |
+
return logger
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_env_info():
|
| 146 |
+
"""Get environment information.
|
| 147 |
+
Currently, only log the software version.
|
| 148 |
+
"""
|
| 149 |
+
import torch
|
| 150 |
+
import torchvision
|
| 151 |
+
|
| 152 |
+
from basicsr.version import __version__
|
| 153 |
+
msg = r"""
|
| 154 |
+
____ _ _____ ____
|
| 155 |
+
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
| 156 |
+
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
| 157 |
+
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
| 158 |
+
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
| 159 |
+
______ __ __ __ __
|
| 160 |
+
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
| 161 |
+
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
| 162 |
+
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
| 163 |
+
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
| 164 |
+
"""
|
| 165 |
+
msg += ('\nVersion Information: '
|
| 166 |
+
f'\n\tBasicSR: {__version__}'
|
| 167 |
+
f'\n\tPyTorch: {torch.__version__}'
|
| 168 |
+
f'\n\tTorchVision: {torchvision.__version__}')
|
| 169 |
+
return msg
|
basicsr/utils/matlab_functions.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def cubic(x):
|
| 7 |
+
"""cubic function used for calculate_weights_indices."""
|
| 8 |
+
absx = torch.abs(x)
|
| 9 |
+
absx2 = absx**2
|
| 10 |
+
absx3 = absx**3
|
| 11 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
| 12 |
+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
| 13 |
+
(absx <= 2)).type_as(absx))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
| 17 |
+
"""Calculate weights and indices, used for imresize function.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
in_length (int): Input length.
|
| 21 |
+
out_length (int): Output length.
|
| 22 |
+
scale (float): Scale factor.
|
| 23 |
+
kernel_width (int): Kernel width.
|
| 24 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
if (scale < 1) and antialiasing:
|
| 28 |
+
# Use a modified kernel (larger kernel width) to simultaneously
|
| 29 |
+
# interpolate and antialias
|
| 30 |
+
kernel_width = kernel_width / scale
|
| 31 |
+
|
| 32 |
+
# Output-space coordinates
|
| 33 |
+
x = torch.linspace(1, out_length, out_length)
|
| 34 |
+
|
| 35 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
| 36 |
+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
| 37 |
+
# space maps to 1.5 in input space.
|
| 38 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
| 39 |
+
|
| 40 |
+
# What is the left-most pixel that can be involved in the computation?
|
| 41 |
+
left = torch.floor(u - kernel_width / 2)
|
| 42 |
+
|
| 43 |
+
# What is the maximum number of pixels that can be involved in the
|
| 44 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
| 45 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
| 46 |
+
# of this function.
|
| 47 |
+
p = math.ceil(kernel_width) + 2
|
| 48 |
+
|
| 49 |
+
# The indices of the input pixels involved in computing the k-th output
|
| 50 |
+
# pixel are in row k of the indices matrix.
|
| 51 |
+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
| 52 |
+
out_length, p)
|
| 53 |
+
|
| 54 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
| 55 |
+
# weights matrix.
|
| 56 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
| 57 |
+
|
| 58 |
+
# apply cubic kernel
|
| 59 |
+
if (scale < 1) and antialiasing:
|
| 60 |
+
weights = scale * cubic(distance_to_center * scale)
|
| 61 |
+
else:
|
| 62 |
+
weights = cubic(distance_to_center)
|
| 63 |
+
|
| 64 |
+
# Normalize the weights matrix so that each row sums to 1.
|
| 65 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
| 66 |
+
weights = weights / weights_sum.expand(out_length, p)
|
| 67 |
+
|
| 68 |
+
# If a column in weights is all zero, get rid of it. only consider the
|
| 69 |
+
# first and last column.
|
| 70 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
| 71 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
| 72 |
+
indices = indices.narrow(1, 1, p - 2)
|
| 73 |
+
weights = weights.narrow(1, 1, p - 2)
|
| 74 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
| 75 |
+
indices = indices.narrow(1, 0, p - 2)
|
| 76 |
+
weights = weights.narrow(1, 0, p - 2)
|
| 77 |
+
weights = weights.contiguous()
|
| 78 |
+
indices = indices.contiguous()
|
| 79 |
+
sym_len_s = -indices.min() + 1
|
| 80 |
+
sym_len_e = indices.max() - in_length
|
| 81 |
+
indices = indices + sym_len_s - 1
|
| 82 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def imresize(img, scale, antialiasing=True):
|
| 87 |
+
"""imresize function same as MATLAB.
|
| 88 |
+
|
| 89 |
+
It now only supports bicubic.
|
| 90 |
+
The same scale applies for both height and width.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
img (Tensor | Numpy array):
|
| 94 |
+
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
| 95 |
+
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
| 96 |
+
scale (float): Scale factor. The same scale applies for both height
|
| 97 |
+
and width.
|
| 98 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 99 |
+
Default: True.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
| 103 |
+
"""
|
| 104 |
+
if type(img).__module__ == np.__name__: # numpy type
|
| 105 |
+
numpy_type = True
|
| 106 |
+
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
| 107 |
+
else:
|
| 108 |
+
numpy_type = False
|
| 109 |
+
|
| 110 |
+
in_c, in_h, in_w = img.size()
|
| 111 |
+
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
| 112 |
+
kernel_width = 4
|
| 113 |
+
kernel = 'cubic'
|
| 114 |
+
|
| 115 |
+
# get weights and indices
|
| 116 |
+
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
| 117 |
+
antialiasing)
|
| 118 |
+
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
| 119 |
+
antialiasing)
|
| 120 |
+
# process H dimension
|
| 121 |
+
# symmetric copying
|
| 122 |
+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
| 123 |
+
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
| 124 |
+
|
| 125 |
+
sym_patch = img[:, :sym_len_hs, :]
|
| 126 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 127 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 128 |
+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
| 129 |
+
|
| 130 |
+
sym_patch = img[:, -sym_len_he:, :]
|
| 131 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 132 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 133 |
+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
| 134 |
+
|
| 135 |
+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
| 136 |
+
kernel_width = weights_h.size(1)
|
| 137 |
+
for i in range(out_h):
|
| 138 |
+
idx = int(indices_h[i][0])
|
| 139 |
+
for j in range(in_c):
|
| 140 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
| 141 |
+
|
| 142 |
+
# process W dimension
|
| 143 |
+
# symmetric copying
|
| 144 |
+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
| 145 |
+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
| 146 |
+
|
| 147 |
+
sym_patch = out_1[:, :, :sym_len_ws]
|
| 148 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 149 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 150 |
+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
| 151 |
+
|
| 152 |
+
sym_patch = out_1[:, :, -sym_len_we:]
|
| 153 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 154 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 155 |
+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
| 156 |
+
|
| 157 |
+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
| 158 |
+
kernel_width = weights_w.size(1)
|
| 159 |
+
for i in range(out_w):
|
| 160 |
+
idx = int(indices_w[i][0])
|
| 161 |
+
for j in range(in_c):
|
| 162 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
| 163 |
+
|
| 164 |
+
if numpy_type:
|
| 165 |
+
out_2 = out_2.numpy().transpose(1, 2, 0)
|
| 166 |
+
return out_2
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def rgb2ycbcr(img, y_only=False):
|
| 170 |
+
"""Convert a RGB image to YCbCr image.
|
| 171 |
+
|
| 172 |
+
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
| 173 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 174 |
+
television. See more details in
|
| 175 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 176 |
+
|
| 177 |
+
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
| 178 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 179 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
img (ndarray): The input image. It accepts:
|
| 183 |
+
1. np.uint8 type with range [0, 255];
|
| 184 |
+
2. np.float32 type with range [0, 1].
|
| 185 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
| 189 |
+
and range as input image.
|
| 190 |
+
"""
|
| 191 |
+
img_type = img.dtype
|
| 192 |
+
img = _convert_input_type_range(img)
|
| 193 |
+
if y_only:
|
| 194 |
+
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
| 195 |
+
else:
|
| 196 |
+
out_img = np.matmul(
|
| 197 |
+
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
| 198 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 199 |
+
return out_img
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def bgr2ycbcr(img, y_only=False):
|
| 203 |
+
"""Convert a BGR image to YCbCr image.
|
| 204 |
+
|
| 205 |
+
The bgr version of rgb2ycbcr.
|
| 206 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 207 |
+
television. See more details in
|
| 208 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 209 |
+
|
| 210 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
| 211 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 212 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
img (ndarray): The input image. It accepts:
|
| 216 |
+
1. np.uint8 type with range [0, 255];
|
| 217 |
+
2. np.float32 type with range [0, 1].
|
| 218 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
| 222 |
+
and range as input image.
|
| 223 |
+
"""
|
| 224 |
+
img_type = img.dtype
|
| 225 |
+
img = _convert_input_type_range(img)
|
| 226 |
+
if y_only:
|
| 227 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
| 228 |
+
else:
|
| 229 |
+
out_img = np.matmul(
|
| 230 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
| 231 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 232 |
+
return out_img
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def ycbcr2rgb(img):
|
| 236 |
+
"""Convert a YCbCr image to RGB image.
|
| 237 |
+
|
| 238 |
+
This function produces the same results as Matlab's ycbcr2rgb function.
|
| 239 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 240 |
+
television. See more details in
|
| 241 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 242 |
+
|
| 243 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
| 244 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 245 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
img (ndarray): The input image. It accepts:
|
| 249 |
+
1. np.uint8 type with range [0, 255];
|
| 250 |
+
2. np.float32 type with range [0, 1].
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
ndarray: The converted RGB image. The output image has the same type
|
| 254 |
+
and range as input image.
|
| 255 |
+
"""
|
| 256 |
+
img_type = img.dtype
|
| 257 |
+
img = _convert_input_type_range(img) * 255
|
| 258 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
| 259 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
| 260 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 261 |
+
return out_img
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def ycbcr2bgr(img):
|
| 265 |
+
"""Convert a YCbCr image to BGR image.
|
| 266 |
+
|
| 267 |
+
The bgr version of ycbcr2rgb.
|
| 268 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 269 |
+
television. See more details in
|
| 270 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 271 |
+
|
| 272 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
| 273 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 274 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
img (ndarray): The input image. It accepts:
|
| 278 |
+
1. np.uint8 type with range [0, 255];
|
| 279 |
+
2. np.float32 type with range [0, 1].
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
ndarray: The converted BGR image. The output image has the same type
|
| 283 |
+
and range as input image.
|
| 284 |
+
"""
|
| 285 |
+
img_type = img.dtype
|
| 286 |
+
img = _convert_input_type_range(img) * 255
|
| 287 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
| 288 |
+
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
| 289 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 290 |
+
return out_img
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _convert_input_type_range(img):
|
| 294 |
+
"""Convert the type and range of the input image.
|
| 295 |
+
|
| 296 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
| 297 |
+
It is mainly used for pre-processing the input image in colorspace
|
| 298 |
+
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
img (ndarray): The input image. It accepts:
|
| 302 |
+
1. np.uint8 type with range [0, 255];
|
| 303 |
+
2. np.float32 type with range [0, 1].
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
(ndarray): The converted image with type of np.float32 and range of
|
| 307 |
+
[0, 1].
|
| 308 |
+
"""
|
| 309 |
+
img_type = img.dtype
|
| 310 |
+
img = img.astype(np.float32)
|
| 311 |
+
if img_type == np.float32:
|
| 312 |
+
pass
|
| 313 |
+
elif img_type == np.uint8:
|
| 314 |
+
img /= 255.
|
| 315 |
+
else:
|
| 316 |
+
raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
|
| 317 |
+
return img
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _convert_output_type_range(img, dst_type):
|
| 321 |
+
"""Convert the type and range of the image according to dst_type.
|
| 322 |
+
|
| 323 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
| 324 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
| 325 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
| 326 |
+
range [0, 1].
|
| 327 |
+
It is mainly used for post-processing images in colorspace convertion
|
| 328 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
img (ndarray): The image to be converted with np.float32 type and
|
| 332 |
+
range [0, 255].
|
| 333 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
| 334 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
| 335 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
| 336 |
+
with range [0, 1].
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
(ndarray): The converted image with desired type and range.
|
| 340 |
+
"""
|
| 341 |
+
if dst_type not in (np.uint8, np.float32):
|
| 342 |
+
raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
|
| 343 |
+
if dst_type == np.uint8:
|
| 344 |
+
img = img.round()
|
| 345 |
+
else:
|
| 346 |
+
img /= 255.
|
| 347 |
+
return img.astype(dst_type)
|
basicsr/utils/misc.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
from os import path as osp
|
| 7 |
+
|
| 8 |
+
from .dist_util import master_only
|
| 9 |
+
from .logger import get_root_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def set_random_seed(seed):
|
| 13 |
+
"""Set random seeds."""
|
| 14 |
+
random.seed(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed(seed)
|
| 18 |
+
torch.cuda.manual_seed_all(seed)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_time_str():
|
| 22 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def mkdir_and_rename(path):
|
| 26 |
+
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
path (str): Folder path.
|
| 30 |
+
"""
|
| 31 |
+
if osp.exists(path):
|
| 32 |
+
new_name = path + '_archived_' + get_time_str()
|
| 33 |
+
print(f'Path already exists. Rename it to {new_name}', flush=True)
|
| 34 |
+
os.rename(path, new_name)
|
| 35 |
+
os.makedirs(path, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@master_only
|
| 39 |
+
def make_exp_dirs(opt):
|
| 40 |
+
"""Make dirs for experiments."""
|
| 41 |
+
path_opt = opt['path'].copy()
|
| 42 |
+
if opt['is_train']:
|
| 43 |
+
mkdir_and_rename(path_opt.pop('experiments_root'))
|
| 44 |
+
else:
|
| 45 |
+
mkdir_and_rename(path_opt.pop('results_root'))
|
| 46 |
+
for key, path in path_opt.items():
|
| 47 |
+
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
|
| 48 |
+
os.makedirs(path, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
| 52 |
+
"""Scan a directory to find the interested files.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
dir_path (str): Path of the directory.
|
| 56 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
| 57 |
+
interested in. Default: None.
|
| 58 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 59 |
+
directory. Default: False.
|
| 60 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
| 61 |
+
Default: False.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
A generator for all the interested files with relative pathes.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 68 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 69 |
+
|
| 70 |
+
root = dir_path
|
| 71 |
+
|
| 72 |
+
def _scandir(dir_path, suffix, recursive):
|
| 73 |
+
for entry in os.scandir(dir_path):
|
| 74 |
+
if not entry.name.startswith('.') and entry.is_file():
|
| 75 |
+
if full_path:
|
| 76 |
+
return_path = entry.path
|
| 77 |
+
else:
|
| 78 |
+
return_path = osp.relpath(entry.path, root)
|
| 79 |
+
|
| 80 |
+
if suffix is None:
|
| 81 |
+
yield return_path
|
| 82 |
+
elif return_path.endswith(suffix):
|
| 83 |
+
yield return_path
|
| 84 |
+
else:
|
| 85 |
+
if recursive:
|
| 86 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
| 87 |
+
else:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def check_resume(opt, resume_iter):
|
| 94 |
+
"""Check resume states and pretrain_network paths.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
opt (dict): Options.
|
| 98 |
+
resume_iter (int): Resume iteration.
|
| 99 |
+
"""
|
| 100 |
+
logger = get_root_logger()
|
| 101 |
+
if opt['path']['resume_state']:
|
| 102 |
+
# get all the networks
|
| 103 |
+
networks = [key for key in opt.keys() if key.startswith('network_')]
|
| 104 |
+
flag_pretrain = False
|
| 105 |
+
for network in networks:
|
| 106 |
+
if opt['path'].get(f'pretrain_{network}') is not None:
|
| 107 |
+
flag_pretrain = True
|
| 108 |
+
if flag_pretrain:
|
| 109 |
+
logger.warning('pretrain_network path will be ignored during resuming.')
|
| 110 |
+
# set pretrained model paths
|
| 111 |
+
for network in networks:
|
| 112 |
+
name = f'pretrain_{network}'
|
| 113 |
+
basename = network.replace('network_', '')
|
| 114 |
+
if opt['path'].get('ignore_resume_networks') is None or (basename
|
| 115 |
+
not in opt['path']['ignore_resume_networks']):
|
| 116 |
+
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
| 117 |
+
logger.info(f"Set {name} to {opt['path'][name]}")
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def sizeof_fmt(size, suffix='B'):
|
| 121 |
+
"""Get human readable file size.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
size (int): File size.
|
| 125 |
+
suffix (str): Suffix. Default: 'B'.
|
| 126 |
+
|
| 127 |
+
Return:
|
| 128 |
+
str: Formated file siz.
|
| 129 |
+
"""
|
| 130 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
| 131 |
+
if abs(size) < 1024.0:
|
| 132 |
+
return f'{size:3.1f} {unit}{suffix}'
|
| 133 |
+
size /= 1024.0
|
| 134 |
+
return f'{size:3.1f} Y{suffix}'
|