Upload 89 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- RDNet-main/RDNet-main/.gitignore +1 -0
- RDNet-main/RDNet-main/README.md +99 -0
- RDNet-main/RDNet-main/VOC2012_224_train_png.txt +0 -0
- RDNet-main/RDNet-main/data/VOC2012_224_train_png.txt +0 -0
- RDNet-main/RDNet-main/data/__pycache__/dataset_sir.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/data/__pycache__/image_folder.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/data/__pycache__/torchdata.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/data/__pycache__/transforms.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/data/dataset_sir.py +332 -0
- RDNet-main/RDNet-main/data/image_folder.py +51 -0
- RDNet-main/RDNet-main/data/real_test.txt +20 -0
- RDNet-main/RDNet-main/data/torchdata.py +67 -0
- RDNet-main/RDNet-main/data/transforms.py +301 -0
- RDNet-main/RDNet-main/engine.py +178 -0
- RDNet-main/RDNet-main/figures/Input_car.jpg +0 -0
- RDNet-main/RDNet-main/figures/Input_class.png +3 -0
- RDNet-main/RDNet-main/figures/Input_green.png +3 -0
- RDNet-main/RDNet-main/figures/Ours_car.png +3 -0
- RDNet-main/RDNet-main/figures/Ours_class.png +3 -0
- RDNet-main/RDNet-main/figures/Ours_green.png +3 -0
- RDNet-main/RDNet-main/figures/Ours_white.png +3 -0
- RDNet-main/RDNet-main/figures/Title.png +0 -0
- RDNet-main/RDNet-main/figures/input_white.jpg +0 -0
- RDNet-main/RDNet-main/figures/net.png +3 -0
- RDNet-main/RDNet-main/figures/result.png +3 -0
- RDNet-main/RDNet-main/figures/vis.png +3 -0
- RDNet-main/RDNet-main/models/__init__.py +11 -0
- RDNet-main/RDNet-main/models/__pycache__/__init__.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/__pycache__/base_model.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/__pycache__/cls_model_eval_nocls_reg.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/__pycache__/losses.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/__pycache__/networks.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/__pycache__/vgg.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/__pycache__/vit_feature_extractor.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/arch/NAFNET.py +480 -0
- RDNet-main/RDNet-main/models/arch/RDnet_.py +202 -0
- RDNet-main/RDNet-main/models/arch/__pycache__/RDnet_.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/arch/__pycache__/classifier.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/arch/__pycache__/focalnet.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/arch/__pycache__/modules_sig.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/arch/__pycache__/reverse_function.cpython-38.pyc +0 -0
- RDNet-main/RDNet-main/models/arch/classifier.py +49 -0
- RDNet-main/RDNet-main/models/arch/decode.py +36 -0
- RDNet-main/RDNet-main/models/arch/focalnet.py +589 -0
- RDNet-main/RDNet-main/models/arch/modules_sig.py +304 -0
- RDNet-main/RDNet-main/models/arch/reverse_function.py +153 -0
- RDNet-main/RDNet-main/models/arch/vgg.py +90 -0
- RDNet-main/RDNet-main/models/base_model.py +71 -0
- RDNet-main/RDNet-main/models/cls_model_eval_nocls_reg.py +517 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
RDNet-main/RDNet-main/figures/Input_class.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
RDNet-main/RDNet-main/figures/Input_green.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
RDNet-main/RDNet-main/figures/net.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
RDNet-main/RDNet-main/figures/Ours_car.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
RDNet-main/RDNet-main/figures/Ours_class.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
RDNet-main/RDNet-main/figures/Ours_green.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
RDNet-main/RDNet-main/figures/Ours_white.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
RDNet-main/RDNet-main/figures/result.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
RDNet-main/RDNet-main/figures/vis.png filter=lfs diff=lfs merge=lfs -text
|
RDNet-main/RDNet-main/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
RDNet-main/RDNet-main/README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Title.png?raw=true" width=95%>
|
| 3 |
+
<p>
|
| 4 |
+
|
| 5 |
+
# Reversible Decoupling Network for Single Image Reflection Removal
|
| 6 |
+
|
| 7 |
+
<div align="center">
|
| 8 |
+
|
| 9 |
+
[](https://paperswithcode.com/sota/reflection-removal-on-sir-2-objects?p=reversible-decoupling-network-for-single)
|
| 10 |
+
[](https://paperswithcode.com/sota/reflection-removal-on-sir-2-wild?p=reversible-decoupling-network-for-single)
|
| 11 |
+
[](https://paperswithcode.com/sota/reflection-removal-on-sir-2-postcard?p=reversible-decoupling-network-for-single)
|
| 12 |
+
[](https://paperswithcode.com/sota/reflection-removal-on-nature?p=reversible-decoupling-network-for-single)
|
| 13 |
+
[](https://paperswithcode.com/sota/reflection-removal-on-real20?p=reversible-decoupling-network-for-single)
|
| 14 |
+
|
| 15 |
+
</div>
|
| 16 |
+
<p align="center" style="font-size: larger;">
|
| 17 |
+
<a href="https://arxiv.org/abs/2410.08063"> Reversible Decoupling Network for Single Image Reflection Removal</a>
|
| 18 |
+
</p>
|
| 19 |
+
<p align="center">
|
| 20 |
+
<a href="https://github.com/WHiTEWoLFJ"> Hao Zhao</a> ⚔️,
|
| 21 |
+
<a href="https://github.com/lime-j"> Mingjia Li</a> ⚔️,
|
| 22 |
+
<a href="https://github.com/mingcv"> Qiming Hu</a>,
|
| 23 |
+
<a href="https://sites.google.com/view/xjguo"> Xiaojie Guo</a> 🦅,
|
| 24 |
+
<p align="center">(⚔️: equal contribution, 🦅 : corresponding author)</p>
|
| 25 |
+
</p>
|
| 26 |
+
|
| 27 |
+
<p align="center">
|
| 28 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/net.png?raw=true" width=95%>
|
| 29 |
+
<p>
|
| 30 |
+
**Our work is accepted by CVPR 2025! See you at the conference!**
|
| 31 |
+
<details>
|
| 32 |
+
<summary>Click for the Abstract of RDNet</summary>
|
| 33 |
+
We present a Reversible Decoupling Network (RDNet), which employs a reversible encoder to secure valuable information while flexibly decoupling transmission-and-reflection-relevant features during the forward pass. Furthermore, we customize a transmission-rate-aware prompt generator to dynamically calibrate features, further boosting performance. Extensive experiments demonstrate the superiority of RDNet over existing SOTA methods on five widely-adopted benchmark datasets.
|
| 34 |
+
</details>
|
| 35 |
+
|
| 36 |
+
## 🚀Todo
|
| 37 |
+
|
| 38 |
+
- [ ] Release the Training code of RDNet.
|
| 39 |
+
|
| 40 |
+
## 🌠 Gallery
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
<table class="center">
|
| 44 |
+
<tr>
|
| 45 |
+
<td><p style="text-align: center">Class Room</p></td>
|
| 46 |
+
<td><p style="text-align: center">White Wall Chamber</p></td>
|
| 47 |
+
</tr>
|
| 48 |
+
<tr>
|
| 49 |
+
<td>
|
| 50 |
+
<div style="width: 100%; max-width: 600px; position: relative;">
|
| 51 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Input_class.png?raw=true" style="width: 100%; height: 300px; display: block;">
|
| 52 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_class.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
|
| 53 |
+
</div>
|
| 54 |
+
</td>
|
| 55 |
+
<td>
|
| 56 |
+
<div style="width: 100%; max-width: 600px; position: relative;">
|
| 57 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/input_white.jpg?raw=true" style="width: 100%; height: 300px; display: block;">
|
| 58 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_white.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
|
| 59 |
+
</div>
|
| 60 |
+
</td>
|
| 61 |
+
</tr>
|
| 62 |
+
<tr>
|
| 63 |
+
<td><p style="text-align: center">Car Window</p></td>
|
| 64 |
+
<td><p style="text-align: center">Very Green Office</p></td>
|
| 65 |
+
</tr>
|
| 66 |
+
<tr>
|
| 67 |
+
<td>
|
| 68 |
+
<div style="width: 100%; max-width: 600px; position: relative;">
|
| 69 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Input_car.jpg?raw=true" style="width: 100%; height: 300px; display: block;">
|
| 70 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_car.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
|
| 71 |
+
</div>
|
| 72 |
+
</td>
|
| 73 |
+
<td>
|
| 74 |
+
<div style="width: 100%; max-width: 600px; position: relative;">
|
| 75 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Input_green.png?raw=true" style="width: 100%; height: 300px; display: block;">
|
| 76 |
+
<img src="https://github.com/lime-j/RDNet/blob/main/figures/Ours_green.png?raw=true" style="width: 100%; height: 300px; display: block; position: absolute; top: 0; left: 0; opacity: 0; transition: opacity 0.5s;" onmouseover="this.style.opacity=1;" onmouseout="this.style.opacity=0;">
|
| 77 |
+
</div>
|
| 78 |
+
</td>
|
| 79 |
+
</tr>
|
| 80 |
+
</table>
|
| 81 |
+
|
| 82 |
+
## Requirements
|
| 83 |
+
We recommend torch 2.x for our code, but it should works fine with most of the modern versions.
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
pip install torch>=2.0 torchvision
|
| 87 |
+
pip install einops ema-pytorch fsspec fvcore huggingface-hub matplotlib numpy opencv-python omegaconf pytorch-msssim scikit-image scikit-learn scipy tensorboard tensorboardx wandb timm
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
# Testing
|
| 91 |
+
The checkpoint for the main network is available at https://checkpoints.mingjia.li/rdnet.pth ; while the model for cls_model is at https://checkpoints.mingjia.li/cls_model.pth . Please put the cls_model.pth under "pretrained" folder.
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
python3 test_sirs.py --icnn_path <path to the main checkpoint> --resume
|
| 95 |
+
```
|
| 96 |
+
# Training
|
| 97 |
+
|
| 98 |
+
Training script will be released in a few days.
|
| 99 |
+
|
RDNet-main/RDNet-main/VOC2012_224_train_png.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
RDNet-main/RDNet-main/data/VOC2012_224_train_png.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
RDNet-main/RDNet-main/data/__pycache__/dataset_sir.cpython-38.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
RDNet-main/RDNet-main/data/__pycache__/image_folder.cpython-38.pyc
ADDED
|
Binary file (1.58 kB). View file
|
|
|
RDNet-main/RDNet-main/data/__pycache__/torchdata.cpython-38.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
RDNet-main/RDNet-main/data/__pycache__/transforms.cpython-38.pyc
ADDED
|
Binary file (9.37 kB). View file
|
|
|
RDNet-main/RDNet-main/data/dataset_sir.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os.path
|
| 3 |
+
import os.path
|
| 4 |
+
import random
|
| 5 |
+
from os.path import join
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision.transforms.functional as TF
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from scipy.signal import convolve2d
|
| 13 |
+
|
| 14 |
+
from data.image_folder import make_dataset
|
| 15 |
+
from data.torchdata import Dataset as BaseDataset
|
| 16 |
+
from data.transforms import to_tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def __scale_width(img, target_width):
|
| 20 |
+
ow, oh = img.size
|
| 21 |
+
if (ow == target_width):
|
| 22 |
+
return img
|
| 23 |
+
w = target_width
|
| 24 |
+
h = int(target_width * oh / ow)
|
| 25 |
+
h = math.ceil(h / 2.) * 2 # round up to even
|
| 26 |
+
return img.resize((w, h), Image.BICUBIC)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def __scale_height(img, target_height):
|
| 30 |
+
ow, oh = img.size
|
| 31 |
+
if (oh == target_height):
|
| 32 |
+
return img
|
| 33 |
+
h = target_height
|
| 34 |
+
w = int(target_height * ow / oh)
|
| 35 |
+
w = math.ceil(w / 2.) * 2
|
| 36 |
+
return img.resize((w, h), Image.BICUBIC)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def paired_data_transforms(img_1, img_2, unaligned_transforms=False):
|
| 40 |
+
def get_params(img, output_size):
|
| 41 |
+
w, h = img.size
|
| 42 |
+
th, tw = output_size
|
| 43 |
+
if w == tw and h == th:
|
| 44 |
+
return 0, 0, h, w
|
| 45 |
+
|
| 46 |
+
i = random.randint(0, h - th)
|
| 47 |
+
j = random.randint(0, w - tw)
|
| 48 |
+
return i, j, th, tw
|
| 49 |
+
|
| 50 |
+
target_size = int(random.randint(320, 640) / 2.) * 2
|
| 51 |
+
ow, oh = img_1.size
|
| 52 |
+
if ow >= oh:
|
| 53 |
+
img_1 = __scale_height(img_1, target_size)
|
| 54 |
+
img_2 = __scale_height(img_2, target_size)
|
| 55 |
+
else:
|
| 56 |
+
img_1 = __scale_width(img_1, target_size)
|
| 57 |
+
img_2 = __scale_width(img_2, target_size)
|
| 58 |
+
|
| 59 |
+
if random.random() < 0.5:
|
| 60 |
+
img_1 = TF.hflip(img_1)
|
| 61 |
+
img_2 = TF.hflip(img_2)
|
| 62 |
+
|
| 63 |
+
if random.random() < 0.5:
|
| 64 |
+
angle = random.choice([90, 180, 270])
|
| 65 |
+
img_1 = TF.rotate(img_1, angle)
|
| 66 |
+
img_2 = TF.rotate(img_2, angle)
|
| 67 |
+
|
| 68 |
+
i, j, h, w = get_params(img_1, (320, 320))
|
| 69 |
+
img_1 = TF.crop(img_1, i, j, h, w)
|
| 70 |
+
|
| 71 |
+
if unaligned_transforms:
|
| 72 |
+
# print('random shift')
|
| 73 |
+
i_shift = random.randint(-10, 10)
|
| 74 |
+
j_shift = random.randint(-10, 10)
|
| 75 |
+
i += i_shift
|
| 76 |
+
j += j_shift
|
| 77 |
+
|
| 78 |
+
img_2 = TF.crop(img_2, i, j, h, w)
|
| 79 |
+
|
| 80 |
+
return img_1, img_2
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ReflectionSynthesis(object):
|
| 84 |
+
def __init__(self):
|
| 85 |
+
# Kernel Size of the Gaussian Blurry
|
| 86 |
+
self.kernel_sizes = [5, 7, 9, 11]
|
| 87 |
+
self.kernel_probs = [0.1, 0.2, 0.3, 0.4]
|
| 88 |
+
|
| 89 |
+
# Sigma of the Gaussian Blurry
|
| 90 |
+
self.sigma_range = [2, 5]
|
| 91 |
+
self.alpha_range = [0.8, 1.0]
|
| 92 |
+
self.beta_range = [0.4, 1.0]
|
| 93 |
+
|
| 94 |
+
def __call__(self, T_, R_):
|
| 95 |
+
T_ = np.asarray(T_, np.float32) / 255.
|
| 96 |
+
R_ = np.asarray(R_, np.float32) / 255.
|
| 97 |
+
|
| 98 |
+
kernel_size = np.random.choice(self.kernel_sizes, p=self.kernel_probs)
|
| 99 |
+
sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1])
|
| 100 |
+
kernel = cv2.getGaussianKernel(kernel_size, sigma)
|
| 101 |
+
kernel2d = np.dot(kernel, kernel.T)
|
| 102 |
+
for i in range(3):
|
| 103 |
+
R_[..., i] = convolve2d(R_[..., i], kernel2d, mode='same')
|
| 104 |
+
|
| 105 |
+
a = np.random.uniform(self.alpha_range[0], self.alpha_range[1])
|
| 106 |
+
b = np.random.uniform(self.beta_range[0], self.beta_range[1])
|
| 107 |
+
T, R = a * T_, b * R_
|
| 108 |
+
|
| 109 |
+
if random.random() < 0.7:
|
| 110 |
+
I = T + R - T * R
|
| 111 |
+
|
| 112 |
+
else:
|
| 113 |
+
I = T + R
|
| 114 |
+
if np.max(I) > 1:
|
| 115 |
+
m = I[I > 1]
|
| 116 |
+
m = (np.mean(m) - 1) * 1.3
|
| 117 |
+
I = np.clip(T + np.clip(R - m, 0, 1), 0, 1)
|
| 118 |
+
|
| 119 |
+
return T_, R_, I
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class DataLoader(torch.utils.data.DataLoader):
|
| 123 |
+
def __init__(self, dataset, batch_size, shuffle, *args, **kwargs):
|
| 124 |
+
super(DataLoader, self).__init__(dataset, batch_size, shuffle, *args, **kwargs)
|
| 125 |
+
self.shuffle = shuffle
|
| 126 |
+
|
| 127 |
+
def reset(self):
|
| 128 |
+
if self.shuffle:
|
| 129 |
+
print('Reset Dataset...')
|
| 130 |
+
self.dataset.reset()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class DSRDataset(BaseDataset):
|
| 134 |
+
def __init__(self, datadir, fns=None, size=None, enable_transforms=True):
|
| 135 |
+
super(DSRDataset, self).__init__()
|
| 136 |
+
self.size = size
|
| 137 |
+
self.datadir = datadir
|
| 138 |
+
self.enable_transforms = enable_transforms
|
| 139 |
+
sortkey = lambda key: os.path.split(key)[-1]
|
| 140 |
+
self.paths = sorted(make_dataset(datadir, fns), key=sortkey)
|
| 141 |
+
if size is not None:
|
| 142 |
+
self.paths = np.random.choice(self.paths, size)
|
| 143 |
+
|
| 144 |
+
self.syn_model = ReflectionSynthesis()
|
| 145 |
+
self.reset(shuffle=False)
|
| 146 |
+
|
| 147 |
+
def reset(self, shuffle=True):
|
| 148 |
+
if shuffle:
|
| 149 |
+
random.shuffle(self.paths)
|
| 150 |
+
num_paths = len(self.paths) // 2
|
| 151 |
+
self.B_paths = self.paths[0:num_paths]
|
| 152 |
+
self.R_paths = self.paths[num_paths:2 * num_paths]
|
| 153 |
+
|
| 154 |
+
def data_synthesis(self, t_img, r_img):
|
| 155 |
+
if self.enable_transforms:
|
| 156 |
+
t_img, r_img = paired_data_transforms(t_img, r_img)
|
| 157 |
+
|
| 158 |
+
t_img, r_img, m_img = self.syn_model(t_img, r_img)
|
| 159 |
+
|
| 160 |
+
B = TF.to_tensor(t_img)
|
| 161 |
+
R = TF.to_tensor(r_img)
|
| 162 |
+
M = TF.to_tensor(m_img)
|
| 163 |
+
|
| 164 |
+
return B, R, M
|
| 165 |
+
|
| 166 |
+
def __getitem__(self, index):
|
| 167 |
+
index_B = index % len(self.B_paths)
|
| 168 |
+
index_R = index % len(self.R_paths)
|
| 169 |
+
|
| 170 |
+
B_path = self.B_paths[index_B]
|
| 171 |
+
R_path = self.R_paths[index_R]
|
| 172 |
+
|
| 173 |
+
t_img = Image.open(B_path).convert('RGB')
|
| 174 |
+
r_img = Image.open(R_path).convert('RGB')
|
| 175 |
+
|
| 176 |
+
B, R, M = self.data_synthesis(t_img, r_img)
|
| 177 |
+
fn = os.path.basename(B_path)
|
| 178 |
+
return {'input': M, 'target_t': B, 'target_r': M-B, 'fn': fn, 'real': False}
|
| 179 |
+
|
| 180 |
+
def __len__(self):
|
| 181 |
+
if self.size is not None:
|
| 182 |
+
return min(max(len(self.B_paths), len(self.R_paths)), self.size)
|
| 183 |
+
else:
|
| 184 |
+
return max(len(self.B_paths), len(self.R_paths))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class DSRTestDataset(BaseDataset):
|
| 188 |
+
def __init__(self, datadir, fns=None, size=None, enable_transforms=False, unaligned_transforms=False,
|
| 189 |
+
round_factor=1, flag=None, if_align=True):
|
| 190 |
+
super(DSRTestDataset, self).__init__()
|
| 191 |
+
self.size = size
|
| 192 |
+
self.datadir = datadir
|
| 193 |
+
self.fns = fns or os.listdir(join(datadir, 'blended'))
|
| 194 |
+
self.enable_transforms = enable_transforms
|
| 195 |
+
self.unaligned_transforms = unaligned_transforms
|
| 196 |
+
self.round_factor = round_factor
|
| 197 |
+
self.flag = flag
|
| 198 |
+
self.if_align = True # if_align
|
| 199 |
+
|
| 200 |
+
if size is not None:
|
| 201 |
+
self.fns = self.fns[:size]
|
| 202 |
+
|
| 203 |
+
def align(self, x1, x2):
|
| 204 |
+
h, w = x1.height, x1.width
|
| 205 |
+
h, w = h // 32 * 32, w // 32 * 32
|
| 206 |
+
x1 = x1.resize((w, h))
|
| 207 |
+
x2 = x2.resize((w, h))
|
| 208 |
+
return x1, x2
|
| 209 |
+
|
| 210 |
+
def __getitem__(self, index):
|
| 211 |
+
fn = self.fns[index]
|
| 212 |
+
|
| 213 |
+
t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
|
| 214 |
+
m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
|
| 215 |
+
|
| 216 |
+
if self.if_align:
|
| 217 |
+
t_img, m_img = self.align(t_img, m_img)
|
| 218 |
+
|
| 219 |
+
if self.enable_transforms:
|
| 220 |
+
t_img, m_img = paired_data_transforms(t_img, m_img, self.unaligned_transforms)
|
| 221 |
+
|
| 222 |
+
B = TF.to_tensor(t_img)
|
| 223 |
+
M = TF.to_tensor(m_img)
|
| 224 |
+
|
| 225 |
+
dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': M - B}
|
| 226 |
+
if self.flag is not None:
|
| 227 |
+
dic.update(self.flag)
|
| 228 |
+
return dic
|
| 229 |
+
|
| 230 |
+
def __len__(self):
|
| 231 |
+
if self.size is not None:
|
| 232 |
+
return min(len(self.fns), self.size)
|
| 233 |
+
else:
|
| 234 |
+
return len(self.fns)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class SIRTestDataset(BaseDataset):
|
| 238 |
+
def __init__(self, datadir, fns=None, size=None, if_align=True):
|
| 239 |
+
super(SIRTestDataset, self).__init__()
|
| 240 |
+
self.size = size
|
| 241 |
+
self.datadir = datadir
|
| 242 |
+
self.fns = fns or os.listdir(join(datadir, 'blended'))
|
| 243 |
+
self.if_align = if_align
|
| 244 |
+
|
| 245 |
+
if size is not None:
|
| 246 |
+
self.fns = self.fns[:size]
|
| 247 |
+
|
| 248 |
+
def align(self, x1, x2, x3):
|
| 249 |
+
h, w = x1.height, x1.width
|
| 250 |
+
h, w = h // 32 * 32, w // 32 * 32
|
| 251 |
+
x1 = x1.resize((w, h))
|
| 252 |
+
x2 = x2.resize((w, h))
|
| 253 |
+
x3 = x3.resize((w, h))
|
| 254 |
+
return x1, x2, x3
|
| 255 |
+
|
| 256 |
+
def __getitem__(self, index):
|
| 257 |
+
fn = self.fns[index]
|
| 258 |
+
|
| 259 |
+
t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB')
|
| 260 |
+
r_img = Image.open(join(self.datadir, 'reflection_layer', fn)).convert('RGB')
|
| 261 |
+
m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB')
|
| 262 |
+
|
| 263 |
+
if self.if_align:
|
| 264 |
+
t_img, r_img, m_img = self.align(t_img, r_img, m_img)
|
| 265 |
+
|
| 266 |
+
B = TF.to_tensor(t_img)
|
| 267 |
+
R = TF.to_tensor(r_img)
|
| 268 |
+
M = TF.to_tensor(m_img)
|
| 269 |
+
|
| 270 |
+
dic = {'input': M, 'target_t': B, 'fn': fn, 'real': True, 'target_r': R, 'target_r_hat': M - B}
|
| 271 |
+
return dic
|
| 272 |
+
|
| 273 |
+
def __len__(self):
|
| 274 |
+
if self.size is not None:
|
| 275 |
+
return min(len(self.fns), self.size)
|
| 276 |
+
else:
|
| 277 |
+
return len(self.fns)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class RealDataset(BaseDataset):
|
| 281 |
+
def __init__(self, datadir, fns=None, size=None):
|
| 282 |
+
super(RealDataset, self).__init__()
|
| 283 |
+
self.size = size
|
| 284 |
+
self.datadir = datadir
|
| 285 |
+
self.fns = fns or os.listdir(join(datadir))
|
| 286 |
+
|
| 287 |
+
if size is not None:
|
| 288 |
+
self.fns = self.fns[:size]
|
| 289 |
+
|
| 290 |
+
def align(self, x):
|
| 291 |
+
h, w = x.height, x.width
|
| 292 |
+
h, w = h // 32 * 32, w // 32 * 32
|
| 293 |
+
x = x.resize((w, h))
|
| 294 |
+
return x
|
| 295 |
+
|
| 296 |
+
def __getitem__(self, index):
|
| 297 |
+
fn = self.fns[index]
|
| 298 |
+
B = -1
|
| 299 |
+
m_img = Image.open(join(self.datadir, fn)).convert('RGB')
|
| 300 |
+
M = to_tensor(self.align(m_img))
|
| 301 |
+
data = {'input': M, 'target_t': B, 'fn': fn}
|
| 302 |
+
return data
|
| 303 |
+
|
| 304 |
+
def __len__(self):
|
| 305 |
+
if self.size is not None:
|
| 306 |
+
return min(len(self.fns), self.size)
|
| 307 |
+
else:
|
| 308 |
+
return len(self.fns)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class FusionDataset(BaseDataset):
|
| 312 |
+
def __init__(self, datasets, fusion_ratios=None):
|
| 313 |
+
self.datasets = datasets
|
| 314 |
+
self.size = sum([len(dataset) for dataset in datasets])
|
| 315 |
+
self.fusion_ratios = fusion_ratios or [1. / len(datasets)] * len(datasets)
|
| 316 |
+
print('[i] using a fusion dataset: %d %s imgs fused with ratio %s' % (
|
| 317 |
+
self.size, [len(dataset) for dataset in datasets], self.fusion_ratios))
|
| 318 |
+
|
| 319 |
+
def reset(self):
|
| 320 |
+
for dataset in self.datasets:
|
| 321 |
+
dataset.reset()
|
| 322 |
+
|
| 323 |
+
def __getitem__(self, index):
|
| 324 |
+
residual = 1
|
| 325 |
+
for i, ratio in enumerate(self.fusion_ratios):
|
| 326 |
+
if random.random() < ratio / residual or i == len(self.fusion_ratios) - 1:
|
| 327 |
+
dataset = self.datasets[i]
|
| 328 |
+
return dataset[index % len(dataset)]
|
| 329 |
+
residual -= ratio
|
| 330 |
+
|
| 331 |
+
def __len__(self):
|
| 332 |
+
return self.size
|
RDNet-main/RDNet-main/data/image_folder.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###############################################################################
|
| 2 |
+
# Code from
|
| 3 |
+
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
|
| 4 |
+
# Modified the original code so that it also loads images from the current
|
| 5 |
+
# directory as well as the subdirectories
|
| 6 |
+
###############################################################################
|
| 7 |
+
|
| 8 |
+
import torch.utils.data as data
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import os
|
| 12 |
+
import os.path
|
| 13 |
+
|
| 14 |
+
IMG_EXTENSIONS = [
|
| 15 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
| 16 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def read_fns(filename):
|
| 21 |
+
with open(filename) as f:
|
| 22 |
+
fns = f.readlines()
|
| 23 |
+
fns = [fn.strip() for fn in fns]
|
| 24 |
+
return fns
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def is_image_file(filename):
|
| 28 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def make_dataset(dir, fns=None):
|
| 32 |
+
images = []
|
| 33 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
| 34 |
+
|
| 35 |
+
if fns is None:
|
| 36 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
| 37 |
+
for fname in fnames:
|
| 38 |
+
if is_image_file(fname):
|
| 39 |
+
path = os.path.join(root, fname)
|
| 40 |
+
images.append(path)
|
| 41 |
+
else:
|
| 42 |
+
for fname in fns:
|
| 43 |
+
if is_image_file(fname):
|
| 44 |
+
path = os.path.join(dir, fname)
|
| 45 |
+
images.append(path)
|
| 46 |
+
|
| 47 |
+
return images
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def default_loader(path):
|
| 51 |
+
return Image.open(path).convert('RGB')
|
RDNet-main/RDNet-main/data/real_test.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
3.jpg
|
| 2 |
+
4.jpg
|
| 3 |
+
9.jpg
|
| 4 |
+
12.jpg
|
| 5 |
+
15.jpg
|
| 6 |
+
22.jpg
|
| 7 |
+
23.jpg
|
| 8 |
+
25.jpg
|
| 9 |
+
29.jpg
|
| 10 |
+
39.jpg
|
| 11 |
+
46.jpg
|
| 12 |
+
47.jpg
|
| 13 |
+
58.jpg
|
| 14 |
+
86.jpg
|
| 15 |
+
87.jpg
|
| 16 |
+
89.jpg
|
| 17 |
+
93.jpg
|
| 18 |
+
103.jpg
|
| 19 |
+
107.jpg
|
| 20 |
+
110.jpg
|
RDNet-main/RDNet-main/data/torchdata.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import bisect
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Dataset(object):
|
| 6 |
+
"""An abstract class representing a Dataset.
|
| 7 |
+
|
| 8 |
+
All other datasets should subclass it. All subclasses should override
|
| 9 |
+
``__len__``, that provides the size of the dataset, and ``__getitem__``,
|
| 10 |
+
supporting integer indexing in range from 0 to len(self) exclusive.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __getitem__(self, index):
|
| 14 |
+
raise NotImplementedError
|
| 15 |
+
|
| 16 |
+
def __len__(self):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def __add__(self, other):
|
| 20 |
+
return ConcatDataset([self, other])
|
| 21 |
+
|
| 22 |
+
def reset(self):
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConcatDataset(Dataset):
|
| 27 |
+
"""
|
| 28 |
+
Dataset to concatenate multiple datasets.
|
| 29 |
+
Purpose: useful to assemble different existing datasets, possibly
|
| 30 |
+
large-scale datasets as the concatenation operation is done in an
|
| 31 |
+
on-the-fly manner.
|
| 32 |
+
|
| 33 |
+
Arguments:
|
| 34 |
+
datasets (sequence): List of datasets to be concatenated
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def cumsum(sequence):
|
| 39 |
+
r, s = [], 0
|
| 40 |
+
for e in sequence:
|
| 41 |
+
l = len(e)
|
| 42 |
+
r.append(l + s)
|
| 43 |
+
s += l
|
| 44 |
+
return r
|
| 45 |
+
|
| 46 |
+
def __init__(self, datasets):
|
| 47 |
+
super(ConcatDataset, self).__init__()
|
| 48 |
+
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
| 49 |
+
self.datasets = list(datasets)
|
| 50 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return self.cumulative_sizes[-1]
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx):
|
| 56 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
| 57 |
+
if dataset_idx == 0:
|
| 58 |
+
sample_idx = idx
|
| 59 |
+
else:
|
| 60 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
| 61 |
+
return self.datasets[dataset_idx][sample_idx]
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def cummulative_sizes(self):
|
| 65 |
+
warnings.warn("cummulative_sizes attribute is renamed to "
|
| 66 |
+
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
| 67 |
+
return self.cumulative_sizes
|
RDNet-main/RDNet-main/data/transforms.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import accimage
|
| 11 |
+
except ImportError:
|
| 12 |
+
accimage = None
|
| 13 |
+
import numpy as np
|
| 14 |
+
import scipy.stats as st
|
| 15 |
+
import cv2
|
| 16 |
+
import collections
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
import util.util as util
|
| 19 |
+
from scipy.signal import convolve2d
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# utility
|
| 23 |
+
def _is_pil_image(img):
|
| 24 |
+
if accimage is not None:
|
| 25 |
+
return isinstance(img, (Image.Image, accimage.Image))
|
| 26 |
+
else:
|
| 27 |
+
return isinstance(img, Image.Image)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _is_tensor_image(img):
|
| 31 |
+
return torch.is_tensor(img) and img.ndimension() == 3
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _is_numpy_image(img):
|
| 35 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def arrshow(arr):
|
| 39 |
+
Image.fromarray(arr.astype(np.uint8)).show()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_transform(opt):
|
| 43 |
+
transform_list = []
|
| 44 |
+
osizes = util.parse_args(opt.loadSize)
|
| 45 |
+
fineSize = util.parse_args(opt.fineSize)
|
| 46 |
+
if opt.resize_or_crop == 'resize_and_crop':
|
| 47 |
+
transform_list.append(
|
| 48 |
+
transforms.RandomChoice([
|
| 49 |
+
transforms.Resize([osize, osize], Image.BICUBIC) for osize in osizes
|
| 50 |
+
]))
|
| 51 |
+
transform_list.append(transforms.RandomCrop(fineSize))
|
| 52 |
+
elif opt.resize_or_crop == 'crop':
|
| 53 |
+
transform_list.append(transforms.RandomCrop(fineSize))
|
| 54 |
+
elif opt.resize_or_crop == 'scale_width':
|
| 55 |
+
transform_list.append(transforms.Lambda(
|
| 56 |
+
lambda img: __scale_width(img, fineSize)))
|
| 57 |
+
elif opt.resize_or_crop == 'scale_width_and_crop':
|
| 58 |
+
transform_list.append(transforms.Lambda(
|
| 59 |
+
lambda img: __scale_width(img, opt.loadSize)))
|
| 60 |
+
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
| 61 |
+
|
| 62 |
+
if opt.isTrain and not opt.no_flip:
|
| 63 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
| 64 |
+
|
| 65 |
+
return transforms.Compose(transform_list)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
to_norm_tensor = transforms.Compose([
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
transforms.Normalize(
|
| 71 |
+
(0.5, 0.5, 0.5),
|
| 72 |
+
(0.5, 0.5, 0.5)
|
| 73 |
+
)
|
| 74 |
+
])
|
| 75 |
+
|
| 76 |
+
to_tensor = transforms.ToTensor()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def __scale_width(img, target_width):
|
| 80 |
+
ow, oh = img.size
|
| 81 |
+
if (ow == target_width):
|
| 82 |
+
return img
|
| 83 |
+
w = target_width
|
| 84 |
+
h = int(target_width * oh / ow)
|
| 85 |
+
h = math.ceil(h / 2.) * 2 # round up to even
|
| 86 |
+
return img.resize((w, h), Image.BICUBIC)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# functional
|
| 90 |
+
def gaussian_blur(img, kernel_size, sigma):
|
| 91 |
+
if not _is_pil_image(img):
|
| 92 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
| 93 |
+
|
| 94 |
+
img = np.asarray(img)
|
| 95 |
+
# the 3rd dimension (i.e. inter-band) would be filtered which is unwanted for our purpose
|
| 96 |
+
# new = gaussian_filter(img, sigma=sigma, truncate=truncate)
|
| 97 |
+
if isinstance(kernel_size, int):
|
| 98 |
+
kernel_size = (kernel_size, kernel_size)
|
| 99 |
+
elif isinstance(kernel_size, collections.Sequence):
|
| 100 |
+
assert len(kernel_size) == 2
|
| 101 |
+
new = cv2.GaussianBlur(img, kernel_size, sigma) # apply gaussian filter band by band
|
| 102 |
+
return Image.fromarray(new)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# transforms
|
| 106 |
+
class GaussianBlur(object):
|
| 107 |
+
def __init__(self, kernel_size=11, sigma=3):
|
| 108 |
+
self.kernel_size = kernel_size
|
| 109 |
+
self.sigma = sigma
|
| 110 |
+
|
| 111 |
+
def __call__(self, img):
|
| 112 |
+
return gaussian_blur(img, self.kernel_size, self.sigma)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ReflectionSythesis_0(object):
|
| 116 |
+
"""Reflection image data synthesis for weakly-supervised learning
|
| 117 |
+
of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3,
|
| 121 |
+
high_gamma=1.3, low_delta=0.4, high_delta=1.8):
|
| 122 |
+
self.kernel_sizes = kernel_sizes or [11]
|
| 123 |
+
self.low_sigma = low_sigma
|
| 124 |
+
self.high_sigma = high_sigma
|
| 125 |
+
self.low_gamma = low_gamma
|
| 126 |
+
self.high_gamma = high_gamma
|
| 127 |
+
self.low_delta = low_delta
|
| 128 |
+
self.high_delta = high_delta
|
| 129 |
+
print('[i] reflection sythesis model: {}'.format({
|
| 130 |
+
'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
|
| 131 |
+
'low_gamma': low_gamma, 'high_gamma': high_gamma}))
|
| 132 |
+
|
| 133 |
+
def __call__(self, B, R):
|
| 134 |
+
if not _is_pil_image(B):
|
| 135 |
+
raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
|
| 136 |
+
if not _is_pil_image(R):
|
| 137 |
+
raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
|
| 138 |
+
B_ = np.asarray(B, np.float32)
|
| 139 |
+
if random.random() < 0.4:
|
| 140 |
+
B_ = np.tile(np.random.uniform(0, 30, (1, 1, 1)), B_.shape) / 255.
|
| 141 |
+
else:
|
| 142 |
+
B_ = np.tile(np.random.normal(50, 50, (1, 1, 3)), (B_.shape[0], B_.shape[1], 1)).clip(0, 255) / 255.
|
| 143 |
+
R_ = np.asarray(R, np.float32) / 255.
|
| 144 |
+
|
| 145 |
+
kernel_size = np.random.choice(self.kernel_sizes)
|
| 146 |
+
sigma = np.random.uniform(self.low_sigma, self.high_sigma)
|
| 147 |
+
gamma = np.random.uniform(self.low_gamma, self.high_gamma)
|
| 148 |
+
delta = np.random.uniform(self.low_delta, self.high_delta)
|
| 149 |
+
R_blur = R_
|
| 150 |
+
kernel = cv2.getGaussianKernel(11, sigma)
|
| 151 |
+
kernel2d = np.dot(kernel, kernel.T)
|
| 152 |
+
|
| 153 |
+
for i in range(3):
|
| 154 |
+
R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
|
| 155 |
+
|
| 156 |
+
R_blur = np.clip(R_blur - np.mean(R_blur) * gamma, 0, 1)
|
| 157 |
+
R_blur = np.clip(R_blur * delta, 0, 1)
|
| 158 |
+
M_ = np.clip(R_blur + B_, 0, 1)
|
| 159 |
+
|
| 160 |
+
return B_, R_blur, M_
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ReflectionSythesis_1(object):
|
| 164 |
+
"""Reflection image data synthesis for weakly-supervised learning
|
| 165 |
+
of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"*
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3, high_gamma=1.3):
|
| 169 |
+
self.kernel_sizes = kernel_sizes or [11]
|
| 170 |
+
self.low_sigma = low_sigma
|
| 171 |
+
self.high_sigma = high_sigma
|
| 172 |
+
self.low_gamma = low_gamma
|
| 173 |
+
self.high_gamma = high_gamma
|
| 174 |
+
print('[i] reflection sythesis model: {}'.format({
|
| 175 |
+
'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma,
|
| 176 |
+
'low_gamma': low_gamma, 'high_gamma': high_gamma}))
|
| 177 |
+
|
| 178 |
+
def __call__(self, B, R):
|
| 179 |
+
if not _is_pil_image(B):
|
| 180 |
+
raise TypeError('B should be PIL Image. Got {}'.format(type(B)))
|
| 181 |
+
if not _is_pil_image(R):
|
| 182 |
+
raise TypeError('R should be PIL Image. Got {}'.format(type(R)))
|
| 183 |
+
|
| 184 |
+
B_ = np.asarray(B, np.float32) / 255.
|
| 185 |
+
R_ = np.asarray(R, np.float32) / 255.
|
| 186 |
+
|
| 187 |
+
kernel_size = np.random.choice(self.kernel_sizes)
|
| 188 |
+
sigma = np.random.uniform(self.low_sigma, self.high_sigma)
|
| 189 |
+
gamma = np.random.uniform(self.low_gamma, self.high_gamma)
|
| 190 |
+
R_blur = R_
|
| 191 |
+
kernel = cv2.getGaussianKernel(11, sigma)
|
| 192 |
+
kernel2d = np.dot(kernel, kernel.T)
|
| 193 |
+
|
| 194 |
+
for i in range(3):
|
| 195 |
+
R_blur[..., i] = convolve2d(R_blur[..., i], kernel2d, mode='same')
|
| 196 |
+
|
| 197 |
+
M_ = B_ + R_blur
|
| 198 |
+
|
| 199 |
+
if np.max(M_) > 1:
|
| 200 |
+
m = M_[M_ > 1]
|
| 201 |
+
m = (np.mean(m) - 1) * gamma
|
| 202 |
+
R_blur = np.clip(R_blur - m, 0, 1)
|
| 203 |
+
M_ = np.clip(R_blur + B_, 0, 1)
|
| 204 |
+
|
| 205 |
+
return B_, R_blur, M_
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class Sobel(object):
|
| 209 |
+
def __call__(self, img):
|
| 210 |
+
if not _is_pil_image(img):
|
| 211 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
| 212 |
+
|
| 213 |
+
gray_img = np.array(img.convert('L'))
|
| 214 |
+
x = cv2.Sobel(gray_img, cv2.CV_16S, 1, 0)
|
| 215 |
+
y = cv2.Sobel(gray_img, cv2.CV_16S, 0, 1)
|
| 216 |
+
|
| 217 |
+
absX = cv2.convertScaleAbs(x)
|
| 218 |
+
absY = cv2.convertScaleAbs(y)
|
| 219 |
+
|
| 220 |
+
dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
|
| 221 |
+
return Image.fromarray(dst)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class ReflectionSythesis_2(object):
|
| 225 |
+
"""Reflection image data synthesis for weakly-supervised learning
|
| 226 |
+
of CVPR 2018 paper *"Single Image Reflection Separation with Perceptual Losses"*
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(self, kernel_sizes=None):
|
| 230 |
+
self.kernel_sizes = kernel_sizes or np.linspace(1, 5, 80)
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def gkern(kernlen=100, nsig=1):
|
| 234 |
+
"""Returns a 2D Gaussian kernel array."""
|
| 235 |
+
interval = (2 * nsig + 1.) / (kernlen)
|
| 236 |
+
x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
|
| 237 |
+
kern1d = np.diff(st.norm.cdf(x))
|
| 238 |
+
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
|
| 239 |
+
kernel = kernel_raw / kernel_raw.sum()
|
| 240 |
+
kernel = kernel / kernel.max()
|
| 241 |
+
return kernel
|
| 242 |
+
|
| 243 |
+
def __call__(self, t, r):
|
| 244 |
+
t = np.float32(t) / 255.
|
| 245 |
+
r = np.float32(r) / 255.
|
| 246 |
+
ori_t = t
|
| 247 |
+
# create a vignetting mask
|
| 248 |
+
g_mask = self.gkern(560, 3)
|
| 249 |
+
g_mask = np.dstack((g_mask, g_mask, g_mask))
|
| 250 |
+
sigma = self.kernel_sizes[np.random.randint(0, len(self.kernel_sizes))]
|
| 251 |
+
|
| 252 |
+
t = np.power(t, 2.2)
|
| 253 |
+
r = np.power(r, 2.2)
|
| 254 |
+
|
| 255 |
+
sz = int(2 * np.ceil(2 * sigma) + 1)
|
| 256 |
+
|
| 257 |
+
r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
|
| 258 |
+
blend = r_blur + t
|
| 259 |
+
|
| 260 |
+
att = 1.08 + np.random.random() / 10.0
|
| 261 |
+
|
| 262 |
+
for i in range(3):
|
| 263 |
+
maski = blend[:, :, i] > 1
|
| 264 |
+
mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
|
| 265 |
+
r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
|
| 266 |
+
r_blur[r_blur >= 1] = 1
|
| 267 |
+
r_blur[r_blur <= 0] = 0
|
| 268 |
+
|
| 269 |
+
h, w = r_blur.shape[0:2]
|
| 270 |
+
neww = np.random.randint(0, 560 - w - 10)
|
| 271 |
+
newh = np.random.randint(0, 560 - h - 10)
|
| 272 |
+
alpha1 = g_mask[newh:newh + h, neww:neww + w, :]
|
| 273 |
+
alpha2 = 1 - np.random.random() / 5.0
|
| 274 |
+
r_blur_mask = np.multiply(r_blur, alpha1)
|
| 275 |
+
blend = r_blur_mask + t * alpha2
|
| 276 |
+
|
| 277 |
+
t = np.power(t, 1 / 2.2)
|
| 278 |
+
r_blur_mask = np.power(r_blur_mask, 1 / 2.2)
|
| 279 |
+
blend = np.power(blend, 1 / 2.2)
|
| 280 |
+
blend[blend >= 1] = 1
|
| 281 |
+
blend[blend <= 0] = 0
|
| 282 |
+
|
| 283 |
+
return np.float32(ori_t), np.float32(r_blur_mask), np.float32(blend)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# Examples
|
| 287 |
+
if __name__ == '__main__':
|
| 288 |
+
"""cv2 imread"""
|
| 289 |
+
# img = cv2.imread('testdata_reflection_real/19-input.png')
|
| 290 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 291 |
+
# img2 = cv2.GaussianBlur(img, (11,11), 3)
|
| 292 |
+
|
| 293 |
+
"""Sobel Operator"""
|
| 294 |
+
# img = np.array(Image.open('datasets/VOC224/train/B/2007_000250.png').convert('L'))
|
| 295 |
+
|
| 296 |
+
"""Reflection Sythesis"""
|
| 297 |
+
b = Image.open('')
|
| 298 |
+
r = Image.open('')
|
| 299 |
+
G = ReflectionSythesis_0()
|
| 300 |
+
m, r = G(b, r)
|
| 301 |
+
r.show()
|
RDNet-main/RDNet-main/engine.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import util.util as util
|
| 3 |
+
from models import make_model
|
| 4 |
+
import time
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from os.path import join
|
| 8 |
+
from util.visualizer import Visualizer
|
| 9 |
+
import tqdm
|
| 10 |
+
import visdom
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tools import mutils
|
| 13 |
+
|
| 14 |
+
class Engine(object):
|
| 15 |
+
def __init__(self, opt,eval_dataset_real,eval_dataset_solidobject,eval_dataset_postcard,eval_dataloader_wild):
|
| 16 |
+
self.opt = opt
|
| 17 |
+
self.writer = None
|
| 18 |
+
self.visualizer = None
|
| 19 |
+
self.model = None
|
| 20 |
+
self.best_val_loss = 1e6
|
| 21 |
+
self.eval_dataset_real = eval_dataset_real
|
| 22 |
+
self.eval_dataset_solidobject = eval_dataset_solidobject
|
| 23 |
+
self.eval_dataset_postcard = eval_dataset_postcard
|
| 24 |
+
self.eval_dataloader_wild = eval_dataloader_wild
|
| 25 |
+
self.result_dir = os.path.join(f'./experiment/{self.opt.name}/results',
|
| 26 |
+
mutils.get_formatted_time())
|
| 27 |
+
self.biggest_psnr=0
|
| 28 |
+
self.__setup()
|
| 29 |
+
|
| 30 |
+
def __setup(self):
|
| 31 |
+
self.basedir = join('experiment', self.opt.name)
|
| 32 |
+
os.makedirs(self.basedir, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
opt = self.opt
|
| 35 |
+
|
| 36 |
+
"""Model"""
|
| 37 |
+
self.model = make_model(self.opt.model) # models.__dict__[self.opt.model]()
|
| 38 |
+
self.model.initialize(opt)
|
| 39 |
+
if True:
|
| 40 |
+
print("IN")
|
| 41 |
+
self.writer = util.get_summary_writer(os.path.join(self.basedir, 'logs'))
|
| 42 |
+
self.visualizer = Visualizer(opt)
|
| 43 |
+
|
| 44 |
+
def train(self, train_loader, **kwargs):
|
| 45 |
+
print('\nEpoch: %d' % self.epoch)
|
| 46 |
+
avg_meters = util.AverageMeters()
|
| 47 |
+
opt = self.opt
|
| 48 |
+
model = self.model
|
| 49 |
+
epoch = self.epoch
|
| 50 |
+
|
| 51 |
+
epoch_start_time = time.time()
|
| 52 |
+
for i, data in tqdm.tqdm(enumerate(train_loader)):
|
| 53 |
+
|
| 54 |
+
iter_start_time = time.time()
|
| 55 |
+
iterations = self.iterations
|
| 56 |
+
|
| 57 |
+
model.set_input(data, mode='train')
|
| 58 |
+
model.optimize_parameters(**kwargs)
|
| 59 |
+
|
| 60 |
+
errors = model.get_current_errors()
|
| 61 |
+
avg_meters.update(errors)
|
| 62 |
+
util.progress_bar(i, len(train_loader), str(avg_meters))
|
| 63 |
+
util.write_loss(self.writer, 'train', avg_meters, iterations)
|
| 64 |
+
if iterations%100==0:
|
| 65 |
+
imgs=[]
|
| 66 |
+
output_clean,output_reflection,input=model.return_output()
|
| 67 |
+
# output_clean,input=model.return_output()
|
| 68 |
+
|
| 69 |
+
output_clean=np.transpose(output_clean,(2,0,1))/255
|
| 70 |
+
#output_reflection = np.transpose(output_reflection, (2, 0, 1))/255
|
| 71 |
+
input = np.transpose(input, (2, 0, 1))/255
|
| 72 |
+
imgs.append(output_clean)
|
| 73 |
+
#imgs.append(output_reflection)
|
| 74 |
+
imgs.append(input)
|
| 75 |
+
util.get_visual(self.writer,iterations,imgs)
|
| 76 |
+
if iterations % opt.print_freq == 0 and opt.display_id != 0:
|
| 77 |
+
t = (time.time() - iter_start_time)
|
| 78 |
+
|
| 79 |
+
self.iterations += 1
|
| 80 |
+
|
| 81 |
+
self.epoch += 1
|
| 82 |
+
|
| 83 |
+
if True:#not self.opt.no_log:
|
| 84 |
+
if self.epoch % opt.save_epoch_freq == 0:
|
| 85 |
+
save_dir = os.path.join(self.result_dir, '%03d' % self.epoch)
|
| 86 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 87 |
+
matrix_real=self.eval(self.eval_dataset_real, dataset_name='testdata_real20', savedir=save_dir, suffix='real20')
|
| 88 |
+
matrix_solid=self.eval(self.eval_dataset_solidobject, dataset_name='testdata_solidobject', savedir=save_dir,
|
| 89 |
+
suffix='solidobject')
|
| 90 |
+
matrix_post=self.eval(self.eval_dataset_postcard, dataset_name='testdata_postcard', savedir=save_dir, suffix='postcard')
|
| 91 |
+
matrix_wild=self.eval(self.eval_dataloader_wild, dataset_name='testdata_wild', savedir=save_dir, suffix='wild')
|
| 92 |
+
sum_PSNR_real=matrix_real['PSNR']*20
|
| 93 |
+
sum_PSNR_solid=matrix_solid['PSNR']*200
|
| 94 |
+
sum_PSNR_post=matrix_post['PSNR']*199
|
| 95 |
+
sum_PSNR_wild=matrix_wild['PSNR']*55
|
| 96 |
+
print("sum_PSNR_real: ",matrix_real['PSNR'],"sum_PSNR_solid: ",matrix_solid['PSNR'],"sum_PSNR_post: ",matrix_post['PSNR'],"sum_PSNR_wild: ",matrix_wild['PSNR'])
|
| 97 |
+
sum_PSNR = float(sum_PSNR_real + sum_PSNR_solid + sum_PSNR_post + sum_PSNR_wild)/474.0
|
| 98 |
+
print('总PSNR:', sum_PSNR)
|
| 99 |
+
if sum_PSNR>self.biggest_psnr:
|
| 100 |
+
self.biggest_psnr=sum_PSNR
|
| 101 |
+
print('saving the model at epoch %d, iters %d' %(self.epoch, self.iterations))
|
| 102 |
+
model.save()
|
| 103 |
+
print('highest: ',self.biggest_psnr,' name: ',opt.name)
|
| 104 |
+
|
| 105 |
+
print('saving the latest model at the end of epoch %d, iters %d' %
|
| 106 |
+
(self.epoch, self.iterations))
|
| 107 |
+
model.save(label='latest')
|
| 108 |
+
|
| 109 |
+
print('Time Taken: %d sec' %
|
| 110 |
+
(time.time() - epoch_start_time))
|
| 111 |
+
|
| 112 |
+
# model.update_learning_rate()
|
| 113 |
+
try:
|
| 114 |
+
train_loader.reset()
|
| 115 |
+
except:
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
def eval(self, val_loader, dataset_name, savedir='./tmp', loss_key=None, **kwargs):
|
| 119 |
+
# print(dataset_name)
|
| 120 |
+
if savedir is not None:
|
| 121 |
+
os.makedirs(savedir, exist_ok=True)
|
| 122 |
+
self.f = open(os.path.join(savedir, 'metrics.txt'), 'w+')
|
| 123 |
+
self.f.write(dataset_name + '\n')
|
| 124 |
+
avg_meters = util.AverageMeters()
|
| 125 |
+
model = self.model
|
| 126 |
+
opt = self.opt
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
for i, data in enumerate(val_loader):
|
| 129 |
+
if self.opt.select is not None and data['fn'][0] not in [f'{self.opt.select}.jpg']:
|
| 130 |
+
continue
|
| 131 |
+
#print(data.shape())
|
| 132 |
+
index = model.eval(data, savedir=savedir, **kwargs)
|
| 133 |
+
|
| 134 |
+
# print(data['fn'][0], index)
|
| 135 |
+
if savedir is not None:
|
| 136 |
+
self.f.write(f"{data['fn'][0]} {index['PSNR']} {index['SSIM']}\n")
|
| 137 |
+
avg_meters.update(index)
|
| 138 |
+
util.progress_bar(i, len(val_loader), str(avg_meters))
|
| 139 |
+
|
| 140 |
+
if not opt.no_log:
|
| 141 |
+
util.write_loss(self.writer, join('eval', dataset_name), avg_meters, self.epoch)
|
| 142 |
+
|
| 143 |
+
if loss_key is not None:
|
| 144 |
+
val_loss = avg_meters[loss_key]
|
| 145 |
+
if val_loss < self.best_val_loss:
|
| 146 |
+
self.best_val_loss = val_loss
|
| 147 |
+
print('saving the best model at the end of epoch %d, iters %d' %
|
| 148 |
+
(self.epoch, self.iterations))
|
| 149 |
+
model.save(label='best_{}_{}'.format(loss_key, dataset_name))
|
| 150 |
+
|
| 151 |
+
return avg_meters
|
| 152 |
+
|
| 153 |
+
def test(self, test_loader, savedir=None, **kwargs):
|
| 154 |
+
model = self.model
|
| 155 |
+
opt = self.opt
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
for i, data in enumerate(test_loader):
|
| 158 |
+
model.test(data, savedir=savedir, **kwargs)
|
| 159 |
+
util.progress_bar(i, len(test_loader))
|
| 160 |
+
|
| 161 |
+
def save_eval(self, label):
|
| 162 |
+
self.model.save_eval(label)
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def iterations(self):
|
| 166 |
+
return self.model.iterations
|
| 167 |
+
|
| 168 |
+
@iterations.setter
|
| 169 |
+
def iterations(self, i):
|
| 170 |
+
self.model.iterations = i
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def epoch(self):
|
| 174 |
+
return self.model.epoch
|
| 175 |
+
|
| 176 |
+
@epoch.setter
|
| 177 |
+
def epoch(self, e):
|
| 178 |
+
self.model.epoch = e
|
RDNet-main/RDNet-main/figures/Input_car.jpg
ADDED
|
RDNet-main/RDNet-main/figures/Input_class.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/Input_green.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/Ours_car.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/Ours_class.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/Ours_green.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/Ours_white.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/Title.png
ADDED
|
RDNet-main/RDNet-main/figures/input_white.jpg
ADDED
|
RDNet-main/RDNet-main/figures/net.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/result.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/figures/vis.png
ADDED
|
Git LFS Details
|
RDNet-main/RDNet-main/models/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
from models.arch import *
|
| 4 |
+
|
| 5 |
+
from models.cls_model_eval_nocls_reg import ClsModel as ClsReg
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_model(name: str):
|
| 9 |
+
|
| 10 |
+
model = ClsReg()
|
| 11 |
+
return model
|
RDNet-main/RDNet-main/models/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (414 Bytes). View file
|
|
|
RDNet-main/RDNet-main/models/__pycache__/base_model.cpython-38.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
RDNet-main/RDNet-main/models/__pycache__/cls_model_eval_nocls_reg.cpython-38.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
RDNet-main/RDNet-main/models/__pycache__/losses.cpython-38.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
RDNet-main/RDNet-main/models/__pycache__/networks.cpython-38.pyc
ADDED
|
Binary file (9.34 kB). View file
|
|
|
RDNet-main/RDNet-main/models/__pycache__/vgg.cpython-38.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
RDNet-main/RDNet-main/models/__pycache__/vit_feature_extractor.cpython-38.pyc
ADDED
|
Binary file (6.95 kB). View file
|
|
|
RDNet-main/RDNet-main/models/arch/NAFNET.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ------------------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
| 3 |
+
# ------------------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
'''
|
| 6 |
+
Simple Baselines for Image Restoration
|
| 7 |
+
|
| 8 |
+
@article{chen2022simple,
|
| 9 |
+
title={Simple Baselines for Image Restoration},
|
| 10 |
+
author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
|
| 11 |
+
journal={arXiv preprint arXiv:2204.04676},
|
| 12 |
+
year={2022}
|
| 13 |
+
}
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
# from .models.archs.arch_util import LayerNorm2d
|
| 20 |
+
import sys
|
| 21 |
+
sys.path.append('/ghome/zhuyr/Deref_RW/networks/')
|
| 22 |
+
|
| 23 |
+
class LayerNormFunction(torch.autograd.Function):
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def forward(ctx, x, weight, bias, eps):
|
| 27 |
+
ctx.eps = eps
|
| 28 |
+
N, C, H, W = x.size()
|
| 29 |
+
mu = x.mean(1, keepdim=True)
|
| 30 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
| 31 |
+
y = (x - mu) / (var + eps).sqrt()
|
| 32 |
+
ctx.save_for_backward(y, var, weight)
|
| 33 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
| 34 |
+
return y
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def backward(ctx, grad_output):
|
| 38 |
+
eps = ctx.eps
|
| 39 |
+
|
| 40 |
+
N, C, H, W = grad_output.size()
|
| 41 |
+
y, var, weight = ctx.saved_variables
|
| 42 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
| 43 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
| 44 |
+
|
| 45 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
| 46 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
| 47 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
| 48 |
+
dim=0), None
|
| 49 |
+
|
| 50 |
+
class LayerNorm2d(nn.Module):
|
| 51 |
+
|
| 52 |
+
def __init__(self, channels, eps=1e-6):
|
| 53 |
+
super(LayerNorm2d, self).__init__()
|
| 54 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
| 55 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
| 56 |
+
self.eps = eps
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
| 60 |
+
|
| 61 |
+
class SimpleGate(nn.Module):
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 64 |
+
return x1 * x2
|
| 65 |
+
|
| 66 |
+
class NAFBlock(nn.Module):
|
| 67 |
+
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
| 68 |
+
super().__init__()
|
| 69 |
+
dw_channel = c * DW_Expand
|
| 70 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 71 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
|
| 72 |
+
bias=True)
|
| 73 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 74 |
+
|
| 75 |
+
# Simplified Channel Attention
|
| 76 |
+
self.sca = nn.Sequential(
|
| 77 |
+
nn.AdaptiveAvgPool2d(1),
|
| 78 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
| 79 |
+
groups=1, bias=True),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# SimpleGate
|
| 83 |
+
self.sg = SimpleGate()
|
| 84 |
+
|
| 85 |
+
ffn_channel = FFN_Expand * c
|
| 86 |
+
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 87 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 88 |
+
|
| 89 |
+
self.norm1 = LayerNorm2d(c)
|
| 90 |
+
self.norm2 = LayerNorm2d(c)
|
| 91 |
+
|
| 92 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 93 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 94 |
+
|
| 95 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
| 96 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
| 97 |
+
|
| 98 |
+
def forward(self, inp):
|
| 99 |
+
x = inp
|
| 100 |
+
|
| 101 |
+
x = self.norm1(x)
|
| 102 |
+
|
| 103 |
+
x = self.conv1(x)
|
| 104 |
+
x = self.conv2(x)
|
| 105 |
+
x = self.sg(x)
|
| 106 |
+
x = x * self.sca(x)
|
| 107 |
+
x = self.conv3(x)
|
| 108 |
+
|
| 109 |
+
x = self.dropout1(x)
|
| 110 |
+
|
| 111 |
+
y = inp + x * self.beta
|
| 112 |
+
|
| 113 |
+
x = self.conv4(self.norm2(y))
|
| 114 |
+
x = self.sg(x)
|
| 115 |
+
x = self.conv5(x)
|
| 116 |
+
|
| 117 |
+
x = self.dropout2(x)
|
| 118 |
+
|
| 119 |
+
return y + x * self.gamma
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class NAFNet(nn.Module):
|
| 123 |
+
|
| 124 |
+
def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28],
|
| 125 |
+
dec_blk_nums=[1, 1, 1, 1], global_residual = False, drop_flag = False, drop_rate = 0.4):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
| 129 |
+
bias=True)
|
| 130 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
|
| 131 |
+
bias=True)
|
| 132 |
+
|
| 133 |
+
self.encoders = nn.ModuleList()
|
| 134 |
+
self.decoders = nn.ModuleList()
|
| 135 |
+
self.middle_blks = nn.ModuleList()
|
| 136 |
+
self.ups = nn.ModuleList()
|
| 137 |
+
self.downs = nn.ModuleList()
|
| 138 |
+
self.global_residual = global_residual
|
| 139 |
+
self.drop_flag = drop_flag
|
| 140 |
+
|
| 141 |
+
if drop_flag:
|
| 142 |
+
self.dropout = nn.Dropout2d(p=drop_rate)
|
| 143 |
+
|
| 144 |
+
chan = width
|
| 145 |
+
for num in enc_blk_nums:
|
| 146 |
+
self.encoders.append(
|
| 147 |
+
nn.Sequential(
|
| 148 |
+
*[NAFBlock(chan) for _ in range(num)]
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
self.downs.append(
|
| 152 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
| 153 |
+
)
|
| 154 |
+
chan = chan * 2
|
| 155 |
+
|
| 156 |
+
self.middle_blks = \
|
| 157 |
+
nn.Sequential(
|
| 158 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
for num in dec_blk_nums:
|
| 162 |
+
self.ups.append(
|
| 163 |
+
nn.Sequential(
|
| 164 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
| 165 |
+
nn.PixelShuffle(2)
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
chan = chan // 2
|
| 169 |
+
self.decoders.append(
|
| 170 |
+
nn.Sequential(
|
| 171 |
+
*[NAFBlock(chan) for _ in range(num)]
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.padder_size = 2 ** len(self.encoders)
|
| 176 |
+
|
| 177 |
+
def forward(self, inp):
|
| 178 |
+
B, C, H, W = inp.shape
|
| 179 |
+
inp = self.check_image_size(inp)
|
| 180 |
+
base_inp = inp[:, :3, :, :]
|
| 181 |
+
x = self.intro(inp)
|
| 182 |
+
|
| 183 |
+
encs = []
|
| 184 |
+
|
| 185 |
+
for encoder, down in zip(self.encoders, self.downs):
|
| 186 |
+
x = encoder(x)
|
| 187 |
+
encs.append(x)
|
| 188 |
+
x = down(x)
|
| 189 |
+
|
| 190 |
+
x = self.middle_blks(x)
|
| 191 |
+
|
| 192 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
| 193 |
+
x = up(x)
|
| 194 |
+
x = x + enc_skip
|
| 195 |
+
x = decoder(x)
|
| 196 |
+
|
| 197 |
+
if self.drop_flag:
|
| 198 |
+
x = self.dropout(x)
|
| 199 |
+
|
| 200 |
+
x = self.ending(x)
|
| 201 |
+
if self.global_residual:
|
| 202 |
+
#print(x.shape, inp.shape, base_inp.shape)
|
| 203 |
+
x = x + base_inp
|
| 204 |
+
else:
|
| 205 |
+
x
|
| 206 |
+
return x[:, :, :H, :W]
|
| 207 |
+
|
| 208 |
+
def check_image_size(self, x):
|
| 209 |
+
_, _, h, w = x.size()
|
| 210 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
| 211 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
| 212 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class NAFNet_wDetHead(nn.Module):
|
| 218 |
+
|
| 219 |
+
def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28],
|
| 220 |
+
dec_blk_nums=[1, 1, 1, 1], global_residual = False, drop_flag = False, drop_rate = 0.4,
|
| 221 |
+
concat = False, merge_manner = 0):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
| 225 |
+
bias=True)
|
| 226 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
|
| 227 |
+
bias=True)
|
| 228 |
+
|
| 229 |
+
self.encoders = nn.ModuleList()
|
| 230 |
+
self.decoders = nn.ModuleList()
|
| 231 |
+
self.middle_blks = nn.ModuleList()
|
| 232 |
+
self.ups = nn.ModuleList()
|
| 233 |
+
self.downs = nn.ModuleList()
|
| 234 |
+
self.global_residual = global_residual
|
| 235 |
+
self.drop_flag = drop_flag
|
| 236 |
+
self.concat = concat
|
| 237 |
+
self.merge_manner = merge_manner
|
| 238 |
+
|
| 239 |
+
if drop_flag:
|
| 240 |
+
self.dropout = nn.Dropout2d(p=drop_rate)
|
| 241 |
+
|
| 242 |
+
# --------------------------- Merge sparse & Img -------------------------------------------------------
|
| 243 |
+
self.intro_Det = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
| 244 |
+
bias=True)
|
| 245 |
+
self.DetEnc = nn.Sequential( *[NAFBlock(width) for _ in range(3)] )
|
| 246 |
+
if self.concat:
|
| 247 |
+
self.Merge_conv = nn.Conv2d(in_channels=width *2 , out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
| 248 |
+
bias=True)
|
| 249 |
+
else:
|
| 250 |
+
self.Merge_conv = nn.Conv2d(in_channels=width , out_channels=width, kernel_size=3, padding=1, stride=1,
|
| 251 |
+
groups=1,
|
| 252 |
+
bias=True)
|
| 253 |
+
# --------------------------- Merge sparse & Img -------------------------------------------------------
|
| 254 |
+
|
| 255 |
+
chan = width
|
| 256 |
+
for num in enc_blk_nums:
|
| 257 |
+
self.encoders.append(
|
| 258 |
+
nn.Sequential(
|
| 259 |
+
*[NAFBlock(chan) for _ in range(num)]
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
self.downs.append(
|
| 263 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
| 264 |
+
)
|
| 265 |
+
chan = chan * 2
|
| 266 |
+
|
| 267 |
+
self.middle_blks = \
|
| 268 |
+
nn.Sequential(
|
| 269 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
for num in dec_blk_nums:
|
| 273 |
+
self.ups.append(
|
| 274 |
+
nn.Sequential(
|
| 275 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
| 276 |
+
nn.PixelShuffle(2)
|
| 277 |
+
)
|
| 278 |
+
)
|
| 279 |
+
chan = chan // 2
|
| 280 |
+
self.decoders.append(
|
| 281 |
+
nn.Sequential(
|
| 282 |
+
*[NAFBlock(chan) for _ in range(num)]
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
self.padder_size = 2 ** len(self.encoders)
|
| 287 |
+
|
| 288 |
+
def forward(self, inp, spare_ref):
|
| 289 |
+
B, C, H, W = inp.shape
|
| 290 |
+
inp = self.check_image_size(inp)
|
| 291 |
+
base_inp = inp #[:, :3, :, :]
|
| 292 |
+
x = self.intro(inp)
|
| 293 |
+
|
| 294 |
+
fea_sparse = self.DetEnc(self.intro_Det(spare_ref))
|
| 295 |
+
|
| 296 |
+
if self.merge_manner ==0 and self.concat:
|
| 297 |
+
x = torch.cat([x, fea_sparse], dim=1)
|
| 298 |
+
x = self.Merge_conv(x)
|
| 299 |
+
elif self.merge_manner == 1 and not self.concat:
|
| 300 |
+
x = x + fea_sparse
|
| 301 |
+
x = self.Merge_conv(x)
|
| 302 |
+
elif self.merge_manner == 2 and not self.concat:
|
| 303 |
+
x = x + fea_sparse *x
|
| 304 |
+
x = self.Merge_conv(x)
|
| 305 |
+
else:
|
| 306 |
+
x = x
|
| 307 |
+
print('Merge Flag Error!!!(No Merge Operation) ---zyr 1031 ')
|
| 308 |
+
|
| 309 |
+
encs = []
|
| 310 |
+
|
| 311 |
+
for encoder, down in zip(self.encoders, self.downs):
|
| 312 |
+
x = encoder(x)
|
| 313 |
+
encs.append(x)
|
| 314 |
+
x = down(x)
|
| 315 |
+
|
| 316 |
+
x = self.middle_blks(x)
|
| 317 |
+
|
| 318 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
| 319 |
+
x = up(x)
|
| 320 |
+
x = x + enc_skip
|
| 321 |
+
x = decoder(x)
|
| 322 |
+
|
| 323 |
+
if self.drop_flag:
|
| 324 |
+
x = self.dropout(x)
|
| 325 |
+
|
| 326 |
+
x = self.ending(x)
|
| 327 |
+
if self.global_residual:
|
| 328 |
+
#print(x.shape, inp.shape, base_inp.shape)
|
| 329 |
+
x = x + base_inp
|
| 330 |
+
else:
|
| 331 |
+
x
|
| 332 |
+
return x[:, :, :H, :W]
|
| 333 |
+
|
| 334 |
+
def check_image_size(self, x):
|
| 335 |
+
_, _, h, w = x.size()
|
| 336 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
| 337 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
| 338 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class NAFNet_refine(nn.Module):
|
| 343 |
+
|
| 344 |
+
def __init__(self, img_channel=6, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28],
|
| 345 |
+
dec_blk_nums=[1, 1, 1, 1], global_residual = False):
|
| 346 |
+
super().__init__()
|
| 347 |
+
|
| 348 |
+
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
|
| 349 |
+
bias=True)
|
| 350 |
+
self.ending = nn.Conv2d(in_channels=width, out_channels=3, kernel_size=3, padding=1, stride=1, groups=1,
|
| 351 |
+
bias=True)
|
| 352 |
+
|
| 353 |
+
self.encoders = nn.ModuleList()
|
| 354 |
+
self.decoders = nn.ModuleList()
|
| 355 |
+
self.middle_blks = nn.ModuleList()
|
| 356 |
+
self.ups = nn.ModuleList()
|
| 357 |
+
self.downs = nn.ModuleList()
|
| 358 |
+
self.global_residual = global_residual
|
| 359 |
+
|
| 360 |
+
chan = width
|
| 361 |
+
for num in enc_blk_nums:
|
| 362 |
+
self.encoders.append(
|
| 363 |
+
nn.Sequential(
|
| 364 |
+
*[NAFBlock(chan) for _ in range(num)]
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
self.downs.append(
|
| 368 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
| 369 |
+
)
|
| 370 |
+
chan = chan * 2
|
| 371 |
+
|
| 372 |
+
self.middle_blks = \
|
| 373 |
+
nn.Sequential(
|
| 374 |
+
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
for num in dec_blk_nums:
|
| 378 |
+
self.ups.append(
|
| 379 |
+
nn.Sequential(
|
| 380 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
| 381 |
+
nn.PixelShuffle(2)
|
| 382 |
+
)
|
| 383 |
+
)
|
| 384 |
+
chan = chan // 2
|
| 385 |
+
self.decoders.append(
|
| 386 |
+
nn.Sequential(
|
| 387 |
+
*[NAFBlock(chan) for _ in range(num)]
|
| 388 |
+
)
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
self.padder_size = 2 ** len(self.encoders)
|
| 392 |
+
|
| 393 |
+
def forward(self, inp, pre_pred):
|
| 394 |
+
B, C, H, W = inp.shape
|
| 395 |
+
inp = self.check_image_size(inp)
|
| 396 |
+
pre_pred = self.check_image_size(pre_pred)
|
| 397 |
+
|
| 398 |
+
network_in = torch.cat([inp, pre_pred ], dim= 1)
|
| 399 |
+
|
| 400 |
+
x = self.intro(network_in)
|
| 401 |
+
|
| 402 |
+
encs = []
|
| 403 |
+
|
| 404 |
+
for encoder, down in zip(self.encoders, self.downs):
|
| 405 |
+
x = encoder(x)
|
| 406 |
+
encs.append(x)
|
| 407 |
+
x = down(x)
|
| 408 |
+
|
| 409 |
+
x = self.middle_blks(x)
|
| 410 |
+
|
| 411 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
| 412 |
+
x = up(x)
|
| 413 |
+
x = x + enc_skip
|
| 414 |
+
x = decoder(x)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
x = self.ending(x)
|
| 418 |
+
if self.global_residual:
|
| 419 |
+
|
| 420 |
+
x = x + inp[:3,:,:,:]
|
| 421 |
+
else:
|
| 422 |
+
x
|
| 423 |
+
return x[:, :, :H, :W]
|
| 424 |
+
|
| 425 |
+
def check_image_size(self, x):
|
| 426 |
+
_, _, h, w = x.size()
|
| 427 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
| 428 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
| 429 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
| 430 |
+
return x
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def print_param_number(net):
|
| 434 |
+
print('#generator parameters:', sum(param.numel() for param in net.parameters()))
|
| 435 |
+
if __name__ == '__main__':
|
| 436 |
+
img_channel = 3
|
| 437 |
+
width = 32
|
| 438 |
+
|
| 439 |
+
# enc_blks = [2, 2, 4, 8]
|
| 440 |
+
# middle_blk_num = 12
|
| 441 |
+
# dec_blks = [2, 2, 2, 2]
|
| 442 |
+
|
| 443 |
+
# enc_blks = [2, 2, 4, 8]
|
| 444 |
+
# middle_blk_num = 12
|
| 445 |
+
# dec_blks = [2, 2, 2, 2]
|
| 446 |
+
|
| 447 |
+
# enc_blks = [1, 1, 1, 28]
|
| 448 |
+
# middle_blk_num = 1
|
| 449 |
+
# dec_blks = [1, 1, 1, 1]
|
| 450 |
+
|
| 451 |
+
enc_blks = [1, 1, 1, 28]
|
| 452 |
+
middle_blk_num = 1
|
| 453 |
+
dec_blks = [1, 1, 1, 1]
|
| 454 |
+
|
| 455 |
+
net = NAFNet_wDetHead(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
| 456 |
+
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks,global_residual = True,
|
| 457 |
+
concat= True, merge_manner= 2) #.cuda()
|
| 458 |
+
#print(net)
|
| 459 |
+
size = 352
|
| 460 |
+
input = torch.randn([1,3,128, 128])#.cuda() inp_shape = (5, 3, 128, 128)
|
| 461 |
+
spare = torch.randn([1,1,128, 128])
|
| 462 |
+
print(net(input, spare).size())
|
| 463 |
+
print_param_number(net)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
#net_local = NAFNetLocal()#.cuda()
|
| 468 |
+
|
| 469 |
+
#print_param_number(net)
|
| 470 |
+
# print(net_local(input).size())
|
| 471 |
+
# inp_shape = (3, 256, 256)
|
| 472 |
+
#
|
| 473 |
+
# from ptflops import get_model_complexity_info
|
| 474 |
+
#
|
| 475 |
+
# macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
|
| 476 |
+
#
|
| 477 |
+
# params = float(params[:-3])
|
| 478 |
+
# macs = float(macs[:-4])
|
| 479 |
+
#
|
| 480 |
+
# print(macs, params)
|
RDNet-main/RDNet-main/models/arch/RDnet_.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from models.arch.focalnet import build_focalnet
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from models.arch.modules_sig import ConvNextBlock, Decoder, LayerNorm, NAFBlock, SimDecoder, UpSampleConvnext
|
| 6 |
+
from models.arch.reverse_function import ReverseFunction
|
| 7 |
+
from timm.models.layers import trunc_normal_
|
| 8 |
+
|
| 9 |
+
class Fusion(nn.Module):
|
| 10 |
+
def __init__(self, level, channels, first_col) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
|
| 13 |
+
self.level = level
|
| 14 |
+
self.first_col = first_col
|
| 15 |
+
self.down = nn.Sequential(
|
| 16 |
+
nn.Conv2d(channels[level - 1], channels[level], kernel_size=2, stride=2),
|
| 17 |
+
LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
|
| 18 |
+
) if level in [1, 2, 3] else nn.Identity()
|
| 19 |
+
if not first_col:
|
| 20 |
+
self.up = UpSampleConvnext(1, channels[level + 1], channels[level]) if level in [0, 1, 2] else nn.Identity()
|
| 21 |
+
|
| 22 |
+
def forward(self, *args):
|
| 23 |
+
|
| 24 |
+
c_down, c_up = args
|
| 25 |
+
channels_dowm=c_down.size(1)
|
| 26 |
+
if self.first_col:
|
| 27 |
+
x_clean = self.down(c_down)
|
| 28 |
+
return x_clean
|
| 29 |
+
if c_up is not None:
|
| 30 |
+
channels_up=c_up.size(1)
|
| 31 |
+
if self.level == 3:
|
| 32 |
+
x_clean = self.down(c_down)
|
| 33 |
+
else:
|
| 34 |
+
x_clean = self.up(c_up) + self.down(c_down)
|
| 35 |
+
|
| 36 |
+
return x_clean
|
| 37 |
+
|
| 38 |
+
class Level(nn.Module):
|
| 39 |
+
def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0, block_type=ConvNextBlock) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
countlayer = sum(layers[:level])
|
| 42 |
+
expansion = 4
|
| 43 |
+
self.fusion = Fusion(level, channels, first_col)
|
| 44 |
+
modules = [block_type(channels[level], expansion * channels[level], channels[level], kernel_size=kernel_size,
|
| 45 |
+
layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer + i]) for i in
|
| 46 |
+
range(layers[level])]
|
| 47 |
+
self.blocks = nn.Sequential(*modules)
|
| 48 |
+
|
| 49 |
+
def forward(self, *args):
|
| 50 |
+
x = self.fusion(*args)
|
| 51 |
+
x_clean = self.blocks(x)
|
| 52 |
+
return x_clean
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SubNet(nn.Module):
|
| 56 |
+
def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory, block_type=ConvNextBlock) -> None:
|
| 57 |
+
super().__init__()
|
| 58 |
+
shortcut_scale_init_value = 0.5
|
| 59 |
+
self.save_memory = save_memory
|
| 60 |
+
self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
|
| 61 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
| 62 |
+
self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
|
| 63 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
| 64 |
+
self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
|
| 65 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
| 66 |
+
self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
|
| 67 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
| 68 |
+
|
| 69 |
+
self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
|
| 70 |
+
|
| 71 |
+
self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
|
| 72 |
+
|
| 73 |
+
self.level2 = Level(2, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
|
| 74 |
+
|
| 75 |
+
self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates, block_type=block_type)
|
| 76 |
+
|
| 77 |
+
def _forward_nonreverse(self, *args):
|
| 78 |
+
x, c0, c1, c2, c3 = args
|
| 79 |
+
c0 = self.alpha0 * c0 + self.level0(x, c1)
|
| 80 |
+
c1 = self.alpha1 * c1 + self.level1(c0, c2)
|
| 81 |
+
c2 = self.alpha2 * c2 + self.level2(c1, c3)
|
| 82 |
+
c3 = self.alpha3 * c3 + self.level3(c2, None)
|
| 83 |
+
return c0, c1, c2, c3
|
| 84 |
+
|
| 85 |
+
def _forward_reverse(self, *args):
|
| 86 |
+
x, c0, c1, c2, c3 = args
|
| 87 |
+
local_funs = [self.level0, self.level1, self.level2, self.level3]
|
| 88 |
+
alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
|
| 89 |
+
_, c0, c1, c2, c3 = ReverseFunction.apply(
|
| 90 |
+
local_funs, alpha, *args)
|
| 91 |
+
|
| 92 |
+
return c0, c1, c2, c3
|
| 93 |
+
|
| 94 |
+
def forward(self, *args):
|
| 95 |
+
|
| 96 |
+
self._clamp_abs(self.alpha0.data, 1e-3)
|
| 97 |
+
self._clamp_abs(self.alpha1.data, 1e-3)
|
| 98 |
+
self._clamp_abs(self.alpha2.data, 1e-3)
|
| 99 |
+
self._clamp_abs(self.alpha3.data, 1e-3)
|
| 100 |
+
if self.save_memory:
|
| 101 |
+
return self._forward_reverse(*args)
|
| 102 |
+
else:
|
| 103 |
+
return self._forward_nonreverse(*args)
|
| 104 |
+
|
| 105 |
+
def _clamp_abs(self, data, value):
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
sign = data.sign()
|
| 108 |
+
data.abs_().clamp_(value)
|
| 109 |
+
data *= sign
|
| 110 |
+
|
| 111 |
+
class StarReLU(nn.Module):
|
| 112 |
+
"""
|
| 113 |
+
StarReLU: s * relu(x) ** 2 + b
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, scale_value=1.0, bias_value=0.0,
|
| 116 |
+
scale_learnable=True, bias_learnable=True,
|
| 117 |
+
mode=None, inplace=True):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.inplace = inplace
|
| 120 |
+
self.relu = nn.ReLU(inplace=inplace)
|
| 121 |
+
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
| 122 |
+
requires_grad=scale_learnable)
|
| 123 |
+
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
| 124 |
+
requires_grad=bias_learnable)
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
return self.scale * self.relu(x)**2 + self.bias
|
| 127 |
+
|
| 128 |
+
class FullNet_NLP(nn.Module):
|
| 129 |
+
def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5,loss_col=4, kernel_size=3, num_classes=1000,
|
| 130 |
+
drop_path=0.0, save_memory=True, inter_supv=True, head_init_scale=None, pretrained_cols=16) -> None:
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.num_subnet = num_subnet
|
| 133 |
+
self.Loss_col=(loss_col+1)
|
| 134 |
+
self.inter_supv = inter_supv
|
| 135 |
+
self.channels = channels
|
| 136 |
+
self.layers = layers
|
| 137 |
+
self.stem_comp = nn.Sequential(
|
| 138 |
+
nn.Conv2d(3, channels[0], kernel_size=5, stride=2, padding=2),
|
| 139 |
+
LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
|
| 140 |
+
)
|
| 141 |
+
self.prompt=nn.Sequential(nn.Linear(in_features=6,out_features=512),
|
| 142 |
+
StarReLU(),
|
| 143 |
+
nn.Linear(in_features=512,out_features=channels[0]),
|
| 144 |
+
StarReLU(),
|
| 145 |
+
)
|
| 146 |
+
dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
|
| 147 |
+
for i in range(num_subnet):
|
| 148 |
+
first_col = True if i == 0 else False
|
| 149 |
+
self.add_module(f'subnet{str(i)}', SubNet(
|
| 150 |
+
channels, layers, kernel_size, first_col,
|
| 151 |
+
dp_rates=dp_rate, save_memory=save_memory,
|
| 152 |
+
block_type=NAFBlock))
|
| 153 |
+
|
| 154 |
+
channels.reverse()
|
| 155 |
+
self.decoder_blocks = nn.ModuleList(
|
| 156 |
+
[Decoder(depth=[1, 1, 1, 1], dim=channels, block_type=NAFBlock, kernel_size=3) for _ in
|
| 157 |
+
range(3)])
|
| 158 |
+
|
| 159 |
+
self.apply(self._init_weights)
|
| 160 |
+
self.baseball = build_focalnet('focalnet_L_384_22k_fl4')
|
| 161 |
+
self.baseball_adapter = nn.ModuleList()
|
| 162 |
+
self.baseball_adapter.append(nn.Conv2d(192, 64, kernel_size=1))
|
| 163 |
+
self.baseball_adapter.append(nn.Conv2d(192, 64, kernel_size=1))
|
| 164 |
+
self.baseball_adapter.append(nn.Conv2d(192 * 2, 64 * 2, kernel_size=1))
|
| 165 |
+
self.baseball_adapter.append(nn.Conv2d(192 * 4, 64 * 4, kernel_size=1))
|
| 166 |
+
self.baseball_adapter.append(nn.Conv2d(192 * 8, 64 * 8, kernel_size=1))
|
| 167 |
+
|
| 168 |
+
def forward(self, x_in,alpha,prompt=True):
|
| 169 |
+
x_cls_out = []
|
| 170 |
+
x_img_out = []
|
| 171 |
+
c0, c1, c2, c3 = 0, 0, 0, 0
|
| 172 |
+
interval = self.num_subnet // 4
|
| 173 |
+
|
| 174 |
+
x_base, x_stem = self.baseball(x_in)
|
| 175 |
+
c0, c1, c2, c3 = x_base
|
| 176 |
+
x_stem = self.baseball_adapter[0](x_stem)
|
| 177 |
+
c0, c1, c2, c3 = self.baseball_adapter[1](c0),\
|
| 178 |
+
self.baseball_adapter[2](c1),\
|
| 179 |
+
self.baseball_adapter[3](c2),\
|
| 180 |
+
self.baseball_adapter[4](c3)
|
| 181 |
+
if prompt==True:
|
| 182 |
+
prompt_alpha=self.prompt(alpha)
|
| 183 |
+
prompt_alpha = prompt_alpha.unsqueeze(-1).unsqueeze(-1)
|
| 184 |
+
x=prompt_alpha*x_stem
|
| 185 |
+
else :
|
| 186 |
+
x = x_stem
|
| 187 |
+
for i in range(self.num_subnet):
|
| 188 |
+
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
|
| 189 |
+
if i>(self.num_subnet-self.Loss_col):
|
| 190 |
+
x_img_out.append(torch.cat([x_in, x_in], dim=-3) - self.decoder_blocks[-1](c3, c2, c1, c0) )
|
| 191 |
+
|
| 192 |
+
return x_cls_out, x_img_out
|
| 193 |
+
|
| 194 |
+
def _init_weights(self, module):
|
| 195 |
+
if isinstance(module, nn.Conv2d):
|
| 196 |
+
trunc_normal_(module.weight, std=.02)
|
| 197 |
+
nn.init.constant_(module.bias, 0)
|
| 198 |
+
elif isinstance(module, nn.Linear):
|
| 199 |
+
trunc_normal_(module.weight, std=.02)
|
| 200 |
+
nn.init.constant_(module.bias, 0)
|
| 201 |
+
|
| 202 |
+
|
RDNet-main/RDNet-main/models/arch/__pycache__/RDnet_.cpython-38.pyc
ADDED
|
Binary file (8.23 kB). View file
|
|
|
RDNet-main/RDNet-main/models/arch/__pycache__/classifier.cpython-38.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
RDNet-main/RDNet-main/models/arch/__pycache__/focalnet.cpython-38.pyc
ADDED
|
Binary file (15.8 kB). View file
|
|
|
RDNet-main/RDNet-main/models/arch/__pycache__/modules_sig.cpython-38.pyc
ADDED
|
Binary file (11 kB). View file
|
|
|
RDNet-main/RDNet-main/models/arch/__pycache__/reverse_function.cpython-38.pyc
ADDED
|
Binary file (4.74 kB). View file
|
|
|
RDNet-main/RDNet-main/models/arch/classifier.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import timm
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
class PretrainedConvNext(nn.Module):
|
| 6 |
+
def __init__(self, model_name='convnext_base', pretrained=True):
|
| 7 |
+
super(PretrainedConvNext, self).__init__()
|
| 8 |
+
# Load the pretrained ConvNext model from timm
|
| 9 |
+
self.model = timm.create_model(model_name, pretrained=False, num_classes=0)
|
| 10 |
+
self.head = nn.Linear(768, 6)
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
with torch.no_grad():
|
| 13 |
+
cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True)
|
| 14 |
+
# Forward pass through the ConvNext model
|
| 15 |
+
out = self.model(cls_input)
|
| 16 |
+
out = self.head(out)
|
| 17 |
+
# alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\
|
| 18 |
+
# out[..., 3:].unsqueeze(-1).unsqueeze(-1)
|
| 19 |
+
|
| 20 |
+
#out = alpha * x + beta
|
| 21 |
+
# print(out.shape)
|
| 22 |
+
return out#alpha,beta#out #out[..., :3], out[..., 3:]
|
| 23 |
+
class PretrainedConvNext_e2e(nn.Module):
|
| 24 |
+
def __init__(self, model_name='convnext_base', pretrained=True):
|
| 25 |
+
super(PretrainedConvNext_e2e, self).__init__()
|
| 26 |
+
# Load the pretrained ConvNext model from timm
|
| 27 |
+
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
|
| 28 |
+
self.head = nn.Linear(768, 6)
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True)
|
| 32 |
+
# Forward pass through the ConvNext model
|
| 33 |
+
out = self.model(cls_input)
|
| 34 |
+
out = self.head(out)
|
| 35 |
+
alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\
|
| 36 |
+
out[..., 3:].unsqueeze(-1).unsqueeze(-1)
|
| 37 |
+
|
| 38 |
+
out = alpha * x + beta
|
| 39 |
+
#print(out.shape)
|
| 40 |
+
return out#alpha,beta#out #out[..., :3], out[..., 3:]
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
model = PretrainedConvNext('convnext_small_in22k')
|
| 44 |
+
print("Testing PretrainedConvNext model...")
|
| 45 |
+
# Assuming a dummy input tensor of size (1, 3, 224, 224) similar to an image in the ImageNet dataset
|
| 46 |
+
dummy_input = torch.randn(20, 3, 224, 224)
|
| 47 |
+
output_x, output_y = model(dummy_input)
|
| 48 |
+
print("Output shape:", output_x.shape)
|
| 49 |
+
print("Test completed successfully.")
|
RDNet-main/RDNet-main/models/arch/decode.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
def make_layers(cfg, batch_norm=False):
|
| 4 |
+
layers = []
|
| 5 |
+
in_channels = 3
|
| 6 |
+
for v in cfg:
|
| 7 |
+
if v == 'M':
|
| 8 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 9 |
+
else:
|
| 10 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 11 |
+
if batch_norm:
|
| 12 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
| 13 |
+
else:
|
| 14 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 15 |
+
in_channels = v
|
| 16 |
+
return nn.Sequential(*layers)
|
| 17 |
+
|
| 18 |
+
cfgs = {
|
| 19 |
+
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class VGG(nn.Module):
|
| 24 |
+
def __init__(self,features):
|
| 25 |
+
super(VGG, self).__init__()
|
| 26 |
+
self.features = features
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
x = self.features(x)
|
| 30 |
+
|
| 31 |
+
def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
|
| 32 |
+
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
def encoder(pretrained=False, progress=True, **kwargs):
|
| 36 |
+
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
|
RDNet-main/RDNet-main/models/arch/focalnet.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# FocalNet for Semantic Segmentation
|
| 3 |
+
# Copyright (c) 2022 Microsoft
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Jianwei Yang
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import math
|
| 8 |
+
import time
|
| 9 |
+
import numpy as np
|
| 10 |
+
import json
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch.utils.checkpoint as checkpoint
|
| 15 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
""" Multilayer perceptron."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 21 |
+
super().__init__()
|
| 22 |
+
out_features = out_features or in_features
|
| 23 |
+
hidden_features = hidden_features or in_features
|
| 24 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 25 |
+
self.act = act_layer()
|
| 26 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 27 |
+
self.drop = nn.Dropout(drop)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
x = self.fc1(x)
|
| 31 |
+
x = self.act(x)
|
| 32 |
+
x = self.drop(x)
|
| 33 |
+
x = self.fc2(x)
|
| 34 |
+
x = self.drop(x)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
class FocalModulation(nn.Module):
|
| 38 |
+
""" Focal Modulation
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
dim (int): Number of input channels.
|
| 42 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
| 43 |
+
focal_level (int): Number of focal levels
|
| 44 |
+
focal_window (int): Focal window size at focal level 1
|
| 45 |
+
focal_factor (int, default=2): Step to increase the focal window
|
| 46 |
+
use_postln (bool, default=False): Whether use post-modulation layernorm
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False,
|
| 50 |
+
use_postln_in_modulation=False, normalize_modulator=False):
|
| 51 |
+
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.dim = dim
|
| 54 |
+
|
| 55 |
+
# specific args for focalv3
|
| 56 |
+
self.focal_level = focal_level
|
| 57 |
+
self.focal_window = focal_window
|
| 58 |
+
self.focal_factor = focal_factor
|
| 59 |
+
self.use_postln_in_modulation = use_postln_in_modulation
|
| 60 |
+
self.normalize_modulator = normalize_modulator
|
| 61 |
+
|
| 62 |
+
self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
|
| 63 |
+
self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
|
| 64 |
+
|
| 65 |
+
self.act = nn.GELU()
|
| 66 |
+
self.proj = nn.Linear(dim, dim)
|
| 67 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 68 |
+
self.focal_layers = nn.ModuleList()
|
| 69 |
+
|
| 70 |
+
if self.use_postln_in_modulation:
|
| 71 |
+
self.ln = nn.LayerNorm(dim)
|
| 72 |
+
|
| 73 |
+
for k in range(self.focal_level):
|
| 74 |
+
kernel_size = self.focal_factor*k + self.focal_window
|
| 75 |
+
self.focal_layers.append(
|
| 76 |
+
nn.Sequential(
|
| 77 |
+
nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
|
| 78 |
+
padding=kernel_size//2, bias=False),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
""" Forward function.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
x: input features with shape of (B, H, W, C)
|
| 88 |
+
"""
|
| 89 |
+
B, nH, nW, C = x.shape
|
| 90 |
+
x = self.f(x)
|
| 91 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 92 |
+
q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
|
| 93 |
+
|
| 94 |
+
ctx_all = 0
|
| 95 |
+
for l in range(self.focal_level):
|
| 96 |
+
ctx = self.focal_layers[l](ctx)
|
| 97 |
+
ctx_all = ctx_all + ctx*gates[:, l:l+1]
|
| 98 |
+
ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
|
| 99 |
+
ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
|
| 100 |
+
if self.normalize_modulator:
|
| 101 |
+
ctx_all = ctx_all / (self.focal_level+1)
|
| 102 |
+
|
| 103 |
+
x_out = q * self.h(ctx_all)
|
| 104 |
+
x_out = x_out.permute(0, 2, 3, 1).contiguous()
|
| 105 |
+
if self.use_postln_in_modulation:
|
| 106 |
+
x_out = self.ln(x_out)
|
| 107 |
+
x_out = self.proj(x_out)
|
| 108 |
+
x_out = self.proj_drop(x_out)
|
| 109 |
+
return x_out
|
| 110 |
+
|
| 111 |
+
class FocalModulationBlock(nn.Module):
|
| 112 |
+
""" Focal Modulation Block.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
dim (int): Number of input channels.
|
| 116 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
| 117 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 118 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
| 119 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
| 120 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 121 |
+
focal_level (int): number of focal levels
|
| 122 |
+
focal_window (int): focal kernel size at level 1
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
|
| 126 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
| 127 |
+
focal_level=2, focal_window=9,
|
| 128 |
+
use_postln=False, use_postln_in_modulation=False,
|
| 129 |
+
normalize_modulator=False,
|
| 130 |
+
use_layerscale=False,
|
| 131 |
+
layerscale_value=1e-4):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.dim = dim
|
| 134 |
+
self.mlp_ratio = mlp_ratio
|
| 135 |
+
self.focal_window = focal_window
|
| 136 |
+
self.focal_level = focal_level
|
| 137 |
+
self.use_postln = use_postln
|
| 138 |
+
self.use_layerscale = use_layerscale
|
| 139 |
+
|
| 140 |
+
self.norm1 = norm_layer(dim)
|
| 141 |
+
self.modulation = FocalModulation(
|
| 142 |
+
dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop,
|
| 143 |
+
use_postln_in_modulation=use_postln_in_modulation,
|
| 144 |
+
normalize_modulator=normalize_modulator,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 148 |
+
self.norm2 = norm_layer(dim)
|
| 149 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 150 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 151 |
+
|
| 152 |
+
self.H = None
|
| 153 |
+
self.W = None
|
| 154 |
+
|
| 155 |
+
self.gamma_1 = 1.0
|
| 156 |
+
self.gamma_2 = 1.0
|
| 157 |
+
if self.use_layerscale:
|
| 158 |
+
self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
|
| 159 |
+
self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
""" Forward function.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 166 |
+
H, W: Spatial resolution of the input feature.
|
| 167 |
+
"""
|
| 168 |
+
B, L, C = x.shape
|
| 169 |
+
H, W = self.H, self.W
|
| 170 |
+
assert L == H * W, "input feature has wrong size"
|
| 171 |
+
|
| 172 |
+
shortcut = x
|
| 173 |
+
if not self.use_postln:
|
| 174 |
+
x = self.norm1(x)
|
| 175 |
+
x = x.view(B, H, W, C)
|
| 176 |
+
|
| 177 |
+
# FM
|
| 178 |
+
x = self.modulation(x).view(B, H * W, C)
|
| 179 |
+
if self.use_postln:
|
| 180 |
+
x = self.norm1(x)
|
| 181 |
+
|
| 182 |
+
# FFN
|
| 183 |
+
x = shortcut + self.drop_path(self.gamma_1 * x)
|
| 184 |
+
|
| 185 |
+
if self.use_postln:
|
| 186 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
| 187 |
+
else:
|
| 188 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 189 |
+
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
class BasicLayer(nn.Module):
|
| 193 |
+
""" A basic focal modulation layer for one stage.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
dim (int): Number of feature channels
|
| 197 |
+
depth (int): Depths of this stage.
|
| 198 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 199 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
| 200 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
| 201 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
| 202 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
| 203 |
+
focal_level (int): Number of focal levels
|
| 204 |
+
focal_window (int): Focal window size at focal level 1
|
| 205 |
+
use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
|
| 206 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(self,
|
| 210 |
+
dim,
|
| 211 |
+
depth,
|
| 212 |
+
mlp_ratio=4.,
|
| 213 |
+
drop=0.,
|
| 214 |
+
drop_path=0.,
|
| 215 |
+
norm_layer=nn.LayerNorm,
|
| 216 |
+
downsample=None,
|
| 217 |
+
focal_window=9,
|
| 218 |
+
focal_level=2,
|
| 219 |
+
use_conv_embed=False,
|
| 220 |
+
use_postln=False,
|
| 221 |
+
use_postln_in_modulation=False,
|
| 222 |
+
normalize_modulator=False,
|
| 223 |
+
use_layerscale=False,
|
| 224 |
+
use_checkpoint=False
|
| 225 |
+
):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.depth = depth
|
| 228 |
+
self.use_checkpoint = use_checkpoint
|
| 229 |
+
|
| 230 |
+
# build blocks
|
| 231 |
+
self.blocks = nn.ModuleList([
|
| 232 |
+
FocalModulationBlock(
|
| 233 |
+
dim=dim,
|
| 234 |
+
mlp_ratio=mlp_ratio,
|
| 235 |
+
drop=drop,
|
| 236 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 237 |
+
focal_window=focal_window,
|
| 238 |
+
focal_level=focal_level,
|
| 239 |
+
use_postln=use_postln,
|
| 240 |
+
use_postln_in_modulation=use_postln_in_modulation,
|
| 241 |
+
normalize_modulator=normalize_modulator,
|
| 242 |
+
use_layerscale=use_layerscale,
|
| 243 |
+
norm_layer=norm_layer)
|
| 244 |
+
for i in range(depth)])
|
| 245 |
+
|
| 246 |
+
# patch merging layer
|
| 247 |
+
if downsample is not None:
|
| 248 |
+
self.downsample = downsample(
|
| 249 |
+
patch_size=2,
|
| 250 |
+
in_chans=dim, embed_dim=2*dim,
|
| 251 |
+
use_conv_embed=use_conv_embed,
|
| 252 |
+
norm_layer=norm_layer,
|
| 253 |
+
is_stem=False
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
self.downsample = None
|
| 258 |
+
|
| 259 |
+
def forward(self, x, H, W):
|
| 260 |
+
""" Forward function.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
x: Input feature, tensor size (B, H*W, C).
|
| 264 |
+
H, W: Spatial resolution of the input feature.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
for blk in self.blocks:
|
| 268 |
+
blk.H, blk.W = H, W
|
| 269 |
+
if self.use_checkpoint:
|
| 270 |
+
x = checkpoint.checkpoint(blk, x)
|
| 271 |
+
else:
|
| 272 |
+
x = blk(x)
|
| 273 |
+
if self.downsample is not None:
|
| 274 |
+
x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
|
| 275 |
+
x_down = self.downsample(x_reshaped)
|
| 276 |
+
x_down = x_down.flatten(2).transpose(1, 2)
|
| 277 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
| 278 |
+
return x, H, W, x_down, Wh, Ww
|
| 279 |
+
else:
|
| 280 |
+
return x, H, W, x, H, W
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class PatchEmbed(nn.Module):
|
| 284 |
+
""" Image to Patch Embedding
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
patch_size (int): Patch token size. Default: 4.
|
| 288 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 289 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 290 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
| 291 |
+
use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
|
| 292 |
+
is_stem (bool): Is the stem block or not.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):
|
| 296 |
+
super().__init__()
|
| 297 |
+
patch_size = to_2tuple(patch_size)
|
| 298 |
+
self.patch_size = patch_size
|
| 299 |
+
|
| 300 |
+
self.in_chans = in_chans
|
| 301 |
+
self.embed_dim = embed_dim
|
| 302 |
+
|
| 303 |
+
if use_conv_embed:
|
| 304 |
+
# if we choose to use conv embedding, then we treat the stem and non-stem differently
|
| 305 |
+
if is_stem:
|
| 306 |
+
kernel_size = 7; padding = 3; stride = 2
|
| 307 |
+
else:
|
| 308 |
+
kernel_size = 3; padding = 1; stride = 2
|
| 309 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
| 310 |
+
else:
|
| 311 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 312 |
+
|
| 313 |
+
if norm_layer is not None:
|
| 314 |
+
self.norm = norm_layer(embed_dim)
|
| 315 |
+
else:
|
| 316 |
+
self.norm = None
|
| 317 |
+
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
"""Forward function."""
|
| 320 |
+
_, _, H, W = x.size()
|
| 321 |
+
if W % self.patch_size[1] != 0:
|
| 322 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
| 323 |
+
if H % self.patch_size[0] != 0:
|
| 324 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
| 325 |
+
|
| 326 |
+
x = self.proj(x) # B C Wh Ww
|
| 327 |
+
if self.norm is not None:
|
| 328 |
+
Wh, Ww = x.size(2), x.size(3)
|
| 329 |
+
x = x.flatten(2).transpose(1, 2)
|
| 330 |
+
x = self.norm(x)
|
| 331 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
| 332 |
+
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class FocalNet(nn.Module):
|
| 337 |
+
""" FocalNet backbone.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
| 341 |
+
used in absolute postion embedding. Default 224.
|
| 342 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
| 343 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 344 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
| 345 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
| 346 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
| 347 |
+
drop_rate (float): Dropout rate.
|
| 348 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
| 349 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
| 350 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
| 351 |
+
out_indices (Sequence[int]): Output from which stages.
|
| 352 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
| 353 |
+
-1 means not freezing any parameters.
|
| 354 |
+
focal_levels (Sequence[int]): Number of focal levels at four stages
|
| 355 |
+
focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
|
| 356 |
+
use_conv_embed (bool): Whether use overlapped convolution for patch embedding
|
| 357 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(self,
|
| 361 |
+
pretrain_img_size=1600,
|
| 362 |
+
patch_size=4,
|
| 363 |
+
in_chans=3,
|
| 364 |
+
embed_dim=96,
|
| 365 |
+
depths=[2, 2, 6, 2],
|
| 366 |
+
mlp_ratio=4.,
|
| 367 |
+
drop_rate=0.,
|
| 368 |
+
drop_path_rate=0.3, # 0.3 or 0.4 works better for large+ models
|
| 369 |
+
norm_layer=nn.LayerNorm,
|
| 370 |
+
patch_norm=True,
|
| 371 |
+
out_indices=(0, 1, 2, 3),
|
| 372 |
+
frozen_stages=-1,
|
| 373 |
+
focal_levels=[3,3,3,3],
|
| 374 |
+
focal_windows=[3,3,3,3],
|
| 375 |
+
use_conv_embed=False,
|
| 376 |
+
use_postln=False,
|
| 377 |
+
use_postln_in_modulation=False,
|
| 378 |
+
use_layerscale=False,
|
| 379 |
+
normalize_modulator=False,
|
| 380 |
+
use_checkpoint=False,
|
| 381 |
+
):
|
| 382 |
+
super().__init__()
|
| 383 |
+
|
| 384 |
+
self.pretrain_img_size = pretrain_img_size
|
| 385 |
+
self.num_layers = len(depths)
|
| 386 |
+
self.embed_dim = embed_dim
|
| 387 |
+
self.patch_norm = patch_norm
|
| 388 |
+
self.out_indices = out_indices
|
| 389 |
+
self.frozen_stages = frozen_stages
|
| 390 |
+
|
| 391 |
+
# split image into non-overlapping patches
|
| 392 |
+
self.patch_embed = PatchEmbed(
|
| 393 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 394 |
+
norm_layer=norm_layer if self.patch_norm else None,
|
| 395 |
+
use_conv_embed=use_conv_embed, is_stem=True)
|
| 396 |
+
|
| 397 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 398 |
+
|
| 399 |
+
# stochastic depth
|
| 400 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 401 |
+
|
| 402 |
+
# build layers
|
| 403 |
+
self.layers = nn.ModuleList()
|
| 404 |
+
for i_layer in range(self.num_layers):
|
| 405 |
+
layer = BasicLayer(
|
| 406 |
+
dim=int(embed_dim * 2 ** i_layer),
|
| 407 |
+
depth=depths[i_layer],
|
| 408 |
+
mlp_ratio=mlp_ratio,
|
| 409 |
+
drop=drop_rate,
|
| 410 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 411 |
+
norm_layer=norm_layer,
|
| 412 |
+
downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
|
| 413 |
+
focal_window=focal_windows[i_layer],
|
| 414 |
+
focal_level=focal_levels[i_layer],
|
| 415 |
+
use_conv_embed=use_conv_embed,
|
| 416 |
+
use_postln=use_postln,
|
| 417 |
+
use_postln_in_modulation=use_postln_in_modulation,
|
| 418 |
+
normalize_modulator=normalize_modulator,
|
| 419 |
+
use_layerscale=use_layerscale,
|
| 420 |
+
use_checkpoint=use_checkpoint)
|
| 421 |
+
self.layers.append(layer)
|
| 422 |
+
|
| 423 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
| 424 |
+
self.num_features = num_features
|
| 425 |
+
|
| 426 |
+
# add a norm layer for each output
|
| 427 |
+
for i_layer in out_indices:
|
| 428 |
+
layer = norm_layer(num_features[i_layer])
|
| 429 |
+
layer_name = f'norm{i_layer}'
|
| 430 |
+
self.add_module(layer_name, layer)
|
| 431 |
+
|
| 432 |
+
self._freeze_stages()
|
| 433 |
+
|
| 434 |
+
def _freeze_stages(self):
|
| 435 |
+
if self.frozen_stages >= 0:
|
| 436 |
+
self.patch_embed.eval()
|
| 437 |
+
for param in self.patch_embed.parameters():
|
| 438 |
+
param.requires_grad = False
|
| 439 |
+
|
| 440 |
+
if self.frozen_stages >= 2:
|
| 441 |
+
self.pos_drop.eval()
|
| 442 |
+
for i in range(0, self.frozen_stages - 1):
|
| 443 |
+
m = self.layers[i]
|
| 444 |
+
m.eval()
|
| 445 |
+
for param in m.parameters():
|
| 446 |
+
param.requires_grad = False
|
| 447 |
+
|
| 448 |
+
def init_weights(self, pretrained=None):
|
| 449 |
+
"""Initialize the weights in backbone.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 453 |
+
Defaults to None.
|
| 454 |
+
"""
|
| 455 |
+
|
| 456 |
+
def _init_weights(m):
|
| 457 |
+
if isinstance(m, nn.Linear):
|
| 458 |
+
trunc_normal_(m.weight, std=.02)
|
| 459 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 460 |
+
nn.init.constant_(m.bias, 0)
|
| 461 |
+
elif isinstance(m, nn.LayerNorm):
|
| 462 |
+
nn.init.constant_(m.bias, 0)
|
| 463 |
+
nn.init.constant_(m.weight, 1.0)
|
| 464 |
+
|
| 465 |
+
if isinstance(pretrained, str):
|
| 466 |
+
self.apply(_init_weights)
|
| 467 |
+
logger = get_root_logger()
|
| 468 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
| 469 |
+
elif pretrained is None:
|
| 470 |
+
self.apply(_init_weights)
|
| 471 |
+
else:
|
| 472 |
+
raise TypeError('pretrained must be a str or None')
|
| 473 |
+
|
| 474 |
+
def forward(self, x):
|
| 475 |
+
"""Forward function."""
|
| 476 |
+
x_emb = self.patch_embed(x)
|
| 477 |
+
Wh, Ww = x_emb.size(2), x_emb.size(3)
|
| 478 |
+
|
| 479 |
+
x = x_emb.flatten(2).transpose(1, 2)
|
| 480 |
+
x = self.pos_drop(x)
|
| 481 |
+
|
| 482 |
+
outs = []
|
| 483 |
+
for i in range(self.num_layers):
|
| 484 |
+
layer = self.layers[i]
|
| 485 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
| 486 |
+
if i in self.out_indices:
|
| 487 |
+
norm_layer = getattr(self, f'norm{i}')
|
| 488 |
+
x_out = norm_layer(x_out)
|
| 489 |
+
|
| 490 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
| 491 |
+
outs.append(out)
|
| 492 |
+
return outs, x_emb
|
| 493 |
+
|
| 494 |
+
def train(self, mode=True):
|
| 495 |
+
"""Convert the model into training mode while keep layers freezed."""
|
| 496 |
+
super(FocalNet, self).train(mode)
|
| 497 |
+
self._freeze_stages()
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def build_focalnet(modelname, **kw):
|
| 502 |
+
assert modelname in [
|
| 503 |
+
'focalnet_L_384_22k',
|
| 504 |
+
'focalnet_L_384_22k_fl4',
|
| 505 |
+
'focalnet_XL_384_22k',
|
| 506 |
+
'focalnet_XL_384_22k_fl4',
|
| 507 |
+
'focalnet_H_224_22k',
|
| 508 |
+
'focalnet_H_224_22k_fl4',
|
| 509 |
+
]
|
| 510 |
+
|
| 511 |
+
if 'focal_levels' in kw:
|
| 512 |
+
kw['focal_levels'] = [kw['focal_levels']] * 4
|
| 513 |
+
|
| 514 |
+
if 'focal_windows' in kw:
|
| 515 |
+
kw['focal_windows'] = [kw['focal_windows']] * 4
|
| 516 |
+
|
| 517 |
+
model_para_dict = {
|
| 518 |
+
'focalnet_L_384_22k': dict(
|
| 519 |
+
embed_dim=192,
|
| 520 |
+
depths=[ 2, 2, 18, 2 ],
|
| 521 |
+
focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
|
| 522 |
+
focal_windows=kw.get('focal_windows', [5, 5, 5, 5]),
|
| 523 |
+
use_conv_embed=True,
|
| 524 |
+
use_postln=True,
|
| 525 |
+
use_postln_in_modulation=False,
|
| 526 |
+
use_layerscale=True,
|
| 527 |
+
normalize_modulator=False,
|
| 528 |
+
),
|
| 529 |
+
'focalnet_L_384_22k_fl4': dict(
|
| 530 |
+
embed_dim=192,
|
| 531 |
+
depths=[ 2, 2, 18, 2 ],
|
| 532 |
+
focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
|
| 533 |
+
focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
|
| 534 |
+
use_conv_embed=True,
|
| 535 |
+
use_postln=True,
|
| 536 |
+
use_postln_in_modulation=False,
|
| 537 |
+
use_layerscale=True,
|
| 538 |
+
normalize_modulator=True,
|
| 539 |
+
),
|
| 540 |
+
'focalnet_XL_384_22k': dict(
|
| 541 |
+
embed_dim=256,
|
| 542 |
+
depths=[ 2, 2, 18, 2 ],
|
| 543 |
+
focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
|
| 544 |
+
focal_windows=kw.get('focal_windows', [5, 5, 5, 5]),
|
| 545 |
+
use_conv_embed=True,
|
| 546 |
+
use_postln=True,
|
| 547 |
+
use_postln_in_modulation=False,
|
| 548 |
+
use_layerscale=True,
|
| 549 |
+
normalize_modulator=False,
|
| 550 |
+
),
|
| 551 |
+
'focalnet_XL_384_22k_fl4': dict(
|
| 552 |
+
embed_dim=256,
|
| 553 |
+
depths=[ 2, 2, 18, 2 ],
|
| 554 |
+
focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
|
| 555 |
+
focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
|
| 556 |
+
use_conv_embed=True,
|
| 557 |
+
use_postln=True,
|
| 558 |
+
use_postln_in_modulation=False,
|
| 559 |
+
use_layerscale=True,
|
| 560 |
+
normalize_modulator=True,
|
| 561 |
+
),
|
| 562 |
+
'focalnet_H_224_22k': dict(
|
| 563 |
+
embed_dim=352,
|
| 564 |
+
depths=[ 2, 2, 18, 2 ],
|
| 565 |
+
focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
|
| 566 |
+
focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
|
| 567 |
+
use_conv_embed=True,
|
| 568 |
+
use_postln=True,
|
| 569 |
+
use_layerscale=True,
|
| 570 |
+
use_postln_in_modulation=True,
|
| 571 |
+
normalize_modulator=False,
|
| 572 |
+
),
|
| 573 |
+
'focalnet_H_224_22k_fl4': dict(
|
| 574 |
+
embed_dim=352,
|
| 575 |
+
depths=[ 2, 2, 18, 2 ],
|
| 576 |
+
focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
|
| 577 |
+
focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
|
| 578 |
+
use_conv_embed=True,
|
| 579 |
+
use_postln=True,
|
| 580 |
+
use_postln_in_modulation=True,
|
| 581 |
+
use_layerscale=True,
|
| 582 |
+
normalize_modulator=False,
|
| 583 |
+
),
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
kw_cgf = model_para_dict[modelname]
|
| 587 |
+
kw_cgf.update(kw)
|
| 588 |
+
model = FocalNet(**kw_cgf)
|
| 589 |
+
return model
|
RDNet-main/RDNet-main/models/arch/modules_sig.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Reversible Column Networks
|
| 3 |
+
# Copyright (c) 2022 Megvii Inc.
|
| 4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
| 5 |
+
# Written by Yuxuan Cai
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import imp
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from timm.models.layers import DropPath
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LayerNormFunction(torch.autograd.Function):
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def forward(ctx, x, weight, bias, eps):
|
| 21 |
+
ctx.eps = eps
|
| 22 |
+
N, C, H, W = x.size()
|
| 23 |
+
mu = x.mean(1, keepdim=True)
|
| 24 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
| 25 |
+
y = (x - mu) / (var + eps).sqrt()
|
| 26 |
+
ctx.save_for_backward(y, var, weight)
|
| 27 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
| 28 |
+
return y
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def backward(ctx, grad_output):
|
| 32 |
+
eps = ctx.eps
|
| 33 |
+
|
| 34 |
+
N, C, H, W = grad_output.size()
|
| 35 |
+
y, var, weight = ctx.saved_variables
|
| 36 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
| 37 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
| 38 |
+
|
| 39 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
| 40 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
| 41 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
| 42 |
+
dim=0), None
|
| 43 |
+
|
| 44 |
+
class LayerNorm2d(nn.Module):
|
| 45 |
+
|
| 46 |
+
def __init__(self, channels, eps=1e-6):
|
| 47 |
+
super(LayerNorm2d, self).__init__()
|
| 48 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
| 49 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
| 50 |
+
self.eps = eps
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
| 54 |
+
|
| 55 |
+
class SimpleGate(nn.Module):
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x1, x2 = x.chunk(2, dim=1)
|
| 58 |
+
return x1 * x2
|
| 59 |
+
|
| 60 |
+
class NAFBlock(nn.Module):
|
| 61 |
+
def __init__(self, dim, expand_dim, out_dim, kernel_size=3, layer_scale_init_value=1e-6, drop_path=0.):
|
| 62 |
+
super().__init__()
|
| 63 |
+
drop_out_rate = 0.
|
| 64 |
+
dw_channel = expand_dim
|
| 65 |
+
self.conv1 = nn.Conv2d(in_channels=dim, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 66 |
+
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=kernel_size, padding=1, stride=1, groups=dw_channel,
|
| 67 |
+
bias=True)
|
| 68 |
+
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 69 |
+
|
| 70 |
+
# Simplified Channel Attention
|
| 71 |
+
self.sca = nn.Sequential(
|
| 72 |
+
nn.AdaptiveAvgPool2d(1),
|
| 73 |
+
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
|
| 74 |
+
groups=1, bias=True),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# SimpleGate
|
| 78 |
+
self.sg = SimpleGate()
|
| 79 |
+
|
| 80 |
+
ffn_channel = expand_dim
|
| 81 |
+
self.conv4 = nn.Conv2d(in_channels=dim, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 82 |
+
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=out_dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
| 83 |
+
|
| 84 |
+
self.norm1 = LayerNorm2d(dim)
|
| 85 |
+
self.norm2 = LayerNorm2d(dim)
|
| 86 |
+
|
| 87 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 88 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
| 89 |
+
|
| 90 |
+
self.beta = nn.Parameter(torch.ones((1, dim, 1, 1)) * layer_scale_init_value, requires_grad=True)
|
| 91 |
+
self.gamma = nn.Parameter(torch.ones((1, dim, 1, 1)) * layer_scale_init_value, requires_grad=True)
|
| 92 |
+
|
| 93 |
+
def forward(self, inp):
|
| 94 |
+
x = inp
|
| 95 |
+
|
| 96 |
+
x = self.norm1(x)
|
| 97 |
+
|
| 98 |
+
x = self.conv1(x)
|
| 99 |
+
x = self.conv2(x)
|
| 100 |
+
x = self.sg(x)
|
| 101 |
+
x = x * self.sca(x)
|
| 102 |
+
x = self.conv3(x)
|
| 103 |
+
|
| 104 |
+
x = self.dropout1(x)
|
| 105 |
+
|
| 106 |
+
y = inp + x * self.beta
|
| 107 |
+
|
| 108 |
+
x = self.conv4(self.norm2(y))
|
| 109 |
+
x = self.sg(x)
|
| 110 |
+
x = self.conv5(x)
|
| 111 |
+
|
| 112 |
+
x = self.dropout2(x)
|
| 113 |
+
|
| 114 |
+
return y + x * self.gamma
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class UpSampleConvnext(nn.Module):
|
| 118 |
+
def __init__(self, ratio, inchannel, outchannel):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.ratio = ratio
|
| 121 |
+
self.channel_reschedule = nn.Sequential(
|
| 122 |
+
# LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
|
| 123 |
+
nn.Linear(inchannel, outchannel),
|
| 124 |
+
LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
|
| 125 |
+
self.upsample = nn.Upsample(scale_factor=2**ratio, mode='bilinear')
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
x = x.permute(0, 2, 3, 1)
|
| 128 |
+
x = self.channel_reschedule(x)
|
| 129 |
+
x = x = x.permute(0, 3, 1, 2)
|
| 130 |
+
|
| 131 |
+
return self.upsample(x)
|
| 132 |
+
|
| 133 |
+
class LayerNorm(nn.Module):
|
| 134 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 135 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 136 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 137 |
+
with shape (batch_size, channels, height, width).
|
| 138 |
+
"""
|
| 139 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first", elementwise_affine = True):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.elementwise_affine = elementwise_affine
|
| 142 |
+
if elementwise_affine:
|
| 143 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 144 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 145 |
+
self.eps = eps
|
| 146 |
+
self.data_format = data_format
|
| 147 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 148 |
+
raise NotImplementedError
|
| 149 |
+
self.normalized_shape = (normalized_shape, )
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
if self.data_format == "channels_last":
|
| 153 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 154 |
+
elif self.data_format == "channels_first":
|
| 155 |
+
u = x.mean(1, keepdim=True)
|
| 156 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 157 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 158 |
+
if self.elementwise_affine:
|
| 159 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class ConvNextBlock(nn.Module):
|
| 164 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 165 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 166 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 167 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
dim (int): Number of input channels.
|
| 171 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 172 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 173 |
+
"""
|
| 174 |
+
def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
|
| 177 |
+
self.norm = nn.LayerNorm(in_channel, eps=1e-6)
|
| 178 |
+
self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 179 |
+
self.act = nn.GELU()
|
| 180 |
+
self.pwconv2 = nn.Linear(hidden_dim, out_channel)
|
| 181 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
|
| 182 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 183 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
input = x
|
| 187 |
+
x = self.dwconv(x)
|
| 188 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 189 |
+
x = self.norm(x)
|
| 190 |
+
x = self.pwconv1(x)
|
| 191 |
+
x = self.act(x)
|
| 192 |
+
x = self.pwconv2(x)
|
| 193 |
+
if self.gamma is not None:
|
| 194 |
+
x = self.gamma * x
|
| 195 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 196 |
+
|
| 197 |
+
x = input + self.drop_path(x)
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
class Decoder(nn.Module):
|
| 201 |
+
def __init__(self, depth=[2,2,2,2], dim=[112, 72, 40, 24], block_type = None, kernel_size = 3) -> None:
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.depth = depth
|
| 204 |
+
self.dim = dim
|
| 205 |
+
self.block_type = block_type
|
| 206 |
+
self._build_decode_layer(dim, depth, kernel_size)
|
| 207 |
+
self.pixelshuffle=nn.PixelShuffle(2)
|
| 208 |
+
# self.star_relu=StarReLU()
|
| 209 |
+
self.projback_ = nn.Sequential(
|
| 210 |
+
nn.Conv2d(
|
| 211 |
+
in_channels=dim[-1],
|
| 212 |
+
out_channels=2 ** 2 * 3 , kernel_size=1),
|
| 213 |
+
nn.PixelShuffle(2)
|
| 214 |
+
)
|
| 215 |
+
self.projback_2 = nn.Sequential(
|
| 216 |
+
nn.Conv2d(
|
| 217 |
+
in_channels=dim[-1],
|
| 218 |
+
out_channels=2 ** 2 * 3, kernel_size=1),
|
| 219 |
+
nn.PixelShuffle(2)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def _build_decode_layer(self, dim, depth, kernel_size):
|
| 223 |
+
normal_layers = nn.ModuleList()
|
| 224 |
+
upsample_layers = nn.ModuleList()
|
| 225 |
+
proj_layers = nn.ModuleList()
|
| 226 |
+
|
| 227 |
+
norm_layer = LayerNorm
|
| 228 |
+
|
| 229 |
+
for i in range(1, len(dim)):
|
| 230 |
+
module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
|
| 231 |
+
normal_layers.append(nn.Sequential(*module))
|
| 232 |
+
upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
| 233 |
+
proj_layers.append(nn.Sequential(
|
| 234 |
+
nn.Conv2d(dim[i-1], dim[i], 1, 1),
|
| 235 |
+
norm_layer(dim[i]),
|
| 236 |
+
# StarReLU() #self.star_relu()
|
| 237 |
+
nn.GELU()
|
| 238 |
+
))
|
| 239 |
+
for i in range(1, len(dim)):
|
| 240 |
+
module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
|
| 241 |
+
normal_layers.append(nn.Sequential(*module))
|
| 242 |
+
upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
| 243 |
+
proj_layers.append(nn.Sequential(
|
| 244 |
+
nn.Conv2d(dim[i-1], dim[i], 1, 1),
|
| 245 |
+
norm_layer(dim[i]),
|
| 246 |
+
))
|
| 247 |
+
self.normal_layers = normal_layers
|
| 248 |
+
self.upsample_layers = upsample_layers
|
| 249 |
+
self.proj_layers = proj_layers
|
| 250 |
+
|
| 251 |
+
def _forward_stage(self, stage, x):
|
| 252 |
+
x = self.proj_layers[stage](x)
|
| 253 |
+
x = self.upsample_layers[stage](x)
|
| 254 |
+
return self.normal_layers[stage](x)
|
| 255 |
+
|
| 256 |
+
def forward(self, c3, c2, c1, c0):
|
| 257 |
+
c0_clean, c0_ref = c0, c0
|
| 258 |
+
c1_clean, c1_ref = c1, c1
|
| 259 |
+
c2_clean, c2_ref = c2, c2
|
| 260 |
+
c3_clean, c3_ref = c3, c3
|
| 261 |
+
x_clean = self._forward_stage(0, c3_clean) * c2_clean
|
| 262 |
+
x_clean = self._forward_stage(1, x_clean) * c1_clean
|
| 263 |
+
x_clean = self._forward_stage(2, x_clean) * c0_clean
|
| 264 |
+
x_clean = self.projback_(x_clean)
|
| 265 |
+
|
| 266 |
+
x_ref = self._forward_stage(3, c3_ref) * c2_ref
|
| 267 |
+
x_ref = self._forward_stage(4, x_ref) * c1_ref
|
| 268 |
+
x_ref = self._forward_stage(5, x_ref) * c0_ref
|
| 269 |
+
x_ref = self.projback_2(x_ref)
|
| 270 |
+
|
| 271 |
+
x=torch.cat((x_clean,x_ref),dim=1)
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
class SimDecoder(nn.Module):
|
| 275 |
+
def __init__(self, in_channel, encoder_stride) -> None:
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.projback = nn.Sequential(
|
| 278 |
+
LayerNorm(in_channel),
|
| 279 |
+
nn.Conv2d(
|
| 280 |
+
in_channels=in_channel,
|
| 281 |
+
out_channels=encoder_stride ** 2 * 3, kernel_size=1),
|
| 282 |
+
nn.PixelShuffle(encoder_stride),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def forward(self, c3):
|
| 286 |
+
return self.projback(c3)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class StarReLU(nn.Module):
|
| 290 |
+
"""
|
| 291 |
+
StarReLU: s * relu(x) ** 2 + b
|
| 292 |
+
"""
|
| 293 |
+
def __init__(self, scale_value=1.0, bias_value=0.0,
|
| 294 |
+
scale_learnable=True, bias_learnable=True,
|
| 295 |
+
mode=None, inplace=True):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.inplace = inplace
|
| 298 |
+
self.relu = nn.ReLU(inplace=inplace)
|
| 299 |
+
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
| 300 |
+
requires_grad=scale_learnable)
|
| 301 |
+
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
| 302 |
+
requires_grad=bias_learnable)
|
| 303 |
+
def forward(self, x):
|
| 304 |
+
return self.scale * self.relu(x)**2 + self.bias
|
RDNet-main/RDNet-main/models/arch/reverse_function.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Any, Iterable, List, Tuple, Callable
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
|
| 7 |
+
fwd_gpu_states = []
|
| 8 |
+
for device in fwd_gpu_devices:
|
| 9 |
+
with torch.cuda.device(device):
|
| 10 |
+
fwd_gpu_states.append(torch.cuda.get_rng_state())
|
| 11 |
+
|
| 12 |
+
return fwd_gpu_states
|
| 13 |
+
|
| 14 |
+
def get_gpu_device(*args):
|
| 15 |
+
|
| 16 |
+
fwd_gpu_devices = list(set(arg.get_device() for arg in args
|
| 17 |
+
if isinstance(arg, torch.Tensor) and arg.is_cuda))
|
| 18 |
+
return fwd_gpu_devices
|
| 19 |
+
|
| 20 |
+
def set_device_states(fwd_cpu_state, devices, states) -> None:
|
| 21 |
+
torch.set_rng_state(fwd_cpu_state)
|
| 22 |
+
for device, state in zip(devices, states):
|
| 23 |
+
with torch.cuda.device(device):
|
| 24 |
+
torch.cuda.set_rng_state(state)
|
| 25 |
+
|
| 26 |
+
def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
| 27 |
+
if isinstance(inputs, tuple):
|
| 28 |
+
out = []
|
| 29 |
+
for inp in inputs:
|
| 30 |
+
if not isinstance(inp, torch.Tensor):
|
| 31 |
+
out.append(inp)
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
x = inp.detach()
|
| 35 |
+
x.requires_grad = True
|
| 36 |
+
out.append(x)
|
| 37 |
+
return tuple(out)
|
| 38 |
+
else:
|
| 39 |
+
raise RuntimeError(
|
| 40 |
+
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
| 41 |
+
|
| 42 |
+
def get_cpu_and_gpu_states(gpu_devices):
|
| 43 |
+
return torch.get_rng_state(), get_gpu_states(gpu_devices)
|
| 44 |
+
|
| 45 |
+
class ReverseFunction(torch.autograd.Function):
|
| 46 |
+
@staticmethod
|
| 47 |
+
def forward(ctx, run_functions, alpha, *args):
|
| 48 |
+
l0, l1, l2, l3 = run_functions
|
| 49 |
+
alpha0, alpha1, alpha2, alpha3 = alpha
|
| 50 |
+
ctx.run_functions = run_functions
|
| 51 |
+
ctx.alpha = alpha
|
| 52 |
+
ctx.preserve_rng_state = True
|
| 53 |
+
|
| 54 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
| 55 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
| 56 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
| 57 |
+
ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
|
| 58 |
+
"dtype": torch.get_autocast_cpu_dtype(),
|
| 59 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
| 60 |
+
|
| 61 |
+
assert len(args) == 5
|
| 62 |
+
[x, c0, c1, c2, c3] = args
|
| 63 |
+
if type(c0) == int:
|
| 64 |
+
ctx.first_col = True
|
| 65 |
+
else:
|
| 66 |
+
ctx.first_col = False
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
gpu_devices = get_gpu_device(*args)
|
| 69 |
+
ctx.gpu_devices = gpu_devices
|
| 70 |
+
ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
|
| 71 |
+
c0 = l0(x, c1) + c0*alpha0
|
| 72 |
+
ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
|
| 73 |
+
c1 = l1(c0, c2) + c1*alpha1
|
| 74 |
+
ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
|
| 75 |
+
c2 = l2(c1, c3) + c2*alpha2
|
| 76 |
+
ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
|
| 77 |
+
c3 = l3(c2, None) + c3*alpha3
|
| 78 |
+
ctx.save_for_backward(x, c0, c1, c2, c3)
|
| 79 |
+
return x, c0, c1 ,c2, c3
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def backward(ctx, *grad_outputs):
|
| 83 |
+
x, c0, c1, c2, c3 = ctx.saved_tensors
|
| 84 |
+
l0, l1, l2, l3 = ctx.run_functions
|
| 85 |
+
alpha0, alpha1, alpha2, alpha3 = ctx.alpha
|
| 86 |
+
gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
|
| 87 |
+
(x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))
|
| 88 |
+
|
| 89 |
+
with torch.enable_grad(), \
|
| 90 |
+
torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
|
| 91 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
| 92 |
+
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
| 93 |
+
|
| 94 |
+
g3_up = g3_right
|
| 95 |
+
g3_left = g3_up*alpha3 ##shortcut
|
| 96 |
+
set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
|
| 97 |
+
oup3 = l3(c2, None)
|
| 98 |
+
torch.autograd.backward(oup3, g3_up, retain_graph=True)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
|
| 101 |
+
g2_up = g2_right+ c2.grad
|
| 102 |
+
g2_left = g2_up*alpha2 ##shortcut
|
| 103 |
+
|
| 104 |
+
(c3_left,) = detach_and_grad((c3_left,))
|
| 105 |
+
set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
|
| 106 |
+
oup2 = l2(c1, c3_left)
|
| 107 |
+
torch.autograd.backward(oup2, g2_up, retain_graph=True)
|
| 108 |
+
c3_left.requires_grad = False
|
| 109 |
+
cout3 = c3_left*alpha3 ##alpha3 update
|
| 110 |
+
torch.autograd.backward(cout3, g3_up)
|
| 111 |
+
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
|
| 114 |
+
g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
|
| 115 |
+
g1_up = g1_right+c1.grad
|
| 116 |
+
g1_left = g1_up*alpha1 ##shortcut
|
| 117 |
+
|
| 118 |
+
(c2_left,) = detach_and_grad((c2_left,))
|
| 119 |
+
set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
|
| 120 |
+
oup1 = l1(c0, c2_left)
|
| 121 |
+
torch.autograd.backward(oup1, g1_up, retain_graph=True)
|
| 122 |
+
c2_left.requires_grad = False
|
| 123 |
+
cout2 = c2_left*alpha2 ##alpha2 update
|
| 124 |
+
torch.autograd.backward(cout2, g2_up)
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
|
| 128 |
+
g0_up = g0_right + c0.grad
|
| 129 |
+
g0_left = g0_up*alpha0 ##shortcut
|
| 130 |
+
g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
|
| 131 |
+
|
| 132 |
+
(c1_left,) = detach_and_grad((c1_left,))
|
| 133 |
+
set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
|
| 134 |
+
oup0 = l0(x, c1_left)
|
| 135 |
+
torch.autograd.backward(oup0, g0_up, retain_graph=True)
|
| 136 |
+
c1_left.requires_grad = False
|
| 137 |
+
cout1 = c1_left*alpha1 ##alpha1 update
|
| 138 |
+
torch.autograd.backward(cout1, g1_up)
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
|
| 142 |
+
gx_up = x.grad ## Fusion
|
| 143 |
+
g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
|
| 144 |
+
c0_left.requires_grad = False
|
| 145 |
+
cout0 = c0_left*alpha0 ##alpha0 update
|
| 146 |
+
torch.autograd.backward(cout0, g0_up)
|
| 147 |
+
|
| 148 |
+
if ctx.first_col:
|
| 149 |
+
return None, None, gx_up, None, None, None, None
|
| 150 |
+
else:
|
| 151 |
+
return None, None, gx_up, g0_left, g1_left, g2_left, g3_left
|
| 152 |
+
|
| 153 |
+
|
RDNet-main/RDNet-main/models/arch/vgg.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import models
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Vgg16(torch.nn.Module):
|
| 8 |
+
def __init__(self, requires_grad=False):
|
| 9 |
+
super(Vgg16, self).__init__()
|
| 10 |
+
vgg_pretrained_features = models.vgg16(pretrained=True).features
|
| 11 |
+
self.slice1 = torch.nn.Sequential()
|
| 12 |
+
self.slice2 = torch.nn.Sequential()
|
| 13 |
+
self.slice3 = torch.nn.Sequential()
|
| 14 |
+
self.slice4 = torch.nn.Sequential()
|
| 15 |
+
for x in range(4):
|
| 16 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 17 |
+
for x in range(4, 9):
|
| 18 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 19 |
+
for x in range(9, 16):
|
| 20 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 21 |
+
for x in range(16, 23):
|
| 22 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 23 |
+
if not requires_grad:
|
| 24 |
+
for param in self.parameters():
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
|
| 27 |
+
def forward(self, X):
|
| 28 |
+
h = self.slice1(X)
|
| 29 |
+
h_relu1_2 = h
|
| 30 |
+
h = self.slice2(h)
|
| 31 |
+
h_relu2_2 = h
|
| 32 |
+
h = self.slice3(h)
|
| 33 |
+
h_relu3_3 = h
|
| 34 |
+
h = self.slice4(h)
|
| 35 |
+
h_relu4_3 = h
|
| 36 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
|
| 37 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3)
|
| 38 |
+
return out
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Vgg19(torch.nn.Module):
|
| 42 |
+
def __init__(self, requires_grad=False):
|
| 43 |
+
super(Vgg19, self).__init__()
|
| 44 |
+
# vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 45 |
+
self.vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 46 |
+
# self.slice1 = torch.nn.Sequential()
|
| 47 |
+
# self.slice2 = torch.nn.Sequential()
|
| 48 |
+
# self.slice3 = torch.nn.Sequential()
|
| 49 |
+
# self.slice4 = torch.nn.Sequential()
|
| 50 |
+
# self.slice5 = torch.nn.Sequential()
|
| 51 |
+
# for x in range(2):
|
| 52 |
+
# self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 53 |
+
# for x in range(2, 7):
|
| 54 |
+
# self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 55 |
+
# for x in range(7, 12):
|
| 56 |
+
# self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 57 |
+
# for x in range(12, 21):
|
| 58 |
+
# self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 59 |
+
# for x in range(21, 30):
|
| 60 |
+
# self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 61 |
+
if not requires_grad:
|
| 62 |
+
for param in self.parameters():
|
| 63 |
+
param.requires_grad = False
|
| 64 |
+
|
| 65 |
+
def forward(self, X, indices=None):
|
| 66 |
+
if indices is None:
|
| 67 |
+
indices = [2, 7, 12, 21, 30]
|
| 68 |
+
out = []
|
| 69 |
+
# indices = sorted(indices)
|
| 70 |
+
for i in range(indices[-1]):
|
| 71 |
+
X = self.vgg_pretrained_features[i](X)
|
| 72 |
+
if (i + 1) in indices:
|
| 73 |
+
out.append(X)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
# h_relu1 = self.slice1(X)
|
| 78 |
+
# h_relu2 = self.slice2(h_relu1)
|
| 79 |
+
# h_relu3 = self.slice3(h_relu2)
|
| 80 |
+
# h_relu4 = self.slice4(h_relu3)
|
| 81 |
+
# h_relu5 = self.slice5(h_relu4)
|
| 82 |
+
# out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 83 |
+
# return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == '__main__':
|
| 87 |
+
vgg = Vgg19()
|
| 88 |
+
import ipdb
|
| 89 |
+
|
| 90 |
+
ipdb.set_trace()
|
RDNet-main/RDNet-main/models/base_model.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import util.util as util
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseModel:
|
| 7 |
+
def name(self):
|
| 8 |
+
return self.__class__.__name__.lower()
|
| 9 |
+
|
| 10 |
+
def initialize(self, opt):
|
| 11 |
+
self.opt = opt
|
| 12 |
+
self.gpu_ids = opt.gpu_ids
|
| 13 |
+
self.isTrain = opt.isTrain
|
| 14 |
+
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
|
| 15 |
+
last_split = opt.checkpoints_dir.split('/')[-1]
|
| 16 |
+
if opt.resume and last_split != 'checkpoints' and (last_split != opt.name or opt.supp_eval):
|
| 17 |
+
|
| 18 |
+
self.save_dir = opt.checkpoints_dir
|
| 19 |
+
self.model_save_dir = os.path.join(opt.checkpoints_dir.replace(opt.checkpoints_dir.split('/')[-1], ''),
|
| 20 |
+
opt.name)
|
| 21 |
+
else:
|
| 22 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 23 |
+
self.model_save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 24 |
+
self._count = 0
|
| 25 |
+
|
| 26 |
+
def set_input(self, input):
|
| 27 |
+
self.input = input
|
| 28 |
+
|
| 29 |
+
def forward(self, mode='train'):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
# used in test time, no backprop
|
| 33 |
+
def test(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
def get_image_paths(self):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def optimize_parameters(self):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def get_current_visuals(self):
|
| 43 |
+
return self.input
|
| 44 |
+
|
| 45 |
+
def get_current_errors(self):
|
| 46 |
+
return {}
|
| 47 |
+
|
| 48 |
+
def print_optimizer_param(self):
|
| 49 |
+
print(self.optimizers[-1])
|
| 50 |
+
|
| 51 |
+
def save(self, label=None):
|
| 52 |
+
epoch = self.epoch
|
| 53 |
+
iterations = self.iterations
|
| 54 |
+
|
| 55 |
+
if label is None:
|
| 56 |
+
model_name = os.path.join(self.model_save_dir, self.opt.name + '_%03d_%08d.pt' % ((epoch), (iterations)))
|
| 57 |
+
else:
|
| 58 |
+
model_name = os.path.join(self.model_save_dir, self.opt.name + '_' + label + '.pt')
|
| 59 |
+
|
| 60 |
+
torch.save(self.state_dict(), model_name)
|
| 61 |
+
|
| 62 |
+
def save_eval(self, label=None):
|
| 63 |
+
model_name = os.path.join(self.model_save_dir, label + '.pt')
|
| 64 |
+
|
| 65 |
+
torch.save(self.state_dict_eval(), model_name)
|
| 66 |
+
|
| 67 |
+
def _init_optimizer(self, optimizers):
|
| 68 |
+
self.optimizers = optimizers
|
| 69 |
+
for optimizer in self.optimizers:
|
| 70 |
+
util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
|
| 71 |
+
util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
|
RDNet-main/RDNet-main/models/cls_model_eval_nocls_reg.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.losses import DINOLoss
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from ema_pytorch import EMA
|
| 9 |
+
from models.arch.classifier import PretrainedConvNext
|
| 10 |
+
import util.util as util
|
| 11 |
+
import util.index as index
|
| 12 |
+
import models.networks as networks
|
| 13 |
+
import models.losses as losses
|
| 14 |
+
from models import arch
|
| 15 |
+
#from models.arch.dncnn import effnetv2_s
|
| 16 |
+
from .base_model import BaseModel
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from os.path import join
|
| 19 |
+
#from torchviz import make_dot
|
| 20 |
+
from models.arch.RDnet_ import FullNet_NLP
|
| 21 |
+
import timm
|
| 22 |
+
|
| 23 |
+
def tensor2im(image_tensor, imtype=np.uint8):
|
| 24 |
+
image_tensor = image_tensor.detach()
|
| 25 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
| 26 |
+
image_numpy = np.clip(image_numpy, 0, 1)
|
| 27 |
+
if image_numpy.shape[0] == 1:
|
| 28 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
| 29 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
|
| 30 |
+
# image_numpy = image_numpy.astype(imtype)
|
| 31 |
+
return image_numpy
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class EdgeMap(nn.Module):
|
| 35 |
+
def __init__(self, scale=1):
|
| 36 |
+
super(EdgeMap, self).__init__()
|
| 37 |
+
self.scale = scale
|
| 38 |
+
self.requires_grad = False
|
| 39 |
+
|
| 40 |
+
def forward(self, img):
|
| 41 |
+
img = img / self.scale
|
| 42 |
+
|
| 43 |
+
N, C, H, W = img.shape
|
| 44 |
+
gradX = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
|
| 45 |
+
gradY = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device)
|
| 46 |
+
|
| 47 |
+
gradx = (img[..., 1:, :] - img[..., :-1, :]).abs().sum(dim=1, keepdim=True)
|
| 48 |
+
grady = (img[..., 1:] - img[..., :-1]).abs().sum(dim=1, keepdim=True)
|
| 49 |
+
|
| 50 |
+
gradX[..., :-1, :] += gradx
|
| 51 |
+
gradX[..., 1:, :] += gradx
|
| 52 |
+
gradX[..., 1:-1, :] /= 2
|
| 53 |
+
|
| 54 |
+
gradY[..., :-1] += grady
|
| 55 |
+
gradY[..., 1:] += grady
|
| 56 |
+
gradY[..., 1:-1] /= 2
|
| 57 |
+
|
| 58 |
+
# edge = (gradX + gradY) / 2
|
| 59 |
+
edge = (gradX + gradY)
|
| 60 |
+
|
| 61 |
+
return edge
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class YTMTNetBase(BaseModel):
|
| 65 |
+
def _init_optimizer(self, optimizers):
|
| 66 |
+
self.optimizers = optimizers
|
| 67 |
+
for optimizer in self.optimizers:
|
| 68 |
+
util.set_opt_param(optimizer, 'initial_lr', self.opt.lr)
|
| 69 |
+
util.set_opt_param(optimizer, 'weight_decay', self.opt.wd)
|
| 70 |
+
|
| 71 |
+
def set_input(self, data, mode='train'):
|
| 72 |
+
target_t = None
|
| 73 |
+
target_r = None
|
| 74 |
+
data_name = None
|
| 75 |
+
identity = False
|
| 76 |
+
mode = mode.lower()
|
| 77 |
+
if mode == 'train':
|
| 78 |
+
input, target_t, target_r = data['input'], data['target_t'], data['target_r']
|
| 79 |
+
elif mode == 'eval':
|
| 80 |
+
input, target_t, target_r, data_name = data['input'], data['target_t'], data['target_r'], data['fn']
|
| 81 |
+
elif mode == 'test':
|
| 82 |
+
input, data_name = data['input'], data['fn']
|
| 83 |
+
else:
|
| 84 |
+
raise NotImplementedError('Mode [%s] is not implemented' % mode)
|
| 85 |
+
|
| 86 |
+
if len(self.gpu_ids) > 0: # transfer data into gpu
|
| 87 |
+
input = input.to(device=self.gpu_ids[0])
|
| 88 |
+
if target_t is not None:
|
| 89 |
+
target_t = target_t.to(device=self.gpu_ids[0])
|
| 90 |
+
if target_r is not None:
|
| 91 |
+
target_r = target_r.to(device=self.gpu_ids[0])
|
| 92 |
+
|
| 93 |
+
self.input = input
|
| 94 |
+
self.identity = identity
|
| 95 |
+
self.input_edge = self.edge_map(self.input)
|
| 96 |
+
self.target_t = target_t
|
| 97 |
+
self.target_r = target_r
|
| 98 |
+
self.data_name = data_name
|
| 99 |
+
|
| 100 |
+
self.issyn = False if 'real' in data else True
|
| 101 |
+
self.aligned = False if 'unaligned' in data else True
|
| 102 |
+
|
| 103 |
+
if target_t is not None:
|
| 104 |
+
self.target_edge = self.edge_map(self.target_t)
|
| 105 |
+
|
| 106 |
+
def eval(self, data, savedir=None, suffix=None, pieapp=None):
|
| 107 |
+
self._eval()
|
| 108 |
+
self.set_input(data, 'eval')
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
self.forward_eval()
|
| 111 |
+
|
| 112 |
+
output_i = tensor2im(self.output_j[6])
|
| 113 |
+
output_j = tensor2im(self.output_j[7])
|
| 114 |
+
target = tensor2im(self.target_t)
|
| 115 |
+
target_r = tensor2im(self.target_r)
|
| 116 |
+
|
| 117 |
+
if self.aligned:
|
| 118 |
+
res = index.quality_assess(output_i, target)
|
| 119 |
+
else:
|
| 120 |
+
res = {}
|
| 121 |
+
|
| 122 |
+
if savedir is not None:
|
| 123 |
+
if self.data_name is not None:
|
| 124 |
+
name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
|
| 125 |
+
savedir = join(savedir, suffix, name)
|
| 126 |
+
os.makedirs(savedir, exist_ok=True)
|
| 127 |
+
Image.fromarray(output_i.astype(np.uint8)).save(
|
| 128 |
+
join(savedir, '{}_t.png'.format(self.opt.name)))
|
| 129 |
+
Image.fromarray(output_j.astype(np.uint8)).save(
|
| 130 |
+
join(savedir, '{}_r.png'.format(self.opt.name)))
|
| 131 |
+
Image.fromarray(target.astype(np.uint8)).save(join(savedir, 't_label.png'))
|
| 132 |
+
Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, 'm_input.png'))
|
| 133 |
+
else:
|
| 134 |
+
if not os.path.exists(join(savedir, 'transmission_layer')):
|
| 135 |
+
os.makedirs(join(savedir, 'transmission_layer'))
|
| 136 |
+
os.makedirs(join(savedir, 'blended'))
|
| 137 |
+
Image.fromarray(target.astype(np.uint8)).save(
|
| 138 |
+
join(savedir, 'transmission_layer', str(self._count) + '.png'))
|
| 139 |
+
Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(
|
| 140 |
+
join(savedir, 'blended', str(self._count) + '.png'))
|
| 141 |
+
self._count += 1
|
| 142 |
+
|
| 143 |
+
return res
|
| 144 |
+
|
| 145 |
+
def test(self, data, savedir=None):
|
| 146 |
+
# only the 1st input of the whole minibatch would be evaluated
|
| 147 |
+
self._eval()
|
| 148 |
+
self.set_input(data, 'test')
|
| 149 |
+
|
| 150 |
+
if self.data_name is not None and savedir is not None:
|
| 151 |
+
name = os.path.splitext(os.path.basename(self.data_name[0]))[0]
|
| 152 |
+
if not os.path.exists(join(savedir, name)):
|
| 153 |
+
os.makedirs(join(savedir, name))
|
| 154 |
+
|
| 155 |
+
if os.path.exists(join(savedir, name, '{}.png'.format(self.opt.name))):
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
output_i, output_j = self.forward()
|
| 160 |
+
output_i = tensor2im(output_i)
|
| 161 |
+
output_j = tensor2im(output_j)
|
| 162 |
+
if self.data_name is not None and savedir is not None:
|
| 163 |
+
Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name, '{}_l.png'.format(self.opt.name)))
|
| 164 |
+
Image.fromarray(output_j.astype(np.uint8)).save(join(savedir, name, '{}_r.png'.format(self.opt.name)))
|
| 165 |
+
Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, name, 'm_input.png'))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ClsModel(YTMTNetBase):
|
| 169 |
+
def name(self):
|
| 170 |
+
return 'ytmtnet'
|
| 171 |
+
|
| 172 |
+
def __init__(self):
|
| 173 |
+
self.epoch = 0
|
| 174 |
+
self.iterations = 0
|
| 175 |
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 176 |
+
self.net_c = None
|
| 177 |
+
|
| 178 |
+
def print_network(self):
|
| 179 |
+
print('--------------------- Model ---------------------')
|
| 180 |
+
print('##################### NetG #####################')
|
| 181 |
+
networks.print_network(self.net_i)
|
| 182 |
+
if self.isTrain and self.opt.lambda_gan > 0:
|
| 183 |
+
print('##################### NetD #####################')
|
| 184 |
+
networks.print_network(self.netD)
|
| 185 |
+
|
| 186 |
+
def _eval(self):
|
| 187 |
+
self.net_i.eval()
|
| 188 |
+
self.net_c.eval()
|
| 189 |
+
|
| 190 |
+
def _train(self):
|
| 191 |
+
self.net_i.train()
|
| 192 |
+
self.net_c.eval()
|
| 193 |
+
def initialize(self, opt):
|
| 194 |
+
self.opt=opt
|
| 195 |
+
BaseModel.initialize(self, opt)
|
| 196 |
+
|
| 197 |
+
in_channels = 3
|
| 198 |
+
self.vgg = None
|
| 199 |
+
|
| 200 |
+
if opt.hyper:
|
| 201 |
+
self.vgg = losses.Vgg19(requires_grad=False).to(self.device)
|
| 202 |
+
in_channels += 1472
|
| 203 |
+
channels = [64, 128, 256, 512]
|
| 204 |
+
layers = [2, 2, 4, 2]
|
| 205 |
+
num_subnet = opt.num_subnet
|
| 206 |
+
self.net_c = PretrainedConvNext("convnext_small_in22k").cuda()
|
| 207 |
+
|
| 208 |
+
self.net_c.load_state_dict(torch.load('pretrained/cls_model.pth')['icnn'])
|
| 209 |
+
|
| 210 |
+
self.net_i = FullNet_NLP(channels, layers, num_subnet, opt.loss_col,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None, kernel_size=3).to(self.device)
|
| 211 |
+
|
| 212 |
+
self.edge_map = EdgeMap(scale=1).to(self.device)
|
| 213 |
+
|
| 214 |
+
if self.isTrain:
|
| 215 |
+
self.loss_dic = losses.init_loss(opt, self.Tensor)
|
| 216 |
+
vggloss = losses.ContentLoss()
|
| 217 |
+
vggloss.initialize(losses.VGGLoss(self.vgg))
|
| 218 |
+
self.loss_dic['t_vgg'] = vggloss
|
| 219 |
+
|
| 220 |
+
cxloss = losses.ContentLoss()
|
| 221 |
+
if opt.unaligned_loss == 'vgg':
|
| 222 |
+
cxloss.initialize(losses.VGGLoss(self.vgg, weights=[0.1], indices=[opt.vgg_layer]))
|
| 223 |
+
elif opt.unaligned_loss == 'ctx':
|
| 224 |
+
cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1], indices=[8, 13, 22]))
|
| 225 |
+
elif opt.unaligned_loss == 'mse':
|
| 226 |
+
cxloss.initialize(nn.MSELoss())
|
| 227 |
+
elif opt.unaligned_loss == 'ctx_vgg':
|
| 228 |
+
cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1, 0.1, 0.1, 0.1], indices=[8, 13, 22, 31],
|
| 229 |
+
criterions=[losses.CX_loss] * 3 + [nn.L1Loss()]))
|
| 230 |
+
else:
|
| 231 |
+
raise NotImplementedError
|
| 232 |
+
self.scaler=torch.cuda.amp.GradScaler()
|
| 233 |
+
with torch.autocast(device_type='cuda',dtype=torch.float16):
|
| 234 |
+
self.dinoloss=DINOLoss()
|
| 235 |
+
self.loss_dic['t_cx'] = cxloss
|
| 236 |
+
|
| 237 |
+
self.optimizer_G = torch.optim.Adam(self.net_i.parameters(),
|
| 238 |
+
lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.wd)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
self._init_optimizer([self.optimizer_G])
|
| 242 |
+
|
| 243 |
+
if opt.resume:
|
| 244 |
+
self.load(self, opt.resume_epoch)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def backward_D(self):
|
| 248 |
+
loss_D=[]
|
| 249 |
+
weight=self.opt.weight_loss
|
| 250 |
+
for p in self.netD.parameters():
|
| 251 |
+
p.requires_grad = True
|
| 252 |
+
for i in range(4):
|
| 253 |
+
loss_D_1, pred_fake_1, pred_real_1 = self.loss_dic['gan'].get_loss(
|
| 254 |
+
self.netD, self.input, self.output_j[2*i], self.target_t)
|
| 255 |
+
loss_D.append(loss_D_1*weight)
|
| 256 |
+
weight+=self.opt.weight_loss
|
| 257 |
+
loss_sum=sum(loss_D)
|
| 258 |
+
|
| 259 |
+
self.loss_D, self.pred_fake, self.pred_real = (loss_sum, pred_fake_1, pred_real_1)
|
| 260 |
+
|
| 261 |
+
(self.loss_D * self.opt.lambda_gan).backward(retain_graph=True)
|
| 262 |
+
|
| 263 |
+
def get_loss(self, out_l, out_r):
|
| 264 |
+
loss_G_GAN_sum=[]
|
| 265 |
+
loss_icnn_pixel_sum=[]
|
| 266 |
+
loss_rcnn_pixel_sum=[]
|
| 267 |
+
loss_icnn_vgg_sum=[]
|
| 268 |
+
weight=self.opt.weight_loss
|
| 269 |
+
for i in range(self.opt.loss_col):
|
| 270 |
+
out_r_clean=out_r[2*i]
|
| 271 |
+
out_r_reflection=out_r[2*i+1]
|
| 272 |
+
if i != self.opt.loss_col -1:
|
| 273 |
+
loss_G_GAN = 0
|
| 274 |
+
loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
|
| 275 |
+
loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
|
| 276 |
+
loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
|
| 277 |
+
else:
|
| 278 |
+
if self.opt.lambda_gan>0:
|
| 279 |
+
|
| 280 |
+
loss_G_GAN=0
|
| 281 |
+
else:
|
| 282 |
+
loss_G_GAN=0
|
| 283 |
+
loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss(out_r_clean, self.target_t)
|
| 284 |
+
loss_rcnn_pixel = self.loss_dic['r_pixel'].get_loss(out_r_reflection, self.target_r) * 1.5 * self.opt.r_pixel_weight
|
| 285 |
+
loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss(out_r_clean, self.target_t) * self.opt.lambda_vgg
|
| 286 |
+
|
| 287 |
+
loss_G_GAN_sum.append(loss_G_GAN*weight)
|
| 288 |
+
loss_icnn_pixel_sum.append(loss_icnn_pixel*weight)
|
| 289 |
+
loss_rcnn_pixel_sum.append(loss_rcnn_pixel*weight)
|
| 290 |
+
loss_icnn_vgg_sum.append(loss_icnn_vgg*weight)
|
| 291 |
+
weight=weight+self.opt.weight_loss
|
| 292 |
+
return sum(loss_G_GAN_sum), sum(loss_icnn_pixel_sum), sum(loss_rcnn_pixel_sum), sum(loss_icnn_vgg_sum)
|
| 293 |
+
|
| 294 |
+
def backward_G(self):
|
| 295 |
+
|
| 296 |
+
self.loss_G_GAN,self.loss_icnn_pixel, self.loss_rcnn_pixel, \
|
| 297 |
+
self.loss_icnn_vgg = self.get_loss(self.output_i, self.output_j)
|
| 298 |
+
|
| 299 |
+
self.loss_exclu = self.exclusion_loss(self.output_i, self.output_j, 3)
|
| 300 |
+
|
| 301 |
+
self.loss_recons = self.loss_dic['recons'](self.output_i, self.output_j, self.input) * 0.2
|
| 302 |
+
|
| 303 |
+
self.loss_G = self.loss_G_GAN +self.loss_icnn_pixel + self.loss_rcnn_pixel + \
|
| 304 |
+
self.loss_icnn_vgg
|
| 305 |
+
self.scaler.scale(self.loss_G).backward()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def hyper_column(self, input_img):
|
| 310 |
+
hypercolumn = self.vgg(input_img)
|
| 311 |
+
_, C, H, W = input_img.shape
|
| 312 |
+
hypercolumn = [F.interpolate(feature.detach(), size=(H, W), mode='bilinear', align_corners=False) for
|
| 313 |
+
feature in hypercolumn]
|
| 314 |
+
input_i = [input_img]
|
| 315 |
+
input_i.extend(hypercolumn)
|
| 316 |
+
input_i = torch.cat(input_i, dim=1)
|
| 317 |
+
return input_i
|
| 318 |
+
|
| 319 |
+
def forward(self):
|
| 320 |
+
# without edge
|
| 321 |
+
|
| 322 |
+
self.output_j=[]
|
| 323 |
+
input_i = self.input
|
| 324 |
+
if self.vgg is not None:
|
| 325 |
+
input_i = self.hyper_column(input_i)
|
| 326 |
+
with torch.no_grad():
|
| 327 |
+
ipt = self.net_c(input_i)
|
| 328 |
+
output_i, output_j = self.net_i(input_i,ipt,prompt=True)
|
| 329 |
+
self.output_i = output_i
|
| 330 |
+
for i in range(self.opt.loss_col):
|
| 331 |
+
out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
|
| 332 |
+
self.output_j.append(out_clean)
|
| 333 |
+
self.output_j.append(out_reflection)
|
| 334 |
+
return self.output_i, self.output_j
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
@torch.no_grad()
|
| 338 |
+
def forward_eval(self):
|
| 339 |
+
|
| 340 |
+
self.output_j=[]
|
| 341 |
+
input_i = self.input
|
| 342 |
+
if self.vgg is not None:
|
| 343 |
+
input_i = self.hyper_column(input_i)
|
| 344 |
+
ipt = self.net_c(input_i)
|
| 345 |
+
|
| 346 |
+
output_i, output_j = self.net_i(input_i,ipt,prompt=True)
|
| 347 |
+
self.output_i = output_i #alpha * output_i + beta
|
| 348 |
+
for i in range(self.opt.loss_col):
|
| 349 |
+
out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
|
| 350 |
+
self.output_j.append(out_clean)
|
| 351 |
+
self.output_j.append(out_reflection)
|
| 352 |
+
return self.output_i, self.output_j
|
| 353 |
+
|
| 354 |
+
def optimize_parameters(self):
|
| 355 |
+
self._train()
|
| 356 |
+
self.forward()
|
| 357 |
+
self.optimizer_G.zero_grad()
|
| 358 |
+
self.backward_G()
|
| 359 |
+
self.optimizer_G.step()
|
| 360 |
+
|
| 361 |
+
def return_output(self):
|
| 362 |
+
output_clean = self.output_j[1]
|
| 363 |
+
output_reflection = self.output_j[0]
|
| 364 |
+
output_clean = tensor2im(output_clean).astype(np.uint8)
|
| 365 |
+
output_reflection = tensor2im(output_reflection).astype(np.uint8)
|
| 366 |
+
input=tensor2im(self.input)
|
| 367 |
+
return output_clean,output_reflection,input
|
| 368 |
+
def exclusion_loss(self, img_T, img_R, level=3, eps=1e-6):
|
| 369 |
+
loss_gra=[]
|
| 370 |
+
weight=0.25
|
| 371 |
+
for i in range(4):
|
| 372 |
+
grad_x_loss = []
|
| 373 |
+
grad_y_loss = []
|
| 374 |
+
img_T=self.output_j[2*i]
|
| 375 |
+
img_R=self.output_j[2*i+1]
|
| 376 |
+
for l in range(level):
|
| 377 |
+
grad_x_T, grad_y_T = self.compute_grad(img_T)
|
| 378 |
+
grad_x_R, grad_y_R = self.compute_grad(img_R)
|
| 379 |
+
|
| 380 |
+
alphax = (2.0 * torch.mean(torch.abs(grad_x_T))) / (torch.mean(torch.abs(grad_x_R)) + eps)
|
| 381 |
+
alphay = (2.0 * torch.mean(torch.abs(grad_y_T))) / (torch.mean(torch.abs(grad_y_R)) + eps)
|
| 382 |
+
|
| 383 |
+
gradx1_s = (torch.sigmoid(grad_x_T) * 2) - 1 # mul 2 minus 1 is to change sigmoid into tanh
|
| 384 |
+
grady1_s = (torch.sigmoid(grad_y_T) * 2) - 1
|
| 385 |
+
gradx2_s = (torch.sigmoid(grad_x_R * alphax) * 2) - 1
|
| 386 |
+
grady2_s = (torch.sigmoid(grad_y_R * alphay) * 2) - 1
|
| 387 |
+
|
| 388 |
+
grad_x_loss.append(((torch.mean(torch.mul(gradx1_s.pow(2), gradx2_s.pow(2)))) + eps) ** 0.25)
|
| 389 |
+
grad_y_loss.append(((torch.mean(torch.mul(grady1_s.pow(2), grady2_s.pow(2)))) + eps) ** 0.25)
|
| 390 |
+
|
| 391 |
+
img_T = F.interpolate(img_T, scale_factor=0.5, mode='bilinear')
|
| 392 |
+
img_R = F.interpolate(img_R, scale_factor=0.5, mode='bilinear')
|
| 393 |
+
loss_gradxy = torch.sum(sum(grad_x_loss) / 3) + torch.sum(sum(grad_y_loss) / 3)
|
| 394 |
+
loss_gra.append(loss_gradxy*weight)
|
| 395 |
+
weight+=0.25
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
return sum(loss_gra) / 2
|
| 399 |
+
|
| 400 |
+
def contain_loss(self, img_T, img_R, img_I, eps=1e-6):
|
| 401 |
+
pix_num = np.prod(img_I.shape)
|
| 402 |
+
predict_tx, predict_ty = self.compute_grad(img_T)
|
| 403 |
+
predict_tx, predict_ty = self.compute_grad(img_T)
|
| 404 |
+
predict_rx, predict_ry = self.compute_grad(img_R)
|
| 405 |
+
input_x, input_y = self.compute_grad(img_I)
|
| 406 |
+
|
| 407 |
+
out = torch.norm(predict_tx / (input_x + eps), 2) ** 2 + \
|
| 408 |
+
torch.norm(predict_ty / (input_y + eps), 2) ** 2 + \
|
| 409 |
+
torch.norm(predict_rx / (input_x + eps), 2) ** 2 + \
|
| 410 |
+
torch.norm(predict_ry / (input_y + eps), 2) ** 2
|
| 411 |
+
|
| 412 |
+
return out / pix_num
|
| 413 |
+
|
| 414 |
+
def compute_grad(self, img):
|
| 415 |
+
gradx = img[:, :, 1:, :] - img[:, :, :-1, :]
|
| 416 |
+
grady = img[:, :, :, 1:] - img[:, :, :, :-1]
|
| 417 |
+
return gradx, grady
|
| 418 |
+
|
| 419 |
+
def load(self, model, resume_epoch=None):
|
| 420 |
+
icnn_path = model.opt.icnn_path
|
| 421 |
+
state_dict = torch.load(icnn_path)
|
| 422 |
+
model.net_i.load_state_dict(state_dict['icnn'])
|
| 423 |
+
return state_dict
|
| 424 |
+
|
| 425 |
+
def state_dict(self):
|
| 426 |
+
state_dict = {
|
| 427 |
+
'icnn': self.net_i.state_dict(),
|
| 428 |
+
'opt_g': self.optimizer_G.state_dict(),
|
| 429 |
+
#'ema' : self.ema.state_dict(),
|
| 430 |
+
'epoch': self.epoch, 'iterations': self.iterations
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
if self.opt.lambda_gan > 0:
|
| 434 |
+
state_dict.update({
|
| 435 |
+
'opt_d': self.optimizer_D.state_dict(),
|
| 436 |
+
'netD': self.netD.state_dict(),
|
| 437 |
+
})
|
| 438 |
+
|
| 439 |
+
return state_dict
|
| 440 |
+
class AvgPool2d(nn.Module):
|
| 441 |
+
def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.kernel_size = kernel_size
|
| 444 |
+
self.base_size = base_size
|
| 445 |
+
self.auto_pad = auto_pad
|
| 446 |
+
|
| 447 |
+
# only used for fast implementation
|
| 448 |
+
self.fast_imp = fast_imp
|
| 449 |
+
self.rs = [5, 4, 3, 2, 1]
|
| 450 |
+
self.max_r1 = self.rs[0]
|
| 451 |
+
self.max_r2 = self.rs[0]
|
| 452 |
+
self.train_size = train_size
|
| 453 |
+
|
| 454 |
+
def extra_repr(self) -> str:
|
| 455 |
+
return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
|
| 456 |
+
self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def forward(self, x):
|
| 460 |
+
if self.kernel_size is None and self.base_size:
|
| 461 |
+
train_size = self.train_size
|
| 462 |
+
if isinstance(self.base_size, int):
|
| 463 |
+
self.base_size = (self.base_size, self.base_size)
|
| 464 |
+
self.kernel_size = list(self.base_size)
|
| 465 |
+
self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
|
| 466 |
+
self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
|
| 467 |
+
|
| 468 |
+
# only used for fast implementation
|
| 469 |
+
self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
|
| 470 |
+
self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
|
| 471 |
+
|
| 472 |
+
if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
|
| 473 |
+
return F.adaptive_avg_pool2d(x, 1)
|
| 474 |
+
|
| 475 |
+
if self.fast_imp: # Non-equivalent implementation but faster
|
| 476 |
+
h, w = x.shape[2:]
|
| 477 |
+
if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
|
| 478 |
+
out = F.adaptive_avg_pool2d(x, 1)
|
| 479 |
+
else:
|
| 480 |
+
r1 = [r for r in self.rs if h % r == 0][0]
|
| 481 |
+
r2 = [r for r in self.rs if w % r == 0][0]
|
| 482 |
+
# reduction_constraint
|
| 483 |
+
r1 = min(self.max_r1, r1)
|
| 484 |
+
r2 = min(self.max_r2, r2)
|
| 485 |
+
s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
|
| 486 |
+
n, c, h, w = s.shape
|
| 487 |
+
k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
|
| 488 |
+
out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
|
| 489 |
+
out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
|
| 490 |
+
else:
|
| 491 |
+
n, c, h, w = x.shape
|
| 492 |
+
s = x.cumsum(dim=-1).cumsum_(dim=-2)
|
| 493 |
+
s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
|
| 494 |
+
k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
|
| 495 |
+
s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
|
| 496 |
+
out = s4 + s1 - s2 - s3
|
| 497 |
+
out = out / (k1 * k2)
|
| 498 |
+
|
| 499 |
+
if self.auto_pad:
|
| 500 |
+
n, c, h, w = x.shape
|
| 501 |
+
_h, _w = out.shape[2:]
|
| 502 |
+
# print(x.shape, self.kernel_size)
|
| 503 |
+
pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
|
| 504 |
+
out = torch.nn.functional.pad(out, pad2d, mode='replicate')
|
| 505 |
+
|
| 506 |
+
return out
|
| 507 |
+
|
| 508 |
+
def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
|
| 509 |
+
for n, m in model.named_children():
|
| 510 |
+
if len(list(m.children())) > 0:
|
| 511 |
+
## compound module, go inside it
|
| 512 |
+
replace_layers(m, base_size, train_size, fast_imp, **kwargs)
|
| 513 |
+
|
| 514 |
+
if isinstance(m, nn.AdaptiveAvgPool2d):
|
| 515 |
+
pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
|
| 516 |
+
assert m.output_size == 1
|
| 517 |
+
setattr(model, n, pool)
|