Esmaill1 commited on
Commit
3365dbf
·
1 Parent(s): 96f4ff4

Refactor code structure for improved readability and maintainability

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. codeformer/.dockerignore +0 -15
  2. codeformer/.gitattributes +0 -3
  3. codeformer/.gitignore +0 -131
  4. codeformer/DOCUMENTATION.md +0 -158
  5. codeformer/Dockerfile +0 -39
  6. codeformer/LICENSE +0 -35
  7. codeformer/README.md +0 -229
  8. codeformer/app.py +0 -303
  9. codeformer/basicsr/VERSION +0 -1
  10. codeformer/basicsr/__init__.py +0 -11
  11. codeformer/basicsr/archs/__init__.py +0 -25
  12. codeformer/basicsr/archs/arcface_arch.py +0 -245
  13. codeformer/basicsr/archs/arch_util.py +0 -318
  14. codeformer/basicsr/archs/codeformer_arch.py +0 -280
  15. codeformer/basicsr/archs/rrdbnet_arch.py +0 -119
  16. codeformer/basicsr/archs/vgg_arch.py +0 -161
  17. codeformer/basicsr/archs/vqgan_arch.py +0 -434
  18. codeformer/basicsr/data/__init__.py +0 -100
  19. codeformer/basicsr/data/data_sampler.py +0 -48
  20. codeformer/basicsr/data/data_util.py +0 -392
  21. codeformer/basicsr/data/ffhq_blind_dataset.py +0 -299
  22. codeformer/basicsr/data/ffhq_blind_joint_dataset.py +0 -324
  23. codeformer/basicsr/data/gaussian_kernels.py +0 -690
  24. codeformer/basicsr/data/paired_image_dataset.py +0 -101
  25. codeformer/basicsr/data/prefetch_dataloader.py +0 -125
  26. codeformer/basicsr/data/transforms.py +0 -165
  27. codeformer/basicsr/losses/__init__.py +0 -26
  28. codeformer/basicsr/losses/loss_util.py +0 -95
  29. codeformer/basicsr/losses/losses.py +0 -455
  30. codeformer/basicsr/metrics/__init__.py +0 -19
  31. codeformer/basicsr/metrics/metric_util.py +0 -45
  32. codeformer/basicsr/metrics/psnr_ssim.py +0 -128
  33. codeformer/basicsr/models/__init__.py +0 -30
  34. codeformer/basicsr/models/base_model.py +0 -322
  35. codeformer/basicsr/models/codeformer_idx_model.py +0 -220
  36. codeformer/basicsr/models/codeformer_joint_model.py +0 -350
  37. codeformer/basicsr/models/codeformer_model.py +0 -332
  38. codeformer/basicsr/models/lr_scheduler.py +0 -96
  39. codeformer/basicsr/models/sr_model.py +0 -209
  40. codeformer/basicsr/models/vqgan_model.py +0 -285
  41. codeformer/basicsr/ops/dcn/__init__.py +0 -7
  42. codeformer/basicsr/ops/dcn/deform_conv.py +0 -377
  43. codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp +0 -685
  44. codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu +0 -867
  45. codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp +0 -164
  46. codeformer/basicsr/ops/fused_act/__init__.py +0 -3
  47. codeformer/basicsr/ops/fused_act/fused_act.py +0 -89
  48. codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp +0 -26
  49. codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu +0 -100
  50. codeformer/basicsr/ops/upfirdn2d/__init__.py +0 -3
codeformer/.dockerignore DELETED
@@ -1,15 +0,0 @@
1
- .git
2
- .gitignore
3
- __pycache__
4
- *.pyc
5
- *.pyo
6
- *.pyd
7
- .DS_Store
8
- weights/
9
- results/
10
- inputs/cropped_faces/
11
- inputs/gray_faces/
12
- inputs/masked_faces/
13
- inputs/whole_imgs/
14
- output/
15
- web-demos/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/.gitattributes DELETED
@@ -1,3 +0,0 @@
1
- *.png filter=lfs diff=lfs merge=lfs -text
2
- *.jpg filter=lfs diff=lfs merge=lfs -text
3
- *.jpeg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
codeformer/.gitignore DELETED
@@ -1,131 +0,0 @@
1
- .vscode
2
-
3
- # ignored files
4
- version.py
5
-
6
- # ignored files with suffix
7
- *.html
8
- *.png
9
- *.jpeg
10
- *.jpg
11
- *.pt
12
- *.gif
13
- *.pth
14
- *.dat
15
- *.zip
16
-
17
- # template
18
-
19
- # Byte-compiled / optimized / DLL files
20
- __pycache__/
21
- *.py[cod]
22
- *$py.class
23
-
24
- # C extensions
25
- *.so
26
-
27
- # Distribution / packaging
28
- .Python
29
- build/
30
- develop-eggs/
31
- dist/
32
- downloads/
33
- eggs/
34
- .eggs/
35
- lib/
36
- lib64/
37
- parts/
38
- sdist/
39
- var/
40
- wheels/
41
- *.egg-info/
42
- .installed.cfg
43
- *.egg
44
- MANIFEST
45
-
46
- # PyInstaller
47
- # Usually these files are written by a python script from a template
48
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
- *.manifest
50
- *.spec
51
-
52
- # Installer logs
53
- pip-log.txt
54
- pip-delete-this-directory.txt
55
-
56
- # Unit test / coverage reports
57
- htmlcov/
58
- .tox/
59
- .coverage
60
- .coverage.*
61
- .cache
62
- nosetests.xml
63
- coverage.xml
64
- *.cover
65
- .hypothesis/
66
- .pytest_cache/
67
-
68
- # Translations
69
- *.mo
70
- *.pot
71
-
72
- # Django stuff:
73
- *.log
74
- local_settings.py
75
- db.sqlite3
76
-
77
- # Flask stuff:
78
- instance/
79
- .webassets-cache
80
-
81
- # Scrapy stuff:
82
- .scrapy
83
-
84
- # Sphinx documentation
85
- docs/_build/
86
-
87
- # PyBuilder
88
- target/
89
-
90
- # Jupyter Notebook
91
- .ipynb_checkpoints
92
-
93
- # pyenv
94
- .python-version
95
-
96
- # celery beat schedule file
97
- celerybeat-schedule
98
-
99
- # SageMath parsed files
100
- *.sage.py
101
-
102
- # Environments
103
- .env
104
- .venv
105
- env/
106
- venv/
107
- ENV/
108
- env.bak/
109
- venv.bak/
110
-
111
- # Spyder project settings
112
- .spyderproject
113
- .spyproject
114
-
115
- # Rope project settings
116
- .ropeproject
117
-
118
- # mkdocs documentation
119
- /site
120
-
121
- # mypy
122
- .mypy_cache/
123
-
124
- # project
125
- results/
126
- experiments/
127
- tb_logger/
128
- run.sh
129
- *debug*
130
- *_old*
131
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/DOCUMENTATION.md DELETED
@@ -1,158 +0,0 @@
1
- # CodeFormer Face Restoration - Project Documentation
2
-
3
- ## 1. Introduction
4
-
5
- **CodeFormer** is a robust blind face restoration algorithm designed to restore old, degraded, or AI-generated face images. It utilizes a **Codebook Lookup Transformer** (VQGAN-based) to predict high-quality facial features even from severe degradation, ensuring that the restored faces look natural and faithful to the original identity.
6
-
7
- This project wraps the core CodeFormer research code into a deployable, user-friendly **Flask Web Application**, containerized with **Docker** for easy deployment on platforms like Hugging Face Spaces.
8
-
9
- ### Key Features
10
- * **Blind Face Restoration:** Restores faces from low-quality inputs without knowing the specific degradation details.
11
- * **Background Enhancement:** Uses **Real-ESRGAN** to upscale and enhance the non-face background regions of the image.
12
- * **Face Alignment & Paste-back:** Automatically detects faces, aligns them for processing, and seamlessly blends them back into the original image.
13
- * **Adjustable Fidelity:** Users can balance between restoration quality (hallucinating details) and identity fidelity (keeping the original look).
14
-
15
- ---
16
-
17
- ## 2. System Architecture
18
-
19
- The application is built on a Python/PyTorch backend served via Flask.
20
-
21
- ### 2.1 Technology Stack
22
- * **Framework:** Flask (Python Web Server)
23
- * **Deep Learning:** PyTorch, TorchVision
24
- * **Image Processing:** OpenCV, NumPy, Pillow
25
- * **Core Libraries:** `basicsr` (Basic Super-Restoration), `facelib` (Face detection/utils)
26
- * **Frontend:** HTML5, Bootstrap 5, Jinja2 Templates
27
- * **Containerization:** Docker (CUDA-enabled)
28
-
29
- ### 2.2 Directory Structure
30
- ```
31
- CodeFormer/
32
- ├── app.py # Main Flask application entry point
33
- ├── Dockerfile # Container configuration
34
- ├── requirements.txt # Python dependencies
35
- ├── basicsr/ # Core AI framework (Super-Resolution tools)
36
- ├── facelib/ # Face detection and alignment utilities
37
- ├── templates/ # HTML Frontend
38
- │ ├── index.html # Upload interface
39
- │ └── result.html # Results display
40
- ├── static/ # Static assets (css, js, uploads)
41
- │ ├── uploads/ # Temporary storage for input images
42
- │ └── results/ # Temporary storage for processed output
43
- └── weights/ # Pre-trained model weights (downloaded on startup)
44
- ├── CodeFormer/ # CodeFormer model (.pth)
45
- ├── facelib/ # Detection (RetinaFace) and Parsing models
46
- └── realesrgan/ # Background upscaler (Real-ESRGAN)
47
- ```
48
-
49
- ### 2.3 Logic Flow
50
- 1. **Input:** User uploads an image via the Web UI.
51
- 2. **Pre-processing (`app.py`):**
52
- * Image is saved to `static/uploads`.
53
- * Parameters (fidelity, upscale factor) are parsed.
54
- 3. **Inference Pipeline:**
55
- * **Detection:** `facelib` detects faces in the image using RetinaFace.
56
- * **Alignment:** Faces are cropped and aligned to a standard 512x512 resolution.
57
- * **Restoration:** The **CodeFormer** model processes the aligned faces.
58
- * **Upscaling (Optional):** The background is upscaled using **Real-ESRGAN**.
59
- * **Paste-back:** Restored faces are warped back to their original positions and blended.
60
- 4. **Output:** The final image is saved to `static/results` and displayed to the user.
61
-
62
- ---
63
-
64
- ## 3. Installation & Deployment
65
-
66
- ### 3.1 Docker Deployment (Recommended)
67
- The project is optimized for Docker.
68
-
69
- **Prerequisites:** Docker, NVIDIA GPU (optional, but recommended).
70
-
71
- 1. **Build the Image:**
72
- ```bash
73
- docker build -t codeformer-app .
74
- ```
75
-
76
- 2. **Run the Container:**
77
- ```bash
78
- # Run on port 7860 (Standard for HF Spaces)
79
- docker run -it -p 7860:7860 codeformer-app
80
- ```
81
- *Note: To use GPU, add the `--gpus all` flag to the run command.*
82
-
83
- ### 3.2 Hugging Face Spaces Deployment
84
- This repository is configured for direct deployment to Hugging Face.
85
-
86
- 1. Create a **Docker** Space on Hugging Face.
87
- 2. Push this entire repository to the Space's Git remote.
88
- ```bash
89
- git remote add hf git@hf.co:spaces/USERNAME/SPACE_NAME
90
- git push hf main
91
- ```
92
- 3. The Space will build (approx. 5-10 mins) and launch automatically.
93
-
94
- ### 3.3 Local Development
95
- 1. **Install Environment:**
96
- ```bash
97
- conda create -n codeformer python=3.8
98
- conda activate codeformer
99
- pip install -r requirements.txt
100
- ```
101
- 2. **Install Basicsr:**
102
- ```bash
103
- python basicsr/setup.py install
104
- ```
105
- 3. **Run App:**
106
- ```bash
107
- python app.py
108
- ```
109
-
110
- ---
111
-
112
- ## 4. User Guide (Web Interface)
113
-
114
- ### 4.1 Interface Controls
115
-
116
- * **Input Image:** Supports standard formats (JPG, PNG, WEBP). Drag and drop supported.
117
- * **Fidelity Weight (w):**
118
- * **Range:** 0.0 to 1.0.
119
- * **0.0 (Better Quality):** The model "hallucinates" more details. Results look very sharp and high-quality but may slightly alter the person's identity (look less like the original).
120
- * **1.0 (Better Identity):** The model sticks strictly to the original features. Results are faithful to the original photo but might be blurrier or contain more artifacts.
121
- * **Recommended:** 0.5 is a balanced default.
122
- * **Upscale Factor:**
123
- * Scales the final output resolution (1x, 2x, or 4x).
124
- * *Note: Higher scaling requires more VRAM.*
125
- * **Enhance Background:**
126
- * If checked, runs Real-ESRGAN on the non-face areas.
127
- * *Recommendation:* Keep checked for full-photo restoration. Uncheck if you only care about the face or are running on limited hardware.
128
- * **Upsample Face:**
129
- * If checked, the restored face is also upsampled to match the background resolution.
130
-
131
- ### 4.2 Viewing Results
132
- The result page features an interactive **Before/After Slider**. Drag the handle left and right to compare the pixels of the original versus the restored image directly.
133
-
134
- ---
135
-
136
- ## 5. Technical Details
137
-
138
- ### 5.1 Model Weights
139
- The application automatically checks for and downloads the following weights to the `weights/` directory on startup:
140
-
141
- | Model | Path | Description |
142
- | :--- | :--- | :--- |
143
- | **CodeFormer** | `weights/CodeFormer/codeformer.pth` | Main restoration model. |
144
- | **RetinaFace** | `weights/facelib/detection_Resnet50_Final.pth` | Face detection. |
145
- | **ParseNet** | `weights/facelib/parsing_parsenet.pth` | Face parsing (segmentation). |
146
- | **Real-ESRGAN** | `weights/realesrgan/RealESRGAN_x2plus.pth` | Background upscaler (x2). |
147
-
148
- ### 5.2 Performance Notes
149
- * **Memory:** The full pipeline (CodeFormer + Real-ESRGAN) requires significant RAM/VRAM. On CPU-only environments (like basic HF Spaces), processing a single image may take 30-60 seconds.
150
- * **Git LFS:** Image assets in this repository are tracked with Git LFS to keep the repo size manageable.
151
-
152
- ---
153
-
154
- ## 6. Credits & References
155
-
156
- * **Original Paper:** [Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)](https://arxiv.org/abs/2206.11253)
157
- * **Authors:** Shangchen Zhou, Kelvin C.K. Chan, Chongyi Li, Chen Change Loy (S-Lab, Nanyang Technological University).
158
- * **Original Repository:** [sczhou/CodeFormer](https://github.com/sczhou/CodeFormer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/Dockerfile DELETED
@@ -1,39 +0,0 @@
1
- FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
2
-
3
- WORKDIR /code
4
-
5
- # Install system dependencies
6
- RUN apt-get update && apt-get install -y \
7
- libgl1 \
8
- libglib2.0-0 \
9
- git \
10
- ninja-build \
11
- && rm -rf /var/lib/apt/lists/*
12
-
13
- # Copy requirements
14
- COPY requirements.txt .
15
-
16
- # Install python dependencies
17
- RUN pip install --no-cache-dir -r requirements.txt
18
-
19
- # Copy application code
20
- COPY . .
21
-
22
- # Create necessary directories and set permissions
23
- RUN mkdir -p weights inputs output static && \
24
- chmod 777 weights inputs output static
25
-
26
- # Install basicsr (build extensions in-place)
27
- RUN python basicsr/setup.py build_ext --inplace
28
-
29
- # Create a non-root user and switch to it
30
- RUN useradd -m -u 1000 user
31
- USER user
32
- ENV HOME=/home/user \
33
- PATH=/home/user/.local/bin:$PATH
34
-
35
- WORKDIR /code
36
-
37
- EXPOSE 7860
38
-
39
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/LICENSE DELETED
@@ -1,35 +0,0 @@
1
- S-Lab License 1.0
2
-
3
- Copyright 2022 S-Lab
4
-
5
- Redistribution and use for non-commercial purpose in source and
6
- binary forms, with or without modification, are permitted provided
7
- that the following conditions are met:
8
-
9
- 1. Redistributions of source code must retain the above copyright
10
- notice, this list of conditions and the following disclaimer.
11
-
12
- 2. Redistributions in binary form must reproduce the above copyright
13
- notice, this list of conditions and the following disclaimer in
14
- the documentation and/or other materials provided with the
15
- distribution.
16
-
17
- 3. Neither the name of the copyright holder nor the names of its
18
- contributors may be used to endorse or promote products derived
19
- from this software without specific prior written permission.
20
-
21
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
- HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
-
33
- In the event that redistribution and/or use for commercial purpose in
34
- source or binary forms, with or without modification is required,
35
- please contact the contributor(s) of the work.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/README.md DELETED
@@ -1,229 +0,0 @@
1
- ---
2
- title: CodeFormer
3
- emoji: 👤
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- app_file: app.py
8
- pinned: false
9
- ---
10
-
11
- <p align="center">
12
- <img src="assets/CodeFormer_logo.png" height=110>
13
- </p>
14
-
15
- ## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)
16
-
17
- [Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
18
-
19
-
20
- <a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer) ![Visitors](https://api.infinitescript.com/badgen/count?name=sczhou/CodeFormer&ltext=Visitors)
21
-
22
-
23
- [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
24
-
25
- S-Lab, Nanyang Technological University
26
-
27
- <img src="assets/network.jpg" width="800px"/>
28
-
29
-
30
- :star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
31
-
32
-
33
- ### Update
34
- - **2023.07.20**: Integrated to :panda_face: [OpenXLab](https://openxlab.org.cn/apps). Try out online demo! [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer)
35
- - **2023.04.19**: :whale: Training codes and config files are public available now.
36
- - **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images.
37
- - **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
38
- - **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper:
39
- - **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
40
- - **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
41
- - [**More**](docs/history_changelog.md)
42
-
43
- ### TODO
44
- - [x] Add training code and config files
45
- - [x] Add checkpoint and script for face inpainting
46
- - [x] Add checkpoint and script for face colorization
47
- - [x] ~~Add background image enhancement~~
48
-
49
- #### :panda_face: Try Enhancing Old Photos / Fixing AI-arts
50
- [<img src="assets/imgsli_1.jpg" height="226px"/>](https://imgsli.com/MTI3NTE2) [<img src="assets/imgsli_2.jpg" height="226px"/>](https://imgsli.com/MTI3NTE1) [<img src="assets/imgsli_3.jpg" height="226px"/>](https://imgsli.com/MTI3NTIw)
51
-
52
- #### Face Restoration
53
-
54
- <img src="assets/restoration_result1.png" width="400px"/> <img src="assets/restoration_result2.png" width="400px"/>
55
- <img src="assets/restoration_result3.png" width="400px"/> <img src="assets/restoration_result4.png" width="400px"/>
56
-
57
- #### Face Color Enhancement and Restoration
58
-
59
- <img src="assets/color_enhancement_result1.png" width="400px"/> <img src="assets/color_enhancement_result2.png" width="400px"/>
60
-
61
- #### Face Inpainting
62
-
63
- <img src="assets/inpainting_result1.png" width="400px"/> <img src="assets/inpainting_result2.png" width="400px"/>
64
-
65
-
66
-
67
- ### Dependencies and Installation
68
-
69
- - Pytorch >= 1.7.1
70
- - CUDA >= 10.1
71
- - Other required packages in `requirements.txt`
72
- ```
73
- # git clone this repository
74
- git clone https://github.com/sczhou/CodeFormer
75
- cd CodeFormer
76
-
77
- # create new anaconda env
78
- conda create -n codeformer python=3.8 -y
79
- conda activate codeformer
80
-
81
- # install python dependencies
82
- pip3 install -r requirements.txt
83
- python basicsr/setup.py develop
84
- conda install -c conda-forge dlib (only for face detection or cropping with dlib)
85
- ```
86
- <!-- conda install -c conda-forge dlib -->
87
-
88
- ### Quick Inference
89
-
90
- #### Download Pre-trained Models:
91
- Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
92
- ```
93
- python scripts/download_pretrained_models.py facelib
94
- python scripts/download_pretrained_models.py dlib (only for dlib face detector)
95
- ```
96
-
97
- Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
98
- ```
99
- python scripts/download_pretrained_models.py CodeFormer
100
- ```
101
-
102
- #### Prepare Testing Data:
103
- You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. You can get the cropped and aligned faces by running the following command:
104
- ```
105
- # you may need to install dlib via: conda install -c conda-forge dlib
106
- python scripts/crop_align_face.py -i [input folder] -o [output folder]
107
- ```
108
-
109
-
110
- #### Testing:
111
- [Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
112
-
113
- Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. The results will be saved in the `results` folder.
114
-
115
-
116
- 🧑🏻 Face Restoration (cropped and aligned face)
117
- ```
118
- # For cropped and aligned faces (512x512)
119
- python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path]
120
- ```
121
-
122
- :framed_picture: Whole Image Enhancement
123
- ```
124
- # For whole image
125
- # Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
126
- # Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
127
- python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path]
128
- ```
129
-
130
- :clapper: Video Enhancement
131
- ```
132
- # For Windows/Mac users, please install ffmpeg first
133
- conda install -c conda-forge ffmpeg
134
- ```
135
- ```
136
- # For video clips
137
- # Video path should end with '.mp4'|'.mov'|'.avi'
138
- python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path]
139
- ```
140
-
141
- 🌈 Face Colorization (cropped and aligned face)
142
- ```
143
- # For cropped and aligned faces (512x512)
144
- # Colorize black and white or faded photo
145
- python inference_colorization.py --input_path [image folder]|[image path]
146
- ```
147
-
148
- 🎨 Face Inpainting (cropped and aligned face)
149
- ```
150
- # For cropped and aligned faces (512x512)
151
- # Inputs could be masked by white brush using an image editing app (e.g., Photoshop)
152
- # (check out the examples in inputs/masked_faces)
153
- python inference_inpainting.py --input_path [image folder]|[image path]
154
- ```
155
- ### Training:
156
- The training commands can be found in the documents: [English](docs/train.md) **|** [简体中文](docs/train_CN.md).
157
-
158
- ### License
159
-
160
- This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
161
-
162
- ---
163
- ### 🐼 Ecosystem Applications & Deployments
164
-
165
- CodeFormer has been widely adopted and deployed across a broad range (>20) of online applications, platforms, API services, and independent websites, and has also been integrated into many open-source projects and toolkits.
166
-
167
- > Only demos on **Hugging Face Space**, **Replicate**, and **OpenXLab** are official deployments **maintained by the authors**. All other demos, APIs, apps, websites, and integrations listed below are **third-party (non-official)** and are not affiliated with the CodeFormer authors. Please verify their legitimacy to avoid potential financial loss.
168
-
169
-
170
- #### Websites (Non-official)
171
-
172
- ⚠️⚠️⚠️ The following websites are **not official and are not operated by us**. They use our models without any license or authorization. Please verify their legitimacy to avoid potential financial loss.
173
-
174
-
175
- | Website | Link | Notes |
176
- |---------|------|--------|
177
- | CodeFormer.net | https://codeformer.net/ | Non-official website |
178
- | CodeFormer.cn | https://www.codeformer.cn/ | Non-official website |
179
- | CodeFormerAI.com | https://codeformerai.com/ | Non-official website |
180
-
181
- #### Online Demos / API Platforms
182
-
183
- | Platform | Link | Notes |
184
- |----------|------|--------|
185
- | Hugging Face | https://huggingface.co/spaces/sczhou/CodeFormer | Maintained by Authors |
186
- | Replicate | https://replicate.com/sczhou/codeformer | Maintained by Authors |
187
- | OpenXLab | https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer |Maintained by Authors |
188
- | Segmind | https://www.segmind.com/models/codeformer | Non-official |
189
- | Sieve | https://www.sievedata.com/functions/sieve/codeformer | Non-official |
190
- | Fal.ai | https://fal.ai/models/fal-ai/codeformer | Non-official |
191
- | VaikerAI | https://vaikerai.com/sczhou/codeformer | Non-official |
192
- | Scade.pro | https://www.scade.pro/processors/lucataco-codeformer | Non-official |
193
- | Grandline | https://www.grandline.ai/model/codeformer | Non-official |
194
- | AI Demos | https://aidemos.com/tools/codeformer | Non-official |
195
- | Synexa | https://synexa.ai/explore/sczhou/codeformer | Non-official |
196
- | RentPrompts | https://rentprompts.ai/models/Codeformer | Non-official |
197
- | ElevaticsAI | https://elevatics.ai/models/super-resolution/codeformer | Non-official |
198
- | Anakin.ai | https://anakin.ai/apps/codeformer-online-face-restoration-by-codeformer-19343 | Non-official |
199
- | Relayto | https://relayto.com/explore/codeformer-yf9rj8kwc7zsr | Non-official |
200
-
201
-
202
- #### Open-Source Projects & Toolkits
203
-
204
- | Project / Toolkit | Link | Notes |
205
- |-------------------|------|--------|
206
- | Stable Diffusion GUI | https://nmkd.itch.io/t2i-gui | Integration |
207
- | Stable Diffusion WebUI | https://github.com/AUTOMATIC1111/stable-diffusion-webui | Integration |
208
- | ChaiNNer | https://github.com/chaiNNer-org/chaiNNer | Integration |
209
- | PyPI | https://pypi.org/project/codeformer/ ; https://pypi.org/project/codeformer-pip/ | Python packages |
210
- | ComfyUI | https://stable-diffusion-art.com/codeformer/ | Integration |
211
-
212
- ---
213
- ### Acknowledgement
214
-
215
- This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
216
-
217
- ### Citation
218
- If our work is useful for your research, please consider citing:
219
-
220
- @inproceedings{zhou2022codeformer,
221
- author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
222
- title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
223
- booktitle = {NeurIPS},
224
- year = {2022}
225
- }
226
-
227
-
228
- ### Contact
229
- If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/app.py DELETED
@@ -1,303 +0,0 @@
1
- """
2
- CodeFormer Flask Application
3
- Deployment on Hugging Face Spaces
4
- """
5
-
6
- import os
7
- import cv2
8
- import torch
9
- import uuid
10
- import numpy as np
11
- import zipfile
12
- from flask import Flask, render_template, request, send_file, url_for
13
- from werkzeug.utils import secure_filename
14
-
15
- from torchvision.transforms.functional import normalize
16
- from basicsr.archs.rrdbnet_arch import RRDBNet
17
- from basicsr.utils import imwrite, img2tensor, tensor2img
18
- from basicsr.utils.download_util import load_file_from_url
19
- from basicsr.utils.misc import gpu_is_available, get_device
20
- from basicsr.utils.realesrgan_utils import RealESRGANer
21
- from basicsr.utils.registry import ARCH_REGISTRY
22
-
23
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
24
- from facelib.utils.misc import is_gray
25
-
26
- # --- Initialization ---
27
- app = Flask(__name__)
28
- app.config['UPLOAD_FOLDER'] = 'static/uploads'
29
- app.config['RESULT_FOLDER'] = 'static/results'
30
- app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB limit
31
-
32
- # Ensure directories exist
33
- os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
34
- os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
35
- os.makedirs('weights/CodeFormer', exist_ok=True)
36
- os.makedirs('weights/facelib', exist_ok=True)
37
- os.makedirs('weights/realesrgan', exist_ok=True)
38
-
39
- # Pretrained model URLs
40
- pretrain_model_url = {
41
- 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
42
- 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
43
- 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
44
- 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
45
- }
46
-
47
- def download_weights():
48
- if not os.path.exists('weights/CodeFormer/codeformer.pth'):
49
- load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='weights/CodeFormer', progress=True, file_name=None)
50
- if not os.path.exists('weights/facelib/detection_Resnet50_Final.pth'):
51
- load_file_from_url(url=pretrain_model_url['detection'], model_dir='weights/facelib', progress=True, file_name=None)
52
- if not os.path.exists('weights/facelib/parsing_parsenet.pth'):
53
- load_file_from_url(url=pretrain_model_url['parsing'], model_dir='weights/facelib', progress=True, file_name=None)
54
- if not os.path.exists('weights/realesrgan/RealESRGAN_x2plus.pth'):
55
- load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='weights/realesrgan', progress=True, file_name=None)
56
-
57
- # Download weights on startup
58
- print("Checking weights...")
59
- download_weights()
60
-
61
- # Global models
62
- device = get_device()
63
- upsampler = None
64
- codeformer_net = None
65
-
66
- def init_models():
67
- global upsampler, codeformer_net
68
-
69
- # RealESRGAN
70
- half = True if gpu_is_available() else False
71
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
72
- upsampler = RealESRGANer(
73
- scale=2,
74
- model_path="weights/realesrgan/RealESRGAN_x2plus.pth",
75
- model=model,
76
- tile=400,
77
- tile_pad=40,
78
- pre_pad=0,
79
- half=half,
80
- )
81
-
82
- # CodeFormer
83
- codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
84
- dim_embd=512,
85
- codebook_size=1024,
86
- n_head=8,
87
- n_layers=9,
88
- connect_list=["32", "64", "128", "256"],
89
- ).to(device)
90
-
91
- ckpt_path = "weights/CodeFormer/codeformer.pth"
92
- checkpoint = torch.load(ckpt_path)["params_ema"]
93
- codeformer_net.load_state_dict(checkpoint)
94
- codeformer_net.eval()
95
- print("Models loaded successfully.")
96
-
97
- init_models()
98
-
99
- def process_image(img_path, background_enhance, face_upsample, upscale, codeformer_fidelity):
100
- """Core inference logic"""
101
- try:
102
- # Defaults
103
- has_aligned = False
104
- only_center_face = False
105
- draw_box = False
106
- detection_model = "retinaface_resnet50"
107
-
108
- img = cv2.imread(img_path, cv2.IMREAD_COLOR)
109
-
110
- # Memory safety checks
111
- upscale = int(upscale)
112
- if upscale > 4: upscale = 4
113
- if upscale > 2 and max(img.shape[:2]) > 1000: upscale = 2
114
- if max(img.shape[:2]) > 1500:
115
- upscale = 1
116
- background_enhance = False
117
- face_upsample = False
118
-
119
- face_helper = FaceRestoreHelper(
120
- upscale,
121
- face_size=512,
122
- crop_ratio=(1, 1),
123
- det_model=detection_model,
124
- save_ext="png",
125
- use_parse=True,
126
- device=device,
127
- )
128
-
129
- bg_upsampler = upsampler if background_enhance else None
130
- face_upsampler = upsampler if face_upsample else None
131
-
132
- if has_aligned:
133
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
134
- face_helper.is_gray = is_gray(img, threshold=5)
135
- face_helper.cropped_faces = [img]
136
- else:
137
- face_helper.read_image(img)
138
- face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
139
- face_helper.align_warp_face()
140
-
141
- # Face restoration
142
- for idx, cropped_face in enumerate(face_helper.cropped_faces):
143
- cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
144
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
145
- cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
146
-
147
- try:
148
- with torch.no_grad():
149
- output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
150
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
151
- except Exception as e:
152
- print(f"Inference error: {e}")
153
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
154
-
155
- restored_face = restored_face.astype("uint8")
156
- face_helper.add_restored_face(restored_face)
157
-
158
- # Paste back
159
- if not has_aligned:
160
- bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None
161
- face_helper.get_inverse_affine(None)
162
-
163
- if face_upsample and face_upsampler:
164
- restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler)
165
- else:
166
- restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
167
- else:
168
- restored_img = face_helper.restored_faces[0]
169
-
170
- return restored_img
171
-
172
- except Exception as e:
173
- print(f"Global processing error: {e}")
174
- return None
175
-
176
- # --- Routes ---
177
-
178
- @app.route('/', methods=['GET'])
179
- def index():
180
- return render_template('index.html')
181
-
182
- @app.route('/process', methods=['POST'])
183
- def process():
184
- if 'image' not in request.files:
185
- return "No image uploaded", 400
186
-
187
- files = request.files.getlist('image')
188
- if not files or files[0].filename == '':
189
- return "No selected file", 400
190
-
191
- results = []
192
-
193
- # Get params (same for all images)
194
- try:
195
- fidelity = float(request.form.get('fidelity', 0.5))
196
- upscale = 4 # Enforce 4x upscale
197
- background_enhance = 'background_enhance' in request.form
198
- face_upsample = 'face_upsample' in request.form
199
- except ValueError:
200
- return "Invalid parameters", 400
201
-
202
- for file in files:
203
- if file.filename == '': continue
204
-
205
- # Save input
206
- filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename)
207
- input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
208
- file.save(input_path)
209
-
210
- # Process
211
- result_img = process_image(input_path, background_enhance, face_upsample, upscale, fidelity)
212
-
213
- if result_img is None:
214
- continue # Skip failed images or handle error appropriately
215
-
216
- # Save output
217
- output_filename = "result_" + filename.rsplit('.', 1)[0] + ".png"
218
- output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
219
- imwrite(result_img, output_path)
220
-
221
- # Generate preview (max 1000px width/height)
222
- preview_filename = "preview_" + output_filename
223
- preview_path = os.path.join(app.config['RESULT_FOLDER'], preview_filename)
224
-
225
- h, w = result_img.shape[:2]
226
- if max(h, w) > 1000:
227
- scale = 1000 / max(h, w)
228
- preview_img = cv2.resize(result_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
229
- imwrite(preview_img, preview_path)
230
- else:
231
- preview_filename = output_filename
232
-
233
- results.append({
234
- 'original': filename,
235
- 'preview': preview_filename,
236
- 'download': output_filename
237
- })
238
-
239
- if not results:
240
- return "Processing failed for all images", 500
241
-
242
- # Create ZIP of all results
243
- zip_filename = f"batch_{uuid.uuid4()}.zip"
244
- zip_path = os.path.join(app.config['RESULT_FOLDER'], zip_filename)
245
-
246
- with zipfile.ZipFile(zip_path, 'w') as zipf:
247
- for item in results:
248
- file_path = os.path.join(app.config['RESULT_FOLDER'], item['download'])
249
- zipf.write(file_path, item['download'])
250
-
251
- return render_template('result.html', results=results, zip_filename=zip_filename)
252
-
253
- @app.route('/api/restore', methods=['POST'])
254
- @app.route('/api/process', methods=['POST'])
255
- def api_restore():
256
- """JSON API for programmatic restoration (used by id-maker)"""
257
- if 'image' not in request.files:
258
- return {"status": "error", "message": "No image uploaded"}, 400
259
-
260
- file = request.files['image']
261
- if file.filename == '':
262
- return {"status": "error", "message": "No selected file"}, 400
263
-
264
- try:
265
- # 1. Get parameters
266
- fidelity = float(request.form.get('fidelity', 0.5))
267
- upscale = int(request.form.get('upscale', 1))
268
- bg_enhance = request.form.get('background_enhance', 'false').lower() == 'true'
269
- face_upscale = request.form.get('face_upsample', 'false').lower() == 'true'
270
-
271
- # 2. Save input
272
- filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename)
273
- input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
274
- file.save(input_path)
275
-
276
- # 3. Process
277
- result_img = process_image(input_path, bg_enhance, face_upscale, upscale, fidelity)
278
-
279
- if result_img is None:
280
- return {"status": "error", "message": "Processing failed"}, 500
281
-
282
- # 4. Save output
283
- output_filename = "api_res_" + filename.rsplit('.', 1)[0] + ".png"
284
- output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
285
- imwrite(result_img, output_path)
286
-
287
- # 5. Return JSON with result URL
288
- base_url = request.host_url.rstrip('/')
289
- result_url = f"{base_url}/static/results/{output_filename}"
290
-
291
- return {
292
- "status": "success",
293
- "results": [{"image_url": result_url}],
294
- "message": "Restoration complete"
295
- }
296
-
297
- except Exception as e:
298
- return {"status": "error", "message": str(e)}, 500
299
-
300
- if __name__ == '__main__':
301
- # Docker/HF Spaces entry point
302
- port = int(os.environ.get("CODEFORMER_PORT", 7860))
303
- app.run(host='0.0.0.0', port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/VERSION DELETED
@@ -1 +0,0 @@
1
- 1.3.2
 
 
codeformer/basicsr/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- # https://github.com/xinntao/BasicSR
2
- # flake8: noqa
3
- from .archs import *
4
- from .data import *
5
- from .losses import *
6
- from .metrics import *
7
- from .models import *
8
- from .ops import *
9
- from .train import *
10
- from .utils import *
11
- from .version import __gitsha__, __version__
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- import importlib
2
- from copy import deepcopy
3
- from os import path as osp
4
-
5
- from basicsr.utils import get_root_logger, scandir
6
- from basicsr.utils.registry import ARCH_REGISTRY
7
-
8
- __all__ = ['build_network']
9
-
10
- # automatically scan and import arch modules for registry
11
- # scan all the files under the 'archs' folder and collect files ending with
12
- # '_arch.py'
13
- arch_folder = osp.dirname(osp.abspath(__file__))
14
- arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
- # import all the arch modules
16
- _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
-
18
-
19
- def build_network(opt):
20
- opt = deepcopy(opt)
21
- network_type = opt.pop('type')
22
- net = ARCH_REGISTRY.get(network_type)(**opt)
23
- logger = get_root_logger()
24
- logger.info(f'Network [{net.__class__.__name__}] is created.')
25
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/arcface_arch.py DELETED
@@ -1,245 +0,0 @@
1
- import torch.nn as nn
2
- from basicsr.utils.registry import ARCH_REGISTRY
3
-
4
-
5
- def conv3x3(inplanes, outplanes, stride=1):
6
- """A simple wrapper for 3x3 convolution with padding.
7
-
8
- Args:
9
- inplanes (int): Channel number of inputs.
10
- outplanes (int): Channel number of outputs.
11
- stride (int): Stride in convolution. Default: 1.
12
- """
13
- return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
14
-
15
-
16
- class BasicBlock(nn.Module):
17
- """Basic residual block used in the ResNetArcFace architecture.
18
-
19
- Args:
20
- inplanes (int): Channel number of inputs.
21
- planes (int): Channel number of outputs.
22
- stride (int): Stride in convolution. Default: 1.
23
- downsample (nn.Module): The downsample module. Default: None.
24
- """
25
- expansion = 1 # output channel expansion ratio
26
-
27
- def __init__(self, inplanes, planes, stride=1, downsample=None):
28
- super(BasicBlock, self).__init__()
29
- self.conv1 = conv3x3(inplanes, planes, stride)
30
- self.bn1 = nn.BatchNorm2d(planes)
31
- self.relu = nn.ReLU(inplace=True)
32
- self.conv2 = conv3x3(planes, planes)
33
- self.bn2 = nn.BatchNorm2d(planes)
34
- self.downsample = downsample
35
- self.stride = stride
36
-
37
- def forward(self, x):
38
- residual = x
39
-
40
- out = self.conv1(x)
41
- out = self.bn1(out)
42
- out = self.relu(out)
43
-
44
- out = self.conv2(out)
45
- out = self.bn2(out)
46
-
47
- if self.downsample is not None:
48
- residual = self.downsample(x)
49
-
50
- out += residual
51
- out = self.relu(out)
52
-
53
- return out
54
-
55
-
56
- class IRBlock(nn.Module):
57
- """Improved residual block (IR Block) used in the ResNetArcFace architecture.
58
-
59
- Args:
60
- inplanes (int): Channel number of inputs.
61
- planes (int): Channel number of outputs.
62
- stride (int): Stride in convolution. Default: 1.
63
- downsample (nn.Module): The downsample module. Default: None.
64
- use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
65
- """
66
- expansion = 1 # output channel expansion ratio
67
-
68
- def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
69
- super(IRBlock, self).__init__()
70
- self.bn0 = nn.BatchNorm2d(inplanes)
71
- self.conv1 = conv3x3(inplanes, inplanes)
72
- self.bn1 = nn.BatchNorm2d(inplanes)
73
- self.prelu = nn.PReLU()
74
- self.conv2 = conv3x3(inplanes, planes, stride)
75
- self.bn2 = nn.BatchNorm2d(planes)
76
- self.downsample = downsample
77
- self.stride = stride
78
- self.use_se = use_se
79
- if self.use_se:
80
- self.se = SEBlock(planes)
81
-
82
- def forward(self, x):
83
- residual = x
84
- out = self.bn0(x)
85
- out = self.conv1(out)
86
- out = self.bn1(out)
87
- out = self.prelu(out)
88
-
89
- out = self.conv2(out)
90
- out = self.bn2(out)
91
- if self.use_se:
92
- out = self.se(out)
93
-
94
- if self.downsample is not None:
95
- residual = self.downsample(x)
96
-
97
- out += residual
98
- out = self.prelu(out)
99
-
100
- return out
101
-
102
-
103
- class Bottleneck(nn.Module):
104
- """Bottleneck block used in the ResNetArcFace architecture.
105
-
106
- Args:
107
- inplanes (int): Channel number of inputs.
108
- planes (int): Channel number of outputs.
109
- stride (int): Stride in convolution. Default: 1.
110
- downsample (nn.Module): The downsample module. Default: None.
111
- """
112
- expansion = 4 # output channel expansion ratio
113
-
114
- def __init__(self, inplanes, planes, stride=1, downsample=None):
115
- super(Bottleneck, self).__init__()
116
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
117
- self.bn1 = nn.BatchNorm2d(planes)
118
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
119
- self.bn2 = nn.BatchNorm2d(planes)
120
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
121
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
122
- self.relu = nn.ReLU(inplace=True)
123
- self.downsample = downsample
124
- self.stride = stride
125
-
126
- def forward(self, x):
127
- residual = x
128
-
129
- out = self.conv1(x)
130
- out = self.bn1(out)
131
- out = self.relu(out)
132
-
133
- out = self.conv2(out)
134
- out = self.bn2(out)
135
- out = self.relu(out)
136
-
137
- out = self.conv3(out)
138
- out = self.bn3(out)
139
-
140
- if self.downsample is not None:
141
- residual = self.downsample(x)
142
-
143
- out += residual
144
- out = self.relu(out)
145
-
146
- return out
147
-
148
-
149
- class SEBlock(nn.Module):
150
- """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
151
-
152
- Args:
153
- channel (int): Channel number of inputs.
154
- reduction (int): Channel reduction ration. Default: 16.
155
- """
156
-
157
- def __init__(self, channel, reduction=16):
158
- super(SEBlock, self).__init__()
159
- self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
160
- self.fc = nn.Sequential(
161
- nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
162
- nn.Sigmoid())
163
-
164
- def forward(self, x):
165
- b, c, _, _ = x.size()
166
- y = self.avg_pool(x).view(b, c)
167
- y = self.fc(y).view(b, c, 1, 1)
168
- return x * y
169
-
170
-
171
- @ARCH_REGISTRY.register()
172
- class ResNetArcFace(nn.Module):
173
- """ArcFace with ResNet architectures.
174
-
175
- Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
176
-
177
- Args:
178
- block (str): Block used in the ArcFace architecture.
179
- layers (tuple(int)): Block numbers in each layer.
180
- use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
181
- """
182
-
183
- def __init__(self, block, layers, use_se=True):
184
- if block == 'IRBlock':
185
- block = IRBlock
186
- self.inplanes = 64
187
- self.use_se = use_se
188
- super(ResNetArcFace, self).__init__()
189
-
190
- self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
191
- self.bn1 = nn.BatchNorm2d(64)
192
- self.prelu = nn.PReLU()
193
- self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
194
- self.layer1 = self._make_layer(block, 64, layers[0])
195
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
196
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
197
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
198
- self.bn4 = nn.BatchNorm2d(512)
199
- self.dropout = nn.Dropout()
200
- self.fc5 = nn.Linear(512 * 8 * 8, 512)
201
- self.bn5 = nn.BatchNorm1d(512)
202
-
203
- # initialization
204
- for m in self.modules():
205
- if isinstance(m, nn.Conv2d):
206
- nn.init.xavier_normal_(m.weight)
207
- elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
208
- nn.init.constant_(m.weight, 1)
209
- nn.init.constant_(m.bias, 0)
210
- elif isinstance(m, nn.Linear):
211
- nn.init.xavier_normal_(m.weight)
212
- nn.init.constant_(m.bias, 0)
213
-
214
- def _make_layer(self, block, planes, num_blocks, stride=1):
215
- downsample = None
216
- if stride != 1 or self.inplanes != planes * block.expansion:
217
- downsample = nn.Sequential(
218
- nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
219
- nn.BatchNorm2d(planes * block.expansion),
220
- )
221
- layers = []
222
- layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
223
- self.inplanes = planes
224
- for _ in range(1, num_blocks):
225
- layers.append(block(self.inplanes, planes, use_se=self.use_se))
226
-
227
- return nn.Sequential(*layers)
228
-
229
- def forward(self, x):
230
- x = self.conv1(x)
231
- x = self.bn1(x)
232
- x = self.prelu(x)
233
- x = self.maxpool(x)
234
-
235
- x = self.layer1(x)
236
- x = self.layer2(x)
237
- x = self.layer3(x)
238
- x = self.layer4(x)
239
- x = self.bn4(x)
240
- x = self.dropout(x)
241
- x = x.view(x.size(0), -1)
242
- x = self.fc5(x)
243
- x = self.bn5(x)
244
-
245
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/arch_util.py DELETED
@@ -1,318 +0,0 @@
1
- import collections.abc
2
- import math
3
- import torch
4
- import torchvision
5
- import warnings
6
- from distutils.version import LooseVersion
7
- from itertools import repeat
8
- from torch import nn as nn
9
- from torch.nn import functional as F
10
- from torch.nn import init as init
11
- from torch.nn.modules.batchnorm import _BatchNorm
12
-
13
- from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
14
- from basicsr.utils import get_root_logger
15
-
16
-
17
- @torch.no_grad()
18
- def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
19
- """Initialize network weights.
20
-
21
- Args:
22
- module_list (list[nn.Module] | nn.Module): Modules to be initialized.
23
- scale (float): Scale initialized weights, especially for residual
24
- blocks. Default: 1.
25
- bias_fill (float): The value to fill bias. Default: 0
26
- kwargs (dict): Other arguments for initialization function.
27
- """
28
- if not isinstance(module_list, list):
29
- module_list = [module_list]
30
- for module in module_list:
31
- for m in module.modules():
32
- if isinstance(m, nn.Conv2d):
33
- init.kaiming_normal_(m.weight, **kwargs)
34
- m.weight.data *= scale
35
- if m.bias is not None:
36
- m.bias.data.fill_(bias_fill)
37
- elif isinstance(m, nn.Linear):
38
- init.kaiming_normal_(m.weight, **kwargs)
39
- m.weight.data *= scale
40
- if m.bias is not None:
41
- m.bias.data.fill_(bias_fill)
42
- elif isinstance(m, _BatchNorm):
43
- init.constant_(m.weight, 1)
44
- if m.bias is not None:
45
- m.bias.data.fill_(bias_fill)
46
-
47
-
48
- def make_layer(basic_block, num_basic_block, **kwarg):
49
- """Make layers by stacking the same blocks.
50
-
51
- Args:
52
- basic_block (nn.module): nn.module class for basic block.
53
- num_basic_block (int): number of blocks.
54
-
55
- Returns:
56
- nn.Sequential: Stacked blocks in nn.Sequential.
57
- """
58
- layers = []
59
- for _ in range(num_basic_block):
60
- layers.append(basic_block(**kwarg))
61
- return nn.Sequential(*layers)
62
-
63
-
64
- class ResidualBlockNoBN(nn.Module):
65
- """Residual block without BN.
66
-
67
- It has a style of:
68
- ---Conv-ReLU-Conv-+-
69
- |________________|
70
-
71
- Args:
72
- num_feat (int): Channel number of intermediate features.
73
- Default: 64.
74
- res_scale (float): Residual scale. Default: 1.
75
- pytorch_init (bool): If set to True, use pytorch default init,
76
- otherwise, use default_init_weights. Default: False.
77
- """
78
-
79
- def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
80
- super(ResidualBlockNoBN, self).__init__()
81
- self.res_scale = res_scale
82
- self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
83
- self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
84
- self.relu = nn.ReLU(inplace=True)
85
-
86
- if not pytorch_init:
87
- default_init_weights([self.conv1, self.conv2], 0.1)
88
-
89
- def forward(self, x):
90
- identity = x
91
- out = self.conv2(self.relu(self.conv1(x)))
92
- return identity + out * self.res_scale
93
-
94
-
95
- class Upsample(nn.Sequential):
96
- """Upsample module.
97
-
98
- Args:
99
- scale (int): Scale factor. Supported scales: 2^n and 3.
100
- num_feat (int): Channel number of intermediate features.
101
- """
102
-
103
- def __init__(self, scale, num_feat):
104
- m = []
105
- if (scale & (scale - 1)) == 0: # scale = 2^n
106
- for _ in range(int(math.log(scale, 2))):
107
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
108
- m.append(nn.PixelShuffle(2))
109
- elif scale == 3:
110
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
111
- m.append(nn.PixelShuffle(3))
112
- else:
113
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
114
- super(Upsample, self).__init__(*m)
115
-
116
-
117
- def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
118
- """Warp an image or feature map with optical flow.
119
-
120
- Args:
121
- x (Tensor): Tensor with size (n, c, h, w).
122
- flow (Tensor): Tensor with size (n, h, w, 2), normal value.
123
- interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
124
- padding_mode (str): 'zeros' or 'border' or 'reflection'.
125
- Default: 'zeros'.
126
- align_corners (bool): Before pytorch 1.3, the default value is
127
- align_corners=True. After pytorch 1.3, the default value is
128
- align_corners=False. Here, we use the True as default.
129
-
130
- Returns:
131
- Tensor: Warped image or feature map.
132
- """
133
- assert x.size()[-2:] == flow.size()[1:3]
134
- _, _, h, w = x.size()
135
- # create mesh grid
136
- grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
137
- grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
138
- grid.requires_grad = False
139
-
140
- vgrid = grid + flow
141
- # scale grid to [-1,1]
142
- vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
143
- vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
144
- vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
145
- output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
146
-
147
- # TODO, what if align_corners=False
148
- return output
149
-
150
-
151
- def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
152
- """Resize a flow according to ratio or shape.
153
-
154
- Args:
155
- flow (Tensor): Precomputed flow. shape [N, 2, H, W].
156
- size_type (str): 'ratio' or 'shape'.
157
- sizes (list[int | float]): the ratio for resizing or the final output
158
- shape.
159
- 1) The order of ratio should be [ratio_h, ratio_w]. For
160
- downsampling, the ratio should be smaller than 1.0 (i.e., ratio
161
- < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
162
- ratio > 1.0).
163
- 2) The order of output_size should be [out_h, out_w].
164
- interp_mode (str): The mode of interpolation for resizing.
165
- Default: 'bilinear'.
166
- align_corners (bool): Whether align corners. Default: False.
167
-
168
- Returns:
169
- Tensor: Resized flow.
170
- """
171
- _, _, flow_h, flow_w = flow.size()
172
- if size_type == 'ratio':
173
- output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
174
- elif size_type == 'shape':
175
- output_h, output_w = sizes[0], sizes[1]
176
- else:
177
- raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
178
-
179
- input_flow = flow.clone()
180
- ratio_h = output_h / flow_h
181
- ratio_w = output_w / flow_w
182
- input_flow[:, 0, :, :] *= ratio_w
183
- input_flow[:, 1, :, :] *= ratio_h
184
- resized_flow = F.interpolate(
185
- input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
186
- return resized_flow
187
-
188
-
189
- # TODO: may write a cpp file
190
- def pixel_unshuffle(x, scale):
191
- """ Pixel unshuffle.
192
-
193
- Args:
194
- x (Tensor): Input feature with shape (b, c, hh, hw).
195
- scale (int): Downsample ratio.
196
-
197
- Returns:
198
- Tensor: the pixel unshuffled feature.
199
- """
200
- b, c, hh, hw = x.size()
201
- out_channel = c * (scale**2)
202
- assert hh % scale == 0 and hw % scale == 0
203
- h = hh // scale
204
- w = hw // scale
205
- x_view = x.view(b, c, h, scale, w, scale)
206
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
207
-
208
-
209
- class DCNv2Pack(ModulatedDeformConvPack):
210
- """Modulated deformable conv for deformable alignment.
211
-
212
- Different from the official DCNv2Pack, which generates offsets and masks
213
- from the preceding features, this DCNv2Pack takes another different
214
- features to generate offsets and masks.
215
-
216
- Ref:
217
- Delving Deep into Deformable Alignment in Video Super-Resolution.
218
- """
219
-
220
- def forward(self, x, feat):
221
- out = self.conv_offset(feat)
222
- o1, o2, mask = torch.chunk(out, 3, dim=1)
223
- offset = torch.cat((o1, o2), dim=1)
224
- mask = torch.sigmoid(mask)
225
-
226
- offset_absmean = torch.mean(torch.abs(offset))
227
- if offset_absmean > 50:
228
- logger = get_root_logger()
229
- logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
230
-
231
- if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
232
- return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
233
- self.dilation, mask)
234
- else:
235
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
236
- self.dilation, self.groups, self.deformable_groups)
237
-
238
-
239
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
240
- # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
241
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
242
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
243
- def norm_cdf(x):
244
- # Computes standard normal cumulative distribution function
245
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
246
-
247
- if (mean < a - 2 * std) or (mean > b + 2 * std):
248
- warnings.warn(
249
- 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
250
- 'The distribution of values may be incorrect.',
251
- stacklevel=2)
252
-
253
- with torch.no_grad():
254
- # Values are generated by using a truncated uniform distribution and
255
- # then using the inverse CDF for the normal distribution.
256
- # Get upper and lower cdf values
257
- low = norm_cdf((a - mean) / std)
258
- up = norm_cdf((b - mean) / std)
259
-
260
- # Uniformly fill tensor with values from [low, up], then translate to
261
- # [2l-1, 2u-1].
262
- tensor.uniform_(2 * low - 1, 2 * up - 1)
263
-
264
- # Use inverse cdf transform for normal distribution to get truncated
265
- # standard normal
266
- tensor.erfinv_()
267
-
268
- # Transform to proper mean, std
269
- tensor.mul_(std * math.sqrt(2.))
270
- tensor.add_(mean)
271
-
272
- # Clamp to ensure it's in the proper range
273
- tensor.clamp_(min=a, max=b)
274
- return tensor
275
-
276
-
277
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
278
- r"""Fills the input Tensor with values drawn from a truncated
279
- normal distribution.
280
-
281
- From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
282
-
283
- The values are effectively drawn from the
284
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
285
- with values outside :math:`[a, b]` redrawn until they are within
286
- the bounds. The method used for generating the random values works
287
- best when :math:`a \leq \text{mean} \leq b`.
288
-
289
- Args:
290
- tensor: an n-dimensional `torch.Tensor`
291
- mean: the mean of the normal distribution
292
- std: the standard deviation of the normal distribution
293
- a: the minimum cutoff value
294
- b: the maximum cutoff value
295
-
296
- Examples:
297
- >>> w = torch.empty(3, 5)
298
- >>> nn.init.trunc_normal_(w)
299
- """
300
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
301
-
302
-
303
- # From PyTorch
304
- def _ntuple(n):
305
-
306
- def parse(x):
307
- if isinstance(x, collections.abc.Iterable):
308
- return x
309
- return tuple(repeat(x, n))
310
-
311
- return parse
312
-
313
-
314
- to_1tuple = _ntuple(1)
315
- to_2tuple = _ntuple(2)
316
- to_3tuple = _ntuple(3)
317
- to_4tuple = _ntuple(4)
318
- to_ntuple = _ntuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/codeformer_arch.py DELETED
@@ -1,280 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn, Tensor
5
- import torch.nn.functional as F
6
- from typing import Optional, List
7
-
8
- from basicsr.archs.vqgan_arch import *
9
- from basicsr.utils import get_root_logger
10
- from basicsr.utils.registry import ARCH_REGISTRY
11
-
12
- def calc_mean_std(feat, eps=1e-5):
13
- """Calculate mean and std for adaptive_instance_normalization.
14
-
15
- Args:
16
- feat (Tensor): 4D tensor.
17
- eps (float): A small value added to the variance to avoid
18
- divide-by-zero. Default: 1e-5.
19
- """
20
- size = feat.size()
21
- assert len(size) == 4, 'The input feature should be 4D tensor.'
22
- b, c = size[:2]
23
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
24
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
25
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
26
- return feat_mean, feat_std
27
-
28
-
29
- def adaptive_instance_normalization(content_feat, style_feat):
30
- """Adaptive instance normalization.
31
-
32
- Adjust the reference features to have the similar color and illuminations
33
- as those in the degradate features.
34
-
35
- Args:
36
- content_feat (Tensor): The reference feature.
37
- style_feat (Tensor): The degradate features.
38
- """
39
- size = content_feat.size()
40
- style_mean, style_std = calc_mean_std(style_feat)
41
- content_mean, content_std = calc_mean_std(content_feat)
42
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
43
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
44
-
45
-
46
- class PositionEmbeddingSine(nn.Module):
47
- """
48
- This is a more standard version of the position embedding, very similar to the one
49
- used by the Attention is all you need paper, generalized to work on images.
50
- """
51
-
52
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
53
- super().__init__()
54
- self.num_pos_feats = num_pos_feats
55
- self.temperature = temperature
56
- self.normalize = normalize
57
- if scale is not None and normalize is False:
58
- raise ValueError("normalize should be True if scale is passed")
59
- if scale is None:
60
- scale = 2 * math.pi
61
- self.scale = scale
62
-
63
- def forward(self, x, mask=None):
64
- if mask is None:
65
- mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
66
- not_mask = ~mask
67
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
68
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
69
- if self.normalize:
70
- eps = 1e-6
71
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
72
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
73
-
74
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
75
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
76
-
77
- pos_x = x_embed[:, :, :, None] / dim_t
78
- pos_y = y_embed[:, :, :, None] / dim_t
79
- pos_x = torch.stack(
80
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
81
- ).flatten(3)
82
- pos_y = torch.stack(
83
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
84
- ).flatten(3)
85
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
86
- return pos
87
-
88
- def _get_activation_fn(activation):
89
- """Return an activation function given a string"""
90
- if activation == "relu":
91
- return F.relu
92
- if activation == "gelu":
93
- return F.gelu
94
- if activation == "glu":
95
- return F.glu
96
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
97
-
98
-
99
- class TransformerSALayer(nn.Module):
100
- def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
101
- super().__init__()
102
- self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
103
- # Implementation of Feedforward model - MLP
104
- self.linear1 = nn.Linear(embed_dim, dim_mlp)
105
- self.dropout = nn.Dropout(dropout)
106
- self.linear2 = nn.Linear(dim_mlp, embed_dim)
107
-
108
- self.norm1 = nn.LayerNorm(embed_dim)
109
- self.norm2 = nn.LayerNorm(embed_dim)
110
- self.dropout1 = nn.Dropout(dropout)
111
- self.dropout2 = nn.Dropout(dropout)
112
-
113
- self.activation = _get_activation_fn(activation)
114
-
115
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
116
- return tensor if pos is None else tensor + pos
117
-
118
- def forward(self, tgt,
119
- tgt_mask: Optional[Tensor] = None,
120
- tgt_key_padding_mask: Optional[Tensor] = None,
121
- query_pos: Optional[Tensor] = None):
122
-
123
- # self attention
124
- tgt2 = self.norm1(tgt)
125
- q = k = self.with_pos_embed(tgt2, query_pos)
126
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
127
- key_padding_mask=tgt_key_padding_mask)[0]
128
- tgt = tgt + self.dropout1(tgt2)
129
-
130
- # ffn
131
- tgt2 = self.norm2(tgt)
132
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
133
- tgt = tgt + self.dropout2(tgt2)
134
- return tgt
135
-
136
- class Fuse_sft_block(nn.Module):
137
- def __init__(self, in_ch, out_ch):
138
- super().__init__()
139
- self.encode_enc = ResBlock(2*in_ch, out_ch)
140
-
141
- self.scale = nn.Sequential(
142
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
143
- nn.LeakyReLU(0.2, True),
144
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
145
-
146
- self.shift = nn.Sequential(
147
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
148
- nn.LeakyReLU(0.2, True),
149
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
150
-
151
- def forward(self, enc_feat, dec_feat, w=1):
152
- enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
153
- scale = self.scale(enc_feat)
154
- shift = self.shift(enc_feat)
155
- residual = w * (dec_feat * scale + shift)
156
- out = dec_feat + residual
157
- return out
158
-
159
-
160
- @ARCH_REGISTRY.register()
161
- class CodeFormer(VQAutoEncoder):
162
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
163
- codebook_size=1024, latent_size=256,
164
- connect_list=['32', '64', '128', '256'],
165
- fix_modules=['quantize','generator'], vqgan_path=None):
166
- super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
167
-
168
- if vqgan_path is not None:
169
- self.load_state_dict(
170
- torch.load(vqgan_path, map_location='cpu')['params_ema'])
171
-
172
- if fix_modules is not None:
173
- for module in fix_modules:
174
- for param in getattr(self, module).parameters():
175
- param.requires_grad = False
176
-
177
- self.connect_list = connect_list
178
- self.n_layers = n_layers
179
- self.dim_embd = dim_embd
180
- self.dim_mlp = dim_embd*2
181
-
182
- self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
183
- self.feat_emb = nn.Linear(256, self.dim_embd)
184
-
185
- # transformer
186
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
187
- for _ in range(self.n_layers)])
188
-
189
- # logits_predict head
190
- self.idx_pred_layer = nn.Sequential(
191
- nn.LayerNorm(dim_embd),
192
- nn.Linear(dim_embd, codebook_size, bias=False))
193
-
194
- self.channels = {
195
- '16': 512,
196
- '32': 256,
197
- '64': 256,
198
- '128': 128,
199
- '256': 128,
200
- '512': 64,
201
- }
202
-
203
- # after second residual block for > 16, before attn layer for ==16
204
- self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
205
- # after first residual block for > 16, before attn layer for ==16
206
- self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
207
-
208
- # fuse_convs_dict
209
- self.fuse_convs_dict = nn.ModuleDict()
210
- for f_size in self.connect_list:
211
- in_ch = self.channels[f_size]
212
- self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
213
-
214
- def _init_weights(self, module):
215
- if isinstance(module, (nn.Linear, nn.Embedding)):
216
- module.weight.data.normal_(mean=0.0, std=0.02)
217
- if isinstance(module, nn.Linear) and module.bias is not None:
218
- module.bias.data.zero_()
219
- elif isinstance(module, nn.LayerNorm):
220
- module.bias.data.zero_()
221
- module.weight.data.fill_(1.0)
222
-
223
- def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
224
- # ################### Encoder #####################
225
- enc_feat_dict = {}
226
- out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
227
- for i, block in enumerate(self.encoder.blocks):
228
- x = block(x)
229
- if i in out_list:
230
- enc_feat_dict[str(x.shape[-1])] = x.clone()
231
-
232
- lq_feat = x
233
- # ################# Transformer ###################
234
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
235
- pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
236
- # BCHW -> BC(HW) -> (HW)BC
237
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
238
- query_emb = feat_emb
239
- # Transformer encoder
240
- for layer in self.ft_layers:
241
- query_emb = layer(query_emb, query_pos=pos_emb)
242
-
243
- # output logits
244
- logits = self.idx_pred_layer(query_emb) # (hw)bn
245
- logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
246
-
247
- if code_only: # for training stage II
248
- # logits doesn't need softmax before cross_entropy loss
249
- return logits, lq_feat
250
-
251
- # ################# Quantization ###################
252
- # if self.training:
253
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
254
- # # b(hw)c -> bc(hw) -> bchw
255
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
256
- # ------------
257
- soft_one_hot = F.softmax(logits, dim=2)
258
- _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
259
- quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
260
- # preserve gradients
261
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
262
-
263
- if detach_16:
264
- quant_feat = quant_feat.detach() # for training stage III
265
- if adain:
266
- quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
267
-
268
- # ################## Generator ####################
269
- x = quant_feat
270
- fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
271
-
272
- for i, block in enumerate(self.generator.blocks):
273
- x = block(x)
274
- if i in fuse_list: # fuse after i-th block
275
- f_size = str(x.shape[-1])
276
- if w>0:
277
- x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
278
- out = x
279
- # logits doesn't need softmax before cross_entropy loss
280
- return out, logits, lq_feat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/rrdbnet_arch.py DELETED
@@ -1,119 +0,0 @@
1
- import torch
2
- from torch import nn as nn
3
- from torch.nn import functional as F
4
-
5
- from basicsr.utils.registry import ARCH_REGISTRY
6
- from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
-
8
-
9
- class ResidualDenseBlock(nn.Module):
10
- """Residual Dense Block.
11
-
12
- Used in RRDB block in ESRGAN.
13
-
14
- Args:
15
- num_feat (int): Channel number of intermediate features.
16
- num_grow_ch (int): Channels for each growth.
17
- """
18
-
19
- def __init__(self, num_feat=64, num_grow_ch=32):
20
- super(ResidualDenseBlock, self).__init__()
21
- self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
- self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
- self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
- self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
- self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
-
27
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
-
29
- # initialization
30
- default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
-
32
- def forward(self, x):
33
- x1 = self.lrelu(self.conv1(x))
34
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
- # Emperically, we use 0.2 to scale the residual for better performance
39
- return x5 * 0.2 + x
40
-
41
-
42
- class RRDB(nn.Module):
43
- """Residual in Residual Dense Block.
44
-
45
- Used in RRDB-Net in ESRGAN.
46
-
47
- Args:
48
- num_feat (int): Channel number of intermediate features.
49
- num_grow_ch (int): Channels for each growth.
50
- """
51
-
52
- def __init__(self, num_feat, num_grow_ch=32):
53
- super(RRDB, self).__init__()
54
- self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
- self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
- self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
-
58
- def forward(self, x):
59
- out = self.rdb1(x)
60
- out = self.rdb2(out)
61
- out = self.rdb3(out)
62
- # Emperically, we use 0.2 to scale the residual for better performance
63
- return out * 0.2 + x
64
-
65
-
66
- @ARCH_REGISTRY.register()
67
- class RRDBNet(nn.Module):
68
- """Networks consisting of Residual in Residual Dense Block, which is used
69
- in ESRGAN.
70
-
71
- ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
-
73
- We extend ESRGAN for scale x2 and scale x1.
74
- Note: This is one option for scale 1, scale 2 in RRDBNet.
75
- We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
- and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
-
78
- Args:
79
- num_in_ch (int): Channel number of inputs.
80
- num_out_ch (int): Channel number of outputs.
81
- num_feat (int): Channel number of intermediate features.
82
- Default: 64
83
- num_block (int): Block number in the trunk network. Defaults: 23
84
- num_grow_ch (int): Channels for each growth. Default: 32.
85
- """
86
-
87
- def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
- super(RRDBNet, self).__init__()
89
- self.scale = scale
90
- if scale == 2:
91
- num_in_ch = num_in_ch * 4
92
- elif scale == 1:
93
- num_in_ch = num_in_ch * 16
94
- self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
- self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
- self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
- # upsample
98
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
-
103
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
-
105
- def forward(self, x):
106
- if self.scale == 2:
107
- feat = pixel_unshuffle(x, scale=2)
108
- elif self.scale == 1:
109
- feat = pixel_unshuffle(x, scale=4)
110
- else:
111
- feat = x
112
- feat = self.conv_first(feat)
113
- body_feat = self.conv_body(self.body(feat))
114
- feat = feat + body_feat
115
- # upsample
116
- feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
- feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
- out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/vgg_arch.py DELETED
@@ -1,161 +0,0 @@
1
- import os
2
- import torch
3
- from collections import OrderedDict
4
- from torch import nn as nn
5
- from torchvision.models import vgg as vgg
6
-
7
- from basicsr.utils.registry import ARCH_REGISTRY
8
-
9
- VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
- NAMES = {
11
- 'vgg11': [
12
- 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
- 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
- 'pool5'
15
- ],
16
- 'vgg13': [
17
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
- 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
- ],
21
- 'vgg16': [
22
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
- 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
- 'pool5'
26
- ],
27
- 'vgg19': [
28
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
- 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
- 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
- ]
33
- }
34
-
35
-
36
- def insert_bn(names):
37
- """Insert bn layer after each conv.
38
-
39
- Args:
40
- names (list): The list of layer names.
41
-
42
- Returns:
43
- list: The list of layer names with bn layers.
44
- """
45
- names_bn = []
46
- for name in names:
47
- names_bn.append(name)
48
- if 'conv' in name:
49
- position = name.replace('conv', '')
50
- names_bn.append('bn' + position)
51
- return names_bn
52
-
53
-
54
- @ARCH_REGISTRY.register()
55
- class VGGFeatureExtractor(nn.Module):
56
- """VGG network for feature extraction.
57
-
58
- In this implementation, we allow users to choose whether use normalization
59
- in the input feature and the type of vgg network. Note that the pretrained
60
- path must fit the vgg type.
61
-
62
- Args:
63
- layer_name_list (list[str]): Forward function returns the corresponding
64
- features according to the layer_name_list.
65
- Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
- vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
- use_input_norm (bool): If True, normalize the input image. Importantly,
68
- the input feature must in the range [0, 1]. Default: True.
69
- range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
- Default: False.
71
- requires_grad (bool): If true, the parameters of VGG network will be
72
- optimized. Default: False.
73
- remove_pooling (bool): If true, the max pooling operations in VGG net
74
- will be removed. Default: False.
75
- pooling_stride (int): The stride of max pooling operation. Default: 2.
76
- """
77
-
78
- def __init__(self,
79
- layer_name_list,
80
- vgg_type='vgg19',
81
- use_input_norm=True,
82
- range_norm=False,
83
- requires_grad=False,
84
- remove_pooling=False,
85
- pooling_stride=2):
86
- super(VGGFeatureExtractor, self).__init__()
87
-
88
- self.layer_name_list = layer_name_list
89
- self.use_input_norm = use_input_norm
90
- self.range_norm = range_norm
91
-
92
- self.names = NAMES[vgg_type.replace('_bn', '')]
93
- if 'bn' in vgg_type:
94
- self.names = insert_bn(self.names)
95
-
96
- # only borrow layers that will be used to avoid unused params
97
- max_idx = 0
98
- for v in layer_name_list:
99
- idx = self.names.index(v)
100
- if idx > max_idx:
101
- max_idx = idx
102
-
103
- if os.path.exists(VGG_PRETRAIN_PATH):
104
- vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
- state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
- vgg_net.load_state_dict(state_dict)
107
- else:
108
- vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
-
110
- features = vgg_net.features[:max_idx + 1]
111
-
112
- modified_net = OrderedDict()
113
- for k, v in zip(self.names, features):
114
- if 'pool' in k:
115
- # if remove_pooling is true, pooling operation will be removed
116
- if remove_pooling:
117
- continue
118
- else:
119
- # in some cases, we may want to change the default stride
120
- modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
- else:
122
- modified_net[k] = v
123
-
124
- self.vgg_net = nn.Sequential(modified_net)
125
-
126
- if not requires_grad:
127
- self.vgg_net.eval()
128
- for param in self.parameters():
129
- param.requires_grad = False
130
- else:
131
- self.vgg_net.train()
132
- for param in self.parameters():
133
- param.requires_grad = True
134
-
135
- if self.use_input_norm:
136
- # the mean is for image with range [0, 1]
137
- self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
- # the std is for image with range [0, 1]
139
- self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
-
141
- def forward(self, x):
142
- """Forward function.
143
-
144
- Args:
145
- x (Tensor): Input tensor with shape (n, c, h, w).
146
-
147
- Returns:
148
- Tensor: Forward results.
149
- """
150
- if self.range_norm:
151
- x = (x + 1) / 2
152
- if self.use_input_norm:
153
- x = (x - self.mean) / self.std
154
- output = {}
155
-
156
- for key, layer in self.vgg_net._modules.items():
157
- x = layer(x)
158
- if key in self.layer_name_list:
159
- output[key] = x.clone()
160
-
161
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/archs/vqgan_arch.py DELETED
@@ -1,434 +0,0 @@
1
- '''
2
- VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
- https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
-
5
- '''
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import copy
11
- from basicsr.utils import get_root_logger
12
- from basicsr.utils.registry import ARCH_REGISTRY
13
-
14
- def normalize(in_channels):
15
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
16
-
17
-
18
- @torch.jit.script
19
- def swish(x):
20
- return x*torch.sigmoid(x)
21
-
22
-
23
- # Define VQVAE classes
24
- class VectorQuantizer(nn.Module):
25
- def __init__(self, codebook_size, emb_dim, beta):
26
- super(VectorQuantizer, self).__init__()
27
- self.codebook_size = codebook_size # number of embeddings
28
- self.emb_dim = emb_dim # dimension of embedding
29
- self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
30
- self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
31
- self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
32
-
33
- def forward(self, z):
34
- # reshape z -> (batch, height, width, channel) and flatten
35
- z = z.permute(0, 2, 3, 1).contiguous()
36
- z_flattened = z.view(-1, self.emb_dim)
37
-
38
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
39
- d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
40
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
41
-
42
- mean_distance = torch.mean(d)
43
- # find closest encodings
44
- min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
45
- # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
46
- # [0-1], higher score, higher confidence
47
- # min_encoding_scores = torch.exp(-min_encoding_scores/10)
48
-
49
- min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
50
- min_encodings.scatter_(1, min_encoding_indices, 1)
51
-
52
- # get quantized latent vectors
53
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
54
- # compute loss for embedding
55
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
56
- # preserve gradients
57
- z_q = z + (z_q - z).detach()
58
-
59
- # perplexity
60
- e_mean = torch.mean(min_encodings, dim=0)
61
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
62
- # reshape back to match original input shape
63
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
64
-
65
- return z_q, loss, {
66
- "perplexity": perplexity,
67
- "min_encodings": min_encodings,
68
- "min_encoding_indices": min_encoding_indices,
69
- "mean_distance": mean_distance
70
- }
71
-
72
- def get_codebook_feat(self, indices, shape):
73
- # input indices: batch*token_num -> (batch*token_num)*1
74
- # shape: batch, height, width, channel
75
- indices = indices.view(-1,1)
76
- min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
77
- min_encodings.scatter_(1, indices, 1)
78
- # get quantized latent vectors
79
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
80
-
81
- if shape is not None: # reshape back to match original input shape
82
- z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
83
-
84
- return z_q
85
-
86
-
87
- class GumbelQuantizer(nn.Module):
88
- def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
89
- super().__init__()
90
- self.codebook_size = codebook_size # number of embeddings
91
- self.emb_dim = emb_dim # dimension of embedding
92
- self.straight_through = straight_through
93
- self.temperature = temp_init
94
- self.kl_weight = kl_weight
95
- self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
96
- self.embed = nn.Embedding(codebook_size, emb_dim)
97
-
98
- def forward(self, z):
99
- hard = self.straight_through if self.training else True
100
-
101
- logits = self.proj(z)
102
-
103
- soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
104
-
105
- z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
106
-
107
- # + kl divergence to the prior loss
108
- qy = F.softmax(logits, dim=1)
109
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
110
- min_encoding_indices = soft_one_hot.argmax(dim=1)
111
-
112
- return z_q, diff, {
113
- "min_encoding_indices": min_encoding_indices
114
- }
115
-
116
-
117
- class Downsample(nn.Module):
118
- def __init__(self, in_channels):
119
- super().__init__()
120
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
121
-
122
- def forward(self, x):
123
- pad = (0, 1, 0, 1)
124
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
125
- x = self.conv(x)
126
- return x
127
-
128
-
129
- class Upsample(nn.Module):
130
- def __init__(self, in_channels):
131
- super().__init__()
132
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
133
-
134
- def forward(self, x):
135
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
136
- x = self.conv(x)
137
-
138
- return x
139
-
140
-
141
- class ResBlock(nn.Module):
142
- def __init__(self, in_channels, out_channels=None):
143
- super(ResBlock, self).__init__()
144
- self.in_channels = in_channels
145
- self.out_channels = in_channels if out_channels is None else out_channels
146
- self.norm1 = normalize(in_channels)
147
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
- self.norm2 = normalize(out_channels)
149
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
150
- if self.in_channels != self.out_channels:
151
- self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
152
-
153
- def forward(self, x_in):
154
- x = x_in
155
- x = self.norm1(x)
156
- x = swish(x)
157
- x = self.conv1(x)
158
- x = self.norm2(x)
159
- x = swish(x)
160
- x = self.conv2(x)
161
- if self.in_channels != self.out_channels:
162
- x_in = self.conv_out(x_in)
163
-
164
- return x + x_in
165
-
166
-
167
- class AttnBlock(nn.Module):
168
- def __init__(self, in_channels):
169
- super().__init__()
170
- self.in_channels = in_channels
171
-
172
- self.norm = normalize(in_channels)
173
- self.q = torch.nn.Conv2d(
174
- in_channels,
175
- in_channels,
176
- kernel_size=1,
177
- stride=1,
178
- padding=0
179
- )
180
- self.k = torch.nn.Conv2d(
181
- in_channels,
182
- in_channels,
183
- kernel_size=1,
184
- stride=1,
185
- padding=0
186
- )
187
- self.v = torch.nn.Conv2d(
188
- in_channels,
189
- in_channels,
190
- kernel_size=1,
191
- stride=1,
192
- padding=0
193
- )
194
- self.proj_out = torch.nn.Conv2d(
195
- in_channels,
196
- in_channels,
197
- kernel_size=1,
198
- stride=1,
199
- padding=0
200
- )
201
-
202
- def forward(self, x):
203
- h_ = x
204
- h_ = self.norm(h_)
205
- q = self.q(h_)
206
- k = self.k(h_)
207
- v = self.v(h_)
208
-
209
- # compute attention
210
- b, c, h, w = q.shape
211
- q = q.reshape(b, c, h*w)
212
- q = q.permute(0, 2, 1)
213
- k = k.reshape(b, c, h*w)
214
- w_ = torch.bmm(q, k)
215
- w_ = w_ * (int(c)**(-0.5))
216
- w_ = F.softmax(w_, dim=2)
217
-
218
- # attend to values
219
- v = v.reshape(b, c, h*w)
220
- w_ = w_.permute(0, 2, 1)
221
- h_ = torch.bmm(v, w_)
222
- h_ = h_.reshape(b, c, h, w)
223
-
224
- h_ = self.proj_out(h_)
225
-
226
- return x+h_
227
-
228
-
229
- class Encoder(nn.Module):
230
- def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
231
- super().__init__()
232
- self.nf = nf
233
- self.num_resolutions = len(ch_mult)
234
- self.num_res_blocks = num_res_blocks
235
- self.resolution = resolution
236
- self.attn_resolutions = attn_resolutions
237
-
238
- curr_res = self.resolution
239
- in_ch_mult = (1,)+tuple(ch_mult)
240
-
241
- blocks = []
242
- # initial convultion
243
- blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
244
-
245
- # residual and downsampling blocks, with attention on smaller res (16x16)
246
- for i in range(self.num_resolutions):
247
- block_in_ch = nf * in_ch_mult[i]
248
- block_out_ch = nf * ch_mult[i]
249
- for _ in range(self.num_res_blocks):
250
- blocks.append(ResBlock(block_in_ch, block_out_ch))
251
- block_in_ch = block_out_ch
252
- if curr_res in attn_resolutions:
253
- blocks.append(AttnBlock(block_in_ch))
254
-
255
- if i != self.num_resolutions - 1:
256
- blocks.append(Downsample(block_in_ch))
257
- curr_res = curr_res // 2
258
-
259
- # non-local attention block
260
- blocks.append(ResBlock(block_in_ch, block_in_ch))
261
- blocks.append(AttnBlock(block_in_ch))
262
- blocks.append(ResBlock(block_in_ch, block_in_ch))
263
-
264
- # normalise and convert to latent size
265
- blocks.append(normalize(block_in_ch))
266
- blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
267
- self.blocks = nn.ModuleList(blocks)
268
-
269
- def forward(self, x):
270
- for block in self.blocks:
271
- x = block(x)
272
-
273
- return x
274
-
275
-
276
- class Generator(nn.Module):
277
- def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
278
- super().__init__()
279
- self.nf = nf
280
- self.ch_mult = ch_mult
281
- self.num_resolutions = len(self.ch_mult)
282
- self.num_res_blocks = res_blocks
283
- self.resolution = img_size
284
- self.attn_resolutions = attn_resolutions
285
- self.in_channels = emb_dim
286
- self.out_channels = 3
287
- block_in_ch = self.nf * self.ch_mult[-1]
288
- curr_res = self.resolution // 2 ** (self.num_resolutions-1)
289
-
290
- blocks = []
291
- # initial conv
292
- blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
293
-
294
- # non-local attention block
295
- blocks.append(ResBlock(block_in_ch, block_in_ch))
296
- blocks.append(AttnBlock(block_in_ch))
297
- blocks.append(ResBlock(block_in_ch, block_in_ch))
298
-
299
- for i in reversed(range(self.num_resolutions)):
300
- block_out_ch = self.nf * self.ch_mult[i]
301
-
302
- for _ in range(self.num_res_blocks):
303
- blocks.append(ResBlock(block_in_ch, block_out_ch))
304
- block_in_ch = block_out_ch
305
-
306
- if curr_res in self.attn_resolutions:
307
- blocks.append(AttnBlock(block_in_ch))
308
-
309
- if i != 0:
310
- blocks.append(Upsample(block_in_ch))
311
- curr_res = curr_res * 2
312
-
313
- blocks.append(normalize(block_in_ch))
314
- blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
315
-
316
- self.blocks = nn.ModuleList(blocks)
317
-
318
-
319
- def forward(self, x):
320
- for block in self.blocks:
321
- x = block(x)
322
-
323
- return x
324
-
325
-
326
- @ARCH_REGISTRY.register()
327
- class VQAutoEncoder(nn.Module):
328
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
329
- beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
330
- super().__init__()
331
- logger = get_root_logger()
332
- self.in_channels = 3
333
- self.nf = nf
334
- self.n_blocks = res_blocks
335
- self.codebook_size = codebook_size
336
- self.embed_dim = emb_dim
337
- self.ch_mult = ch_mult
338
- self.resolution = img_size
339
- self.attn_resolutions = attn_resolutions
340
- self.quantizer_type = quantizer
341
- self.encoder = Encoder(
342
- self.in_channels,
343
- self.nf,
344
- self.embed_dim,
345
- self.ch_mult,
346
- self.n_blocks,
347
- self.resolution,
348
- self.attn_resolutions
349
- )
350
- if self.quantizer_type == "nearest":
351
- self.beta = beta #0.25
352
- self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
353
- elif self.quantizer_type == "gumbel":
354
- self.gumbel_num_hiddens = emb_dim
355
- self.straight_through = gumbel_straight_through
356
- self.kl_weight = gumbel_kl_weight
357
- self.quantize = GumbelQuantizer(
358
- self.codebook_size,
359
- self.embed_dim,
360
- self.gumbel_num_hiddens,
361
- self.straight_through,
362
- self.kl_weight
363
- )
364
- self.generator = Generator(
365
- self.nf,
366
- self.embed_dim,
367
- self.ch_mult,
368
- self.n_blocks,
369
- self.resolution,
370
- self.attn_resolutions
371
- )
372
-
373
- if model_path is not None:
374
- chkpt = torch.load(model_path, map_location='cpu')
375
- if 'params_ema' in chkpt:
376
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
377
- logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
378
- elif 'params' in chkpt:
379
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
380
- logger.info(f'vqgan is loaded from: {model_path} [params]')
381
- else:
382
- raise ValueError(f'Wrong params!')
383
-
384
-
385
- def forward(self, x):
386
- x = self.encoder(x)
387
- quant, codebook_loss, quant_stats = self.quantize(x)
388
- x = self.generator(quant)
389
- return x, codebook_loss, quant_stats
390
-
391
-
392
-
393
- # patch based discriminator
394
- @ARCH_REGISTRY.register()
395
- class VQGANDiscriminator(nn.Module):
396
- def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
397
- super().__init__()
398
-
399
- layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
400
- ndf_mult = 1
401
- ndf_mult_prev = 1
402
- for n in range(1, n_layers): # gradually increase the number of filters
403
- ndf_mult_prev = ndf_mult
404
- ndf_mult = min(2 ** n, 8)
405
- layers += [
406
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
407
- nn.BatchNorm2d(ndf * ndf_mult),
408
- nn.LeakyReLU(0.2, True)
409
- ]
410
-
411
- ndf_mult_prev = ndf_mult
412
- ndf_mult = min(2 ** n_layers, 8)
413
-
414
- layers += [
415
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
416
- nn.BatchNorm2d(ndf * ndf_mult),
417
- nn.LeakyReLU(0.2, True)
418
- ]
419
-
420
- layers += [
421
- nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
422
- self.main = nn.Sequential(*layers)
423
-
424
- if model_path is not None:
425
- chkpt = torch.load(model_path, map_location='cpu')
426
- if 'params_d' in chkpt:
427
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
428
- elif 'params' in chkpt:
429
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
430
- else:
431
- raise ValueError(f'Wrong params!')
432
-
433
- def forward(self, x):
434
- return self.main(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/__init__.py DELETED
@@ -1,100 +0,0 @@
1
- import importlib
2
- import numpy as np
3
- import random
4
- import torch
5
- import torch.utils.data
6
- from copy import deepcopy
7
- from functools import partial
8
- from os import path as osp
9
-
10
- from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
- from basicsr.utils import get_root_logger, scandir
12
- from basicsr.utils.dist_util import get_dist_info
13
- from basicsr.utils.registry import DATASET_REGISTRY
14
-
15
- __all__ = ['build_dataset', 'build_dataloader']
16
-
17
- # automatically scan and import dataset modules for registry
18
- # scan all the files under the data folder with '_dataset' in file names
19
- data_folder = osp.dirname(osp.abspath(__file__))
20
- dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
- # import all the dataset modules
22
- _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
-
24
-
25
- def build_dataset(dataset_opt):
26
- """Build dataset from options.
27
-
28
- Args:
29
- dataset_opt (dict): Configuration for dataset. It must constain:
30
- name (str): Dataset name.
31
- type (str): Dataset type.
32
- """
33
- dataset_opt = deepcopy(dataset_opt)
34
- dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
- logger = get_root_logger()
36
- logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
- return dataset
38
-
39
-
40
- def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
- """Build dataloader.
42
-
43
- Args:
44
- dataset (torch.utils.data.Dataset): Dataset.
45
- dataset_opt (dict): Dataset options. It contains the following keys:
46
- phase (str): 'train' or 'val'.
47
- num_worker_per_gpu (int): Number of workers for each GPU.
48
- batch_size_per_gpu (int): Training batch size for each GPU.
49
- num_gpu (int): Number of GPUs. Used only in the train phase.
50
- Default: 1.
51
- dist (bool): Whether in distributed training. Used only in the train
52
- phase. Default: False.
53
- sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
- seed (int | None): Seed. Default: None
55
- """
56
- phase = dataset_opt['phase']
57
- rank, _ = get_dist_info()
58
- if phase == 'train':
59
- if dist: # distributed training
60
- batch_size = dataset_opt['batch_size_per_gpu']
61
- num_workers = dataset_opt['num_worker_per_gpu']
62
- else: # non-distributed training
63
- multiplier = 1 if num_gpu == 0 else num_gpu
64
- batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
- num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
- dataloader_args = dict(
67
- dataset=dataset,
68
- batch_size=batch_size,
69
- shuffle=False,
70
- num_workers=num_workers,
71
- sampler=sampler,
72
- drop_last=True)
73
- if sampler is None:
74
- dataloader_args['shuffle'] = True
75
- dataloader_args['worker_init_fn'] = partial(
76
- worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
- elif phase in ['val', 'test']: # validation
78
- dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
- else:
80
- raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
-
82
- dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
-
84
- prefetch_mode = dataset_opt.get('prefetch_mode')
85
- if prefetch_mode == 'cpu': # CPUPrefetcher
86
- num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
87
- logger = get_root_logger()
88
- logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
89
- return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
90
- else:
91
- # prefetch_mode=None: Normal dataloader
92
- # prefetch_mode='cuda': dataloader for CUDAPrefetcher
93
- return torch.utils.data.DataLoader(**dataloader_args)
94
-
95
-
96
- def worker_init_fn(worker_id, num_workers, rank, seed):
97
- # Set the worker seed to num_workers * rank + worker_id + seed
98
- worker_seed = num_workers * rank + worker_id + seed
99
- np.random.seed(worker_seed)
100
- random.seed(worker_seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/data_sampler.py DELETED
@@ -1,48 +0,0 @@
1
- import math
2
- import torch
3
- from torch.utils.data.sampler import Sampler
4
-
5
-
6
- class EnlargedSampler(Sampler):
7
- """Sampler that restricts data loading to a subset of the dataset.
8
-
9
- Modified from torch.utils.data.distributed.DistributedSampler
10
- Support enlarging the dataset for iteration-based training, for saving
11
- time when restart the dataloader after each epoch
12
-
13
- Args:
14
- dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
- num_replicas (int | None): Number of processes participating in
16
- the training. It is usually the world_size.
17
- rank (int | None): Rank of the current process within num_replicas.
18
- ratio (int): Enlarging ratio. Default: 1.
19
- """
20
-
21
- def __init__(self, dataset, num_replicas, rank, ratio=1):
22
- self.dataset = dataset
23
- self.num_replicas = num_replicas
24
- self.rank = rank
25
- self.epoch = 0
26
- self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
- self.total_size = self.num_samples * self.num_replicas
28
-
29
- def __iter__(self):
30
- # deterministically shuffle based on epoch
31
- g = torch.Generator()
32
- g.manual_seed(self.epoch)
33
- indices = torch.randperm(self.total_size, generator=g).tolist()
34
-
35
- dataset_size = len(self.dataset)
36
- indices = [v % dataset_size for v in indices]
37
-
38
- # subsample
39
- indices = indices[self.rank:self.total_size:self.num_replicas]
40
- assert len(indices) == self.num_samples
41
-
42
- return iter(indices)
43
-
44
- def __len__(self):
45
- return self.num_samples
46
-
47
- def set_epoch(self, epoch):
48
- self.epoch = epoch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/data_util.py DELETED
@@ -1,392 +0,0 @@
1
- import cv2
2
- import math
3
- import numpy as np
4
- import torch
5
- from os import path as osp
6
- from PIL import Image, ImageDraw
7
- from torch.nn import functional as F
8
-
9
- from basicsr.data.transforms import mod_crop
10
- from basicsr.utils import img2tensor, scandir
11
-
12
-
13
- def read_img_seq(path, require_mod_crop=False, scale=1):
14
- """Read a sequence of images from a given folder path.
15
-
16
- Args:
17
- path (list[str] | str): List of image paths or image folder path.
18
- require_mod_crop (bool): Require mod crop for each image.
19
- Default: False.
20
- scale (int): Scale factor for mod_crop. Default: 1.
21
-
22
- Returns:
23
- Tensor: size (t, c, h, w), RGB, [0, 1].
24
- """
25
- if isinstance(path, list):
26
- img_paths = path
27
- else:
28
- img_paths = sorted(list(scandir(path, full_path=True)))
29
- imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
- if require_mod_crop:
31
- imgs = [mod_crop(img, scale) for img in imgs]
32
- imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
33
- imgs = torch.stack(imgs, dim=0)
34
- return imgs
35
-
36
-
37
- def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
38
- """Generate an index list for reading `num_frames` frames from a sequence
39
- of images.
40
-
41
- Args:
42
- crt_idx (int): Current center index.
43
- max_frame_num (int): Max number of the sequence of images (from 1).
44
- num_frames (int): Reading num_frames frames.
45
- padding (str): Padding mode, one of
46
- 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
47
- Examples: current_idx = 0, num_frames = 5
48
- The generated frame indices under different padding mode:
49
- replicate: [0, 0, 0, 1, 2]
50
- reflection: [2, 1, 0, 1, 2]
51
- reflection_circle: [4, 3, 0, 1, 2]
52
- circle: [3, 4, 0, 1, 2]
53
-
54
- Returns:
55
- list[int]: A list of indices.
56
- """
57
- assert num_frames % 2 == 1, 'num_frames should be an odd number.'
58
- assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
59
-
60
- max_frame_num = max_frame_num - 1 # start from 0
61
- num_pad = num_frames // 2
62
-
63
- indices = []
64
- for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
65
- if i < 0:
66
- if padding == 'replicate':
67
- pad_idx = 0
68
- elif padding == 'reflection':
69
- pad_idx = -i
70
- elif padding == 'reflection_circle':
71
- pad_idx = crt_idx + num_pad - i
72
- else:
73
- pad_idx = num_frames + i
74
- elif i > max_frame_num:
75
- if padding == 'replicate':
76
- pad_idx = max_frame_num
77
- elif padding == 'reflection':
78
- pad_idx = max_frame_num * 2 - i
79
- elif padding == 'reflection_circle':
80
- pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
81
- else:
82
- pad_idx = i - num_frames
83
- else:
84
- pad_idx = i
85
- indices.append(pad_idx)
86
- return indices
87
-
88
-
89
- def paired_paths_from_lmdb(folders, keys):
90
- """Generate paired paths from lmdb files.
91
-
92
- Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
93
-
94
- lq.lmdb
95
- ├── data.mdb
96
- ├── lock.mdb
97
- ├── meta_info.txt
98
-
99
- The data.mdb and lock.mdb are standard lmdb files and you can refer to
100
- https://lmdb.readthedocs.io/en/release/ for more details.
101
-
102
- The meta_info.txt is a specified txt file to record the meta information
103
- of our datasets. It will be automatically created when preparing
104
- datasets by our provided dataset tools.
105
- Each line in the txt file records
106
- 1)image name (with extension),
107
- 2)image shape,
108
- 3)compression level, separated by a white space.
109
- Example: `baboon.png (120,125,3) 1`
110
-
111
- We use the image name without extension as the lmdb key.
112
- Note that we use the same key for the corresponding lq and gt images.
113
-
114
- Args:
115
- folders (list[str]): A list of folder path. The order of list should
116
- be [input_folder, gt_folder].
117
- keys (list[str]): A list of keys identifying folders. The order should
118
- be in consistent with folders, e.g., ['lq', 'gt'].
119
- Note that this key is different from lmdb keys.
120
-
121
- Returns:
122
- list[str]: Returned path list.
123
- """
124
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
125
- f'But got {len(folders)}')
126
- assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
127
- input_folder, gt_folder = folders
128
- input_key, gt_key = keys
129
-
130
- if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
131
- raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
132
- f'formats. But received {input_key}: {input_folder}; '
133
- f'{gt_key}: {gt_folder}')
134
- # ensure that the two meta_info files are the same
135
- with open(osp.join(input_folder, 'meta_info.txt')) as fin:
136
- input_lmdb_keys = [line.split('.')[0] for line in fin]
137
- with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
138
- gt_lmdb_keys = [line.split('.')[0] for line in fin]
139
- if set(input_lmdb_keys) != set(gt_lmdb_keys):
140
- raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
141
- else:
142
- paths = []
143
- for lmdb_key in sorted(input_lmdb_keys):
144
- paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
145
- return paths
146
-
147
-
148
- def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
149
- """Generate paired paths from an meta information file.
150
-
151
- Each line in the meta information file contains the image names and
152
- image shape (usually for gt), separated by a white space.
153
-
154
- Example of an meta information file:
155
- ```
156
- 0001_s001.png (480,480,3)
157
- 0001_s002.png (480,480,3)
158
- ```
159
-
160
- Args:
161
- folders (list[str]): A list of folder path. The order of list should
162
- be [input_folder, gt_folder].
163
- keys (list[str]): A list of keys identifying folders. The order should
164
- be in consistent with folders, e.g., ['lq', 'gt'].
165
- meta_info_file (str): Path to the meta information file.
166
- filename_tmpl (str): Template for each filename. Note that the
167
- template excludes the file extension. Usually the filename_tmpl is
168
- for files in the input folder.
169
-
170
- Returns:
171
- list[str]: Returned path list.
172
- """
173
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
174
- f'But got {len(folders)}')
175
- assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
176
- input_folder, gt_folder = folders
177
- input_key, gt_key = keys
178
-
179
- with open(meta_info_file, 'r') as fin:
180
- gt_names = [line.split(' ')[0] for line in fin]
181
-
182
- paths = []
183
- for gt_name in gt_names:
184
- basename, ext = osp.splitext(osp.basename(gt_name))
185
- input_name = f'{filename_tmpl.format(basename)}{ext}'
186
- input_path = osp.join(input_folder, input_name)
187
- gt_path = osp.join(gt_folder, gt_name)
188
- paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
189
- return paths
190
-
191
-
192
- def paired_paths_from_folder(folders, keys, filename_tmpl):
193
- """Generate paired paths from folders.
194
-
195
- Args:
196
- folders (list[str]): A list of folder path. The order of list should
197
- be [input_folder, gt_folder].
198
- keys (list[str]): A list of keys identifying folders. The order should
199
- be in consistent with folders, e.g., ['lq', 'gt'].
200
- filename_tmpl (str): Template for each filename. Note that the
201
- template excludes the file extension. Usually the filename_tmpl is
202
- for files in the input folder.
203
-
204
- Returns:
205
- list[str]: Returned path list.
206
- """
207
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
208
- f'But got {len(folders)}')
209
- assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
210
- input_folder, gt_folder = folders
211
- input_key, gt_key = keys
212
-
213
- input_paths = list(scandir(input_folder))
214
- gt_paths = list(scandir(gt_folder))
215
- assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
216
- f'{len(input_paths)}, {len(gt_paths)}.')
217
- paths = []
218
- for gt_path in gt_paths:
219
- basename, ext = osp.splitext(osp.basename(gt_path))
220
- input_name = f'{filename_tmpl.format(basename)}{ext}'
221
- input_path = osp.join(input_folder, input_name)
222
- assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
223
- gt_path = osp.join(gt_folder, gt_path)
224
- paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
225
- return paths
226
-
227
-
228
- def paths_from_folder(folder):
229
- """Generate paths from folder.
230
-
231
- Args:
232
- folder (str): Folder path.
233
-
234
- Returns:
235
- list[str]: Returned path list.
236
- """
237
-
238
- paths = list(scandir(folder))
239
- paths = [osp.join(folder, path) for path in paths]
240
- return paths
241
-
242
-
243
- def paths_from_lmdb(folder):
244
- """Generate paths from lmdb.
245
-
246
- Args:
247
- folder (str): Folder path.
248
-
249
- Returns:
250
- list[str]: Returned path list.
251
- """
252
- if not folder.endswith('.lmdb'):
253
- raise ValueError(f'Folder {folder}folder should in lmdb format.')
254
- with open(osp.join(folder, 'meta_info.txt')) as fin:
255
- paths = [line.split('.')[0] for line in fin]
256
- return paths
257
-
258
-
259
- def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
260
- """Generate Gaussian kernel used in `duf_downsample`.
261
-
262
- Args:
263
- kernel_size (int): Kernel size. Default: 13.
264
- sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
265
-
266
- Returns:
267
- np.array: The Gaussian kernel.
268
- """
269
- from scipy.ndimage import filters as filters
270
- kernel = np.zeros((kernel_size, kernel_size))
271
- # set element at the middle to one, a dirac delta
272
- kernel[kernel_size // 2, kernel_size // 2] = 1
273
- # gaussian-smooth the dirac, resulting in a gaussian filter
274
- return filters.gaussian_filter(kernel, sigma)
275
-
276
-
277
- def duf_downsample(x, kernel_size=13, scale=4):
278
- """Downsamping with Gaussian kernel used in the DUF official code.
279
-
280
- Args:
281
- x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
282
- kernel_size (int): Kernel size. Default: 13.
283
- scale (int): Downsampling factor. Supported scale: (2, 3, 4).
284
- Default: 4.
285
-
286
- Returns:
287
- Tensor: DUF downsampled frames.
288
- """
289
- assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
290
-
291
- squeeze_flag = False
292
- if x.ndim == 4:
293
- squeeze_flag = True
294
- x = x.unsqueeze(0)
295
- b, t, c, h, w = x.size()
296
- x = x.view(-1, 1, h, w)
297
- pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
298
- x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
299
-
300
- gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
301
- gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
302
- x = F.conv2d(x, gaussian_filter, stride=scale)
303
- x = x[:, :, 2:-2, 2:-2]
304
- x = x.view(b, t, c, x.size(2), x.size(3))
305
- if squeeze_flag:
306
- x = x.squeeze(0)
307
- return x
308
-
309
-
310
- def brush_stroke_mask(img, color=(255,255,255)):
311
- min_num_vertex = 8
312
- max_num_vertex = 28
313
- mean_angle = 2*math.pi / 5
314
- angle_range = 2*math.pi / 12
315
- # training large mask ratio (training setting)
316
- min_width = 30
317
- max_width = 70
318
- # very large mask ratio (test setting and refine after 200k)
319
- # min_width = 80
320
- # max_width = 120
321
- def generate_mask(H, W, img=None):
322
- average_radius = math.sqrt(H*H+W*W) / 8
323
- mask = Image.new('RGB', (W, H), 0)
324
- if img is not None: mask = img # Image.fromarray(img)
325
-
326
- for _ in range(np.random.randint(1, 4)):
327
- num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
328
- angle_min = mean_angle - np.random.uniform(0, angle_range)
329
- angle_max = mean_angle + np.random.uniform(0, angle_range)
330
- angles = []
331
- vertex = []
332
- for i in range(num_vertex):
333
- if i % 2 == 0:
334
- angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
335
- else:
336
- angles.append(np.random.uniform(angle_min, angle_max))
337
-
338
- h, w = mask.size
339
- vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
340
- for i in range(num_vertex):
341
- r = np.clip(
342
- np.random.normal(loc=average_radius, scale=average_radius//2),
343
- 0, 2*average_radius)
344
- new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
345
- new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
346
- vertex.append((int(new_x), int(new_y)))
347
-
348
- draw = ImageDraw.Draw(mask)
349
- width = int(np.random.uniform(min_width, max_width))
350
- draw.line(vertex, fill=color, width=width)
351
- for v in vertex:
352
- draw.ellipse((v[0] - width//2,
353
- v[1] - width//2,
354
- v[0] + width//2,
355
- v[1] + width//2),
356
- fill=color)
357
-
358
- return mask
359
-
360
- width, height = img.size
361
- mask = generate_mask(height, width, img)
362
- return mask
363
-
364
-
365
- def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
366
- """Generate a random free form mask with configuration.
367
- Args:
368
- config: Config should have configuration including IMG_SHAPES,
369
- VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
370
- Returns:
371
- tuple: (top, left, height, width)
372
- Link:
373
- https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
374
- """
375
- height = shape[0]
376
- width = shape[1]
377
- mask = np.zeros((height, width), np.float32)
378
- times = np.random.randint(times-5, times)
379
- for i in range(times):
380
- start_x = np.random.randint(width)
381
- start_y = np.random.randint(height)
382
- for j in range(1 + np.random.randint(5)):
383
- angle = 0.01 + np.random.randint(max_angle)
384
- if i % 2 == 0:
385
- angle = 2 * 3.1415926 - angle
386
- length = 10 + np.random.randint(max_len-20, max_len)
387
- brush_w = 5 + np.random.randint(max_width-30, max_width)
388
- end_x = (start_x + length * np.sin(angle)).astype(np.int32)
389
- end_y = (start_y + length * np.cos(angle)).astype(np.int32)
390
- cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
391
- start_x, start_y = end_x, end_y
392
- return mask.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/ffhq_blind_dataset.py DELETED
@@ -1,299 +0,0 @@
1
- import cv2
2
- import math
3
- import random
4
- import numpy as np
5
- import os.path as osp
6
- from scipy.io import loadmat
7
- from PIL import Image
8
- import torch
9
- import torch.utils.data as data
10
- from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
11
- adjust_hue, adjust_saturation, normalize)
12
- from basicsr.data import gaussian_kernels as gaussian_kernels
13
- from basicsr.data.transforms import augment
14
- from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
15
- from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
16
- from basicsr.utils.registry import DATASET_REGISTRY
17
-
18
- @DATASET_REGISTRY.register()
19
- class FFHQBlindDataset(data.Dataset):
20
-
21
- def __init__(self, opt):
22
- super(FFHQBlindDataset, self).__init__()
23
- logger = get_root_logger()
24
- self.opt = opt
25
- # file client (io backend)
26
- self.file_client = None
27
- self.io_backend_opt = opt['io_backend']
28
-
29
- self.gt_folder = opt['dataroot_gt']
30
- self.gt_size = opt.get('gt_size', 512)
31
- self.in_size = opt.get('in_size', 512)
32
- assert self.gt_size >= self.in_size, 'Wrong setting.'
33
-
34
- self.mean = opt.get('mean', [0.5, 0.5, 0.5])
35
- self.std = opt.get('std', [0.5, 0.5, 0.5])
36
-
37
- self.component_path = opt.get('component_path', None)
38
- self.latent_gt_path = opt.get('latent_gt_path', None)
39
-
40
- if self.component_path is not None:
41
- self.crop_components = True
42
- self.components_dict = torch.load(self.component_path)
43
- self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
44
- self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
45
- self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
46
- else:
47
- self.crop_components = False
48
-
49
- if self.latent_gt_path is not None:
50
- self.load_latent_gt = True
51
- self.latent_gt_dict = torch.load(self.latent_gt_path)
52
- else:
53
- self.load_latent_gt = False
54
-
55
- if self.io_backend_opt['type'] == 'lmdb':
56
- self.io_backend_opt['db_paths'] = self.gt_folder
57
- if not self.gt_folder.endswith('.lmdb'):
58
- raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
59
- with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
60
- self.paths = [line.split('.')[0] for line in fin]
61
- else:
62
- self.paths = paths_from_folder(self.gt_folder)
63
-
64
- # inpainting mask
65
- self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
66
- if self.gen_inpaint_mask:
67
- logger.info(f'generate mask ...')
68
- # self.mask_max_angle = opt.get('mask_max_angle', 10)
69
- # self.mask_max_len = opt.get('mask_max_len', 150)
70
- # self.mask_max_width = opt.get('mask_max_width', 50)
71
- # self.mask_draw_times = opt.get('mask_draw_times', 10)
72
- # # print
73
- # logger.info(f'mask_max_angle: {self.mask_max_angle}')
74
- # logger.info(f'mask_max_len: {self.mask_max_len}')
75
- # logger.info(f'mask_max_width: {self.mask_max_width}')
76
- # logger.info(f'mask_draw_times: {self.mask_draw_times}')
77
-
78
- # perform corrupt
79
- self.use_corrupt = opt.get('use_corrupt', True)
80
- self.use_motion_kernel = False
81
- # self.use_motion_kernel = opt.get('use_motion_kernel', True)
82
-
83
- if self.use_motion_kernel:
84
- self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
85
- motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
86
- self.motion_kernels = torch.load(motion_kernel_path)
87
-
88
- if self.use_corrupt and not self.gen_inpaint_mask:
89
- # degradation configurations
90
- self.blur_kernel_size = opt['blur_kernel_size']
91
- self.blur_sigma = opt['blur_sigma']
92
- self.kernel_list = opt['kernel_list']
93
- self.kernel_prob = opt['kernel_prob']
94
- self.downsample_range = opt['downsample_range']
95
- self.noise_range = opt['noise_range']
96
- self.jpeg_range = opt['jpeg_range']
97
- # print
98
- logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
99
- logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
100
- logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
101
- logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
102
-
103
- # color jitter
104
- self.color_jitter_prob = opt.get('color_jitter_prob', None)
105
- self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
106
- self.color_jitter_shift = opt.get('color_jitter_shift', 20)
107
- if self.color_jitter_prob is not None:
108
- logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
109
-
110
- # to gray
111
- self.gray_prob = opt.get('gray_prob', 0.0)
112
- if self.gray_prob is not None:
113
- logger.info(f'Use random gray. Prob: {self.gray_prob}')
114
- self.color_jitter_shift /= 255.
115
-
116
- @staticmethod
117
- def color_jitter(img, shift):
118
- """jitter color: randomly jitter the RGB values, in numpy formats"""
119
- jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
120
- img = img + jitter_val
121
- img = np.clip(img, 0, 1)
122
- return img
123
-
124
- @staticmethod
125
- def color_jitter_pt(img, brightness, contrast, saturation, hue):
126
- """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
127
- fn_idx = torch.randperm(4)
128
- for fn_id in fn_idx:
129
- if fn_id == 0 and brightness is not None:
130
- brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
131
- img = adjust_brightness(img, brightness_factor)
132
-
133
- if fn_id == 1 and contrast is not None:
134
- contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
135
- img = adjust_contrast(img, contrast_factor)
136
-
137
- if fn_id == 2 and saturation is not None:
138
- saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
139
- img = adjust_saturation(img, saturation_factor)
140
-
141
- if fn_id == 3 and hue is not None:
142
- hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
143
- img = adjust_hue(img, hue_factor)
144
- return img
145
-
146
-
147
- def get_component_locations(self, name, status):
148
- components_bbox = self.components_dict[name]
149
- if status[0]: # hflip
150
- # exchange right and left eye
151
- tmp = components_bbox['left_eye']
152
- components_bbox['left_eye'] = components_bbox['right_eye']
153
- components_bbox['right_eye'] = tmp
154
- # modify the width coordinate
155
- components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
156
- components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
157
- components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
158
- components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
159
-
160
- locations_gt = {}
161
- locations_in = {}
162
- for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
163
- mean = components_bbox[part][0:2]
164
- half_len = components_bbox[part][2]
165
- if 'eye' in part:
166
- half_len *= self.eye_enlarge_ratio
167
- elif part == 'nose':
168
- half_len *= self.nose_enlarge_ratio
169
- elif part == 'mouth':
170
- half_len *= self.mouth_enlarge_ratio
171
- loc = np.hstack((mean - half_len + 1, mean + half_len))
172
- loc = torch.from_numpy(loc).float()
173
- locations_gt[part] = loc
174
- loc_in = loc/(self.gt_size//self.in_size)
175
- locations_in[part] = loc_in
176
- return locations_gt, locations_in
177
-
178
-
179
- def __getitem__(self, index):
180
- if self.file_client is None:
181
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
182
-
183
- # load gt image
184
- gt_path = self.paths[index]
185
- name = osp.basename(gt_path)[:-4]
186
- img_bytes = self.file_client.get(gt_path)
187
- img_gt = imfrombytes(img_bytes, float32=True)
188
-
189
- # random horizontal flip
190
- img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
191
-
192
- if self.load_latent_gt:
193
- if status[0]:
194
- latent_gt = self.latent_gt_dict['hflip'][name]
195
- else:
196
- latent_gt = self.latent_gt_dict['orig'][name]
197
-
198
- if self.crop_components:
199
- locations_gt, locations_in = self.get_component_locations(name, status)
200
-
201
- # generate in image
202
- img_in = img_gt
203
- if self.use_corrupt and not self.gen_inpaint_mask:
204
- # motion blur
205
- if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
206
- m_i = random.randint(0,31)
207
- k = self.motion_kernels[f'{m_i:02d}']
208
- img_in = cv2.filter2D(img_in,-1,k)
209
-
210
- # gaussian blur
211
- kernel = gaussian_kernels.random_mixed_kernels(
212
- self.kernel_list,
213
- self.kernel_prob,
214
- self.blur_kernel_size,
215
- self.blur_sigma,
216
- self.blur_sigma,
217
- [-math.pi, math.pi],
218
- noise_range=None)
219
- img_in = cv2.filter2D(img_in, -1, kernel)
220
-
221
- # downsample
222
- scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
223
- img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
224
-
225
- # noise
226
- if self.noise_range is not None:
227
- noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
228
- noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
229
- img_in = img_in + noise
230
- img_in = np.clip(img_in, 0, 1)
231
-
232
- # jpeg
233
- if self.jpeg_range is not None:
234
- jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
235
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
236
- _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
237
- img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
238
-
239
- # resize to in_size
240
- img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
241
-
242
- # if self.gen_inpaint_mask:
243
- # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
244
- # max_angle = self.mask_max_angle, max_len = self.mask_max_len,
245
- # max_width = self.mask_max_width, times = self.mask_draw_times)
246
- # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
247
- # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
248
-
249
- # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
250
-
251
- if self.gen_inpaint_mask:
252
- img_in = (img_in*255).astype('uint8')
253
- img_in = brush_stroke_mask(Image.fromarray(img_in))
254
- img_in = np.array(img_in) / 255.
255
-
256
- # random color jitter (only for lq)
257
- if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
258
- img_in = self.color_jitter(img_in, self.color_jitter_shift)
259
- # random to gray (only for lq)
260
- if self.gray_prob and np.random.uniform() < self.gray_prob:
261
- img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
262
- img_in = np.tile(img_in[:, :, None], [1, 1, 3])
263
-
264
- # BGR to RGB, HWC to CHW, numpy to tensor
265
- img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
266
-
267
- # random color jitter (pytorch version) (only for lq)
268
- if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
269
- brightness = self.opt.get('brightness', (0.5, 1.5))
270
- contrast = self.opt.get('contrast', (0.5, 1.5))
271
- saturation = self.opt.get('saturation', (0, 1.5))
272
- hue = self.opt.get('hue', (-0.1, 0.1))
273
- img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
274
-
275
- # round and clip
276
- img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
277
-
278
- # Set vgg range_norm=True if use the normalization here
279
- # normalize
280
- normalize(img_in, self.mean, self.std, inplace=True)
281
- normalize(img_gt, self.mean, self.std, inplace=True)
282
-
283
- return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
284
-
285
- if self.crop_components:
286
- return_dict['locations_in'] = locations_in
287
- return_dict['locations_gt'] = locations_gt
288
-
289
- if self.load_latent_gt:
290
- return_dict['latent_gt'] = latent_gt
291
-
292
- # if self.gen_inpaint_mask:
293
- # return_dict['inpaint_mask'] = inpaint_mask
294
-
295
- return return_dict
296
-
297
-
298
- def __len__(self):
299
- return len(self.paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/ffhq_blind_joint_dataset.py DELETED
@@ -1,324 +0,0 @@
1
- import cv2
2
- import math
3
- import random
4
- import numpy as np
5
- import os.path as osp
6
- from scipy.io import loadmat
7
- import torch
8
- import torch.utils.data as data
9
- from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
10
- adjust_hue, adjust_saturation, normalize)
11
- from basicsr.data import gaussian_kernels as gaussian_kernels
12
- from basicsr.data.transforms import augment
13
- from basicsr.data.data_util import paths_from_folder
14
- from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
15
- from basicsr.utils.registry import DATASET_REGISTRY
16
-
17
- @DATASET_REGISTRY.register()
18
- class FFHQBlindJointDataset(data.Dataset):
19
-
20
- def __init__(self, opt):
21
- super(FFHQBlindJointDataset, self).__init__()
22
- logger = get_root_logger()
23
- self.opt = opt
24
- # file client (io backend)
25
- self.file_client = None
26
- self.io_backend_opt = opt['io_backend']
27
-
28
- self.gt_folder = opt['dataroot_gt']
29
- self.gt_size = opt.get('gt_size', 512)
30
- self.in_size = opt.get('in_size', 512)
31
- assert self.gt_size >= self.in_size, 'Wrong setting.'
32
-
33
- self.mean = opt.get('mean', [0.5, 0.5, 0.5])
34
- self.std = opt.get('std', [0.5, 0.5, 0.5])
35
-
36
- self.component_path = opt.get('component_path', None)
37
- self.latent_gt_path = opt.get('latent_gt_path', None)
38
-
39
- if self.component_path is not None:
40
- self.crop_components = True
41
- self.components_dict = torch.load(self.component_path)
42
- self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
43
- self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
44
- self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
45
- else:
46
- self.crop_components = False
47
-
48
- if self.latent_gt_path is not None:
49
- self.load_latent_gt = True
50
- self.latent_gt_dict = torch.load(self.latent_gt_path)
51
- else:
52
- self.load_latent_gt = False
53
-
54
- if self.io_backend_opt['type'] == 'lmdb':
55
- self.io_backend_opt['db_paths'] = self.gt_folder
56
- if not self.gt_folder.endswith('.lmdb'):
57
- raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
58
- with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
59
- self.paths = [line.split('.')[0] for line in fin]
60
- else:
61
- self.paths = paths_from_folder(self.gt_folder)
62
-
63
- # perform corrupt
64
- self.use_corrupt = opt.get('use_corrupt', True)
65
- self.use_motion_kernel = False
66
- # self.use_motion_kernel = opt.get('use_motion_kernel', True)
67
-
68
- if self.use_motion_kernel:
69
- self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
70
- motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
71
- self.motion_kernels = torch.load(motion_kernel_path)
72
-
73
- if self.use_corrupt:
74
- # degradation configurations
75
- self.blur_kernel_size = self.opt['blur_kernel_size']
76
- self.kernel_list = self.opt['kernel_list']
77
- self.kernel_prob = self.opt['kernel_prob']
78
- # Small degradation
79
- self.blur_sigma = self.opt['blur_sigma']
80
- self.downsample_range = self.opt['downsample_range']
81
- self.noise_range = self.opt['noise_range']
82
- self.jpeg_range = self.opt['jpeg_range']
83
- # Large degradation
84
- self.blur_sigma_large = self.opt['blur_sigma_large']
85
- self.downsample_range_large = self.opt['downsample_range_large']
86
- self.noise_range_large = self.opt['noise_range_large']
87
- self.jpeg_range_large = self.opt['jpeg_range_large']
88
-
89
- # print
90
- logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
91
- logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
92
- logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
93
- logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
94
-
95
- # color jitter
96
- self.color_jitter_prob = opt.get('color_jitter_prob', None)
97
- self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
98
- self.color_jitter_shift = opt.get('color_jitter_shift', 20)
99
- if self.color_jitter_prob is not None:
100
- logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
101
-
102
- # to gray
103
- self.gray_prob = opt.get('gray_prob', 0.0)
104
- if self.gray_prob is not None:
105
- logger.info(f'Use random gray. Prob: {self.gray_prob}')
106
- self.color_jitter_shift /= 255.
107
-
108
- @staticmethod
109
- def color_jitter(img, shift):
110
- """jitter color: randomly jitter the RGB values, in numpy formats"""
111
- jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
112
- img = img + jitter_val
113
- img = np.clip(img, 0, 1)
114
- return img
115
-
116
- @staticmethod
117
- def color_jitter_pt(img, brightness, contrast, saturation, hue):
118
- """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
119
- fn_idx = torch.randperm(4)
120
- for fn_id in fn_idx:
121
- if fn_id == 0 and brightness is not None:
122
- brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
123
- img = adjust_brightness(img, brightness_factor)
124
-
125
- if fn_id == 1 and contrast is not None:
126
- contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
127
- img = adjust_contrast(img, contrast_factor)
128
-
129
- if fn_id == 2 and saturation is not None:
130
- saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
131
- img = adjust_saturation(img, saturation_factor)
132
-
133
- if fn_id == 3 and hue is not None:
134
- hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
135
- img = adjust_hue(img, hue_factor)
136
- return img
137
-
138
-
139
- def get_component_locations(self, name, status):
140
- components_bbox = self.components_dict[name]
141
- if status[0]: # hflip
142
- # exchange right and left eye
143
- tmp = components_bbox['left_eye']
144
- components_bbox['left_eye'] = components_bbox['right_eye']
145
- components_bbox['right_eye'] = tmp
146
- # modify the width coordinate
147
- components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
148
- components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
149
- components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
150
- components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
151
-
152
- locations_gt = {}
153
- locations_in = {}
154
- for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
155
- mean = components_bbox[part][0:2]
156
- half_len = components_bbox[part][2]
157
- if 'eye' in part:
158
- half_len *= self.eye_enlarge_ratio
159
- elif part == 'nose':
160
- half_len *= self.nose_enlarge_ratio
161
- elif part == 'mouth':
162
- half_len *= self.mouth_enlarge_ratio
163
- loc = np.hstack((mean - half_len + 1, mean + half_len))
164
- loc = torch.from_numpy(loc).float()
165
- locations_gt[part] = loc
166
- loc_in = loc/(self.gt_size//self.in_size)
167
- locations_in[part] = loc_in
168
- return locations_gt, locations_in
169
-
170
-
171
- def __getitem__(self, index):
172
- if self.file_client is None:
173
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
174
-
175
- # load gt image
176
- gt_path = self.paths[index]
177
- name = osp.basename(gt_path)[:-4]
178
- img_bytes = self.file_client.get(gt_path)
179
- img_gt = imfrombytes(img_bytes, float32=True)
180
-
181
- # random horizontal flip
182
- img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
183
-
184
- if self.load_latent_gt:
185
- if status[0]:
186
- latent_gt = self.latent_gt_dict['hflip'][name]
187
- else:
188
- latent_gt = self.latent_gt_dict['orig'][name]
189
-
190
- if self.crop_components:
191
- locations_gt, locations_in = self.get_component_locations(name, status)
192
-
193
- # generate in image
194
- img_in = img_gt
195
- if self.use_corrupt:
196
- # motion blur
197
- if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
198
- m_i = random.randint(0,31)
199
- k = self.motion_kernels[f'{m_i:02d}']
200
- img_in = cv2.filter2D(img_in,-1,k)
201
-
202
- # gaussian blur
203
- kernel = gaussian_kernels.random_mixed_kernels(
204
- self.kernel_list,
205
- self.kernel_prob,
206
- self.blur_kernel_size,
207
- self.blur_sigma,
208
- self.blur_sigma,
209
- [-math.pi, math.pi],
210
- noise_range=None)
211
- img_in = cv2.filter2D(img_in, -1, kernel)
212
-
213
- # downsample
214
- scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
215
- img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
216
-
217
- # noise
218
- if self.noise_range is not None:
219
- noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
220
- noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
221
- img_in = img_in + noise
222
- img_in = np.clip(img_in, 0, 1)
223
-
224
- # jpeg
225
- if self.jpeg_range is not None:
226
- jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
227
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
228
- _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
229
- img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
230
-
231
- # resize to in_size
232
- img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
233
-
234
-
235
- # generate in_large with large degradation
236
- img_in_large = img_gt
237
-
238
- if self.use_corrupt:
239
- # motion blur
240
- if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
241
- m_i = random.randint(0,31)
242
- k = self.motion_kernels[f'{m_i:02d}']
243
- img_in_large = cv2.filter2D(img_in_large,-1,k)
244
-
245
- # gaussian blur
246
- kernel = gaussian_kernels.random_mixed_kernels(
247
- self.kernel_list,
248
- self.kernel_prob,
249
- self.blur_kernel_size,
250
- self.blur_sigma_large,
251
- self.blur_sigma_large,
252
- [-math.pi, math.pi],
253
- noise_range=None)
254
- img_in_large = cv2.filter2D(img_in_large, -1, kernel)
255
-
256
- # downsample
257
- scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
258
- img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
259
-
260
- # noise
261
- if self.noise_range_large is not None:
262
- noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
263
- noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
264
- img_in_large = img_in_large + noise
265
- img_in_large = np.clip(img_in_large, 0, 1)
266
-
267
- # jpeg
268
- if self.jpeg_range_large is not None:
269
- jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
270
- encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
271
- _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
272
- img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
273
-
274
- # resize to in_size
275
- img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
276
-
277
-
278
- # random color jitter (only for lq)
279
- if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
280
- img_in = self.color_jitter(img_in, self.color_jitter_shift)
281
- img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
282
- # random to gray (only for lq)
283
- if self.gray_prob and np.random.uniform() < self.gray_prob:
284
- img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
285
- img_in = np.tile(img_in[:, :, None], [1, 1, 3])
286
- img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
287
- img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
288
-
289
- # BGR to RGB, HWC to CHW, numpy to tensor
290
- img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
291
-
292
- # random color jitter (pytorch version) (only for lq)
293
- if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
294
- brightness = self.opt.get('brightness', (0.5, 1.5))
295
- contrast = self.opt.get('contrast', (0.5, 1.5))
296
- saturation = self.opt.get('saturation', (0, 1.5))
297
- hue = self.opt.get('hue', (-0.1, 0.1))
298
- img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
299
- img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
300
-
301
- # round and clip
302
- img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
303
- img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
304
-
305
- # Set vgg range_norm=True if use the normalization here
306
- # normalize
307
- normalize(img_in, self.mean, self.std, inplace=True)
308
- normalize(img_in_large, self.mean, self.std, inplace=True)
309
- normalize(img_gt, self.mean, self.std, inplace=True)
310
-
311
- return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
312
-
313
- if self.crop_components:
314
- return_dict['locations_in'] = locations_in
315
- return_dict['locations_gt'] = locations_gt
316
-
317
- if self.load_latent_gt:
318
- return_dict['latent_gt'] = latent_gt
319
-
320
- return return_dict
321
-
322
-
323
- def __len__(self):
324
- return len(self.paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/gaussian_kernels.py DELETED
@@ -1,690 +0,0 @@
1
- import math
2
- import numpy as np
3
- import random
4
- from scipy.ndimage.interpolation import shift
5
- from scipy.stats import multivariate_normal
6
-
7
-
8
- def sigma_matrix2(sig_x, sig_y, theta):
9
- """Calculate the rotated sigma matrix (two dimensional matrix).
10
- Args:
11
- sig_x (float):
12
- sig_y (float):
13
- theta (float): Radian measurement.
14
- Returns:
15
- ndarray: Rotated sigma matrix.
16
- """
17
- D = np.array([[sig_x**2, 0], [0, sig_y**2]])
18
- U = np.array([[np.cos(theta), -np.sin(theta)],
19
- [np.sin(theta), np.cos(theta)]])
20
- return np.dot(U, np.dot(D, U.T))
21
-
22
-
23
- def mesh_grid(kernel_size):
24
- """Generate the mesh grid, centering at zero.
25
- Args:
26
- kernel_size (int):
27
- Returns:
28
- xy (ndarray): with the shape (kernel_size, kernel_size, 2)
29
- xx (ndarray): with the shape (kernel_size, kernel_size)
30
- yy (ndarray): with the shape (kernel_size, kernel_size)
31
- """
32
- ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
33
- xx, yy = np.meshgrid(ax, ax)
34
- xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
35
- yy.reshape(kernel_size * kernel_size,
36
- 1))).reshape(kernel_size, kernel_size, 2)
37
- return xy, xx, yy
38
-
39
-
40
- def pdf2(sigma_matrix, grid):
41
- """Calculate PDF of the bivariate Gaussian distribution.
42
- Args:
43
- sigma_matrix (ndarray): with the shape (2, 2)
44
- grid (ndarray): generated by :func:`mesh_grid`,
45
- with the shape (K, K, 2), K is the kernel size.
46
- Returns:
47
- kernel (ndarrray): un-normalized kernel.
48
- """
49
- inverse_sigma = np.linalg.inv(sigma_matrix)
50
- kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
51
- return kernel
52
-
53
-
54
- def cdf2(D, grid):
55
- """Calculate the CDF of the standard bivariate Gaussian distribution.
56
- Used in skewed Gaussian distribution.
57
- Args:
58
- D (ndarrasy): skew matrix.
59
- grid (ndarray): generated by :func:`mesh_grid`,
60
- with the shape (K, K, 2), K is the kernel size.
61
- Returns:
62
- cdf (ndarray): skewed cdf.
63
- """
64
- rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
65
- grid = np.dot(grid, D)
66
- cdf = rv.cdf(grid)
67
- return cdf
68
-
69
-
70
- def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
71
- """Generate a bivariate skew Gaussian kernel.
72
- Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
73
- Args:
74
- kernel_size (int):
75
- sig_x (float):
76
- sig_y (float):
77
- theta (float): Radian measurement.
78
- D (ndarrasy): skew matrix.
79
- grid (ndarray, optional): generated by :func:`mesh_grid`,
80
- with the shape (K, K, 2), K is the kernel size. Default: None
81
- Returns:
82
- kernel (ndarray): normalized kernel.
83
- .. _A multivariate skew normal distribution:
84
- https://www.sciencedirect.com/science/article/pii/S0047259X03001313
85
- """
86
- if grid is None:
87
- grid, _, _ = mesh_grid(kernel_size)
88
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
89
- pdf = pdf2(sigma_matrix, grid)
90
- cdf = cdf2(D, grid)
91
- kernel = pdf * cdf
92
- kernel = kernel / np.sum(kernel)
93
- return kernel
94
-
95
-
96
- def mass_center_shift(kernel_size, kernel):
97
- """Calculate the shift of the mass center of a kenrel.
98
- Args:
99
- kernel_size (int):
100
- kernel (ndarray): normalized kernel.
101
- Returns:
102
- delta_h (float):
103
- delta_w (float):
104
- """
105
- ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
106
- col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
107
- delta_h = np.dot(row_sum, ax)
108
- delta_w = np.dot(col_sum, ax)
109
- return delta_h, delta_w
110
-
111
-
112
- def bivariate_skew_Gaussian_center(kernel_size,
113
- sig_x,
114
- sig_y,
115
- theta,
116
- D,
117
- grid=None):
118
- """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
119
- Args:
120
- kernel_size (int):
121
- sig_x (float):
122
- sig_y (float):
123
- theta (float): Radian measurement.
124
- D (ndarrasy): skew matrix.
125
- grid (ndarray, optional): generated by :func:`mesh_grid`,
126
- with the shape (K, K, 2), K is the kernel size. Default: None
127
- Returns:
128
- kernel (ndarray): centered and normalized kernel.
129
- """
130
- if grid is None:
131
- grid, _, _ = mesh_grid(kernel_size)
132
- kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
133
- delta_h, delta_w = mass_center_shift(kernel_size, kernel)
134
- kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
135
- kernel = kernel / np.sum(kernel)
136
- return kernel
137
-
138
-
139
- def bivariate_anisotropic_Gaussian(kernel_size,
140
- sig_x,
141
- sig_y,
142
- theta,
143
- grid=None):
144
- """Generate a bivariate anisotropic Gaussian kernel.
145
- Args:
146
- kernel_size (int):
147
- sig_x (float):
148
- sig_y (float):
149
- theta (float): Radian measurement.
150
- grid (ndarray, optional): generated by :func:`mesh_grid`,
151
- with the shape (K, K, 2), K is the kernel size. Default: None
152
- Returns:
153
- kernel (ndarray): normalized kernel.
154
- """
155
- if grid is None:
156
- grid, _, _ = mesh_grid(kernel_size)
157
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
158
- kernel = pdf2(sigma_matrix, grid)
159
- kernel = kernel / np.sum(kernel)
160
- return kernel
161
-
162
-
163
- def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
164
- """Generate a bivariate isotropic Gaussian kernel.
165
- Args:
166
- kernel_size (int):
167
- sig (float):
168
- grid (ndarray, optional): generated by :func:`mesh_grid`,
169
- with the shape (K, K, 2), K is the kernel size. Default: None
170
- Returns:
171
- kernel (ndarray): normalized kernel.
172
- """
173
- if grid is None:
174
- grid, _, _ = mesh_grid(kernel_size)
175
- sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
176
- kernel = pdf2(sigma_matrix, grid)
177
- kernel = kernel / np.sum(kernel)
178
- return kernel
179
-
180
-
181
- def bivariate_generalized_Gaussian(kernel_size,
182
- sig_x,
183
- sig_y,
184
- theta,
185
- beta,
186
- grid=None):
187
- """Generate a bivariate generalized Gaussian kernel.
188
- Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
189
- by Pascal et. al (2013).
190
- Args:
191
- kernel_size (int):
192
- sig_x (float):
193
- sig_y (float):
194
- theta (float): Radian measurement.
195
- beta (float): shape parameter, beta = 1 is the normal distribution.
196
- grid (ndarray, optional): generated by :func:`mesh_grid`,
197
- with the shape (K, K, 2), K is the kernel size. Default: None
198
- Returns:
199
- kernel (ndarray): normalized kernel.
200
- .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
201
- https://arxiv.org/abs/1302.6498
202
- """
203
- if grid is None:
204
- grid, _, _ = mesh_grid(kernel_size)
205
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
206
- inverse_sigma = np.linalg.inv(sigma_matrix)
207
- kernel = np.exp(
208
- -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
209
- kernel = kernel / np.sum(kernel)
210
- return kernel
211
-
212
-
213
- def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
214
- """Generate a plateau-like anisotropic kernel.
215
- 1 / (1+x^(beta))
216
- Args:
217
- kernel_size (int):
218
- sig_x (float):
219
- sig_y (float):
220
- theta (float): Radian measurement.
221
- beta (float): shape parameter, beta = 1 is the normal distribution.
222
- grid (ndarray, optional): generated by :func:`mesh_grid`,
223
- with the shape (K, K, 2), K is the kernel size. Default: None
224
- Returns:
225
- kernel (ndarray): normalized kernel.
226
- """
227
- if grid is None:
228
- grid, _, _ = mesh_grid(kernel_size)
229
- sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
230
- inverse_sigma = np.linalg.inv(sigma_matrix)
231
- kernel = np.reciprocal(
232
- np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
233
- kernel = kernel / np.sum(kernel)
234
- return kernel
235
-
236
-
237
- def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
238
- """Generate a plateau-like isotropic kernel.
239
- 1 / (1+x^(beta))
240
- Args:
241
- kernel_size (int):
242
- sig (float):
243
- beta (float): shape parameter, beta = 1 is the normal distribution.
244
- grid (ndarray, optional): generated by :func:`mesh_grid`,
245
- with the shape (K, K, 2), K is the kernel size. Default: None
246
- Returns:
247
- kernel (ndarray): normalized kernel.
248
- """
249
- if grid is None:
250
- grid, _, _ = mesh_grid(kernel_size)
251
- sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
252
- inverse_sigma = np.linalg.inv(sigma_matrix)
253
- kernel = np.reciprocal(
254
- np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
255
- kernel = kernel / np.sum(kernel)
256
- return kernel
257
-
258
-
259
- def random_bivariate_skew_Gaussian_center(kernel_size,
260
- sigma_x_range,
261
- sigma_y_range,
262
- rotation_range,
263
- noise_range=None,
264
- strict=False):
265
- """Randomly generate bivariate skew Gaussian kernels at center.
266
- Args:
267
- kernel_size (int):
268
- sigma_x_range (tuple): [0.6, 5]
269
- sigma_y_range (tuple): [0.6, 5]
270
- rotation range (tuple): [-math.pi, math.pi]
271
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
272
- Returns:
273
- kernel (ndarray):
274
- """
275
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
276
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
277
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
278
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
279
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
280
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
281
- if strict:
282
- sigma_max = np.max([sigma_x, sigma_y])
283
- sigma_min = np.min([sigma_x, sigma_y])
284
- sigma_x, sigma_y = sigma_max, sigma_min
285
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
286
-
287
- sigma_max = np.max([sigma_x, sigma_y])
288
- thres = 3 / sigma_max
289
- D = [[np.random.uniform(-thres, thres),
290
- np.random.uniform(-thres, thres)],
291
- [np.random.uniform(-thres, thres),
292
- np.random.uniform(-thres, thres)]]
293
-
294
- kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
295
- rotation, D)
296
-
297
- # add multiplicative noise
298
- if noise_range is not None:
299
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
300
- noise = np.random.uniform(
301
- noise_range[0], noise_range[1], size=kernel.shape)
302
- kernel = kernel * noise
303
- kernel = kernel / np.sum(kernel)
304
- if strict:
305
- return kernel, sigma_x, sigma_y, rotation, D
306
- else:
307
- return kernel
308
-
309
-
310
- def random_bivariate_anisotropic_Gaussian(kernel_size,
311
- sigma_x_range,
312
- sigma_y_range,
313
- rotation_range,
314
- noise_range=None,
315
- strict=False):
316
- """Randomly generate bivariate anisotropic Gaussian kernels.
317
- Args:
318
- kernel_size (int):
319
- sigma_x_range (tuple): [0.6, 5]
320
- sigma_y_range (tuple): [0.6, 5]
321
- rotation range (tuple): [-math.pi, math.pi]
322
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
323
- Returns:
324
- kernel (ndarray):
325
- """
326
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
327
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
328
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
329
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
330
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
331
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
332
- if strict:
333
- sigma_max = np.max([sigma_x, sigma_y])
334
- sigma_min = np.min([sigma_x, sigma_y])
335
- sigma_x, sigma_y = sigma_max, sigma_min
336
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
337
-
338
- kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
339
- rotation)
340
-
341
- # add multiplicative noise
342
- if noise_range is not None:
343
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
344
- noise = np.random.uniform(
345
- noise_range[0], noise_range[1], size=kernel.shape)
346
- kernel = kernel * noise
347
- kernel = kernel / np.sum(kernel)
348
- if strict:
349
- return kernel, sigma_x, sigma_y, rotation
350
- else:
351
- return kernel
352
-
353
-
354
- def random_bivariate_isotropic_Gaussian(kernel_size,
355
- sigma_range,
356
- noise_range=None,
357
- strict=False):
358
- """Randomly generate bivariate isotropic Gaussian kernels.
359
- Args:
360
- kernel_size (int):
361
- sigma_range (tuple): [0.6, 5]
362
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
363
- Returns:
364
- kernel (ndarray):
365
- """
366
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
367
- assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
368
- sigma = np.random.uniform(sigma_range[0], sigma_range[1])
369
-
370
- kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
371
-
372
- # add multiplicative noise
373
- if noise_range is not None:
374
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
375
- noise = np.random.uniform(
376
- noise_range[0], noise_range[1], size=kernel.shape)
377
- kernel = kernel * noise
378
- kernel = kernel / np.sum(kernel)
379
- if strict:
380
- return kernel, sigma
381
- else:
382
- return kernel
383
-
384
-
385
- def random_bivariate_generalized_Gaussian(kernel_size,
386
- sigma_x_range,
387
- sigma_y_range,
388
- rotation_range,
389
- beta_range,
390
- noise_range=None,
391
- strict=False):
392
- """Randomly generate bivariate generalized Gaussian kernels.
393
- Args:
394
- kernel_size (int):
395
- sigma_x_range (tuple): [0.6, 5]
396
- sigma_y_range (tuple): [0.6, 5]
397
- rotation range (tuple): [-math.pi, math.pi]
398
- beta_range (tuple): [0.5, 8]
399
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
400
- Returns:
401
- kernel (ndarray):
402
- """
403
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
404
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
405
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
406
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
407
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
408
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
409
- if strict:
410
- sigma_max = np.max([sigma_x, sigma_y])
411
- sigma_min = np.min([sigma_x, sigma_y])
412
- sigma_x, sigma_y = sigma_max, sigma_min
413
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
414
- if np.random.uniform() < 0.5:
415
- beta = np.random.uniform(beta_range[0], 1)
416
- else:
417
- beta = np.random.uniform(1, beta_range[1])
418
-
419
- kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
420
- rotation, beta)
421
-
422
- # add multiplicative noise
423
- if noise_range is not None:
424
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
425
- noise = np.random.uniform(
426
- noise_range[0], noise_range[1], size=kernel.shape)
427
- kernel = kernel * noise
428
- kernel = kernel / np.sum(kernel)
429
- if strict:
430
- return kernel, sigma_x, sigma_y, rotation, beta
431
- else:
432
- return kernel
433
-
434
-
435
- def random_bivariate_plateau_type1(kernel_size,
436
- sigma_x_range,
437
- sigma_y_range,
438
- rotation_range,
439
- beta_range,
440
- noise_range=None,
441
- strict=False):
442
- """Randomly generate bivariate plateau type1 kernels.
443
- Args:
444
- kernel_size (int):
445
- sigma_x_range (tuple): [0.6, 5]
446
- sigma_y_range (tuple): [0.6, 5]
447
- rotation range (tuple): [-math.pi/2, math.pi/2]
448
- beta_range (tuple): [1, 4]
449
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
450
- Returns:
451
- kernel (ndarray):
452
- """
453
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
454
- assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
455
- assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
456
- assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
457
- sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
458
- sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
459
- if strict:
460
- sigma_max = np.max([sigma_x, sigma_y])
461
- sigma_min = np.min([sigma_x, sigma_y])
462
- sigma_x, sigma_y = sigma_max, sigma_min
463
- rotation = np.random.uniform(rotation_range[0], rotation_range[1])
464
- if np.random.uniform() < 0.5:
465
- beta = np.random.uniform(beta_range[0], 1)
466
- else:
467
- beta = np.random.uniform(1, beta_range[1])
468
-
469
- kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
470
- beta)
471
-
472
- # add multiplicative noise
473
- if noise_range is not None:
474
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
475
- noise = np.random.uniform(
476
- noise_range[0], noise_range[1], size=kernel.shape)
477
- kernel = kernel * noise
478
- kernel = kernel / np.sum(kernel)
479
- if strict:
480
- return kernel, sigma_x, sigma_y, rotation, beta
481
- else:
482
- return kernel
483
-
484
-
485
- def random_bivariate_plateau_type1_iso(kernel_size,
486
- sigma_range,
487
- beta_range,
488
- noise_range=None,
489
- strict=False):
490
- """Randomly generate bivariate plateau type1 kernels (iso).
491
- Args:
492
- kernel_size (int):
493
- sigma_range (tuple): [0.6, 5]
494
- beta_range (tuple): [1, 4]
495
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
496
- Returns:
497
- kernel (ndarray):
498
- """
499
- assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
500
- assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
501
- sigma = np.random.uniform(sigma_range[0], sigma_range[1])
502
- beta = np.random.uniform(beta_range[0], beta_range[1])
503
-
504
- kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
505
-
506
- # add multiplicative noise
507
- if noise_range is not None:
508
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
509
- noise = np.random.uniform(
510
- noise_range[0], noise_range[1], size=kernel.shape)
511
- kernel = kernel * noise
512
- kernel = kernel / np.sum(kernel)
513
- if strict:
514
- return kernel, sigma, beta
515
- else:
516
- return kernel
517
-
518
-
519
- def random_mixed_kernels(kernel_list,
520
- kernel_prob,
521
- kernel_size=21,
522
- sigma_x_range=[0.6, 5],
523
- sigma_y_range=[0.6, 5],
524
- rotation_range=[-math.pi, math.pi],
525
- beta_range=[0.5, 8],
526
- noise_range=None):
527
- """Randomly generate mixed kernels.
528
- Args:
529
- kernel_list (tuple): a list name of kenrel types,
530
- support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
531
- kernel_prob (tuple): corresponding kernel probability for each kernel type
532
- kernel_size (int):
533
- sigma_x_range (tuple): [0.6, 5]
534
- sigma_y_range (tuple): [0.6, 5]
535
- rotation range (tuple): [-math.pi, math.pi]
536
- beta_range (tuple): [0.5, 8]
537
- noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
538
- Returns:
539
- kernel (ndarray):
540
- """
541
- kernel_type = random.choices(kernel_list, kernel_prob)[0]
542
- if kernel_type == 'iso':
543
- kernel = random_bivariate_isotropic_Gaussian(
544
- kernel_size, sigma_x_range, noise_range=noise_range)
545
- elif kernel_type == 'aniso':
546
- kernel = random_bivariate_anisotropic_Gaussian(
547
- kernel_size,
548
- sigma_x_range,
549
- sigma_y_range,
550
- rotation_range,
551
- noise_range=noise_range)
552
- elif kernel_type == 'skew':
553
- kernel = random_bivariate_skew_Gaussian_center(
554
- kernel_size,
555
- sigma_x_range,
556
- sigma_y_range,
557
- rotation_range,
558
- noise_range=noise_range)
559
- elif kernel_type == 'generalized':
560
- kernel = random_bivariate_generalized_Gaussian(
561
- kernel_size,
562
- sigma_x_range,
563
- sigma_y_range,
564
- rotation_range,
565
- beta_range,
566
- noise_range=noise_range)
567
- elif kernel_type == 'plateau_iso':
568
- kernel = random_bivariate_plateau_type1_iso(
569
- kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
570
- elif kernel_type == 'plateau_aniso':
571
- kernel = random_bivariate_plateau_type1(
572
- kernel_size,
573
- sigma_x_range,
574
- sigma_y_range,
575
- rotation_range,
576
- beta_range,
577
- noise_range=noise_range)
578
- # add multiplicative noise
579
- if noise_range is not None:
580
- assert noise_range[0] < noise_range[1], 'Wrong noise range.'
581
- noise = np.random.uniform(
582
- noise_range[0], noise_range[1], size=kernel.shape)
583
- kernel = kernel * noise
584
- kernel = kernel / np.sum(kernel)
585
- return kernel
586
-
587
-
588
- def show_one_kernel():
589
- import matplotlib.pyplot as plt
590
- kernel_size = 21
591
-
592
- # bivariate skew Gaussian
593
- D = [[0, 0], [0, 0]]
594
- D = [[3 / 4, 0], [0, 0.5]]
595
- kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
596
- # bivariate anisotropic Gaussian
597
- kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
598
- # bivariate anisotropic Gaussian
599
- kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
600
- # bivariate generalized Gaussian
601
- kernel = bivariate_generalized_Gaussian(
602
- kernel_size, 2, 4, -math.pi / 4, beta=4)
603
-
604
- delta_h, delta_w = mass_center_shift(kernel_size, kernel)
605
- print(delta_h, delta_w)
606
-
607
- fig, axs = plt.subplots(nrows=2, ncols=2)
608
- # axs.set_axis_off()
609
- ax = axs[0][0]
610
- im = ax.matshow(kernel, cmap='jet', origin='upper')
611
- fig.colorbar(im, ax=ax)
612
-
613
- # image
614
- ax = axs[0][1]
615
- kernel_vis = kernel - np.min(kernel)
616
- kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
617
- ax.imshow(kernel_vis, interpolation='nearest')
618
-
619
- _, xx, yy = mesh_grid(kernel_size)
620
- # contour
621
- ax = axs[1][0]
622
- CS = ax.contour(xx, yy, kernel, origin='upper')
623
- ax.clabel(CS, inline=1, fontsize=3)
624
-
625
- # contourf
626
- ax = axs[1][1]
627
- kernel = kernel / np.max(kernel)
628
- p = ax.contourf(
629
- xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
630
- fig.colorbar(p)
631
-
632
- plt.show()
633
-
634
-
635
- def show_plateau_kernel():
636
- import matplotlib.pyplot as plt
637
- kernel_size = 21
638
-
639
- kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
640
- kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
641
- kernel_gau = bivariate_generalized_Gaussian(
642
- kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
643
- delta_h, delta_w = mass_center_shift(kernel_size, kernel)
644
- print(delta_h, delta_w)
645
-
646
- # kernel_slice = kernel[10, :]
647
- # kernel_gau_slice = kernel_gau[10, :]
648
- # kernel_norm_slice = kernel_norm[10, :]
649
- # fig, ax = plt.subplots()
650
- # t = list(range(1, 22))
651
-
652
- # ax.plot(t, kernel_gau_slice)
653
- # ax.plot(t, kernel_slice)
654
- # ax.plot(t, kernel_norm_slice)
655
-
656
- # t = np.arange(0, 10, 0.1)
657
- # y = np.exp(-0.5 * t)
658
- # y2 = np.reciprocal(1 + t)
659
- # print(t.shape)
660
- # print(y.shape)
661
- # ax.plot(t, y)
662
- # ax.plot(t, y2)
663
- # plt.show()
664
-
665
- fig, axs = plt.subplots(nrows=2, ncols=2)
666
- # axs.set_axis_off()
667
- ax = axs[0][0]
668
- im = ax.matshow(kernel, cmap='jet', origin='upper')
669
- fig.colorbar(im, ax=ax)
670
-
671
- # image
672
- ax = axs[0][1]
673
- kernel_vis = kernel - np.min(kernel)
674
- kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
675
- ax.imshow(kernel_vis, interpolation='nearest')
676
-
677
- _, xx, yy = mesh_grid(kernel_size)
678
- # contour
679
- ax = axs[1][0]
680
- CS = ax.contour(xx, yy, kernel, origin='upper')
681
- ax.clabel(CS, inline=1, fontsize=3)
682
-
683
- # contourf
684
- ax = axs[1][1]
685
- kernel = kernel / np.max(kernel)
686
- p = ax.contourf(
687
- xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
688
- fig.colorbar(p)
689
-
690
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/paired_image_dataset.py DELETED
@@ -1,101 +0,0 @@
1
- from torch.utils import data as data
2
- from torchvision.transforms.functional import normalize
3
-
4
- from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5
- from basicsr.data.transforms import augment, paired_random_crop
6
- from basicsr.utils import FileClient, imfrombytes, img2tensor
7
- from basicsr.utils.registry import DATASET_REGISTRY
8
-
9
-
10
- @DATASET_REGISTRY.register()
11
- class PairedImageDataset(data.Dataset):
12
- """Paired image dataset for image restoration.
13
-
14
- Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
15
- GT image pairs.
16
-
17
- There are three modes:
18
- 1. 'lmdb': Use lmdb files.
19
- If opt['io_backend'] == lmdb.
20
- 2. 'meta_info_file': Use meta information file to generate paths.
21
- If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
22
- 3. 'folder': Scan folders to generate paths.
23
- The rest.
24
-
25
- Args:
26
- opt (dict): Config for train datasets. It contains the following keys:
27
- dataroot_gt (str): Data root path for gt.
28
- dataroot_lq (str): Data root path for lq.
29
- meta_info_file (str): Path for meta information file.
30
- io_backend (dict): IO backend type and other kwarg.
31
- filename_tmpl (str): Template for each filename. Note that the
32
- template excludes the file extension. Default: '{}'.
33
- gt_size (int): Cropped patched size for gt patches.
34
- use_flip (bool): Use horizontal flips.
35
- use_rot (bool): Use rotation (use vertical flip and transposing h
36
- and w for implementation).
37
-
38
- scale (bool): Scale, which will be added automatically.
39
- phase (str): 'train' or 'val'.
40
- """
41
-
42
- def __init__(self, opt):
43
- super(PairedImageDataset, self).__init__()
44
- self.opt = opt
45
- # file client (io backend)
46
- self.file_client = None
47
- self.io_backend_opt = opt['io_backend']
48
- self.mean = opt['mean'] if 'mean' in opt else None
49
- self.std = opt['std'] if 'std' in opt else None
50
-
51
- self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
52
- if 'filename_tmpl' in opt:
53
- self.filename_tmpl = opt['filename_tmpl']
54
- else:
55
- self.filename_tmpl = '{}'
56
-
57
- if self.io_backend_opt['type'] == 'lmdb':
58
- self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
59
- self.io_backend_opt['client_keys'] = ['lq', 'gt']
60
- self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
61
- elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
62
- self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
63
- self.opt['meta_info_file'], self.filename_tmpl)
64
- else:
65
- self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
66
-
67
- def __getitem__(self, index):
68
- if self.file_client is None:
69
- self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
70
-
71
- scale = self.opt['scale']
72
-
73
- # Load gt and lq images. Dimension order: HWC; channel order: BGR;
74
- # image range: [0, 1], float32.
75
- gt_path = self.paths[index]['gt_path']
76
- img_bytes = self.file_client.get(gt_path, 'gt')
77
- img_gt = imfrombytes(img_bytes, float32=True)
78
- lq_path = self.paths[index]['lq_path']
79
- img_bytes = self.file_client.get(lq_path, 'lq')
80
- img_lq = imfrombytes(img_bytes, float32=True)
81
-
82
- # augmentation for training
83
- if self.opt['phase'] == 'train':
84
- gt_size = self.opt['gt_size']
85
- # random crop
86
- img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
87
- # flip, rotation
88
- img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
89
-
90
- # TODO: color space transform
91
- # BGR to RGB, HWC to CHW, numpy to tensor
92
- img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
93
- # normalize
94
- if self.mean is not None or self.std is not None:
95
- normalize(img_lq, self.mean, self.std, inplace=True)
96
- normalize(img_gt, self.mean, self.std, inplace=True)
97
-
98
- return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
99
-
100
- def __len__(self):
101
- return len(self.paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/prefetch_dataloader.py DELETED
@@ -1,125 +0,0 @@
1
- import queue as Queue
2
- import threading
3
- import torch
4
- from torch.utils.data import DataLoader
5
-
6
-
7
- class PrefetchGenerator(threading.Thread):
8
- """A general prefetch generator.
9
-
10
- Ref:
11
- https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
-
13
- Args:
14
- generator: Python generator.
15
- num_prefetch_queue (int): Number of prefetch queue.
16
- """
17
-
18
- def __init__(self, generator, num_prefetch_queue):
19
- threading.Thread.__init__(self)
20
- self.queue = Queue.Queue(num_prefetch_queue)
21
- self.generator = generator
22
- self.daemon = True
23
- self.start()
24
-
25
- def run(self):
26
- for item in self.generator:
27
- self.queue.put(item)
28
- self.queue.put(None)
29
-
30
- def __next__(self):
31
- next_item = self.queue.get()
32
- if next_item is None:
33
- raise StopIteration
34
- return next_item
35
-
36
- def __iter__(self):
37
- return self
38
-
39
-
40
- class PrefetchDataLoader(DataLoader):
41
- """Prefetch version of dataloader.
42
-
43
- Ref:
44
- https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
-
46
- TODO:
47
- Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
- ddp.
49
-
50
- Args:
51
- num_prefetch_queue (int): Number of prefetch queue.
52
- kwargs (dict): Other arguments for dataloader.
53
- """
54
-
55
- def __init__(self, num_prefetch_queue, **kwargs):
56
- self.num_prefetch_queue = num_prefetch_queue
57
- super(PrefetchDataLoader, self).__init__(**kwargs)
58
-
59
- def __iter__(self):
60
- return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
-
62
-
63
- class CPUPrefetcher():
64
- """CPU prefetcher.
65
-
66
- Args:
67
- loader: Dataloader.
68
- """
69
-
70
- def __init__(self, loader):
71
- self.ori_loader = loader
72
- self.loader = iter(loader)
73
-
74
- def next(self):
75
- try:
76
- return next(self.loader)
77
- except StopIteration:
78
- return None
79
-
80
- def reset(self):
81
- self.loader = iter(self.ori_loader)
82
-
83
-
84
- class CUDAPrefetcher():
85
- """CUDA prefetcher.
86
-
87
- Ref:
88
- https://github.com/NVIDIA/apex/issues/304#
89
-
90
- It may consums more GPU memory.
91
-
92
- Args:
93
- loader: Dataloader.
94
- opt (dict): Options.
95
- """
96
-
97
- def __init__(self, loader, opt):
98
- self.ori_loader = loader
99
- self.loader = iter(loader)
100
- self.opt = opt
101
- self.stream = torch.cuda.Stream()
102
- self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
- self.preload()
104
-
105
- def preload(self):
106
- try:
107
- self.batch = next(self.loader) # self.batch is a dict
108
- except StopIteration:
109
- self.batch = None
110
- return None
111
- # put tensors to gpu
112
- with torch.cuda.stream(self.stream):
113
- for k, v in self.batch.items():
114
- if torch.is_tensor(v):
115
- self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
-
117
- def next(self):
118
- torch.cuda.current_stream().wait_stream(self.stream)
119
- batch = self.batch
120
- self.preload()
121
- return batch
122
-
123
- def reset(self):
124
- self.loader = iter(self.ori_loader)
125
- self.preload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/data/transforms.py DELETED
@@ -1,165 +0,0 @@
1
- import cv2
2
- import random
3
-
4
-
5
- def mod_crop(img, scale):
6
- """Mod crop images, used during testing.
7
-
8
- Args:
9
- img (ndarray): Input image.
10
- scale (int): Scale factor.
11
-
12
- Returns:
13
- ndarray: Result image.
14
- """
15
- img = img.copy()
16
- if img.ndim in (2, 3):
17
- h, w = img.shape[0], img.shape[1]
18
- h_remainder, w_remainder = h % scale, w % scale
19
- img = img[:h - h_remainder, :w - w_remainder, ...]
20
- else:
21
- raise ValueError(f'Wrong img ndim: {img.ndim}.')
22
- return img
23
-
24
-
25
- def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
26
- """Paired random crop.
27
-
28
- It crops lists of lq and gt images with corresponding locations.
29
-
30
- Args:
31
- img_gts (list[ndarray] | ndarray): GT images. Note that all images
32
- should have the same shape. If the input is an ndarray, it will
33
- be transformed to a list containing itself.
34
- img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
35
- should have the same shape. If the input is an ndarray, it will
36
- be transformed to a list containing itself.
37
- gt_patch_size (int): GT patch size.
38
- scale (int): Scale factor.
39
- gt_path (str): Path to ground-truth.
40
-
41
- Returns:
42
- list[ndarray] | ndarray: GT images and LQ images. If returned results
43
- only have one element, just return ndarray.
44
- """
45
-
46
- if not isinstance(img_gts, list):
47
- img_gts = [img_gts]
48
- if not isinstance(img_lqs, list):
49
- img_lqs = [img_lqs]
50
-
51
- h_lq, w_lq, _ = img_lqs[0].shape
52
- h_gt, w_gt, _ = img_gts[0].shape
53
- lq_patch_size = gt_patch_size // scale
54
-
55
- if h_gt != h_lq * scale or w_gt != w_lq * scale:
56
- raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
57
- f'multiplication of LQ ({h_lq}, {w_lq}).')
58
- if h_lq < lq_patch_size or w_lq < lq_patch_size:
59
- raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
60
- f'({lq_patch_size}, {lq_patch_size}). '
61
- f'Please remove {gt_path}.')
62
-
63
- # randomly choose top and left coordinates for lq patch
64
- top = random.randint(0, h_lq - lq_patch_size)
65
- left = random.randint(0, w_lq - lq_patch_size)
66
-
67
- # crop lq patch
68
- img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
69
-
70
- # crop corresponding gt patch
71
- top_gt, left_gt = int(top * scale), int(left * scale)
72
- img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
73
- if len(img_gts) == 1:
74
- img_gts = img_gts[0]
75
- if len(img_lqs) == 1:
76
- img_lqs = img_lqs[0]
77
- return img_gts, img_lqs
78
-
79
-
80
- def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
81
- """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
82
-
83
- We use vertical flip and transpose for rotation implementation.
84
- All the images in the list use the same augmentation.
85
-
86
- Args:
87
- imgs (list[ndarray] | ndarray): Images to be augmented. If the input
88
- is an ndarray, it will be transformed to a list.
89
- hflip (bool): Horizontal flip. Default: True.
90
- rotation (bool): Ratotation. Default: True.
91
- flows (list[ndarray]: Flows to be augmented. If the input is an
92
- ndarray, it will be transformed to a list.
93
- Dimension is (h, w, 2). Default: None.
94
- return_status (bool): Return the status of flip and rotation.
95
- Default: False.
96
-
97
- Returns:
98
- list[ndarray] | ndarray: Augmented images and flows. If returned
99
- results only have one element, just return ndarray.
100
-
101
- """
102
- hflip = hflip and random.random() < 0.5
103
- vflip = rotation and random.random() < 0.5
104
- rot90 = rotation and random.random() < 0.5
105
-
106
- def _augment(img):
107
- if hflip: # horizontal
108
- cv2.flip(img, 1, img)
109
- if vflip: # vertical
110
- cv2.flip(img, 0, img)
111
- if rot90:
112
- img = img.transpose(1, 0, 2)
113
- return img
114
-
115
- def _augment_flow(flow):
116
- if hflip: # horizontal
117
- cv2.flip(flow, 1, flow)
118
- flow[:, :, 0] *= -1
119
- if vflip: # vertical
120
- cv2.flip(flow, 0, flow)
121
- flow[:, :, 1] *= -1
122
- if rot90:
123
- flow = flow.transpose(1, 0, 2)
124
- flow = flow[:, :, [1, 0]]
125
- return flow
126
-
127
- if not isinstance(imgs, list):
128
- imgs = [imgs]
129
- imgs = [_augment(img) for img in imgs]
130
- if len(imgs) == 1:
131
- imgs = imgs[0]
132
-
133
- if flows is not None:
134
- if not isinstance(flows, list):
135
- flows = [flows]
136
- flows = [_augment_flow(flow) for flow in flows]
137
- if len(flows) == 1:
138
- flows = flows[0]
139
- return imgs, flows
140
- else:
141
- if return_status:
142
- return imgs, (hflip, vflip, rot90)
143
- else:
144
- return imgs
145
-
146
-
147
- def img_rotate(img, angle, center=None, scale=1.0):
148
- """Rotate image.
149
-
150
- Args:
151
- img (ndarray): Image to be rotated.
152
- angle (float): Rotation angle in degrees. Positive values mean
153
- counter-clockwise rotation.
154
- center (tuple[int]): Rotation center. If the center is None,
155
- initialize it as the center of the image. Default: None.
156
- scale (float): Isotropic scale factor. Default: 1.0.
157
- """
158
- (h, w) = img.shape[:2]
159
-
160
- if center is None:
161
- center = (w // 2, h // 2)
162
-
163
- matrix = cv2.getRotationMatrix2D(center, angle, scale)
164
- rotated_img = cv2.warpAffine(img, matrix, (w, h))
165
- return rotated_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/losses/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- from copy import deepcopy
2
-
3
- from basicsr.utils import get_root_logger
4
- from basicsr.utils.registry import LOSS_REGISTRY
5
- from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
- gradient_penalty_loss, r1_penalty)
7
-
8
- __all__ = [
9
- 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
- 'r1_penalty', 'g_path_regularize'
11
- ]
12
-
13
-
14
- def build_loss(opt):
15
- """Build loss from options.
16
-
17
- Args:
18
- opt (dict): Configuration. It must constain:
19
- type (str): Model type.
20
- """
21
- opt = deepcopy(opt)
22
- loss_type = opt.pop('type')
23
- loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
- logger = get_root_logger()
25
- logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/losses/loss_util.py DELETED
@@ -1,95 +0,0 @@
1
- import functools
2
- from torch.nn import functional as F
3
-
4
-
5
- def reduce_loss(loss, reduction):
6
- """Reduce loss as specified.
7
-
8
- Args:
9
- loss (Tensor): Elementwise loss tensor.
10
- reduction (str): Options are 'none', 'mean' and 'sum'.
11
-
12
- Returns:
13
- Tensor: Reduced loss tensor.
14
- """
15
- reduction_enum = F._Reduction.get_enum(reduction)
16
- # none: 0, elementwise_mean:1, sum: 2
17
- if reduction_enum == 0:
18
- return loss
19
- elif reduction_enum == 1:
20
- return loss.mean()
21
- else:
22
- return loss.sum()
23
-
24
-
25
- def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
- """Apply element-wise weight and reduce loss.
27
-
28
- Args:
29
- loss (Tensor): Element-wise loss.
30
- weight (Tensor): Element-wise weights. Default: None.
31
- reduction (str): Same as built-in losses of PyTorch. Options are
32
- 'none', 'mean' and 'sum'. Default: 'mean'.
33
-
34
- Returns:
35
- Tensor: Loss values.
36
- """
37
- # if weight is specified, apply element-wise weight
38
- if weight is not None:
39
- assert weight.dim() == loss.dim()
40
- assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
- loss = loss * weight
42
-
43
- # if weight is not specified or reduction is sum, just reduce the loss
44
- if weight is None or reduction == 'sum':
45
- loss = reduce_loss(loss, reduction)
46
- # if reduction is mean, then compute mean over weight region
47
- elif reduction == 'mean':
48
- if weight.size(1) > 1:
49
- weight = weight.sum()
50
- else:
51
- weight = weight.sum() * loss.size(1)
52
- loss = loss.sum() / weight
53
-
54
- return loss
55
-
56
-
57
- def weighted_loss(loss_func):
58
- """Create a weighted version of a given loss function.
59
-
60
- To use this decorator, the loss function must have the signature like
61
- `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
- element-wise loss without any reduction. This decorator will add weight
63
- and reduction arguments to the function. The decorated function will have
64
- the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
- **kwargs)`.
66
-
67
- :Example:
68
-
69
- >>> import torch
70
- >>> @weighted_loss
71
- >>> def l1_loss(pred, target):
72
- >>> return (pred - target).abs()
73
-
74
- >>> pred = torch.Tensor([0, 2, 3])
75
- >>> target = torch.Tensor([1, 1, 1])
76
- >>> weight = torch.Tensor([1, 0, 1])
77
-
78
- >>> l1_loss(pred, target)
79
- tensor(1.3333)
80
- >>> l1_loss(pred, target, weight)
81
- tensor(1.5000)
82
- >>> l1_loss(pred, target, reduction='none')
83
- tensor([1., 1., 2.])
84
- >>> l1_loss(pred, target, weight, reduction='sum')
85
- tensor(3.)
86
- """
87
-
88
- @functools.wraps(loss_func)
89
- def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
- # get element-wise loss
91
- loss = loss_func(pred, target, **kwargs)
92
- loss = weight_reduce_loss(loss, weight, reduction)
93
- return loss
94
-
95
- return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/losses/losses.py DELETED
@@ -1,455 +0,0 @@
1
- import math
2
- import lpips
3
- import torch
4
- from torch import autograd as autograd
5
- from torch import nn as nn
6
- from torch.nn import functional as F
7
-
8
- from basicsr.archs.vgg_arch import VGGFeatureExtractor
9
- from basicsr.utils.registry import LOSS_REGISTRY
10
- from .loss_util import weighted_loss
11
-
12
- _reduction_modes = ['none', 'mean', 'sum']
13
-
14
-
15
- @weighted_loss
16
- def l1_loss(pred, target):
17
- return F.l1_loss(pred, target, reduction='none')
18
-
19
-
20
- @weighted_loss
21
- def mse_loss(pred, target):
22
- return F.mse_loss(pred, target, reduction='none')
23
-
24
-
25
- @weighted_loss
26
- def charbonnier_loss(pred, target, eps=1e-12):
27
- return torch.sqrt((pred - target)**2 + eps)
28
-
29
-
30
- @LOSS_REGISTRY.register()
31
- class L1Loss(nn.Module):
32
- """L1 (mean absolute error, MAE) loss.
33
-
34
- Args:
35
- loss_weight (float): Loss weight for L1 loss. Default: 1.0.
36
- reduction (str): Specifies the reduction to apply to the output.
37
- Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
38
- """
39
-
40
- def __init__(self, loss_weight=1.0, reduction='mean'):
41
- super(L1Loss, self).__init__()
42
- if reduction not in ['none', 'mean', 'sum']:
43
- raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
44
-
45
- self.loss_weight = loss_weight
46
- self.reduction = reduction
47
-
48
- def forward(self, pred, target, weight=None, **kwargs):
49
- """
50
- Args:
51
- pred (Tensor): of shape (N, C, H, W). Predicted tensor.
52
- target (Tensor): of shape (N, C, H, W). Ground truth tensor.
53
- weight (Tensor, optional): of shape (N, C, H, W). Element-wise
54
- weights. Default: None.
55
- """
56
- return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
57
-
58
-
59
- @LOSS_REGISTRY.register()
60
- class MSELoss(nn.Module):
61
- """MSE (L2) loss.
62
-
63
- Args:
64
- loss_weight (float): Loss weight for MSE loss. Default: 1.0.
65
- reduction (str): Specifies the reduction to apply to the output.
66
- Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
67
- """
68
-
69
- def __init__(self, loss_weight=1.0, reduction='mean'):
70
- super(MSELoss, self).__init__()
71
- if reduction not in ['none', 'mean', 'sum']:
72
- raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
73
-
74
- self.loss_weight = loss_weight
75
- self.reduction = reduction
76
-
77
- def forward(self, pred, target, weight=None, **kwargs):
78
- """
79
- Args:
80
- pred (Tensor): of shape (N, C, H, W). Predicted tensor.
81
- target (Tensor): of shape (N, C, H, W). Ground truth tensor.
82
- weight (Tensor, optional): of shape (N, C, H, W). Element-wise
83
- weights. Default: None.
84
- """
85
- return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
86
-
87
-
88
- @LOSS_REGISTRY.register()
89
- class CharbonnierLoss(nn.Module):
90
- """Charbonnier loss (one variant of Robust L1Loss, a differentiable
91
- variant of L1Loss).
92
-
93
- Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
94
- Super-Resolution".
95
-
96
- Args:
97
- loss_weight (float): Loss weight for L1 loss. Default: 1.0.
98
- reduction (str): Specifies the reduction to apply to the output.
99
- Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
100
- eps (float): A value used to control the curvature near zero.
101
- Default: 1e-12.
102
- """
103
-
104
- def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
105
- super(CharbonnierLoss, self).__init__()
106
- if reduction not in ['none', 'mean', 'sum']:
107
- raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
108
-
109
- self.loss_weight = loss_weight
110
- self.reduction = reduction
111
- self.eps = eps
112
-
113
- def forward(self, pred, target, weight=None, **kwargs):
114
- """
115
- Args:
116
- pred (Tensor): of shape (N, C, H, W). Predicted tensor.
117
- target (Tensor): of shape (N, C, H, W). Ground truth tensor.
118
- weight (Tensor, optional): of shape (N, C, H, W). Element-wise
119
- weights. Default: None.
120
- """
121
- return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
122
-
123
-
124
- @LOSS_REGISTRY.register()
125
- class WeightedTVLoss(L1Loss):
126
- """Weighted TV loss.
127
-
128
- Args:
129
- loss_weight (float): Loss weight. Default: 1.0.
130
- """
131
-
132
- def __init__(self, loss_weight=1.0):
133
- super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
134
-
135
- def forward(self, pred, weight=None):
136
- y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
137
- x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
138
-
139
- loss = x_diff + y_diff
140
-
141
- return loss
142
-
143
-
144
- @LOSS_REGISTRY.register()
145
- class PerceptualLoss(nn.Module):
146
- """Perceptual loss with commonly used style loss.
147
-
148
- Args:
149
- layer_weights (dict): The weight for each layer of vgg feature.
150
- Here is an example: {'conv5_4': 1.}, which means the conv5_4
151
- feature layer (before relu5_4) will be extracted with weight
152
- 1.0 in calculting losses.
153
- vgg_type (str): The type of vgg network used as feature extractor.
154
- Default: 'vgg19'.
155
- use_input_norm (bool): If True, normalize the input image in vgg.
156
- Default: True.
157
- range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
158
- Default: False.
159
- perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
160
- loss will be calculated and the loss will multiplied by the
161
- weight. Default: 1.0.
162
- style_weight (float): If `style_weight > 0`, the style loss will be
163
- calculated and the loss will multiplied by the weight.
164
- Default: 0.
165
- criterion (str): Criterion used for perceptual loss. Default: 'l1'.
166
- """
167
-
168
- def __init__(self,
169
- layer_weights,
170
- vgg_type='vgg19',
171
- use_input_norm=True,
172
- range_norm=False,
173
- perceptual_weight=1.0,
174
- style_weight=0.,
175
- criterion='l1'):
176
- super(PerceptualLoss, self).__init__()
177
- self.perceptual_weight = perceptual_weight
178
- self.style_weight = style_weight
179
- self.layer_weights = layer_weights
180
- self.vgg = VGGFeatureExtractor(
181
- layer_name_list=list(layer_weights.keys()),
182
- vgg_type=vgg_type,
183
- use_input_norm=use_input_norm,
184
- range_norm=range_norm)
185
-
186
- self.criterion_type = criterion
187
- if self.criterion_type == 'l1':
188
- self.criterion = torch.nn.L1Loss()
189
- elif self.criterion_type == 'l2':
190
- self.criterion = torch.nn.L2loss()
191
- elif self.criterion_type == 'mse':
192
- self.criterion = torch.nn.MSELoss(reduction='mean')
193
- elif self.criterion_type == 'fro':
194
- self.criterion = None
195
- else:
196
- raise NotImplementedError(f'{criterion} criterion has not been supported.')
197
-
198
- def forward(self, x, gt):
199
- """Forward function.
200
-
201
- Args:
202
- x (Tensor): Input tensor with shape (n, c, h, w).
203
- gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
204
-
205
- Returns:
206
- Tensor: Forward results.
207
- """
208
- # extract vgg features
209
- x_features = self.vgg(x)
210
- gt_features = self.vgg(gt.detach())
211
-
212
- # calculate perceptual loss
213
- if self.perceptual_weight > 0:
214
- percep_loss = 0
215
- for k in x_features.keys():
216
- if self.criterion_type == 'fro':
217
- percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
218
- else:
219
- percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
220
- percep_loss *= self.perceptual_weight
221
- else:
222
- percep_loss = None
223
-
224
- # calculate style loss
225
- if self.style_weight > 0:
226
- style_loss = 0
227
- for k in x_features.keys():
228
- if self.criterion_type == 'fro':
229
- style_loss += torch.norm(
230
- self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
231
- else:
232
- style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
233
- gt_features[k])) * self.layer_weights[k]
234
- style_loss *= self.style_weight
235
- else:
236
- style_loss = None
237
-
238
- return percep_loss, style_loss
239
-
240
- def _gram_mat(self, x):
241
- """Calculate Gram matrix.
242
-
243
- Args:
244
- x (torch.Tensor): Tensor with shape of (n, c, h, w).
245
-
246
- Returns:
247
- torch.Tensor: Gram matrix.
248
- """
249
- n, c, h, w = x.size()
250
- features = x.view(n, c, w * h)
251
- features_t = features.transpose(1, 2)
252
- gram = features.bmm(features_t) / (c * h * w)
253
- return gram
254
-
255
-
256
- @LOSS_REGISTRY.register()
257
- class LPIPSLoss(nn.Module):
258
- def __init__(self,
259
- loss_weight=1.0,
260
- use_input_norm=True,
261
- range_norm=False,):
262
- super(LPIPSLoss, self).__init__()
263
- self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
264
- self.loss_weight = loss_weight
265
- self.use_input_norm = use_input_norm
266
- self.range_norm = range_norm
267
-
268
- if self.use_input_norm:
269
- # the mean is for image with range [0, 1]
270
- self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
271
- # the std is for image with range [0, 1]
272
- self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
273
-
274
- def forward(self, pred, target):
275
- if self.range_norm:
276
- pred = (pred + 1) / 2
277
- target = (target + 1) / 2
278
- if self.use_input_norm:
279
- pred = (pred - self.mean) / self.std
280
- target = (target - self.mean) / self.std
281
- lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
282
- return self.loss_weight * lpips_loss.mean()
283
-
284
-
285
- @LOSS_REGISTRY.register()
286
- class GANLoss(nn.Module):
287
- """Define GAN loss.
288
-
289
- Args:
290
- gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
291
- real_label_val (float): The value for real label. Default: 1.0.
292
- fake_label_val (float): The value for fake label. Default: 0.0.
293
- loss_weight (float): Loss weight. Default: 1.0.
294
- Note that loss_weight is only for generators; and it is always 1.0
295
- for discriminators.
296
- """
297
-
298
- def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
299
- super(GANLoss, self).__init__()
300
- self.gan_type = gan_type
301
- self.loss_weight = loss_weight
302
- self.real_label_val = real_label_val
303
- self.fake_label_val = fake_label_val
304
-
305
- if self.gan_type == 'vanilla':
306
- self.loss = nn.BCEWithLogitsLoss()
307
- elif self.gan_type == 'lsgan':
308
- self.loss = nn.MSELoss()
309
- elif self.gan_type == 'wgan':
310
- self.loss = self._wgan_loss
311
- elif self.gan_type == 'wgan_softplus':
312
- self.loss = self._wgan_softplus_loss
313
- elif self.gan_type == 'hinge':
314
- self.loss = nn.ReLU()
315
- else:
316
- raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
317
-
318
- def _wgan_loss(self, input, target):
319
- """wgan loss.
320
-
321
- Args:
322
- input (Tensor): Input tensor.
323
- target (bool): Target label.
324
-
325
- Returns:
326
- Tensor: wgan loss.
327
- """
328
- return -input.mean() if target else input.mean()
329
-
330
- def _wgan_softplus_loss(self, input, target):
331
- """wgan loss with soft plus. softplus is a smooth approximation to the
332
- ReLU function.
333
-
334
- In StyleGAN2, it is called:
335
- Logistic loss for discriminator;
336
- Non-saturating loss for generator.
337
-
338
- Args:
339
- input (Tensor): Input tensor.
340
- target (bool): Target label.
341
-
342
- Returns:
343
- Tensor: wgan loss.
344
- """
345
- return F.softplus(-input).mean() if target else F.softplus(input).mean()
346
-
347
- def get_target_label(self, input, target_is_real):
348
- """Get target label.
349
-
350
- Args:
351
- input (Tensor): Input tensor.
352
- target_is_real (bool): Whether the target is real or fake.
353
-
354
- Returns:
355
- (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
356
- return Tensor.
357
- """
358
-
359
- if self.gan_type in ['wgan', 'wgan_softplus']:
360
- return target_is_real
361
- target_val = (self.real_label_val if target_is_real else self.fake_label_val)
362
- return input.new_ones(input.size()) * target_val
363
-
364
- def forward(self, input, target_is_real, is_disc=False):
365
- """
366
- Args:
367
- input (Tensor): The input for the loss module, i.e., the network
368
- prediction.
369
- target_is_real (bool): Whether the targe is real or fake.
370
- is_disc (bool): Whether the loss for discriminators or not.
371
- Default: False.
372
-
373
- Returns:
374
- Tensor: GAN loss value.
375
- """
376
- if self.gan_type == 'hinge':
377
- if is_disc: # for discriminators in hinge-gan
378
- input = -input if target_is_real else input
379
- loss = self.loss(1 + input).mean()
380
- else: # for generators in hinge-gan
381
- loss = -input.mean()
382
- else: # other gan types
383
- target_label = self.get_target_label(input, target_is_real)
384
- loss = self.loss(input, target_label)
385
-
386
- # loss_weight is always 1.0 for discriminators
387
- return loss if is_disc else loss * self.loss_weight
388
-
389
-
390
- def r1_penalty(real_pred, real_img):
391
- """R1 regularization for discriminator. The core idea is to
392
- penalize the gradient on real data alone: when the
393
- generator distribution produces the true data distribution
394
- and the discriminator is equal to 0 on the data manifold, the
395
- gradient penalty ensures that the discriminator cannot create
396
- a non-zero gradient orthogonal to the data manifold without
397
- suffering a loss in the GAN game.
398
-
399
- Ref:
400
- Eq. 9 in Which training methods for GANs do actually converge.
401
- """
402
- grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
403
- grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
404
- return grad_penalty
405
-
406
-
407
- def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
408
- noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
409
- grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
410
- path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
411
-
412
- path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
413
-
414
- path_penalty = (path_lengths - path_mean).pow(2).mean()
415
-
416
- return path_penalty, path_lengths.detach().mean(), path_mean.detach()
417
-
418
-
419
- def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
420
- """Calculate gradient penalty for wgan-gp.
421
-
422
- Args:
423
- discriminator (nn.Module): Network for the discriminator.
424
- real_data (Tensor): Real input data.
425
- fake_data (Tensor): Fake input data.
426
- weight (Tensor): Weight tensor. Default: None.
427
-
428
- Returns:
429
- Tensor: A tensor for gradient penalty.
430
- """
431
-
432
- batch_size = real_data.size(0)
433
- alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
434
-
435
- # interpolate between real_data and fake_data
436
- interpolates = alpha * real_data + (1. - alpha) * fake_data
437
- interpolates = autograd.Variable(interpolates, requires_grad=True)
438
-
439
- disc_interpolates = discriminator(interpolates)
440
- gradients = autograd.grad(
441
- outputs=disc_interpolates,
442
- inputs=interpolates,
443
- grad_outputs=torch.ones_like(disc_interpolates),
444
- create_graph=True,
445
- retain_graph=True,
446
- only_inputs=True)[0]
447
-
448
- if weight is not None:
449
- gradients = gradients * weight
450
-
451
- gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
452
- if weight is not None:
453
- gradients_penalty /= torch.mean(weight)
454
-
455
- return gradients_penalty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/metrics/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- from copy import deepcopy
2
-
3
- from basicsr.utils.registry import METRIC_REGISTRY
4
- from .psnr_ssim import calculate_psnr, calculate_ssim
5
-
6
- __all__ = ['calculate_psnr', 'calculate_ssim']
7
-
8
-
9
- def calculate_metric(data, opt):
10
- """Calculate metric from data and options.
11
-
12
- Args:
13
- opt (dict): Configuration. It must constain:
14
- type (str): Model type.
15
- """
16
- opt = deepcopy(opt)
17
- metric_type = opt.pop('type')
18
- metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
19
- return metric
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/metrics/metric_util.py DELETED
@@ -1,45 +0,0 @@
1
- import numpy as np
2
-
3
- from basicsr.utils.matlab_functions import bgr2ycbcr
4
-
5
-
6
- def reorder_image(img, input_order='HWC'):
7
- """Reorder images to 'HWC' order.
8
-
9
- If the input_order is (h, w), return (h, w, 1);
10
- If the input_order is (c, h, w), return (h, w, c);
11
- If the input_order is (h, w, c), return as it is.
12
-
13
- Args:
14
- img (ndarray): Input image.
15
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
16
- If the input image shape is (h, w), input_order will not have
17
- effects. Default: 'HWC'.
18
-
19
- Returns:
20
- ndarray: reordered image.
21
- """
22
-
23
- if input_order not in ['HWC', 'CHW']:
24
- raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
25
- if len(img.shape) == 2:
26
- img = img[..., None]
27
- if input_order == 'CHW':
28
- img = img.transpose(1, 2, 0)
29
- return img
30
-
31
-
32
- def to_y_channel(img):
33
- """Change to Y channel of YCbCr.
34
-
35
- Args:
36
- img (ndarray): Images with range [0, 255].
37
-
38
- Returns:
39
- (ndarray): Images with range [0, 255] (float type) without round.
40
- """
41
- img = img.astype(np.float32) / 255.
42
- if img.ndim == 3 and img.shape[2] == 3:
43
- img = bgr2ycbcr(img, y_only=True)
44
- img = img[..., None]
45
- return img * 255.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/metrics/psnr_ssim.py DELETED
@@ -1,128 +0,0 @@
1
- import cv2
2
- import numpy as np
3
-
4
- from basicsr.metrics.metric_util import reorder_image, to_y_channel
5
- from basicsr.utils.registry import METRIC_REGISTRY
6
-
7
-
8
- @METRIC_REGISTRY.register()
9
- def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
10
- """Calculate PSNR (Peak Signal-to-Noise Ratio).
11
-
12
- Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13
-
14
- Args:
15
- img1 (ndarray): Images with range [0, 255].
16
- img2 (ndarray): Images with range [0, 255].
17
- crop_border (int): Cropped pixels in each edge of an image. These
18
- pixels are not involved in the PSNR calculation.
19
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
20
- Default: 'HWC'.
21
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
22
-
23
- Returns:
24
- float: psnr result.
25
- """
26
-
27
- assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
28
- if input_order not in ['HWC', 'CHW']:
29
- raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
30
- img1 = reorder_image(img1, input_order=input_order)
31
- img2 = reorder_image(img2, input_order=input_order)
32
- img1 = img1.astype(np.float64)
33
- img2 = img2.astype(np.float64)
34
-
35
- if crop_border != 0:
36
- img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
37
- img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
38
-
39
- if test_y_channel:
40
- img1 = to_y_channel(img1)
41
- img2 = to_y_channel(img2)
42
-
43
- mse = np.mean((img1 - img2)**2)
44
- if mse == 0:
45
- return float('inf')
46
- return 20. * np.log10(255. / np.sqrt(mse))
47
-
48
-
49
- def _ssim(img1, img2):
50
- """Calculate SSIM (structural similarity) for one channel images.
51
-
52
- It is called by func:`calculate_ssim`.
53
-
54
- Args:
55
- img1 (ndarray): Images with range [0, 255] with order 'HWC'.
56
- img2 (ndarray): Images with range [0, 255] with order 'HWC'.
57
-
58
- Returns:
59
- float: ssim result.
60
- """
61
-
62
- C1 = (0.01 * 255)**2
63
- C2 = (0.03 * 255)**2
64
-
65
- img1 = img1.astype(np.float64)
66
- img2 = img2.astype(np.float64)
67
- kernel = cv2.getGaussianKernel(11, 1.5)
68
- window = np.outer(kernel, kernel.transpose())
69
-
70
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
71
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
72
- mu1_sq = mu1**2
73
- mu2_sq = mu2**2
74
- mu1_mu2 = mu1 * mu2
75
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
76
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
77
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
78
-
79
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
80
- return ssim_map.mean()
81
-
82
-
83
- @METRIC_REGISTRY.register()
84
- def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
85
- """Calculate SSIM (structural similarity).
86
-
87
- Ref:
88
- Image quality assessment: From error visibility to structural similarity
89
-
90
- The results are the same as that of the official released MATLAB code in
91
- https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92
-
93
- For three-channel images, SSIM is calculated for each channel and then
94
- averaged.
95
-
96
- Args:
97
- img1 (ndarray): Images with range [0, 255].
98
- img2 (ndarray): Images with range [0, 255].
99
- crop_border (int): Cropped pixels in each edge of an image. These
100
- pixels are not involved in the SSIM calculation.
101
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
102
- Default: 'HWC'.
103
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104
-
105
- Returns:
106
- float: ssim result.
107
- """
108
-
109
- assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
110
- if input_order not in ['HWC', 'CHW']:
111
- raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
112
- img1 = reorder_image(img1, input_order=input_order)
113
- img2 = reorder_image(img2, input_order=input_order)
114
- img1 = img1.astype(np.float64)
115
- img2 = img2.astype(np.float64)
116
-
117
- if crop_border != 0:
118
- img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
119
- img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120
-
121
- if test_y_channel:
122
- img1 = to_y_channel(img1)
123
- img2 = to_y_channel(img2)
124
-
125
- ssims = []
126
- for i in range(img1.shape[2]):
127
- ssims.append(_ssim(img1[..., i], img2[..., i]))
128
- return np.array(ssims).mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- import importlib
2
- from copy import deepcopy
3
- from os import path as osp
4
-
5
- from basicsr.utils import get_root_logger, scandir
6
- from basicsr.utils.registry import MODEL_REGISTRY
7
-
8
- __all__ = ['build_model']
9
-
10
- # automatically scan and import model modules for registry
11
- # scan all the files under the 'models' folder and collect files ending with
12
- # '_model.py'
13
- model_folder = osp.dirname(osp.abspath(__file__))
14
- model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
15
- # import all the model modules
16
- _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
17
-
18
-
19
- def build_model(opt):
20
- """Build model from options.
21
-
22
- Args:
23
- opt (dict): Configuration. It must constain:
24
- model_type (str): Model type.
25
- """
26
- opt = deepcopy(opt)
27
- model = MODEL_REGISTRY.get(opt['model_type'])(opt)
28
- logger = get_root_logger()
29
- logger.info(f'Model [{model.__class__.__name__}] is created.')
30
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/base_model.py DELETED
@@ -1,322 +0,0 @@
1
- import logging
2
- import os
3
- import torch
4
- from collections import OrderedDict
5
- from copy import deepcopy
6
- from torch.nn.parallel import DataParallel, DistributedDataParallel
7
-
8
- from basicsr.models import lr_scheduler as lr_scheduler
9
- from basicsr.utils.dist_util import master_only
10
-
11
- logger = logging.getLogger('basicsr')
12
-
13
-
14
- class BaseModel():
15
- """Base model."""
16
-
17
- def __init__(self, opt):
18
- self.opt = opt
19
- self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
20
- self.is_train = opt['is_train']
21
- self.schedulers = []
22
- self.optimizers = []
23
-
24
- def feed_data(self, data):
25
- pass
26
-
27
- def optimize_parameters(self):
28
- pass
29
-
30
- def get_current_visuals(self):
31
- pass
32
-
33
- def save(self, epoch, current_iter):
34
- """Save networks and training state."""
35
- pass
36
-
37
- def validation(self, dataloader, current_iter, tb_logger, save_img=False):
38
- """Validation function.
39
-
40
- Args:
41
- dataloader (torch.utils.data.DataLoader): Validation dataloader.
42
- current_iter (int): Current iteration.
43
- tb_logger (tensorboard logger): Tensorboard logger.
44
- save_img (bool): Whether to save images. Default: False.
45
- """
46
- if self.opt['dist']:
47
- self.dist_validation(dataloader, current_iter, tb_logger, save_img)
48
- else:
49
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
50
-
51
- def model_ema(self, decay=0.999):
52
- net_g = self.get_bare_model(self.net_g)
53
-
54
- net_g_params = dict(net_g.named_parameters())
55
- net_g_ema_params = dict(self.net_g_ema.named_parameters())
56
-
57
- for k in net_g_ema_params.keys():
58
- net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
59
-
60
- def get_current_log(self):
61
- return self.log_dict
62
-
63
- def model_to_device(self, net):
64
- """Model to device. It also warps models with DistributedDataParallel
65
- or DataParallel.
66
-
67
- Args:
68
- net (nn.Module)
69
- """
70
- net = net.to(self.device)
71
- if self.opt['dist']:
72
- find_unused_parameters = self.opt.get('find_unused_parameters', False)
73
- net = DistributedDataParallel(
74
- net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
75
- elif self.opt['num_gpu'] > 1:
76
- net = DataParallel(net)
77
- return net
78
-
79
- def get_optimizer(self, optim_type, params, lr, **kwargs):
80
- if optim_type == 'Adam':
81
- optimizer = torch.optim.Adam(params, lr, **kwargs)
82
- else:
83
- raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
84
- return optimizer
85
-
86
- def setup_schedulers(self):
87
- """Set up schedulers."""
88
- train_opt = self.opt['train']
89
- scheduler_type = train_opt['scheduler'].pop('type')
90
- if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
91
- for optimizer in self.optimizers:
92
- self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
93
- elif scheduler_type == 'CosineAnnealingRestartLR':
94
- for optimizer in self.optimizers:
95
- self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
96
- else:
97
- raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
98
-
99
- def get_bare_model(self, net):
100
- """Get bare model, especially under wrapping with
101
- DistributedDataParallel or DataParallel.
102
- """
103
- if isinstance(net, (DataParallel, DistributedDataParallel)):
104
- net = net.module
105
- return net
106
-
107
- @master_only
108
- def print_network(self, net):
109
- """Print the str and parameter number of a network.
110
-
111
- Args:
112
- net (nn.Module)
113
- """
114
- if isinstance(net, (DataParallel, DistributedDataParallel)):
115
- net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
116
- else:
117
- net_cls_str = f'{net.__class__.__name__}'
118
-
119
- net = self.get_bare_model(net)
120
- net_str = str(net)
121
- net_params = sum(map(lambda x: x.numel(), net.parameters()))
122
-
123
- logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
124
- logger.info(net_str)
125
-
126
- def _set_lr(self, lr_groups_l):
127
- """Set learning rate for warmup.
128
-
129
- Args:
130
- lr_groups_l (list): List for lr_groups, each for an optimizer.
131
- """
132
- for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
133
- for param_group, lr in zip(optimizer.param_groups, lr_groups):
134
- param_group['lr'] = lr
135
-
136
- def _get_init_lr(self):
137
- """Get the initial lr, which is set by the scheduler.
138
- """
139
- init_lr_groups_l = []
140
- for optimizer in self.optimizers:
141
- init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
142
- return init_lr_groups_l
143
-
144
- def update_learning_rate(self, current_iter, warmup_iter=-1):
145
- """Update learning rate.
146
-
147
- Args:
148
- current_iter (int): Current iteration.
149
- warmup_iter (int): Warmup iter numbers. -1 for no warmup.
150
- Default: -1.
151
- """
152
- if current_iter > 1:
153
- for scheduler in self.schedulers:
154
- scheduler.step()
155
- # set up warm-up learning rate
156
- if current_iter < warmup_iter:
157
- # get initial lr for each group
158
- init_lr_g_l = self._get_init_lr()
159
- # modify warming-up learning rates
160
- # currently only support linearly warm up
161
- warm_up_lr_l = []
162
- for init_lr_g in init_lr_g_l:
163
- warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
164
- # set learning rate
165
- self._set_lr(warm_up_lr_l)
166
-
167
- def get_current_learning_rate(self):
168
- return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
169
-
170
- @master_only
171
- def save_network(self, net, net_label, current_iter, param_key='params'):
172
- """Save networks.
173
-
174
- Args:
175
- net (nn.Module | list[nn.Module]): Network(s) to be saved.
176
- net_label (str): Network label.
177
- current_iter (int): Current iter number.
178
- param_key (str | list[str]): The parameter key(s) to save network.
179
- Default: 'params'.
180
- """
181
- if current_iter == -1:
182
- current_iter = 'latest'
183
- save_filename = f'{net_label}_{current_iter}.pth'
184
- save_path = os.path.join(self.opt['path']['models'], save_filename)
185
-
186
- net = net if isinstance(net, list) else [net]
187
- param_key = param_key if isinstance(param_key, list) else [param_key]
188
- assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
189
-
190
- save_dict = {}
191
- for net_, param_key_ in zip(net, param_key):
192
- net_ = self.get_bare_model(net_)
193
- state_dict = net_.state_dict()
194
- for key, param in state_dict.items():
195
- if key.startswith('module.'): # remove unnecessary 'module.'
196
- key = key[7:]
197
- state_dict[key] = param.cpu()
198
- save_dict[param_key_] = state_dict
199
-
200
- torch.save(save_dict, save_path)
201
-
202
- def _print_different_keys_loading(self, crt_net, load_net, strict=True):
203
- """Print keys with differnet name or different size when loading models.
204
-
205
- 1. Print keys with differnet names.
206
- 2. If strict=False, print the same key but with different tensor size.
207
- It also ignore these keys with different sizes (not load).
208
-
209
- Args:
210
- crt_net (torch model): Current network.
211
- load_net (dict): Loaded network.
212
- strict (bool): Whether strictly loaded. Default: True.
213
- """
214
- crt_net = self.get_bare_model(crt_net)
215
- crt_net = crt_net.state_dict()
216
- crt_net_keys = set(crt_net.keys())
217
- load_net_keys = set(load_net.keys())
218
-
219
- if crt_net_keys != load_net_keys:
220
- logger.warning('Current net - loaded net:')
221
- for v in sorted(list(crt_net_keys - load_net_keys)):
222
- logger.warning(f' {v}')
223
- logger.warning('Loaded net - current net:')
224
- for v in sorted(list(load_net_keys - crt_net_keys)):
225
- logger.warning(f' {v}')
226
-
227
- # check the size for the same keys
228
- if not strict:
229
- common_keys = crt_net_keys & load_net_keys
230
- for k in common_keys:
231
- if crt_net[k].size() != load_net[k].size():
232
- logger.warning(f'Size different, ignore [{k}]: crt_net: '
233
- f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
234
- load_net[k + '.ignore'] = load_net.pop(k)
235
-
236
- def load_network(self, net, load_path, strict=True, param_key='params'):
237
- """Load network.
238
-
239
- Args:
240
- load_path (str): The path of networks to be loaded.
241
- net (nn.Module): Network.
242
- strict (bool): Whether strictly loaded.
243
- param_key (str): The parameter key of loaded network. If set to
244
- None, use the root 'path'.
245
- Default: 'params'.
246
- """
247
- net = self.get_bare_model(net)
248
- logger.info(f'Loading {net.__class__.__name__} model from {load_path}.')
249
- load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
250
- if param_key is not None:
251
- if param_key not in load_net and 'params' in load_net:
252
- param_key = 'params'
253
- logger.info('Loading: params_ema does not exist, use params.')
254
- load_net = load_net[param_key]
255
- # remove unnecessary 'module.'
256
- for k, v in deepcopy(load_net).items():
257
- if k.startswith('module.'):
258
- load_net[k[7:]] = v
259
- load_net.pop(k)
260
- self._print_different_keys_loading(net, load_net, strict)
261
- net.load_state_dict(load_net, strict=strict)
262
-
263
- @master_only
264
- def save_training_state(self, epoch, current_iter):
265
- """Save training states during training, which will be used for
266
- resuming.
267
-
268
- Args:
269
- epoch (int): Current epoch.
270
- current_iter (int): Current iteration.
271
- """
272
- if current_iter != -1:
273
- state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
274
- for o in self.optimizers:
275
- state['optimizers'].append(o.state_dict())
276
- for s in self.schedulers:
277
- state['schedulers'].append(s.state_dict())
278
- save_filename = f'{current_iter}.state'
279
- save_path = os.path.join(self.opt['path']['training_states'], save_filename)
280
- torch.save(state, save_path)
281
-
282
- def resume_training(self, resume_state):
283
- """Reload the optimizers and schedulers for resumed training.
284
-
285
- Args:
286
- resume_state (dict): Resume state.
287
- """
288
- resume_optimizers = resume_state['optimizers']
289
- resume_schedulers = resume_state['schedulers']
290
- assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
291
- assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
292
- for i, o in enumerate(resume_optimizers):
293
- self.optimizers[i].load_state_dict(o)
294
- for i, s in enumerate(resume_schedulers):
295
- self.schedulers[i].load_state_dict(s)
296
-
297
- def reduce_loss_dict(self, loss_dict):
298
- """reduce loss dict.
299
-
300
- In distributed training, it averages the losses among different GPUs .
301
-
302
- Args:
303
- loss_dict (OrderedDict): Loss dict.
304
- """
305
- with torch.no_grad():
306
- if self.opt['dist']:
307
- keys = []
308
- losses = []
309
- for name, value in loss_dict.items():
310
- keys.append(name)
311
- losses.append(value)
312
- losses = torch.stack(losses, 0)
313
- torch.distributed.reduce(losses, dst=0)
314
- if self.opt['rank'] == 0:
315
- losses /= self.opt['world_size']
316
- loss_dict = {key: loss for key, loss in zip(keys, losses)}
317
-
318
- log_dict = OrderedDict()
319
- for name, value in loss_dict.items():
320
- log_dict[name] = value.mean().item()
321
-
322
- return log_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/codeformer_idx_model.py DELETED
@@ -1,220 +0,0 @@
1
- import torch
2
- from collections import OrderedDict
3
- from os import path as osp
4
- from tqdm import tqdm
5
-
6
- from basicsr.archs import build_network
7
- from basicsr.metrics import calculate_metric
8
- from basicsr.utils import get_root_logger, imwrite, tensor2img
9
- from basicsr.utils.registry import MODEL_REGISTRY
10
- import torch.nn.functional as F
11
- from .sr_model import SRModel
12
-
13
-
14
- @MODEL_REGISTRY.register()
15
- class CodeFormerIdxModel(SRModel):
16
- def feed_data(self, data):
17
- self.gt = data['gt'].to(self.device)
18
- self.input = data['in'].to(self.device)
19
- self.b = self.gt.shape[0]
20
-
21
- if 'latent_gt' in data:
22
- self.idx_gt = data['latent_gt'].to(self.device)
23
- self.idx_gt = self.idx_gt.view(self.b, -1)
24
- else:
25
- self.idx_gt = None
26
-
27
- def init_training_settings(self):
28
- logger = get_root_logger()
29
- train_opt = self.opt['train']
30
-
31
- self.ema_decay = train_opt.get('ema_decay', 0)
32
- if self.ema_decay > 0:
33
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
34
- # define network net_g with Exponential Moving Average (EMA)
35
- # net_g_ema is used only for testing on one GPU and saving
36
- # There is no need to wrap with DistributedDataParallel
37
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
38
- # load pretrained model
39
- load_path = self.opt['path'].get('pretrain_network_g', None)
40
- if load_path is not None:
41
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
42
- else:
43
- self.model_ema(0) # copy net_g weight
44
- self.net_g_ema.eval()
45
-
46
- if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
47
- self.generate_idx_gt = False
48
- elif self.opt.get('network_vqgan', None) is not None:
49
- self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
50
- self.hq_vqgan_fix.eval()
51
- self.generate_idx_gt = True
52
- for param in self.hq_vqgan_fix.parameters():
53
- param.requires_grad = False
54
- else:
55
- raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
56
-
57
- logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
58
-
59
- self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
60
- self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
61
- self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
62
- self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
63
-
64
- self.net_g.train()
65
-
66
- # set up optimizers and schedulers
67
- self.setup_optimizers()
68
- self.setup_schedulers()
69
-
70
-
71
- def setup_optimizers(self):
72
- train_opt = self.opt['train']
73
- # optimizer g
74
- optim_params_g = []
75
- for k, v in self.net_g.named_parameters():
76
- if v.requires_grad:
77
- optim_params_g.append(v)
78
- else:
79
- logger = get_root_logger()
80
- logger.warning(f'Params {k} will not be optimized.')
81
- optim_type = train_opt['optim_g'].pop('type')
82
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
83
- self.optimizers.append(self.optimizer_g)
84
-
85
-
86
- def optimize_parameters(self, current_iter):
87
- logger = get_root_logger()
88
- # optimize net_g
89
- self.optimizer_g.zero_grad()
90
-
91
- if self.generate_idx_gt:
92
- x = self.hq_vqgan_fix.encoder(self.gt)
93
- _, _, quant_stats = self.hq_vqgan_fix.quantize(x)
94
- min_encoding_indices = quant_stats['min_encoding_indices']
95
- self.idx_gt = min_encoding_indices.view(self.b, -1)
96
-
97
- if self.hq_feat_loss:
98
- # quant_feats
99
- quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
100
-
101
- logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
102
-
103
- l_g_total = 0
104
- loss_dict = OrderedDict()
105
- # hq_feat_loss
106
- if self.hq_feat_loss: # codebook loss
107
- l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
108
- l_g_total += l_feat_encoder
109
- loss_dict['l_feat_encoder'] = l_feat_encoder
110
-
111
- # cross_entropy_loss
112
- if self.cross_entropy_loss:
113
- # b(hw)n -> bn(hw)
114
- cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
115
- l_g_total += cross_entropy_loss
116
- loss_dict['cross_entropy_loss'] = cross_entropy_loss
117
-
118
- l_g_total.backward()
119
- self.optimizer_g.step()
120
-
121
- if self.ema_decay > 0:
122
- self.model_ema(decay=self.ema_decay)
123
-
124
- self.log_dict = self.reduce_loss_dict(loss_dict)
125
-
126
-
127
- def test(self):
128
- with torch.no_grad():
129
- if hasattr(self, 'net_g_ema'):
130
- self.net_g_ema.eval()
131
- self.output, _, _ = self.net_g_ema(self.input, w=0)
132
- else:
133
- logger = get_root_logger()
134
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
135
- self.net_g.eval()
136
- self.output, _, _ = self.net_g(self.input, w=0)
137
- self.net_g.train()
138
-
139
-
140
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
141
- if self.opt['rank'] == 0:
142
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
143
-
144
-
145
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
146
- dataset_name = dataloader.dataset.opt['name']
147
- with_metrics = self.opt['val'].get('metrics') is not None
148
- if with_metrics:
149
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
150
- pbar = tqdm(total=len(dataloader), unit='image')
151
-
152
- for idx, val_data in enumerate(dataloader):
153
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
154
- self.feed_data(val_data)
155
- self.test()
156
-
157
- visuals = self.get_current_visuals()
158
- sr_img = tensor2img([visuals['result']])
159
- if 'gt' in visuals:
160
- gt_img = tensor2img([visuals['gt']])
161
- del self.gt
162
-
163
- # tentative for out of GPU memory
164
- del self.lq
165
- del self.output
166
- torch.cuda.empty_cache()
167
-
168
- if save_img:
169
- if self.opt['is_train']:
170
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
171
- f'{img_name}_{current_iter}.png')
172
- else:
173
- if self.opt['val']['suffix']:
174
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
175
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
176
- else:
177
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
178
- f'{img_name}_{self.opt["name"]}.png')
179
- imwrite(sr_img, save_img_path)
180
-
181
- if with_metrics:
182
- # calculate metrics
183
- for name, opt_ in self.opt['val']['metrics'].items():
184
- metric_data = dict(img1=sr_img, img2=gt_img)
185
- self.metric_results[name] += calculate_metric(metric_data, opt_)
186
- pbar.update(1)
187
- pbar.set_description(f'Test {img_name}')
188
- pbar.close()
189
-
190
- if with_metrics:
191
- for metric in self.metric_results.keys():
192
- self.metric_results[metric] /= (idx + 1)
193
-
194
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
195
-
196
-
197
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
198
- log_str = f'Validation {dataset_name}\n'
199
- for metric, value in self.metric_results.items():
200
- log_str += f'\t # {metric}: {value:.4f}\n'
201
- logger = get_root_logger()
202
- logger.info(log_str)
203
- if tb_logger:
204
- for metric, value in self.metric_results.items():
205
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
206
-
207
-
208
- def get_current_visuals(self):
209
- out_dict = OrderedDict()
210
- out_dict['gt'] = self.gt.detach().cpu()
211
- out_dict['result'] = self.output.detach().cpu()
212
- return out_dict
213
-
214
-
215
- def save(self, epoch, current_iter):
216
- if self.ema_decay > 0:
217
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
218
- else:
219
- self.save_network(self.net_g, 'net_g', current_iter)
220
- self.save_training_state(epoch, current_iter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/codeformer_joint_model.py DELETED
@@ -1,350 +0,0 @@
1
- import torch
2
- from collections import OrderedDict
3
- from os import path as osp
4
- from tqdm import tqdm
5
-
6
-
7
- from basicsr.archs import build_network
8
- from basicsr.losses import build_loss
9
- from basicsr.metrics import calculate_metric
10
- from basicsr.utils import get_root_logger, imwrite, tensor2img
11
- from basicsr.utils.registry import MODEL_REGISTRY
12
- import torch.nn.functional as F
13
- from .sr_model import SRModel
14
-
15
-
16
- @MODEL_REGISTRY.register()
17
- class CodeFormerJointModel(SRModel):
18
- def feed_data(self, data):
19
- self.gt = data['gt'].to(self.device)
20
- self.input = data['in'].to(self.device)
21
- self.input_large_de = data['in_large_de'].to(self.device)
22
- self.b = self.gt.shape[0]
23
-
24
- if 'latent_gt' in data:
25
- self.idx_gt = data['latent_gt'].to(self.device)
26
- self.idx_gt = self.idx_gt.view(self.b, -1)
27
- else:
28
- self.idx_gt = None
29
-
30
- def init_training_settings(self):
31
- logger = get_root_logger()
32
- train_opt = self.opt['train']
33
-
34
- self.ema_decay = train_opt.get('ema_decay', 0)
35
- if self.ema_decay > 0:
36
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
37
- # define network net_g with Exponential Moving Average (EMA)
38
- # net_g_ema is used only for testing on one GPU and saving
39
- # There is no need to wrap with DistributedDataParallel
40
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
41
- # load pretrained model
42
- load_path = self.opt['path'].get('pretrain_network_g', None)
43
- if load_path is not None:
44
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
45
- else:
46
- self.model_ema(0) # copy net_g weight
47
- self.net_g_ema.eval()
48
-
49
- if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
50
- self.generate_idx_gt = False
51
- elif self.opt.get('network_vqgan', None) is not None:
52
- self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
53
- self.hq_vqgan_fix.eval()
54
- self.generate_idx_gt = True
55
- for param in self.hq_vqgan_fix.parameters():
56
- param.requires_grad = False
57
- else:
58
- raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
59
-
60
- logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
61
-
62
- self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
63
- self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
64
- self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
65
- self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
66
- self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
67
-
68
- # define network net_d
69
- self.net_d = build_network(self.opt['network_d'])
70
- self.net_d = self.model_to_device(self.net_d)
71
- self.print_network(self.net_d)
72
-
73
- # load pretrained models
74
- load_path = self.opt['path'].get('pretrain_network_d', None)
75
- if load_path is not None:
76
- self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
77
-
78
- self.net_g.train()
79
- self.net_d.train()
80
-
81
- # define losses
82
- if train_opt.get('pixel_opt'):
83
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
84
- else:
85
- self.cri_pix = None
86
-
87
- if train_opt.get('perceptual_opt'):
88
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
89
- else:
90
- self.cri_perceptual = None
91
-
92
- if train_opt.get('gan_opt'):
93
- self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
94
-
95
-
96
- self.fix_generator = train_opt.get('fix_generator', True)
97
- logger.info(f'fix_generator: {self.fix_generator}')
98
-
99
- self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
100
- self.net_d_iters = train_opt.get('net_d_iters', 1)
101
- self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
102
-
103
- # set up optimizers and schedulers
104
- self.setup_optimizers()
105
- self.setup_schedulers()
106
-
107
- def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
108
- recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
109
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
110
-
111
- d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
112
- d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
113
- return d_weight
114
-
115
- def setup_optimizers(self):
116
- train_opt = self.opt['train']
117
- # optimizer g
118
- optim_params_g = []
119
- for k, v in self.net_g.named_parameters():
120
- if v.requires_grad:
121
- optim_params_g.append(v)
122
- else:
123
- logger = get_root_logger()
124
- logger.warning(f'Params {k} will not be optimized.')
125
- optim_type = train_opt['optim_g'].pop('type')
126
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
127
- self.optimizers.append(self.optimizer_g)
128
- # optimizer d
129
- optim_type = train_opt['optim_d'].pop('type')
130
- self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
131
- self.optimizers.append(self.optimizer_d)
132
-
133
- def gray_resize_for_identity(self, out, size=128):
134
- out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
135
- out_gray = out_gray.unsqueeze(1)
136
- out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
137
- return out_gray
138
-
139
- def optimize_parameters(self, current_iter):
140
- logger = get_root_logger()
141
- # optimize net_g
142
- for p in self.net_d.parameters():
143
- p.requires_grad = False
144
-
145
- self.optimizer_g.zero_grad()
146
-
147
- if self.generate_idx_gt:
148
- x = self.hq_vqgan_fix.encoder(self.gt)
149
- output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
150
- min_encoding_indices = quant_stats['min_encoding_indices']
151
- self.idx_gt = min_encoding_indices.view(self.b, -1)
152
-
153
- if current_iter <= 40000: # small degradation
154
- small_per_n = 1
155
- w = 1
156
- elif current_iter <= 80000: # small degradation
157
- small_per_n = 1
158
- w = 1.3
159
- elif current_iter <= 120000: # large degradation
160
- small_per_n = 120000
161
- w = 0
162
- else: # mixed degradation
163
- small_per_n = 15
164
- w = 1.3
165
-
166
- if current_iter % small_per_n == 0:
167
- self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True)
168
- large_de = False
169
- else:
170
- logits, lq_feat = self.net_g(self.input_large_de, code_only=True)
171
- large_de = True
172
-
173
- if self.hq_feat_loss:
174
- # quant_feats
175
- quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
176
-
177
- l_g_total = 0
178
- loss_dict = OrderedDict()
179
- if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
180
- # hq_feat_loss
181
- if not 'transformer' in self.opt['network_g']['fix_modules']:
182
- if self.hq_feat_loss: # codebook loss
183
- l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
184
- l_g_total += l_feat_encoder
185
- loss_dict['l_feat_encoder'] = l_feat_encoder
186
-
187
- # cross_entropy_loss
188
- if self.cross_entropy_loss:
189
- # b(hw)n -> bn(hw)
190
- cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
191
- l_g_total += cross_entropy_loss
192
- loss_dict['cross_entropy_loss'] = cross_entropy_loss
193
-
194
- # pixel loss
195
- if not large_de: # when large degradation don't need image-level loss
196
- if self.cri_pix:
197
- l_g_pix = self.cri_pix(self.output, self.gt)
198
- l_g_total += l_g_pix
199
- loss_dict['l_g_pix'] = l_g_pix
200
-
201
- # perceptual loss
202
- if self.cri_perceptual:
203
- l_g_percep = self.cri_perceptual(self.output, self.gt)
204
- l_g_total += l_g_percep
205
- loss_dict['l_g_percep'] = l_g_percep
206
-
207
- # gan loss
208
- if current_iter > self.net_d_start_iter:
209
- fake_g_pred = self.net_d(self.output)
210
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
211
- recon_loss = l_g_pix + l_g_percep
212
- if not self.fix_generator:
213
- last_layer = self.net_g.module.generator.blocks[-1].weight
214
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
215
- else:
216
- largest_fuse_size = self.opt['network_g']['connect_list'][-1]
217
- last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
218
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
219
-
220
- d_weight *= self.scale_adaptive_gan_weight # 0.8
221
- loss_dict['d_weight'] = d_weight
222
- l_g_total += d_weight * l_g_gan
223
- loss_dict['l_g_gan'] = d_weight * l_g_gan
224
-
225
- l_g_total.backward()
226
- self.optimizer_g.step()
227
-
228
- if self.ema_decay > 0:
229
- self.model_ema(decay=self.ema_decay)
230
-
231
- # optimize net_d
232
- if not large_de:
233
- if current_iter > self.net_d_start_iter:
234
- for p in self.net_d.parameters():
235
- p.requires_grad = True
236
-
237
- self.optimizer_d.zero_grad()
238
- # real
239
- real_d_pred = self.net_d(self.gt)
240
- l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
241
- loss_dict['l_d_real'] = l_d_real
242
- loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
243
- l_d_real.backward()
244
- # fake
245
- fake_d_pred = self.net_d(self.output.detach())
246
- l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
247
- loss_dict['l_d_fake'] = l_d_fake
248
- loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
249
- l_d_fake.backward()
250
-
251
- self.optimizer_d.step()
252
-
253
- self.log_dict = self.reduce_loss_dict(loss_dict)
254
-
255
-
256
- def test(self):
257
- with torch.no_grad():
258
- if hasattr(self, 'net_g_ema'):
259
- self.net_g_ema.eval()
260
- self.output, _, _ = self.net_g_ema(self.input, w=1)
261
- else:
262
- logger = get_root_logger()
263
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
264
- self.net_g.eval()
265
- self.output, _, _ = self.net_g(self.input, w=1)
266
- self.net_g.train()
267
-
268
-
269
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
270
- if self.opt['rank'] == 0:
271
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
272
-
273
-
274
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
275
- dataset_name = dataloader.dataset.opt['name']
276
- with_metrics = self.opt['val'].get('metrics') is not None
277
- if with_metrics:
278
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
279
- pbar = tqdm(total=len(dataloader), unit='image')
280
-
281
- for idx, val_data in enumerate(dataloader):
282
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
283
- self.feed_data(val_data)
284
- self.test()
285
-
286
- visuals = self.get_current_visuals()
287
- sr_img = tensor2img([visuals['result']])
288
- if 'gt' in visuals:
289
- gt_img = tensor2img([visuals['gt']])
290
- del self.gt
291
-
292
- # tentative for out of GPU memory
293
- del self.lq
294
- del self.output
295
- torch.cuda.empty_cache()
296
-
297
- if save_img:
298
- if self.opt['is_train']:
299
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
300
- f'{img_name}_{current_iter}.png')
301
- else:
302
- if self.opt['val']['suffix']:
303
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
304
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
305
- else:
306
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
307
- f'{img_name}_{self.opt["name"]}.png')
308
- imwrite(sr_img, save_img_path)
309
-
310
- if with_metrics:
311
- # calculate metrics
312
- for name, opt_ in self.opt['val']['metrics'].items():
313
- metric_data = dict(img1=sr_img, img2=gt_img)
314
- self.metric_results[name] += calculate_metric(metric_data, opt_)
315
- pbar.update(1)
316
- pbar.set_description(f'Test {img_name}')
317
- pbar.close()
318
-
319
- if with_metrics:
320
- for metric in self.metric_results.keys():
321
- self.metric_results[metric] /= (idx + 1)
322
-
323
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
324
-
325
-
326
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
327
- log_str = f'Validation {dataset_name}\n'
328
- for metric, value in self.metric_results.items():
329
- log_str += f'\t # {metric}: {value:.4f}\n'
330
- logger = get_root_logger()
331
- logger.info(log_str)
332
- if tb_logger:
333
- for metric, value in self.metric_results.items():
334
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
335
-
336
-
337
- def get_current_visuals(self):
338
- out_dict = OrderedDict()
339
- out_dict['gt'] = self.gt.detach().cpu()
340
- out_dict['result'] = self.output.detach().cpu()
341
- return out_dict
342
-
343
-
344
- def save(self, epoch, current_iter):
345
- if self.ema_decay > 0:
346
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
347
- else:
348
- self.save_network(self.net_g, 'net_g', current_iter)
349
- self.save_network(self.net_d, 'net_d', current_iter)
350
- self.save_training_state(epoch, current_iter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/codeformer_model.py DELETED
@@ -1,332 +0,0 @@
1
- import torch
2
- from collections import OrderedDict
3
- from os import path as osp
4
- from tqdm import tqdm
5
-
6
- from basicsr.archs import build_network
7
- from basicsr.losses import build_loss
8
- from basicsr.metrics import calculate_metric
9
- from basicsr.utils import get_root_logger, imwrite, tensor2img
10
- from basicsr.utils.registry import MODEL_REGISTRY
11
- import torch.nn.functional as F
12
- from .sr_model import SRModel
13
-
14
-
15
- @MODEL_REGISTRY.register()
16
- class CodeFormerModel(SRModel):
17
- def feed_data(self, data):
18
- self.gt = data['gt'].to(self.device)
19
- self.input = data['in'].to(self.device)
20
- self.b = self.gt.shape[0]
21
-
22
- if 'latent_gt' in data:
23
- self.idx_gt = data['latent_gt'].to(self.device)
24
- self.idx_gt = self.idx_gt.view(self.b, -1)
25
- else:
26
- self.idx_gt = None
27
-
28
- def init_training_settings(self):
29
- logger = get_root_logger()
30
- train_opt = self.opt['train']
31
-
32
- self.ema_decay = train_opt.get('ema_decay', 0)
33
- if self.ema_decay > 0:
34
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
35
- # define network net_g with Exponential Moving Average (EMA)
36
- # net_g_ema is used only for testing on one GPU and saving
37
- # There is no need to wrap with DistributedDataParallel
38
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
39
- # load pretrained model
40
- load_path = self.opt['path'].get('pretrain_network_g', None)
41
- if load_path is not None:
42
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
43
- else:
44
- self.model_ema(0) # copy net_g weight
45
- self.net_g_ema.eval()
46
-
47
- if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
48
- self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
49
- self.hq_vqgan_fix.eval()
50
- self.generate_idx_gt = True
51
- for param in self.hq_vqgan_fix.parameters():
52
- param.requires_grad = False
53
- else:
54
- self.generate_idx_gt = False
55
-
56
- self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
57
- self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
58
- self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
59
- self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
60
- self.fidelity_weight = train_opt.get('fidelity_weight', 1.0)
61
- self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
62
-
63
-
64
- self.net_g.train()
65
- # define network net_d
66
- if self.fidelity_weight > 0:
67
- self.net_d = build_network(self.opt['network_d'])
68
- self.net_d = self.model_to_device(self.net_d)
69
- self.print_network(self.net_d)
70
-
71
- # load pretrained models
72
- load_path = self.opt['path'].get('pretrain_network_d', None)
73
- if load_path is not None:
74
- self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
75
-
76
- self.net_d.train()
77
-
78
- # define losses
79
- if train_opt.get('pixel_opt'):
80
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
81
- else:
82
- self.cri_pix = None
83
-
84
- if train_opt.get('perceptual_opt'):
85
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
86
- else:
87
- self.cri_perceptual = None
88
-
89
- if train_opt.get('gan_opt'):
90
- self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
91
-
92
-
93
- self.fix_generator = train_opt.get('fix_generator', True)
94
- logger.info(f'fix_generator: {self.fix_generator}')
95
-
96
- self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
97
- self.net_d_iters = train_opt.get('net_d_iters', 1)
98
- self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
99
-
100
- # set up optimizers and schedulers
101
- self.setup_optimizers()
102
- self.setup_schedulers()
103
-
104
- def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
105
- recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
106
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
107
-
108
- d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
109
- d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
110
- return d_weight
111
-
112
- def setup_optimizers(self):
113
- train_opt = self.opt['train']
114
- # optimizer g
115
- optim_params_g = []
116
- for k, v in self.net_g.named_parameters():
117
- if v.requires_grad:
118
- optim_params_g.append(v)
119
- else:
120
- logger = get_root_logger()
121
- logger.warning(f'Params {k} will not be optimized.')
122
- optim_type = train_opt['optim_g'].pop('type')
123
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
124
- self.optimizers.append(self.optimizer_g)
125
- # optimizer d
126
- if self.fidelity_weight > 0:
127
- optim_type = train_opt['optim_d'].pop('type')
128
- self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
129
- self.optimizers.append(self.optimizer_d)
130
-
131
- def gray_resize_for_identity(self, out, size=128):
132
- out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
133
- out_gray = out_gray.unsqueeze(1)
134
- out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
135
- return out_gray
136
-
137
- def optimize_parameters(self, current_iter):
138
- logger = get_root_logger()
139
- # optimize net_g
140
- for p in self.net_d.parameters():
141
- p.requires_grad = False
142
-
143
- self.optimizer_g.zero_grad()
144
-
145
- if self.generate_idx_gt:
146
- x = self.hq_vqgan_fix.encoder(self.gt)
147
- output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
148
- min_encoding_indices = quant_stats['min_encoding_indices']
149
- self.idx_gt = min_encoding_indices.view(self.b, -1)
150
-
151
- if self.fidelity_weight > 0:
152
- self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True)
153
- else:
154
- logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
155
-
156
- if self.hq_feat_loss:
157
- # quant_feats
158
- quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
159
-
160
- l_g_total = 0
161
- loss_dict = OrderedDict()
162
- if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
163
- # hq_feat_loss
164
- if self.hq_feat_loss: # codebook loss
165
- l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
166
- l_g_total += l_feat_encoder
167
- loss_dict['l_feat_encoder'] = l_feat_encoder
168
-
169
- # cross_entropy_loss
170
- if self.cross_entropy_loss:
171
- # b(hw)n -> bn(hw)
172
- cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
173
- l_g_total += cross_entropy_loss
174
- loss_dict['cross_entropy_loss'] = cross_entropy_loss
175
-
176
- if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss
177
- # pixel loss
178
- if self.cri_pix:
179
- l_g_pix = self.cri_pix(self.output, self.gt)
180
- l_g_total += l_g_pix
181
- loss_dict['l_g_pix'] = l_g_pix
182
-
183
- # perceptual loss
184
- if self.cri_perceptual:
185
- l_g_percep = self.cri_perceptual(self.output, self.gt)
186
- l_g_total += l_g_percep
187
- loss_dict['l_g_percep'] = l_g_percep
188
-
189
- # gan loss
190
- if current_iter > self.net_d_start_iter:
191
- fake_g_pred = self.net_d(self.output)
192
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
193
- recon_loss = l_g_pix + l_g_percep
194
- if not self.fix_generator:
195
- last_layer = self.net_g.module.generator.blocks[-1].weight
196
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
197
- else:
198
- largest_fuse_size = self.opt['network_g']['connect_list'][-1]
199
- last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
200
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
201
-
202
- d_weight *= self.scale_adaptive_gan_weight # 0.8
203
- loss_dict['d_weight'] = d_weight
204
- l_g_total += d_weight * l_g_gan
205
- loss_dict['l_g_gan'] = d_weight * l_g_gan
206
-
207
- l_g_total.backward()
208
- self.optimizer_g.step()
209
-
210
- if self.ema_decay > 0:
211
- self.model_ema(decay=self.ema_decay)
212
-
213
- # optimize net_d
214
- if current_iter > self.net_d_start_iter and self.fidelity_weight > 0:
215
- for p in self.net_d.parameters():
216
- p.requires_grad = True
217
-
218
- self.optimizer_d.zero_grad()
219
- # real
220
- real_d_pred = self.net_d(self.gt)
221
- l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
222
- loss_dict['l_d_real'] = l_d_real
223
- loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
224
- l_d_real.backward()
225
- # fake
226
- fake_d_pred = self.net_d(self.output.detach())
227
- l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
228
- loss_dict['l_d_fake'] = l_d_fake
229
- loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
230
- l_d_fake.backward()
231
-
232
- self.optimizer_d.step()
233
-
234
- self.log_dict = self.reduce_loss_dict(loss_dict)
235
-
236
-
237
- def test(self):
238
- with torch.no_grad():
239
- if hasattr(self, 'net_g_ema'):
240
- self.net_g_ema.eval()
241
- self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight)
242
- else:
243
- logger = get_root_logger()
244
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
245
- self.net_g.eval()
246
- self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight)
247
- self.net_g.train()
248
-
249
-
250
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
251
- if self.opt['rank'] == 0:
252
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
253
-
254
-
255
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
256
- dataset_name = dataloader.dataset.opt['name']
257
- with_metrics = self.opt['val'].get('metrics') is not None
258
- if with_metrics:
259
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
260
- pbar = tqdm(total=len(dataloader), unit='image')
261
-
262
- for idx, val_data in enumerate(dataloader):
263
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
264
- self.feed_data(val_data)
265
- self.test()
266
-
267
- visuals = self.get_current_visuals()
268
- sr_img = tensor2img([visuals['result']])
269
- if 'gt' in visuals:
270
- gt_img = tensor2img([visuals['gt']])
271
- del self.gt
272
-
273
- # tentative for out of GPU memory
274
- del self.lq
275
- del self.output
276
- torch.cuda.empty_cache()
277
-
278
- if save_img:
279
- if self.opt['is_train']:
280
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
281
- f'{img_name}_{current_iter}.png')
282
- else:
283
- if self.opt['val']['suffix']:
284
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
285
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
286
- else:
287
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
288
- f'{img_name}_{self.opt["name"]}.png')
289
- imwrite(sr_img, save_img_path)
290
-
291
- if with_metrics:
292
- # calculate metrics
293
- for name, opt_ in self.opt['val']['metrics'].items():
294
- metric_data = dict(img1=sr_img, img2=gt_img)
295
- self.metric_results[name] += calculate_metric(metric_data, opt_)
296
- pbar.update(1)
297
- pbar.set_description(f'Test {img_name}')
298
- pbar.close()
299
-
300
- if with_metrics:
301
- for metric in self.metric_results.keys():
302
- self.metric_results[metric] /= (idx + 1)
303
-
304
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
305
-
306
-
307
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
308
- log_str = f'Validation {dataset_name}\n'
309
- for metric, value in self.metric_results.items():
310
- log_str += f'\t # {metric}: {value:.4f}\n'
311
- logger = get_root_logger()
312
- logger.info(log_str)
313
- if tb_logger:
314
- for metric, value in self.metric_results.items():
315
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
316
-
317
-
318
- def get_current_visuals(self):
319
- out_dict = OrderedDict()
320
- out_dict['gt'] = self.gt.detach().cpu()
321
- out_dict['result'] = self.output.detach().cpu()
322
- return out_dict
323
-
324
-
325
- def save(self, epoch, current_iter):
326
- if self.ema_decay > 0:
327
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
328
- else:
329
- self.save_network(self.net_g, 'net_g', current_iter)
330
- if self.fidelity_weight > 0:
331
- self.save_network(self.net_d, 'net_d', current_iter)
332
- self.save_training_state(epoch, current_iter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/lr_scheduler.py DELETED
@@ -1,96 +0,0 @@
1
- import math
2
- from collections import Counter
3
- from torch.optim.lr_scheduler import _LRScheduler
4
-
5
-
6
- class MultiStepRestartLR(_LRScheduler):
7
- """ MultiStep with restarts learning rate scheme.
8
-
9
- Args:
10
- optimizer (torch.nn.optimizer): Torch optimizer.
11
- milestones (list): Iterations that will decrease learning rate.
12
- gamma (float): Decrease ratio. Default: 0.1.
13
- restarts (list): Restart iterations. Default: [0].
14
- restart_weights (list): Restart weights at each restart iteration.
15
- Default: [1].
16
- last_epoch (int): Used in _LRScheduler. Default: -1.
17
- """
18
-
19
- def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
20
- self.milestones = Counter(milestones)
21
- self.gamma = gamma
22
- self.restarts = restarts
23
- self.restart_weights = restart_weights
24
- assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
25
- super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
26
-
27
- def get_lr(self):
28
- if self.last_epoch in self.restarts:
29
- weight = self.restart_weights[self.restarts.index(self.last_epoch)]
30
- return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
31
- if self.last_epoch not in self.milestones:
32
- return [group['lr'] for group in self.optimizer.param_groups]
33
- return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
34
-
35
-
36
- def get_position_from_periods(iteration, cumulative_period):
37
- """Get the position from a period list.
38
-
39
- It will return the index of the right-closest number in the period list.
40
- For example, the cumulative_period = [100, 200, 300, 400],
41
- if iteration == 50, return 0;
42
- if iteration == 210, return 2;
43
- if iteration == 300, return 2.
44
-
45
- Args:
46
- iteration (int): Current iteration.
47
- cumulative_period (list[int]): Cumulative period list.
48
-
49
- Returns:
50
- int: The position of the right-closest number in the period list.
51
- """
52
- for i, period in enumerate(cumulative_period):
53
- if iteration <= period:
54
- return i
55
-
56
-
57
- class CosineAnnealingRestartLR(_LRScheduler):
58
- """ Cosine annealing with restarts learning rate scheme.
59
-
60
- An example of config:
61
- periods = [10, 10, 10, 10]
62
- restart_weights = [1, 0.5, 0.5, 0.5]
63
- eta_min=1e-7
64
-
65
- It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
66
- scheduler will restart with the weights in restart_weights.
67
-
68
- Args:
69
- optimizer (torch.nn.optimizer): Torch optimizer.
70
- periods (list): Period for each cosine anneling cycle.
71
- restart_weights (list): Restart weights at each restart iteration.
72
- Default: [1].
73
- eta_min (float): The mimimum lr. Default: 0.
74
- last_epoch (int): Used in _LRScheduler. Default: -1.
75
- """
76
-
77
- def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
78
- self.periods = periods
79
- self.restart_weights = restart_weights
80
- self.eta_min = eta_min
81
- assert (len(self.periods) == len(
82
- self.restart_weights)), 'periods and restart_weights should have the same length.'
83
- self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
84
- super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
85
-
86
- def get_lr(self):
87
- idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
88
- current_weight = self.restart_weights[idx]
89
- nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
90
- current_period = self.periods[idx]
91
-
92
- return [
93
- self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
94
- (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
95
- for base_lr in self.base_lrs
96
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/sr_model.py DELETED
@@ -1,209 +0,0 @@
1
- import torch
2
- from collections import OrderedDict
3
- from os import path as osp
4
- from tqdm import tqdm
5
-
6
- from basicsr.archs import build_network
7
- from basicsr.losses import build_loss
8
- from basicsr.metrics import calculate_metric
9
- from basicsr.utils import get_root_logger, imwrite, tensor2img
10
- from basicsr.utils.registry import MODEL_REGISTRY
11
- from .base_model import BaseModel
12
-
13
- @MODEL_REGISTRY.register()
14
- class SRModel(BaseModel):
15
- """Base SR model for single image super-resolution."""
16
-
17
- def __init__(self, opt):
18
- super(SRModel, self).__init__(opt)
19
-
20
- # define network
21
- self.net_g = build_network(opt['network_g'])
22
- self.net_g = self.model_to_device(self.net_g)
23
- self.print_network(self.net_g)
24
-
25
- # load pretrained models
26
- load_path = self.opt['path'].get('pretrain_network_g', None)
27
- if load_path is not None:
28
- param_key = self.opt['path'].get('param_key_g', 'params')
29
- self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
30
-
31
- if self.is_train:
32
- self.init_training_settings()
33
-
34
- def init_training_settings(self):
35
- self.net_g.train()
36
- train_opt = self.opt['train']
37
-
38
- self.ema_decay = train_opt.get('ema_decay', 0)
39
- if self.ema_decay > 0:
40
- logger = get_root_logger()
41
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
42
- # define network net_g with Exponential Moving Average (EMA)
43
- # net_g_ema is used only for testing on one GPU and saving
44
- # There is no need to wrap with DistributedDataParallel
45
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
46
- # load pretrained model
47
- load_path = self.opt['path'].get('pretrain_network_g', None)
48
- if load_path is not None:
49
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
50
- else:
51
- self.model_ema(0) # copy net_g weight
52
- self.net_g_ema.eval()
53
-
54
- # define losses
55
- if train_opt.get('pixel_opt'):
56
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
57
- else:
58
- self.cri_pix = None
59
-
60
- if train_opt.get('perceptual_opt'):
61
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
62
- else:
63
- self.cri_perceptual = None
64
-
65
- if self.cri_pix is None and self.cri_perceptual is None:
66
- raise ValueError('Both pixel and perceptual losses are None.')
67
-
68
- # set up optimizers and schedulers
69
- self.setup_optimizers()
70
- self.setup_schedulers()
71
-
72
- def setup_optimizers(self):
73
- train_opt = self.opt['train']
74
- optim_params = []
75
- for k, v in self.net_g.named_parameters():
76
- if v.requires_grad:
77
- optim_params.append(v)
78
- else:
79
- logger = get_root_logger()
80
- logger.warning(f'Params {k} will not be optimized.')
81
-
82
- optim_type = train_opt['optim_g'].pop('type')
83
- self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
84
- self.optimizers.append(self.optimizer_g)
85
-
86
- def feed_data(self, data):
87
- self.lq = data['lq'].to(self.device)
88
- if 'gt' in data:
89
- self.gt = data['gt'].to(self.device)
90
-
91
- def optimize_parameters(self, current_iter):
92
- self.optimizer_g.zero_grad()
93
- self.output = self.net_g(self.lq)
94
-
95
- l_total = 0
96
- loss_dict = OrderedDict()
97
- # pixel loss
98
- if self.cri_pix:
99
- l_pix = self.cri_pix(self.output, self.gt)
100
- l_total += l_pix
101
- loss_dict['l_pix'] = l_pix
102
- # perceptual loss
103
- if self.cri_perceptual:
104
- l_percep, l_style = self.cri_perceptual(self.output, self.gt)
105
- if l_percep is not None:
106
- l_total += l_percep
107
- loss_dict['l_percep'] = l_percep
108
- if l_style is not None:
109
- l_total += l_style
110
- loss_dict['l_style'] = l_style
111
-
112
- l_total.backward()
113
- self.optimizer_g.step()
114
-
115
- self.log_dict = self.reduce_loss_dict(loss_dict)
116
-
117
- if self.ema_decay > 0:
118
- self.model_ema(decay=self.ema_decay)
119
-
120
- def test(self):
121
- if hasattr(self, 'ema_decay'):
122
- self.net_g_ema.eval()
123
- with torch.no_grad():
124
- self.output = self.net_g_ema(self.lq)
125
- else:
126
- self.net_g.eval()
127
- with torch.no_grad():
128
- self.output = self.net_g(self.lq)
129
- self.net_g.train()
130
-
131
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
132
- if self.opt['rank'] == 0:
133
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
134
-
135
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
136
- dataset_name = dataloader.dataset.opt['name']
137
- with_metrics = self.opt['val'].get('metrics') is not None
138
- if with_metrics:
139
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
140
- pbar = tqdm(total=len(dataloader), unit='image')
141
-
142
- for idx, val_data in enumerate(dataloader):
143
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
144
- self.feed_data(val_data)
145
- self.test()
146
-
147
- visuals = self.get_current_visuals()
148
- sr_img = tensor2img([visuals['result']])
149
- if 'gt' in visuals:
150
- gt_img = tensor2img([visuals['gt']])
151
- del self.gt
152
-
153
- # tentative for out of GPU memory
154
- del self.lq
155
- del self.output
156
- torch.cuda.empty_cache()
157
-
158
- if save_img:
159
- if self.opt['is_train']:
160
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
161
- f'{img_name}_{current_iter}.png')
162
- else:
163
- if self.opt['val']['suffix']:
164
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
165
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
166
- else:
167
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
168
- f'{img_name}_{self.opt["name"]}.png')
169
- imwrite(sr_img, save_img_path)
170
-
171
- if with_metrics:
172
- # calculate metrics
173
- for name, opt_ in self.opt['val']['metrics'].items():
174
- metric_data = dict(img1=sr_img, img2=gt_img)
175
- self.metric_results[name] += calculate_metric(metric_data, opt_)
176
- pbar.update(1)
177
- pbar.set_description(f'Test {img_name}')
178
- pbar.close()
179
-
180
- if with_metrics:
181
- for metric in self.metric_results.keys():
182
- self.metric_results[metric] /= (idx + 1)
183
-
184
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
185
-
186
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
187
- log_str = f'Validation {dataset_name}\n'
188
- for metric, value in self.metric_results.items():
189
- log_str += f'\t # {metric}: {value:.4f}\n'
190
- logger = get_root_logger()
191
- logger.info(log_str)
192
- if tb_logger:
193
- for metric, value in self.metric_results.items():
194
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
195
-
196
- def get_current_visuals(self):
197
- out_dict = OrderedDict()
198
- out_dict['lq'] = self.lq.detach().cpu()
199
- out_dict['result'] = self.output.detach().cpu()
200
- if hasattr(self, 'gt'):
201
- out_dict['gt'] = self.gt.detach().cpu()
202
- return out_dict
203
-
204
- def save(self, epoch, current_iter):
205
- if hasattr(self, 'ema_decay'):
206
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
207
- else:
208
- self.save_network(self.net_g, 'net_g', current_iter)
209
- self.save_training_state(epoch, current_iter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/models/vqgan_model.py DELETED
@@ -1,285 +0,0 @@
1
- import torch
2
- from collections import OrderedDict
3
- from os import path as osp
4
- from tqdm import tqdm
5
-
6
- from basicsr.archs import build_network
7
- from basicsr.losses import build_loss
8
- from basicsr.metrics import calculate_metric
9
- from basicsr.utils import get_root_logger, imwrite, tensor2img
10
- from basicsr.utils.registry import MODEL_REGISTRY
11
- import torch.nn.functional as F
12
- from .sr_model import SRModel
13
-
14
-
15
- @MODEL_REGISTRY.register()
16
- class VQGANModel(SRModel):
17
- def feed_data(self, data):
18
- self.gt = data['gt'].to(self.device)
19
- self.b = self.gt.shape[0]
20
-
21
-
22
- def init_training_settings(self):
23
- logger = get_root_logger()
24
- train_opt = self.opt['train']
25
-
26
- self.ema_decay = train_opt.get('ema_decay', 0)
27
- if self.ema_decay > 0:
28
- logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
29
- # define network net_g with Exponential Moving Average (EMA)
30
- # net_g_ema is used only for testing on one GPU and saving
31
- # There is no need to wrap with DistributedDataParallel
32
- self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
33
- # load pretrained model
34
- load_path = self.opt['path'].get('pretrain_network_g', None)
35
- if load_path is not None:
36
- self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
37
- else:
38
- self.model_ema(0) # copy net_g weight
39
- self.net_g_ema.eval()
40
-
41
- # define network net_d
42
- self.net_d = build_network(self.opt['network_d'])
43
- self.net_d = self.model_to_device(self.net_d)
44
- self.print_network(self.net_d)
45
-
46
- # load pretrained models
47
- load_path = self.opt['path'].get('pretrain_network_d', None)
48
- if load_path is not None:
49
- self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
50
-
51
- self.net_g.train()
52
- self.net_d.train()
53
-
54
- # define losses
55
- if train_opt.get('pixel_opt'):
56
- self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
57
- else:
58
- self.cri_pix = None
59
-
60
- if train_opt.get('perceptual_opt'):
61
- self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
62
- else:
63
- self.cri_perceptual = None
64
-
65
- if train_opt.get('gan_opt'):
66
- self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
67
-
68
- if train_opt.get('codebook_opt'):
69
- self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
70
- else:
71
- self.l_weight_codebook = 1.0
72
-
73
- self.vqgan_quantizer = self.opt['network_g']['quantizer']
74
- logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
75
-
76
- self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
77
- self.net_d_iters = train_opt.get('net_d_iters', 1)
78
- self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
79
- self.disc_weight = train_opt.get('disc_weight', 0.8)
80
-
81
- # set up optimizers and schedulers
82
- self.setup_optimizers()
83
- self.setup_schedulers()
84
-
85
- def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
86
- recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
87
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
88
-
89
- d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
90
- d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
91
- return d_weight
92
-
93
- def adopt_weight(self, weight, global_step, threshold=0, value=0.):
94
- if global_step < threshold:
95
- weight = value
96
- return weight
97
-
98
- def setup_optimizers(self):
99
- train_opt = self.opt['train']
100
- # optimizer g
101
- optim_params_g = []
102
- for k, v in self.net_g.named_parameters():
103
- if v.requires_grad:
104
- optim_params_g.append(v)
105
- else:
106
- logger = get_root_logger()
107
- logger.warning(f'Params {k} will not be optimized.')
108
- optim_type = train_opt['optim_g'].pop('type')
109
- self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
110
- self.optimizers.append(self.optimizer_g)
111
- # optimizer d
112
- optim_type = train_opt['optim_d'].pop('type')
113
- self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
114
- self.optimizers.append(self.optimizer_d)
115
-
116
-
117
- def optimize_parameters(self, current_iter):
118
- logger = get_root_logger()
119
- loss_dict = OrderedDict()
120
- if self.opt['network_g']['quantizer'] == 'gumbel':
121
- self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
122
- if current_iter%1000 == 0:
123
- logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
124
-
125
- # optimize net_g
126
- for p in self.net_d.parameters():
127
- p.requires_grad = False
128
-
129
- self.optimizer_g.zero_grad()
130
- self.output, l_codebook, quant_stats = self.net_g(self.gt)
131
-
132
- l_codebook = l_codebook*self.l_weight_codebook
133
-
134
- l_g_total = 0
135
- if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
136
- # pixel loss
137
- if self.cri_pix:
138
- l_g_pix = self.cri_pix(self.output, self.gt)
139
- l_g_total += l_g_pix
140
- loss_dict['l_g_pix'] = l_g_pix
141
- # perceptual loss
142
- if self.cri_perceptual:
143
- l_g_percep = self.cri_perceptual(self.output, self.gt)
144
- l_g_total += l_g_percep
145
- loss_dict['l_g_percep'] = l_g_percep
146
-
147
- # gan loss
148
- if current_iter > self.net_d_start_iter:
149
- # fake_g_pred = self.net_d(self.output_1024)
150
- fake_g_pred = self.net_d(self.output)
151
- l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
152
- recon_loss = l_g_total
153
- last_layer = self.net_g.module.generator.blocks[-1].weight
154
- d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
155
- d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
156
- d_weight *= self.disc_weight # tamming setting 0.8
157
- l_g_total += d_weight * l_g_gan
158
- loss_dict['l_g_gan'] = d_weight * l_g_gan
159
-
160
- l_g_total += l_codebook
161
- loss_dict['l_codebook'] = l_codebook
162
-
163
- l_g_total.backward()
164
- self.optimizer_g.step()
165
-
166
- # optimize net_d
167
- if current_iter > self.net_d_start_iter:
168
- for p in self.net_d.parameters():
169
- p.requires_grad = True
170
-
171
- self.optimizer_d.zero_grad()
172
- # real
173
- real_d_pred = self.net_d(self.gt)
174
- l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
175
- loss_dict['l_d_real'] = l_d_real
176
- loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
177
- l_d_real.backward()
178
- # fake
179
- fake_d_pred = self.net_d(self.output.detach())
180
- l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
181
- loss_dict['l_d_fake'] = l_d_fake
182
- loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
183
- l_d_fake.backward()
184
- self.optimizer_d.step()
185
-
186
- self.log_dict = self.reduce_loss_dict(loss_dict)
187
-
188
- if self.ema_decay > 0:
189
- self.model_ema(decay=self.ema_decay)
190
-
191
-
192
- def test(self):
193
- with torch.no_grad():
194
- if hasattr(self, 'net_g_ema'):
195
- self.net_g_ema.eval()
196
- self.output, _, _ = self.net_g_ema(self.gt)
197
- else:
198
- logger = get_root_logger()
199
- logger.warning('Do not have self.net_g_ema, use self.net_g.')
200
- self.net_g.eval()
201
- self.output, _, _ = self.net_g(self.gt)
202
- self.net_g.train()
203
-
204
-
205
- def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
206
- if self.opt['rank'] == 0:
207
- self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
208
-
209
-
210
- def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
211
- dataset_name = dataloader.dataset.opt['name']
212
- with_metrics = self.opt['val'].get('metrics') is not None
213
- if with_metrics:
214
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
215
- pbar = tqdm(total=len(dataloader), unit='image')
216
-
217
- for idx, val_data in enumerate(dataloader):
218
- img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
219
- self.feed_data(val_data)
220
- self.test()
221
-
222
- visuals = self.get_current_visuals()
223
- sr_img = tensor2img([visuals['result']])
224
- if 'gt' in visuals:
225
- gt_img = tensor2img([visuals['gt']])
226
- del self.gt
227
-
228
- # tentative for out of GPU memory
229
- del self.lq
230
- del self.output
231
- torch.cuda.empty_cache()
232
-
233
- if save_img:
234
- if self.opt['is_train']:
235
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
236
- f'{img_name}_{current_iter}.png')
237
- else:
238
- if self.opt['val']['suffix']:
239
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
240
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
241
- else:
242
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
243
- f'{img_name}_{self.opt["name"]}.png')
244
- imwrite(sr_img, save_img_path)
245
-
246
- if with_metrics:
247
- # calculate metrics
248
- for name, opt_ in self.opt['val']['metrics'].items():
249
- metric_data = dict(img1=sr_img, img2=gt_img)
250
- self.metric_results[name] += calculate_metric(metric_data, opt_)
251
- pbar.update(1)
252
- pbar.set_description(f'Test {img_name}')
253
- pbar.close()
254
-
255
- if with_metrics:
256
- for metric in self.metric_results.keys():
257
- self.metric_results[metric] /= (idx + 1)
258
-
259
- self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
260
-
261
-
262
- def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
263
- log_str = f'Validation {dataset_name}\n'
264
- for metric, value in self.metric_results.items():
265
- log_str += f'\t # {metric}: {value:.4f}\n'
266
- logger = get_root_logger()
267
- logger.info(log_str)
268
- if tb_logger:
269
- for metric, value in self.metric_results.items():
270
- tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
271
-
272
-
273
- def get_current_visuals(self):
274
- out_dict = OrderedDict()
275
- out_dict['gt'] = self.gt.detach().cpu()
276
- out_dict['result'] = self.output.detach().cpu()
277
- return out_dict
278
-
279
- def save(self, epoch, current_iter):
280
- if self.ema_decay > 0:
281
- self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
282
- else:
283
- self.save_network(self.net_g, 'net_g', current_iter)
284
- self.save_network(self.net_d, 'net_d', current_iter)
285
- self.save_training_state(epoch, current_iter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/dcn/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
2
- modulated_deform_conv)
3
-
4
- __all__ = [
5
- 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
6
- 'modulated_deform_conv'
7
- ]
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/dcn/deform_conv.py DELETED
@@ -1,377 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn as nn
4
- from torch.autograd import Function
5
- from torch.autograd.function import once_differentiable
6
- from torch.nn import functional as F
7
- from torch.nn.modules.utils import _pair, _single
8
-
9
- try:
10
- from . import deform_conv_ext
11
- except ImportError:
12
- import os
13
- BASICSR_JIT = os.getenv('BASICSR_JIT')
14
- if BASICSR_JIT == 'True':
15
- from torch.utils.cpp_extension import load
16
- module_path = os.path.dirname(__file__)
17
- deform_conv_ext = load(
18
- 'deform_conv',
19
- sources=[
20
- os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
21
- os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
22
- os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
23
- ],
24
- )
25
-
26
-
27
- class DeformConvFunction(Function):
28
-
29
- @staticmethod
30
- def forward(ctx,
31
- input,
32
- offset,
33
- weight,
34
- stride=1,
35
- padding=0,
36
- dilation=1,
37
- groups=1,
38
- deformable_groups=1,
39
- im2col_step=64):
40
- if input is not None and input.dim() != 4:
41
- raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
42
- ctx.stride = _pair(stride)
43
- ctx.padding = _pair(padding)
44
- ctx.dilation = _pair(dilation)
45
- ctx.groups = groups
46
- ctx.deformable_groups = deformable_groups
47
- ctx.im2col_step = im2col_step
48
-
49
- ctx.save_for_backward(input, offset, weight)
50
-
51
- output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
52
-
53
- ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
54
-
55
- if not input.is_cuda:
56
- raise NotImplementedError
57
- else:
58
- cur_im2col_step = min(ctx.im2col_step, input.shape[0])
59
- assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
60
- deform_conv_ext.deform_conv_forward(input, weight,
61
- offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
62
- weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
63
- ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
64
- ctx.deformable_groups, cur_im2col_step)
65
- return output
66
-
67
- @staticmethod
68
- @once_differentiable
69
- def backward(ctx, grad_output):
70
- input, offset, weight = ctx.saved_tensors
71
-
72
- grad_input = grad_offset = grad_weight = None
73
-
74
- if not grad_output.is_cuda:
75
- raise NotImplementedError
76
- else:
77
- cur_im2col_step = min(ctx.im2col_step, input.shape[0])
78
- assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
79
-
80
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
81
- grad_input = torch.zeros_like(input)
82
- grad_offset = torch.zeros_like(offset)
83
- deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
84
- grad_offset, weight, ctx.bufs_[0], weight.size(3),
85
- weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
86
- ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
87
- ctx.deformable_groups, cur_im2col_step)
88
-
89
- if ctx.needs_input_grad[2]:
90
- grad_weight = torch.zeros_like(weight)
91
- deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
92
- ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
93
- weight.size(2), ctx.stride[1], ctx.stride[0],
94
- ctx.padding[1], ctx.padding[0], ctx.dilation[1],
95
- ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
96
- cur_im2col_step)
97
-
98
- return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
99
-
100
- @staticmethod
101
- def _output_size(input, weight, padding, dilation, stride):
102
- channels = weight.size(0)
103
- output_size = (input.size(0), channels)
104
- for d in range(input.dim() - 2):
105
- in_size = input.size(d + 2)
106
- pad = padding[d]
107
- kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
108
- stride_ = stride[d]
109
- output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
110
- if not all(map(lambda s: s > 0, output_size)):
111
- raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
112
- return output_size
113
-
114
-
115
- class ModulatedDeformConvFunction(Function):
116
-
117
- @staticmethod
118
- def forward(ctx,
119
- input,
120
- offset,
121
- mask,
122
- weight,
123
- bias=None,
124
- stride=1,
125
- padding=0,
126
- dilation=1,
127
- groups=1,
128
- deformable_groups=1):
129
- ctx.stride = stride
130
- ctx.padding = padding
131
- ctx.dilation = dilation
132
- ctx.groups = groups
133
- ctx.deformable_groups = deformable_groups
134
- ctx.with_bias = bias is not None
135
- if not ctx.with_bias:
136
- bias = input.new_empty(1) # fake tensor
137
- if not input.is_cuda:
138
- raise NotImplementedError
139
- if weight.requires_grad or mask.requires_grad or offset.requires_grad \
140
- or input.requires_grad:
141
- ctx.save_for_backward(input, offset, mask, weight, bias)
142
- output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
143
- ctx._bufs = [input.new_empty(0), input.new_empty(0)]
144
- deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
145
- ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
146
- ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
147
- ctx.groups, ctx.deformable_groups, ctx.with_bias)
148
- return output
149
-
150
- @staticmethod
151
- @once_differentiable
152
- def backward(ctx, grad_output):
153
- if not grad_output.is_cuda:
154
- raise NotImplementedError
155
- input, offset, mask, weight, bias = ctx.saved_tensors
156
- grad_input = torch.zeros_like(input)
157
- grad_offset = torch.zeros_like(offset)
158
- grad_mask = torch.zeros_like(mask)
159
- grad_weight = torch.zeros_like(weight)
160
- grad_bias = torch.zeros_like(bias)
161
- deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
162
- grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
163
- grad_output, weight.shape[2], weight.shape[3], ctx.stride,
164
- ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
165
- ctx.groups, ctx.deformable_groups, ctx.with_bias)
166
- if not ctx.with_bias:
167
- grad_bias = None
168
-
169
- return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
170
-
171
- @staticmethod
172
- def _infer_shape(ctx, input, weight):
173
- n = input.size(0)
174
- channels_out = weight.size(0)
175
- height, width = input.shape[2:4]
176
- kernel_h, kernel_w = weight.shape[2:4]
177
- height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
178
- width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
179
- return n, channels_out, height_out, width_out
180
-
181
-
182
- deform_conv = DeformConvFunction.apply
183
- modulated_deform_conv = ModulatedDeformConvFunction.apply
184
-
185
-
186
- class DeformConv(nn.Module):
187
-
188
- def __init__(self,
189
- in_channels,
190
- out_channels,
191
- kernel_size,
192
- stride=1,
193
- padding=0,
194
- dilation=1,
195
- groups=1,
196
- deformable_groups=1,
197
- bias=False):
198
- super(DeformConv, self).__init__()
199
-
200
- assert not bias
201
- assert in_channels % groups == 0, \
202
- f'in_channels {in_channels} is not divisible by groups {groups}'
203
- assert out_channels % groups == 0, \
204
- f'out_channels {out_channels} is not divisible ' \
205
- f'by groups {groups}'
206
-
207
- self.in_channels = in_channels
208
- self.out_channels = out_channels
209
- self.kernel_size = _pair(kernel_size)
210
- self.stride = _pair(stride)
211
- self.padding = _pair(padding)
212
- self.dilation = _pair(dilation)
213
- self.groups = groups
214
- self.deformable_groups = deformable_groups
215
- # enable compatibility with nn.Conv2d
216
- self.transposed = False
217
- self.output_padding = _single(0)
218
-
219
- self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
220
-
221
- self.reset_parameters()
222
-
223
- def reset_parameters(self):
224
- n = self.in_channels
225
- for k in self.kernel_size:
226
- n *= k
227
- stdv = 1. / math.sqrt(n)
228
- self.weight.data.uniform_(-stdv, stdv)
229
-
230
- def forward(self, x, offset):
231
- # To fix an assert error in deform_conv_cuda.cpp:128
232
- # input image is smaller than kernel
233
- input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
234
- if input_pad:
235
- pad_h = max(self.kernel_size[0] - x.size(2), 0)
236
- pad_w = max(self.kernel_size[1] - x.size(3), 0)
237
- x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
238
- offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
239
- out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
240
- self.deformable_groups)
241
- if input_pad:
242
- out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
243
- return out
244
-
245
-
246
- class DeformConvPack(DeformConv):
247
- """A Deformable Conv Encapsulation that acts as normal Conv layers.
248
-
249
- Args:
250
- in_channels (int): Same as nn.Conv2d.
251
- out_channels (int): Same as nn.Conv2d.
252
- kernel_size (int or tuple[int]): Same as nn.Conv2d.
253
- stride (int or tuple[int]): Same as nn.Conv2d.
254
- padding (int or tuple[int]): Same as nn.Conv2d.
255
- dilation (int or tuple[int]): Same as nn.Conv2d.
256
- groups (int): Same as nn.Conv2d.
257
- bias (bool or str): If specified as `auto`, it will be decided by the
258
- norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
259
- False.
260
- """
261
-
262
- _version = 2
263
-
264
- def __init__(self, *args, **kwargs):
265
- super(DeformConvPack, self).__init__(*args, **kwargs)
266
-
267
- self.conv_offset = nn.Conv2d(
268
- self.in_channels,
269
- self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
270
- kernel_size=self.kernel_size,
271
- stride=_pair(self.stride),
272
- padding=_pair(self.padding),
273
- dilation=_pair(self.dilation),
274
- bias=True)
275
- self.init_offset()
276
-
277
- def init_offset(self):
278
- self.conv_offset.weight.data.zero_()
279
- self.conv_offset.bias.data.zero_()
280
-
281
- def forward(self, x):
282
- offset = self.conv_offset(x)
283
- return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
284
- self.deformable_groups)
285
-
286
-
287
- class ModulatedDeformConv(nn.Module):
288
-
289
- def __init__(self,
290
- in_channels,
291
- out_channels,
292
- kernel_size,
293
- stride=1,
294
- padding=0,
295
- dilation=1,
296
- groups=1,
297
- deformable_groups=1,
298
- bias=True):
299
- super(ModulatedDeformConv, self).__init__()
300
- self.in_channels = in_channels
301
- self.out_channels = out_channels
302
- self.kernel_size = _pair(kernel_size)
303
- self.stride = stride
304
- self.padding = padding
305
- self.dilation = dilation
306
- self.groups = groups
307
- self.deformable_groups = deformable_groups
308
- self.with_bias = bias
309
- # enable compatibility with nn.Conv2d
310
- self.transposed = False
311
- self.output_padding = _single(0)
312
-
313
- self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
314
- if bias:
315
- self.bias = nn.Parameter(torch.Tensor(out_channels))
316
- else:
317
- self.register_parameter('bias', None)
318
- self.init_weights()
319
-
320
- def init_weights(self):
321
- n = self.in_channels
322
- for k in self.kernel_size:
323
- n *= k
324
- stdv = 1. / math.sqrt(n)
325
- self.weight.data.uniform_(-stdv, stdv)
326
- if self.bias is not None:
327
- self.bias.data.zero_()
328
-
329
- def forward(self, x, offset, mask):
330
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
331
- self.groups, self.deformable_groups)
332
-
333
-
334
- class ModulatedDeformConvPack(ModulatedDeformConv):
335
- """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
336
-
337
- Args:
338
- in_channels (int): Same as nn.Conv2d.
339
- out_channels (int): Same as nn.Conv2d.
340
- kernel_size (int or tuple[int]): Same as nn.Conv2d.
341
- stride (int or tuple[int]): Same as nn.Conv2d.
342
- padding (int or tuple[int]): Same as nn.Conv2d.
343
- dilation (int or tuple[int]): Same as nn.Conv2d.
344
- groups (int): Same as nn.Conv2d.
345
- bias (bool or str): If specified as `auto`, it will be decided by the
346
- norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
347
- False.
348
- """
349
-
350
- _version = 2
351
-
352
- def __init__(self, *args, **kwargs):
353
- super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
354
-
355
- self.conv_offset = nn.Conv2d(
356
- self.in_channels,
357
- self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
358
- kernel_size=self.kernel_size,
359
- stride=_pair(self.stride),
360
- padding=_pair(self.padding),
361
- dilation=_pair(self.dilation),
362
- bias=True)
363
- self.init_weights()
364
-
365
- def init_weights(self):
366
- super(ModulatedDeformConvPack, self).init_weights()
367
- if hasattr(self, 'conv_offset'):
368
- self.conv_offset.weight.data.zero_()
369
- self.conv_offset.bias.data.zero_()
370
-
371
- def forward(self, x):
372
- out = self.conv_offset(x)
373
- o1, o2, mask = torch.chunk(out, 3, dim=1)
374
- offset = torch.cat((o1, o2), dim=1)
375
- mask = torch.sigmoid(mask)
376
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
377
- self.groups, self.deformable_groups)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp DELETED
@@ -1,685 +0,0 @@
1
- // modify from
2
- // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
-
4
- #include <torch/extension.h>
5
- #include <ATen/DeviceGuard.h>
6
-
7
- #include <cmath>
8
- #include <vector>
9
-
10
- void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
11
- const int channels, const int height, const int width,
12
- const int ksize_h, const int ksize_w, const int pad_h,
13
- const int pad_w, const int stride_h, const int stride_w,
14
- const int dilation_h, const int dilation_w,
15
- const int parallel_imgs, const int deformable_group,
16
- at::Tensor data_col);
17
-
18
- void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
19
- const int channels, const int height, const int width,
20
- const int ksize_h, const int ksize_w, const int pad_h,
21
- const int pad_w, const int stride_h, const int stride_w,
22
- const int dilation_h, const int dilation_w,
23
- const int parallel_imgs, const int deformable_group,
24
- at::Tensor grad_im);
25
-
26
- void deformable_col2im_coord(
27
- const at::Tensor data_col, const at::Tensor data_im,
28
- const at::Tensor data_offset, const int channels, const int height,
29
- const int width, const int ksize_h, const int ksize_w, const int pad_h,
30
- const int pad_w, const int stride_h, const int stride_w,
31
- const int dilation_h, const int dilation_w, const int parallel_imgs,
32
- const int deformable_group, at::Tensor grad_offset);
33
-
34
- void modulated_deformable_im2col_cuda(
35
- const at::Tensor data_im, const at::Tensor data_offset,
36
- const at::Tensor data_mask, const int batch_size, const int channels,
37
- const int height_im, const int width_im, const int height_col,
38
- const int width_col, const int kernel_h, const int kenerl_w,
39
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
40
- const int dilation_h, const int dilation_w, const int deformable_group,
41
- at::Tensor data_col);
42
-
43
- void modulated_deformable_col2im_cuda(
44
- const at::Tensor data_col, const at::Tensor data_offset,
45
- const at::Tensor data_mask, const int batch_size, const int channels,
46
- const int height_im, const int width_im, const int height_col,
47
- const int width_col, const int kernel_h, const int kenerl_w,
48
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
49
- const int dilation_h, const int dilation_w, const int deformable_group,
50
- at::Tensor grad_im);
51
-
52
- void modulated_deformable_col2im_coord_cuda(
53
- const at::Tensor data_col, const at::Tensor data_im,
54
- const at::Tensor data_offset, const at::Tensor data_mask,
55
- const int batch_size, const int channels, const int height_im,
56
- const int width_im, const int height_col, const int width_col,
57
- const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
58
- const int stride_h, const int stride_w, const int dilation_h,
59
- const int dilation_w, const int deformable_group, at::Tensor grad_offset,
60
- at::Tensor grad_mask);
61
-
62
- void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
63
- at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
64
- int padW, int dilationH, int dilationW, int group,
65
- int deformable_group) {
66
- TORCH_CHECK(weight.ndimension() == 4,
67
- "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
68
- "but got: %s",
69
- weight.ndimension());
70
-
71
- TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
72
-
73
- TORCH_CHECK(kW > 0 && kH > 0,
74
- "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
75
- kW);
76
-
77
- TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
78
- "kernel size should be consistent with weight, ",
79
- "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
80
- kW, weight.size(2), weight.size(3));
81
-
82
- TORCH_CHECK(dW > 0 && dH > 0,
83
- "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
84
-
85
- TORCH_CHECK(
86
- dilationW > 0 && dilationH > 0,
87
- "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
88
- dilationH, dilationW);
89
-
90
- int ndim = input.ndimension();
91
- int dimf = 0;
92
- int dimh = 1;
93
- int dimw = 2;
94
-
95
- if (ndim == 4) {
96
- dimf++;
97
- dimh++;
98
- dimw++;
99
- }
100
-
101
- TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
102
- ndim);
103
-
104
- long nInputPlane = weight.size(1) * group;
105
- long inputHeight = input.size(dimh);
106
- long inputWidth = input.size(dimw);
107
- long nOutputPlane = weight.size(0);
108
- long outputHeight =
109
- (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
110
- long outputWidth =
111
- (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
112
-
113
- TORCH_CHECK(nInputPlane % deformable_group == 0,
114
- "input channels must divide deformable group size");
115
-
116
- if (outputWidth < 1 || outputHeight < 1)
117
- AT_ERROR(
118
- "Given input size: (%ld x %ld x %ld). "
119
- "Calculated output size: (%ld x %ld x %ld). Output size is too small",
120
- nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
121
- outputWidth);
122
-
123
- TORCH_CHECK(input.size(1) == nInputPlane,
124
- "invalid number of input planes, expected: %d, but got: %d",
125
- nInputPlane, input.size(1));
126
-
127
- TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
128
- "input image is smaller than kernel");
129
-
130
- TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
131
- "invalid spatial size of offset, expected height: %d width: %d, but "
132
- "got height: %d width: %d",
133
- outputHeight, outputWidth, offset.size(2), offset.size(3));
134
-
135
- TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
136
- "invalid number of channels of offset");
137
-
138
- if (gradOutput != NULL) {
139
- TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
140
- "invalid number of gradOutput planes, expected: %d, but got: %d",
141
- nOutputPlane, gradOutput->size(dimf));
142
-
143
- TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
144
- gradOutput->size(dimw) == outputWidth),
145
- "invalid size of gradOutput, expected height: %d width: %d , but "
146
- "got height: %d width: %d",
147
- outputHeight, outputWidth, gradOutput->size(dimh),
148
- gradOutput->size(dimw));
149
- }
150
- }
151
-
152
- int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
153
- at::Tensor offset, at::Tensor output,
154
- at::Tensor columns, at::Tensor ones, int kW,
155
- int kH, int dW, int dH, int padW, int padH,
156
- int dilationW, int dilationH, int group,
157
- int deformable_group, int im2col_step) {
158
- // todo: resize columns to include im2col: done
159
- // todo: add im2col_step as input
160
- // todo: add new output buffer and transpose it to output (or directly
161
- // transpose output) todo: possibly change data indexing because of
162
- // parallel_imgs
163
-
164
- shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
165
- dilationH, dilationW, group, deformable_group);
166
- at::DeviceGuard guard(input.device());
167
-
168
- input = input.contiguous();
169
- offset = offset.contiguous();
170
- weight = weight.contiguous();
171
-
172
- int batch = 1;
173
- if (input.ndimension() == 3) {
174
- // Force batch
175
- batch = 0;
176
- input.unsqueeze_(0);
177
- offset.unsqueeze_(0);
178
- }
179
-
180
- // todo: assert batchsize dividable by im2col_step
181
-
182
- long batchSize = input.size(0);
183
- long nInputPlane = input.size(1);
184
- long inputHeight = input.size(2);
185
- long inputWidth = input.size(3);
186
-
187
- long nOutputPlane = weight.size(0);
188
-
189
- long outputWidth =
190
- (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
191
- long outputHeight =
192
- (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
193
-
194
- TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
195
-
196
- output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
197
- outputHeight, outputWidth});
198
- columns = at::zeros(
199
- {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
200
- input.options());
201
-
202
- if (ones.ndimension() != 2 ||
203
- ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
204
- ones = at::ones({outputHeight, outputWidth}, input.options());
205
- }
206
-
207
- input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
208
- inputHeight, inputWidth});
209
- offset =
210
- offset.view({batchSize / im2col_step, im2col_step,
211
- deformable_group * 2 * kH * kW, outputHeight, outputWidth});
212
-
213
- at::Tensor output_buffer =
214
- at::zeros({batchSize / im2col_step, nOutputPlane,
215
- im2col_step * outputHeight, outputWidth},
216
- output.options());
217
-
218
- output_buffer = output_buffer.view(
219
- {output_buffer.size(0), group, output_buffer.size(1) / group,
220
- output_buffer.size(2), output_buffer.size(3)});
221
-
222
- for (int elt = 0; elt < batchSize / im2col_step; elt++) {
223
- deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
224
- inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
225
- dilationW, im2col_step, deformable_group, columns);
226
-
227
- columns = columns.view({group, columns.size(0) / group, columns.size(1)});
228
- weight = weight.view({group, weight.size(0) / group, weight.size(1),
229
- weight.size(2), weight.size(3)});
230
-
231
- for (int g = 0; g < group; g++) {
232
- output_buffer[elt][g] = output_buffer[elt][g]
233
- .flatten(1)
234
- .addmm_(weight[g].flatten(1), columns[g])
235
- .view_as(output_buffer[elt][g]);
236
- }
237
- }
238
-
239
- output_buffer = output_buffer.view(
240
- {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
241
- output_buffer.size(3), output_buffer.size(4)});
242
-
243
- output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
244
- im2col_step, outputHeight, outputWidth});
245
- output_buffer.transpose_(1, 2);
246
- output.copy_(output_buffer);
247
- output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
248
-
249
- input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
250
- offset = offset.view(
251
- {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
252
-
253
- if (batch == 0) {
254
- output = output.view({nOutputPlane, outputHeight, outputWidth});
255
- input = input.view({nInputPlane, inputHeight, inputWidth});
256
- offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
257
- }
258
-
259
- return 1;
260
- }
261
-
262
- int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
263
- at::Tensor gradOutput, at::Tensor gradInput,
264
- at::Tensor gradOffset, at::Tensor weight,
265
- at::Tensor columns, int kW, int kH, int dW,
266
- int dH, int padW, int padH, int dilationW,
267
- int dilationH, int group,
268
- int deformable_group, int im2col_step) {
269
- shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
270
- dilationH, dilationW, group, deformable_group);
271
- at::DeviceGuard guard(input.device());
272
-
273
- input = input.contiguous();
274
- offset = offset.contiguous();
275
- gradOutput = gradOutput.contiguous();
276
- weight = weight.contiguous();
277
-
278
- int batch = 1;
279
-
280
- if (input.ndimension() == 3) {
281
- // Force batch
282
- batch = 0;
283
- input = input.view({1, input.size(0), input.size(1), input.size(2)});
284
- offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
285
- gradOutput = gradOutput.view(
286
- {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
287
- }
288
-
289
- long batchSize = input.size(0);
290
- long nInputPlane = input.size(1);
291
- long inputHeight = input.size(2);
292
- long inputWidth = input.size(3);
293
-
294
- long nOutputPlane = weight.size(0);
295
-
296
- long outputWidth =
297
- (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
298
- long outputHeight =
299
- (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
300
-
301
- TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
302
- gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
303
- columns = at::zeros(
304
- {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
305
- input.options());
306
-
307
- // change order of grad output
308
- gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
309
- nOutputPlane, outputHeight, outputWidth});
310
- gradOutput.transpose_(1, 2);
311
-
312
- gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
313
- inputHeight, inputWidth});
314
- input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
315
- inputHeight, inputWidth});
316
- gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
317
- deformable_group * 2 * kH * kW, outputHeight,
318
- outputWidth});
319
- offset =
320
- offset.view({batchSize / im2col_step, im2col_step,
321
- deformable_group * 2 * kH * kW, outputHeight, outputWidth});
322
-
323
- for (int elt = 0; elt < batchSize / im2col_step; elt++) {
324
- // divide into groups
325
- columns = columns.view({group, columns.size(0) / group, columns.size(1)});
326
- weight = weight.view({group, weight.size(0) / group, weight.size(1),
327
- weight.size(2), weight.size(3)});
328
- gradOutput = gradOutput.view(
329
- {gradOutput.size(0), group, gradOutput.size(1) / group,
330
- gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
331
-
332
- for (int g = 0; g < group; g++) {
333
- columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
334
- gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
335
- }
336
-
337
- columns =
338
- columns.view({columns.size(0) * columns.size(1), columns.size(2)});
339
- gradOutput = gradOutput.view(
340
- {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
341
- gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
342
-
343
- deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
344
- inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
345
- dilationH, dilationW, im2col_step, deformable_group,
346
- gradOffset[elt]);
347
-
348
- deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
349
- inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
350
- dilationW, im2col_step, deformable_group, gradInput[elt]);
351
- }
352
-
353
- gradOutput.transpose_(1, 2);
354
- gradOutput =
355
- gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
356
-
357
- gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
358
- input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
359
- gradOffset = gradOffset.view(
360
- {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
361
- offset = offset.view(
362
- {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
363
-
364
- if (batch == 0) {
365
- gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
366
- input = input.view({nInputPlane, inputHeight, inputWidth});
367
- gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
368
- offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
369
- gradOffset =
370
- gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
371
- }
372
-
373
- return 1;
374
- }
375
-
376
- int deform_conv_backward_parameters_cuda(
377
- at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
378
- at::Tensor gradWeight, // at::Tensor gradBias,
379
- at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
380
- int padW, int padH, int dilationW, int dilationH, int group,
381
- int deformable_group, float scale, int im2col_step) {
382
- // todo: transpose and reshape outGrad
383
- // todo: reshape columns
384
- // todo: add im2col_step as input
385
-
386
- shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
387
- padW, dilationH, dilationW, group, deformable_group);
388
- at::DeviceGuard guard(input.device());
389
-
390
- input = input.contiguous();
391
- offset = offset.contiguous();
392
- gradOutput = gradOutput.contiguous();
393
-
394
- int batch = 1;
395
-
396
- if (input.ndimension() == 3) {
397
- // Force batch
398
- batch = 0;
399
- input = input.view(
400
- at::IntList({1, input.size(0), input.size(1), input.size(2)}));
401
- gradOutput = gradOutput.view(
402
- {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
403
- }
404
-
405
- long batchSize = input.size(0);
406
- long nInputPlane = input.size(1);
407
- long inputHeight = input.size(2);
408
- long inputWidth = input.size(3);
409
-
410
- long nOutputPlane = gradWeight.size(0);
411
-
412
- long outputWidth =
413
- (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
414
- long outputHeight =
415
- (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
416
-
417
- TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
418
-
419
- columns = at::zeros(
420
- {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
421
- input.options());
422
-
423
- gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
424
- nOutputPlane, outputHeight, outputWidth});
425
- gradOutput.transpose_(1, 2);
426
-
427
- at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
428
- gradOutputBuffer =
429
- gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
430
- outputHeight, outputWidth});
431
- gradOutputBuffer.copy_(gradOutput);
432
- gradOutputBuffer =
433
- gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
434
- im2col_step * outputHeight, outputWidth});
435
-
436
- gradOutput.transpose_(1, 2);
437
- gradOutput =
438
- gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
439
-
440
- input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
441
- inputHeight, inputWidth});
442
- offset =
443
- offset.view({batchSize / im2col_step, im2col_step,
444
- deformable_group * 2 * kH * kW, outputHeight, outputWidth});
445
-
446
- for (int elt = 0; elt < batchSize / im2col_step; elt++) {
447
- deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
448
- inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
449
- dilationW, im2col_step, deformable_group, columns);
450
-
451
- // divide into group
452
- gradOutputBuffer = gradOutputBuffer.view(
453
- {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
454
- gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
455
- columns = columns.view({group, columns.size(0) / group, columns.size(1)});
456
- gradWeight =
457
- gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
458
- gradWeight.size(2), gradWeight.size(3)});
459
-
460
- for (int g = 0; g < group; g++) {
461
- gradWeight[g] = gradWeight[g]
462
- .flatten(1)
463
- .addmm_(gradOutputBuffer[elt][g].flatten(1),
464
- columns[g].transpose(1, 0), 1.0, scale)
465
- .view_as(gradWeight[g]);
466
- }
467
- gradOutputBuffer = gradOutputBuffer.view(
468
- {gradOutputBuffer.size(0),
469
- gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
470
- gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
471
- columns =
472
- columns.view({columns.size(0) * columns.size(1), columns.size(2)});
473
- gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
474
- gradWeight.size(2), gradWeight.size(3),
475
- gradWeight.size(4)});
476
- }
477
-
478
- input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
479
- offset = offset.view(
480
- {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
481
-
482
- if (batch == 0) {
483
- gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
484
- input = input.view({nInputPlane, inputHeight, inputWidth});
485
- }
486
-
487
- return 1;
488
- }
489
-
490
- void modulated_deform_conv_cuda_forward(
491
- at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
492
- at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
493
- int kernel_h, int kernel_w, const int stride_h, const int stride_w,
494
- const int pad_h, const int pad_w, const int dilation_h,
495
- const int dilation_w, const int group, const int deformable_group,
496
- const bool with_bias) {
497
- TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
498
- TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
499
- at::DeviceGuard guard(input.device());
500
-
501
- const int batch = input.size(0);
502
- const int channels = input.size(1);
503
- const int height = input.size(2);
504
- const int width = input.size(3);
505
-
506
- const int channels_out = weight.size(0);
507
- const int channels_kernel = weight.size(1);
508
- const int kernel_h_ = weight.size(2);
509
- const int kernel_w_ = weight.size(3);
510
-
511
- if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
512
- AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
513
- kernel_h_, kernel_w, kernel_h_, kernel_w_);
514
- if (channels != channels_kernel * group)
515
- AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
516
- channels, channels_kernel * group);
517
-
518
- const int height_out =
519
- (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
520
- const int width_out =
521
- (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
522
-
523
- if (ones.ndimension() != 2 ||
524
- ones.size(0) * ones.size(1) < height_out * width_out) {
525
- // Resize plane and fill with ones...
526
- ones = at::ones({height_out, width_out}, input.options());
527
- }
528
-
529
- // resize output
530
- output = output.view({batch, channels_out, height_out, width_out}).zero_();
531
- // resize temporary columns
532
- columns =
533
- at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
534
- input.options());
535
-
536
- output = output.view({output.size(0), group, output.size(1) / group,
537
- output.size(2), output.size(3)});
538
-
539
- for (int b = 0; b < batch; b++) {
540
- modulated_deformable_im2col_cuda(
541
- input[b], offset[b], mask[b], 1, channels, height, width, height_out,
542
- width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
543
- dilation_h, dilation_w, deformable_group, columns);
544
-
545
- // divide into group
546
- weight = weight.view({group, weight.size(0) / group, weight.size(1),
547
- weight.size(2), weight.size(3)});
548
- columns = columns.view({group, columns.size(0) / group, columns.size(1)});
549
-
550
- for (int g = 0; g < group; g++) {
551
- output[b][g] = output[b][g]
552
- .flatten(1)
553
- .addmm_(weight[g].flatten(1), columns[g])
554
- .view_as(output[b][g]);
555
- }
556
-
557
- weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
558
- weight.size(3), weight.size(4)});
559
- columns =
560
- columns.view({columns.size(0) * columns.size(1), columns.size(2)});
561
- }
562
-
563
- output = output.view({output.size(0), output.size(1) * output.size(2),
564
- output.size(3), output.size(4)});
565
-
566
- if (with_bias) {
567
- output += bias.view({1, bias.size(0), 1, 1});
568
- }
569
- }
570
-
571
- void modulated_deform_conv_cuda_backward(
572
- at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
573
- at::Tensor offset, at::Tensor mask, at::Tensor columns,
574
- at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
575
- at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
576
- int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
577
- int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
578
- const bool with_bias) {
579
- TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
580
- TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
581
- at::DeviceGuard guard(input.device());
582
-
583
- const int batch = input.size(0);
584
- const int channels = input.size(1);
585
- const int height = input.size(2);
586
- const int width = input.size(3);
587
-
588
- const int channels_kernel = weight.size(1);
589
- const int kernel_h_ = weight.size(2);
590
- const int kernel_w_ = weight.size(3);
591
- if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
592
- AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
593
- kernel_h_, kernel_w, kernel_h_, kernel_w_);
594
- if (channels != channels_kernel * group)
595
- AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
596
- channels, channels_kernel * group);
597
-
598
- const int height_out =
599
- (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
600
- const int width_out =
601
- (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
602
-
603
- if (ones.ndimension() != 2 ||
604
- ones.size(0) * ones.size(1) < height_out * width_out) {
605
- // Resize plane and fill with ones...
606
- ones = at::ones({height_out, width_out}, input.options());
607
- }
608
-
609
- grad_input = grad_input.view({batch, channels, height, width});
610
- columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
611
- input.options());
612
-
613
- grad_output =
614
- grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
615
- grad_output.size(2), grad_output.size(3)});
616
-
617
- for (int b = 0; b < batch; b++) {
618
- // divide int group
619
- columns = columns.view({group, columns.size(0) / group, columns.size(1)});
620
- weight = weight.view({group, weight.size(0) / group, weight.size(1),
621
- weight.size(2), weight.size(3)});
622
-
623
- for (int g = 0; g < group; g++) {
624
- columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
625
- grad_output[b][g].flatten(1), 0.0f, 1.0f);
626
- }
627
-
628
- columns =
629
- columns.view({columns.size(0) * columns.size(1), columns.size(2)});
630
- weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
631
- weight.size(3), weight.size(4)});
632
-
633
- // gradient w.r.t. input coordinate data
634
- modulated_deformable_col2im_coord_cuda(
635
- columns, input[b], offset[b], mask[b], 1, channels, height, width,
636
- height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
637
- stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
638
- grad_mask[b]);
639
- // gradient w.r.t. input data
640
- modulated_deformable_col2im_cuda(
641
- columns, offset[b], mask[b], 1, channels, height, width, height_out,
642
- width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
643
- dilation_h, dilation_w, deformable_group, grad_input[b]);
644
-
645
- // gradient w.r.t. weight, dWeight should accumulate across the batch and
646
- // group
647
- modulated_deformable_im2col_cuda(
648
- input[b], offset[b], mask[b], 1, channels, height, width, height_out,
649
- width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
650
- dilation_h, dilation_w, deformable_group, columns);
651
-
652
- columns = columns.view({group, columns.size(0) / group, columns.size(1)});
653
- grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
654
- grad_weight.size(1), grad_weight.size(2),
655
- grad_weight.size(3)});
656
- if (with_bias)
657
- grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
658
-
659
- for (int g = 0; g < group; g++) {
660
- grad_weight[g] =
661
- grad_weight[g]
662
- .flatten(1)
663
- .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
664
- .view_as(grad_weight[g]);
665
- if (with_bias) {
666
- grad_bias[g] =
667
- grad_bias[g]
668
- .view({-1, 1})
669
- .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
670
- .view(-1);
671
- }
672
- }
673
-
674
- columns =
675
- columns.view({columns.size(0) * columns.size(1), columns.size(2)});
676
- grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
677
- grad_weight.size(2), grad_weight.size(3),
678
- grad_weight.size(4)});
679
- if (with_bias)
680
- grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
681
- }
682
- grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
683
- grad_output.size(2), grad_output.size(3),
684
- grad_output.size(4)});
685
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu DELETED
@@ -1,867 +0,0 @@
1
- /*!
2
- ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3
- *
4
- * COPYRIGHT
5
- *
6
- * All contributions by the University of California:
7
- * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8
- * All rights reserved.
9
- *
10
- * All other contributions:
11
- * Copyright (c) 2014-2017, the respective contributors
12
- * All rights reserved.
13
- *
14
- * Caffe uses a shared copyright model: each contributor holds copyright over
15
- * their contributions to Caffe. The project versioning records all such
16
- * contribution and copyright details. If a contributor wants to further mark
17
- * their specific copyright on a particular contribution, they should indicate
18
- * their copyright solely in the commit message of the change when it is
19
- * committed.
20
- *
21
- * LICENSE
22
- *
23
- * Redistribution and use in source and binary forms, with or without
24
- * modification, are permitted provided that the following conditions are met:
25
- *
26
- * 1. Redistributions of source code must retain the above copyright notice, this
27
- * list of conditions and the following disclaimer.
28
- * 2. Redistributions in binary form must reproduce the above copyright notice,
29
- * this list of conditions and the following disclaimer in the documentation
30
- * and/or other materials provided with the distribution.
31
- *
32
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36
- * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42
- *
43
- * CONTRIBUTION AGREEMENT
44
- *
45
- * By contributing to the BVLC/caffe repository through pull-request, comment,
46
- * or otherwise, the contributor releases their content to the
47
- * license and copyright terms herein.
48
- *
49
- ***************** END Caffe Copyright Notice and Disclaimer ********************
50
- *
51
- * Copyright (c) 2018 Microsoft
52
- * Licensed under The MIT License [see LICENSE for details]
53
- * \file modulated_deformable_im2col.cuh
54
- * \brief Function definitions of converting an image to
55
- * column matrix based on kernel, padding, dilation, and offset.
56
- * These functions are mainly used in deformable convolution operators.
57
- * \ref: https://arxiv.org/abs/1703.06211
58
- * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
59
- */
60
-
61
- // modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
62
-
63
- #include <ATen/ATen.h>
64
- #include <ATen/cuda/CUDAContext.h>
65
- #include <THC/THCAtomics.cuh>
66
- #include <stdio.h>
67
- #include <math.h>
68
- #include <float.h>
69
-
70
- using namespace at;
71
-
72
- #define CUDA_KERNEL_LOOP(i, n) \
73
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
74
- i += blockDim.x * gridDim.x)
75
-
76
- const int CUDA_NUM_THREADS = 1024;
77
- const int kMaxGridNum = 65535;
78
-
79
- inline int GET_BLOCKS(const int N)
80
- {
81
- return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
82
- }
83
-
84
- template <typename scalar_t>
85
- __device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
86
- const int height, const int width, scalar_t h, scalar_t w)
87
- {
88
-
89
- int h_low = floor(h);
90
- int w_low = floor(w);
91
- int h_high = h_low + 1;
92
- int w_high = w_low + 1;
93
-
94
- scalar_t lh = h - h_low;
95
- scalar_t lw = w - w_low;
96
- scalar_t hh = 1 - lh, hw = 1 - lw;
97
-
98
- scalar_t v1 = 0;
99
- if (h_low >= 0 && w_low >= 0)
100
- v1 = bottom_data[h_low * data_width + w_low];
101
- scalar_t v2 = 0;
102
- if (h_low >= 0 && w_high <= width - 1)
103
- v2 = bottom_data[h_low * data_width + w_high];
104
- scalar_t v3 = 0;
105
- if (h_high <= height - 1 && w_low >= 0)
106
- v3 = bottom_data[h_high * data_width + w_low];
107
- scalar_t v4 = 0;
108
- if (h_high <= height - 1 && w_high <= width - 1)
109
- v4 = bottom_data[h_high * data_width + w_high];
110
-
111
- scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
112
-
113
- scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
114
- return val;
115
- }
116
-
117
- template <typename scalar_t>
118
- __device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
119
- const int h, const int w, const int height, const int width)
120
- {
121
-
122
- if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
123
- {
124
- //empty
125
- return 0;
126
- }
127
-
128
- int argmax_h_low = floor(argmax_h);
129
- int argmax_w_low = floor(argmax_w);
130
- int argmax_h_high = argmax_h_low + 1;
131
- int argmax_w_high = argmax_w_low + 1;
132
-
133
- scalar_t weight = 0;
134
- if (h == argmax_h_low && w == argmax_w_low)
135
- weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
136
- if (h == argmax_h_low && w == argmax_w_high)
137
- weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
138
- if (h == argmax_h_high && w == argmax_w_low)
139
- weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
140
- if (h == argmax_h_high && w == argmax_w_high)
141
- weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
142
- return weight;
143
- }
144
-
145
- template <typename scalar_t>
146
- __device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
147
- const int height, const int width, const scalar_t *im_data,
148
- const int data_width, const int bp_dir)
149
- {
150
-
151
- if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
152
- {
153
- //empty
154
- return 0;
155
- }
156
-
157
- int argmax_h_low = floor(argmax_h);
158
- int argmax_w_low = floor(argmax_w);
159
- int argmax_h_high = argmax_h_low + 1;
160
- int argmax_w_high = argmax_w_low + 1;
161
-
162
- scalar_t weight = 0;
163
-
164
- if (bp_dir == 0)
165
- {
166
- if (argmax_h_low >= 0 && argmax_w_low >= 0)
167
- weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
168
- if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
169
- weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
170
- if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
171
- weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
172
- if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
173
- weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
174
- }
175
- else if (bp_dir == 1)
176
- {
177
- if (argmax_h_low >= 0 && argmax_w_low >= 0)
178
- weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
179
- if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
180
- weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
181
- if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
182
- weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
183
- if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
184
- weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
185
- }
186
-
187
- return weight;
188
- }
189
-
190
- template <typename scalar_t>
191
- __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
192
- const int height, const int width, const int kernel_h, const int kernel_w,
193
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
194
- const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
195
- const int batch_size, const int num_channels, const int deformable_group,
196
- const int height_col, const int width_col,
197
- scalar_t *data_col)
198
- {
199
- CUDA_KERNEL_LOOP(index, n)
200
- {
201
- // index index of output matrix
202
- const int w_col = index % width_col;
203
- const int h_col = (index / width_col) % height_col;
204
- const int b_col = (index / width_col / height_col) % batch_size;
205
- const int c_im = (index / width_col / height_col) / batch_size;
206
- const int c_col = c_im * kernel_h * kernel_w;
207
-
208
- // compute deformable group index
209
- const int deformable_group_index = c_im / channel_per_deformable_group;
210
-
211
- const int h_in = h_col * stride_h - pad_h;
212
- const int w_in = w_col * stride_w - pad_w;
213
- scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
214
- //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
215
- const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
216
- const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
217
-
218
- for (int i = 0; i < kernel_h; ++i)
219
- {
220
- for (int j = 0; j < kernel_w; ++j)
221
- {
222
- const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
223
- const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
224
- const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
225
- const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
226
- scalar_t val = static_cast<scalar_t>(0);
227
- const scalar_t h_im = h_in + i * dilation_h + offset_h;
228
- const scalar_t w_im = w_in + j * dilation_w + offset_w;
229
- if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
230
- {
231
- //const scalar_t map_h = i * dilation_h + offset_h;
232
- //const scalar_t map_w = j * dilation_w + offset_w;
233
- //const int cur_height = height - h_in;
234
- //const int cur_width = width - w_in;
235
- //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
236
- val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
237
- }
238
- *data_col_ptr = val;
239
- data_col_ptr += batch_size * height_col * width_col;
240
- }
241
- }
242
- }
243
- }
244
-
245
- void deformable_im2col(
246
- const at::Tensor data_im, const at::Tensor data_offset, const int channels,
247
- const int height, const int width, const int ksize_h, const int ksize_w,
248
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
249
- const int dilation_h, const int dilation_w, const int parallel_imgs,
250
- const int deformable_group, at::Tensor data_col)
251
- {
252
- // num_axes should be smaller than block size
253
- // todo: check parallel_imgs is correctly passed in
254
- int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
255
- int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
256
- int num_kernels = channels * height_col * width_col * parallel_imgs;
257
- int channel_per_deformable_group = channels / deformable_group;
258
-
259
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
260
- data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
261
- const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
262
- const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
263
- scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
264
-
265
- deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
266
- num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
267
- pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
268
- channel_per_deformable_group, parallel_imgs, channels, deformable_group,
269
- height_col, width_col, data_col_);
270
- }));
271
-
272
- cudaError_t err = cudaGetLastError();
273
- if (err != cudaSuccess)
274
- {
275
- printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
276
- }
277
- }
278
-
279
- template <typename scalar_t>
280
- __global__ void deformable_col2im_gpu_kernel(
281
- const int n, const scalar_t *data_col, const scalar_t *data_offset,
282
- const int channels, const int height, const int width,
283
- const int kernel_h, const int kernel_w,
284
- const int pad_h, const int pad_w,
285
- const int stride_h, const int stride_w,
286
- const int dilation_h, const int dilation_w,
287
- const int channel_per_deformable_group,
288
- const int batch_size, const int deformable_group,
289
- const int height_col, const int width_col,
290
- scalar_t *grad_im)
291
- {
292
- CUDA_KERNEL_LOOP(index, n)
293
- {
294
- const int j = (index / width_col / height_col / batch_size) % kernel_w;
295
- const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
296
- const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
297
- // compute the start and end of the output
298
-
299
- const int deformable_group_index = c / channel_per_deformable_group;
300
-
301
- int w_out = index % width_col;
302
- int h_out = (index / width_col) % height_col;
303
- int b = (index / width_col / height_col) % batch_size;
304
- int w_in = w_out * stride_w - pad_w;
305
- int h_in = h_out * stride_h - pad_h;
306
-
307
- const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
308
- 2 * kernel_h * kernel_w * height_col * width_col;
309
- const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
310
- const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
311
- const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
312
- const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
313
- const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
314
- const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
315
-
316
- const scalar_t cur_top_grad = data_col[index];
317
- const int cur_h = (int)cur_inv_h_data;
318
- const int cur_w = (int)cur_inv_w_data;
319
- for (int dy = -2; dy <= 2; dy++)
320
- {
321
- for (int dx = -2; dx <= 2; dx++)
322
- {
323
- if (cur_h + dy >= 0 && cur_h + dy < height &&
324
- cur_w + dx >= 0 && cur_w + dx < width &&
325
- abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
326
- abs(cur_inv_w_data - (cur_w + dx)) < 1)
327
- {
328
- int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
329
- scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
330
- atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
331
- }
332
- }
333
- }
334
- }
335
- }
336
-
337
- void deformable_col2im(
338
- const at::Tensor data_col, const at::Tensor data_offset, const int channels,
339
- const int height, const int width, const int ksize_h,
340
- const int ksize_w, const int pad_h, const int pad_w,
341
- const int stride_h, const int stride_w,
342
- const int dilation_h, const int dilation_w,
343
- const int parallel_imgs, const int deformable_group,
344
- at::Tensor grad_im)
345
- {
346
-
347
- // todo: make sure parallel_imgs is passed in correctly
348
- int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
349
- int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
350
- int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
351
- int channel_per_deformable_group = channels / deformable_group;
352
-
353
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
354
- data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
355
- const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
356
- const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
357
- scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
358
-
359
- deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
360
- num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
361
- ksize_w, pad_h, pad_w, stride_h, stride_w,
362
- dilation_h, dilation_w, channel_per_deformable_group,
363
- parallel_imgs, deformable_group, height_col, width_col, grad_im_);
364
- }));
365
-
366
- cudaError_t err = cudaGetLastError();
367
- if (err != cudaSuccess)
368
- {
369
- printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
370
- }
371
- }
372
-
373
- template <typename scalar_t>
374
- __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
375
- const scalar_t *data_im, const scalar_t *data_offset,
376
- const int channels, const int height, const int width,
377
- const int kernel_h, const int kernel_w,
378
- const int pad_h, const int pad_w,
379
- const int stride_h, const int stride_w,
380
- const int dilation_h, const int dilation_w,
381
- const int channel_per_deformable_group,
382
- const int batch_size, const int offset_channels, const int deformable_group,
383
- const int height_col, const int width_col, scalar_t *grad_offset)
384
- {
385
- CUDA_KERNEL_LOOP(index, n)
386
- {
387
- scalar_t val = 0;
388
- int w = index % width_col;
389
- int h = (index / width_col) % height_col;
390
- int c = (index / width_col / height_col) % offset_channels;
391
- int b = (index / width_col / height_col) / offset_channels;
392
- // compute the start and end of the output
393
-
394
- const int deformable_group_index = c / (2 * kernel_h * kernel_w);
395
- const int col_step = kernel_h * kernel_w;
396
- int cnt = 0;
397
- const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
398
- batch_size * width_col * height_col;
399
- const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
400
- channel_per_deformable_group / kernel_h / kernel_w * height * width;
401
- const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
402
- kernel_h * kernel_w * height_col * width_col;
403
-
404
- const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
405
-
406
- for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
407
- {
408
- const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
409
- const int bp_dir = offset_c % 2;
410
-
411
- int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
412
- int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
413
- int w_out = col_pos % width_col;
414
- int h_out = (col_pos / width_col) % height_col;
415
- int w_in = w_out * stride_w - pad_w;
416
- int h_in = h_out * stride_h - pad_h;
417
- const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
418
- const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
419
- const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
420
- const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
421
- scalar_t inv_h = h_in + i * dilation_h + offset_h;
422
- scalar_t inv_w = w_in + j * dilation_w + offset_w;
423
- if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
424
- {
425
- inv_h = inv_w = -2;
426
- }
427
- const scalar_t weight = get_coordinate_weight(
428
- inv_h, inv_w,
429
- height, width, data_im_ptr + cnt * height * width, width, bp_dir);
430
- val += weight * data_col_ptr[col_pos];
431
- cnt += 1;
432
- }
433
-
434
- grad_offset[index] = val;
435
- }
436
- }
437
-
438
- void deformable_col2im_coord(
439
- const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
440
- const int channels, const int height, const int width, const int ksize_h,
441
- const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
442
- const int stride_w, const int dilation_h, const int dilation_w,
443
- const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
444
- {
445
-
446
- int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
447
- int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
448
- int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
449
- int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
450
-
451
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
452
- data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
453
- const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
454
- const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
455
- const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
456
- scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
457
-
458
- deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
459
- num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
460
- ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
461
- dilation_h, dilation_w, channel_per_deformable_group,
462
- parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
463
- height_col, width_col, grad_offset_);
464
- }));
465
- }
466
-
467
- template <typename scalar_t>
468
- __device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
469
- const int height, const int width, scalar_t h, scalar_t w)
470
- {
471
- int h_low = floor(h);
472
- int w_low = floor(w);
473
- int h_high = h_low + 1;
474
- int w_high = w_low + 1;
475
-
476
- scalar_t lh = h - h_low;
477
- scalar_t lw = w - w_low;
478
- scalar_t hh = 1 - lh, hw = 1 - lw;
479
-
480
- scalar_t v1 = 0;
481
- if (h_low >= 0 && w_low >= 0)
482
- v1 = bottom_data[h_low * data_width + w_low];
483
- scalar_t v2 = 0;
484
- if (h_low >= 0 && w_high <= width - 1)
485
- v2 = bottom_data[h_low * data_width + w_high];
486
- scalar_t v3 = 0;
487
- if (h_high <= height - 1 && w_low >= 0)
488
- v3 = bottom_data[h_high * data_width + w_low];
489
- scalar_t v4 = 0;
490
- if (h_high <= height - 1 && w_high <= width - 1)
491
- v4 = bottom_data[h_high * data_width + w_high];
492
-
493
- scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
494
-
495
- scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
496
- return val;
497
- }
498
-
499
- template <typename scalar_t>
500
- __device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
501
- const int h, const int w, const int height, const int width)
502
- {
503
- if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
504
- {
505
- //empty
506
- return 0;
507
- }
508
-
509
- int argmax_h_low = floor(argmax_h);
510
- int argmax_w_low = floor(argmax_w);
511
- int argmax_h_high = argmax_h_low + 1;
512
- int argmax_w_high = argmax_w_low + 1;
513
-
514
- scalar_t weight = 0;
515
- if (h == argmax_h_low && w == argmax_w_low)
516
- weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
517
- if (h == argmax_h_low && w == argmax_w_high)
518
- weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
519
- if (h == argmax_h_high && w == argmax_w_low)
520
- weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
521
- if (h == argmax_h_high && w == argmax_w_high)
522
- weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
523
- return weight;
524
- }
525
-
526
- template <typename scalar_t>
527
- __device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
528
- const int height, const int width, const scalar_t *im_data,
529
- const int data_width, const int bp_dir)
530
- {
531
- if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
532
- {
533
- //empty
534
- return 0;
535
- }
536
-
537
- int argmax_h_low = floor(argmax_h);
538
- int argmax_w_low = floor(argmax_w);
539
- int argmax_h_high = argmax_h_low + 1;
540
- int argmax_w_high = argmax_w_low + 1;
541
-
542
- scalar_t weight = 0;
543
-
544
- if (bp_dir == 0)
545
- {
546
- if (argmax_h_low >= 0 && argmax_w_low >= 0)
547
- weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
548
- if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
549
- weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
550
- if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
551
- weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
552
- if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
553
- weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
554
- }
555
- else if (bp_dir == 1)
556
- {
557
- if (argmax_h_low >= 0 && argmax_w_low >= 0)
558
- weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
559
- if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
560
- weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
561
- if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
562
- weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
563
- if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
564
- weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
565
- }
566
-
567
- return weight;
568
- }
569
-
570
- template <typename scalar_t>
571
- __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
572
- const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
573
- const int height, const int width, const int kernel_h, const int kernel_w,
574
- const int pad_h, const int pad_w,
575
- const int stride_h, const int stride_w,
576
- const int dilation_h, const int dilation_w,
577
- const int channel_per_deformable_group,
578
- const int batch_size, const int num_channels, const int deformable_group,
579
- const int height_col, const int width_col,
580
- scalar_t *data_col)
581
- {
582
- CUDA_KERNEL_LOOP(index, n)
583
- {
584
- // index index of output matrix
585
- const int w_col = index % width_col;
586
- const int h_col = (index / width_col) % height_col;
587
- const int b_col = (index / width_col / height_col) % batch_size;
588
- const int c_im = (index / width_col / height_col) / batch_size;
589
- const int c_col = c_im * kernel_h * kernel_w;
590
-
591
- // compute deformable group index
592
- const int deformable_group_index = c_im / channel_per_deformable_group;
593
-
594
- const int h_in = h_col * stride_h - pad_h;
595
- const int w_in = w_col * stride_w - pad_w;
596
-
597
- scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
598
- //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
599
- const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
600
- const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
601
-
602
- const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
603
-
604
- for (int i = 0; i < kernel_h; ++i)
605
- {
606
- for (int j = 0; j < kernel_w; ++j)
607
- {
608
- const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
609
- const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
610
- const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
611
- const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
612
- const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
613
- const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
614
- scalar_t val = static_cast<scalar_t>(0);
615
- const scalar_t h_im = h_in + i * dilation_h + offset_h;
616
- const scalar_t w_im = w_in + j * dilation_w + offset_w;
617
- //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
618
- if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
619
- {
620
- //const float map_h = i * dilation_h + offset_h;
621
- //const float map_w = j * dilation_w + offset_w;
622
- //const int cur_height = height - h_in;
623
- //const int cur_width = width - w_in;
624
- //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
625
- val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
626
- }
627
- *data_col_ptr = val * mask;
628
- data_col_ptr += batch_size * height_col * width_col;
629
- //data_col_ptr += height_col * width_col;
630
- }
631
- }
632
- }
633
- }
634
-
635
- template <typename scalar_t>
636
- __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
637
- const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
638
- const int channels, const int height, const int width,
639
- const int kernel_h, const int kernel_w,
640
- const int pad_h, const int pad_w,
641
- const int stride_h, const int stride_w,
642
- const int dilation_h, const int dilation_w,
643
- const int channel_per_deformable_group,
644
- const int batch_size, const int deformable_group,
645
- const int height_col, const int width_col,
646
- scalar_t *grad_im)
647
- {
648
- CUDA_KERNEL_LOOP(index, n)
649
- {
650
- const int j = (index / width_col / height_col / batch_size) % kernel_w;
651
- const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
652
- const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
653
- // compute the start and end of the output
654
-
655
- const int deformable_group_index = c / channel_per_deformable_group;
656
-
657
- int w_out = index % width_col;
658
- int h_out = (index / width_col) % height_col;
659
- int b = (index / width_col / height_col) % batch_size;
660
- int w_in = w_out * stride_w - pad_w;
661
- int h_in = h_out * stride_h - pad_h;
662
-
663
- const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
664
- const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
665
- const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
666
- const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
667
- const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
668
- const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
669
- const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
670
- const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
671
- const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
672
- const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
673
-
674
- const scalar_t cur_top_grad = data_col[index] * mask;
675
- const int cur_h = (int)cur_inv_h_data;
676
- const int cur_w = (int)cur_inv_w_data;
677
- for (int dy = -2; dy <= 2; dy++)
678
- {
679
- for (int dx = -2; dx <= 2; dx++)
680
- {
681
- if (cur_h + dy >= 0 && cur_h + dy < height &&
682
- cur_w + dx >= 0 && cur_w + dx < width &&
683
- abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
684
- abs(cur_inv_w_data - (cur_w + dx)) < 1)
685
- {
686
- int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
687
- scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
688
- atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
689
- }
690
- }
691
- }
692
- }
693
- }
694
-
695
- template <typename scalar_t>
696
- __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
697
- const scalar_t *data_col, const scalar_t *data_im,
698
- const scalar_t *data_offset, const scalar_t *data_mask,
699
- const int channels, const int height, const int width,
700
- const int kernel_h, const int kernel_w,
701
- const int pad_h, const int pad_w,
702
- const int stride_h, const int stride_w,
703
- const int dilation_h, const int dilation_w,
704
- const int channel_per_deformable_group,
705
- const int batch_size, const int offset_channels, const int deformable_group,
706
- const int height_col, const int width_col,
707
- scalar_t *grad_offset, scalar_t *grad_mask)
708
- {
709
- CUDA_KERNEL_LOOP(index, n)
710
- {
711
- scalar_t val = 0, mval = 0;
712
- int w = index % width_col;
713
- int h = (index / width_col) % height_col;
714
- int c = (index / width_col / height_col) % offset_channels;
715
- int b = (index / width_col / height_col) / offset_channels;
716
- // compute the start and end of the output
717
-
718
- const int deformable_group_index = c / (2 * kernel_h * kernel_w);
719
- const int col_step = kernel_h * kernel_w;
720
- int cnt = 0;
721
- const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
722
- const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
723
- const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
724
- const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
725
-
726
- const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
727
-
728
- for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
729
- {
730
- const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
731
- const int bp_dir = offset_c % 2;
732
-
733
- int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
734
- int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
735
- int w_out = col_pos % width_col;
736
- int h_out = (col_pos / width_col) % height_col;
737
- int w_in = w_out * stride_w - pad_w;
738
- int h_in = h_out * stride_h - pad_h;
739
- const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
740
- const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
741
- const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
742
- const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
743
- const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
744
- const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
745
- scalar_t inv_h = h_in + i * dilation_h + offset_h;
746
- scalar_t inv_w = w_in + j * dilation_w + offset_w;
747
- if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
748
- {
749
- inv_h = inv_w = -2;
750
- }
751
- else
752
- {
753
- mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
754
- }
755
- const scalar_t weight = dmcn_get_coordinate_weight(
756
- inv_h, inv_w,
757
- height, width, data_im_ptr + cnt * height * width, width, bp_dir);
758
- val += weight * data_col_ptr[col_pos] * mask;
759
- cnt += 1;
760
- }
761
- // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
762
- grad_offset[index] = val;
763
- if (offset_c % 2 == 0)
764
- // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
765
- grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
766
- }
767
- }
768
-
769
- void modulated_deformable_im2col_cuda(
770
- const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
771
- const int batch_size, const int channels, const int height_im, const int width_im,
772
- const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
773
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
774
- const int dilation_h, const int dilation_w,
775
- const int deformable_group, at::Tensor data_col)
776
- {
777
- // num_axes should be smaller than block size
778
- const int channel_per_deformable_group = channels / deformable_group;
779
- const int num_kernels = channels * batch_size * height_col * width_col;
780
-
781
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
782
- data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
783
- const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
784
- const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
785
- const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
786
- scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
787
-
788
- modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
789
- num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
790
- pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
791
- batch_size, channels, deformable_group, height_col, width_col, data_col_);
792
- }));
793
-
794
- cudaError_t err = cudaGetLastError();
795
- if (err != cudaSuccess)
796
- {
797
- printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
798
- }
799
- }
800
-
801
- void modulated_deformable_col2im_cuda(
802
- const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
803
- const int batch_size, const int channels, const int height_im, const int width_im,
804
- const int height_col, const int width_col, const int kernel_h, const int kernel_w,
805
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
806
- const int dilation_h, const int dilation_w,
807
- const int deformable_group, at::Tensor grad_im)
808
- {
809
-
810
- const int channel_per_deformable_group = channels / deformable_group;
811
- const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
812
-
813
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
814
- data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
815
- const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
816
- const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
817
- const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
818
- scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
819
-
820
- modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
821
- num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
822
- kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
823
- dilation_h, dilation_w, channel_per_deformable_group,
824
- batch_size, deformable_group, height_col, width_col, grad_im_);
825
- }));
826
-
827
- cudaError_t err = cudaGetLastError();
828
- if (err != cudaSuccess)
829
- {
830
- printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
831
- }
832
- }
833
-
834
- void modulated_deformable_col2im_coord_cuda(
835
- const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
836
- const int batch_size, const int channels, const int height_im, const int width_im,
837
- const int height_col, const int width_col, const int kernel_h, const int kernel_w,
838
- const int pad_h, const int pad_w, const int stride_h, const int stride_w,
839
- const int dilation_h, const int dilation_w,
840
- const int deformable_group,
841
- at::Tensor grad_offset, at::Tensor grad_mask)
842
- {
843
- const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
844
- const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
845
-
846
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
847
- data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
848
- const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
849
- const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
850
- const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
851
- const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
852
- scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
853
- scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
854
-
855
- modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
856
- num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
857
- kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
858
- dilation_h, dilation_w, channel_per_deformable_group,
859
- batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
860
- grad_offset_, grad_mask_);
861
- }));
862
- cudaError_t err = cudaGetLastError();
863
- if (err != cudaSuccess)
864
- {
865
- printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
866
- }
867
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp DELETED
@@ -1,164 +0,0 @@
1
- // modify from
2
- // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
-
4
- #include <torch/extension.h>
5
- #include <ATen/DeviceGuard.h>
6
-
7
- #include <cmath>
8
- #include <vector>
9
-
10
- #define WITH_CUDA // always use cuda
11
- #ifdef WITH_CUDA
12
- int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
13
- at::Tensor offset, at::Tensor output,
14
- at::Tensor columns, at::Tensor ones, int kW,
15
- int kH, int dW, int dH, int padW, int padH,
16
- int dilationW, int dilationH, int group,
17
- int deformable_group, int im2col_step);
18
-
19
- int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
20
- at::Tensor gradOutput, at::Tensor gradInput,
21
- at::Tensor gradOffset, at::Tensor weight,
22
- at::Tensor columns, int kW, int kH, int dW,
23
- int dH, int padW, int padH, int dilationW,
24
- int dilationH, int group,
25
- int deformable_group, int im2col_step);
26
-
27
- int deform_conv_backward_parameters_cuda(
28
- at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
29
- at::Tensor gradWeight, // at::Tensor gradBias,
30
- at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
31
- int padW, int padH, int dilationW, int dilationH, int group,
32
- int deformable_group, float scale, int im2col_step);
33
-
34
- void modulated_deform_conv_cuda_forward(
35
- at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
36
- at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
37
- int kernel_h, int kernel_w, const int stride_h, const int stride_w,
38
- const int pad_h, const int pad_w, const int dilation_h,
39
- const int dilation_w, const int group, const int deformable_group,
40
- const bool with_bias);
41
-
42
- void modulated_deform_conv_cuda_backward(
43
- at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
44
- at::Tensor offset, at::Tensor mask, at::Tensor columns,
45
- at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
46
- at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
47
- int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
48
- int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
49
- const bool with_bias);
50
- #endif
51
-
52
- int deform_conv_forward(at::Tensor input, at::Tensor weight,
53
- at::Tensor offset, at::Tensor output,
54
- at::Tensor columns, at::Tensor ones, int kW,
55
- int kH, int dW, int dH, int padW, int padH,
56
- int dilationW, int dilationH, int group,
57
- int deformable_group, int im2col_step) {
58
- if (input.device().is_cuda()) {
59
- #ifdef WITH_CUDA
60
- return deform_conv_forward_cuda(input, weight, offset, output, columns,
61
- ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
62
- deformable_group, im2col_step);
63
- #else
64
- AT_ERROR("deform conv is not compiled with GPU support");
65
- #endif
66
- }
67
- AT_ERROR("deform conv is not implemented on CPU");
68
- }
69
-
70
- int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
71
- at::Tensor gradOutput, at::Tensor gradInput,
72
- at::Tensor gradOffset, at::Tensor weight,
73
- at::Tensor columns, int kW, int kH, int dW,
74
- int dH, int padW, int padH, int dilationW,
75
- int dilationH, int group,
76
- int deformable_group, int im2col_step) {
77
- if (input.device().is_cuda()) {
78
- #ifdef WITH_CUDA
79
- return deform_conv_backward_input_cuda(input, offset, gradOutput,
80
- gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
81
- dilationW, dilationH, group, deformable_group, im2col_step);
82
- #else
83
- AT_ERROR("deform conv is not compiled with GPU support");
84
- #endif
85
- }
86
- AT_ERROR("deform conv is not implemented on CPU");
87
- }
88
-
89
- int deform_conv_backward_parameters(
90
- at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
91
- at::Tensor gradWeight, // at::Tensor gradBias,
92
- at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
93
- int padW, int padH, int dilationW, int dilationH, int group,
94
- int deformable_group, float scale, int im2col_step) {
95
- if (input.device().is_cuda()) {
96
- #ifdef WITH_CUDA
97
- return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
98
- gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
99
- dilationH, group, deformable_group, scale, im2col_step);
100
- #else
101
- AT_ERROR("deform conv is not compiled with GPU support");
102
- #endif
103
- }
104
- AT_ERROR("deform conv is not implemented on CPU");
105
- }
106
-
107
- void modulated_deform_conv_forward(
108
- at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
109
- at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
110
- int kernel_h, int kernel_w, const int stride_h, const int stride_w,
111
- const int pad_h, const int pad_w, const int dilation_h,
112
- const int dilation_w, const int group, const int deformable_group,
113
- const bool with_bias) {
114
- if (input.device().is_cuda()) {
115
- #ifdef WITH_CUDA
116
- return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
117
- offset, mask, output, columns, kernel_h, kernel_w, stride_h,
118
- stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
119
- deformable_group, with_bias);
120
- #else
121
- AT_ERROR("modulated deform conv is not compiled with GPU support");
122
- #endif
123
- }
124
- AT_ERROR("modulated deform conv is not implemented on CPU");
125
- }
126
-
127
- void modulated_deform_conv_backward(
128
- at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
129
- at::Tensor offset, at::Tensor mask, at::Tensor columns,
130
- at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
131
- at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
132
- int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
133
- int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
134
- const bool with_bias) {
135
- if (input.device().is_cuda()) {
136
- #ifdef WITH_CUDA
137
- return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
138
- offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
139
- grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
140
- pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
141
- with_bias);
142
- #else
143
- AT_ERROR("modulated deform conv is not compiled with GPU support");
144
- #endif
145
- }
146
- AT_ERROR("modulated deform conv is not implemented on CPU");
147
- }
148
-
149
-
150
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
151
- m.def("deform_conv_forward", &deform_conv_forward,
152
- "deform forward");
153
- m.def("deform_conv_backward_input", &deform_conv_backward_input,
154
- "deform_conv_backward_input");
155
- m.def("deform_conv_backward_parameters",
156
- &deform_conv_backward_parameters,
157
- "deform_conv_backward_parameters");
158
- m.def("modulated_deform_conv_forward",
159
- &modulated_deform_conv_forward,
160
- "modulated deform conv forward");
161
- m.def("modulated_deform_conv_backward",
162
- &modulated_deform_conv_backward,
163
- "modulated deform conv backward");
164
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/fused_act/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
-
3
- __all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
 
 
 
 
codeformer/basicsr/ops/fused_act/fused_act.py DELETED
@@ -1,89 +0,0 @@
1
- # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
2
-
3
- import torch
4
- from torch import nn
5
- from torch.autograd import Function
6
-
7
- try:
8
- from . import fused_act_ext
9
- except ImportError:
10
- import os
11
- BASICSR_JIT = os.getenv('BASICSR_JIT')
12
- if BASICSR_JIT == 'True':
13
- from torch.utils.cpp_extension import load
14
- module_path = os.path.dirname(__file__)
15
- fused_act_ext = load(
16
- 'fused',
17
- sources=[
18
- os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
19
- os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
20
- ],
21
- )
22
-
23
-
24
- class FusedLeakyReLUFunctionBackward(Function):
25
-
26
- @staticmethod
27
- def forward(ctx, grad_output, out, negative_slope, scale):
28
- ctx.save_for_backward(out)
29
- ctx.negative_slope = negative_slope
30
- ctx.scale = scale
31
-
32
- empty = grad_output.new_empty(0)
33
-
34
- grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
35
-
36
- dim = [0]
37
-
38
- if grad_input.ndim > 2:
39
- dim += list(range(2, grad_input.ndim))
40
-
41
- grad_bias = grad_input.sum(dim).detach()
42
-
43
- return grad_input, grad_bias
44
-
45
- @staticmethod
46
- def backward(ctx, gradgrad_input, gradgrad_bias):
47
- out, = ctx.saved_tensors
48
- gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
49
- ctx.scale)
50
-
51
- return gradgrad_out, None, None, None
52
-
53
-
54
- class FusedLeakyReLUFunction(Function):
55
-
56
- @staticmethod
57
- def forward(ctx, input, bias, negative_slope, scale):
58
- empty = input.new_empty(0)
59
- out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
60
- ctx.save_for_backward(out)
61
- ctx.negative_slope = negative_slope
62
- ctx.scale = scale
63
-
64
- return out
65
-
66
- @staticmethod
67
- def backward(ctx, grad_output):
68
- out, = ctx.saved_tensors
69
-
70
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
71
-
72
- return grad_input, grad_bias, None, None
73
-
74
-
75
- class FusedLeakyReLU(nn.Module):
76
-
77
- def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
78
- super().__init__()
79
-
80
- self.bias = nn.Parameter(torch.zeros(channel))
81
- self.negative_slope = negative_slope
82
- self.scale = scale
83
-
84
- def forward(self, input):
85
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
86
-
87
-
88
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
89
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp DELETED
@@ -1,26 +0,0 @@
1
- // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
2
- #include <torch/extension.h>
3
-
4
-
5
- torch::Tensor fused_bias_act_op(const torch::Tensor& input,
6
- const torch::Tensor& bias,
7
- const torch::Tensor& refer,
8
- int act, int grad, float alpha, float scale);
9
-
10
- #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13
-
14
- torch::Tensor fused_bias_act(const torch::Tensor& input,
15
- const torch::Tensor& bias,
16
- const torch::Tensor& refer,
17
- int act, int grad, float alpha, float scale) {
18
- CHECK_CUDA(input);
19
- CHECK_CUDA(bias);
20
-
21
- return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
22
- }
23
-
24
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
25
- m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu DELETED
@@ -1,100 +0,0 @@
1
- // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
2
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
3
- //
4
- // This work is made available under the Nvidia Source Code License-NC.
5
- // To view a copy of this license, visit
6
- // https://nvlabs.github.io/stylegan2/license.html
7
-
8
- #include <torch/types.h>
9
-
10
- #include <ATen/ATen.h>
11
- #include <ATen/AccumulateType.h>
12
- #include <ATen/cuda/CUDAContext.h>
13
- #include <ATen/cuda/CUDAApplyUtils.cuh>
14
-
15
- #include <cuda.h>
16
- #include <cuda_runtime.h>
17
-
18
-
19
- template <typename scalar_t>
20
- static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
21
- int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
22
- int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
23
-
24
- scalar_t zero = 0.0;
25
-
26
- for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
27
- scalar_t x = p_x[xi];
28
-
29
- if (use_bias) {
30
- x += p_b[(xi / step_b) % size_b];
31
- }
32
-
33
- scalar_t ref = use_ref ? p_ref[xi] : zero;
34
-
35
- scalar_t y;
36
-
37
- switch (act * 10 + grad) {
38
- default:
39
- case 10: y = x; break;
40
- case 11: y = x; break;
41
- case 12: y = 0.0; break;
42
-
43
- case 30: y = (x > 0.0) ? x : x * alpha; break;
44
- case 31: y = (ref > 0.0) ? x : x * alpha; break;
45
- case 32: y = 0.0; break;
46
- }
47
-
48
- out[xi] = y * scale;
49
- }
50
- }
51
-
52
-
53
- torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
54
- int act, int grad, float alpha, float scale) {
55
- int curDevice = -1;
56
- cudaGetDevice(&curDevice);
57
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
58
-
59
- auto x = input.contiguous();
60
- auto b = bias.contiguous();
61
- auto ref = refer.contiguous();
62
-
63
- int use_bias = b.numel() ? 1 : 0;
64
- int use_ref = ref.numel() ? 1 : 0;
65
-
66
- int size_x = x.numel();
67
- int size_b = b.numel();
68
- int step_b = 1;
69
-
70
- for (int i = 1 + 1; i < x.dim(); i++) {
71
- step_b *= x.size(i);
72
- }
73
-
74
- int loop_x = 4;
75
- int block_size = 4 * 32;
76
- int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
77
-
78
- auto y = torch::empty_like(x);
79
-
80
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
81
- fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
82
- y.data_ptr<scalar_t>(),
83
- x.data_ptr<scalar_t>(),
84
- b.data_ptr<scalar_t>(),
85
- ref.data_ptr<scalar_t>(),
86
- act,
87
- grad,
88
- alpha,
89
- scale,
90
- loop_x,
91
- size_x,
92
- step_b,
93
- size_b,
94
- use_bias,
95
- use_ref
96
- );
97
- });
98
-
99
- return y;
100
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeformer/basicsr/ops/upfirdn2d/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .upfirdn2d import upfirdn2d
2
-
3
- __all__ = ['upfirdn2d']