Spaces:
Runtime error
Runtime error
test
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +3 -0
- JarvisIR/.gitignore +146 -0
- JarvisIR/.gitmodules +7 -0
- JarvisIR/LICENSE +21 -0
- JarvisIR/README.md +148 -0
- JarvisIR/demo_gradio.py +474 -0
- JarvisIR/docs/gradio_demo.md +36 -0
- JarvisIR/docs/sft_training.md +84 -0
- JarvisIR/package/README.md +155 -0
- JarvisIR/package/agent_tools.egg-info/PKG-INFO +7 -0
- JarvisIR/package/agent_tools.egg-info/SOURCES.txt +192 -0
- JarvisIR/package/agent_tools.egg-info/dependency_links.txt +1 -0
- JarvisIR/package/agent_tools.egg-info/top_level.txt +1 -0
- JarvisIR/package/agent_tools/ESRGAN/__init__.py +0 -0
- JarvisIR/package/agent_tools/ESRGAN/inference.py +49 -0
- JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus.yml +188 -0
- JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus_pairdata.yml +150 -0
- JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x2plus.yml +186 -0
- JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x4plus.yml +185 -0
- JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x2plus.yml +145 -0
- JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x4plus.yml +144 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/__init__.py +5 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/__init__.py +11 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/discriminator_arch.py +67 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/srvgg_arch.py +69 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/__init__.py +11 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_dataset.py +192 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_paired_dataset.py +108 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/__init__.py +11 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrgan_model.py +258 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrnet_model.py +188 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/train.py +11 -0
- JarvisIR/package/agent_tools/ESRGAN/realesrgan/utils.py +309 -0
- JarvisIR/package/agent_tools/HVICIDNet/inference.py +65 -0
- JarvisIR/package/agent_tools/HVICIDNet/loss/loss_utils.py +145 -0
- JarvisIR/package/agent_tools/HVICIDNet/loss/losses.py +193 -0
- JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_pris_params.npz +3 -0
- JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_utils.py +559 -0
- JarvisIR/package/agent_tools/HVICIDNet/loss/vgg_arch.py +239 -0
- JarvisIR/package/agent_tools/HVICIDNet/mods.py +153 -0
- JarvisIR/package/agent_tools/HVICIDNet/net/CIDNet.py +129 -0
- JarvisIR/package/agent_tools/HVICIDNet/net/HVI_transform.py +122 -0
- JarvisIR/package/agent_tools/HVICIDNet/net/LCA.py +93 -0
- JarvisIR/package/agent_tools/HVICIDNet/net/transformer_utils.py +71 -0
- JarvisIR/package/agent_tools/HVICIDNet/wavelet.py +65 -0
- JarvisIR/package/agent_tools/IDT/__init__.py +0 -0
- JarvisIR/package/agent_tools/IDT/analyse/cal_rf_bf.py +68 -0
- JarvisIR/package/agent_tools/IDT/configs/daytime_128.yml +55 -0
- JarvisIR/package/agent_tools/IDT/configs/daytime_256.yml +55 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
JarvisIR/package/agent_tools/RIDCP/.eggs/**/*.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/.eggs/
|
| 2 |
+
*.so
|
| 3 |
+
JarvisIR/package/agent_tools/RIDCP/.eggs
|
JarvisIR/.gitignore
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ignored folders
|
| 2 |
+
datasets/*
|
| 3 |
+
experiments*
|
| 4 |
+
experiments/*
|
| 5 |
+
# results/*
|
| 6 |
+
tb_logger*
|
| 7 |
+
pretrained_models
|
| 8 |
+
wandb/*
|
| 9 |
+
tmp*/*
|
| 10 |
+
*.sh
|
| 11 |
+
.vscode*
|
| 12 |
+
.github
|
| 13 |
+
# ignored files
|
| 14 |
+
version.py
|
| 15 |
+
|
| 16 |
+
# ignored files with suffix
|
| 17 |
+
*.html
|
| 18 |
+
*.pth
|
| 19 |
+
*.zip
|
| 20 |
+
|
| 21 |
+
# template
|
| 22 |
+
|
| 23 |
+
# Byte-compiled / optimized / DLL files
|
| 24 |
+
__pycache__/
|
| 25 |
+
*.py[cod]
|
| 26 |
+
*$py.class
|
| 27 |
+
|
| 28 |
+
# C extensions
|
| 29 |
+
*.so
|
| 30 |
+
|
| 31 |
+
# Distribution / packaging
|
| 32 |
+
.Python
|
| 33 |
+
build/
|
| 34 |
+
develop-eggs/
|
| 35 |
+
dist/
|
| 36 |
+
downloads/
|
| 37 |
+
eggs/
|
| 38 |
+
.eggs/
|
| 39 |
+
lib/
|
| 40 |
+
lib64/
|
| 41 |
+
parts/
|
| 42 |
+
sdist/
|
| 43 |
+
var/
|
| 44 |
+
wheels/
|
| 45 |
+
*.egg-info/
|
| 46 |
+
.installed.cfg
|
| 47 |
+
*.egg
|
| 48 |
+
MANIFEST
|
| 49 |
+
|
| 50 |
+
# PyInstaller
|
| 51 |
+
# Usually these files are written by a python script from a template
|
| 52 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 53 |
+
*.manifest
|
| 54 |
+
*.spec
|
| 55 |
+
|
| 56 |
+
# Installer logs
|
| 57 |
+
pip-log.txt
|
| 58 |
+
pip-delete-this-directory.txt
|
| 59 |
+
|
| 60 |
+
# Unit test / coverage reports
|
| 61 |
+
htmlcov/
|
| 62 |
+
.tox/
|
| 63 |
+
.coverage
|
| 64 |
+
.coverage.*
|
| 65 |
+
.cache
|
| 66 |
+
nosetests.xml
|
| 67 |
+
coverage.xml
|
| 68 |
+
*.cover
|
| 69 |
+
.hypothesis/
|
| 70 |
+
.pytest_cache/
|
| 71 |
+
|
| 72 |
+
# Translations
|
| 73 |
+
*.mo
|
| 74 |
+
*.pot
|
| 75 |
+
|
| 76 |
+
# Django stuff:
|
| 77 |
+
*.log
|
| 78 |
+
local_settings.py
|
| 79 |
+
db.sqlite3
|
| 80 |
+
|
| 81 |
+
# Flask stuff:
|
| 82 |
+
instance/
|
| 83 |
+
.webassets-cache
|
| 84 |
+
|
| 85 |
+
# Scrapy stuff:
|
| 86 |
+
.scrapy
|
| 87 |
+
|
| 88 |
+
# Sphinx documentation
|
| 89 |
+
docs/_build/
|
| 90 |
+
|
| 91 |
+
# PyBuilder
|
| 92 |
+
target/
|
| 93 |
+
|
| 94 |
+
# Jupyter Notebook
|
| 95 |
+
.ipynb_checkpoints
|
| 96 |
+
|
| 97 |
+
# pyenv
|
| 98 |
+
.python-version
|
| 99 |
+
|
| 100 |
+
# celery beat schedule file
|
| 101 |
+
celerybeat-schedule
|
| 102 |
+
|
| 103 |
+
# SageMath parsed files
|
| 104 |
+
*.sage.py
|
| 105 |
+
|
| 106 |
+
# Environments
|
| 107 |
+
.env
|
| 108 |
+
.venv
|
| 109 |
+
env/
|
| 110 |
+
venv/
|
| 111 |
+
ENV/
|
| 112 |
+
env.bak/
|
| 113 |
+
venv.bak/
|
| 114 |
+
|
| 115 |
+
# Spyder project settings
|
| 116 |
+
.spyderproject
|
| 117 |
+
.spyproject
|
| 118 |
+
|
| 119 |
+
# Rope project settings
|
| 120 |
+
.ropeproject
|
| 121 |
+
|
| 122 |
+
# mkdocs documentation
|
| 123 |
+
/site
|
| 124 |
+
|
| 125 |
+
# mypy
|
| 126 |
+
.mypy_cache/
|
| 127 |
+
|
| 128 |
+
checkpoint
|
| 129 |
+
*.zip
|
| 130 |
+
*.png
|
| 131 |
+
*.jpg
|
| 132 |
+
*.jpeg
|
| 133 |
+
*.json
|
| 134 |
+
.eggs
|
| 135 |
+
*.pth*
|
| 136 |
+
*.pk
|
| 137 |
+
*.pkl
|
| 138 |
+
*.mdb
|
| 139 |
+
*.pt
|
| 140 |
+
*.log
|
| 141 |
+
*.bin
|
| 142 |
+
__pycache__
|
| 143 |
+
Temp
|
| 144 |
+
work_dirs
|
| 145 |
+
src/sft/config
|
| 146 |
+
checkpoints
|
JarvisIR/.gitmodules
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "dependences/BasicSR"]
|
| 2 |
+
path = dependences/BasicSR
|
| 3 |
+
url = https://github.com/XPixelGroup/BasicSR.git
|
| 4 |
+
[submodule "src/sft/xtuner"]
|
| 5 |
+
path = src/sft/xtuner
|
| 6 |
+
url = https://github.com/InternLM/xtuner.git
|
| 7 |
+
branch = v0.1.23
|
JarvisIR/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Yunlong Lin
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
JarvisIR/README.md
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<div align="center">
|
| 3 |
+
<img src="assets/icon.png" alt="JarvisIR Logo" width="100px">
|
| 4 |
+
<h1>[CVPR' 2025] JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration</h1>
|
| 5 |
+
</div>
|
| 6 |
+
|
| 7 |
+
<a href="https://lyl1015.github.io/papers/CVPR2025_JarvisIR.pdf" target="_blank" rel="noopener noreferrer">
|
| 8 |
+
<img src="https://img.shields.io/badge/Paper-JarvisIR-b31b1b" alt="Paper PDF">
|
| 9 |
+
</a>
|
| 10 |
+
<!-- <a href="#"><img src="https://img.shields.io/badge/arXiv-即将发布-b31b1b" alt="arXiv"></a> -->
|
| 11 |
+
<a href="https://cvpr2025-jarvisir.github.io/"><img src="https://img.shields.io/badge/Project-Page-green" alt="Project Page"></a>
|
| 12 |
+
<a href="#"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Coming%20Soon-blue" alt="Demo"></a>
|
| 13 |
+
<a href="https://github.com/LYL1015/JarvisIR?tab=readme-ov-file/"><img src="https://img.shields.io/badge/GitHub-Code-black" alt="Code"></a>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
[Yunlong Lin](https://lyl1015.github.io/)<sup>1*♣</sup>, [Zixu Lin](https://github.com/)<sup>1*♣</sup>, [Haoyu Chen](https://haoyuchen.com/)<sup>2*</sup>, [Panwang Pan](https://paulpanwang.github.io/)<sup>3*</sup>, [Chenxin Li](https://chenxinli001.github.io/)<sup>6</sup>, [Sixiang Chen](https://ephemeral182.github.io/)<sup>2</sup>, [Kairun Wen](https://kairunwen.github.io/)<sup>1</sup>, [Yeying Jin](https://jinyeying.github.io/)<sup>4</sup>, [Wenbo Li](https://fenglinglwb.github.io/)<sup>5†</sup>, [Xinghao Ding](https://scholar.google.com/citations?user=k5hVBfMAAAAJ&hl=zh-CN)<sup>1†</sup>
|
| 17 |
+
|
| 18 |
+
<sup>1</sup>Xiamen University, <sup>2</sup>The Hong Kong University of Science and Technology (Guangzhou), <sup>3</sup>Bytedance's Pico, <sup>4</sup>Tencent, <sup>5</sup>Huawei Noah's Ark Lab, <sup>6</sup>The Chinese University of Hong Kong
|
| 19 |
+
<!-- <sup>*</sup>Equal Contribution <sup>♣</sup>Equal Contribution <sup>†</sup>Corresponding Author -->
|
| 20 |
+
Accepted by CVPR 2025
|
| 21 |
+
|
| 22 |
+
<!-- <div align="center">
|
| 23 |
+
<video width="800" controls>
|
| 24 |
+
<source src="assets/demo.mp4" type="video/mp4">
|
| 25 |
+
Your browser does not support the video tag.
|
| 26 |
+
</video>
|
| 27 |
+
<p>JarvisIR Demo Video: Showcasing image restoration capabilities under various adverse weather conditions</p>
|
| 28 |
+
</div> -->
|
| 29 |
+
https://github.com/user-attachments/assets/d9094fba-e24c-403e-90cb-b3d2f6e48939
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
</div>
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## :postbox: Updates
|
| 36 |
+
<!-- - 2023.12.04: Add an option to speed up the inference process by adjusting the number of denoising steps. -->
|
| 37 |
+
<!-- - 2024.2.9: Release our demo codes and models. Have fun! :yum: -->
|
| 38 |
+
- 2025.4.8: This repo is created.
|
| 39 |
+
|
| 40 |
+
## :diamonds: Overview
|
| 41 |
+
JarvisIR (CVPR 2025) is a VLM-powered agent designed to tackle the challenges of vision-centric perception systems under unpredictable and coupled weather degradations. It leverages the VLM as a controller to manage multiple expert restoration models, enabling robust and autonomous operation in real-world conditions. JarvisIR employs a novel two-stage framework consisting of supervised fine-tuning and human feedback alignment, allowing it to effectively fine-tune on large-scale real-world data in an unsupervised manner. Supported by CleanBench, a comprehensive dataset with 150K synthetic and 80K real instruction-response pairs, JarvisIR demonstrates superior decision-making and restoration capabilities, achieving a 50% improvement in the average of all perception metrics on CleanBench-Real.
|
| 42 |
+
<div align="center">
|
| 43 |
+
<img src="assets/teaser1.png" alt="JarvisIR Teaser" width="800px">
|
| 44 |
+
</div>
|
| 45 |
+
|
| 46 |
+
## :rocket: Method
|
| 47 |
+
|
| 48 |
+
JarvisIR implements an innovative two-stage framework that leverages a Vision-Language Model (VLM) as a controller to manage multiple expert restoration models:
|
| 49 |
+
|
| 50 |
+
1. **Supervised Fine-tuning Stage**: JarvisIR undergoes supervised fine-tuning on synthetic data from CleanBench to enable it to follow user instructions and recognize image degradation. This initial training allows the model to identify various types of image degradation and select appropriate restoration strategies.
|
| 51 |
+
|
| 52 |
+
2. **Human Feedback Alignment Stage**: We further finetune JarvisIR on CleanBench-Real using the MRRHF algorithm to improve system robustness, reduce hallucinations, and enhance generalizability under real-world adverse weather conditions. This stage ensures the model makes decisions that align with human expectations in complex real-world scenarios.
|
| 53 |
+
|
| 54 |
+
The core advantage of JarvisIR lies in its ability to handle multiple complex, coupled weather degradations and provide stable, reliable image inputs for autonomous driving perception systems.
|
| 55 |
+
|
| 56 |
+
<div align="center">
|
| 57 |
+
<img src="assets/framework.png" alt="JarvisIR Method" width="800px">
|
| 58 |
+
<p>Two-stage training framework of JarvisIR: The VLM controller analyzes input images, selects and coordinates expert models for optimal restoration</p>
|
| 59 |
+
</div>
|
| 60 |
+
|
| 61 |
+
## :bar_chart: CleanBench Dataset
|
| 62 |
+
|
| 63 |
+
To support the training and evaluation of JarvisIR, we introduce CleanBench, the first high-quality instruction-following dataset specifically curated for developing intelligent restoration systems. CleanBench contains **150K** synthetic and **80K** real instruction-response pairs, providing a comprehensive foundation for training and evaluating intelligent image restoration systems.
|
| 64 |
+
|
| 65 |
+
### Dataset Construction
|
| 66 |
+
|
| 67 |
+
The CleanBench dataset construction workflow consists of three main steps:
|
| 68 |
+
|
| 69 |
+
1. **Synthesis of Degraded Images**: We generate a diverse set of degraded images by applying various weather conditions and degradation types to clean images, creating realistic scenarios that autonomous driving systems might encounter.
|
| 70 |
+
|
| 71 |
+
2. **Generation of Assessment Reasoning and Optimal Task Sequence**: For each degraded image, we generate detailed assessments of the degradation types present and determine the optimal sequence of restoration tasks needed to effectively restore the image.
|
| 72 |
+
|
| 73 |
+
3. **Generation of Instruction-Response Pairs**: Based on the degradation assessment and restoration sequence, we create comprehensive instruction-response pairs that guide the model in understanding user requests and providing appropriate restoration solutions.
|
| 74 |
+
|
| 75 |
+
<div align="center">
|
| 76 |
+
<img src="assets/datasets.png" alt="CleanBench Dataset Construction" width="800px">
|
| 77 |
+
<p>CleanBench dataset construction workflow: from degraded image synthesis to instruction-response pair generation</p>
|
| 78 |
+
</div>
|
| 79 |
+
|
| 80 |
+
### Dataset Features
|
| 81 |
+
|
| 82 |
+
- **Comprehensive Coverage**: Includes various weather conditions (rain, snow, fog, night) and their combinations
|
| 83 |
+
- **High-Quality Annotations**: Detailed degradation assessments and optimal restoration sequences
|
| 84 |
+
- **Real-World Examples**: 80K real-world examples to ensure model generalization
|
| 85 |
+
- **Instruction Diversity**: Multiple instruction formats to enhance model adaptability
|
| 86 |
+
|
| 87 |
+
CleanBench serves as a crucial resource for training and evaluating intelligent image restoration systems, enabling models like JarvisIR to make informed decisions about restoration strategies in complex real-world scenarios.
|
| 88 |
+
|
| 89 |
+
## :computer: Getting Started
|
| 90 |
+
|
| 91 |
+
For sft training and environment setup preparation, please follow:
|
| 92 |
+
|
| 93 |
+
- [SFT Training](./docs/sft_training.md)
|
| 94 |
+
<!-- - [Dataset Preparation](./docs/dataset_preparation.md) -->
|
| 95 |
+
<!--
|
| 96 |
+
For sft training, please follow:
|
| 97 |
+
|
| 98 |
+
- [SFT Training](./docs/sft_training.md) -->
|
| 99 |
+
|
| 100 |
+
For gradio demo runing, please follow:
|
| 101 |
+
|
| 102 |
+
- [Gradio Demo](./docs/gradio_demo.md)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
## :toolbox: Expert Models
|
| 106 |
+
|
| 107 |
+
JarvisIR integrates multiple expert restoration models to handle various types of image degradation:
|
| 108 |
+
|
| 109 |
+
| Task | Model | Description |
|
| 110 |
+
|------|-------|-------------|
|
| 111 |
+
| **Super-resolution** | Real-ESRGAN | Fast GAN-based model for super-resolution, deblurring, and artifact removal |
|
| 112 |
+
| **Denoising** | SCUNet | Hybrid UNet-based model combining convolution and transformer blocks for robust denoising |
|
| 113 |
+
| **Deraining** | UDR-S2Former | Uncertainty-aware transformer model for rain streak removal |
|
| 114 |
+
| | Img2img-turbo-rain | Efficient SD-turbo based model for fast and effective rain removal |
|
| 115 |
+
| **Raindrop removal** | IDT | Transformer-based model for de-raining and raindrop removal |
|
| 116 |
+
| **Dehazing** | RIDCP | Efficient dehazing model utilizing high-quality codebook priors |
|
| 117 |
+
| | KANet | Efficient dehazing network using a localization-and-removal pipeline |
|
| 118 |
+
| **Desnowing** | Img2img-turbo-snow | Efficient model for removing snow artifacts while preserving natural scene details |
|
| 119 |
+
| | Snowmaster | Real-world image desnowing via MLLM with multi-model feedback optimization |
|
| 120 |
+
| **Low-light enhancement** | Retinexformer | One-stage Retinex-based Transformer for low-light image enhancement |
|
| 121 |
+
| | HVICIDNet | Lightweight transformer for low-light and exposure correction |
|
| 122 |
+
| | LightenDiff | Diffusion-based framework for low-light enhancement |
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
## :circus_tent: Checklist
|
| 126 |
+
|
| 127 |
+
- [x] Release preview inference code and gradio demo
|
| 128 |
+
- [x] Release SFT training code
|
| 129 |
+
- [ ] Release Hugging Face demo
|
| 130 |
+
- [ ] Release CleanBench data
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
## :pray: Acknowledgements
|
| 134 |
+
|
| 135 |
+
We would like to express our gratitude to [HuggingGPT](https://github.com/microsoft/JARVIS), [XTuner](https://github.com/InternLM/xtuner), and [RRHF](https://github.com/GanjinZero/RRHF) for their valuable open-source contributions which have provided important technical references for our work.
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
## :love_you_gesture: Citation
|
| 141 |
+
```bibtex
|
| 142 |
+
@inproceedings{jarvisir2025,
|
| 143 |
+
title={JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration},
|
| 144 |
+
author={Lin, Yunlong and Lin, Zixu and Chen, Haoyu and Pan, Panwang and Li, Chenxin and Chen, Sixiang and Kairun, Wen and Jin, Yeying and Li, Wenbo and Ding, Xinghao},
|
| 145 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 146 |
+
year={2025}
|
| 147 |
+
}
|
| 148 |
+
```
|
JarvisIR/demo_gradio.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
|
| 8 |
+
from threading import Thread
|
| 9 |
+
from agent_tools import RestorationToolkit
|
| 10 |
+
|
| 11 |
+
# Set CUDA device
|
| 12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 13 |
+
|
| 14 |
+
# Model configuration
|
| 15 |
+
# XXX: Path to the fine-tuned LLaVA model
|
| 16 |
+
model_id = ""
|
| 17 |
+
|
| 18 |
+
# Available image restoration tasks and their corresponding models
|
| 19 |
+
all_tasks = " {denoise: [scunet, restormer], lighten: [retinexformer_fivek, hvicidnet, lightdiff], \
|
| 20 |
+
derain: [idt, turbo_rain, s2former], defog:[ridcp, kanet], \
|
| 21 |
+
desnow:[turbo_snow, snowmaster], super_resolution: [real_esrgan], \
|
| 22 |
+
}"
|
| 23 |
+
|
| 24 |
+
# Various prompt templates for querying the LLM about image degradation and restoration tasks
|
| 25 |
+
prompts_query2 = [
|
| 26 |
+
f"Considering the image's degradation, suggest the required tasks with explanations, and identify suitable tools for each task. Options for tasks and tools include: {all_tasks}.",
|
| 27 |
+
f"Given the image's degradation, outline the essential tasks along with justifications, and choose the appropriate tools for each task from the following options: {all_tasks}.",
|
| 28 |
+
f"Please specify the tasks required due to the image's degradation, explain the reasons, and select relevant tools for each task from the provided options: {all_tasks}.",
|
| 29 |
+
f"Based on the image degradation, determine the necessary tasks and their reasons, along with the appropriate tools for each task. Choose from these options: {all_tasks}.",
|
| 30 |
+
f"Identify the tasks required to address the image's degradation, including the reasons for each, and select tools from the options: {all_tasks}.",
|
| 31 |
+
f"Considering the degradation observed, list the tasks needed and their justifications, then pick the most suitable tools for each task from these options: {all_tasks}.",
|
| 32 |
+
f"Evaluate the image degradation, and based on that, provide the necessary tasks and reasons, along with tools chosen from the options: {all_tasks}.",
|
| 33 |
+
f"With respect to the image degradation, outline the tasks needed and explain why, selecting tools from the following list: {all_tasks}.",
|
| 34 |
+
f"Given the level of degradation in the image, specify tasks to address it, include reasons, and select tools for each task from: {all_tasks}.",
|
| 35 |
+
f"Examine the image's degradation, propose relevant tasks and their explanations, and identify tools from the options provided: {all_tasks}.",
|
| 36 |
+
f"Based on observed degradation, detail the tasks required, explain your choices, and select tools from these options: {all_tasks}.",
|
| 37 |
+
f"Using the image's degradation as a guide, list the necessary tasks, include explanations, and pick tools from the provided choices: {all_tasks}.",
|
| 38 |
+
f"Assess the image degradation, provide the essential tasks and reasons, and select the appropriate tools for each task from the options: {all_tasks}.",
|
| 39 |
+
f"According to the image's degradation, determine which tasks are necessary and why, choosing tools for each task from: {all_tasks}.",
|
| 40 |
+
f"Observe the degradation in the image, specify the needed tasks with justifications, and select appropriate tools from: {all_tasks}.",
|
| 41 |
+
f"Taking the image degradation into account, specify tasks needed, provide reasons, and choose tools from the following: {all_tasks}.",
|
| 42 |
+
f"Consider the image's degradation level, outline the tasks necessary, provide reasoning, and select suitable tools from: {all_tasks}.",
|
| 43 |
+
f"Evaluate the degradation in the image, identify tasks required, explain your choices, and pick tools from: {all_tasks}.",
|
| 44 |
+
f"Analyze the image degradation and suggest tasks with justifications, choosing the best tools from these options: {all_tasks}.",
|
| 45 |
+
f"Review the image degradation, and based on it, specify tasks needed, provide reasons, and select tools for each task from: {all_tasks}."
|
| 46 |
+
]
|
| 47 |
+
|
| 48 |
+
# Initialize models
|
| 49 |
+
print("Loading LLM model...")
|
| 50 |
+
|
| 51 |
+
# Initialize the image restoration toolkit
|
| 52 |
+
tool_engine = RestorationToolkit(score_weight=[0,0,0,0,0])
|
| 53 |
+
# Load the LLaVA model in half precision to reduce memory usage
|
| 54 |
+
model = LlavaForConditionalGeneration.from_pretrained(
|
| 55 |
+
model_id,
|
| 56 |
+
torch_dtype=torch.float16,
|
| 57 |
+
low_cpu_mem_usage=True,
|
| 58 |
+
).to(0)
|
| 59 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 60 |
+
|
| 61 |
+
print("Loading tool engine...")
|
| 62 |
+
|
| 63 |
+
def parse_llm_response(response):
|
| 64 |
+
"""
|
| 65 |
+
Parse the LLM response to extract reason and answer sections
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
response (str): The raw response from the LLM
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
tuple: (reason, answer) extracted from the response
|
| 72 |
+
"""
|
| 73 |
+
reason_match = re.search(r'<reason>(.*?)</reason>', response, re.DOTALL)
|
| 74 |
+
answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
|
| 75 |
+
|
| 76 |
+
reason = reason_match.group(1).strip() if reason_match else "No reasoning provided"
|
| 77 |
+
answer = answer_match.group(1).strip() if answer_match else "No answer provided"
|
| 78 |
+
|
| 79 |
+
return reason, answer
|
| 80 |
+
|
| 81 |
+
def extract_models_from_answer(answer):
|
| 82 |
+
"""
|
| 83 |
+
Extract model names from the answer string using regex
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
answer (str): The answer string containing model recommendations
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
list: List of extracted model names
|
| 90 |
+
"""
|
| 91 |
+
# Pattern to match [type:xxx]:(model:xxx)
|
| 92 |
+
pattern = r'\[type:[^\]]+\]:\(model:([^)]+)\)'
|
| 93 |
+
models = re.findall(pattern, answer)
|
| 94 |
+
return models
|
| 95 |
+
|
| 96 |
+
def beautify_recommended_actions(answer, models):
|
| 97 |
+
"""
|
| 98 |
+
Format the LLM's recommendations in a more visually appealing way
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
answer (str): The raw answer from LLM
|
| 102 |
+
models (list): List of extracted model names
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
str: Beautified display of recommendations
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
# Task type to emoji mapping for visual enhancement
|
| 109 |
+
task_icons = {
|
| 110 |
+
'denoise': '🧹',
|
| 111 |
+
'lighten': '💡',
|
| 112 |
+
'derain': '🌧️',
|
| 113 |
+
'defog': '🌫️',
|
| 114 |
+
'desnow': '❄️',
|
| 115 |
+
'super_resolution': '🔍'
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# Parse the answer to extract tasks and models
|
| 119 |
+
pattern = r'\[type:([^\]]+)\]:\(model:([^)]+)\)'
|
| 120 |
+
matches = re.findall(pattern, answer)
|
| 121 |
+
|
| 122 |
+
if not matches:
|
| 123 |
+
return f"**🎯 Recommended Actions:**\n\n{answer}\n\n**Extracted Models:** {', '.join(models) if models else 'None'}"
|
| 124 |
+
|
| 125 |
+
# Create beautified display
|
| 126 |
+
beautified = "**🎯 Recommended Actions:**\n"
|
| 127 |
+
beautified += "> "
|
| 128 |
+
|
| 129 |
+
# Create horizontal flow of actions
|
| 130 |
+
action_parts = []
|
| 131 |
+
for task_type, model_name in matches:
|
| 132 |
+
task_type = task_type.strip()
|
| 133 |
+
model_name = model_name.strip()
|
| 134 |
+
|
| 135 |
+
# Get icon for task type
|
| 136 |
+
icon = task_icons.get(task_type, '🔧')
|
| 137 |
+
|
| 138 |
+
# Format task name (capitalize and replace underscores)
|
| 139 |
+
task_display = task_type.title().replace('_', ' ')
|
| 140 |
+
|
| 141 |
+
# Create action part: icon + task + model
|
| 142 |
+
action_part = f"{icon} {task_display}:`{model_name}`"
|
| 143 |
+
action_parts.append(action_part)
|
| 144 |
+
|
| 145 |
+
# Join with arrows to show sequence
|
| 146 |
+
beautified += " ➡ ".join(action_parts) + "\n\n"
|
| 147 |
+
|
| 148 |
+
# Add summary information
|
| 149 |
+
beautified += f"**📋 Processing Pipeline:** {len(matches)} steps\n"
|
| 150 |
+
beautified += f"**🛠️ Models to use:** {' → '.join(models)}"
|
| 151 |
+
|
| 152 |
+
return beautified
|
| 153 |
+
|
| 154 |
+
def resize_image_to_original(processed_image_path, original_size):
|
| 155 |
+
"""
|
| 156 |
+
Resize processed image back to original dimensions
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
processed_image_path (str): Path to the processed image
|
| 160 |
+
original_size (tuple): Original image dimensions (width, height)
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
str: Path to the resized image
|
| 164 |
+
"""
|
| 165 |
+
if processed_image_path and os.path.exists(processed_image_path):
|
| 166 |
+
img = Image.open(processed_image_path)
|
| 167 |
+
img_resized = img.resize(original_size, Image.Resampling.LANCZOS)
|
| 168 |
+
|
| 169 |
+
# Save resized image
|
| 170 |
+
output_path = os.path.join('temp_outputs', 'final_result.png')
|
| 171 |
+
img_resized.save(output_path)
|
| 172 |
+
return output_path
|
| 173 |
+
return processed_image_path
|
| 174 |
+
|
| 175 |
+
def get_llm_response_streaming(image_path):
|
| 176 |
+
"""
|
| 177 |
+
Get streaming response from LLM for image analysis
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
image_path (str): Path to the input image
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
TextIteratorStreamer: A streamer object to yield tokens
|
| 184 |
+
"""
|
| 185 |
+
# Select random prompt from the templates
|
| 186 |
+
instruction = prompts_query2[random.randint(0, len(prompts_query2)-1)]
|
| 187 |
+
|
| 188 |
+
# Format the prompt with image for multimodal input
|
| 189 |
+
prompt = (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{instruction}<|eot_id|>"
|
| 190 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
| 191 |
+
|
| 192 |
+
# Load and process image
|
| 193 |
+
raw_image = Image.open(image_path)
|
| 194 |
+
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
|
| 195 |
+
|
| 196 |
+
# Setup streaming for token-by-token generation
|
| 197 |
+
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 198 |
+
|
| 199 |
+
# Generate response in a separate thread to avoid blocking
|
| 200 |
+
generation_kwargs = dict(
|
| 201 |
+
**inputs,
|
| 202 |
+
streamer=streamer,
|
| 203 |
+
max_new_tokens=400,
|
| 204 |
+
do_sample=False
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 208 |
+
thread.start()
|
| 209 |
+
|
| 210 |
+
return streamer
|
| 211 |
+
|
| 212 |
+
def process_image_with_tools(image_path, models, original_size):
|
| 213 |
+
"""
|
| 214 |
+
Process image using the tool engine and restore to original size
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
image_path (str): Path to the input image
|
| 218 |
+
models (list): List of models to apply
|
| 219 |
+
original_size (tuple): Original image dimensions
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
str: Path to the final processed image
|
| 223 |
+
"""
|
| 224 |
+
if not models:
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
# Create output directory
|
| 228 |
+
os.makedirs('temp_outputs', exist_ok=True)
|
| 229 |
+
|
| 230 |
+
# Process the image with selected models
|
| 231 |
+
res = tool_engine.process_image(models, image_path, 'temp_outputs')
|
| 232 |
+
|
| 233 |
+
# Resize back to original dimensions
|
| 234 |
+
final_result = resize_image_to_original(res['output_path'], original_size)
|
| 235 |
+
|
| 236 |
+
return final_result
|
| 237 |
+
|
| 238 |
+
def process_full_pipeline(image):
|
| 239 |
+
"""
|
| 240 |
+
Main processing pipeline with streaming UI updates
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
image (str): Path to the input image
|
| 244 |
+
|
| 245 |
+
Yields:
|
| 246 |
+
tuple: (chat_history, processed_image) for Gradio UI updates
|
| 247 |
+
"""
|
| 248 |
+
if image is None:
|
| 249 |
+
return [], None
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
# Get original image size for later restoration
|
| 253 |
+
original_img = Image.open(image)
|
| 254 |
+
original_size = original_img.size
|
| 255 |
+
|
| 256 |
+
# Initialize chat history for UI
|
| 257 |
+
chat_history = [("Image uploaded for analysis", None)]
|
| 258 |
+
|
| 259 |
+
# Step 1: Get streaming LLM response
|
| 260 |
+
streamer = get_llm_response_streaming(image)
|
| 261 |
+
|
| 262 |
+
# Stream the response to UI with real-time updates
|
| 263 |
+
full_response = ""
|
| 264 |
+
in_reason = False
|
| 265 |
+
in_answer = False
|
| 266 |
+
reason_displayed = False
|
| 267 |
+
answer_displayed = False
|
| 268 |
+
reasoning_added = False # Track if reasoning entry was added
|
| 269 |
+
|
| 270 |
+
for new_text in streamer:
|
| 271 |
+
full_response += new_text
|
| 272 |
+
|
| 273 |
+
# Check if we're entering reason section or if we need to start showing content
|
| 274 |
+
if ('<reason>' in full_response and not in_reason and not reason_displayed) or (not reasoning_added and not in_reason and not reason_displayed):
|
| 275 |
+
in_reason = True
|
| 276 |
+
reasoning_added = True
|
| 277 |
+
|
| 278 |
+
if '<reason>' in full_response:
|
| 279 |
+
# Extract content after <reason>
|
| 280 |
+
reason_start = full_response.find('<reason>') + len('<reason>')
|
| 281 |
+
reason_content = full_response[reason_start:].strip()
|
| 282 |
+
else:
|
| 283 |
+
# Show all content as reasoning if no tag yet
|
| 284 |
+
reason_content = full_response.strip()
|
| 285 |
+
|
| 286 |
+
# Add reasoning to chat history
|
| 287 |
+
chat_history.append((None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}"))
|
| 288 |
+
yield chat_history, None
|
| 289 |
+
|
| 290 |
+
# If we're in reason section, update content
|
| 291 |
+
elif in_reason and not reason_displayed:
|
| 292 |
+
# Check if reason section is complete
|
| 293 |
+
if '</reason>' in full_response:
|
| 294 |
+
# Extract complete reason content
|
| 295 |
+
reason_start = full_response.find('<reason>') + len('<reason>')
|
| 296 |
+
reason_end = full_response.find('</reason>')
|
| 297 |
+
reason_content = full_response[reason_start:reason_end].strip()
|
| 298 |
+
|
| 299 |
+
# Update chat history with complete reason
|
| 300 |
+
chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
|
| 301 |
+
reason_displayed = True
|
| 302 |
+
in_reason = False
|
| 303 |
+
yield chat_history, None
|
| 304 |
+
else:
|
| 305 |
+
# Continue streaming reason content
|
| 306 |
+
if '<reason>' in full_response:
|
| 307 |
+
reason_start = full_response.find('<reason>') + len('<reason>')
|
| 308 |
+
reason_content = full_response[reason_start:].strip()
|
| 309 |
+
else:
|
| 310 |
+
reason_content = full_response.strip()
|
| 311 |
+
|
| 312 |
+
# Update chat history with partial reason
|
| 313 |
+
chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
|
| 314 |
+
yield chat_history, None
|
| 315 |
+
|
| 316 |
+
# Check if we're entering answer section
|
| 317 |
+
elif '<answer>' in full_response and not in_answer and not answer_displayed and reason_displayed:
|
| 318 |
+
in_answer = True
|
| 319 |
+
# Extract content after <answer>
|
| 320 |
+
answer_start = full_response.find('<answer>') + len('<answer>')
|
| 321 |
+
answer_content = full_response[answer_start:]
|
| 322 |
+
|
| 323 |
+
# Add partial answer to chat history
|
| 324 |
+
models = extract_models_from_answer(answer_content)
|
| 325 |
+
beautified = beautify_recommended_actions(answer_content, models)
|
| 326 |
+
chat_history.append((None, beautified))
|
| 327 |
+
yield chat_history, None
|
| 328 |
+
|
| 329 |
+
# If we're in answer section, update content
|
| 330 |
+
elif in_answer and not answer_displayed:
|
| 331 |
+
# Check if answer section is complete
|
| 332 |
+
if '</answer>' in full_response:
|
| 333 |
+
# Extract complete answer content
|
| 334 |
+
answer_start = full_response.find('<answer>') + len('<answer>')
|
| 335 |
+
answer_end = full_response.find('</answer>')
|
| 336 |
+
answer_content = full_response[answer_start:answer_end].strip()
|
| 337 |
+
|
| 338 |
+
# Parse and process final answer
|
| 339 |
+
models = extract_models_from_answer(answer_content)
|
| 340 |
+
beautified = beautify_recommended_actions(answer_content, models)
|
| 341 |
+
chat_history[2] = (None, beautified)
|
| 342 |
+
answer_displayed = True
|
| 343 |
+
in_answer = False
|
| 344 |
+
yield chat_history, None
|
| 345 |
+
|
| 346 |
+
# Process image with tools
|
| 347 |
+
if models:
|
| 348 |
+
chat_history.append((None, "**🔄 Processing image...**"))
|
| 349 |
+
yield chat_history, None
|
| 350 |
+
|
| 351 |
+
processed_image = process_image_with_tools(image, models, original_size)
|
| 352 |
+
chat_history[-1] = (None, "**✅ Processing Complete!**")
|
| 353 |
+
yield chat_history, processed_image
|
| 354 |
+
return
|
| 355 |
+
else:
|
| 356 |
+
chat_history.append((None, "**❌ No valid models found in the response**"))
|
| 357 |
+
yield chat_history, None
|
| 358 |
+
return
|
| 359 |
+
else:
|
| 360 |
+
# Continue streaming answer content
|
| 361 |
+
answer_start = full_response.find('<answer>') + len('<answer>')
|
| 362 |
+
answer_content = full_response[answer_start:].strip()
|
| 363 |
+
|
| 364 |
+
# Update chat history with partial answer
|
| 365 |
+
models = extract_models_from_answer(answer_content)
|
| 366 |
+
beautified = beautify_recommended_actions(answer_content, models)
|
| 367 |
+
chat_history[2] = (None, beautified)
|
| 368 |
+
yield chat_history, None
|
| 369 |
+
|
| 370 |
+
# Fallback if streaming completes without proper tags
|
| 371 |
+
if not answer_displayed:
|
| 372 |
+
reason, answer = parse_llm_response(full_response)
|
| 373 |
+
models = extract_models_from_answer(answer)
|
| 374 |
+
|
| 375 |
+
chat_history = [
|
| 376 |
+
("Image uploaded for analysis", None),
|
| 377 |
+
(None, f"**🤔 Analysis & Reasoning:**\n\n{reason}"),
|
| 378 |
+
(None, beautify_recommended_actions(answer, models))
|
| 379 |
+
]
|
| 380 |
+
|
| 381 |
+
if models:
|
| 382 |
+
chat_history.append((None, "**🔄 Processing image...**"))
|
| 383 |
+
yield chat_history, None
|
| 384 |
+
|
| 385 |
+
processed_image = process_image_with_tools(image, models, original_size)
|
| 386 |
+
chat_history[-1] = (None, "**✅ Processing Complete!**")
|
| 387 |
+
yield chat_history, processed_image
|
| 388 |
+
else:
|
| 389 |
+
chat_history.append((None, "**❌ No valid models found in the response**"))
|
| 390 |
+
yield chat_history, None
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
error_msg = f"Error: {str(e)}"
|
| 394 |
+
chat_history = [
|
| 395 |
+
("Image uploaded for analysis", None),
|
| 396 |
+
(None, f"**❌ Error occurred:**\n\n{error_msg}")
|
| 397 |
+
]
|
| 398 |
+
yield chat_history, None
|
| 399 |
+
|
| 400 |
+
# Create Gradio interface
|
| 401 |
+
def create_interface():
|
| 402 |
+
"""
|
| 403 |
+
Create and configure the Gradio web interface
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
gr.Blocks: Configured Gradio interface
|
| 407 |
+
"""
|
| 408 |
+
with gr.Blocks(title="JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration", theme=gr.themes.Soft()) as demo:
|
| 409 |
+
# Header with logo and title
|
| 410 |
+
gr.Markdown("""
|
| 411 |
+
# <img src="https://cvpr2025-jarvisir.github.io/imgs/icon.png" width="32" height="32" style="display: inline-block; vertical-align: middle; transform: translateY(-2px); margin-right: 1px;"/> JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration
|
| 412 |
+
|
| 413 |
+
Upload an image and let JarvisIR analyze its degradation and recommend the best restoration tools!
|
| 414 |
+
""")
|
| 415 |
+
|
| 416 |
+
with gr.Row():
|
| 417 |
+
with gr.Column(scale=1):
|
| 418 |
+
# Input image upload component
|
| 419 |
+
input_image = gr.Image(
|
| 420 |
+
type="filepath",
|
| 421 |
+
label="📸 Upload Your Image",
|
| 422 |
+
height=400
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# Process button
|
| 426 |
+
process_btn = gr.Button(
|
| 427 |
+
"🚀 Analyze & Process",
|
| 428 |
+
variant="primary",
|
| 429 |
+
size="lg"
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
with gr.Column(scale=1):
|
| 433 |
+
# Chat interface to show analysis
|
| 434 |
+
chatbot = gr.Chatbot(
|
| 435 |
+
label="💬 AI Analysis Chat",
|
| 436 |
+
height=400,
|
| 437 |
+
show_label=True,
|
| 438 |
+
bubble_full_width=False
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
with gr.Row():
|
| 442 |
+
# Output image display
|
| 443 |
+
output_image = gr.Image(
|
| 444 |
+
label="✨ Processed Result",
|
| 445 |
+
height=300
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Connect event handler for the process button
|
| 449 |
+
process_btn.click(
|
| 450 |
+
fn=process_full_pipeline,
|
| 451 |
+
inputs=[input_image],
|
| 452 |
+
outputs=[chatbot, output_image]
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# Instructions section
|
| 456 |
+
gr.Markdown("### 📝 Instructions:")
|
| 457 |
+
gr.Markdown("""
|
| 458 |
+
1. **Upload an image** that needs restoration (blurry, dark, noisy, etc.)
|
| 459 |
+
2. **Click 'Analyze & Process'** to let AI analyze the image
|
| 460 |
+
3. **View the chat** to see AI's reasoning and recommendations in real-time
|
| 461 |
+
4. **Check the result** - processed image restored to original dimensions
|
| 462 |
+
""")
|
| 463 |
+
|
| 464 |
+
return demo
|
| 465 |
+
|
| 466 |
+
if __name__ == "__main__":
|
| 467 |
+
print("Starting Image Restoration Assistant...")
|
| 468 |
+
demo = create_interface()
|
| 469 |
+
# Launch the Gradio app on specified host and port
|
| 470 |
+
demo.launch(
|
| 471 |
+
server_name="0.0.0.0",
|
| 472 |
+
server_port=7866,
|
| 473 |
+
share=False
|
| 474 |
+
)
|
JarvisIR/docs/gradio_demo.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gradio Demo Guide
|
| 2 |
+
|
| 3 |
+
This guide provides instructions on how to run the Gradio demo for JarvisIR.
|
| 4 |
+
|
| 5 |
+
## Environment Setup
|
| 6 |
+
|
| 7 |
+
Please follow the environment setup instructions in the [SFT Training Guide](./sft_training.md#environment-setup) to create the conda environment and install the necessary dependencies. The same environment can be used for running the Gradio demo.
|
| 8 |
+
|
| 9 |
+
Make sure you have activated the conda environment:
|
| 10 |
+
```bash
|
| 11 |
+
conda activate sft_jarvisir
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Download Preview Weights
|
| 16 |
+
|
| 17 |
+
To run the Gradio demo, you need to download the preview weights and place them in the correct location:
|
| 18 |
+
|
| 19 |
+
1. Download the JarvisIR preview weights from Hugging Face
|
| 20 |
+
2. Create the weights directory (if it doesn't exist):
|
| 21 |
+
```bash
|
| 22 |
+
mkdir -p checkpoints/jarvisir-preview/
|
| 23 |
+
```
|
| 24 |
+
3. Place the downloaded weight files in the `checkpoints/jarvisir-preview/` directory
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Running the Demo
|
| 29 |
+
|
| 30 |
+
Once the environment is set up and activated, you can run the Gradio demo with the following command from the root directory of the project:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
python gradio_demo.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
This will launch a web interface where you can interact with the model.
|
JarvisIR/docs/sft_training.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SFT (Supervised Fine-Tuning) Training Guide
|
| 2 |
+
|
| 3 |
+
This guide provides step-by-step instructions for setting up the environment and performing supervised fine-tuning using XTuner framework.
|
| 4 |
+
|
| 5 |
+
## Environment Setup
|
| 6 |
+
|
| 7 |
+
### 1. Create and Activate Virtual Environment
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
conda create -n sft_jarvisir python=3.10
|
| 11 |
+
conda activate sft_jarvisir
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
### 2. Install Dependencies
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
pip install -r requirements.txt
|
| 18 |
+
cd package/agent_tools/Retinexformer && python3 setup.py develop --no_cuda_ext && cd ..
|
| 19 |
+
cd RIDCP && python3 setup.py develop && cd ../..
|
| 20 |
+
pip install -e .
|
| 21 |
+
|
| 22 |
+
cd ../dependences/BasicSR
|
| 23 |
+
pip install -e .
|
| 24 |
+
|
| 25 |
+
cd ../src/sft/xtuner
|
| 26 |
+
|
| 27 |
+
# Install required packages
|
| 28 |
+
pip install -r requirements.txt
|
| 29 |
+
|
| 30 |
+
# Install XTuner with all features
|
| 31 |
+
pip install -e '.[all]'
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 3. Verify Installation
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# Check XTuner installation
|
| 38 |
+
xtuner version
|
| 39 |
+
|
| 40 |
+
# Verify GPU availability
|
| 41 |
+
python -c "import torch; print(f'CUDA Available: {torch.cuda.is_available()}')"
|
| 42 |
+
python -c "import torch; print(f'GPU Count: {torch.cuda.device_count()}')"
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Model Weights Setup
|
| 46 |
+
|
| 47 |
+
### 1. Create Model Directory Structure
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# Create directories for models and datasets
|
| 51 |
+
cd ..
|
| 52 |
+
mkdir -p base_models
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### 2. Download Base Model Weights
|
| 56 |
+
|
| 57 |
+
Download the following models to the base_models folder:
|
| 58 |
+
<table>
|
| 59 |
+
<tr>
|
| 60 |
+
<th>Model</th><th>Model Weights</th>
|
| 61 |
+
</tr>
|
| 62 |
+
<tr style="border-top: 2px solid">
|
| 63 |
+
<td>openai/clip-vit-large-patch14-336</td><td><a href="https://huggingface.co/openai/clip-vit-large-patch14-336"> 🤗 HuggingFace</a></td>
|
| 64 |
+
</tr>
|
| 65 |
+
<tr style="border-top: 2px solid">
|
| 66 |
+
<td>xtuner/llava-llama-3-8b</td><td><a href="https://huggingface.co/xtuner/llava-llama-3-8b"> 🤗 HuggingFace</a></td>
|
| 67 |
+
</tr>
|
| 68 |
+
<tr style="border-top: 2px solid">
|
| 69 |
+
<td>xtuner/llava-llama-3-8b-pretrain</td><td><a href="https://huggingface.co/xtuner/llava-llama-3-8b-pretrain"> 🤗 HuggingFace</a></td>
|
| 70 |
+
</tr>
|
| 71 |
+
</table>
|
| 72 |
+
|
| 73 |
+
### 3. Training with xtuner
|
| 74 |
+
- Modify the model and data paths in the **llava_8b_full.py** file.
|
| 75 |
+
- Run the following command to perform sft.
|
| 76 |
+
```bash
|
| 77 |
+
NPROC_PER_NODE=${GPU_NUM} xtuner train llava_8b_full.py --deepspeed deepspeed_zero2
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Acknowledgements
|
| 81 |
+
We would like to thank the [XTuner](https://github.com/InternLM/xtuner) team for open-sourcing such an excellent framework. For more fine-tuning methods, please refer to their official documentation and code repository.
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
JarvisIR/package/README.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image Restoration Expert Toolkit
|
| 2 |
+
|
| 3 |
+
JarvisIR integrates diverse expert tools for image restoration, targeting problems like low-light conditions, rain/snow, haze, resolution enhancement, and noise. We've optimized our toolkit for both efficiency and performance, with some variations from the original paper.
|
| 4 |
+
|
| 5 |
+
## Key Features
|
| 6 |
+
|
| 7 |
+
- Unified framework for multiple expert models.
|
| 8 |
+
- Simple API for testing each expert model.
|
| 9 |
+
- Quality assessment capabilities through reward functions
|
| 10 |
+
- Modular design for easy extension with new models
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
## Expert Model List
|
| 14 |
+
| Task | Tools | Model Description |
|
| 15 |
+
|---------|---------|------|
|
| 16 |
+
| **Super-resolution & Deblurring & Artifact removal** | Real-ESRGAN | Fast GAN for super-resolution, deblurring, and artifact removal. |
|
| 17 |
+
| **Denoising** | SCUNet | Hybrid UNet-based model combining convolution and transformer blocks, designed for robust denoising. |
|
| 18 |
+
| **Deraining** | UDR-S2Former | An uncertainty-aware transformer model for rain streak removal. |
|
| 19 |
+
| | Img2img-turbo-rain | Efficient model based on SD-turbo, designed for fast and effective rain removal in real-world images. |
|
| 20 |
+
| **Raindrop removal** | IDT | Transformer-based model for de-raining and raindrop removal. |
|
| 21 |
+
| **Dehazing** | RIDCP | Efficient dehazing model utilizing high-quality codebook priors |
|
| 22 |
+
| | KANet | Efficient dehazing network using a localization-and-removal pipeline. |
|
| 23 |
+
| **Desnowing** | Img2img-turbo-snow | Efficient model for removing snow artifacts while preserving natural scene details. |
|
| 24 |
+
| | Snowmaster | Real-world Image Desnowing via MLLM with Multi-Model Feedback Optimization |
|
| 25 |
+
| **Low-light enhancement** | Retinexformer | One-stage Retinex-based Transformer for Low-light Image Enhancement |
|
| 26 |
+
| | HVICIDNet | Lightweight transformer for low-light and exposure correction |
|
| 27 |
+
| | LightenDiff | Diffusion-based framework for low-light enhancement |
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
|
| 31 |
+
### Basic Usage
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
from agent_tools import RestorationToolkit
|
| 35 |
+
|
| 36 |
+
# Initialize the toolkit
|
| 37 |
+
toolkit = RestorationToolkit(device='cuda')
|
| 38 |
+
|
| 39 |
+
# Process an image with a sequence of models
|
| 40 |
+
result = toolkit.process_image(
|
| 41 |
+
tools=['scunet', 'real_esrgan'], # Models to apply in sequence
|
| 42 |
+
img_path='path/to/input.jpg',
|
| 43 |
+
output_dir='path/to/output'
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Access the result
|
| 47 |
+
output_path = result['output_path']
|
| 48 |
+
quality_score = result['score']
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Running the Test API Server
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
from agent_tools import start_server
|
| 55 |
+
|
| 56 |
+
# Start the server with default models
|
| 57 |
+
start_server(host='0.0.0.0', port=5010)
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Or use the API with curl:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
curl -X POST http://localhost:5010/process_image \
|
| 64 |
+
-H "Content-Type: application/json" \
|
| 65 |
+
-d '{"img_path": "path/to/image.jpg", "models": ["scunet", "real_esrgan"]}'
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Cite
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
# Real-esrgan
|
| 72 |
+
@inproceedings{wang2021real,
|
| 73 |
+
title={Real-esrgan: Training real-world blind super-resolution with pure synthetic data},
|
| 74 |
+
author={Wang, Xintao and Xie, Liangbin and Dong, Chao and Shan, Ying},
|
| 75 |
+
booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
|
| 76 |
+
pages={1905--1914},
|
| 77 |
+
year={2021}
|
| 78 |
+
}
|
| 79 |
+
# img2img-turbo
|
| 80 |
+
@article{parmar2024one,
|
| 81 |
+
title={One-step image translation with text-to-image models},
|
| 82 |
+
author={Parmar, Gaurav and Park, Taesung and Narasimhan, Srinivasa and Zhu, Jun-Yan},
|
| 83 |
+
journal={arXiv preprint arXiv:2403.12036},
|
| 84 |
+
year={2024}
|
| 85 |
+
}
|
| 86 |
+
# Lightendiffusion
|
| 87 |
+
@article{jiang2024lightendiffusion,
|
| 88 |
+
title={Lightendiffusion: Unsupervised low-light image enhancement with latent-retinex diffusion models},
|
| 89 |
+
author={Jiang, Hai and Luo, Ao and Liu, Xiaohong and Han, Songchen and Liu, Shuaicheng},
|
| 90 |
+
journal={arXiv preprint arXiv:2407.08939},
|
| 91 |
+
year={2024}
|
| 92 |
+
}
|
| 93 |
+
# KANet
|
| 94 |
+
@article{feng2024advancing,
|
| 95 |
+
title={Advancing real-world image dehazing: perspective, modules, and training},
|
| 96 |
+
author={Feng, Yuxin and Ma, Long and Meng, Xiaozhe and Zhou, Fan and Liu, Risheng and Su, Zhuo},
|
| 97 |
+
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
| 98 |
+
year={2024},
|
| 99 |
+
publisher={IEEE}
|
| 100 |
+
}
|
| 101 |
+
# IDT
|
| 102 |
+
@article{xiao2022image,
|
| 103 |
+
title={Image de-raining transformer},
|
| 104 |
+
author={Xiao, Jie and Fu, Xueyang and Liu, Aiping and Wu, Feng and Zha, Zheng-Jun},
|
| 105 |
+
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
| 106 |
+
volume={45},
|
| 107 |
+
number={11},
|
| 108 |
+
pages={12978--12995},
|
| 109 |
+
year={2022},
|
| 110 |
+
publisher={IEEE}
|
| 111 |
+
}
|
| 112 |
+
# SCUNet
|
| 113 |
+
@article{zhang2023practical,
|
| 114 |
+
author = {Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Fan, Deng-Ping and Timofte, Radu and Gool, Luc Van},
|
| 115 |
+
title = {Practical Blind Image Denoising via Swin-Conv-UNet and Data Synthesis},
|
| 116 |
+
journal = {Machine Intelligence Research},
|
| 117 |
+
DOI = {10.1007/s11633-023-1466-0},
|
| 118 |
+
url = {https://doi.org/10.1007/s11633-023-1466-0},
|
| 119 |
+
volume={20},
|
| 120 |
+
number={6},
|
| 121 |
+
pages={822--836},
|
| 122 |
+
year={2023},
|
| 123 |
+
publisher={Springer}
|
| 124 |
+
}
|
| 125 |
+
# S2Former
|
| 126 |
+
@inproceedings{chen2023sparse,
|
| 127 |
+
title={Sparse sampling transformer with uncertainty-driven ranking for unified removal of raindrops and rain streaks},
|
| 128 |
+
author={Chen, Sixiang and Ye, Tian and Bai, Jinbin and Chen, Erkang and Shi, Jun and Zhu, Lei},
|
| 129 |
+
booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
|
| 130 |
+
pages={13106--13117},
|
| 131 |
+
year={2023}
|
| 132 |
+
}
|
| 133 |
+
# RIDCP
|
| 134 |
+
@inproceedings{wu2023ridcp,
|
| 135 |
+
title={Ridcp: Revitalizing real image dehazing via high-quality codebook priors},
|
| 136 |
+
author={Wu, Rui-Qi and Duan, Zheng-Peng and Guo, Chun-Le and Chai, Zhi and Li, Chongyi},
|
| 137 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
| 138 |
+
pages={22282--22291},
|
| 139 |
+
year={2023}
|
| 140 |
+
}
|
| 141 |
+
# HVI-CIDNet
|
| 142 |
+
@article{yan2024you,
|
| 143 |
+
title={You only need one color space: An efficient network for low-light image enhancement},
|
| 144 |
+
author={Yan, Qingsen and Feng, Yixu and Zhang, Cheng and Wang, Pei and Wu, Peng and Dong, Wei and Sun, Jinqiu and Zhang, Yanning},
|
| 145 |
+
journal={arXiv preprint arXiv:2402.05809},
|
| 146 |
+
year={2024}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# Retinexformer
|
| 150 |
+
@inproceedings{retinexformer,
|
| 151 |
+
title={Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement},
|
| 152 |
+
author={Yuanhao Cai and Hao Bian and Jing Lin and Haoqian Wang and Radu Timofte and Yulun Zhang},
|
| 153 |
+
booktitle={ICCV},
|
| 154 |
+
year={2023}
|
| 155 |
+
}
|
JarvisIR/package/agent_tools.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: agent_tools
|
| 3 |
+
Version: 1.0.0
|
| 4 |
+
Summary: Agent Tools
|
| 5 |
+
Classifier: Programming Language :: Python :: 3
|
| 6 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 7 |
+
Requires-Python: >=3.8
|
JarvisIR/package/agent_tools.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
agent_tools/__init__.py
|
| 4 |
+
agent_tools/iqa_reward.py
|
| 5 |
+
agent_tools/restoration_toolkit.py
|
| 6 |
+
agent_tools/tool_testing_api.py
|
| 7 |
+
agent_tools.egg-info/PKG-INFO
|
| 8 |
+
agent_tools.egg-info/SOURCES.txt
|
| 9 |
+
agent_tools.egg-info/dependency_links.txt
|
| 10 |
+
agent_tools.egg-info/top_level.txt
|
| 11 |
+
agent_tools/ESRGAN/__init__.py
|
| 12 |
+
agent_tools/ESRGAN/inference.py
|
| 13 |
+
agent_tools/ESRGAN/realesrgan/__init__.py
|
| 14 |
+
agent_tools/ESRGAN/realesrgan/train.py
|
| 15 |
+
agent_tools/ESRGAN/realesrgan/utils.py
|
| 16 |
+
agent_tools/ESRGAN/realesrgan/archs/__init__.py
|
| 17 |
+
agent_tools/ESRGAN/realesrgan/archs/discriminator_arch.py
|
| 18 |
+
agent_tools/ESRGAN/realesrgan/archs/srvgg_arch.py
|
| 19 |
+
agent_tools/ESRGAN/realesrgan/data/__init__.py
|
| 20 |
+
agent_tools/ESRGAN/realesrgan/data/realesrgan_dataset.py
|
| 21 |
+
agent_tools/ESRGAN/realesrgan/data/realesrgan_paired_dataset.py
|
| 22 |
+
agent_tools/ESRGAN/realesrgan/models/__init__.py
|
| 23 |
+
agent_tools/ESRGAN/realesrgan/models/realesrgan_model.py
|
| 24 |
+
agent_tools/ESRGAN/realesrgan/models/realesrnet_model.py
|
| 25 |
+
agent_tools/HVICIDNet/inference.py
|
| 26 |
+
agent_tools/HVICIDNet/mods.py
|
| 27 |
+
agent_tools/HVICIDNet/wavelet.py
|
| 28 |
+
agent_tools/HVICIDNet/loss/loss_utils.py
|
| 29 |
+
agent_tools/HVICIDNet/loss/losses.py
|
| 30 |
+
agent_tools/HVICIDNet/loss/niqe_utils.py
|
| 31 |
+
agent_tools/HVICIDNet/loss/vgg_arch.py
|
| 32 |
+
agent_tools/HVICIDNet/net/CIDNet.py
|
| 33 |
+
agent_tools/HVICIDNet/net/HVI_transform.py
|
| 34 |
+
agent_tools/HVICIDNet/net/LCA.py
|
| 35 |
+
agent_tools/HVICIDNet/net/transformer_utils.py
|
| 36 |
+
agent_tools/IDT/__init__.py
|
| 37 |
+
agent_tools/IDT/inference.py
|
| 38 |
+
agent_tools/IDT/analyse/cal_rf_bf.py
|
| 39 |
+
agent_tools/IDT/models/ICRA.py
|
| 40 |
+
agent_tools/IDT/models/IDT.py
|
| 41 |
+
agent_tools/IDT/models/Uformer.py
|
| 42 |
+
agent_tools/IDT/models/__init__.py
|
| 43 |
+
agent_tools/IDT/models/atgan.py
|
| 44 |
+
agent_tools/IDT/models/ddm.py
|
| 45 |
+
agent_tools/IDT/models/onego_genotypes_searched.py
|
| 46 |
+
agent_tools/IDT/models/onego_ops_derain.py
|
| 47 |
+
agent_tools/IDT/models/onego_se_nets.py
|
| 48 |
+
agent_tools/IDT/models/onego_train_model.py
|
| 49 |
+
agent_tools/IDT/models/restoration.py
|
| 50 |
+
agent_tools/IDT/models/restormer.py
|
| 51 |
+
agent_tools/IDT/models/transformer2d.py
|
| 52 |
+
agent_tools/IDT/models/unet.py
|
| 53 |
+
agent_tools/IDT/utils/__init__.py
|
| 54 |
+
agent_tools/IDT/utils/logging.py
|
| 55 |
+
agent_tools/IDT/utils/optimize.py
|
| 56 |
+
agent_tools/IDT/utils/sampling.py
|
| 57 |
+
agent_tools/KANet/LD_model1.py
|
| 58 |
+
agent_tools/KANet/base_networks.py
|
| 59 |
+
agent_tools/KANet/deconv.py
|
| 60 |
+
agent_tools/KANet/inference.py
|
| 61 |
+
agent_tools/KANet/transweather_model.py
|
| 62 |
+
agent_tools/LightenDiffusion/inference.py
|
| 63 |
+
agent_tools/LightenDiffusion/models/__init__.py
|
| 64 |
+
agent_tools/LightenDiffusion/models/ddm.py
|
| 65 |
+
agent_tools/LightenDiffusion/models/decom.py
|
| 66 |
+
agent_tools/LightenDiffusion/models/restoration.py
|
| 67 |
+
agent_tools/LightenDiffusion/models/unet.py
|
| 68 |
+
agent_tools/LightenDiffusion/utils/__init__.py
|
| 69 |
+
agent_tools/LightenDiffusion/utils/logging.py
|
| 70 |
+
agent_tools/LightenDiffusion/utils/optimize.py
|
| 71 |
+
agent_tools/LightenDiffusion/utils/sampling.py
|
| 72 |
+
agent_tools/RIDCP/__init__.py
|
| 73 |
+
agent_tools/RIDCP/inference.py
|
| 74 |
+
agent_tools/RIDCP/setup.py
|
| 75 |
+
agent_tools/RIDCP/basicsr_ridcp/__init__.py
|
| 76 |
+
agent_tools/RIDCP/basicsr_ridcp/active_codebook.py
|
| 77 |
+
agent_tools/RIDCP/basicsr_ridcp/test.py
|
| 78 |
+
agent_tools/RIDCP/basicsr_ridcp/train.py
|
| 79 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/__init__.py
|
| 80 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/arch_util.py
|
| 81 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/dehaze_vq_weight_arch.py
|
| 82 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/discriminator_arch.py
|
| 83 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/network_swinir.py
|
| 84 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/ridcp_utils.py
|
| 85 |
+
agent_tools/RIDCP/basicsr_ridcp/archs/vgg_arch.py
|
| 86 |
+
agent_tools/RIDCP/basicsr_ridcp/data/__init__.py
|
| 87 |
+
agent_tools/RIDCP/basicsr_ridcp/data/data_sampler.py
|
| 88 |
+
agent_tools/RIDCP/basicsr_ridcp/data/data_util.py
|
| 89 |
+
agent_tools/RIDCP/basicsr_ridcp/data/haze_online_dataset.py
|
| 90 |
+
agent_tools/RIDCP/basicsr_ridcp/data/prefetch_dataloader.py
|
| 91 |
+
agent_tools/RIDCP/basicsr_ridcp/data/transforms.py
|
| 92 |
+
agent_tools/RIDCP/basicsr_ridcp/losses/__init__.py
|
| 93 |
+
agent_tools/RIDCP/basicsr_ridcp/losses/loss_util.py
|
| 94 |
+
agent_tools/RIDCP/basicsr_ridcp/losses/losses.py
|
| 95 |
+
agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_fid_folder.py
|
| 96 |
+
agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_fid_stats_from_datasets.py
|
| 97 |
+
agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_lpips.py
|
| 98 |
+
agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_niqe.py
|
| 99 |
+
agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_psnr_ssim.py
|
| 100 |
+
agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_stylegan2_fid.py
|
| 101 |
+
agent_tools/RIDCP/basicsr_ridcp/models/__init__.py
|
| 102 |
+
agent_tools/RIDCP/basicsr_ridcp/models/base_model.py
|
| 103 |
+
agent_tools/RIDCP/basicsr_ridcp/models/dehaze_vq_model.py
|
| 104 |
+
agent_tools/RIDCP/basicsr_ridcp/models/lr_scheduler.py
|
| 105 |
+
agent_tools/RIDCP/basicsr_ridcp/ops/__init__.py
|
| 106 |
+
agent_tools/RIDCP/basicsr_ridcp/ops/dcn/__init__.py
|
| 107 |
+
agent_tools/RIDCP/basicsr_ridcp/ops/dcn/deform_conv.py
|
| 108 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/__init__.py
|
| 109 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/diffjpeg.py
|
| 110 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/dist_util.py
|
| 111 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/download_util.py
|
| 112 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/face_util.py
|
| 113 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/file_client.py
|
| 114 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/flow_util.py
|
| 115 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/img_process_util.py
|
| 116 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/img_util.py
|
| 117 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/lmdb_util.py
|
| 118 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/logger.py
|
| 119 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/matlab_functions.py
|
| 120 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/misc.py
|
| 121 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/options.py
|
| 122 |
+
agent_tools/RIDCP/basicsr_ridcp/utils/registry.py
|
| 123 |
+
agent_tools/RIDCP/utils/utils_image.py
|
| 124 |
+
agent_tools/Retinexformer/__init__.py
|
| 125 |
+
agent_tools/Retinexformer/inference.py
|
| 126 |
+
agent_tools/Retinexformer/setup.py
|
| 127 |
+
agent_tools/Retinexformer/Enhancement/utils.py
|
| 128 |
+
agent_tools/Retinexformer/basicsr_retinexformer/test.py
|
| 129 |
+
agent_tools/Retinexformer/basicsr_retinexformer/train.py
|
| 130 |
+
agent_tools/Retinexformer/basicsr_retinexformer/version.py
|
| 131 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/SDSD_image_dataset.py
|
| 132 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/SID_image_dataset.py
|
| 133 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/SMID_image_dataset.py
|
| 134 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/__init__.py
|
| 135 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/data_sampler.py
|
| 136 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/data_util.py
|
| 137 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/ffhq_dataset.py
|
| 138 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/paired_image_dataset.py
|
| 139 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/prefetch_dataloader.py
|
| 140 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/reds_dataset.py
|
| 141 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/single_image_dataset.py
|
| 142 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/transforms.py
|
| 143 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/util.py
|
| 144 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/video_test_dataset.py
|
| 145 |
+
agent_tools/Retinexformer/basicsr_retinexformer/data/vimeo90k_dataset.py
|
| 146 |
+
agent_tools/Retinexformer/basicsr_retinexformer/metrics/__init__.py
|
| 147 |
+
agent_tools/Retinexformer/basicsr_retinexformer/metrics/fid.py
|
| 148 |
+
agent_tools/Retinexformer/basicsr_retinexformer/metrics/metric_util.py
|
| 149 |
+
agent_tools/Retinexformer/basicsr_retinexformer/metrics/niqe.py
|
| 150 |
+
agent_tools/Retinexformer/basicsr_retinexformer/metrics/psnr_ssim.py
|
| 151 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/__init__.py
|
| 152 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/base_model.py
|
| 153 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/image_restoration_model.py
|
| 154 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/lr_scheduler.py
|
| 155 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/archs/MST_Plus_Plus_arch.py
|
| 156 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/archs/RetinexFormer_arch.py
|
| 157 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/archs/__init__.py
|
| 158 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/archs/arch_util.py
|
| 159 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/archs/layers.py
|
| 160 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/losses/__init__.py
|
| 161 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/losses/loss_util.py
|
| 162 |
+
agent_tools/Retinexformer/basicsr_retinexformer/models/losses/losses.py
|
| 163 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/__init__.py
|
| 164 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/bundle_submissions.py
|
| 165 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/create_lmdb.py
|
| 166 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/dist_util.py
|
| 167 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/download_util.py
|
| 168 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/face_util.py
|
| 169 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/file_client.py
|
| 170 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/flow_util.py
|
| 171 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/img_util.py
|
| 172 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/lmdb_util.py
|
| 173 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/logger.py
|
| 174 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/matlab_functions.py
|
| 175 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/misc.py
|
| 176 |
+
agent_tools/Retinexformer/basicsr_retinexformer/utils/options.py
|
| 177 |
+
agent_tools/S2Former/UDR_S2Former.py
|
| 178 |
+
agent_tools/S2Former/base_net_snow.py
|
| 179 |
+
agent_tools/S2Former/condconv.py
|
| 180 |
+
agent_tools/S2Former/inference.py
|
| 181 |
+
agent_tools/SCUNet/__init__.py
|
| 182 |
+
agent_tools/SCUNet/inference.py
|
| 183 |
+
agent_tools/SCUNet/models/network_scunet.py
|
| 184 |
+
agent_tools/SCUNet/utils/utils_image.py
|
| 185 |
+
agent_tools/SnowMaster/inference.py
|
| 186 |
+
agent_tools/SnowMaster/nafnet.py
|
| 187 |
+
agent_tools/SnowMaster/nafnet_utils.py
|
| 188 |
+
agent_tools/img2img_turbo/__init__.py
|
| 189 |
+
agent_tools/img2img_turbo/inference.py
|
| 190 |
+
agent_tools/img2img_turbo/src/cyclegan_turbo.py
|
| 191 |
+
agent_tools/img2img_turbo/src/model.py
|
| 192 |
+
agent_tools/img2img_turbo/src/my_utils/training_utils.py
|
JarvisIR/package/agent_tools.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
JarvisIR/package/agent_tools.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
agent_tools
|
JarvisIR/package/agent_tools/ESRGAN/__init__.py
ADDED
|
File without changes
|
JarvisIR/package/agent_tools/ESRGAN/inference.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import os
|
| 3 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 4 |
+
|
| 5 |
+
from .realesrgan import RealESRGANer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_esrgan_model(model_path, device):
|
| 9 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
| 10 |
+
netscale = 4
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# use dni to control the denoise strength
|
| 14 |
+
dni_weight = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# restorer
|
| 18 |
+
upsampler = RealESRGANer(
|
| 19 |
+
scale=netscale,
|
| 20 |
+
model_path=model_path,
|
| 21 |
+
dni_weight=dni_weight,
|
| 22 |
+
model=model,
|
| 23 |
+
tile=0,
|
| 24 |
+
tile_pad=10,
|
| 25 |
+
pre_pad=0,
|
| 26 |
+
half=True,
|
| 27 |
+
device=device)
|
| 28 |
+
return upsampler
|
| 29 |
+
|
| 30 |
+
def esrgan_predict(upsampler, input_image, output_dir, device,):
|
| 31 |
+
|
| 32 |
+
# determine models according to model names
|
| 33 |
+
outscale = 4 # the final upsampling scale
|
| 34 |
+
imgname, extension = os.path.splitext(os.path.basename(input_image))
|
| 35 |
+
|
| 36 |
+
img = cv2.imread(input_image, cv2.IMREAD_COLOR)
|
| 37 |
+
h, w, _ = img.shape
|
| 38 |
+
try:
|
| 39 |
+
output, _ = upsampler.enhance(img, outscale=outscale)
|
| 40 |
+
except RuntimeError as error:
|
| 41 |
+
print('Error', error)
|
| 42 |
+
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
| 43 |
+
# resize back to the original resolution
|
| 44 |
+
output = cv2.resize(output, (w, h), interpolation=cv2.INTER_CUBIC)
|
| 45 |
+
|
| 46 |
+
save_path = os.path.join(output_dir, f'{imgname}.png')
|
| 47 |
+
cv2.imwrite(save_path, output)
|
| 48 |
+
|
| 49 |
+
return save_path
|
JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus.yml
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: finetune_RealESRGANx4plus_400k
|
| 3 |
+
model_type: RealESRGANModel
|
| 4 |
+
scale: 4
|
| 5 |
+
num_gpu: auto
|
| 6 |
+
manual_seed: 0
|
| 7 |
+
|
| 8 |
+
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
| 9 |
+
# USM the ground-truth
|
| 10 |
+
l1_gt_usm: True
|
| 11 |
+
percep_gt_usm: True
|
| 12 |
+
gan_gt_usm: False
|
| 13 |
+
|
| 14 |
+
# the first degradation process
|
| 15 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
| 16 |
+
resize_range: [0.15, 1.5]
|
| 17 |
+
gaussian_noise_prob: 0.5
|
| 18 |
+
noise_range: [1, 30]
|
| 19 |
+
poisson_scale_range: [0.05, 3]
|
| 20 |
+
gray_noise_prob: 0.4
|
| 21 |
+
jpeg_range: [30, 95]
|
| 22 |
+
|
| 23 |
+
# the second degradation process
|
| 24 |
+
second_blur_prob: 0.8
|
| 25 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
| 26 |
+
resize_range2: [0.3, 1.2]
|
| 27 |
+
gaussian_noise_prob2: 0.5
|
| 28 |
+
noise_range2: [1, 25]
|
| 29 |
+
poisson_scale_range2: [0.05, 2.5]
|
| 30 |
+
gray_noise_prob2: 0.4
|
| 31 |
+
jpeg_range2: [30, 95]
|
| 32 |
+
|
| 33 |
+
gt_size: 256
|
| 34 |
+
queue_size: 180
|
| 35 |
+
|
| 36 |
+
# dataset and data loader settings
|
| 37 |
+
datasets:
|
| 38 |
+
train:
|
| 39 |
+
name: DF2K+OST
|
| 40 |
+
type: RealESRGANDataset
|
| 41 |
+
dataroot_gt: datasets/DF2K
|
| 42 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
| 43 |
+
io_backend:
|
| 44 |
+
type: disk
|
| 45 |
+
|
| 46 |
+
blur_kernel_size: 21
|
| 47 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 48 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 49 |
+
sinc_prob: 0.1
|
| 50 |
+
blur_sigma: [0.2, 3]
|
| 51 |
+
betag_range: [0.5, 4]
|
| 52 |
+
betap_range: [1, 2]
|
| 53 |
+
|
| 54 |
+
blur_kernel_size2: 21
|
| 55 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 56 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 57 |
+
sinc_prob2: 0.1
|
| 58 |
+
blur_sigma2: [0.2, 1.5]
|
| 59 |
+
betag_range2: [0.5, 4]
|
| 60 |
+
betap_range2: [1, 2]
|
| 61 |
+
|
| 62 |
+
final_sinc_prob: 0.8
|
| 63 |
+
|
| 64 |
+
gt_size: 256
|
| 65 |
+
use_hflip: True
|
| 66 |
+
use_rot: False
|
| 67 |
+
|
| 68 |
+
# data loader
|
| 69 |
+
use_shuffle: true
|
| 70 |
+
num_worker_per_gpu: 5
|
| 71 |
+
batch_size_per_gpu: 12
|
| 72 |
+
dataset_enlarge_ratio: 1
|
| 73 |
+
prefetch_mode: ~
|
| 74 |
+
|
| 75 |
+
# Uncomment these for validation
|
| 76 |
+
# val:
|
| 77 |
+
# name: validation
|
| 78 |
+
# type: PairedImageDataset
|
| 79 |
+
# dataroot_gt: path_to_gt
|
| 80 |
+
# dataroot_lq: path_to_lq
|
| 81 |
+
# io_backend:
|
| 82 |
+
# type: disk
|
| 83 |
+
|
| 84 |
+
# network structures
|
| 85 |
+
network_g:
|
| 86 |
+
type: RRDBNet
|
| 87 |
+
num_in_ch: 3
|
| 88 |
+
num_out_ch: 3
|
| 89 |
+
num_feat: 64
|
| 90 |
+
num_block: 23
|
| 91 |
+
num_grow_ch: 32
|
| 92 |
+
|
| 93 |
+
network_d:
|
| 94 |
+
type: UNetDiscriminatorSN
|
| 95 |
+
num_in_ch: 3
|
| 96 |
+
num_feat: 64
|
| 97 |
+
skip_connection: True
|
| 98 |
+
|
| 99 |
+
# path
|
| 100 |
+
path:
|
| 101 |
+
# use the pre-trained Real-ESRNet model
|
| 102 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
| 103 |
+
param_key_g: params_ema
|
| 104 |
+
strict_load_g: true
|
| 105 |
+
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
|
| 106 |
+
param_key_d: params
|
| 107 |
+
strict_load_d: true
|
| 108 |
+
resume_state: ~
|
| 109 |
+
|
| 110 |
+
# training settings
|
| 111 |
+
train:
|
| 112 |
+
ema_decay: 0.999
|
| 113 |
+
optim_g:
|
| 114 |
+
type: Adam
|
| 115 |
+
lr: !!float 1e-4
|
| 116 |
+
weight_decay: 0
|
| 117 |
+
betas: [0.9, 0.99]
|
| 118 |
+
optim_d:
|
| 119 |
+
type: Adam
|
| 120 |
+
lr: !!float 1e-4
|
| 121 |
+
weight_decay: 0
|
| 122 |
+
betas: [0.9, 0.99]
|
| 123 |
+
|
| 124 |
+
scheduler:
|
| 125 |
+
type: MultiStepLR
|
| 126 |
+
milestones: [400000]
|
| 127 |
+
gamma: 0.5
|
| 128 |
+
|
| 129 |
+
total_iter: 400000
|
| 130 |
+
warmup_iter: -1 # no warm up
|
| 131 |
+
|
| 132 |
+
# losses
|
| 133 |
+
pixel_opt:
|
| 134 |
+
type: L1Loss
|
| 135 |
+
loss_weight: 1.0
|
| 136 |
+
reduction: mean
|
| 137 |
+
# perceptual loss (content and style losses)
|
| 138 |
+
perceptual_opt:
|
| 139 |
+
type: PerceptualLoss
|
| 140 |
+
layer_weights:
|
| 141 |
+
# before relu
|
| 142 |
+
'conv1_2': 0.1
|
| 143 |
+
'conv2_2': 0.1
|
| 144 |
+
'conv3_4': 1
|
| 145 |
+
'conv4_4': 1
|
| 146 |
+
'conv5_4': 1
|
| 147 |
+
vgg_type: vgg19
|
| 148 |
+
use_input_norm: true
|
| 149 |
+
perceptual_weight: !!float 1.0
|
| 150 |
+
style_weight: 0
|
| 151 |
+
range_norm: false
|
| 152 |
+
criterion: l1
|
| 153 |
+
# gan loss
|
| 154 |
+
gan_opt:
|
| 155 |
+
type: GANLoss
|
| 156 |
+
gan_type: vanilla
|
| 157 |
+
real_label_val: 1.0
|
| 158 |
+
fake_label_val: 0.0
|
| 159 |
+
loss_weight: !!float 1e-1
|
| 160 |
+
|
| 161 |
+
net_d_iters: 1
|
| 162 |
+
net_d_init_iters: 0
|
| 163 |
+
|
| 164 |
+
# Uncomment these for validation
|
| 165 |
+
# validation settings
|
| 166 |
+
# val:
|
| 167 |
+
# val_freq: !!float 5e3
|
| 168 |
+
# save_img: True
|
| 169 |
+
|
| 170 |
+
# metrics:
|
| 171 |
+
# psnr: # metric name
|
| 172 |
+
# type: calculate_psnr
|
| 173 |
+
# crop_border: 4
|
| 174 |
+
# test_y_channel: false
|
| 175 |
+
|
| 176 |
+
# logging settings
|
| 177 |
+
logger:
|
| 178 |
+
print_freq: 100
|
| 179 |
+
save_checkpoint_freq: !!float 5e3
|
| 180 |
+
use_tb_logger: true
|
| 181 |
+
wandb:
|
| 182 |
+
project: ~
|
| 183 |
+
resume_id: ~
|
| 184 |
+
|
| 185 |
+
# dist training settings
|
| 186 |
+
dist_params:
|
| 187 |
+
backend: nccl
|
| 188 |
+
port: 29500
|
JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus_pairdata.yml
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: finetune_RealESRGANx4plus_400k_pairdata
|
| 3 |
+
model_type: RealESRGANModel
|
| 4 |
+
scale: 4
|
| 5 |
+
num_gpu: auto
|
| 6 |
+
manual_seed: 0
|
| 7 |
+
|
| 8 |
+
# USM the ground-truth
|
| 9 |
+
l1_gt_usm: True
|
| 10 |
+
percep_gt_usm: True
|
| 11 |
+
gan_gt_usm: False
|
| 12 |
+
|
| 13 |
+
high_order_degradation: False # do not use the high-order degradation generation process
|
| 14 |
+
|
| 15 |
+
# dataset and data loader settings
|
| 16 |
+
datasets:
|
| 17 |
+
train:
|
| 18 |
+
name: DIV2K
|
| 19 |
+
type: RealESRGANPairedDataset
|
| 20 |
+
dataroot_gt: datasets/DF2K
|
| 21 |
+
dataroot_lq: datasets/DF2K
|
| 22 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
|
| 23 |
+
io_backend:
|
| 24 |
+
type: disk
|
| 25 |
+
|
| 26 |
+
gt_size: 256
|
| 27 |
+
use_hflip: True
|
| 28 |
+
use_rot: False
|
| 29 |
+
|
| 30 |
+
# data loader
|
| 31 |
+
use_shuffle: true
|
| 32 |
+
num_worker_per_gpu: 5
|
| 33 |
+
batch_size_per_gpu: 12
|
| 34 |
+
dataset_enlarge_ratio: 1
|
| 35 |
+
prefetch_mode: ~
|
| 36 |
+
|
| 37 |
+
# Uncomment these for validation
|
| 38 |
+
# val:
|
| 39 |
+
# name: validation
|
| 40 |
+
# type: PairedImageDataset
|
| 41 |
+
# dataroot_gt: path_to_gt
|
| 42 |
+
# dataroot_lq: path_to_lq
|
| 43 |
+
# io_backend:
|
| 44 |
+
# type: disk
|
| 45 |
+
|
| 46 |
+
# network structures
|
| 47 |
+
network_g:
|
| 48 |
+
type: RRDBNet
|
| 49 |
+
num_in_ch: 3
|
| 50 |
+
num_out_ch: 3
|
| 51 |
+
num_feat: 64
|
| 52 |
+
num_block: 23
|
| 53 |
+
num_grow_ch: 32
|
| 54 |
+
|
| 55 |
+
network_d:
|
| 56 |
+
type: UNetDiscriminatorSN
|
| 57 |
+
num_in_ch: 3
|
| 58 |
+
num_feat: 64
|
| 59 |
+
skip_connection: True
|
| 60 |
+
|
| 61 |
+
# path
|
| 62 |
+
path:
|
| 63 |
+
# use the pre-trained Real-ESRNet model
|
| 64 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
| 65 |
+
param_key_g: params_ema
|
| 66 |
+
strict_load_g: true
|
| 67 |
+
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
|
| 68 |
+
param_key_d: params
|
| 69 |
+
strict_load_d: true
|
| 70 |
+
resume_state: ~
|
| 71 |
+
|
| 72 |
+
# training settings
|
| 73 |
+
train:
|
| 74 |
+
ema_decay: 0.999
|
| 75 |
+
optim_g:
|
| 76 |
+
type: Adam
|
| 77 |
+
lr: !!float 1e-4
|
| 78 |
+
weight_decay: 0
|
| 79 |
+
betas: [0.9, 0.99]
|
| 80 |
+
optim_d:
|
| 81 |
+
type: Adam
|
| 82 |
+
lr: !!float 1e-4
|
| 83 |
+
weight_decay: 0
|
| 84 |
+
betas: [0.9, 0.99]
|
| 85 |
+
|
| 86 |
+
scheduler:
|
| 87 |
+
type: MultiStepLR
|
| 88 |
+
milestones: [400000]
|
| 89 |
+
gamma: 0.5
|
| 90 |
+
|
| 91 |
+
total_iter: 400000
|
| 92 |
+
warmup_iter: -1 # no warm up
|
| 93 |
+
|
| 94 |
+
# losses
|
| 95 |
+
pixel_opt:
|
| 96 |
+
type: L1Loss
|
| 97 |
+
loss_weight: 1.0
|
| 98 |
+
reduction: mean
|
| 99 |
+
# perceptual loss (content and style losses)
|
| 100 |
+
perceptual_opt:
|
| 101 |
+
type: PerceptualLoss
|
| 102 |
+
layer_weights:
|
| 103 |
+
# before relu
|
| 104 |
+
'conv1_2': 0.1
|
| 105 |
+
'conv2_2': 0.1
|
| 106 |
+
'conv3_4': 1
|
| 107 |
+
'conv4_4': 1
|
| 108 |
+
'conv5_4': 1
|
| 109 |
+
vgg_type: vgg19
|
| 110 |
+
use_input_norm: true
|
| 111 |
+
perceptual_weight: !!float 1.0
|
| 112 |
+
style_weight: 0
|
| 113 |
+
range_norm: false
|
| 114 |
+
criterion: l1
|
| 115 |
+
# gan loss
|
| 116 |
+
gan_opt:
|
| 117 |
+
type: GANLoss
|
| 118 |
+
gan_type: vanilla
|
| 119 |
+
real_label_val: 1.0
|
| 120 |
+
fake_label_val: 0.0
|
| 121 |
+
loss_weight: !!float 1e-1
|
| 122 |
+
|
| 123 |
+
net_d_iters: 1
|
| 124 |
+
net_d_init_iters: 0
|
| 125 |
+
|
| 126 |
+
# Uncomment these for validation
|
| 127 |
+
# validation settings
|
| 128 |
+
# val:
|
| 129 |
+
# val_freq: !!float 5e3
|
| 130 |
+
# save_img: True
|
| 131 |
+
|
| 132 |
+
# metrics:
|
| 133 |
+
# psnr: # metric name
|
| 134 |
+
# type: calculate_psnr
|
| 135 |
+
# crop_border: 4
|
| 136 |
+
# test_y_channel: false
|
| 137 |
+
|
| 138 |
+
# logging settings
|
| 139 |
+
logger:
|
| 140 |
+
print_freq: 100
|
| 141 |
+
save_checkpoint_freq: !!float 5e3
|
| 142 |
+
use_tb_logger: true
|
| 143 |
+
wandb:
|
| 144 |
+
project: ~
|
| 145 |
+
resume_id: ~
|
| 146 |
+
|
| 147 |
+
# dist training settings
|
| 148 |
+
dist_params:
|
| 149 |
+
backend: nccl
|
| 150 |
+
port: 29500
|
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x2plus.yml
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: train_RealESRGANx2plus_400k_B12G4
|
| 3 |
+
model_type: RealESRGANModel
|
| 4 |
+
scale: 2
|
| 5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
| 6 |
+
manual_seed: 0
|
| 7 |
+
|
| 8 |
+
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
| 9 |
+
# USM the ground-truth
|
| 10 |
+
l1_gt_usm: True
|
| 11 |
+
percep_gt_usm: True
|
| 12 |
+
gan_gt_usm: False
|
| 13 |
+
|
| 14 |
+
# the first degradation process
|
| 15 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
| 16 |
+
resize_range: [0.15, 1.5]
|
| 17 |
+
gaussian_noise_prob: 0.5
|
| 18 |
+
noise_range: [1, 30]
|
| 19 |
+
poisson_scale_range: [0.05, 3]
|
| 20 |
+
gray_noise_prob: 0.4
|
| 21 |
+
jpeg_range: [30, 95]
|
| 22 |
+
|
| 23 |
+
# the second degradation process
|
| 24 |
+
second_blur_prob: 0.8
|
| 25 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
| 26 |
+
resize_range2: [0.3, 1.2]
|
| 27 |
+
gaussian_noise_prob2: 0.5
|
| 28 |
+
noise_range2: [1, 25]
|
| 29 |
+
poisson_scale_range2: [0.05, 2.5]
|
| 30 |
+
gray_noise_prob2: 0.4
|
| 31 |
+
jpeg_range2: [30, 95]
|
| 32 |
+
|
| 33 |
+
gt_size: 256
|
| 34 |
+
queue_size: 180
|
| 35 |
+
|
| 36 |
+
# dataset and data loader settings
|
| 37 |
+
datasets:
|
| 38 |
+
train:
|
| 39 |
+
name: DF2K+OST
|
| 40 |
+
type: RealESRGANDataset
|
| 41 |
+
dataroot_gt: datasets/DF2K
|
| 42 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
| 43 |
+
io_backend:
|
| 44 |
+
type: disk
|
| 45 |
+
|
| 46 |
+
blur_kernel_size: 21
|
| 47 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 48 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 49 |
+
sinc_prob: 0.1
|
| 50 |
+
blur_sigma: [0.2, 3]
|
| 51 |
+
betag_range: [0.5, 4]
|
| 52 |
+
betap_range: [1, 2]
|
| 53 |
+
|
| 54 |
+
blur_kernel_size2: 21
|
| 55 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 56 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 57 |
+
sinc_prob2: 0.1
|
| 58 |
+
blur_sigma2: [0.2, 1.5]
|
| 59 |
+
betag_range2: [0.5, 4]
|
| 60 |
+
betap_range2: [1, 2]
|
| 61 |
+
|
| 62 |
+
final_sinc_prob: 0.8
|
| 63 |
+
|
| 64 |
+
gt_size: 256
|
| 65 |
+
use_hflip: True
|
| 66 |
+
use_rot: False
|
| 67 |
+
|
| 68 |
+
# data loader
|
| 69 |
+
use_shuffle: true
|
| 70 |
+
num_worker_per_gpu: 5
|
| 71 |
+
batch_size_per_gpu: 12
|
| 72 |
+
dataset_enlarge_ratio: 1
|
| 73 |
+
prefetch_mode: ~
|
| 74 |
+
|
| 75 |
+
# Uncomment these for validation
|
| 76 |
+
# val:
|
| 77 |
+
# name: validation
|
| 78 |
+
# type: PairedImageDataset
|
| 79 |
+
# dataroot_gt: path_to_gt
|
| 80 |
+
# dataroot_lq: path_to_lq
|
| 81 |
+
# io_backend:
|
| 82 |
+
# type: disk
|
| 83 |
+
|
| 84 |
+
# network structures
|
| 85 |
+
network_g:
|
| 86 |
+
type: RRDBNet
|
| 87 |
+
num_in_ch: 3
|
| 88 |
+
num_out_ch: 3
|
| 89 |
+
num_feat: 64
|
| 90 |
+
num_block: 23
|
| 91 |
+
num_grow_ch: 32
|
| 92 |
+
scale: 2
|
| 93 |
+
|
| 94 |
+
network_d:
|
| 95 |
+
type: UNetDiscriminatorSN
|
| 96 |
+
num_in_ch: 3
|
| 97 |
+
num_feat: 64
|
| 98 |
+
skip_connection: True
|
| 99 |
+
|
| 100 |
+
# path
|
| 101 |
+
path:
|
| 102 |
+
# use the pre-trained Real-ESRNet model
|
| 103 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
|
| 104 |
+
param_key_g: params_ema
|
| 105 |
+
strict_load_g: true
|
| 106 |
+
resume_state: ~
|
| 107 |
+
|
| 108 |
+
# training settings
|
| 109 |
+
train:
|
| 110 |
+
ema_decay: 0.999
|
| 111 |
+
optim_g:
|
| 112 |
+
type: Adam
|
| 113 |
+
lr: !!float 1e-4
|
| 114 |
+
weight_decay: 0
|
| 115 |
+
betas: [0.9, 0.99]
|
| 116 |
+
optim_d:
|
| 117 |
+
type: Adam
|
| 118 |
+
lr: !!float 1e-4
|
| 119 |
+
weight_decay: 0
|
| 120 |
+
betas: [0.9, 0.99]
|
| 121 |
+
|
| 122 |
+
scheduler:
|
| 123 |
+
type: MultiStepLR
|
| 124 |
+
milestones: [400000]
|
| 125 |
+
gamma: 0.5
|
| 126 |
+
|
| 127 |
+
total_iter: 400000
|
| 128 |
+
warmup_iter: -1 # no warm up
|
| 129 |
+
|
| 130 |
+
# losses
|
| 131 |
+
pixel_opt:
|
| 132 |
+
type: L1Loss
|
| 133 |
+
loss_weight: 1.0
|
| 134 |
+
reduction: mean
|
| 135 |
+
# perceptual loss (content and style losses)
|
| 136 |
+
perceptual_opt:
|
| 137 |
+
type: PerceptualLoss
|
| 138 |
+
layer_weights:
|
| 139 |
+
# before relu
|
| 140 |
+
'conv1_2': 0.1
|
| 141 |
+
'conv2_2': 0.1
|
| 142 |
+
'conv3_4': 1
|
| 143 |
+
'conv4_4': 1
|
| 144 |
+
'conv5_4': 1
|
| 145 |
+
vgg_type: vgg19
|
| 146 |
+
use_input_norm: true
|
| 147 |
+
perceptual_weight: !!float 1.0
|
| 148 |
+
style_weight: 0
|
| 149 |
+
range_norm: false
|
| 150 |
+
criterion: l1
|
| 151 |
+
# gan loss
|
| 152 |
+
gan_opt:
|
| 153 |
+
type: GANLoss
|
| 154 |
+
gan_type: vanilla
|
| 155 |
+
real_label_val: 1.0
|
| 156 |
+
fake_label_val: 0.0
|
| 157 |
+
loss_weight: !!float 1e-1
|
| 158 |
+
|
| 159 |
+
net_d_iters: 1
|
| 160 |
+
net_d_init_iters: 0
|
| 161 |
+
|
| 162 |
+
# Uncomment these for validation
|
| 163 |
+
# validation settings
|
| 164 |
+
# val:
|
| 165 |
+
# val_freq: !!float 5e3
|
| 166 |
+
# save_img: True
|
| 167 |
+
|
| 168 |
+
# metrics:
|
| 169 |
+
# psnr: # metric name
|
| 170 |
+
# type: calculate_psnr
|
| 171 |
+
# crop_border: 4
|
| 172 |
+
# test_y_channel: false
|
| 173 |
+
|
| 174 |
+
# logging settings
|
| 175 |
+
logger:
|
| 176 |
+
print_freq: 100
|
| 177 |
+
save_checkpoint_freq: !!float 5e3
|
| 178 |
+
use_tb_logger: true
|
| 179 |
+
wandb:
|
| 180 |
+
project: ~
|
| 181 |
+
resume_id: ~
|
| 182 |
+
|
| 183 |
+
# dist training settings
|
| 184 |
+
dist_params:
|
| 185 |
+
backend: nccl
|
| 186 |
+
port: 29500
|
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x4plus.yml
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: train_RealESRGANx4plus_400k_B12G4
|
| 3 |
+
model_type: RealESRGANModel
|
| 4 |
+
scale: 4
|
| 5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
| 6 |
+
manual_seed: 0
|
| 7 |
+
|
| 8 |
+
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
| 9 |
+
# USM the ground-truth
|
| 10 |
+
l1_gt_usm: True
|
| 11 |
+
percep_gt_usm: True
|
| 12 |
+
gan_gt_usm: False
|
| 13 |
+
|
| 14 |
+
# the first degradation process
|
| 15 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
| 16 |
+
resize_range: [0.15, 1.5]
|
| 17 |
+
gaussian_noise_prob: 0.5
|
| 18 |
+
noise_range: [1, 30]
|
| 19 |
+
poisson_scale_range: [0.05, 3]
|
| 20 |
+
gray_noise_prob: 0.4
|
| 21 |
+
jpeg_range: [30, 95]
|
| 22 |
+
|
| 23 |
+
# the second degradation process
|
| 24 |
+
second_blur_prob: 0.8
|
| 25 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
| 26 |
+
resize_range2: [0.3, 1.2]
|
| 27 |
+
gaussian_noise_prob2: 0.5
|
| 28 |
+
noise_range2: [1, 25]
|
| 29 |
+
poisson_scale_range2: [0.05, 2.5]
|
| 30 |
+
gray_noise_prob2: 0.4
|
| 31 |
+
jpeg_range2: [30, 95]
|
| 32 |
+
|
| 33 |
+
gt_size: 256
|
| 34 |
+
queue_size: 180
|
| 35 |
+
|
| 36 |
+
# dataset and data loader settings
|
| 37 |
+
datasets:
|
| 38 |
+
train:
|
| 39 |
+
name: DF2K+OST
|
| 40 |
+
type: RealESRGANDataset
|
| 41 |
+
dataroot_gt: datasets/DF2K
|
| 42 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
| 43 |
+
io_backend:
|
| 44 |
+
type: disk
|
| 45 |
+
|
| 46 |
+
blur_kernel_size: 21
|
| 47 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 48 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 49 |
+
sinc_prob: 0.1
|
| 50 |
+
blur_sigma: [0.2, 3]
|
| 51 |
+
betag_range: [0.5, 4]
|
| 52 |
+
betap_range: [1, 2]
|
| 53 |
+
|
| 54 |
+
blur_kernel_size2: 21
|
| 55 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 56 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 57 |
+
sinc_prob2: 0.1
|
| 58 |
+
blur_sigma2: [0.2, 1.5]
|
| 59 |
+
betag_range2: [0.5, 4]
|
| 60 |
+
betap_range2: [1, 2]
|
| 61 |
+
|
| 62 |
+
final_sinc_prob: 0.8
|
| 63 |
+
|
| 64 |
+
gt_size: 256
|
| 65 |
+
use_hflip: True
|
| 66 |
+
use_rot: False
|
| 67 |
+
|
| 68 |
+
# data loader
|
| 69 |
+
use_shuffle: true
|
| 70 |
+
num_worker_per_gpu: 5
|
| 71 |
+
batch_size_per_gpu: 12
|
| 72 |
+
dataset_enlarge_ratio: 1
|
| 73 |
+
prefetch_mode: ~
|
| 74 |
+
|
| 75 |
+
# Uncomment these for validation
|
| 76 |
+
# val:
|
| 77 |
+
# name: validation
|
| 78 |
+
# type: PairedImageDataset
|
| 79 |
+
# dataroot_gt: path_to_gt
|
| 80 |
+
# dataroot_lq: path_to_lq
|
| 81 |
+
# io_backend:
|
| 82 |
+
# type: disk
|
| 83 |
+
|
| 84 |
+
# network structures
|
| 85 |
+
network_g:
|
| 86 |
+
type: RRDBNet
|
| 87 |
+
num_in_ch: 3
|
| 88 |
+
num_out_ch: 3
|
| 89 |
+
num_feat: 64
|
| 90 |
+
num_block: 23
|
| 91 |
+
num_grow_ch: 32
|
| 92 |
+
|
| 93 |
+
network_d:
|
| 94 |
+
type: UNetDiscriminatorSN
|
| 95 |
+
num_in_ch: 3
|
| 96 |
+
num_feat: 64
|
| 97 |
+
skip_connection: True
|
| 98 |
+
|
| 99 |
+
# path
|
| 100 |
+
path:
|
| 101 |
+
# use the pre-trained Real-ESRNet model
|
| 102 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
| 103 |
+
param_key_g: params_ema
|
| 104 |
+
strict_load_g: true
|
| 105 |
+
resume_state: ~
|
| 106 |
+
|
| 107 |
+
# training settings
|
| 108 |
+
train:
|
| 109 |
+
ema_decay: 0.999
|
| 110 |
+
optim_g:
|
| 111 |
+
type: Adam
|
| 112 |
+
lr: !!float 1e-4
|
| 113 |
+
weight_decay: 0
|
| 114 |
+
betas: [0.9, 0.99]
|
| 115 |
+
optim_d:
|
| 116 |
+
type: Adam
|
| 117 |
+
lr: !!float 1e-4
|
| 118 |
+
weight_decay: 0
|
| 119 |
+
betas: [0.9, 0.99]
|
| 120 |
+
|
| 121 |
+
scheduler:
|
| 122 |
+
type: MultiStepLR
|
| 123 |
+
milestones: [400000]
|
| 124 |
+
gamma: 0.5
|
| 125 |
+
|
| 126 |
+
total_iter: 400000
|
| 127 |
+
warmup_iter: -1 # no warm up
|
| 128 |
+
|
| 129 |
+
# losses
|
| 130 |
+
pixel_opt:
|
| 131 |
+
type: L1Loss
|
| 132 |
+
loss_weight: 1.0
|
| 133 |
+
reduction: mean
|
| 134 |
+
# perceptual loss (content and style losses)
|
| 135 |
+
perceptual_opt:
|
| 136 |
+
type: PerceptualLoss
|
| 137 |
+
layer_weights:
|
| 138 |
+
# before relu
|
| 139 |
+
'conv1_2': 0.1
|
| 140 |
+
'conv2_2': 0.1
|
| 141 |
+
'conv3_4': 1
|
| 142 |
+
'conv4_4': 1
|
| 143 |
+
'conv5_4': 1
|
| 144 |
+
vgg_type: vgg19
|
| 145 |
+
use_input_norm: true
|
| 146 |
+
perceptual_weight: !!float 1.0
|
| 147 |
+
style_weight: 0
|
| 148 |
+
range_norm: false
|
| 149 |
+
criterion: l1
|
| 150 |
+
# gan loss
|
| 151 |
+
gan_opt:
|
| 152 |
+
type: GANLoss
|
| 153 |
+
gan_type: vanilla
|
| 154 |
+
real_label_val: 1.0
|
| 155 |
+
fake_label_val: 0.0
|
| 156 |
+
loss_weight: !!float 1e-1
|
| 157 |
+
|
| 158 |
+
net_d_iters: 1
|
| 159 |
+
net_d_init_iters: 0
|
| 160 |
+
|
| 161 |
+
# Uncomment these for validation
|
| 162 |
+
# validation settings
|
| 163 |
+
# val:
|
| 164 |
+
# val_freq: !!float 5e3
|
| 165 |
+
# save_img: True
|
| 166 |
+
|
| 167 |
+
# metrics:
|
| 168 |
+
# psnr: # metric name
|
| 169 |
+
# type: calculate_psnr
|
| 170 |
+
# crop_border: 4
|
| 171 |
+
# test_y_channel: false
|
| 172 |
+
|
| 173 |
+
# logging settings
|
| 174 |
+
logger:
|
| 175 |
+
print_freq: 100
|
| 176 |
+
save_checkpoint_freq: !!float 5e3
|
| 177 |
+
use_tb_logger: true
|
| 178 |
+
wandb:
|
| 179 |
+
project: ~
|
| 180 |
+
resume_id: ~
|
| 181 |
+
|
| 182 |
+
# dist training settings
|
| 183 |
+
dist_params:
|
| 184 |
+
backend: nccl
|
| 185 |
+
port: 29500
|
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x2plus.yml
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: train_RealESRNetx2plus_1000k_B12G4
|
| 3 |
+
model_type: RealESRNetModel
|
| 4 |
+
scale: 2
|
| 5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
| 6 |
+
manual_seed: 0
|
| 7 |
+
|
| 8 |
+
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
| 9 |
+
gt_usm: True # USM the ground-truth
|
| 10 |
+
|
| 11 |
+
# the first degradation process
|
| 12 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
| 13 |
+
resize_range: [0.15, 1.5]
|
| 14 |
+
gaussian_noise_prob: 0.5
|
| 15 |
+
noise_range: [1, 30]
|
| 16 |
+
poisson_scale_range: [0.05, 3]
|
| 17 |
+
gray_noise_prob: 0.4
|
| 18 |
+
jpeg_range: [30, 95]
|
| 19 |
+
|
| 20 |
+
# the second degradation process
|
| 21 |
+
second_blur_prob: 0.8
|
| 22 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
| 23 |
+
resize_range2: [0.3, 1.2]
|
| 24 |
+
gaussian_noise_prob2: 0.5
|
| 25 |
+
noise_range2: [1, 25]
|
| 26 |
+
poisson_scale_range2: [0.05, 2.5]
|
| 27 |
+
gray_noise_prob2: 0.4
|
| 28 |
+
jpeg_range2: [30, 95]
|
| 29 |
+
|
| 30 |
+
gt_size: 256
|
| 31 |
+
queue_size: 180
|
| 32 |
+
|
| 33 |
+
# dataset and data loader settings
|
| 34 |
+
datasets:
|
| 35 |
+
train:
|
| 36 |
+
name: DF2K+OST
|
| 37 |
+
type: RealESRGANDataset
|
| 38 |
+
dataroot_gt: datasets/DF2K
|
| 39 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
| 40 |
+
io_backend:
|
| 41 |
+
type: disk
|
| 42 |
+
|
| 43 |
+
blur_kernel_size: 21
|
| 44 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 45 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 46 |
+
sinc_prob: 0.1
|
| 47 |
+
blur_sigma: [0.2, 3]
|
| 48 |
+
betag_range: [0.5, 4]
|
| 49 |
+
betap_range: [1, 2]
|
| 50 |
+
|
| 51 |
+
blur_kernel_size2: 21
|
| 52 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 53 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 54 |
+
sinc_prob2: 0.1
|
| 55 |
+
blur_sigma2: [0.2, 1.5]
|
| 56 |
+
betag_range2: [0.5, 4]
|
| 57 |
+
betap_range2: [1, 2]
|
| 58 |
+
|
| 59 |
+
final_sinc_prob: 0.8
|
| 60 |
+
|
| 61 |
+
gt_size: 256
|
| 62 |
+
use_hflip: True
|
| 63 |
+
use_rot: False
|
| 64 |
+
|
| 65 |
+
# data loader
|
| 66 |
+
use_shuffle: true
|
| 67 |
+
num_worker_per_gpu: 5
|
| 68 |
+
batch_size_per_gpu: 12
|
| 69 |
+
dataset_enlarge_ratio: 1
|
| 70 |
+
prefetch_mode: ~
|
| 71 |
+
|
| 72 |
+
# Uncomment these for validation
|
| 73 |
+
# val:
|
| 74 |
+
# name: validation
|
| 75 |
+
# type: PairedImageDataset
|
| 76 |
+
# dataroot_gt: path_to_gt
|
| 77 |
+
# dataroot_lq: path_to_lq
|
| 78 |
+
# io_backend:
|
| 79 |
+
# type: disk
|
| 80 |
+
|
| 81 |
+
# network structures
|
| 82 |
+
network_g:
|
| 83 |
+
type: RRDBNet
|
| 84 |
+
num_in_ch: 3
|
| 85 |
+
num_out_ch: 3
|
| 86 |
+
num_feat: 64
|
| 87 |
+
num_block: 23
|
| 88 |
+
num_grow_ch: 32
|
| 89 |
+
scale: 2
|
| 90 |
+
|
| 91 |
+
# path
|
| 92 |
+
path:
|
| 93 |
+
pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
|
| 94 |
+
param_key_g: params_ema
|
| 95 |
+
strict_load_g: False
|
| 96 |
+
resume_state: ~
|
| 97 |
+
|
| 98 |
+
# training settings
|
| 99 |
+
train:
|
| 100 |
+
ema_decay: 0.999
|
| 101 |
+
optim_g:
|
| 102 |
+
type: Adam
|
| 103 |
+
lr: !!float 2e-4
|
| 104 |
+
weight_decay: 0
|
| 105 |
+
betas: [0.9, 0.99]
|
| 106 |
+
|
| 107 |
+
scheduler:
|
| 108 |
+
type: MultiStepLR
|
| 109 |
+
milestones: [1000000]
|
| 110 |
+
gamma: 0.5
|
| 111 |
+
|
| 112 |
+
total_iter: 1000000
|
| 113 |
+
warmup_iter: -1 # no warm up
|
| 114 |
+
|
| 115 |
+
# losses
|
| 116 |
+
pixel_opt:
|
| 117 |
+
type: L1Loss
|
| 118 |
+
loss_weight: 1.0
|
| 119 |
+
reduction: mean
|
| 120 |
+
|
| 121 |
+
# Uncomment these for validation
|
| 122 |
+
# validation settings
|
| 123 |
+
# val:
|
| 124 |
+
# val_freq: !!float 5e3
|
| 125 |
+
# save_img: True
|
| 126 |
+
|
| 127 |
+
# metrics:
|
| 128 |
+
# psnr: # metric name
|
| 129 |
+
# type: calculate_psnr
|
| 130 |
+
# crop_border: 4
|
| 131 |
+
# test_y_channel: false
|
| 132 |
+
|
| 133 |
+
# logging settings
|
| 134 |
+
logger:
|
| 135 |
+
print_freq: 100
|
| 136 |
+
save_checkpoint_freq: !!float 5e3
|
| 137 |
+
use_tb_logger: true
|
| 138 |
+
wandb:
|
| 139 |
+
project: ~
|
| 140 |
+
resume_id: ~
|
| 141 |
+
|
| 142 |
+
# dist training settings
|
| 143 |
+
dist_params:
|
| 144 |
+
backend: nccl
|
| 145 |
+
port: 29500
|
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x4plus.yml
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: train_RealESRNetx4plus_1000k_B12G4
|
| 3 |
+
model_type: RealESRNetModel
|
| 4 |
+
scale: 4
|
| 5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
| 6 |
+
manual_seed: 0
|
| 7 |
+
|
| 8 |
+
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
| 9 |
+
gt_usm: True # USM the ground-truth
|
| 10 |
+
|
| 11 |
+
# the first degradation process
|
| 12 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
| 13 |
+
resize_range: [0.15, 1.5]
|
| 14 |
+
gaussian_noise_prob: 0.5
|
| 15 |
+
noise_range: [1, 30]
|
| 16 |
+
poisson_scale_range: [0.05, 3]
|
| 17 |
+
gray_noise_prob: 0.4
|
| 18 |
+
jpeg_range: [30, 95]
|
| 19 |
+
|
| 20 |
+
# the second degradation process
|
| 21 |
+
second_blur_prob: 0.8
|
| 22 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
| 23 |
+
resize_range2: [0.3, 1.2]
|
| 24 |
+
gaussian_noise_prob2: 0.5
|
| 25 |
+
noise_range2: [1, 25]
|
| 26 |
+
poisson_scale_range2: [0.05, 2.5]
|
| 27 |
+
gray_noise_prob2: 0.4
|
| 28 |
+
jpeg_range2: [30, 95]
|
| 29 |
+
|
| 30 |
+
gt_size: 256
|
| 31 |
+
queue_size: 180
|
| 32 |
+
|
| 33 |
+
# dataset and data loader settings
|
| 34 |
+
datasets:
|
| 35 |
+
train:
|
| 36 |
+
name: DF2K+OST
|
| 37 |
+
type: RealESRGANDataset
|
| 38 |
+
dataroot_gt: datasets/DF2K
|
| 39 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
| 40 |
+
io_backend:
|
| 41 |
+
type: disk
|
| 42 |
+
|
| 43 |
+
blur_kernel_size: 21
|
| 44 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 45 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 46 |
+
sinc_prob: 0.1
|
| 47 |
+
blur_sigma: [0.2, 3]
|
| 48 |
+
betag_range: [0.5, 4]
|
| 49 |
+
betap_range: [1, 2]
|
| 50 |
+
|
| 51 |
+
blur_kernel_size2: 21
|
| 52 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
| 53 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
| 54 |
+
sinc_prob2: 0.1
|
| 55 |
+
blur_sigma2: [0.2, 1.5]
|
| 56 |
+
betag_range2: [0.5, 4]
|
| 57 |
+
betap_range2: [1, 2]
|
| 58 |
+
|
| 59 |
+
final_sinc_prob: 0.8
|
| 60 |
+
|
| 61 |
+
gt_size: 256
|
| 62 |
+
use_hflip: True
|
| 63 |
+
use_rot: False
|
| 64 |
+
|
| 65 |
+
# data loader
|
| 66 |
+
use_shuffle: true
|
| 67 |
+
num_worker_per_gpu: 5
|
| 68 |
+
batch_size_per_gpu: 12
|
| 69 |
+
dataset_enlarge_ratio: 1
|
| 70 |
+
prefetch_mode: ~
|
| 71 |
+
|
| 72 |
+
# Uncomment these for validation
|
| 73 |
+
# val:
|
| 74 |
+
# name: validation
|
| 75 |
+
# type: PairedImageDataset
|
| 76 |
+
# dataroot_gt: path_to_gt
|
| 77 |
+
# dataroot_lq: path_to_lq
|
| 78 |
+
# io_backend:
|
| 79 |
+
# type: disk
|
| 80 |
+
|
| 81 |
+
# network structures
|
| 82 |
+
network_g:
|
| 83 |
+
type: RRDBNet
|
| 84 |
+
num_in_ch: 3
|
| 85 |
+
num_out_ch: 3
|
| 86 |
+
num_feat: 64
|
| 87 |
+
num_block: 23
|
| 88 |
+
num_grow_ch: 32
|
| 89 |
+
|
| 90 |
+
# path
|
| 91 |
+
path:
|
| 92 |
+
pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
|
| 93 |
+
param_key_g: params_ema
|
| 94 |
+
strict_load_g: true
|
| 95 |
+
resume_state: ~
|
| 96 |
+
|
| 97 |
+
# training settings
|
| 98 |
+
train:
|
| 99 |
+
ema_decay: 0.999
|
| 100 |
+
optim_g:
|
| 101 |
+
type: Adam
|
| 102 |
+
lr: !!float 2e-4
|
| 103 |
+
weight_decay: 0
|
| 104 |
+
betas: [0.9, 0.99]
|
| 105 |
+
|
| 106 |
+
scheduler:
|
| 107 |
+
type: MultiStepLR
|
| 108 |
+
milestones: [1000000]
|
| 109 |
+
gamma: 0.5
|
| 110 |
+
|
| 111 |
+
total_iter: 1000000
|
| 112 |
+
warmup_iter: -1 # no warm up
|
| 113 |
+
|
| 114 |
+
# losses
|
| 115 |
+
pixel_opt:
|
| 116 |
+
type: L1Loss
|
| 117 |
+
loss_weight: 1.0
|
| 118 |
+
reduction: mean
|
| 119 |
+
|
| 120 |
+
# Uncomment these for validation
|
| 121 |
+
# validation settings
|
| 122 |
+
# val:
|
| 123 |
+
# val_freq: !!float 5e3
|
| 124 |
+
# save_img: True
|
| 125 |
+
|
| 126 |
+
# metrics:
|
| 127 |
+
# psnr: # metric name
|
| 128 |
+
# type: calculate_psnr
|
| 129 |
+
# crop_border: 4
|
| 130 |
+
# test_y_channel: false
|
| 131 |
+
|
| 132 |
+
# logging settings
|
| 133 |
+
logger:
|
| 134 |
+
print_freq: 100
|
| 135 |
+
save_checkpoint_freq: !!float 5e3
|
| 136 |
+
use_tb_logger: true
|
| 137 |
+
wandb:
|
| 138 |
+
project: ~
|
| 139 |
+
resume_id: ~
|
| 140 |
+
|
| 141 |
+
# dist training settings
|
| 142 |
+
dist_params:
|
| 143 |
+
backend: nccl
|
| 144 |
+
port: 29500
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
from .archs import *
|
| 3 |
+
from .data import *
|
| 4 |
+
from .models import *
|
| 5 |
+
from .utils import *
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from basicsr.utils import scandir
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# automatically scan and import arch modules for registry
|
| 6 |
+
# scan all the files that end with '_arch.py' under the archs folder
|
| 7 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
| 8 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
| 9 |
+
# import all the arch modules
|
| 10 |
+
from . import srvgg_arch
|
| 11 |
+
from . import discriminator_arch
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/discriminator_arch.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.nn.utils import spectral_norm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@ARCH_REGISTRY.register()
|
| 8 |
+
class UNetDiscriminatorSN(nn.Module):
|
| 9 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
| 10 |
+
|
| 11 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
| 12 |
+
|
| 13 |
+
Arg:
|
| 14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
| 15 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
| 16 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
| 20 |
+
super(UNetDiscriminatorSN, self).__init__()
|
| 21 |
+
self.skip_connection = skip_connection
|
| 22 |
+
norm = spectral_norm
|
| 23 |
+
# the first convolution
|
| 24 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
| 25 |
+
# downsample
|
| 26 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
| 27 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
| 28 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
| 29 |
+
# upsample
|
| 30 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
| 31 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
| 32 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
| 33 |
+
# extra convolutions
|
| 34 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
| 35 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
| 36 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
# downsample
|
| 40 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
| 41 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
| 42 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
| 43 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
| 44 |
+
|
| 45 |
+
# upsample
|
| 46 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
| 47 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
| 48 |
+
|
| 49 |
+
if self.skip_connection:
|
| 50 |
+
x4 = x4 + x2
|
| 51 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
| 52 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
| 53 |
+
|
| 54 |
+
if self.skip_connection:
|
| 55 |
+
x5 = x5 + x1
|
| 56 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
| 57 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
| 58 |
+
|
| 59 |
+
if self.skip_connection:
|
| 60 |
+
x6 = x6 + x0
|
| 61 |
+
|
| 62 |
+
# extra convolutions
|
| 63 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
| 64 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
| 65 |
+
out = self.conv9(out)
|
| 66 |
+
|
| 67 |
+
return out
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/srvgg_arch.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@ARCH_REGISTRY.register()
|
| 7 |
+
class SRVGGNetCompact(nn.Module):
|
| 8 |
+
"""A compact VGG-style network structure for super-resolution.
|
| 9 |
+
|
| 10 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
| 11 |
+
conducted on the HR feature space.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
| 15 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
| 16 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
| 17 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
| 18 |
+
upscale (int): Upsampling factor. Default: 4.
|
| 19 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
| 23 |
+
super(SRVGGNetCompact, self).__init__()
|
| 24 |
+
self.num_in_ch = num_in_ch
|
| 25 |
+
self.num_out_ch = num_out_ch
|
| 26 |
+
self.num_feat = num_feat
|
| 27 |
+
self.num_conv = num_conv
|
| 28 |
+
self.upscale = upscale
|
| 29 |
+
self.act_type = act_type
|
| 30 |
+
|
| 31 |
+
self.body = nn.ModuleList()
|
| 32 |
+
# the first conv
|
| 33 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
| 34 |
+
# the first activation
|
| 35 |
+
if act_type == 'relu':
|
| 36 |
+
activation = nn.ReLU(inplace=True)
|
| 37 |
+
elif act_type == 'prelu':
|
| 38 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
| 39 |
+
elif act_type == 'leakyrelu':
|
| 40 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 41 |
+
self.body.append(activation)
|
| 42 |
+
|
| 43 |
+
# the body structure
|
| 44 |
+
for _ in range(num_conv):
|
| 45 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
| 46 |
+
# activation
|
| 47 |
+
if act_type == 'relu':
|
| 48 |
+
activation = nn.ReLU(inplace=True)
|
| 49 |
+
elif act_type == 'prelu':
|
| 50 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
| 51 |
+
elif act_type == 'leakyrelu':
|
| 52 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
| 53 |
+
self.body.append(activation)
|
| 54 |
+
|
| 55 |
+
# the last conv
|
| 56 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
| 57 |
+
# upsample
|
| 58 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
out = x
|
| 62 |
+
for i in range(0, len(self.body)):
|
| 63 |
+
out = self.body[i](out)
|
| 64 |
+
|
| 65 |
+
out = self.upsampler(out)
|
| 66 |
+
# add the nearest upsampled image, so that the network learns the residual
|
| 67 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
| 68 |
+
out += base
|
| 69 |
+
return out
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from basicsr.utils import scandir
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# automatically scan and import dataset modules for registry
|
| 6 |
+
# scan all the files that end with '_dataset.py' under the data folder
|
| 7 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 8 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
| 9 |
+
# import all the dataset modules
|
| 10 |
+
from . import realesrgan_paired_dataset
|
| 11 |
+
from . import realesrgan_dataset
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_dataset.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import torch
|
| 9 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
| 10 |
+
from basicsr.data.transforms import augment
|
| 11 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 12 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 13 |
+
from torch.utils import data as data
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@DATASET_REGISTRY.register()
|
| 17 |
+
class RealESRGANDataset(data.Dataset):
|
| 18 |
+
"""Dataset used for Real-ESRGAN model:
|
| 19 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
| 20 |
+
|
| 21 |
+
It loads gt (Ground-Truth) images, and augments them.
|
| 22 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
| 23 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 27 |
+
dataroot_gt (str): Data root path for gt.
|
| 28 |
+
meta_info (str): Path for meta information file.
|
| 29 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 30 |
+
use_hflip (bool): Use horizontal flips.
|
| 31 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
| 32 |
+
Please see more options in the codes.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, opt):
|
| 36 |
+
super(RealESRGANDataset, self).__init__()
|
| 37 |
+
self.opt = opt
|
| 38 |
+
self.file_client = None
|
| 39 |
+
self.io_backend_opt = opt['io_backend']
|
| 40 |
+
self.gt_folder = opt['dataroot_gt']
|
| 41 |
+
|
| 42 |
+
# file client (lmdb io backend)
|
| 43 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 44 |
+
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
| 45 |
+
self.io_backend_opt['client_keys'] = ['gt']
|
| 46 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 47 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
| 48 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 49 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 50 |
+
else:
|
| 51 |
+
# disk backend with meta_info
|
| 52 |
+
# Each line in the meta_info describes the relative path to an image
|
| 53 |
+
with open(self.opt['meta_info']) as fin:
|
| 54 |
+
paths = [line.strip().split(' ')[0] for line in fin]
|
| 55 |
+
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
| 56 |
+
|
| 57 |
+
# blur settings for the first degradation
|
| 58 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
| 59 |
+
self.kernel_list = opt['kernel_list']
|
| 60 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
| 61 |
+
self.blur_sigma = opt['blur_sigma']
|
| 62 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
| 63 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
| 64 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
| 65 |
+
|
| 66 |
+
# blur settings for the second degradation
|
| 67 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
| 68 |
+
self.kernel_list2 = opt['kernel_list2']
|
| 69 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
| 70 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
| 71 |
+
self.betag_range2 = opt['betag_range2']
|
| 72 |
+
self.betap_range2 = opt['betap_range2']
|
| 73 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
| 74 |
+
|
| 75 |
+
# a final sinc filter
|
| 76 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
| 77 |
+
|
| 78 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
| 79 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
| 80 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
| 81 |
+
self.pulse_tensor[10, 10] = 1
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, index):
|
| 84 |
+
if self.file_client is None:
|
| 85 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 86 |
+
|
| 87 |
+
# -------------------------------- Load gt images -------------------------------- #
|
| 88 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
| 89 |
+
gt_path = self.paths[index]
|
| 90 |
+
# avoid errors caused by high latency in reading files
|
| 91 |
+
retry = 3
|
| 92 |
+
while retry > 0:
|
| 93 |
+
try:
|
| 94 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 95 |
+
except (IOError, OSError) as e:
|
| 96 |
+
logger = get_root_logger()
|
| 97 |
+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
| 98 |
+
# change another file to read
|
| 99 |
+
index = random.randint(0, self.__len__())
|
| 100 |
+
gt_path = self.paths[index]
|
| 101 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
| 102 |
+
else:
|
| 103 |
+
break
|
| 104 |
+
finally:
|
| 105 |
+
retry -= 1
|
| 106 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 107 |
+
|
| 108 |
+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
| 109 |
+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
| 110 |
+
|
| 111 |
+
# crop or pad to 400
|
| 112 |
+
# TODO: 400 is hard-coded. You may change it accordingly
|
| 113 |
+
h, w = img_gt.shape[0:2]
|
| 114 |
+
crop_pad_size = 400
|
| 115 |
+
# pad
|
| 116 |
+
if h < crop_pad_size or w < crop_pad_size:
|
| 117 |
+
pad_h = max(0, crop_pad_size - h)
|
| 118 |
+
pad_w = max(0, crop_pad_size - w)
|
| 119 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
| 120 |
+
# crop
|
| 121 |
+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
| 122 |
+
h, w = img_gt.shape[0:2]
|
| 123 |
+
# randomly choose top and left coordinates
|
| 124 |
+
top = random.randint(0, h - crop_pad_size)
|
| 125 |
+
left = random.randint(0, w - crop_pad_size)
|
| 126 |
+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
| 127 |
+
|
| 128 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
| 129 |
+
kernel_size = random.choice(self.kernel_range)
|
| 130 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
| 131 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
| 132 |
+
if kernel_size < 13:
|
| 133 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
| 134 |
+
else:
|
| 135 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
| 136 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
| 137 |
+
else:
|
| 138 |
+
kernel = random_mixed_kernels(
|
| 139 |
+
self.kernel_list,
|
| 140 |
+
self.kernel_prob,
|
| 141 |
+
kernel_size,
|
| 142 |
+
self.blur_sigma,
|
| 143 |
+
self.blur_sigma, [-math.pi, math.pi],
|
| 144 |
+
self.betag_range,
|
| 145 |
+
self.betap_range,
|
| 146 |
+
noise_range=None)
|
| 147 |
+
# pad kernel
|
| 148 |
+
pad_size = (21 - kernel_size) // 2
|
| 149 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
| 150 |
+
|
| 151 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
| 152 |
+
kernel_size = random.choice(self.kernel_range)
|
| 153 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
| 154 |
+
if kernel_size < 13:
|
| 155 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
| 156 |
+
else:
|
| 157 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
| 158 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
| 159 |
+
else:
|
| 160 |
+
kernel2 = random_mixed_kernels(
|
| 161 |
+
self.kernel_list2,
|
| 162 |
+
self.kernel_prob2,
|
| 163 |
+
kernel_size,
|
| 164 |
+
self.blur_sigma2,
|
| 165 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
| 166 |
+
self.betag_range2,
|
| 167 |
+
self.betap_range2,
|
| 168 |
+
noise_range=None)
|
| 169 |
+
|
| 170 |
+
# pad kernel
|
| 171 |
+
pad_size = (21 - kernel_size) // 2
|
| 172 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
| 173 |
+
|
| 174 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
| 175 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
| 176 |
+
kernel_size = random.choice(self.kernel_range)
|
| 177 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
| 178 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
| 179 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
| 180 |
+
else:
|
| 181 |
+
sinc_kernel = self.pulse_tensor
|
| 182 |
+
|
| 183 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 184 |
+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
| 185 |
+
kernel = torch.FloatTensor(kernel)
|
| 186 |
+
kernel2 = torch.FloatTensor(kernel2)
|
| 187 |
+
|
| 188 |
+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
| 189 |
+
return return_d
|
| 190 |
+
|
| 191 |
+
def __len__(self):
|
| 192 |
+
return len(self.paths)
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_paired_dataset.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
| 3 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 4 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
| 5 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 6 |
+
from torch.utils import data as data
|
| 7 |
+
from torchvision.transforms.functional import normalize
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DATASET_REGISTRY.register()
|
| 11 |
+
class RealESRGANPairedDataset(data.Dataset):
|
| 12 |
+
"""Paired image dataset for image restoration.
|
| 13 |
+
|
| 14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
| 15 |
+
|
| 16 |
+
There are three modes:
|
| 17 |
+
1. 'lmdb': Use lmdb files.
|
| 18 |
+
If opt['io_backend'] == lmdb.
|
| 19 |
+
2. 'meta_info': Use meta information file to generate paths.
|
| 20 |
+
If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
|
| 21 |
+
3. 'folder': Scan folders to generate paths.
|
| 22 |
+
The rest.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 26 |
+
dataroot_gt (str): Data root path for gt.
|
| 27 |
+
dataroot_lq (str): Data root path for lq.
|
| 28 |
+
meta_info (str): Path for meta information file.
|
| 29 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 30 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
| 31 |
+
Default: '{}'.
|
| 32 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 33 |
+
use_hflip (bool): Use horizontal flips.
|
| 34 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
| 35 |
+
and w for implementation).
|
| 36 |
+
|
| 37 |
+
scale (bool): Scale, which will be added automatically.
|
| 38 |
+
phase (str): 'train' or 'val'.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, opt):
|
| 42 |
+
super(RealESRGANPairedDataset, self).__init__()
|
| 43 |
+
self.opt = opt
|
| 44 |
+
self.file_client = None
|
| 45 |
+
self.io_backend_opt = opt['io_backend']
|
| 46 |
+
# mean and std for normalizing the input images
|
| 47 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 48 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 49 |
+
|
| 50 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 51 |
+
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
| 52 |
+
|
| 53 |
+
# file client (lmdb io backend)
|
| 54 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 55 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 56 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 57 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 58 |
+
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
| 59 |
+
# disk backend with meta_info
|
| 60 |
+
# Each line in the meta_info describes the relative path to an image
|
| 61 |
+
with open(self.opt['meta_info']) as fin:
|
| 62 |
+
paths = [line.strip() for line in fin]
|
| 63 |
+
self.paths = []
|
| 64 |
+
for path in paths:
|
| 65 |
+
gt_path, lq_path = path.split(', ')
|
| 66 |
+
gt_path = os.path.join(self.gt_folder, gt_path)
|
| 67 |
+
lq_path = os.path.join(self.lq_folder, lq_path)
|
| 68 |
+
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
| 69 |
+
else:
|
| 70 |
+
# disk backend
|
| 71 |
+
# it will scan the whole folder to get meta info
|
| 72 |
+
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
| 73 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, index):
|
| 76 |
+
if self.file_client is None:
|
| 77 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 78 |
+
|
| 79 |
+
scale = self.opt['scale']
|
| 80 |
+
|
| 81 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 82 |
+
# image range: [0, 1], float32.
|
| 83 |
+
gt_path = self.paths[index]['gt_path']
|
| 84 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 85 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 86 |
+
lq_path = self.paths[index]['lq_path']
|
| 87 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 88 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 89 |
+
|
| 90 |
+
# augmentation for training
|
| 91 |
+
if self.opt['phase'] == 'train':
|
| 92 |
+
gt_size = self.opt['gt_size']
|
| 93 |
+
# random crop
|
| 94 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
| 95 |
+
# flip, rotation
|
| 96 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
| 97 |
+
|
| 98 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 99 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
| 100 |
+
# normalize
|
| 101 |
+
if self.mean is not None or self.std is not None:
|
| 102 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 103 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 104 |
+
|
| 105 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
return len(self.paths)
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from basicsr.utils import scandir
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# automatically scan and import model modules for registry
|
| 6 |
+
# scan all the files that end with '_model.py' under the model folder
|
| 7 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 8 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
| 9 |
+
# import all the model modules
|
| 10 |
+
from . import realesrgan_model
|
| 11 |
+
from . import realesrnet_model
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrgan_model.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
| 5 |
+
from basicsr.data.transforms import paired_random_crop
|
| 6 |
+
from basicsr.models.srgan_model import SRGANModel
|
| 7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
| 8 |
+
from basicsr.utils.img_process_util import filter2D
|
| 9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 10 |
+
from collections import OrderedDict
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@MODEL_REGISTRY.register()
|
| 15 |
+
class RealESRGANModel(SRGANModel):
|
| 16 |
+
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
| 17 |
+
|
| 18 |
+
It mainly performs:
|
| 19 |
+
1. randomly synthesize LQ images in GPU tensors
|
| 20 |
+
2. optimize the networks with GAN training.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, opt):
|
| 24 |
+
super(RealESRGANModel, self).__init__(opt)
|
| 25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
| 26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
| 27 |
+
self.queue_size = opt.get('queue_size', 180)
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def _dequeue_and_enqueue(self):
|
| 31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
| 32 |
+
|
| 33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
| 34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
| 35 |
+
to increase the degradation diversity in a batch.
|
| 36 |
+
"""
|
| 37 |
+
# initialize
|
| 38 |
+
b, c, h, w = self.lq.size()
|
| 39 |
+
if not hasattr(self, 'queue_lr'):
|
| 40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
| 41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
| 42 |
+
_, c, h, w = self.gt.size()
|
| 43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
| 44 |
+
self.queue_ptr = 0
|
| 45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
| 46 |
+
# do dequeue and enqueue
|
| 47 |
+
# shuffle
|
| 48 |
+
idx = torch.randperm(self.queue_size)
|
| 49 |
+
self.queue_lr = self.queue_lr[idx]
|
| 50 |
+
self.queue_gt = self.queue_gt[idx]
|
| 51 |
+
# get first b samples
|
| 52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
| 53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
| 54 |
+
# update the queue
|
| 55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
| 56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
| 57 |
+
|
| 58 |
+
self.lq = lq_dequeue
|
| 59 |
+
self.gt = gt_dequeue
|
| 60 |
+
else:
|
| 61 |
+
# only do enqueue
|
| 62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
| 63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
| 64 |
+
self.queue_ptr = self.queue_ptr + b
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def feed_data(self, data):
|
| 68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
| 69 |
+
"""
|
| 70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
| 71 |
+
# training data synthesis
|
| 72 |
+
self.gt = data['gt'].to(self.device)
|
| 73 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
| 74 |
+
|
| 75 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
| 76 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
| 77 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
| 78 |
+
|
| 79 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
| 80 |
+
|
| 81 |
+
# ----------------------- The first degradation process ----------------------- #
|
| 82 |
+
# blur
|
| 83 |
+
out = filter2D(self.gt_usm, self.kernel1)
|
| 84 |
+
# random resize
|
| 85 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
| 86 |
+
if updown_type == 'up':
|
| 87 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
| 88 |
+
elif updown_type == 'down':
|
| 89 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
| 90 |
+
else:
|
| 91 |
+
scale = 1
|
| 92 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 93 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
| 94 |
+
# add noise
|
| 95 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
| 96 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
| 97 |
+
out = random_add_gaussian_noise_pt(
|
| 98 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
| 99 |
+
else:
|
| 100 |
+
out = random_add_poisson_noise_pt(
|
| 101 |
+
out,
|
| 102 |
+
scale_range=self.opt['poisson_scale_range'],
|
| 103 |
+
gray_prob=gray_noise_prob,
|
| 104 |
+
clip=True,
|
| 105 |
+
rounds=False)
|
| 106 |
+
# JPEG compression
|
| 107 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
| 108 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
| 109 |
+
out = self.jpeger(out, quality=jpeg_p)
|
| 110 |
+
|
| 111 |
+
# ----------------------- The second degradation process ----------------------- #
|
| 112 |
+
# blur
|
| 113 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
| 114 |
+
out = filter2D(out, self.kernel2)
|
| 115 |
+
# random resize
|
| 116 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
| 117 |
+
if updown_type == 'up':
|
| 118 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
| 119 |
+
elif updown_type == 'down':
|
| 120 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
| 121 |
+
else:
|
| 122 |
+
scale = 1
|
| 123 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 124 |
+
out = F.interpolate(
|
| 125 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
| 126 |
+
# add noise
|
| 127 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
| 128 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
| 129 |
+
out = random_add_gaussian_noise_pt(
|
| 130 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
| 131 |
+
else:
|
| 132 |
+
out = random_add_poisson_noise_pt(
|
| 133 |
+
out,
|
| 134 |
+
scale_range=self.opt['poisson_scale_range2'],
|
| 135 |
+
gray_prob=gray_noise_prob,
|
| 136 |
+
clip=True,
|
| 137 |
+
rounds=False)
|
| 138 |
+
|
| 139 |
+
# JPEG compression + the final sinc filter
|
| 140 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
| 141 |
+
# as one operation.
|
| 142 |
+
# We consider two orders:
|
| 143 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
| 144 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
| 145 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
| 146 |
+
if np.random.uniform() < 0.5:
|
| 147 |
+
# resize back + the final sinc filter
|
| 148 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 149 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
| 150 |
+
out = filter2D(out, self.sinc_kernel)
|
| 151 |
+
# JPEG compression
|
| 152 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
| 153 |
+
out = torch.clamp(out, 0, 1)
|
| 154 |
+
out = self.jpeger(out, quality=jpeg_p)
|
| 155 |
+
else:
|
| 156 |
+
# JPEG compression
|
| 157 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
| 158 |
+
out = torch.clamp(out, 0, 1)
|
| 159 |
+
out = self.jpeger(out, quality=jpeg_p)
|
| 160 |
+
# resize back + the final sinc filter
|
| 161 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 162 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
| 163 |
+
out = filter2D(out, self.sinc_kernel)
|
| 164 |
+
|
| 165 |
+
# clamp and round
|
| 166 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
| 167 |
+
|
| 168 |
+
# random crop
|
| 169 |
+
gt_size = self.opt['gt_size']
|
| 170 |
+
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
|
| 171 |
+
self.opt['scale'])
|
| 172 |
+
|
| 173 |
+
# training pair pool
|
| 174 |
+
self._dequeue_and_enqueue()
|
| 175 |
+
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
| 176 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
| 177 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
| 178 |
+
else:
|
| 179 |
+
# for paired training or validation
|
| 180 |
+
self.lq = data['lq'].to(self.device)
|
| 181 |
+
if 'gt' in data:
|
| 182 |
+
self.gt = data['gt'].to(self.device)
|
| 183 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
| 184 |
+
|
| 185 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 186 |
+
# do not use the synthetic process during validation
|
| 187 |
+
self.is_train = False
|
| 188 |
+
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 189 |
+
self.is_train = True
|
| 190 |
+
|
| 191 |
+
def optimize_parameters(self, current_iter):
|
| 192 |
+
# usm sharpening
|
| 193 |
+
l1_gt = self.gt_usm
|
| 194 |
+
percep_gt = self.gt_usm
|
| 195 |
+
gan_gt = self.gt_usm
|
| 196 |
+
if self.opt['l1_gt_usm'] is False:
|
| 197 |
+
l1_gt = self.gt
|
| 198 |
+
if self.opt['percep_gt_usm'] is False:
|
| 199 |
+
percep_gt = self.gt
|
| 200 |
+
if self.opt['gan_gt_usm'] is False:
|
| 201 |
+
gan_gt = self.gt
|
| 202 |
+
|
| 203 |
+
# optimize net_g
|
| 204 |
+
for p in self.net_d.parameters():
|
| 205 |
+
p.requires_grad = False
|
| 206 |
+
|
| 207 |
+
self.optimizer_g.zero_grad()
|
| 208 |
+
self.output = self.net_g(self.lq)
|
| 209 |
+
|
| 210 |
+
l_g_total = 0
|
| 211 |
+
loss_dict = OrderedDict()
|
| 212 |
+
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
| 213 |
+
# pixel loss
|
| 214 |
+
if self.cri_pix:
|
| 215 |
+
l_g_pix = self.cri_pix(self.output, l1_gt)
|
| 216 |
+
l_g_total += l_g_pix
|
| 217 |
+
loss_dict['l_g_pix'] = l_g_pix
|
| 218 |
+
# perceptual loss
|
| 219 |
+
if self.cri_perceptual:
|
| 220 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
|
| 221 |
+
if l_g_percep is not None:
|
| 222 |
+
l_g_total += l_g_percep
|
| 223 |
+
loss_dict['l_g_percep'] = l_g_percep
|
| 224 |
+
if l_g_style is not None:
|
| 225 |
+
l_g_total += l_g_style
|
| 226 |
+
loss_dict['l_g_style'] = l_g_style
|
| 227 |
+
# gan loss
|
| 228 |
+
fake_g_pred = self.net_d(self.output)
|
| 229 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
| 230 |
+
l_g_total += l_g_gan
|
| 231 |
+
loss_dict['l_g_gan'] = l_g_gan
|
| 232 |
+
|
| 233 |
+
l_g_total.backward()
|
| 234 |
+
self.optimizer_g.step()
|
| 235 |
+
|
| 236 |
+
# optimize net_d
|
| 237 |
+
for p in self.net_d.parameters():
|
| 238 |
+
p.requires_grad = True
|
| 239 |
+
|
| 240 |
+
self.optimizer_d.zero_grad()
|
| 241 |
+
# real
|
| 242 |
+
real_d_pred = self.net_d(gan_gt)
|
| 243 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
| 244 |
+
loss_dict['l_d_real'] = l_d_real
|
| 245 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
| 246 |
+
l_d_real.backward()
|
| 247 |
+
# fake
|
| 248 |
+
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
|
| 249 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
| 250 |
+
loss_dict['l_d_fake'] = l_d_fake
|
| 251 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
| 252 |
+
l_d_fake.backward()
|
| 253 |
+
self.optimizer_d.step()
|
| 254 |
+
|
| 255 |
+
if self.ema_decay > 0:
|
| 256 |
+
self.model_ema(decay=self.ema_decay)
|
| 257 |
+
|
| 258 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrnet_model.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
| 5 |
+
from basicsr.data.transforms import paired_random_crop
|
| 6 |
+
from basicsr.models.sr_model import SRModel
|
| 7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
| 8 |
+
from basicsr.utils.img_process_util import filter2D
|
| 9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@MODEL_REGISTRY.register()
|
| 14 |
+
class RealESRNetModel(SRModel):
|
| 15 |
+
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
| 16 |
+
|
| 17 |
+
It is trained without GAN losses.
|
| 18 |
+
It mainly performs:
|
| 19 |
+
1. randomly synthesize LQ images in GPU tensors
|
| 20 |
+
2. optimize the networks with GAN training.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, opt):
|
| 24 |
+
super(RealESRNetModel, self).__init__(opt)
|
| 25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
| 26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
| 27 |
+
self.queue_size = opt.get('queue_size', 180)
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def _dequeue_and_enqueue(self):
|
| 31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
| 32 |
+
|
| 33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
| 34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
| 35 |
+
to increase the degradation diversity in a batch.
|
| 36 |
+
"""
|
| 37 |
+
# initialize
|
| 38 |
+
b, c, h, w = self.lq.size()
|
| 39 |
+
if not hasattr(self, 'queue_lr'):
|
| 40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
| 41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
| 42 |
+
_, c, h, w = self.gt.size()
|
| 43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
| 44 |
+
self.queue_ptr = 0
|
| 45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
| 46 |
+
# do dequeue and enqueue
|
| 47 |
+
# shuffle
|
| 48 |
+
idx = torch.randperm(self.queue_size)
|
| 49 |
+
self.queue_lr = self.queue_lr[idx]
|
| 50 |
+
self.queue_gt = self.queue_gt[idx]
|
| 51 |
+
# get first b samples
|
| 52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
| 53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
| 54 |
+
# update the queue
|
| 55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
| 56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
| 57 |
+
|
| 58 |
+
self.lq = lq_dequeue
|
| 59 |
+
self.gt = gt_dequeue
|
| 60 |
+
else:
|
| 61 |
+
# only do enqueue
|
| 62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
| 63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
| 64 |
+
self.queue_ptr = self.queue_ptr + b
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def feed_data(self, data):
|
| 68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
| 69 |
+
"""
|
| 70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
| 71 |
+
# training data synthesis
|
| 72 |
+
self.gt = data['gt'].to(self.device)
|
| 73 |
+
# USM sharpen the GT images
|
| 74 |
+
if self.opt['gt_usm'] is True:
|
| 75 |
+
self.gt = self.usm_sharpener(self.gt)
|
| 76 |
+
|
| 77 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
| 78 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
| 79 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
| 80 |
+
|
| 81 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
| 82 |
+
|
| 83 |
+
# ----------------------- The first degradation process ----------------------- #
|
| 84 |
+
# blur
|
| 85 |
+
out = filter2D(self.gt, self.kernel1)
|
| 86 |
+
# random resize
|
| 87 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
| 88 |
+
if updown_type == 'up':
|
| 89 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
| 90 |
+
elif updown_type == 'down':
|
| 91 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
| 92 |
+
else:
|
| 93 |
+
scale = 1
|
| 94 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 95 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
| 96 |
+
# add noise
|
| 97 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
| 98 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
| 99 |
+
out = random_add_gaussian_noise_pt(
|
| 100 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
| 101 |
+
else:
|
| 102 |
+
out = random_add_poisson_noise_pt(
|
| 103 |
+
out,
|
| 104 |
+
scale_range=self.opt['poisson_scale_range'],
|
| 105 |
+
gray_prob=gray_noise_prob,
|
| 106 |
+
clip=True,
|
| 107 |
+
rounds=False)
|
| 108 |
+
# JPEG compression
|
| 109 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
| 110 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
| 111 |
+
out = self.jpeger(out, quality=jpeg_p)
|
| 112 |
+
|
| 113 |
+
# ----------------------- The second degradation process ----------------------- #
|
| 114 |
+
# blur
|
| 115 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
| 116 |
+
out = filter2D(out, self.kernel2)
|
| 117 |
+
# random resize
|
| 118 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
| 119 |
+
if updown_type == 'up':
|
| 120 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
| 121 |
+
elif updown_type == 'down':
|
| 122 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
| 123 |
+
else:
|
| 124 |
+
scale = 1
|
| 125 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 126 |
+
out = F.interpolate(
|
| 127 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
| 128 |
+
# add noise
|
| 129 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
| 130 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
| 131 |
+
out = random_add_gaussian_noise_pt(
|
| 132 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
| 133 |
+
else:
|
| 134 |
+
out = random_add_poisson_noise_pt(
|
| 135 |
+
out,
|
| 136 |
+
scale_range=self.opt['poisson_scale_range2'],
|
| 137 |
+
gray_prob=gray_noise_prob,
|
| 138 |
+
clip=True,
|
| 139 |
+
rounds=False)
|
| 140 |
+
|
| 141 |
+
# JPEG compression + the final sinc filter
|
| 142 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
| 143 |
+
# as one operation.
|
| 144 |
+
# We consider two orders:
|
| 145 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
| 146 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
| 147 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
| 148 |
+
if np.random.uniform() < 0.5:
|
| 149 |
+
# resize back + the final sinc filter
|
| 150 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 151 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
| 152 |
+
out = filter2D(out, self.sinc_kernel)
|
| 153 |
+
# JPEG compression
|
| 154 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
| 155 |
+
out = torch.clamp(out, 0, 1)
|
| 156 |
+
out = self.jpeger(out, quality=jpeg_p)
|
| 157 |
+
else:
|
| 158 |
+
# JPEG compression
|
| 159 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
| 160 |
+
out = torch.clamp(out, 0, 1)
|
| 161 |
+
out = self.jpeger(out, quality=jpeg_p)
|
| 162 |
+
# resize back + the final sinc filter
|
| 163 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
| 164 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
| 165 |
+
out = filter2D(out, self.sinc_kernel)
|
| 166 |
+
|
| 167 |
+
# clamp and round
|
| 168 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
| 169 |
+
|
| 170 |
+
# random crop
|
| 171 |
+
gt_size = self.opt['gt_size']
|
| 172 |
+
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
| 173 |
+
|
| 174 |
+
# training pair pool
|
| 175 |
+
self._dequeue_and_enqueue()
|
| 176 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
| 177 |
+
else:
|
| 178 |
+
# for paired training or validation
|
| 179 |
+
self.lq = data['lq'].to(self.device)
|
| 180 |
+
if 'gt' in data:
|
| 181 |
+
self.gt = data['gt'].to(self.device)
|
| 182 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
| 183 |
+
|
| 184 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 185 |
+
# do not use the synthetic process during validation
|
| 186 |
+
self.is_train = False
|
| 187 |
+
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 188 |
+
self.is_train = True
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/train.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from basicsr.train import train_pipeline
|
| 4 |
+
|
| 5 |
+
import realesrgan.archs
|
| 6 |
+
import realesrgan.data
|
| 7 |
+
import realesrgan.models
|
| 8 |
+
|
| 9 |
+
if __name__ == '__main__':
|
| 10 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 11 |
+
train_pipeline(root_path)
|
JarvisIR/package/agent_tools/ESRGAN/realesrgan/utils.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import queue
|
| 6 |
+
import threading
|
| 7 |
+
import torch
|
| 8 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RealESRGANer():
|
| 15 |
+
"""A helper class for upsampling images with RealESRGAN.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
| 19 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
| 20 |
+
model (nn.Module): The defined network. Default: None.
|
| 21 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
| 22 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
| 23 |
+
0 denotes for do not use tile. Default: 0.
|
| 24 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
| 25 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
| 26 |
+
half (float): Whether to use half precision during inference. Default: False.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self,
|
| 30 |
+
scale,
|
| 31 |
+
model_path,
|
| 32 |
+
dni_weight=None,
|
| 33 |
+
model=None,
|
| 34 |
+
tile=0,
|
| 35 |
+
tile_pad=10,
|
| 36 |
+
pre_pad=10,
|
| 37 |
+
half=False,
|
| 38 |
+
device=None,
|
| 39 |
+
gpu_id=None):
|
| 40 |
+
self.scale = scale
|
| 41 |
+
self.tile_size = tile
|
| 42 |
+
self.tile_pad = tile_pad
|
| 43 |
+
self.pre_pad = pre_pad
|
| 44 |
+
self.mod_scale = None
|
| 45 |
+
self.half = half
|
| 46 |
+
|
| 47 |
+
# initialize model
|
| 48 |
+
self.device = device
|
| 49 |
+
|
| 50 |
+
if isinstance(model_path, list):
|
| 51 |
+
# dni
|
| 52 |
+
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
| 53 |
+
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
| 54 |
+
else:
|
| 55 |
+
# if the model_path starts with https, it will first download models to the folder: weights
|
| 56 |
+
if model_path.startswith('https://'):
|
| 57 |
+
model_path = load_file_from_url(
|
| 58 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
| 59 |
+
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
| 60 |
+
|
| 61 |
+
# prefer to use params_ema
|
| 62 |
+
if 'params_ema' in loadnet:
|
| 63 |
+
keyname = 'params_ema'
|
| 64 |
+
else:
|
| 65 |
+
keyname = 'params'
|
| 66 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
| 67 |
+
|
| 68 |
+
model.eval()
|
| 69 |
+
self.model = model.to(self.device)
|
| 70 |
+
if self.half:
|
| 71 |
+
self.model = self.model.half()
|
| 72 |
+
|
| 73 |
+
def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
|
| 74 |
+
"""Deep network interpolation.
|
| 75 |
+
|
| 76 |
+
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
| 77 |
+
"""
|
| 78 |
+
net_a = torch.load(net_a, map_location=torch.device(loc))
|
| 79 |
+
net_b = torch.load(net_b, map_location=torch.device(loc))
|
| 80 |
+
for k, v_a in net_a[key].items():
|
| 81 |
+
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
| 82 |
+
return net_a
|
| 83 |
+
|
| 84 |
+
def pre_process(self, img):
|
| 85 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
| 86 |
+
"""
|
| 87 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
| 88 |
+
self.img = img.unsqueeze(0).to(self.device)
|
| 89 |
+
if self.half:
|
| 90 |
+
self.img = self.img.half()
|
| 91 |
+
|
| 92 |
+
# pre_pad
|
| 93 |
+
if self.pre_pad != 0:
|
| 94 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
| 95 |
+
# mod pad for divisible borders
|
| 96 |
+
if self.scale == 2:
|
| 97 |
+
self.mod_scale = 2
|
| 98 |
+
elif self.scale == 1:
|
| 99 |
+
self.mod_scale = 4
|
| 100 |
+
if self.mod_scale is not None:
|
| 101 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
| 102 |
+
_, _, h, w = self.img.size()
|
| 103 |
+
if (h % self.mod_scale != 0):
|
| 104 |
+
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
| 105 |
+
if (w % self.mod_scale != 0):
|
| 106 |
+
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
| 107 |
+
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
| 108 |
+
|
| 109 |
+
def process(self):
|
| 110 |
+
# model inference
|
| 111 |
+
self.output = self.model(self.img)
|
| 112 |
+
|
| 113 |
+
def tile_process(self):
|
| 114 |
+
"""It will first crop input images to tiles, and then process each tile.
|
| 115 |
+
Finally, all the processed tiles are merged into one images.
|
| 116 |
+
|
| 117 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
| 118 |
+
"""
|
| 119 |
+
batch, channel, height, width = self.img.shape
|
| 120 |
+
output_height = height * self.scale
|
| 121 |
+
output_width = width * self.scale
|
| 122 |
+
output_shape = (batch, channel, output_height, output_width)
|
| 123 |
+
|
| 124 |
+
# start with black image
|
| 125 |
+
self.output = self.img.new_zeros(output_shape)
|
| 126 |
+
tiles_x = math.ceil(width / self.tile_size)
|
| 127 |
+
tiles_y = math.ceil(height / self.tile_size)
|
| 128 |
+
|
| 129 |
+
# loop over all tiles
|
| 130 |
+
for y in range(tiles_y):
|
| 131 |
+
for x in range(tiles_x):
|
| 132 |
+
# extract tile from input image
|
| 133 |
+
ofs_x = x * self.tile_size
|
| 134 |
+
ofs_y = y * self.tile_size
|
| 135 |
+
# input tile area on total image
|
| 136 |
+
input_start_x = ofs_x
|
| 137 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
| 138 |
+
input_start_y = ofs_y
|
| 139 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
| 140 |
+
|
| 141 |
+
# input tile area on total image with padding
|
| 142 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
| 143 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
| 144 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
| 145 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
| 146 |
+
|
| 147 |
+
# input tile dimensions
|
| 148 |
+
input_tile_width = input_end_x - input_start_x
|
| 149 |
+
input_tile_height = input_end_y - input_start_y
|
| 150 |
+
tile_idx = y * tiles_x + x + 1
|
| 151 |
+
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
| 152 |
+
|
| 153 |
+
# upscale tile
|
| 154 |
+
try:
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
output_tile = self.model(input_tile)
|
| 157 |
+
except RuntimeError as error:
|
| 158 |
+
print('Error', error)
|
| 159 |
+
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
| 160 |
+
|
| 161 |
+
# output tile area on total image
|
| 162 |
+
output_start_x = input_start_x * self.scale
|
| 163 |
+
output_end_x = input_end_x * self.scale
|
| 164 |
+
output_start_y = input_start_y * self.scale
|
| 165 |
+
output_end_y = input_end_y * self.scale
|
| 166 |
+
|
| 167 |
+
# output tile area without padding
|
| 168 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
| 169 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
| 170 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
| 171 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
| 172 |
+
|
| 173 |
+
# put tile into output image
|
| 174 |
+
self.output[:, :, output_start_y:output_end_y,
|
| 175 |
+
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
| 176 |
+
output_start_x_tile:output_end_x_tile]
|
| 177 |
+
|
| 178 |
+
def post_process(self):
|
| 179 |
+
# remove extra pad
|
| 180 |
+
if self.mod_scale is not None:
|
| 181 |
+
_, _, h, w = self.output.size()
|
| 182 |
+
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
| 183 |
+
# remove prepad
|
| 184 |
+
if self.pre_pad != 0:
|
| 185 |
+
_, _, h, w = self.output.size()
|
| 186 |
+
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
| 187 |
+
return self.output
|
| 188 |
+
|
| 189 |
+
@torch.no_grad()
|
| 190 |
+
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
| 191 |
+
h_input, w_input = img.shape[0:2]
|
| 192 |
+
# img: numpy
|
| 193 |
+
img = img.astype(np.float32)
|
| 194 |
+
if np.max(img) > 256: # 16-bit image
|
| 195 |
+
max_range = 65535
|
| 196 |
+
print('\tInput is a 16-bit image')
|
| 197 |
+
else:
|
| 198 |
+
max_range = 255
|
| 199 |
+
img = img / max_range
|
| 200 |
+
if len(img.shape) == 2: # gray image
|
| 201 |
+
img_mode = 'L'
|
| 202 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 203 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
| 204 |
+
img_mode = 'RGBA'
|
| 205 |
+
alpha = img[:, :, 3]
|
| 206 |
+
img = img[:, :, 0:3]
|
| 207 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 208 |
+
if alpha_upsampler == 'realesrgan':
|
| 209 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
| 210 |
+
else:
|
| 211 |
+
img_mode = 'RGB'
|
| 212 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 213 |
+
|
| 214 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
| 215 |
+
self.pre_process(img)
|
| 216 |
+
if self.tile_size > 0:
|
| 217 |
+
self.tile_process()
|
| 218 |
+
else:
|
| 219 |
+
self.process()
|
| 220 |
+
output_img = self.post_process()
|
| 221 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 222 |
+
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
| 223 |
+
if img_mode == 'L':
|
| 224 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
| 225 |
+
|
| 226 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
| 227 |
+
if img_mode == 'RGBA':
|
| 228 |
+
if alpha_upsampler == 'realesrgan':
|
| 229 |
+
self.pre_process(alpha)
|
| 230 |
+
if self.tile_size > 0:
|
| 231 |
+
self.tile_process()
|
| 232 |
+
else:
|
| 233 |
+
self.process()
|
| 234 |
+
output_alpha = self.post_process()
|
| 235 |
+
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 236 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
| 237 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
| 238 |
+
else: # use the cv2 resize for alpha channel
|
| 239 |
+
h, w = alpha.shape[0:2]
|
| 240 |
+
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
| 241 |
+
|
| 242 |
+
# merge the alpha channel
|
| 243 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
| 244 |
+
output_img[:, :, 3] = output_alpha
|
| 245 |
+
|
| 246 |
+
# ------------------------------ return ------------------------------ #
|
| 247 |
+
if max_range == 65535: # 16-bit image
|
| 248 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
| 249 |
+
else:
|
| 250 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
| 251 |
+
|
| 252 |
+
if outscale is not None and outscale != float(self.scale):
|
| 253 |
+
output = cv2.resize(
|
| 254 |
+
output, (
|
| 255 |
+
int(w_input * outscale),
|
| 256 |
+
int(h_input * outscale),
|
| 257 |
+
), interpolation=cv2.INTER_LANCZOS4)
|
| 258 |
+
|
| 259 |
+
return output, img_mode
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class PrefetchReader(threading.Thread):
|
| 263 |
+
"""Prefetch images.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
img_list (list[str]): A image list of image paths to be read.
|
| 267 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def __init__(self, img_list, num_prefetch_queue):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.que = queue.Queue(num_prefetch_queue)
|
| 273 |
+
self.img_list = img_list
|
| 274 |
+
|
| 275 |
+
def run(self):
|
| 276 |
+
for img_path in self.img_list:
|
| 277 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
| 278 |
+
self.que.put(img)
|
| 279 |
+
|
| 280 |
+
self.que.put(None)
|
| 281 |
+
|
| 282 |
+
def __next__(self):
|
| 283 |
+
next_item = self.que.get()
|
| 284 |
+
if next_item is None:
|
| 285 |
+
raise StopIteration
|
| 286 |
+
return next_item
|
| 287 |
+
|
| 288 |
+
def __iter__(self):
|
| 289 |
+
return self
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class IOConsumer(threading.Thread):
|
| 293 |
+
|
| 294 |
+
def __init__(self, opt, que, qid):
|
| 295 |
+
super().__init__()
|
| 296 |
+
self._queue = que
|
| 297 |
+
self.qid = qid
|
| 298 |
+
self.opt = opt
|
| 299 |
+
|
| 300 |
+
def run(self):
|
| 301 |
+
while True:
|
| 302 |
+
msg = self._queue.get()
|
| 303 |
+
if isinstance(msg, str) and msg == 'quit':
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
output = msg['output']
|
| 307 |
+
save_path = msg['save_path']
|
| 308 |
+
cv2.imwrite(save_path, output)
|
| 309 |
+
print(f'IO worker {self.qid} is done.')
|
JarvisIR/package/agent_tools/HVICIDNet/inference.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from .net.CIDNet import CIDNet
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import platform
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
def load_hvicidnet_model(path, device):
|
| 12 |
+
model = CIDNet().to(device)
|
| 13 |
+
model.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage))
|
| 14 |
+
model.eval()
|
| 15 |
+
return model
|
| 16 |
+
|
| 17 |
+
def hvicidnet_predict(model, input_img, output_dir, device):
|
| 18 |
+
# Load image if path is provided as string
|
| 19 |
+
if isinstance(input_img, str):
|
| 20 |
+
img_name = os.path.basename(input_img).split('.')[0]
|
| 21 |
+
input_img = Image.open(input_img)
|
| 22 |
+
else:
|
| 23 |
+
img_name = "output" # default name if input is not a path
|
| 24 |
+
|
| 25 |
+
torch.set_grad_enabled(False)
|
| 26 |
+
pil2tensor = transforms.Compose([transforms.ToTensor()])
|
| 27 |
+
input = pil2tensor(input_img)
|
| 28 |
+
factor = 8
|
| 29 |
+
h, w = input.shape[1], input.shape[2]
|
| 30 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 31 |
+
padh = H - h if h % factor != 0 else 0
|
| 32 |
+
padw = W - w if w % factor != 0 else 0
|
| 33 |
+
input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
|
| 34 |
+
gamma = 1
|
| 35 |
+
alpha_s = 1.0
|
| 36 |
+
alpha_i = 1.0
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
model.trans.alpha_s = alpha_s
|
| 39 |
+
model.trans.alpha = alpha_i
|
| 40 |
+
output = model(input.cuda()**gamma)
|
| 41 |
+
|
| 42 |
+
output = torch.clamp(output.to(device),0,1).cpu()
|
| 43 |
+
output = output[:, :, :h, :w]
|
| 44 |
+
enhanced_img = transforms.ToPILImage()(output.squeeze(0))
|
| 45 |
+
if isinstance(input_img, str):
|
| 46 |
+
original_img = Image.open(input_img)
|
| 47 |
+
enhanced_img = enhanced_img.resize(original_img.size, Image.LANCZOS)
|
| 48 |
+
|
| 49 |
+
# Save the output
|
| 50 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 51 |
+
save_path = os.path.join(output_dir, img_name+'.png')
|
| 52 |
+
enhanced_img.save(save_path)
|
| 53 |
+
return save_path
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == '__main__':
|
| 57 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 58 |
+
model_path = "checkpoints/HVICIDNet/generalization.pth"
|
| 59 |
+
img_path = "./Test_Input/108.png"
|
| 60 |
+
output_folder = "./output"
|
| 61 |
+
model = load_hvicidnet_model(model_path, device)
|
| 62 |
+
save_path = hvicidnet_predict(model, img_path, output_folder, device)
|
| 63 |
+
print(f"processed image saved to: {save_path}")
|
| 64 |
+
|
| 65 |
+
|
JarvisIR/package/agent_tools/HVICIDNet/loss/loss_utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import functools
|
| 4 |
+
from math import exp
|
| 5 |
+
from torch.autograd import Variable
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def reduce_loss(loss, reduction):
|
| 11 |
+
"""Reduce loss as specified.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
loss (Tensor): Elementwise loss tensor.
|
| 15 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tensor: Reduced loss tensor.
|
| 19 |
+
"""
|
| 20 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 21 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 22 |
+
if reduction_enum == 0:
|
| 23 |
+
return loss
|
| 24 |
+
elif reduction_enum == 1:
|
| 25 |
+
return loss.mean()
|
| 26 |
+
else:
|
| 27 |
+
return loss.sum()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
| 31 |
+
"""Apply element-wise weight and reduce loss.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
loss (Tensor): Element-wise loss.
|
| 35 |
+
weight (Tensor): Element-wise weights. Default: None.
|
| 36 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
| 37 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tensor: Loss values.
|
| 41 |
+
"""
|
| 42 |
+
# if weight is specified, apply element-wise weight
|
| 43 |
+
if weight is not None:
|
| 44 |
+
assert weight.dim() == loss.dim()
|
| 45 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
| 46 |
+
loss = loss * weight
|
| 47 |
+
|
| 48 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
| 49 |
+
if weight is None or reduction == 'sum':
|
| 50 |
+
loss = reduce_loss(loss, reduction)
|
| 51 |
+
# if reduction is mean, then compute mean over weight region
|
| 52 |
+
elif reduction == 'mean':
|
| 53 |
+
if weight.size(1) > 1:
|
| 54 |
+
weight = weight.sum()
|
| 55 |
+
else:
|
| 56 |
+
weight = weight.sum() * loss.size(1)
|
| 57 |
+
loss = loss.sum() / weight
|
| 58 |
+
|
| 59 |
+
return loss
|
| 60 |
+
|
| 61 |
+
def weighted_loss(loss_func):
|
| 62 |
+
"""Create a weighted version of a given loss function.
|
| 63 |
+
|
| 64 |
+
To use this decorator, the loss function must have the signature like
|
| 65 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
| 66 |
+
element-wise loss without any reduction. This decorator will add weight
|
| 67 |
+
and reduction arguments to the function. The decorated function will have
|
| 68 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
| 69 |
+
**kwargs)`.
|
| 70 |
+
|
| 71 |
+
:Example:
|
| 72 |
+
|
| 73 |
+
>>> import torch
|
| 74 |
+
>>> @weighted_loss
|
| 75 |
+
>>> def l1_loss(pred, target):
|
| 76 |
+
>>> return (pred - target).abs()
|
| 77 |
+
|
| 78 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
| 79 |
+
>>> target = torch.Tensor([1, 1, 1])
|
| 80 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
| 81 |
+
|
| 82 |
+
>>> l1_loss(pred, target)
|
| 83 |
+
tensor(1.3333)
|
| 84 |
+
>>> l1_loss(pred, target, weight)
|
| 85 |
+
tensor(1.5000)
|
| 86 |
+
>>> l1_loss(pred, target, reduction='none')
|
| 87 |
+
tensor([1., 1., 2.])
|
| 88 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
| 89 |
+
tensor(3.)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
@functools.wraps(loss_func)
|
| 93 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
| 94 |
+
# get element-wise loss
|
| 95 |
+
loss = loss_func(pred, target, **kwargs)
|
| 96 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
| 97 |
+
return loss
|
| 98 |
+
|
| 99 |
+
return wrapper
|
| 100 |
+
|
| 101 |
+
@weighted_loss
|
| 102 |
+
def l1_loss(pred, target):
|
| 103 |
+
return F.l1_loss(pred, target, reduction='none')
|
| 104 |
+
|
| 105 |
+
@weighted_loss
|
| 106 |
+
def mse_loss(pred, target):
|
| 107 |
+
return F.mse_loss(pred, target, reduction='none')
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def gaussian(window_size,sigma):
|
| 114 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
| 115 |
+
return gauss/torch.sum(gauss)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def create_window(window_size,channel=1):
|
| 119 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 120 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 121 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
| 122 |
+
return window
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def map_ssim(img1, img2, window, window_size, channel, size_average=True):
|
| 126 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
| 127 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
| 128 |
+
|
| 129 |
+
mu1_sq = mu1.pow(2)
|
| 130 |
+
mu2_sq = mu2.pow(2)
|
| 131 |
+
mu1_mu2 = mu1 * mu2
|
| 132 |
+
|
| 133 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
| 134 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
| 135 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
| 136 |
+
|
| 137 |
+
C1 = 0.01 ** 2
|
| 138 |
+
C2 = 0.03 ** 2
|
| 139 |
+
|
| 140 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
| 141 |
+
|
| 142 |
+
if size_average:
|
| 143 |
+
return ssim_map.mean()
|
| 144 |
+
else:
|
| 145 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
JarvisIR/package/agent_tools/HVICIDNet/loss/losses.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from loss.vgg_arch import VGGFeatureExtractor, Registry
|
| 5 |
+
from loss.loss_utils import *
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
| 9 |
+
|
| 10 |
+
class L1Loss(nn.Module):
|
| 11 |
+
"""L1 (mean absolute error, MAE) loss.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
| 15 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 16 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 20 |
+
super(L1Loss, self).__init__()
|
| 21 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 22 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. '
|
| 23 |
+
f'Supported ones are: {_reduction_modes}')
|
| 24 |
+
|
| 25 |
+
self.loss_weight = loss_weight
|
| 26 |
+
self.reduction = reduction
|
| 27 |
+
|
| 28 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 29 |
+
"""
|
| 30 |
+
Args:
|
| 31 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 32 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 33 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 34 |
+
weights. Default: None.
|
| 35 |
+
"""
|
| 36 |
+
return self.loss_weight * l1_loss(
|
| 37 |
+
pred, target, weight, reduction=self.reduction)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class EdgeLoss(nn.Module):
|
| 42 |
+
def __init__(self,loss_weight=1.0, reduction='mean'):
|
| 43 |
+
super(EdgeLoss, self).__init__()
|
| 44 |
+
k = torch.Tensor([[.05, .25, .4, .25, .05]])
|
| 45 |
+
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1).cuda()
|
| 46 |
+
|
| 47 |
+
self.weight = loss_weight
|
| 48 |
+
|
| 49 |
+
def conv_gauss(self, img):
|
| 50 |
+
n_channels, _, kw, kh = self.kernel.shape
|
| 51 |
+
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
|
| 52 |
+
return F.conv2d(img, self.kernel, groups=n_channels)
|
| 53 |
+
|
| 54 |
+
def laplacian_kernel(self, current):
|
| 55 |
+
filtered = self.conv_gauss(current)
|
| 56 |
+
down = filtered[:,:,::2,::2]
|
| 57 |
+
new_filter = torch.zeros_like(filtered)
|
| 58 |
+
new_filter[:,:,::2,::2] = down*4
|
| 59 |
+
filtered = self.conv_gauss(new_filter)
|
| 60 |
+
diff = current - filtered
|
| 61 |
+
return diff
|
| 62 |
+
|
| 63 |
+
def forward(self, x, y):
|
| 64 |
+
loss = mse_loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
|
| 65 |
+
return loss*self.weight
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PerceptualLoss(nn.Module):
|
| 69 |
+
"""Perceptual loss with commonly used style loss.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
| 73 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
| 74 |
+
feature layer (before relu5_4) will be extracted with weight
|
| 75 |
+
1.0 in calculting losses.
|
| 76 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
| 77 |
+
Default: 'vgg19'.
|
| 78 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
| 79 |
+
Default: True.
|
| 80 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
| 81 |
+
Default: False.
|
| 82 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
| 83 |
+
loss will be calculated and the loss will multiplied by the
|
| 84 |
+
weight. Default: 1.0.
|
| 85 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
| 86 |
+
calculated and the loss will multiplied by the weight.
|
| 87 |
+
Default: 0.
|
| 88 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self,
|
| 92 |
+
layer_weights,
|
| 93 |
+
vgg_type='vgg19',
|
| 94 |
+
use_input_norm=True,
|
| 95 |
+
range_norm=True,
|
| 96 |
+
perceptual_weight=1.0,
|
| 97 |
+
style_weight=0.,
|
| 98 |
+
criterion='l1'):
|
| 99 |
+
super(PerceptualLoss, self).__init__()
|
| 100 |
+
self.perceptual_weight = perceptual_weight
|
| 101 |
+
self.style_weight = style_weight
|
| 102 |
+
self.layer_weights = layer_weights
|
| 103 |
+
self.vgg = VGGFeatureExtractor(
|
| 104 |
+
layer_name_list=list(layer_weights.keys()),
|
| 105 |
+
vgg_type=vgg_type,
|
| 106 |
+
use_input_norm=use_input_norm,
|
| 107 |
+
range_norm=range_norm)
|
| 108 |
+
|
| 109 |
+
self.criterion_type = criterion
|
| 110 |
+
if self.criterion_type == 'l1':
|
| 111 |
+
self.criterion = torch.nn.L1Loss()
|
| 112 |
+
elif self.criterion_type == 'l2':
|
| 113 |
+
self.criterion = torch.nn.L2loss()
|
| 114 |
+
elif self.criterion_type == 'mse':
|
| 115 |
+
self.criterion = torch.nn.MSELoss(reduction='mean')
|
| 116 |
+
elif self.criterion_type == 'fro':
|
| 117 |
+
self.criterion = None
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
| 120 |
+
|
| 121 |
+
def forward(self, x, gt):
|
| 122 |
+
"""Forward function.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
| 126 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Tensor: Forward results.
|
| 130 |
+
"""
|
| 131 |
+
# extract vgg features
|
| 132 |
+
x_features = self.vgg(x)
|
| 133 |
+
gt_features = self.vgg(gt.detach())
|
| 134 |
+
|
| 135 |
+
# calculate perceptual loss
|
| 136 |
+
if self.perceptual_weight > 0:
|
| 137 |
+
percep_loss = 0
|
| 138 |
+
for k in x_features.keys():
|
| 139 |
+
if self.criterion_type == 'fro':
|
| 140 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
| 141 |
+
else:
|
| 142 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
| 143 |
+
percep_loss *= self.perceptual_weight
|
| 144 |
+
else:
|
| 145 |
+
percep_loss = None
|
| 146 |
+
|
| 147 |
+
# calculate style loss
|
| 148 |
+
if self.style_weight > 0:
|
| 149 |
+
style_loss = 0
|
| 150 |
+
for k in x_features.keys():
|
| 151 |
+
if self.criterion_type == 'fro':
|
| 152 |
+
style_loss += torch.norm(
|
| 153 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
| 154 |
+
else:
|
| 155 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
| 156 |
+
gt_features[k])) * self.layer_weights[k]
|
| 157 |
+
style_loss *= self.style_weight
|
| 158 |
+
else:
|
| 159 |
+
style_loss = None
|
| 160 |
+
|
| 161 |
+
return percep_loss, style_loss
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class SSIM(torch.nn.Module):
|
| 167 |
+
def __init__(self, window_size=11, size_average=True,weight=1.):
|
| 168 |
+
super(SSIM, self).__init__()
|
| 169 |
+
self.window_size = window_size
|
| 170 |
+
self.size_average = size_average
|
| 171 |
+
self.channel = 1
|
| 172 |
+
self.window = create_window(window_size, self.channel)
|
| 173 |
+
self.weight = weight
|
| 174 |
+
|
| 175 |
+
def forward(self, img1, img2):
|
| 176 |
+
(_, channel, _, _) = img1.size()
|
| 177 |
+
|
| 178 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
| 179 |
+
window = self.window
|
| 180 |
+
else:
|
| 181 |
+
window = create_window(self.window_size, channel)
|
| 182 |
+
|
| 183 |
+
if img1.is_cuda:
|
| 184 |
+
window = window.cuda(img1.get_device())
|
| 185 |
+
window = window.type_as(img1)
|
| 186 |
+
|
| 187 |
+
self.window = window
|
| 188 |
+
self.channel = channel
|
| 189 |
+
|
| 190 |
+
return (1. - map_ssim(img1, img2, window, self.window_size, channel, self.size_average)) * self.weight
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
|
JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_pris_params.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
|
| 3 |
+
size 11850
|
JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_utils.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.ndimage import convolve
|
| 6 |
+
from scipy.special import gamma
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
def cubic(x):
|
| 10 |
+
"""cubic function used for calculate_weights_indices."""
|
| 11 |
+
absx = torch.abs(x)
|
| 12 |
+
absx2 = absx**2
|
| 13 |
+
absx3 = absx**3
|
| 14 |
+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
| 15 |
+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
| 16 |
+
(absx <= 2)).type_as(absx))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
| 21 |
+
"""Calculate weights and indices, used for imresize function.
|
| 22 |
+
Args:
|
| 23 |
+
in_length (int): Input length.
|
| 24 |
+
out_length (int): Output length.
|
| 25 |
+
scale (float): Scale factor.
|
| 26 |
+
kernel_width (int): Kernel width.
|
| 27 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
if (scale < 1) and antialiasing:
|
| 31 |
+
# Use a modified kernel (larger kernel width) to simultaneously
|
| 32 |
+
# interpolate and antialias
|
| 33 |
+
kernel_width = kernel_width / scale
|
| 34 |
+
|
| 35 |
+
# Output-space coordinates
|
| 36 |
+
x = torch.linspace(1, out_length, out_length)
|
| 37 |
+
|
| 38 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
| 39 |
+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
| 40 |
+
# space maps to 1.5 in input space.
|
| 41 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
| 42 |
+
|
| 43 |
+
# What is the left-most pixel that can be involved in the computation?
|
| 44 |
+
left = torch.floor(u - kernel_width / 2)
|
| 45 |
+
|
| 46 |
+
# What is the maximum number of pixels that can be involved in the
|
| 47 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
| 48 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
| 49 |
+
# of this function.
|
| 50 |
+
p = math.ceil(kernel_width) + 2
|
| 51 |
+
|
| 52 |
+
# The indices of the input pixels involved in computing the k-th output
|
| 53 |
+
# pixel are in row k of the indices matrix.
|
| 54 |
+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
| 55 |
+
out_length, p)
|
| 56 |
+
|
| 57 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
| 58 |
+
# weights matrix.
|
| 59 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
| 60 |
+
|
| 61 |
+
# apply cubic kernel
|
| 62 |
+
if (scale < 1) and antialiasing:
|
| 63 |
+
weights = scale * cubic(distance_to_center * scale)
|
| 64 |
+
else:
|
| 65 |
+
weights = cubic(distance_to_center)
|
| 66 |
+
|
| 67 |
+
# Normalize the weights matrix so that each row sums to 1.
|
| 68 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
| 69 |
+
weights = weights / weights_sum.expand(out_length, p)
|
| 70 |
+
|
| 71 |
+
# If a column in weights is all zero, get rid of it. only consider the
|
| 72 |
+
# first and last column.
|
| 73 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
| 74 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
| 75 |
+
indices = indices.narrow(1, 1, p - 2)
|
| 76 |
+
weights = weights.narrow(1, 1, p - 2)
|
| 77 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
| 78 |
+
indices = indices.narrow(1, 0, p - 2)
|
| 79 |
+
weights = weights.narrow(1, 0, p - 2)
|
| 80 |
+
weights = weights.contiguous()
|
| 81 |
+
indices = indices.contiguous()
|
| 82 |
+
sym_len_s = -indices.min() + 1
|
| 83 |
+
sym_len_e = indices.max() - in_length
|
| 84 |
+
indices = indices + sym_len_s - 1
|
| 85 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
| 86 |
+
|
| 87 |
+
def imresize(img, scale, antialiasing=True):
|
| 88 |
+
"""imresize function same as MATLAB.
|
| 89 |
+
It now only supports bicubic.
|
| 90 |
+
The same scale applies for both height and width.
|
| 91 |
+
Args:
|
| 92 |
+
img (Tensor | Numpy array):
|
| 93 |
+
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
| 94 |
+
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
| 95 |
+
scale (float): Scale factor. The same scale applies for both height
|
| 96 |
+
and width.
|
| 97 |
+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
| 98 |
+
Default: True.
|
| 99 |
+
Returns:
|
| 100 |
+
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
| 101 |
+
"""
|
| 102 |
+
squeeze_flag = False
|
| 103 |
+
if type(img).__module__ == np.__name__: # numpy type
|
| 104 |
+
numpy_type = True
|
| 105 |
+
if img.ndim == 2:
|
| 106 |
+
img = img[:, :, None]
|
| 107 |
+
squeeze_flag = True
|
| 108 |
+
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
| 109 |
+
else:
|
| 110 |
+
numpy_type = False
|
| 111 |
+
if img.ndim == 2:
|
| 112 |
+
img = img.unsqueeze(0)
|
| 113 |
+
squeeze_flag = True
|
| 114 |
+
|
| 115 |
+
in_c, in_h, in_w = img.size()
|
| 116 |
+
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
| 117 |
+
kernel_width = 4
|
| 118 |
+
kernel = 'cubic'
|
| 119 |
+
|
| 120 |
+
# get weights and indices
|
| 121 |
+
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
| 122 |
+
antialiasing)
|
| 123 |
+
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
| 124 |
+
antialiasing)
|
| 125 |
+
# process H dimension
|
| 126 |
+
# symmetric copying
|
| 127 |
+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
| 128 |
+
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
| 129 |
+
|
| 130 |
+
sym_patch = img[:, :sym_len_hs, :]
|
| 131 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 132 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 133 |
+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
| 134 |
+
|
| 135 |
+
sym_patch = img[:, -sym_len_he:, :]
|
| 136 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 137 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 138 |
+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
| 139 |
+
|
| 140 |
+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
| 141 |
+
kernel_width = weights_h.size(1)
|
| 142 |
+
for i in range(out_h):
|
| 143 |
+
idx = int(indices_h[i][0])
|
| 144 |
+
for j in range(in_c):
|
| 145 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
| 146 |
+
|
| 147 |
+
# process W dimension
|
| 148 |
+
# symmetric copying
|
| 149 |
+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
| 150 |
+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
| 151 |
+
|
| 152 |
+
sym_patch = out_1[:, :, :sym_len_ws]
|
| 153 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 154 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 155 |
+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
| 156 |
+
|
| 157 |
+
sym_patch = out_1[:, :, -sym_len_we:]
|
| 158 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 159 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 160 |
+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
| 161 |
+
|
| 162 |
+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
| 163 |
+
kernel_width = weights_w.size(1)
|
| 164 |
+
for i in range(out_w):
|
| 165 |
+
idx = int(indices_w[i][0])
|
| 166 |
+
for j in range(in_c):
|
| 167 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
| 168 |
+
|
| 169 |
+
if squeeze_flag:
|
| 170 |
+
out_2 = out_2.squeeze(0)
|
| 171 |
+
if numpy_type:
|
| 172 |
+
out_2 = out_2.numpy()
|
| 173 |
+
if not squeeze_flag:
|
| 174 |
+
out_2 = out_2.transpose(1, 2, 0)
|
| 175 |
+
|
| 176 |
+
return out_2
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _convert_input_type_range(img):
|
| 180 |
+
"""Convert the type and range of the input image.
|
| 181 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
| 182 |
+
It is mainly used for pre-processing the input image in colorspace
|
| 183 |
+
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
| 184 |
+
Args:
|
| 185 |
+
img (ndarray): The input image. It accepts:
|
| 186 |
+
1. np.uint8 type with range [0, 255];
|
| 187 |
+
2. np.float32 type with range [0, 1].
|
| 188 |
+
Returns:
|
| 189 |
+
(ndarray): The converted image with type of np.float32 and range of
|
| 190 |
+
[0, 1].
|
| 191 |
+
"""
|
| 192 |
+
img_type = img.dtype
|
| 193 |
+
img = img.astype(np.float32)
|
| 194 |
+
if img_type == np.float32:
|
| 195 |
+
pass
|
| 196 |
+
elif img_type == np.uint8:
|
| 197 |
+
img /= 255.
|
| 198 |
+
else:
|
| 199 |
+
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
|
| 200 |
+
return img
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _convert_output_type_range(img, dst_type):
|
| 204 |
+
"""Convert the type and range of the image according to dst_type.
|
| 205 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
| 206 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
| 207 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
| 208 |
+
range [0, 1].
|
| 209 |
+
It is mainly used for post-processing images in colorspace conversion
|
| 210 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
| 211 |
+
Args:
|
| 212 |
+
img (ndarray): The image to be converted with np.float32 type and
|
| 213 |
+
range [0, 255].
|
| 214 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
| 215 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
| 216 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
| 217 |
+
with range [0, 1].
|
| 218 |
+
Returns:
|
| 219 |
+
(ndarray): The converted image with desired type and range.
|
| 220 |
+
"""
|
| 221 |
+
if dst_type not in (np.uint8, np.float32):
|
| 222 |
+
raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
|
| 223 |
+
if dst_type == np.uint8:
|
| 224 |
+
img = img.round()
|
| 225 |
+
else:
|
| 226 |
+
img /= 255.
|
| 227 |
+
return img.astype(dst_type)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def rgb2ycbcr(img, y_only=False):
|
| 232 |
+
"""Convert a RGB image to YCbCr image.
|
| 233 |
+
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
| 234 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 235 |
+
television. See more details in
|
| 236 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 237 |
+
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
| 238 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 239 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 240 |
+
Args:
|
| 241 |
+
img (ndarray): The input image. It accepts:
|
| 242 |
+
1. np.uint8 type with range [0, 255];
|
| 243 |
+
2. np.float32 type with range [0, 1].
|
| 244 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 245 |
+
Returns:
|
| 246 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
| 247 |
+
and range as input image.
|
| 248 |
+
"""
|
| 249 |
+
img_type = img.dtype
|
| 250 |
+
img = _convert_input_type_range(img)
|
| 251 |
+
if y_only:
|
| 252 |
+
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
| 253 |
+
else:
|
| 254 |
+
out_img = np.matmul(
|
| 255 |
+
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
| 256 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 257 |
+
return out_img
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def bgr2ycbcr(img, y_only=False):
|
| 261 |
+
"""Convert a BGR image to YCbCr image.
|
| 262 |
+
The bgr version of rgb2ycbcr.
|
| 263 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 264 |
+
television. See more details in
|
| 265 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 266 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
| 267 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 268 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 269 |
+
Args:
|
| 270 |
+
img (ndarray): The input image. It accepts:
|
| 271 |
+
1. np.uint8 type with range [0, 255];
|
| 272 |
+
2. np.float32 type with range [0, 1].
|
| 273 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 274 |
+
Returns:
|
| 275 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
| 276 |
+
and range as input image.
|
| 277 |
+
"""
|
| 278 |
+
img_type = img.dtype
|
| 279 |
+
img = _convert_input_type_range(img)
|
| 280 |
+
if y_only:
|
| 281 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
| 282 |
+
else:
|
| 283 |
+
out_img = np.matmul(
|
| 284 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
| 285 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 286 |
+
return out_img
|
| 287 |
+
|
| 288 |
+
def ycbcr2rgb(img):
|
| 289 |
+
"""Convert a YCbCr image to RGB image.
|
| 290 |
+
This function produces the same results as Matlab's ycbcr2rgb function.
|
| 291 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
| 292 |
+
television. See more details in
|
| 293 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 294 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
| 295 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
| 296 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
| 297 |
+
Args:
|
| 298 |
+
img (ndarray): The input image. It accepts:
|
| 299 |
+
1. np.uint8 type with range [0, 255];
|
| 300 |
+
2. np.float32 type with range [0, 1].
|
| 301 |
+
Returns:
|
| 302 |
+
ndarray: The converted RGB image. The output image has the same type
|
| 303 |
+
and range as input image.
|
| 304 |
+
"""
|
| 305 |
+
img_type = img.dtype
|
| 306 |
+
img = _convert_input_type_range(img) * 255
|
| 307 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
| 308 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
| 309 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
| 310 |
+
return out_img
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def to_y_channel(img):
|
| 314 |
+
"""Change to Y channel of YCbCr.
|
| 315 |
+
Args:
|
| 316 |
+
img (ndarray): Images with range [0, 255].
|
| 317 |
+
Returns:
|
| 318 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
| 319 |
+
"""
|
| 320 |
+
img = img.astype(np.float32) / 255.
|
| 321 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
| 322 |
+
img = bgr2ycbcr(img, y_only=True)
|
| 323 |
+
img = img[..., None]
|
| 324 |
+
return img * 255.
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def reorder_image(img, input_order='HWC'):
|
| 328 |
+
"""Reorder images to 'HWC' order.
|
| 329 |
+
If the input_order is (h, w), return (h, w, 1);
|
| 330 |
+
If the input_order is (c, h, w), return (h, w, c);
|
| 331 |
+
If the input_order is (h, w, c), return as it is.
|
| 332 |
+
Args:
|
| 333 |
+
img (ndarray): Input image.
|
| 334 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 335 |
+
If the input image shape is (h, w), input_order will not have
|
| 336 |
+
effects. Default: 'HWC'.
|
| 337 |
+
Returns:
|
| 338 |
+
ndarray: reordered image.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
if input_order not in ['HWC', 'CHW']:
|
| 342 |
+
raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
|
| 343 |
+
if len(img.shape) == 2:
|
| 344 |
+
img = img[..., None]
|
| 345 |
+
if input_order == 'CHW':
|
| 346 |
+
img = img.transpose(1, 2, 0)
|
| 347 |
+
return img
|
| 348 |
+
|
| 349 |
+
def rgb2ycbcr_pt(img, y_only=False):
|
| 350 |
+
"""Convert RGB images to YCbCr images (PyTorch version).
|
| 351 |
+
It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
|
| 352 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
| 353 |
+
Args:
|
| 354 |
+
img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
|
| 355 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
| 356 |
+
Returns:
|
| 357 |
+
(Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
|
| 358 |
+
"""
|
| 359 |
+
if y_only:
|
| 360 |
+
weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
|
| 361 |
+
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
|
| 362 |
+
else:
|
| 363 |
+
weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
|
| 364 |
+
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
|
| 365 |
+
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
|
| 366 |
+
|
| 367 |
+
out_img = out_img / 255.
|
| 368 |
+
return
|
| 369 |
+
|
| 370 |
+
def tensor2img(tensor):
|
| 371 |
+
im = (255. * tensor).data.cpu().numpy()
|
| 372 |
+
# clamp
|
| 373 |
+
im[im > 255] = 255
|
| 374 |
+
im[im < 0] = 0
|
| 375 |
+
im = im.astype(np.uint8)
|
| 376 |
+
return im
|
| 377 |
+
|
| 378 |
+
def img2tensor(img):
|
| 379 |
+
img = (img / 255.).astype('float32')
|
| 380 |
+
if img.ndim ==2:
|
| 381 |
+
img = np.expand_dims(np.expand_dims(img, axis = 0),axis=0)
|
| 382 |
+
else:
|
| 383 |
+
img = np.transpose(img, (2, 0, 1)) # C, H, W
|
| 384 |
+
img = np.expand_dims(img, axis=0)
|
| 385 |
+
img = np.ascontiguousarray(img, dtype=np.float32)
|
| 386 |
+
tensor = torch.from_numpy(img)
|
| 387 |
+
return tensor
|
| 388 |
+
|
| 389 |
+
def estimate_aggd_param(block):
|
| 390 |
+
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
|
| 391 |
+
Args:
|
| 392 |
+
block (ndarray): 2D Image block.
|
| 393 |
+
Returns:
|
| 394 |
+
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
|
| 395 |
+
distribution (Estimating the parames in Equation 7 in the paper).
|
| 396 |
+
"""
|
| 397 |
+
block = block.flatten()
|
| 398 |
+
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
|
| 399 |
+
gam_reciprocal = np.reciprocal(gam)
|
| 400 |
+
r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
|
| 401 |
+
|
| 402 |
+
left_std = np.sqrt(np.mean(block[block < 0]**2))
|
| 403 |
+
right_std = np.sqrt(np.mean(block[block > 0]**2))
|
| 404 |
+
gammahat = left_std / right_std
|
| 405 |
+
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
|
| 406 |
+
rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
|
| 407 |
+
array_position = np.argmin((r_gam - rhatnorm)**2)
|
| 408 |
+
|
| 409 |
+
alpha = gam[array_position]
|
| 410 |
+
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
| 411 |
+
beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
|
| 412 |
+
return (alpha, beta_l, beta_r)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def compute_feature(block):
|
| 416 |
+
"""Compute features.
|
| 417 |
+
Args:
|
| 418 |
+
block (ndarray): 2D Image block.
|
| 419 |
+
Returns:
|
| 420 |
+
list: Features with length of 18.
|
| 421 |
+
"""
|
| 422 |
+
feat = []
|
| 423 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block)
|
| 424 |
+
feat.extend([alpha, (beta_l + beta_r) / 2])
|
| 425 |
+
|
| 426 |
+
# distortions disturb the fairly regular structure of natural images.
|
| 427 |
+
# This deviation can be captured by analyzing the sample distribution of
|
| 428 |
+
# the products of pairs of adjacent coefficients computed along
|
| 429 |
+
# horizontal, vertical and diagonal orientations.
|
| 430 |
+
shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
|
| 431 |
+
for i in range(len(shifts)):
|
| 432 |
+
shifted_block = np.roll(block, shifts[i], axis=(0, 1))
|
| 433 |
+
alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
|
| 434 |
+
# Eq. 8
|
| 435 |
+
mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
|
| 436 |
+
feat.extend([alpha, mean, beta_l, beta_r])
|
| 437 |
+
return feat
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
|
| 441 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
| 442 |
+
``Paper: Making a "Completely Blind" Image Quality Analyzer``
|
| 443 |
+
This implementation could produce almost the same results as the official
|
| 444 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
| 445 |
+
Note that we do not include block overlap height and width, since they are
|
| 446 |
+
always 0 in the official implementation.
|
| 447 |
+
For good performance, it is advisable by the official implementation to
|
| 448 |
+
divide the distorted image in to the same size patched as used for the
|
| 449 |
+
construction of multivariate Gaussian model.
|
| 450 |
+
Args:
|
| 451 |
+
img (ndarray): Input image whose quality needs to be computed. The
|
| 452 |
+
image must be a gray or Y (of YCbCr) image with shape (h, w).
|
| 453 |
+
Range [0, 255] with float type.
|
| 454 |
+
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
|
| 455 |
+
model calculated on the pristine dataset.
|
| 456 |
+
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
|
| 457 |
+
Gaussian model calculated on the pristine dataset.
|
| 458 |
+
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
|
| 459 |
+
image.
|
| 460 |
+
block_size_h (int): Height of the blocks in to which image is divided.
|
| 461 |
+
Default: 96 (the official recommended value).
|
| 462 |
+
block_size_w (int): Width of the blocks in to which image is divided.
|
| 463 |
+
Default: 96 (the official recommended value).
|
| 464 |
+
"""
|
| 465 |
+
assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
|
| 466 |
+
# crop image
|
| 467 |
+
h, w = img.shape
|
| 468 |
+
num_block_h = math.floor(h / block_size_h)
|
| 469 |
+
num_block_w = math.floor(w / block_size_w)
|
| 470 |
+
img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
|
| 471 |
+
|
| 472 |
+
distparam = [] # dist param is actually the multiscale features
|
| 473 |
+
for scale in (1, 2): # perform on two scales (1, 2)
|
| 474 |
+
mu = convolve(img, gaussian_window, mode='nearest')
|
| 475 |
+
sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
|
| 476 |
+
# normalize, as in Eq. 1 in the paper
|
| 477 |
+
img_nomalized = (img - mu) / (sigma + 1)
|
| 478 |
+
|
| 479 |
+
feat = []
|
| 480 |
+
for idx_w in range(num_block_w):
|
| 481 |
+
for idx_h in range(num_block_h):
|
| 482 |
+
# process ecah block
|
| 483 |
+
block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
|
| 484 |
+
idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
|
| 485 |
+
feat.append(compute_feature(block))
|
| 486 |
+
|
| 487 |
+
distparam.append(np.array(feat))
|
| 488 |
+
|
| 489 |
+
if scale == 1:
|
| 490 |
+
img = imresize(img / 255., scale=0.5, antialiasing=True)
|
| 491 |
+
img = img * 255.
|
| 492 |
+
|
| 493 |
+
distparam = np.concatenate(distparam, axis=1)
|
| 494 |
+
|
| 495 |
+
# fit a MVG (multivariate Gaussian) model to distorted patch features
|
| 496 |
+
mu_distparam = np.nanmean(distparam, axis=0)
|
| 497 |
+
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
|
| 498 |
+
distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
|
| 499 |
+
cov_distparam = np.cov(distparam_no_nan, rowvar=False)
|
| 500 |
+
|
| 501 |
+
# compute niqe quality, Eq. 10 in the paper
|
| 502 |
+
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
|
| 503 |
+
quality = np.matmul(
|
| 504 |
+
np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
|
| 505 |
+
|
| 506 |
+
quality = np.sqrt(quality)
|
| 507 |
+
quality = float(np.squeeze(quality))
|
| 508 |
+
return quality
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def calculate_niqe(img, crop_border=0,input_order='HWC', convert_to='y', **kwargs):
|
| 512 |
+
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
|
| 513 |
+
``Paper: Making a "Completely Blind" Image Quality Analyzer``
|
| 514 |
+
This implementation could produce almost the same results as the official
|
| 515 |
+
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
|
| 516 |
+
> MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
|
| 517 |
+
> Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
|
| 518 |
+
We use the official params estimated from the pristine dataset.
|
| 519 |
+
We use the recommended block size (96, 96) without overlaps.
|
| 520 |
+
Args:
|
| 521 |
+
img (ndarray): Input image whose quality needs to be computed.
|
| 522 |
+
The input image must be in range [0, 255] with float/int type.
|
| 523 |
+
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
|
| 524 |
+
If the input order is 'HWC' or 'CHW', it will be converted to gray
|
| 525 |
+
or Y (of YCbCr) image according to the ``convert_to`` argument.
|
| 526 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 527 |
+
pixels are not involved in the metric calculation.
|
| 528 |
+
input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
|
| 529 |
+
Default: 'HWC'.
|
| 530 |
+
convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
|
| 531 |
+
Default: 'y'.
|
| 532 |
+
Returns:
|
| 533 |
+
float: NIQE result.
|
| 534 |
+
"""
|
| 535 |
+
# ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 536 |
+
# we use the official params estimated from the pristine dataset.
|
| 537 |
+
niqe_pris_params = np.load('./loss/niqe_pris_params.npz')
|
| 538 |
+
mu_pris_param = niqe_pris_params['mu_pris_param']
|
| 539 |
+
cov_pris_param = niqe_pris_params['cov_pris_param']
|
| 540 |
+
gaussian_window = niqe_pris_params['gaussian_window']
|
| 541 |
+
|
| 542 |
+
img = img.astype(np.float32)
|
| 543 |
+
if input_order != 'HW':
|
| 544 |
+
img = reorder_image(img, input_order=input_order)
|
| 545 |
+
if convert_to == 'y':
|
| 546 |
+
img = to_y_channel(img)
|
| 547 |
+
elif convert_to == 'gray':
|
| 548 |
+
img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
|
| 549 |
+
img = np.squeeze(img)
|
| 550 |
+
|
| 551 |
+
if crop_border != 0:
|
| 552 |
+
img = img[crop_border:-crop_border, crop_border:-crop_border]
|
| 553 |
+
|
| 554 |
+
# round is necessary for being consistent with MATLAB's result
|
| 555 |
+
img = img.round()
|
| 556 |
+
|
| 557 |
+
niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
|
| 558 |
+
|
| 559 |
+
return niqe_result
|
JarvisIR/package/agent_tools/HVICIDNet/loss/vgg_arch.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from torch import nn as nn
|
| 5 |
+
from torchvision.models import vgg as vgg
|
| 6 |
+
|
| 7 |
+
class Registry():
|
| 8 |
+
"""
|
| 9 |
+
The registry that provides name -> object mapping, to support third-party
|
| 10 |
+
users' custom modules.
|
| 11 |
+
|
| 12 |
+
To create a registry (e.g. a backbone registry):
|
| 13 |
+
|
| 14 |
+
.. code-block:: python
|
| 15 |
+
|
| 16 |
+
BACKBONE_REGISTRY = Registry('BACKBONE')
|
| 17 |
+
|
| 18 |
+
To register an object:
|
| 19 |
+
|
| 20 |
+
.. code-block:: python
|
| 21 |
+
|
| 22 |
+
@BACKBONE_REGISTRY.register()
|
| 23 |
+
class MyBackbone():
|
| 24 |
+
...
|
| 25 |
+
|
| 26 |
+
Or:
|
| 27 |
+
|
| 28 |
+
.. code-block:: python
|
| 29 |
+
|
| 30 |
+
BACKBONE_REGISTRY.register(MyBackbone)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, name):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
name (str): the name of this registry
|
| 37 |
+
"""
|
| 38 |
+
self._name = name
|
| 39 |
+
self._obj_map = {}
|
| 40 |
+
|
| 41 |
+
def _do_register(self, name, obj):
|
| 42 |
+
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
| 43 |
+
f"in '{self._name}' registry!")
|
| 44 |
+
self._obj_map[name] = obj
|
| 45 |
+
|
| 46 |
+
def register(self, obj=None):
|
| 47 |
+
"""
|
| 48 |
+
Register the given object under the the name `obj.__name__`.
|
| 49 |
+
Can be used as either a decorator or not.
|
| 50 |
+
See docstring of this class for usage.
|
| 51 |
+
"""
|
| 52 |
+
if obj is None:
|
| 53 |
+
# used as a decorator
|
| 54 |
+
def deco(func_or_class):
|
| 55 |
+
name = func_or_class.__name__
|
| 56 |
+
self._do_register(name, func_or_class)
|
| 57 |
+
return func_or_class
|
| 58 |
+
|
| 59 |
+
return deco
|
| 60 |
+
|
| 61 |
+
# used as a function call
|
| 62 |
+
name = obj.__name__
|
| 63 |
+
self._do_register(name, obj)
|
| 64 |
+
|
| 65 |
+
def get(self, name):
|
| 66 |
+
ret = self._obj_map.get(name)
|
| 67 |
+
if ret is None:
|
| 68 |
+
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
| 69 |
+
return ret
|
| 70 |
+
|
| 71 |
+
def __contains__(self, name):
|
| 72 |
+
return name in self._obj_map
|
| 73 |
+
|
| 74 |
+
def __iter__(self):
|
| 75 |
+
return iter(self._obj_map.items())
|
| 76 |
+
|
| 77 |
+
def keys(self):
|
| 78 |
+
return self._obj_map.keys()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
DATASET_REGISTRY = Registry('dataset')
|
| 82 |
+
ARCH_REGISTRY = Registry('arch')
|
| 83 |
+
MODEL_REGISTRY = Registry('model')
|
| 84 |
+
LOSS_REGISTRY = Registry('loss')
|
| 85 |
+
METRIC_REGISTRY = Registry('metric')
|
| 86 |
+
|
| 87 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
| 88 |
+
NAMES = {
|
| 89 |
+
'vgg11': [
|
| 90 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
| 91 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
| 92 |
+
'pool5'
|
| 93 |
+
],
|
| 94 |
+
'vgg13': [
|
| 95 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 96 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
| 97 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
| 98 |
+
],
|
| 99 |
+
'vgg16': [
|
| 100 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 101 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
| 102 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
| 103 |
+
'pool5'
|
| 104 |
+
],
|
| 105 |
+
'vgg19': [
|
| 106 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 107 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
| 108 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
| 109 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
| 110 |
+
]
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def insert_bn(names):
|
| 115 |
+
"""Insert bn layer after each conv.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
names (list): The list of layer names.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
list: The list of layer names with bn layers.
|
| 122 |
+
"""
|
| 123 |
+
names_bn = []
|
| 124 |
+
for name in names:
|
| 125 |
+
names_bn.append(name)
|
| 126 |
+
if 'conv' in name:
|
| 127 |
+
position = name.replace('conv', '')
|
| 128 |
+
names_bn.append('bn' + position)
|
| 129 |
+
return names_bn
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@ARCH_REGISTRY.register()
|
| 133 |
+
class VGGFeatureExtractor(nn.Module):
|
| 134 |
+
"""VGG network for feature extraction.
|
| 135 |
+
|
| 136 |
+
In this implementation, we allow users to choose whether use normalization
|
| 137 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
| 138 |
+
path must fit the vgg type.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
| 142 |
+
features according to the layer_name_list.
|
| 143 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
| 144 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
| 145 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
| 146 |
+
the input feature must in the range [0, 1]. Default: True.
|
| 147 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
| 148 |
+
Default: False.
|
| 149 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
| 150 |
+
optimized. Default: False.
|
| 151 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
| 152 |
+
will be removed. Default: False.
|
| 153 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self,
|
| 157 |
+
layer_name_list,
|
| 158 |
+
vgg_type='vgg19',
|
| 159 |
+
use_input_norm=True,
|
| 160 |
+
range_norm=False,
|
| 161 |
+
requires_grad=False,
|
| 162 |
+
remove_pooling=False,
|
| 163 |
+
pooling_stride=2):
|
| 164 |
+
super(VGGFeatureExtractor, self).__init__()
|
| 165 |
+
|
| 166 |
+
self.layer_name_list = layer_name_list
|
| 167 |
+
self.use_input_norm = use_input_norm
|
| 168 |
+
self.range_norm = range_norm
|
| 169 |
+
|
| 170 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
| 171 |
+
if 'bn' in vgg_type:
|
| 172 |
+
self.names = insert_bn(self.names)
|
| 173 |
+
|
| 174 |
+
# only borrow layers that will be used to avoid unused params
|
| 175 |
+
max_idx = 0
|
| 176 |
+
for v in layer_name_list:
|
| 177 |
+
idx = self.names.index(v)
|
| 178 |
+
if idx > max_idx:
|
| 179 |
+
max_idx = idx
|
| 180 |
+
|
| 181 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
| 182 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
| 183 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
| 184 |
+
vgg_net.load_state_dict(state_dict)
|
| 185 |
+
else:
|
| 186 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
| 187 |
+
|
| 188 |
+
features = vgg_net.features[:max_idx + 1]
|
| 189 |
+
|
| 190 |
+
modified_net = OrderedDict()
|
| 191 |
+
for k, v in zip(self.names, features):
|
| 192 |
+
if 'pool' in k:
|
| 193 |
+
# if remove_pooling is true, pooling operation will be removed
|
| 194 |
+
if remove_pooling:
|
| 195 |
+
continue
|
| 196 |
+
else:
|
| 197 |
+
# in some cases, we may want to change the default stride
|
| 198 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
| 199 |
+
else:
|
| 200 |
+
modified_net[k] = v
|
| 201 |
+
|
| 202 |
+
self.vgg_net = nn.Sequential(modified_net).cuda()
|
| 203 |
+
|
| 204 |
+
if not requires_grad:
|
| 205 |
+
self.vgg_net.eval()
|
| 206 |
+
for param in self.parameters():
|
| 207 |
+
param.requires_grad = False
|
| 208 |
+
else:
|
| 209 |
+
self.vgg_net.train()
|
| 210 |
+
for param in self.parameters():
|
| 211 |
+
param.requires_grad = True
|
| 212 |
+
|
| 213 |
+
if self.use_input_norm:
|
| 214 |
+
# the mean is for image with range [0, 1]
|
| 215 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda())
|
| 216 |
+
# the std is for image with range [0, 1]
|
| 217 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda())
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
"""Forward function.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Tensor: Forward results.
|
| 227 |
+
"""
|
| 228 |
+
if self.range_norm:
|
| 229 |
+
x = (x + 1) / 2
|
| 230 |
+
if self.use_input_norm:
|
| 231 |
+
x = (x - self.mean) / self.std
|
| 232 |
+
output = {}
|
| 233 |
+
|
| 234 |
+
for key, layer in self.vgg_net._modules.items():
|
| 235 |
+
x = layer(x)
|
| 236 |
+
if key in self.layer_name_list:
|
| 237 |
+
output[key] = x.clone()
|
| 238 |
+
|
| 239 |
+
return output
|
JarvisIR/package/agent_tools/HVICIDNet/mods.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import warnings
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 7 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class cross_attention(nn.Module):
|
| 11 |
+
def __init__(self, dim, num_heads, dropout=0.):
|
| 12 |
+
super(cross_attention, self).__init__()
|
| 13 |
+
if dim % num_heads != 0:
|
| 14 |
+
raise ValueError(
|
| 15 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 16 |
+
"heads (%d)" % (dim, num_heads)
|
| 17 |
+
)
|
| 18 |
+
self.num_heads = num_heads
|
| 19 |
+
self.attention_head_size = int(dim / num_heads)
|
| 20 |
+
|
| 21 |
+
self.query = Depth_conv(in_ch=dim, out_ch=dim)
|
| 22 |
+
self.key = Depth_conv(in_ch=dim, out_ch=dim)
|
| 23 |
+
self.value = Depth_conv(in_ch=dim, out_ch=dim)
|
| 24 |
+
|
| 25 |
+
self.dropout = nn.Dropout(dropout)
|
| 26 |
+
|
| 27 |
+
def transpose_for_scores(self, x):
|
| 28 |
+
'''
|
| 29 |
+
new_x_shape = x.size()[:-1] + (
|
| 30 |
+
self.num_heads,
|
| 31 |
+
self.attention_head_size,
|
| 32 |
+
)
|
| 33 |
+
print(new_x_shape)
|
| 34 |
+
x = x.view(*new_x_shape)
|
| 35 |
+
'''
|
| 36 |
+
return x.permute(0, 2, 1, 3)
|
| 37 |
+
|
| 38 |
+
def forward(self, hidden_states, ctx):
|
| 39 |
+
mixed_query_layer = self.query(hidden_states)
|
| 40 |
+
mixed_key_layer = self.key(ctx)
|
| 41 |
+
mixed_value_layer = self.value(ctx)
|
| 42 |
+
|
| 43 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 44 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 45 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 46 |
+
|
| 47 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 48 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 49 |
+
|
| 50 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 51 |
+
|
| 52 |
+
attention_probs = self.dropout(attention_probs)
|
| 53 |
+
|
| 54 |
+
ctx_layer = torch.matmul(attention_probs, value_layer)
|
| 55 |
+
ctx_layer = ctx_layer.permute(0, 2, 1, 3).contiguous()
|
| 56 |
+
|
| 57 |
+
return ctx_layer
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Depth_conv(nn.Module):
|
| 61 |
+
def __init__(self, in_ch, out_ch):
|
| 62 |
+
super(Depth_conv, self).__init__()
|
| 63 |
+
self.depth_conv = nn.Conv2d(
|
| 64 |
+
in_channels=in_ch,
|
| 65 |
+
out_channels=in_ch,
|
| 66 |
+
kernel_size=(3, 3),
|
| 67 |
+
stride=(1, 1),
|
| 68 |
+
padding=1,
|
| 69 |
+
groups=in_ch
|
| 70 |
+
)
|
| 71 |
+
self.point_conv = nn.Conv2d(
|
| 72 |
+
in_channels=in_ch,
|
| 73 |
+
out_channels=out_ch,
|
| 74 |
+
kernel_size=(1, 1),
|
| 75 |
+
stride=(1, 1),
|
| 76 |
+
padding=0,
|
| 77 |
+
groups=1
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, input):
|
| 81 |
+
out = self.depth_conv(input)
|
| 82 |
+
out = self.point_conv(out)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Dilated_Resblock(nn.Module):
|
| 87 |
+
def __init__(self, in_channels, out_channels):
|
| 88 |
+
super(Dilated_Resblock, self).__init__()
|
| 89 |
+
|
| 90 |
+
sequence = list()
|
| 91 |
+
sequence += [
|
| 92 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1),
|
| 93 |
+
padding=1, dilation=(1, 1)),
|
| 94 |
+
nn.LeakyReLU(),
|
| 95 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1),
|
| 96 |
+
padding=2, dilation=(2, 2)),
|
| 97 |
+
nn.LeakyReLU(),
|
| 98 |
+
nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
|
| 99 |
+
padding=3, dilation=(3, 3)),
|
| 100 |
+
nn.LeakyReLU(),
|
| 101 |
+
nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
|
| 102 |
+
padding=2, dilation=(2, 2)),
|
| 103 |
+
nn.LeakyReLU(),
|
| 104 |
+
nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
|
| 105 |
+
padding=1, dilation=(1, 1))
|
| 106 |
+
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
self.model = nn.Sequential(*sequence)
|
| 110 |
+
|
| 111 |
+
def forward(self, x):
|
| 112 |
+
out = self.model(x) + x
|
| 113 |
+
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class HFRM(nn.Module):
|
| 118 |
+
def __init__(self, in_channels, out_channels):
|
| 119 |
+
super(HFRM, self).__init__()
|
| 120 |
+
|
| 121 |
+
self.conv_head = Depth_conv(in_channels, out_channels)
|
| 122 |
+
|
| 123 |
+
self.dilated_block_LH = Dilated_Resblock(out_channels, out_channels)
|
| 124 |
+
self.dilated_block_HL = Dilated_Resblock(out_channels, out_channels)
|
| 125 |
+
|
| 126 |
+
self.cross_attention0 = cross_attention(out_channels, num_heads=8)
|
| 127 |
+
self.dilated_block_HH = Dilated_Resblock(out_channels, out_channels)
|
| 128 |
+
self.conv_HH = nn.Conv2d(out_channels*2, out_channels, kernel_size=3, stride=1, padding=1)
|
| 129 |
+
self.cross_attention1 = cross_attention(out_channels, num_heads=8)
|
| 130 |
+
|
| 131 |
+
self.conv_tail = Depth_conv(out_channels, in_channels)
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
|
| 135 |
+
b, c, h, w = x.shape
|
| 136 |
+
|
| 137 |
+
residual = x
|
| 138 |
+
|
| 139 |
+
x = self.conv_head(x)
|
| 140 |
+
|
| 141 |
+
x_HL, x_LH, x_HH = x[:b//3, ...], x[b//3:2*b//3, ...], x[2*b//3:, ...]
|
| 142 |
+
|
| 143 |
+
x_HH_LH = self.cross_attention0(x_LH, x_HH)
|
| 144 |
+
x_HH_HL = self.cross_attention1(x_HL, x_HH)
|
| 145 |
+
|
| 146 |
+
x_HL = self.dilated_block_HL(x_HL)
|
| 147 |
+
x_LH = self.dilated_block_LH(x_LH)
|
| 148 |
+
|
| 149 |
+
x_HH = self.dilated_block_HH(self.conv_HH(torch.cat((x_HH_LH, x_HH_HL), dim=1)))
|
| 150 |
+
|
| 151 |
+
out = self.conv_tail(torch.cat((x_HL, x_LH, x_HH), dim=0))
|
| 152 |
+
|
| 153 |
+
return out + residual
|
JarvisIR/package/agent_tools/HVICIDNet/net/CIDNet.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .HVI_transform import RGB_HVI
|
| 4 |
+
from .transformer_utils import *
|
| 5 |
+
from .LCA import *
|
| 6 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 7 |
+
|
| 8 |
+
class CIDNet(nn.Module, PyTorchModelHubMixin):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
channels=[36, 36, 72, 144],
|
| 11 |
+
heads=[1, 2, 4, 8],
|
| 12 |
+
norm=False
|
| 13 |
+
):
|
| 14 |
+
super(CIDNet, self).__init__()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
[ch1, ch2, ch3, ch4] = channels
|
| 18 |
+
[head1, head2, head3, head4] = heads
|
| 19 |
+
|
| 20 |
+
# HV_ways
|
| 21 |
+
self.HVE_block0 = nn.Sequential(
|
| 22 |
+
nn.ReplicationPad2d(1),
|
| 23 |
+
nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
|
| 24 |
+
)
|
| 25 |
+
self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
|
| 26 |
+
self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
|
| 27 |
+
self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
|
| 28 |
+
|
| 29 |
+
self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
|
| 30 |
+
self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
|
| 31 |
+
self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
|
| 32 |
+
self.HVD_block0 = nn.Sequential(
|
| 33 |
+
nn.ReplicationPad2d(1),
|
| 34 |
+
nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# I_ways
|
| 39 |
+
self.IE_block0 = nn.Sequential(
|
| 40 |
+
nn.ReplicationPad2d(1),
|
| 41 |
+
nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
|
| 42 |
+
)
|
| 43 |
+
self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
|
| 44 |
+
self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
|
| 45 |
+
self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
|
| 46 |
+
|
| 47 |
+
self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
|
| 48 |
+
self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
|
| 49 |
+
self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
|
| 50 |
+
self.ID_block0 = nn.Sequential(
|
| 51 |
+
nn.ReplicationPad2d(1),
|
| 52 |
+
nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.HV_LCA1 = HV_LCA(ch2, head2)
|
| 56 |
+
self.HV_LCA2 = HV_LCA(ch3, head3)
|
| 57 |
+
self.HV_LCA3 = HV_LCA(ch4, head4)
|
| 58 |
+
self.HV_LCA4 = HV_LCA(ch4, head4)
|
| 59 |
+
self.HV_LCA5 = HV_LCA(ch3, head3)
|
| 60 |
+
self.HV_LCA6 = HV_LCA(ch2, head2)
|
| 61 |
+
|
| 62 |
+
self.I_LCA1 = I_LCA(ch2, head2)
|
| 63 |
+
self.I_LCA2 = I_LCA(ch3, head3)
|
| 64 |
+
self.I_LCA3 = I_LCA(ch4, head4)
|
| 65 |
+
self.I_LCA4 = I_LCA(ch4, head4)
|
| 66 |
+
self.I_LCA5 = I_LCA(ch3, head3)
|
| 67 |
+
self.I_LCA6 = I_LCA(ch2, head2)
|
| 68 |
+
|
| 69 |
+
self.trans = RGB_HVI()
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
dtypes = x.dtype
|
| 73 |
+
hvi = self.trans.HVIT(x)
|
| 74 |
+
i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
|
| 75 |
+
# low
|
| 76 |
+
i_enc0 = self.IE_block0(i)
|
| 77 |
+
i_enc1 = self.IE_block1(i_enc0)
|
| 78 |
+
hv_0 = self.HVE_block0(hvi)
|
| 79 |
+
hv_1 = self.HVE_block1(hv_0)
|
| 80 |
+
i_jump0 = i_enc0
|
| 81 |
+
hv_jump0 = hv_0
|
| 82 |
+
|
| 83 |
+
i_enc2 = self.I_LCA1(i_enc1, hv_1)
|
| 84 |
+
hv_2 = self.HV_LCA1(hv_1, i_enc1)
|
| 85 |
+
v_jump1 = i_enc2
|
| 86 |
+
hv_jump1 = hv_2
|
| 87 |
+
i_enc2 = self.IE_block2(i_enc2)
|
| 88 |
+
hv_2 = self.HVE_block2(hv_2)
|
| 89 |
+
|
| 90 |
+
i_enc3 = self.I_LCA2(i_enc2, hv_2)
|
| 91 |
+
hv_3 = self.HV_LCA2(hv_2, i_enc2)
|
| 92 |
+
v_jump2 = i_enc3
|
| 93 |
+
hv_jump2 = hv_3
|
| 94 |
+
i_enc3 = self.IE_block3(i_enc2)
|
| 95 |
+
hv_3 = self.HVE_block3(hv_2)
|
| 96 |
+
|
| 97 |
+
i_enc4 = self.I_LCA3(i_enc3, hv_3)
|
| 98 |
+
hv_4 = self.HV_LCA3(hv_3, i_enc3)
|
| 99 |
+
|
| 100 |
+
i_dec4 = self.I_LCA4(i_enc4,hv_4)
|
| 101 |
+
hv_4 = self.HV_LCA4(hv_4, i_enc4)
|
| 102 |
+
|
| 103 |
+
hv_3 = self.HVD_block3(hv_4, hv_jump2)
|
| 104 |
+
i_dec3 = self.ID_block3(i_dec4, v_jump2)
|
| 105 |
+
i_dec2 = self.I_LCA5(i_dec3, hv_3)
|
| 106 |
+
hv_2 = self.HV_LCA5(hv_3, i_dec3)
|
| 107 |
+
|
| 108 |
+
hv_2 = self.HVD_block2(hv_2, hv_jump1)
|
| 109 |
+
i_dec2 = self.ID_block2(i_dec3, v_jump1)
|
| 110 |
+
|
| 111 |
+
i_dec1 = self.I_LCA6(i_dec2, hv_2)
|
| 112 |
+
hv_1 = self.HV_LCA6(hv_2, i_dec2)
|
| 113 |
+
|
| 114 |
+
i_dec1 = self.ID_block1(i_dec1, i_jump0)
|
| 115 |
+
i_dec0 = self.ID_block0(i_dec1)
|
| 116 |
+
hv_1 = self.HVD_block1(hv_1, hv_jump0)
|
| 117 |
+
hv_0 = self.HVD_block0(hv_1)
|
| 118 |
+
|
| 119 |
+
output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
|
| 120 |
+
output_rgb = self.trans.PHVIT(output_hvi)
|
| 121 |
+
|
| 122 |
+
return output_rgb
|
| 123 |
+
|
| 124 |
+
def HVIT(self,x):
|
| 125 |
+
hvi = self.trans.HVIT(x)
|
| 126 |
+
return hvi
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
JarvisIR/package/agent_tools/HVICIDNet/net/HVI_transform.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
pi = 3.141592653589793
|
| 5 |
+
|
| 6 |
+
class RGB_HVI(nn.Module):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
super(RGB_HVI, self).__init__()
|
| 9 |
+
self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned
|
| 10 |
+
self.gated = False
|
| 11 |
+
self.gated2= False
|
| 12 |
+
self.alpha = 1.0
|
| 13 |
+
self.alpha_s = 1.3
|
| 14 |
+
self.this_k = 0
|
| 15 |
+
|
| 16 |
+
def HVIT(self, img):
|
| 17 |
+
eps = 1e-8
|
| 18 |
+
device = img.device
|
| 19 |
+
dtypes = img.dtype
|
| 20 |
+
hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
|
| 21 |
+
value = img.max(1)[0].to(dtypes)
|
| 22 |
+
img_min = img.min(1)[0].to(dtypes)
|
| 23 |
+
hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
|
| 24 |
+
hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
|
| 25 |
+
hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
|
| 26 |
+
|
| 27 |
+
hue[img.min(1)[0]==value] = 0.0
|
| 28 |
+
hue = hue/6.0
|
| 29 |
+
|
| 30 |
+
saturation = (value - img_min ) / (value + eps )
|
| 31 |
+
saturation[value==0] = 0
|
| 32 |
+
|
| 33 |
+
hue = hue.unsqueeze(1)
|
| 34 |
+
saturation = saturation.unsqueeze(1)
|
| 35 |
+
value = value.unsqueeze(1)
|
| 36 |
+
|
| 37 |
+
k = self.density_k
|
| 38 |
+
self.this_k = k.item()
|
| 39 |
+
|
| 40 |
+
color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
|
| 41 |
+
ch = (2.0 * pi * hue).cos()
|
| 42 |
+
cv = (2.0 * pi * hue).sin()
|
| 43 |
+
H = color_sensitive * saturation * ch
|
| 44 |
+
V = color_sensitive * saturation * cv
|
| 45 |
+
I = value
|
| 46 |
+
xyz = torch.cat([H, V, I],dim=1)
|
| 47 |
+
return xyz
|
| 48 |
+
|
| 49 |
+
def PHVIT(self, img):
|
| 50 |
+
eps = 1e-8
|
| 51 |
+
H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:]
|
| 52 |
+
|
| 53 |
+
# clip
|
| 54 |
+
H = torch.clamp(H,-1,1)
|
| 55 |
+
V = torch.clamp(V,-1,1)
|
| 56 |
+
I = torch.clamp(I,0,1)
|
| 57 |
+
|
| 58 |
+
v = I
|
| 59 |
+
k = self.this_k
|
| 60 |
+
color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
|
| 61 |
+
H = (H) / (color_sensitive + eps)
|
| 62 |
+
V = (V) / (color_sensitive + eps)
|
| 63 |
+
H = torch.clamp(H,-1,1)
|
| 64 |
+
V = torch.clamp(V,-1,1)
|
| 65 |
+
h = torch.atan2(V + eps,H + eps) / (2*pi)
|
| 66 |
+
h = h%1
|
| 67 |
+
s = torch.sqrt(H**2 + V**2 + eps)
|
| 68 |
+
|
| 69 |
+
if self.gated:
|
| 70 |
+
s = s * self.alpha_s
|
| 71 |
+
|
| 72 |
+
s = torch.clamp(s,0,1)
|
| 73 |
+
v = torch.clamp(v,0,1)
|
| 74 |
+
|
| 75 |
+
r = torch.zeros_like(h)
|
| 76 |
+
g = torch.zeros_like(h)
|
| 77 |
+
b = torch.zeros_like(h)
|
| 78 |
+
|
| 79 |
+
hi = torch.floor(h * 6.0)
|
| 80 |
+
f = h * 6.0 - hi
|
| 81 |
+
p = v * (1. - s)
|
| 82 |
+
q = v * (1. - (f * s))
|
| 83 |
+
t = v * (1. - ((1. - f) * s))
|
| 84 |
+
|
| 85 |
+
hi0 = hi==0
|
| 86 |
+
hi1 = hi==1
|
| 87 |
+
hi2 = hi==2
|
| 88 |
+
hi3 = hi==3
|
| 89 |
+
hi4 = hi==4
|
| 90 |
+
hi5 = hi==5
|
| 91 |
+
|
| 92 |
+
r[hi0] = v[hi0]
|
| 93 |
+
g[hi0] = t[hi0]
|
| 94 |
+
b[hi0] = p[hi0]
|
| 95 |
+
|
| 96 |
+
r[hi1] = q[hi1]
|
| 97 |
+
g[hi1] = v[hi1]
|
| 98 |
+
b[hi1] = p[hi1]
|
| 99 |
+
|
| 100 |
+
r[hi2] = p[hi2]
|
| 101 |
+
g[hi2] = v[hi2]
|
| 102 |
+
b[hi2] = t[hi2]
|
| 103 |
+
|
| 104 |
+
r[hi3] = p[hi3]
|
| 105 |
+
g[hi3] = q[hi3]
|
| 106 |
+
b[hi3] = v[hi3]
|
| 107 |
+
|
| 108 |
+
r[hi4] = t[hi4]
|
| 109 |
+
g[hi4] = p[hi4]
|
| 110 |
+
b[hi4] = v[hi4]
|
| 111 |
+
|
| 112 |
+
r[hi5] = v[hi5]
|
| 113 |
+
g[hi5] = p[hi5]
|
| 114 |
+
b[hi5] = q[hi5]
|
| 115 |
+
|
| 116 |
+
r = r.unsqueeze(1)
|
| 117 |
+
g = g.unsqueeze(1)
|
| 118 |
+
b = b.unsqueeze(1)
|
| 119 |
+
rgb = torch.cat([r, g, b], dim=1)
|
| 120 |
+
if self.gated2:
|
| 121 |
+
rgb = rgb * self.alpha
|
| 122 |
+
return rgb
|
JarvisIR/package/agent_tools/HVICIDNet/net/LCA.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
from .transformer_utils import *
|
| 5 |
+
|
| 6 |
+
# Cross Attention Block
|
| 7 |
+
class CAB(nn.Module):
|
| 8 |
+
def __init__(self, dim, num_heads, bias):
|
| 9 |
+
super(CAB, self).__init__()
|
| 10 |
+
self.num_heads = num_heads
|
| 11 |
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
| 12 |
+
|
| 13 |
+
self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 14 |
+
self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
|
| 15 |
+
self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
|
| 16 |
+
self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
|
| 17 |
+
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 18 |
+
|
| 19 |
+
def forward(self, x, y):
|
| 20 |
+
b, c, h, w = x.shape
|
| 21 |
+
|
| 22 |
+
q = self.q_dwconv(self.q(x))
|
| 23 |
+
kv = self.kv_dwconv(self.kv(y))
|
| 24 |
+
k, v = kv.chunk(2, dim=1)
|
| 25 |
+
|
| 26 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 27 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 28 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 29 |
+
|
| 30 |
+
q = torch.nn.functional.normalize(q, dim=-1)
|
| 31 |
+
k = torch.nn.functional.normalize(k, dim=-1)
|
| 32 |
+
|
| 33 |
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
| 34 |
+
attn = nn.functional.softmax(attn,dim=-1)
|
| 35 |
+
|
| 36 |
+
out = (attn @ v)
|
| 37 |
+
|
| 38 |
+
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
|
| 39 |
+
|
| 40 |
+
out = self.project_out(out)
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Intensity Enhancement Layer
|
| 45 |
+
class IEL(nn.Module):
|
| 46 |
+
def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
|
| 47 |
+
super(IEL, self).__init__()
|
| 48 |
+
|
| 49 |
+
hidden_features = int(dim*ffn_expansion_factor)
|
| 50 |
+
|
| 51 |
+
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
|
| 52 |
+
|
| 53 |
+
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
| 54 |
+
self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
|
| 55 |
+
self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
|
| 56 |
+
|
| 57 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
| 58 |
+
|
| 59 |
+
self.Tanh = nn.Tanh()
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = self.project_in(x)
|
| 62 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 63 |
+
x1 = self.Tanh(self.dwconv1(x1)) + x1
|
| 64 |
+
x2 = self.Tanh(self.dwconv2(x2)) + x2
|
| 65 |
+
x = x1 * x2
|
| 66 |
+
x = self.project_out(x)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Lightweight Cross Attention
|
| 71 |
+
class HV_LCA(nn.Module):
|
| 72 |
+
def __init__(self, dim,num_heads, bias=False):
|
| 73 |
+
super(HV_LCA, self).__init__()
|
| 74 |
+
self.gdfn = IEL(dim) # IEL and CDL have same structure
|
| 75 |
+
self.norm = LayerNorm(dim)
|
| 76 |
+
self.ffn = CAB(dim, num_heads, bias)
|
| 77 |
+
|
| 78 |
+
def forward(self, x, y):
|
| 79 |
+
x = x + self.ffn(self.norm(x),self.norm(y))
|
| 80 |
+
x = self.gdfn(self.norm(x))
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class I_LCA(nn.Module):
|
| 84 |
+
def __init__(self, dim,num_heads, bias=False):
|
| 85 |
+
super(I_LCA, self).__init__()
|
| 86 |
+
self.norm = LayerNorm(dim)
|
| 87 |
+
self.gdfn = IEL(dim)
|
| 88 |
+
self.ffn = CAB(dim, num_heads, bias=bias)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, y):
|
| 91 |
+
x = x + self.ffn(self.norm(x),self.norm(y))
|
| 92 |
+
x = x + self.gdfn(self.norm(x))
|
| 93 |
+
return x
|
JarvisIR/package/agent_tools/HVICIDNet/net/transformer_utils.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class LayerNorm(nn.Module):
|
| 6 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 7 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 8 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 9 |
+
with shape (batch_size, channels, height, width).
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 14 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 15 |
+
self.eps = eps
|
| 16 |
+
self.data_format = data_format
|
| 17 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
self.normalized_shape = (normalized_shape, )
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
if self.data_format == "channels_last":
|
| 23 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 24 |
+
elif self.data_format == "channels_first":
|
| 25 |
+
u = x.mean(1, keepdim=True)
|
| 26 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 27 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 28 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
class NormDownsample(nn.Module):
|
| 32 |
+
def __init__(self,in_ch,out_ch,scale=0.5,use_norm=False):
|
| 33 |
+
super(NormDownsample, self).__init__()
|
| 34 |
+
self.use_norm=use_norm
|
| 35 |
+
if self.use_norm:
|
| 36 |
+
self.norm=LayerNorm(out_ch)
|
| 37 |
+
self.prelu = nn.PReLU()
|
| 38 |
+
self.down = nn.Sequential(
|
| 39 |
+
nn.Conv2d(in_ch, out_ch,kernel_size=3,stride=1, padding=1, bias=False),
|
| 40 |
+
nn.UpsamplingBilinear2d(scale_factor=scale))
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
x = self.down(x)
|
| 43 |
+
x = self.prelu(x)
|
| 44 |
+
if self.use_norm:
|
| 45 |
+
x = self.norm(x)
|
| 46 |
+
return x
|
| 47 |
+
else:
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
class NormUpsample(nn.Module):
|
| 51 |
+
def __init__(self, in_ch,out_ch,scale=2,use_norm=False):
|
| 52 |
+
super(NormUpsample, self).__init__()
|
| 53 |
+
self.use_norm=use_norm
|
| 54 |
+
if self.use_norm:
|
| 55 |
+
self.norm=LayerNorm(out_ch)
|
| 56 |
+
self.prelu = nn.PReLU()
|
| 57 |
+
self.up_scale = nn.Sequential(
|
| 58 |
+
nn.Conv2d(in_ch,out_ch,kernel_size=3,stride=1, padding=1, bias=False),
|
| 59 |
+
nn.UpsamplingBilinear2d(scale_factor=scale))
|
| 60 |
+
self.up = nn.Conv2d(out_ch*2,out_ch,kernel_size=1,stride=1, padding=0, bias=False)
|
| 61 |
+
|
| 62 |
+
def forward(self, x,y):
|
| 63 |
+
x = self.up_scale(x)
|
| 64 |
+
x = torch.cat([x, y],dim=1)
|
| 65 |
+
x = self.up(x)
|
| 66 |
+
x = self.prelu(x)
|
| 67 |
+
if self.use_norm:
|
| 68 |
+
return self.norm(x)
|
| 69 |
+
else:
|
| 70 |
+
return x
|
| 71 |
+
|
JarvisIR/package/agent_tools/HVICIDNet/wavelet.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def Normalize(x):
|
| 6 |
+
ymax = 255
|
| 7 |
+
ymin = 0
|
| 8 |
+
xmax = x.max()
|
| 9 |
+
xmin = x.min()
|
| 10 |
+
return (ymax-ymin)*(x-xmin)/(xmax-xmin) + ymin
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def dwt_init(x):
|
| 14 |
+
|
| 15 |
+
x01 = x[:, :, 0::2, :] / 2
|
| 16 |
+
x02 = x[:, :, 1::2, :] / 2
|
| 17 |
+
x1 = x01[:, :, :, 0::2]
|
| 18 |
+
x2 = x02[:, :, :, 0::2]
|
| 19 |
+
x3 = x01[:, :, :, 1::2]
|
| 20 |
+
x4 = x02[:, :, :, 1::2]
|
| 21 |
+
x_LL = x1 + x2 + x3 + x4
|
| 22 |
+
x_HL = -x1 - x2 + x3 + x4
|
| 23 |
+
x_LH = -x1 + x2 - x3 + x4
|
| 24 |
+
x_HH = x1 - x2 - x3 + x4
|
| 25 |
+
|
| 26 |
+
return torch.cat((x_LL, x_HL, x_LH, x_HH), 0)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# 使用哈尔 haar 小波变换来实现二维离散小波
|
| 30 |
+
def iwt_init(x):
|
| 31 |
+
r = 2
|
| 32 |
+
in_batch, in_channel, in_height, in_width = x.size()
|
| 33 |
+
out_batch, out_channel, out_height, out_width = int(in_batch/(r**2)),in_channel, r * in_height, r * in_width
|
| 34 |
+
x1 = x[0:out_batch, :, :] / 2
|
| 35 |
+
x2 = x[out_batch:out_batch * 2, :, :, :] / 2
|
| 36 |
+
x3 = x[out_batch * 2:out_batch * 3, :, :, :] / 2
|
| 37 |
+
x4 = x[out_batch * 3:out_batch * 4, :, :, :] / 2
|
| 38 |
+
|
| 39 |
+
h = torch.zeros([out_batch, out_channel, out_height,
|
| 40 |
+
out_width]).float().to(x.device)
|
| 41 |
+
|
| 42 |
+
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
|
| 43 |
+
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
|
| 44 |
+
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
|
| 45 |
+
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
|
| 46 |
+
|
| 47 |
+
return h
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DWT(nn.Module):
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super(DWT, self).__init__()
|
| 53 |
+
self.requires_grad = False # 信号处理,非卷积运算,不需要进行梯度求导
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return dwt_init(x)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class IWT(nn.Module):
|
| 60 |
+
def __init__(self):
|
| 61 |
+
super(IWT, self).__init__()
|
| 62 |
+
self.requires_grad = False
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return iwt_init(x)
|
JarvisIR/package/agent_tools/IDT/__init__.py
ADDED
|
File without changes
|
JarvisIR/package/agent_tools/IDT/analyse/cal_rf_bf.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
def getimgarr(path_1):
|
| 7 |
+
dir_1 = sorted(os.listdir(path_1))
|
| 8 |
+
imgs = []
|
| 9 |
+
for j in range(len(dir_1)):
|
| 10 |
+
img = cv2.imread(path_1 + '/' +dir_1[j])
|
| 11 |
+
imgs.append(img)
|
| 12 |
+
imgs = np.array(imgs, dtype=float)
|
| 13 |
+
|
| 14 |
+
return imgs
|
| 15 |
+
|
| 16 |
+
#inp_path = '/disk1/release/DayRainDrop/Clear'
|
| 17 |
+
#inp_path = '/disk1/release/NightRainDrop/Clear'
|
| 18 |
+
#inp_path = '/disk1/release/DayRainDrop_Test/Clear'
|
| 19 |
+
inp_path = '/disk1/release/NightRainDrop_Test/Clear'
|
| 20 |
+
|
| 21 |
+
countall = 0
|
| 22 |
+
countbf = 0
|
| 23 |
+
countrf = 0
|
| 24 |
+
|
| 25 |
+
sid = sorted(os.listdir(inp_path))
|
| 26 |
+
print(sid,len(sid))
|
| 27 |
+
|
| 28 |
+
for i in range(len(sid)):
|
| 29 |
+
|
| 30 |
+
inp_path_clean = os.path.join(inp_path, sid[i])
|
| 31 |
+
inp_path_blur = inp_path_clean.replace('/Clear/','/Blur/')
|
| 32 |
+
#print(inp_path_clean,inp_path_blur)
|
| 33 |
+
|
| 34 |
+
dropimgs = sorted(os.listdir(inp_path_clean))
|
| 35 |
+
Droplist = []
|
| 36 |
+
for didx in range(len(dropimgs)):
|
| 37 |
+
if dropimgs[didx].endswith('.png'):
|
| 38 |
+
Droplist.append(dropimgs[didx])
|
| 39 |
+
# print(Droplist)
|
| 40 |
+
# print('-------------------------------------------------')
|
| 41 |
+
|
| 42 |
+
for j in range(len(Droplist)):
|
| 43 |
+
cleaninp_frame_name = os.path.join(inp_path_clean, Droplist[j])
|
| 44 |
+
blurinp_frame_name = cleaninp_frame_name.replace('/Clear/','/Blur/')
|
| 45 |
+
# print(cleaninp_frame_name,blurinp_frame_name)
|
| 46 |
+
countall = countall + 1
|
| 47 |
+
cleanimage = Image.open(cleaninp_frame_name)
|
| 48 |
+
blurimage = Image.open(blurinp_frame_name)
|
| 49 |
+
if cleanimage == blurimage:
|
| 50 |
+
countbf = countbf + 1
|
| 51 |
+
parts = cleaninp_frame_name.split(os.sep)
|
| 52 |
+
new_path = os.path.join(parts[-2], parts[-1])
|
| 53 |
+
print(new_path)
|
| 54 |
+
else:
|
| 55 |
+
countrf = countrf + 1
|
| 56 |
+
print('-countall-',countall,'-background focus-',countbf,'-raindrop focus-',countrf)
|
| 57 |
+
|
| 58 |
+
#DayRainDrop -countall- 5442,-background focus- 1836,-raindrop focus- 3606,
|
| 59 |
+
#NightRainDrop -countall- 9744,-background focus- 4906,-raindrop focus- 4838,
|
| 60 |
+
#Total: 15186
|
| 61 |
+
|
| 62 |
+
#DayRainDrop_Train -countall- 8655,-background focus- 4143,-raindrop focus- 4512
|
| 63 |
+
#DayRainDrop_Test -countall- 729 -background focus- 261 -raindrop focus- 468
|
| 64 |
+
#NightRainDrop_Train -countall- 4713,-background focus- 1575,-raindrop focus- 3138
|
| 65 |
+
#NightRainDrop_Test -countall- 1089 -background focus- 763 -raindrop focus- 326
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
JarvisIR/package/agent_tools/IDT/configs/daytime_128.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: "RainDrop"
|
| 3 |
+
#image_size: 64 # DIT, RDiffusion, onego
|
| 4 |
+
#image_size: 256 # ICRA, Uformer, atgan
|
| 5 |
+
image_size: 128 # IDT, restormer
|
| 6 |
+
channels: 3
|
| 7 |
+
num_workers: 8
|
| 8 |
+
data_dir: "/data4/lx_workspace/datasets/SFT-data/mini/daytime_driving_rainy/rainy/"
|
| 9 |
+
conditional: True
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
in_channels: 3
|
| 13 |
+
out_ch: 3
|
| 14 |
+
ch: 128
|
| 15 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
| 16 |
+
num_res_blocks: 2
|
| 17 |
+
attn_resolutions: [16, ]
|
| 18 |
+
dropout: 0.0
|
| 19 |
+
ema_rate: 0.999
|
| 20 |
+
ema: True
|
| 21 |
+
resamp_with_conv: True
|
| 22 |
+
|
| 23 |
+
diffusion:
|
| 24 |
+
beta_schedule: linear
|
| 25 |
+
beta_start: 0.0001
|
| 26 |
+
beta_end: 0.02
|
| 27 |
+
num_diffusion_timesteps: 1000
|
| 28 |
+
|
| 29 |
+
training:
|
| 30 |
+
#-----------DIT, RDiffusion, onego, ICRA--------------
|
| 31 |
+
# patch_n: 4
|
| 32 |
+
# batch_size: 16
|
| 33 |
+
# #---------IDT Uformer restormer--------------
|
| 34 |
+
patch_n: 1
|
| 35 |
+
batch_size: 4
|
| 36 |
+
#-----------atgan--------------
|
| 37 |
+
# patch_n: 1
|
| 38 |
+
# batch_size: 32
|
| 39 |
+
#-----------DIT, RDiffusion--------------
|
| 40 |
+
#n_epochs: 37042
|
| 41 |
+
n_epochs: 401
|
| 42 |
+
n_iters: 20000000
|
| 43 |
+
snapshot_freq: 50
|
| 44 |
+
validation_freq: 10000
|
| 45 |
+
|
| 46 |
+
sampling:
|
| 47 |
+
batch_size: 4
|
| 48 |
+
last_only: True
|
| 49 |
+
|
| 50 |
+
optim:
|
| 51 |
+
weight_decay: 0.000
|
| 52 |
+
optimizer: "Adam"
|
| 53 |
+
lr: 0.00002
|
| 54 |
+
amsgrad: False
|
| 55 |
+
eps: 0.00000001
|
JarvisIR/package/agent_tools/IDT/configs/daytime_256.yml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: "RainDrop"
|
| 3 |
+
#image_size: 64 # DIT, RDiffusion, onego
|
| 4 |
+
image_size: 256 # ICRA, Uformer, atgan
|
| 5 |
+
# image_size: 128 # IDT, restormer
|
| 6 |
+
channels: 3
|
| 7 |
+
num_workers: 8
|
| 8 |
+
data_dir: "/data4/lx_workspace/datasets/SFT-data/mini/daytime_driving_rainy/rainy/"
|
| 9 |
+
conditional: True
|
| 10 |
+
|
| 11 |
+
model:
|
| 12 |
+
in_channels: 3
|
| 13 |
+
out_ch: 3
|
| 14 |
+
ch: 128
|
| 15 |
+
ch_mult: [1, 1, 2, 2, 4, 4]
|
| 16 |
+
num_res_blocks: 2
|
| 17 |
+
attn_resolutions: [16, ]
|
| 18 |
+
dropout: 0.0
|
| 19 |
+
ema_rate: 0.999
|
| 20 |
+
ema: True
|
| 21 |
+
resamp_with_conv: True
|
| 22 |
+
|
| 23 |
+
diffusion:
|
| 24 |
+
beta_schedule: linear
|
| 25 |
+
beta_start: 0.0001
|
| 26 |
+
beta_end: 0.02
|
| 27 |
+
num_diffusion_timesteps: 1000
|
| 28 |
+
|
| 29 |
+
training:
|
| 30 |
+
#-----------DIT, RDiffusion, onego, ICRA--------------
|
| 31 |
+
# patch_n: 4
|
| 32 |
+
# batch_size: 16
|
| 33 |
+
# #---------IDT Uformer restormer--------------
|
| 34 |
+
patch_n: 1
|
| 35 |
+
batch_size: 4
|
| 36 |
+
#-----------atgan--------------
|
| 37 |
+
# patch_n: 1
|
| 38 |
+
# batch_size: 32
|
| 39 |
+
#-----------DIT, RDiffusion--------------
|
| 40 |
+
#n_epochs: 37042
|
| 41 |
+
n_epochs: 401
|
| 42 |
+
n_iters: 20000000
|
| 43 |
+
snapshot_freq: 50
|
| 44 |
+
validation_freq: 10000
|
| 45 |
+
|
| 46 |
+
sampling:
|
| 47 |
+
batch_size: 4
|
| 48 |
+
last_only: True
|
| 49 |
+
|
| 50 |
+
optim:
|
| 51 |
+
weight_decay: 0.000
|
| 52 |
+
optimizer: "Adam"
|
| 53 |
+
lr: 0.00002
|
| 54 |
+
amsgrad: False
|
| 55 |
+
eps: 0.00000001
|