LYL1015 commited on
Commit
eea83e8
·
1 Parent(s): 3d2f97b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +3 -0
  3. JarvisIR/.gitignore +146 -0
  4. JarvisIR/.gitmodules +7 -0
  5. JarvisIR/LICENSE +21 -0
  6. JarvisIR/README.md +148 -0
  7. JarvisIR/demo_gradio.py +474 -0
  8. JarvisIR/docs/gradio_demo.md +36 -0
  9. JarvisIR/docs/sft_training.md +84 -0
  10. JarvisIR/package/README.md +155 -0
  11. JarvisIR/package/agent_tools.egg-info/PKG-INFO +7 -0
  12. JarvisIR/package/agent_tools.egg-info/SOURCES.txt +192 -0
  13. JarvisIR/package/agent_tools.egg-info/dependency_links.txt +1 -0
  14. JarvisIR/package/agent_tools.egg-info/top_level.txt +1 -0
  15. JarvisIR/package/agent_tools/ESRGAN/__init__.py +0 -0
  16. JarvisIR/package/agent_tools/ESRGAN/inference.py +49 -0
  17. JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus.yml +188 -0
  18. JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus_pairdata.yml +150 -0
  19. JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x2plus.yml +186 -0
  20. JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x4plus.yml +185 -0
  21. JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x2plus.yml +145 -0
  22. JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x4plus.yml +144 -0
  23. JarvisIR/package/agent_tools/ESRGAN/realesrgan/__init__.py +5 -0
  24. JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/__init__.py +11 -0
  25. JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/discriminator_arch.py +67 -0
  26. JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/srvgg_arch.py +69 -0
  27. JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/__init__.py +11 -0
  28. JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_dataset.py +192 -0
  29. JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_paired_dataset.py +108 -0
  30. JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/__init__.py +11 -0
  31. JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrgan_model.py +258 -0
  32. JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrnet_model.py +188 -0
  33. JarvisIR/package/agent_tools/ESRGAN/realesrgan/train.py +11 -0
  34. JarvisIR/package/agent_tools/ESRGAN/realesrgan/utils.py +309 -0
  35. JarvisIR/package/agent_tools/HVICIDNet/inference.py +65 -0
  36. JarvisIR/package/agent_tools/HVICIDNet/loss/loss_utils.py +145 -0
  37. JarvisIR/package/agent_tools/HVICIDNet/loss/losses.py +193 -0
  38. JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_pris_params.npz +3 -0
  39. JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_utils.py +559 -0
  40. JarvisIR/package/agent_tools/HVICIDNet/loss/vgg_arch.py +239 -0
  41. JarvisIR/package/agent_tools/HVICIDNet/mods.py +153 -0
  42. JarvisIR/package/agent_tools/HVICIDNet/net/CIDNet.py +129 -0
  43. JarvisIR/package/agent_tools/HVICIDNet/net/HVI_transform.py +122 -0
  44. JarvisIR/package/agent_tools/HVICIDNet/net/LCA.py +93 -0
  45. JarvisIR/package/agent_tools/HVICIDNet/net/transformer_utils.py +71 -0
  46. JarvisIR/package/agent_tools/HVICIDNet/wavelet.py +65 -0
  47. JarvisIR/package/agent_tools/IDT/__init__.py +0 -0
  48. JarvisIR/package/agent_tools/IDT/analyse/cal_rf_bf.py +68 -0
  49. JarvisIR/package/agent_tools/IDT/configs/daytime_128.yml +55 -0
  50. JarvisIR/package/agent_tools/IDT/configs/daytime_256.yml +55 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ JarvisIR/package/agent_tools/RIDCP/.eggs/**/*.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **/.eggs/
2
+ *.so
3
+ JarvisIR/package/agent_tools/RIDCP/.eggs
JarvisIR/.gitignore ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignored folders
2
+ datasets/*
3
+ experiments*
4
+ experiments/*
5
+ # results/*
6
+ tb_logger*
7
+ pretrained_models
8
+ wandb/*
9
+ tmp*/*
10
+ *.sh
11
+ .vscode*
12
+ .github
13
+ # ignored files
14
+ version.py
15
+
16
+ # ignored files with suffix
17
+ *.html
18
+ *.pth
19
+ *.zip
20
+
21
+ # template
22
+
23
+ # Byte-compiled / optimized / DLL files
24
+ __pycache__/
25
+ *.py[cod]
26
+ *$py.class
27
+
28
+ # C extensions
29
+ *.so
30
+
31
+ # Distribution / packaging
32
+ .Python
33
+ build/
34
+ develop-eggs/
35
+ dist/
36
+ downloads/
37
+ eggs/
38
+ .eggs/
39
+ lib/
40
+ lib64/
41
+ parts/
42
+ sdist/
43
+ var/
44
+ wheels/
45
+ *.egg-info/
46
+ .installed.cfg
47
+ *.egg
48
+ MANIFEST
49
+
50
+ # PyInstaller
51
+ # Usually these files are written by a python script from a template
52
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
53
+ *.manifest
54
+ *.spec
55
+
56
+ # Installer logs
57
+ pip-log.txt
58
+ pip-delete-this-directory.txt
59
+
60
+ # Unit test / coverage reports
61
+ htmlcov/
62
+ .tox/
63
+ .coverage
64
+ .coverage.*
65
+ .cache
66
+ nosetests.xml
67
+ coverage.xml
68
+ *.cover
69
+ .hypothesis/
70
+ .pytest_cache/
71
+
72
+ # Translations
73
+ *.mo
74
+ *.pot
75
+
76
+ # Django stuff:
77
+ *.log
78
+ local_settings.py
79
+ db.sqlite3
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ target/
93
+
94
+ # Jupyter Notebook
95
+ .ipynb_checkpoints
96
+
97
+ # pyenv
98
+ .python-version
99
+
100
+ # celery beat schedule file
101
+ celerybeat-schedule
102
+
103
+ # SageMath parsed files
104
+ *.sage.py
105
+
106
+ # Environments
107
+ .env
108
+ .venv
109
+ env/
110
+ venv/
111
+ ENV/
112
+ env.bak/
113
+ venv.bak/
114
+
115
+ # Spyder project settings
116
+ .spyderproject
117
+ .spyproject
118
+
119
+ # Rope project settings
120
+ .ropeproject
121
+
122
+ # mkdocs documentation
123
+ /site
124
+
125
+ # mypy
126
+ .mypy_cache/
127
+
128
+ checkpoint
129
+ *.zip
130
+ *.png
131
+ *.jpg
132
+ *.jpeg
133
+ *.json
134
+ .eggs
135
+ *.pth*
136
+ *.pk
137
+ *.pkl
138
+ *.mdb
139
+ *.pt
140
+ *.log
141
+ *.bin
142
+ __pycache__
143
+ Temp
144
+ work_dirs
145
+ src/sft/config
146
+ checkpoints
JarvisIR/.gitmodules ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [submodule "dependences/BasicSR"]
2
+ path = dependences/BasicSR
3
+ url = https://github.com/XPixelGroup/BasicSR.git
4
+ [submodule "src/sft/xtuner"]
5
+ path = src/sft/xtuner
6
+ url = https://github.com/InternLM/xtuner.git
7
+ branch = v0.1.23
JarvisIR/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Yunlong Lin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
JarvisIR/README.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <div align="center">
3
+ <img src="assets/icon.png" alt="JarvisIR Logo" width="100px">
4
+ <h1>[CVPR' 2025] JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration</h1>
5
+ </div>
6
+
7
+ <a href="https://lyl1015.github.io/papers/CVPR2025_JarvisIR.pdf" target="_blank" rel="noopener noreferrer">
8
+ <img src="https://img.shields.io/badge/Paper-JarvisIR-b31b1b" alt="Paper PDF">
9
+ </a>
10
+ <!-- <a href="#"><img src="https://img.shields.io/badge/arXiv-即将发布-b31b1b" alt="arXiv"></a> -->
11
+ <a href="https://cvpr2025-jarvisir.github.io/"><img src="https://img.shields.io/badge/Project-Page-green" alt="Project Page"></a>
12
+ <a href="#"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Coming%20Soon-blue" alt="Demo"></a>
13
+ <a href="https://github.com/LYL1015/JarvisIR?tab=readme-ov-file/"><img src="https://img.shields.io/badge/GitHub-Code-black" alt="Code"></a>
14
+
15
+
16
+ [Yunlong Lin](https://lyl1015.github.io/)<sup>1*♣</sup>, [Zixu Lin](https://github.com/)<sup>1*♣</sup>, [Haoyu Chen](https://haoyuchen.com/)<sup>2*</sup>, [Panwang Pan](https://paulpanwang.github.io/)<sup>3*</sup>, [Chenxin Li](https://chenxinli001.github.io/)<sup>6</sup>, [Sixiang Chen](https://ephemeral182.github.io/)<sup>2</sup>, [Kairun Wen](https://kairunwen.github.io/)<sup>1</sup>, [Yeying Jin](https://jinyeying.github.io/)<sup>4</sup>, [Wenbo Li](https://fenglinglwb.github.io/)<sup>5†</sup>, [Xinghao Ding](https://scholar.google.com/citations?user=k5hVBfMAAAAJ&hl=zh-CN)<sup>1†</sup>
17
+
18
+ <sup>1</sup>Xiamen University, <sup>2</sup>The Hong Kong University of Science and Technology (Guangzhou), <sup>3</sup>Bytedance's Pico, <sup>4</sup>Tencent, <sup>5</sup>Huawei Noah's Ark Lab, <sup>6</sup>The Chinese University of Hong Kong
19
+ <!-- <sup>*</sup>Equal Contribution <sup>♣</sup>Equal Contribution <sup>†</sup>Corresponding Author -->
20
+ Accepted by CVPR 2025
21
+
22
+ <!-- <div align="center">
23
+ <video width="800" controls>
24
+ <source src="assets/demo.mp4" type="video/mp4">
25
+ Your browser does not support the video tag.
26
+ </video>
27
+ <p>JarvisIR Demo Video: Showcasing image restoration capabilities under various adverse weather conditions</p>
28
+ </div> -->
29
+ https://github.com/user-attachments/assets/d9094fba-e24c-403e-90cb-b3d2f6e48939
30
+
31
+
32
+ </div>
33
+
34
+
35
+ ## :postbox: Updates
36
+ <!-- - 2023.12.04: Add an option to speed up the inference process by adjusting the number of denoising steps. -->
37
+ <!-- - 2024.2.9: Release our demo codes and models. Have fun! :yum: -->
38
+ - 2025.4.8: This repo is created.
39
+
40
+ ## :diamonds: Overview
41
+ JarvisIR (CVPR 2025) is a VLM-powered agent designed to tackle the challenges of vision-centric perception systems under unpredictable and coupled weather degradations. It leverages the VLM as a controller to manage multiple expert restoration models, enabling robust and autonomous operation in real-world conditions. JarvisIR employs a novel two-stage framework consisting of supervised fine-tuning and human feedback alignment, allowing it to effectively fine-tune on large-scale real-world data in an unsupervised manner. Supported by CleanBench, a comprehensive dataset with 150K synthetic and 80K real instruction-response pairs, JarvisIR demonstrates superior decision-making and restoration capabilities, achieving a 50% improvement in the average of all perception metrics on CleanBench-Real.
42
+ <div align="center">
43
+ <img src="assets/teaser1.png" alt="JarvisIR Teaser" width="800px">
44
+ </div>
45
+
46
+ ## :rocket: Method
47
+
48
+ JarvisIR implements an innovative two-stage framework that leverages a Vision-Language Model (VLM) as a controller to manage multiple expert restoration models:
49
+
50
+ 1. **Supervised Fine-tuning Stage**: JarvisIR undergoes supervised fine-tuning on synthetic data from CleanBench to enable it to follow user instructions and recognize image degradation. This initial training allows the model to identify various types of image degradation and select appropriate restoration strategies.
51
+
52
+ 2. **Human Feedback Alignment Stage**: We further finetune JarvisIR on CleanBench-Real using the MRRHF algorithm to improve system robustness, reduce hallucinations, and enhance generalizability under real-world adverse weather conditions. This stage ensures the model makes decisions that align with human expectations in complex real-world scenarios.
53
+
54
+ The core advantage of JarvisIR lies in its ability to handle multiple complex, coupled weather degradations and provide stable, reliable image inputs for autonomous driving perception systems.
55
+
56
+ <div align="center">
57
+ <img src="assets/framework.png" alt="JarvisIR Method" width="800px">
58
+ <p>Two-stage training framework of JarvisIR: The VLM controller analyzes input images, selects and coordinates expert models for optimal restoration</p>
59
+ </div>
60
+
61
+ ## :bar_chart: CleanBench Dataset
62
+
63
+ To support the training and evaluation of JarvisIR, we introduce CleanBench, the first high-quality instruction-following dataset specifically curated for developing intelligent restoration systems. CleanBench contains **150K** synthetic and **80K** real instruction-response pairs, providing a comprehensive foundation for training and evaluating intelligent image restoration systems.
64
+
65
+ ### Dataset Construction
66
+
67
+ The CleanBench dataset construction workflow consists of three main steps:
68
+
69
+ 1. **Synthesis of Degraded Images**: We generate a diverse set of degraded images by applying various weather conditions and degradation types to clean images, creating realistic scenarios that autonomous driving systems might encounter.
70
+
71
+ 2. **Generation of Assessment Reasoning and Optimal Task Sequence**: For each degraded image, we generate detailed assessments of the degradation types present and determine the optimal sequence of restoration tasks needed to effectively restore the image.
72
+
73
+ 3. **Generation of Instruction-Response Pairs**: Based on the degradation assessment and restoration sequence, we create comprehensive instruction-response pairs that guide the model in understanding user requests and providing appropriate restoration solutions.
74
+
75
+ <div align="center">
76
+ <img src="assets/datasets.png" alt="CleanBench Dataset Construction" width="800px">
77
+ <p>CleanBench dataset construction workflow: from degraded image synthesis to instruction-response pair generation</p>
78
+ </div>
79
+
80
+ ### Dataset Features
81
+
82
+ - **Comprehensive Coverage**: Includes various weather conditions (rain, snow, fog, night) and their combinations
83
+ - **High-Quality Annotations**: Detailed degradation assessments and optimal restoration sequences
84
+ - **Real-World Examples**: 80K real-world examples to ensure model generalization
85
+ - **Instruction Diversity**: Multiple instruction formats to enhance model adaptability
86
+
87
+ CleanBench serves as a crucial resource for training and evaluating intelligent image restoration systems, enabling models like JarvisIR to make informed decisions about restoration strategies in complex real-world scenarios.
88
+
89
+ ## :computer: Getting Started
90
+
91
+ For sft training and environment setup preparation, please follow:
92
+
93
+ - [SFT Training](./docs/sft_training.md)
94
+ <!-- - [Dataset Preparation](./docs/dataset_preparation.md) -->
95
+ <!--
96
+ For sft training, please follow:
97
+
98
+ - [SFT Training](./docs/sft_training.md) -->
99
+
100
+ For gradio demo runing, please follow:
101
+
102
+ - [Gradio Demo](./docs/gradio_demo.md)
103
+
104
+
105
+ ## :toolbox: Expert Models
106
+
107
+ JarvisIR integrates multiple expert restoration models to handle various types of image degradation:
108
+
109
+ | Task | Model | Description |
110
+ |------|-------|-------------|
111
+ | **Super-resolution** | Real-ESRGAN | Fast GAN-based model for super-resolution, deblurring, and artifact removal |
112
+ | **Denoising** | SCUNet | Hybrid UNet-based model combining convolution and transformer blocks for robust denoising |
113
+ | **Deraining** | UDR-S2Former | Uncertainty-aware transformer model for rain streak removal |
114
+ | | Img2img-turbo-rain | Efficient SD-turbo based model for fast and effective rain removal |
115
+ | **Raindrop removal** | IDT | Transformer-based model for de-raining and raindrop removal |
116
+ | **Dehazing** | RIDCP | Efficient dehazing model utilizing high-quality codebook priors |
117
+ | | KANet | Efficient dehazing network using a localization-and-removal pipeline |
118
+ | **Desnowing** | Img2img-turbo-snow | Efficient model for removing snow artifacts while preserving natural scene details |
119
+ | | Snowmaster | Real-world image desnowing via MLLM with multi-model feedback optimization |
120
+ | **Low-light enhancement** | Retinexformer | One-stage Retinex-based Transformer for low-light image enhancement |
121
+ | | HVICIDNet | Lightweight transformer for low-light and exposure correction |
122
+ | | LightenDiff | Diffusion-based framework for low-light enhancement |
123
+
124
+
125
+ ## :circus_tent: Checklist
126
+
127
+ - [x] Release preview inference code and gradio demo
128
+ - [x] Release SFT training code
129
+ - [ ] Release Hugging Face demo
130
+ - [ ] Release CleanBench data
131
+
132
+
133
+ ## :pray: Acknowledgements
134
+
135
+ We would like to express our gratitude to [HuggingGPT](https://github.com/microsoft/JARVIS), [XTuner](https://github.com/InternLM/xtuner), and [RRHF](https://github.com/GanjinZero/RRHF) for their valuable open-source contributions which have provided important technical references for our work.
136
+
137
+
138
+
139
+
140
+ ## :love_you_gesture: Citation
141
+ ```bibtex
142
+ @inproceedings{jarvisir2025,
143
+ title={JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration},
144
+ author={Lin, Yunlong and Lin, Zixu and Chen, Haoyu and Pan, Panwang and Li, Chenxin and Chen, Sixiang and Kairun, Wen and Jin, Yeying and Li, Wenbo and Ding, Xinghao},
145
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
146
+ year={2025}
147
+ }
148
+ ```
JarvisIR/demo_gradio.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import random
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
8
+ from threading import Thread
9
+ from agent_tools import RestorationToolkit
10
+
11
+ # Set CUDA device
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
13
+
14
+ # Model configuration
15
+ # XXX: Path to the fine-tuned LLaVA model
16
+ model_id = ""
17
+
18
+ # Available image restoration tasks and their corresponding models
19
+ all_tasks = " {denoise: [scunet, restormer], lighten: [retinexformer_fivek, hvicidnet, lightdiff], \
20
+ derain: [idt, turbo_rain, s2former], defog:[ridcp, kanet], \
21
+ desnow:[turbo_snow, snowmaster], super_resolution: [real_esrgan], \
22
+ }"
23
+
24
+ # Various prompt templates for querying the LLM about image degradation and restoration tasks
25
+ prompts_query2 = [
26
+ f"Considering the image's degradation, suggest the required tasks with explanations, and identify suitable tools for each task. Options for tasks and tools include: {all_tasks}.",
27
+ f"Given the image's degradation, outline the essential tasks along with justifications, and choose the appropriate tools for each task from the following options: {all_tasks}.",
28
+ f"Please specify the tasks required due to the image's degradation, explain the reasons, and select relevant tools for each task from the provided options: {all_tasks}.",
29
+ f"Based on the image degradation, determine the necessary tasks and their reasons, along with the appropriate tools for each task. Choose from these options: {all_tasks}.",
30
+ f"Identify the tasks required to address the image's degradation, including the reasons for each, and select tools from the options: {all_tasks}.",
31
+ f"Considering the degradation observed, list the tasks needed and their justifications, then pick the most suitable tools for each task from these options: {all_tasks}.",
32
+ f"Evaluate the image degradation, and based on that, provide the necessary tasks and reasons, along with tools chosen from the options: {all_tasks}.",
33
+ f"With respect to the image degradation, outline the tasks needed and explain why, selecting tools from the following list: {all_tasks}.",
34
+ f"Given the level of degradation in the image, specify tasks to address it, include reasons, and select tools for each task from: {all_tasks}.",
35
+ f"Examine the image's degradation, propose relevant tasks and their explanations, and identify tools from the options provided: {all_tasks}.",
36
+ f"Based on observed degradation, detail the tasks required, explain your choices, and select tools from these options: {all_tasks}.",
37
+ f"Using the image's degradation as a guide, list the necessary tasks, include explanations, and pick tools from the provided choices: {all_tasks}.",
38
+ f"Assess the image degradation, provide the essential tasks and reasons, and select the appropriate tools for each task from the options: {all_tasks}.",
39
+ f"According to the image's degradation, determine which tasks are necessary and why, choosing tools for each task from: {all_tasks}.",
40
+ f"Observe the degradation in the image, specify the needed tasks with justifications, and select appropriate tools from: {all_tasks}.",
41
+ f"Taking the image degradation into account, specify tasks needed, provide reasons, and choose tools from the following: {all_tasks}.",
42
+ f"Consider the image's degradation level, outline the tasks necessary, provide reasoning, and select suitable tools from: {all_tasks}.",
43
+ f"Evaluate the degradation in the image, identify tasks required, explain your choices, and pick tools from: {all_tasks}.",
44
+ f"Analyze the image degradation and suggest tasks with justifications, choosing the best tools from these options: {all_tasks}.",
45
+ f"Review the image degradation, and based on it, specify tasks needed, provide reasons, and select tools for each task from: {all_tasks}."
46
+ ]
47
+
48
+ # Initialize models
49
+ print("Loading LLM model...")
50
+
51
+ # Initialize the image restoration toolkit
52
+ tool_engine = RestorationToolkit(score_weight=[0,0,0,0,0])
53
+ # Load the LLaVA model in half precision to reduce memory usage
54
+ model = LlavaForConditionalGeneration.from_pretrained(
55
+ model_id,
56
+ torch_dtype=torch.float16,
57
+ low_cpu_mem_usage=True,
58
+ ).to(0)
59
+ processor = AutoProcessor.from_pretrained(model_id)
60
+
61
+ print("Loading tool engine...")
62
+
63
+ def parse_llm_response(response):
64
+ """
65
+ Parse the LLM response to extract reason and answer sections
66
+
67
+ Args:
68
+ response (str): The raw response from the LLM
69
+
70
+ Returns:
71
+ tuple: (reason, answer) extracted from the response
72
+ """
73
+ reason_match = re.search(r'<reason>(.*?)</reason>', response, re.DOTALL)
74
+ answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
75
+
76
+ reason = reason_match.group(1).strip() if reason_match else "No reasoning provided"
77
+ answer = answer_match.group(1).strip() if answer_match else "No answer provided"
78
+
79
+ return reason, answer
80
+
81
+ def extract_models_from_answer(answer):
82
+ """
83
+ Extract model names from the answer string using regex
84
+
85
+ Args:
86
+ answer (str): The answer string containing model recommendations
87
+
88
+ Returns:
89
+ list: List of extracted model names
90
+ """
91
+ # Pattern to match [type:xxx]:(model:xxx)
92
+ pattern = r'\[type:[^\]]+\]:\(model:([^)]+)\)'
93
+ models = re.findall(pattern, answer)
94
+ return models
95
+
96
+ def beautify_recommended_actions(answer, models):
97
+ """
98
+ Format the LLM's recommendations in a more visually appealing way
99
+
100
+ Args:
101
+ answer (str): The raw answer from LLM
102
+ models (list): List of extracted model names
103
+
104
+ Returns:
105
+ str: Beautified display of recommendations
106
+ """
107
+
108
+ # Task type to emoji mapping for visual enhancement
109
+ task_icons = {
110
+ 'denoise': '🧹',
111
+ 'lighten': '💡',
112
+ 'derain': '🌧️',
113
+ 'defog': '🌫️',
114
+ 'desnow': '❄️',
115
+ 'super_resolution': '🔍'
116
+ }
117
+
118
+ # Parse the answer to extract tasks and models
119
+ pattern = r'\[type:([^\]]+)\]:\(model:([^)]+)\)'
120
+ matches = re.findall(pattern, answer)
121
+
122
+ if not matches:
123
+ return f"**🎯 Recommended Actions:**\n\n{answer}\n\n**Extracted Models:** {', '.join(models) if models else 'None'}"
124
+
125
+ # Create beautified display
126
+ beautified = "**🎯 Recommended Actions:**\n"
127
+ beautified += "> "
128
+
129
+ # Create horizontal flow of actions
130
+ action_parts = []
131
+ for task_type, model_name in matches:
132
+ task_type = task_type.strip()
133
+ model_name = model_name.strip()
134
+
135
+ # Get icon for task type
136
+ icon = task_icons.get(task_type, '🔧')
137
+
138
+ # Format task name (capitalize and replace underscores)
139
+ task_display = task_type.title().replace('_', ' ')
140
+
141
+ # Create action part: icon + task + model
142
+ action_part = f"{icon} {task_display}:`{model_name}`"
143
+ action_parts.append(action_part)
144
+
145
+ # Join with arrows to show sequence
146
+ beautified += " ➡ ".join(action_parts) + "\n\n"
147
+
148
+ # Add summary information
149
+ beautified += f"**📋 Processing Pipeline:** {len(matches)} steps\n"
150
+ beautified += f"**🛠️ Models to use:** {' → '.join(models)}"
151
+
152
+ return beautified
153
+
154
+ def resize_image_to_original(processed_image_path, original_size):
155
+ """
156
+ Resize processed image back to original dimensions
157
+
158
+ Args:
159
+ processed_image_path (str): Path to the processed image
160
+ original_size (tuple): Original image dimensions (width, height)
161
+
162
+ Returns:
163
+ str: Path to the resized image
164
+ """
165
+ if processed_image_path and os.path.exists(processed_image_path):
166
+ img = Image.open(processed_image_path)
167
+ img_resized = img.resize(original_size, Image.Resampling.LANCZOS)
168
+
169
+ # Save resized image
170
+ output_path = os.path.join('temp_outputs', 'final_result.png')
171
+ img_resized.save(output_path)
172
+ return output_path
173
+ return processed_image_path
174
+
175
+ def get_llm_response_streaming(image_path):
176
+ """
177
+ Get streaming response from LLM for image analysis
178
+
179
+ Args:
180
+ image_path (str): Path to the input image
181
+
182
+ Returns:
183
+ TextIteratorStreamer: A streamer object to yield tokens
184
+ """
185
+ # Select random prompt from the templates
186
+ instruction = prompts_query2[random.randint(0, len(prompts_query2)-1)]
187
+
188
+ # Format the prompt with image for multimodal input
189
+ prompt = (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{instruction}<|eot_id|>"
190
+ "<|start_header_id|>assistant<|end_header_id|>\n\n")
191
+
192
+ # Load and process image
193
+ raw_image = Image.open(image_path)
194
+ inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
195
+
196
+ # Setup streaming for token-by-token generation
197
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
198
+
199
+ # Generate response in a separate thread to avoid blocking
200
+ generation_kwargs = dict(
201
+ **inputs,
202
+ streamer=streamer,
203
+ max_new_tokens=400,
204
+ do_sample=False
205
+ )
206
+
207
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
208
+ thread.start()
209
+
210
+ return streamer
211
+
212
+ def process_image_with_tools(image_path, models, original_size):
213
+ """
214
+ Process image using the tool engine and restore to original size
215
+
216
+ Args:
217
+ image_path (str): Path to the input image
218
+ models (list): List of models to apply
219
+ original_size (tuple): Original image dimensions
220
+
221
+ Returns:
222
+ str: Path to the final processed image
223
+ """
224
+ if not models:
225
+ return None
226
+
227
+ # Create output directory
228
+ os.makedirs('temp_outputs', exist_ok=True)
229
+
230
+ # Process the image with selected models
231
+ res = tool_engine.process_image(models, image_path, 'temp_outputs')
232
+
233
+ # Resize back to original dimensions
234
+ final_result = resize_image_to_original(res['output_path'], original_size)
235
+
236
+ return final_result
237
+
238
+ def process_full_pipeline(image):
239
+ """
240
+ Main processing pipeline with streaming UI updates
241
+
242
+ Args:
243
+ image (str): Path to the input image
244
+
245
+ Yields:
246
+ tuple: (chat_history, processed_image) for Gradio UI updates
247
+ """
248
+ if image is None:
249
+ return [], None
250
+
251
+ try:
252
+ # Get original image size for later restoration
253
+ original_img = Image.open(image)
254
+ original_size = original_img.size
255
+
256
+ # Initialize chat history for UI
257
+ chat_history = [("Image uploaded for analysis", None)]
258
+
259
+ # Step 1: Get streaming LLM response
260
+ streamer = get_llm_response_streaming(image)
261
+
262
+ # Stream the response to UI with real-time updates
263
+ full_response = ""
264
+ in_reason = False
265
+ in_answer = False
266
+ reason_displayed = False
267
+ answer_displayed = False
268
+ reasoning_added = False # Track if reasoning entry was added
269
+
270
+ for new_text in streamer:
271
+ full_response += new_text
272
+
273
+ # Check if we're entering reason section or if we need to start showing content
274
+ if ('<reason>' in full_response and not in_reason and not reason_displayed) or (not reasoning_added and not in_reason and not reason_displayed):
275
+ in_reason = True
276
+ reasoning_added = True
277
+
278
+ if '<reason>' in full_response:
279
+ # Extract content after <reason>
280
+ reason_start = full_response.find('<reason>') + len('<reason>')
281
+ reason_content = full_response[reason_start:].strip()
282
+ else:
283
+ # Show all content as reasoning if no tag yet
284
+ reason_content = full_response.strip()
285
+
286
+ # Add reasoning to chat history
287
+ chat_history.append((None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}"))
288
+ yield chat_history, None
289
+
290
+ # If we're in reason section, update content
291
+ elif in_reason and not reason_displayed:
292
+ # Check if reason section is complete
293
+ if '</reason>' in full_response:
294
+ # Extract complete reason content
295
+ reason_start = full_response.find('<reason>') + len('<reason>')
296
+ reason_end = full_response.find('</reason>')
297
+ reason_content = full_response[reason_start:reason_end].strip()
298
+
299
+ # Update chat history with complete reason
300
+ chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
301
+ reason_displayed = True
302
+ in_reason = False
303
+ yield chat_history, None
304
+ else:
305
+ # Continue streaming reason content
306
+ if '<reason>' in full_response:
307
+ reason_start = full_response.find('<reason>') + len('<reason>')
308
+ reason_content = full_response[reason_start:].strip()
309
+ else:
310
+ reason_content = full_response.strip()
311
+
312
+ # Update chat history with partial reason
313
+ chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
314
+ yield chat_history, None
315
+
316
+ # Check if we're entering answer section
317
+ elif '<answer>' in full_response and not in_answer and not answer_displayed and reason_displayed:
318
+ in_answer = True
319
+ # Extract content after <answer>
320
+ answer_start = full_response.find('<answer>') + len('<answer>')
321
+ answer_content = full_response[answer_start:]
322
+
323
+ # Add partial answer to chat history
324
+ models = extract_models_from_answer(answer_content)
325
+ beautified = beautify_recommended_actions(answer_content, models)
326
+ chat_history.append((None, beautified))
327
+ yield chat_history, None
328
+
329
+ # If we're in answer section, update content
330
+ elif in_answer and not answer_displayed:
331
+ # Check if answer section is complete
332
+ if '</answer>' in full_response:
333
+ # Extract complete answer content
334
+ answer_start = full_response.find('<answer>') + len('<answer>')
335
+ answer_end = full_response.find('</answer>')
336
+ answer_content = full_response[answer_start:answer_end].strip()
337
+
338
+ # Parse and process final answer
339
+ models = extract_models_from_answer(answer_content)
340
+ beautified = beautify_recommended_actions(answer_content, models)
341
+ chat_history[2] = (None, beautified)
342
+ answer_displayed = True
343
+ in_answer = False
344
+ yield chat_history, None
345
+
346
+ # Process image with tools
347
+ if models:
348
+ chat_history.append((None, "**🔄 Processing image...**"))
349
+ yield chat_history, None
350
+
351
+ processed_image = process_image_with_tools(image, models, original_size)
352
+ chat_history[-1] = (None, "**✅ Processing Complete!**")
353
+ yield chat_history, processed_image
354
+ return
355
+ else:
356
+ chat_history.append((None, "**❌ No valid models found in the response**"))
357
+ yield chat_history, None
358
+ return
359
+ else:
360
+ # Continue streaming answer content
361
+ answer_start = full_response.find('<answer>') + len('<answer>')
362
+ answer_content = full_response[answer_start:].strip()
363
+
364
+ # Update chat history with partial answer
365
+ models = extract_models_from_answer(answer_content)
366
+ beautified = beautify_recommended_actions(answer_content, models)
367
+ chat_history[2] = (None, beautified)
368
+ yield chat_history, None
369
+
370
+ # Fallback if streaming completes without proper tags
371
+ if not answer_displayed:
372
+ reason, answer = parse_llm_response(full_response)
373
+ models = extract_models_from_answer(answer)
374
+
375
+ chat_history = [
376
+ ("Image uploaded for analysis", None),
377
+ (None, f"**🤔 Analysis & Reasoning:**\n\n{reason}"),
378
+ (None, beautify_recommended_actions(answer, models))
379
+ ]
380
+
381
+ if models:
382
+ chat_history.append((None, "**🔄 Processing image...**"))
383
+ yield chat_history, None
384
+
385
+ processed_image = process_image_with_tools(image, models, original_size)
386
+ chat_history[-1] = (None, "**✅ Processing Complete!**")
387
+ yield chat_history, processed_image
388
+ else:
389
+ chat_history.append((None, "**❌ No valid models found in the response**"))
390
+ yield chat_history, None
391
+
392
+ except Exception as e:
393
+ error_msg = f"Error: {str(e)}"
394
+ chat_history = [
395
+ ("Image uploaded for analysis", None),
396
+ (None, f"**❌ Error occurred:**\n\n{error_msg}")
397
+ ]
398
+ yield chat_history, None
399
+
400
+ # Create Gradio interface
401
+ def create_interface():
402
+ """
403
+ Create and configure the Gradio web interface
404
+
405
+ Returns:
406
+ gr.Blocks: Configured Gradio interface
407
+ """
408
+ with gr.Blocks(title="JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration", theme=gr.themes.Soft()) as demo:
409
+ # Header with logo and title
410
+ gr.Markdown("""
411
+ # <img src="https://cvpr2025-jarvisir.github.io/imgs/icon.png" width="32" height="32" style="display: inline-block; vertical-align: middle; transform: translateY(-2px); margin-right: 1px;"/> JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration
412
+
413
+ Upload an image and let JarvisIR analyze its degradation and recommend the best restoration tools!
414
+ """)
415
+
416
+ with gr.Row():
417
+ with gr.Column(scale=1):
418
+ # Input image upload component
419
+ input_image = gr.Image(
420
+ type="filepath",
421
+ label="📸 Upload Your Image",
422
+ height=400
423
+ )
424
+
425
+ # Process button
426
+ process_btn = gr.Button(
427
+ "🚀 Analyze & Process",
428
+ variant="primary",
429
+ size="lg"
430
+ )
431
+
432
+ with gr.Column(scale=1):
433
+ # Chat interface to show analysis
434
+ chatbot = gr.Chatbot(
435
+ label="💬 AI Analysis Chat",
436
+ height=400,
437
+ show_label=True,
438
+ bubble_full_width=False
439
+ )
440
+
441
+ with gr.Row():
442
+ # Output image display
443
+ output_image = gr.Image(
444
+ label="✨ Processed Result",
445
+ height=300
446
+ )
447
+
448
+ # Connect event handler for the process button
449
+ process_btn.click(
450
+ fn=process_full_pipeline,
451
+ inputs=[input_image],
452
+ outputs=[chatbot, output_image]
453
+ )
454
+
455
+ # Instructions section
456
+ gr.Markdown("### 📝 Instructions:")
457
+ gr.Markdown("""
458
+ 1. **Upload an image** that needs restoration (blurry, dark, noisy, etc.)
459
+ 2. **Click 'Analyze & Process'** to let AI analyze the image
460
+ 3. **View the chat** to see AI's reasoning and recommendations in real-time
461
+ 4. **Check the result** - processed image restored to original dimensions
462
+ """)
463
+
464
+ return demo
465
+
466
+ if __name__ == "__main__":
467
+ print("Starting Image Restoration Assistant...")
468
+ demo = create_interface()
469
+ # Launch the Gradio app on specified host and port
470
+ demo.launch(
471
+ server_name="0.0.0.0",
472
+ server_port=7866,
473
+ share=False
474
+ )
JarvisIR/docs/gradio_demo.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio Demo Guide
2
+
3
+ This guide provides instructions on how to run the Gradio demo for JarvisIR.
4
+
5
+ ## Environment Setup
6
+
7
+ Please follow the environment setup instructions in the [SFT Training Guide](./sft_training.md#environment-setup) to create the conda environment and install the necessary dependencies. The same environment can be used for running the Gradio demo.
8
+
9
+ Make sure you have activated the conda environment:
10
+ ```bash
11
+ conda activate sft_jarvisir
12
+ ```
13
+
14
+
15
+ ## Download Preview Weights
16
+
17
+ To run the Gradio demo, you need to download the preview weights and place them in the correct location:
18
+
19
+ 1. Download the JarvisIR preview weights from Hugging Face
20
+ 2. Create the weights directory (if it doesn't exist):
21
+ ```bash
22
+ mkdir -p checkpoints/jarvisir-preview/
23
+ ```
24
+ 3. Place the downloaded weight files in the `checkpoints/jarvisir-preview/` directory
25
+
26
+
27
+
28
+ ## Running the Demo
29
+
30
+ Once the environment is set up and activated, you can run the Gradio demo with the following command from the root directory of the project:
31
+
32
+ ```bash
33
+ python gradio_demo.py
34
+ ```
35
+
36
+ This will launch a web interface where you can interact with the model.
JarvisIR/docs/sft_training.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SFT (Supervised Fine-Tuning) Training Guide
2
+
3
+ This guide provides step-by-step instructions for setting up the environment and performing supervised fine-tuning using XTuner framework.
4
+
5
+ ## Environment Setup
6
+
7
+ ### 1. Create and Activate Virtual Environment
8
+
9
+ ```bash
10
+ conda create -n sft_jarvisir python=3.10
11
+ conda activate sft_jarvisir
12
+ ```
13
+
14
+ ### 2. Install Dependencies
15
+
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ cd package/agent_tools/Retinexformer && python3 setup.py develop --no_cuda_ext && cd ..
19
+ cd RIDCP && python3 setup.py develop && cd ../..
20
+ pip install -e .
21
+
22
+ cd ../dependences/BasicSR
23
+ pip install -e .
24
+
25
+ cd ../src/sft/xtuner
26
+
27
+ # Install required packages
28
+ pip install -r requirements.txt
29
+
30
+ # Install XTuner with all features
31
+ pip install -e '.[all]'
32
+ ```
33
+
34
+ ### 3. Verify Installation
35
+
36
+ ```bash
37
+ # Check XTuner installation
38
+ xtuner version
39
+
40
+ # Verify GPU availability
41
+ python -c "import torch; print(f'CUDA Available: {torch.cuda.is_available()}')"
42
+ python -c "import torch; print(f'GPU Count: {torch.cuda.device_count()}')"
43
+ ```
44
+
45
+ ## Model Weights Setup
46
+
47
+ ### 1. Create Model Directory Structure
48
+
49
+ ```bash
50
+ # Create directories for models and datasets
51
+ cd ..
52
+ mkdir -p base_models
53
+ ```
54
+
55
+ ### 2. Download Base Model Weights
56
+
57
+ Download the following models to the base_models folder:
58
+ <table>
59
+ <tr>
60
+ <th>Model</th><th>Model Weights</th>
61
+ </tr>
62
+ <tr style="border-top: 2px solid">
63
+ <td>openai/clip-vit-large-patch14-336</td><td><a href="https://huggingface.co/openai/clip-vit-large-patch14-336"> 🤗 HuggingFace</a></td>
64
+ </tr>
65
+ <tr style="border-top: 2px solid">
66
+ <td>xtuner/llava-llama-3-8b</td><td><a href="https://huggingface.co/xtuner/llava-llama-3-8b"> 🤗 HuggingFace</a></td>
67
+ </tr>
68
+ <tr style="border-top: 2px solid">
69
+ <td>xtuner/llava-llama-3-8b-pretrain</td><td><a href="https://huggingface.co/xtuner/llava-llama-3-8b-pretrain"> 🤗 HuggingFace</a></td>
70
+ </tr>
71
+ </table>
72
+
73
+ ### 3. Training with xtuner
74
+ - Modify the model and data paths in the **llava_8b_full.py** file.
75
+ - Run the following command to perform sft.
76
+ ```bash
77
+ NPROC_PER_NODE=${GPU_NUM} xtuner train llava_8b_full.py --deepspeed deepspeed_zero2
78
+ ```
79
+
80
+ ## Acknowledgements
81
+ We would like to thank the [XTuner](https://github.com/InternLM/xtuner) team for open-sourcing such an excellent framework. For more fine-tuning methods, please refer to their official documentation and code repository.
82
+
83
+
84
+
JarvisIR/package/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image Restoration Expert Toolkit
2
+
3
+ JarvisIR integrates diverse expert tools for image restoration, targeting problems like low-light conditions, rain/snow, haze, resolution enhancement, and noise. We've optimized our toolkit for both efficiency and performance, with some variations from the original paper.
4
+
5
+ ## Key Features
6
+
7
+ - Unified framework for multiple expert models.
8
+ - Simple API for testing each expert model.
9
+ - Quality assessment capabilities through reward functions
10
+ - Modular design for easy extension with new models
11
+
12
+
13
+ ## Expert Model List
14
+ | Task | Tools | Model Description |
15
+ |---------|---------|------|
16
+ | **Super-resolution & Deblurring & Artifact removal** | Real-ESRGAN | Fast GAN for super-resolution, deblurring, and artifact removal. |
17
+ | **Denoising** | SCUNet | Hybrid UNet-based model combining convolution and transformer blocks, designed for robust denoising. |
18
+ | **Deraining** | UDR-S2Former | An uncertainty-aware transformer model for rain streak removal. |
19
+ | | Img2img-turbo-rain | Efficient model based on SD-turbo, designed for fast and effective rain removal in real-world images. |
20
+ | **Raindrop removal** | IDT | Transformer-based model for de-raining and raindrop removal. |
21
+ | **Dehazing** | RIDCP | Efficient dehazing model utilizing high-quality codebook priors |
22
+ | | KANet | Efficient dehazing network using a localization-and-removal pipeline. |
23
+ | **Desnowing** | Img2img-turbo-snow | Efficient model for removing snow artifacts while preserving natural scene details. |
24
+ | | Snowmaster | Real-world Image Desnowing via MLLM with Multi-Model Feedback Optimization |
25
+ | **Low-light enhancement** | Retinexformer | One-stage Retinex-based Transformer for Low-light Image Enhancement |
26
+ | | HVICIDNet | Lightweight transformer for low-light and exposure correction |
27
+ | | LightenDiff | Diffusion-based framework for low-light enhancement |
28
+
29
+ ## Usage
30
+
31
+ ### Basic Usage
32
+
33
+ ```python
34
+ from agent_tools import RestorationToolkit
35
+
36
+ # Initialize the toolkit
37
+ toolkit = RestorationToolkit(device='cuda')
38
+
39
+ # Process an image with a sequence of models
40
+ result = toolkit.process_image(
41
+ tools=['scunet', 'real_esrgan'], # Models to apply in sequence
42
+ img_path='path/to/input.jpg',
43
+ output_dir='path/to/output'
44
+ )
45
+
46
+ # Access the result
47
+ output_path = result['output_path']
48
+ quality_score = result['score']
49
+ ```
50
+
51
+ ### Running the Test API Server
52
+
53
+ ```python
54
+ from agent_tools import start_server
55
+
56
+ # Start the server with default models
57
+ start_server(host='0.0.0.0', port=5010)
58
+ ```
59
+
60
+ Or use the API with curl:
61
+
62
+ ```bash
63
+ curl -X POST http://localhost:5010/process_image \
64
+ -H "Content-Type: application/json" \
65
+ -d '{"img_path": "path/to/image.jpg", "models": ["scunet", "real_esrgan"]}'
66
+ ```
67
+
68
+ ## Cite
69
+
70
+ ```
71
+ # Real-esrgan
72
+ @inproceedings{wang2021real,
73
+ title={Real-esrgan: Training real-world blind super-resolution with pure synthetic data},
74
+ author={Wang, Xintao and Xie, Liangbin and Dong, Chao and Shan, Ying},
75
+ booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
76
+ pages={1905--1914},
77
+ year={2021}
78
+ }
79
+ # img2img-turbo
80
+ @article{parmar2024one,
81
+ title={One-step image translation with text-to-image models},
82
+ author={Parmar, Gaurav and Park, Taesung and Narasimhan, Srinivasa and Zhu, Jun-Yan},
83
+ journal={arXiv preprint arXiv:2403.12036},
84
+ year={2024}
85
+ }
86
+ # Lightendiffusion
87
+ @article{jiang2024lightendiffusion,
88
+ title={Lightendiffusion: Unsupervised low-light image enhancement with latent-retinex diffusion models},
89
+ author={Jiang, Hai and Luo, Ao and Liu, Xiaohong and Han, Songchen and Liu, Shuaicheng},
90
+ journal={arXiv preprint arXiv:2407.08939},
91
+ year={2024}
92
+ }
93
+ # KANet
94
+ @article{feng2024advancing,
95
+ title={Advancing real-world image dehazing: perspective, modules, and training},
96
+ author={Feng, Yuxin and Ma, Long and Meng, Xiaozhe and Zhou, Fan and Liu, Risheng and Su, Zhuo},
97
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
98
+ year={2024},
99
+ publisher={IEEE}
100
+ }
101
+ # IDT
102
+ @article{xiao2022image,
103
+ title={Image de-raining transformer},
104
+ author={Xiao, Jie and Fu, Xueyang and Liu, Aiping and Wu, Feng and Zha, Zheng-Jun},
105
+ journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
106
+ volume={45},
107
+ number={11},
108
+ pages={12978--12995},
109
+ year={2022},
110
+ publisher={IEEE}
111
+ }
112
+ # SCUNet
113
+ @article{zhang2023practical,
114
+ author = {Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Fan, Deng-Ping and Timofte, Radu and Gool, Luc Van},
115
+ title = {Practical Blind Image Denoising via Swin-Conv-UNet and Data Synthesis},
116
+ journal = {Machine Intelligence Research},
117
+ DOI = {10.1007/s11633-023-1466-0},
118
+ url = {https://doi.org/10.1007/s11633-023-1466-0},
119
+ volume={20},
120
+ number={6},
121
+ pages={822--836},
122
+ year={2023},
123
+ publisher={Springer}
124
+ }
125
+ # S2Former
126
+ @inproceedings{chen2023sparse,
127
+ title={Sparse sampling transformer with uncertainty-driven ranking for unified removal of raindrops and rain streaks},
128
+ author={Chen, Sixiang and Ye, Tian and Bai, Jinbin and Chen, Erkang and Shi, Jun and Zhu, Lei},
129
+ booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
130
+ pages={13106--13117},
131
+ year={2023}
132
+ }
133
+ # RIDCP
134
+ @inproceedings{wu2023ridcp,
135
+ title={Ridcp: Revitalizing real image dehazing via high-quality codebook priors},
136
+ author={Wu, Rui-Qi and Duan, Zheng-Peng and Guo, Chun-Le and Chai, Zhi and Li, Chongyi},
137
+ booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
138
+ pages={22282--22291},
139
+ year={2023}
140
+ }
141
+ # HVI-CIDNet
142
+ @article{yan2024you,
143
+ title={You only need one color space: An efficient network for low-light image enhancement},
144
+ author={Yan, Qingsen and Feng, Yixu and Zhang, Cheng and Wang, Pei and Wu, Peng and Dong, Wei and Sun, Jinqiu and Zhang, Yanning},
145
+ journal={arXiv preprint arXiv:2402.05809},
146
+ year={2024}
147
+ }
148
+
149
+ # Retinexformer
150
+ @inproceedings{retinexformer,
151
+ title={Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement},
152
+ author={Yuanhao Cai and Hao Bian and Jing Lin and Haoqian Wang and Radu Timofte and Yulun Zhang},
153
+ booktitle={ICCV},
154
+ year={2023}
155
+ }
JarvisIR/package/agent_tools.egg-info/PKG-INFO ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: agent_tools
3
+ Version: 1.0.0
4
+ Summary: Agent Tools
5
+ Classifier: Programming Language :: Python :: 3
6
+ Classifier: License :: OSI Approved :: Apache Software License
7
+ Requires-Python: >=3.8
JarvisIR/package/agent_tools.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ agent_tools/__init__.py
4
+ agent_tools/iqa_reward.py
5
+ agent_tools/restoration_toolkit.py
6
+ agent_tools/tool_testing_api.py
7
+ agent_tools.egg-info/PKG-INFO
8
+ agent_tools.egg-info/SOURCES.txt
9
+ agent_tools.egg-info/dependency_links.txt
10
+ agent_tools.egg-info/top_level.txt
11
+ agent_tools/ESRGAN/__init__.py
12
+ agent_tools/ESRGAN/inference.py
13
+ agent_tools/ESRGAN/realesrgan/__init__.py
14
+ agent_tools/ESRGAN/realesrgan/train.py
15
+ agent_tools/ESRGAN/realesrgan/utils.py
16
+ agent_tools/ESRGAN/realesrgan/archs/__init__.py
17
+ agent_tools/ESRGAN/realesrgan/archs/discriminator_arch.py
18
+ agent_tools/ESRGAN/realesrgan/archs/srvgg_arch.py
19
+ agent_tools/ESRGAN/realesrgan/data/__init__.py
20
+ agent_tools/ESRGAN/realesrgan/data/realesrgan_dataset.py
21
+ agent_tools/ESRGAN/realesrgan/data/realesrgan_paired_dataset.py
22
+ agent_tools/ESRGAN/realesrgan/models/__init__.py
23
+ agent_tools/ESRGAN/realesrgan/models/realesrgan_model.py
24
+ agent_tools/ESRGAN/realesrgan/models/realesrnet_model.py
25
+ agent_tools/HVICIDNet/inference.py
26
+ agent_tools/HVICIDNet/mods.py
27
+ agent_tools/HVICIDNet/wavelet.py
28
+ agent_tools/HVICIDNet/loss/loss_utils.py
29
+ agent_tools/HVICIDNet/loss/losses.py
30
+ agent_tools/HVICIDNet/loss/niqe_utils.py
31
+ agent_tools/HVICIDNet/loss/vgg_arch.py
32
+ agent_tools/HVICIDNet/net/CIDNet.py
33
+ agent_tools/HVICIDNet/net/HVI_transform.py
34
+ agent_tools/HVICIDNet/net/LCA.py
35
+ agent_tools/HVICIDNet/net/transformer_utils.py
36
+ agent_tools/IDT/__init__.py
37
+ agent_tools/IDT/inference.py
38
+ agent_tools/IDT/analyse/cal_rf_bf.py
39
+ agent_tools/IDT/models/ICRA.py
40
+ agent_tools/IDT/models/IDT.py
41
+ agent_tools/IDT/models/Uformer.py
42
+ agent_tools/IDT/models/__init__.py
43
+ agent_tools/IDT/models/atgan.py
44
+ agent_tools/IDT/models/ddm.py
45
+ agent_tools/IDT/models/onego_genotypes_searched.py
46
+ agent_tools/IDT/models/onego_ops_derain.py
47
+ agent_tools/IDT/models/onego_se_nets.py
48
+ agent_tools/IDT/models/onego_train_model.py
49
+ agent_tools/IDT/models/restoration.py
50
+ agent_tools/IDT/models/restormer.py
51
+ agent_tools/IDT/models/transformer2d.py
52
+ agent_tools/IDT/models/unet.py
53
+ agent_tools/IDT/utils/__init__.py
54
+ agent_tools/IDT/utils/logging.py
55
+ agent_tools/IDT/utils/optimize.py
56
+ agent_tools/IDT/utils/sampling.py
57
+ agent_tools/KANet/LD_model1.py
58
+ agent_tools/KANet/base_networks.py
59
+ agent_tools/KANet/deconv.py
60
+ agent_tools/KANet/inference.py
61
+ agent_tools/KANet/transweather_model.py
62
+ agent_tools/LightenDiffusion/inference.py
63
+ agent_tools/LightenDiffusion/models/__init__.py
64
+ agent_tools/LightenDiffusion/models/ddm.py
65
+ agent_tools/LightenDiffusion/models/decom.py
66
+ agent_tools/LightenDiffusion/models/restoration.py
67
+ agent_tools/LightenDiffusion/models/unet.py
68
+ agent_tools/LightenDiffusion/utils/__init__.py
69
+ agent_tools/LightenDiffusion/utils/logging.py
70
+ agent_tools/LightenDiffusion/utils/optimize.py
71
+ agent_tools/LightenDiffusion/utils/sampling.py
72
+ agent_tools/RIDCP/__init__.py
73
+ agent_tools/RIDCP/inference.py
74
+ agent_tools/RIDCP/setup.py
75
+ agent_tools/RIDCP/basicsr_ridcp/__init__.py
76
+ agent_tools/RIDCP/basicsr_ridcp/active_codebook.py
77
+ agent_tools/RIDCP/basicsr_ridcp/test.py
78
+ agent_tools/RIDCP/basicsr_ridcp/train.py
79
+ agent_tools/RIDCP/basicsr_ridcp/archs/__init__.py
80
+ agent_tools/RIDCP/basicsr_ridcp/archs/arch_util.py
81
+ agent_tools/RIDCP/basicsr_ridcp/archs/dehaze_vq_weight_arch.py
82
+ agent_tools/RIDCP/basicsr_ridcp/archs/discriminator_arch.py
83
+ agent_tools/RIDCP/basicsr_ridcp/archs/network_swinir.py
84
+ agent_tools/RIDCP/basicsr_ridcp/archs/ridcp_utils.py
85
+ agent_tools/RIDCP/basicsr_ridcp/archs/vgg_arch.py
86
+ agent_tools/RIDCP/basicsr_ridcp/data/__init__.py
87
+ agent_tools/RIDCP/basicsr_ridcp/data/data_sampler.py
88
+ agent_tools/RIDCP/basicsr_ridcp/data/data_util.py
89
+ agent_tools/RIDCP/basicsr_ridcp/data/haze_online_dataset.py
90
+ agent_tools/RIDCP/basicsr_ridcp/data/prefetch_dataloader.py
91
+ agent_tools/RIDCP/basicsr_ridcp/data/transforms.py
92
+ agent_tools/RIDCP/basicsr_ridcp/losses/__init__.py
93
+ agent_tools/RIDCP/basicsr_ridcp/losses/loss_util.py
94
+ agent_tools/RIDCP/basicsr_ridcp/losses/losses.py
95
+ agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_fid_folder.py
96
+ agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_fid_stats_from_datasets.py
97
+ agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_lpips.py
98
+ agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_niqe.py
99
+ agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_psnr_ssim.py
100
+ agent_tools/RIDCP/basicsr_ridcp/metrics/calculate_stylegan2_fid.py
101
+ agent_tools/RIDCP/basicsr_ridcp/models/__init__.py
102
+ agent_tools/RIDCP/basicsr_ridcp/models/base_model.py
103
+ agent_tools/RIDCP/basicsr_ridcp/models/dehaze_vq_model.py
104
+ agent_tools/RIDCP/basicsr_ridcp/models/lr_scheduler.py
105
+ agent_tools/RIDCP/basicsr_ridcp/ops/__init__.py
106
+ agent_tools/RIDCP/basicsr_ridcp/ops/dcn/__init__.py
107
+ agent_tools/RIDCP/basicsr_ridcp/ops/dcn/deform_conv.py
108
+ agent_tools/RIDCP/basicsr_ridcp/utils/__init__.py
109
+ agent_tools/RIDCP/basicsr_ridcp/utils/diffjpeg.py
110
+ agent_tools/RIDCP/basicsr_ridcp/utils/dist_util.py
111
+ agent_tools/RIDCP/basicsr_ridcp/utils/download_util.py
112
+ agent_tools/RIDCP/basicsr_ridcp/utils/face_util.py
113
+ agent_tools/RIDCP/basicsr_ridcp/utils/file_client.py
114
+ agent_tools/RIDCP/basicsr_ridcp/utils/flow_util.py
115
+ agent_tools/RIDCP/basicsr_ridcp/utils/img_process_util.py
116
+ agent_tools/RIDCP/basicsr_ridcp/utils/img_util.py
117
+ agent_tools/RIDCP/basicsr_ridcp/utils/lmdb_util.py
118
+ agent_tools/RIDCP/basicsr_ridcp/utils/logger.py
119
+ agent_tools/RIDCP/basicsr_ridcp/utils/matlab_functions.py
120
+ agent_tools/RIDCP/basicsr_ridcp/utils/misc.py
121
+ agent_tools/RIDCP/basicsr_ridcp/utils/options.py
122
+ agent_tools/RIDCP/basicsr_ridcp/utils/registry.py
123
+ agent_tools/RIDCP/utils/utils_image.py
124
+ agent_tools/Retinexformer/__init__.py
125
+ agent_tools/Retinexformer/inference.py
126
+ agent_tools/Retinexformer/setup.py
127
+ agent_tools/Retinexformer/Enhancement/utils.py
128
+ agent_tools/Retinexformer/basicsr_retinexformer/test.py
129
+ agent_tools/Retinexformer/basicsr_retinexformer/train.py
130
+ agent_tools/Retinexformer/basicsr_retinexformer/version.py
131
+ agent_tools/Retinexformer/basicsr_retinexformer/data/SDSD_image_dataset.py
132
+ agent_tools/Retinexformer/basicsr_retinexformer/data/SID_image_dataset.py
133
+ agent_tools/Retinexformer/basicsr_retinexformer/data/SMID_image_dataset.py
134
+ agent_tools/Retinexformer/basicsr_retinexformer/data/__init__.py
135
+ agent_tools/Retinexformer/basicsr_retinexformer/data/data_sampler.py
136
+ agent_tools/Retinexformer/basicsr_retinexformer/data/data_util.py
137
+ agent_tools/Retinexformer/basicsr_retinexformer/data/ffhq_dataset.py
138
+ agent_tools/Retinexformer/basicsr_retinexformer/data/paired_image_dataset.py
139
+ agent_tools/Retinexformer/basicsr_retinexformer/data/prefetch_dataloader.py
140
+ agent_tools/Retinexformer/basicsr_retinexformer/data/reds_dataset.py
141
+ agent_tools/Retinexformer/basicsr_retinexformer/data/single_image_dataset.py
142
+ agent_tools/Retinexformer/basicsr_retinexformer/data/transforms.py
143
+ agent_tools/Retinexformer/basicsr_retinexformer/data/util.py
144
+ agent_tools/Retinexformer/basicsr_retinexformer/data/video_test_dataset.py
145
+ agent_tools/Retinexformer/basicsr_retinexformer/data/vimeo90k_dataset.py
146
+ agent_tools/Retinexformer/basicsr_retinexformer/metrics/__init__.py
147
+ agent_tools/Retinexformer/basicsr_retinexformer/metrics/fid.py
148
+ agent_tools/Retinexformer/basicsr_retinexformer/metrics/metric_util.py
149
+ agent_tools/Retinexformer/basicsr_retinexformer/metrics/niqe.py
150
+ agent_tools/Retinexformer/basicsr_retinexformer/metrics/psnr_ssim.py
151
+ agent_tools/Retinexformer/basicsr_retinexformer/models/__init__.py
152
+ agent_tools/Retinexformer/basicsr_retinexformer/models/base_model.py
153
+ agent_tools/Retinexformer/basicsr_retinexformer/models/image_restoration_model.py
154
+ agent_tools/Retinexformer/basicsr_retinexformer/models/lr_scheduler.py
155
+ agent_tools/Retinexformer/basicsr_retinexformer/models/archs/MST_Plus_Plus_arch.py
156
+ agent_tools/Retinexformer/basicsr_retinexformer/models/archs/RetinexFormer_arch.py
157
+ agent_tools/Retinexformer/basicsr_retinexformer/models/archs/__init__.py
158
+ agent_tools/Retinexformer/basicsr_retinexformer/models/archs/arch_util.py
159
+ agent_tools/Retinexformer/basicsr_retinexformer/models/archs/layers.py
160
+ agent_tools/Retinexformer/basicsr_retinexformer/models/losses/__init__.py
161
+ agent_tools/Retinexformer/basicsr_retinexformer/models/losses/loss_util.py
162
+ agent_tools/Retinexformer/basicsr_retinexformer/models/losses/losses.py
163
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/__init__.py
164
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/bundle_submissions.py
165
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/create_lmdb.py
166
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/dist_util.py
167
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/download_util.py
168
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/face_util.py
169
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/file_client.py
170
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/flow_util.py
171
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/img_util.py
172
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/lmdb_util.py
173
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/logger.py
174
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/matlab_functions.py
175
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/misc.py
176
+ agent_tools/Retinexformer/basicsr_retinexformer/utils/options.py
177
+ agent_tools/S2Former/UDR_S2Former.py
178
+ agent_tools/S2Former/base_net_snow.py
179
+ agent_tools/S2Former/condconv.py
180
+ agent_tools/S2Former/inference.py
181
+ agent_tools/SCUNet/__init__.py
182
+ agent_tools/SCUNet/inference.py
183
+ agent_tools/SCUNet/models/network_scunet.py
184
+ agent_tools/SCUNet/utils/utils_image.py
185
+ agent_tools/SnowMaster/inference.py
186
+ agent_tools/SnowMaster/nafnet.py
187
+ agent_tools/SnowMaster/nafnet_utils.py
188
+ agent_tools/img2img_turbo/__init__.py
189
+ agent_tools/img2img_turbo/inference.py
190
+ agent_tools/img2img_turbo/src/cyclegan_turbo.py
191
+ agent_tools/img2img_turbo/src/model.py
192
+ agent_tools/img2img_turbo/src/my_utils/training_utils.py
JarvisIR/package/agent_tools.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
JarvisIR/package/agent_tools.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ agent_tools
JarvisIR/package/agent_tools/ESRGAN/__init__.py ADDED
File without changes
JarvisIR/package/agent_tools/ESRGAN/inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ from basicsr.archs.rrdbnet_arch import RRDBNet
4
+
5
+ from .realesrgan import RealESRGANer
6
+
7
+
8
+ def load_esrgan_model(model_path, device):
9
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
10
+ netscale = 4
11
+
12
+
13
+ # use dni to control the denoise strength
14
+ dni_weight = None
15
+
16
+
17
+ # restorer
18
+ upsampler = RealESRGANer(
19
+ scale=netscale,
20
+ model_path=model_path,
21
+ dni_weight=dni_weight,
22
+ model=model,
23
+ tile=0,
24
+ tile_pad=10,
25
+ pre_pad=0,
26
+ half=True,
27
+ device=device)
28
+ return upsampler
29
+
30
+ def esrgan_predict(upsampler, input_image, output_dir, device,):
31
+
32
+ # determine models according to model names
33
+ outscale = 4 # the final upsampling scale
34
+ imgname, extension = os.path.splitext(os.path.basename(input_image))
35
+
36
+ img = cv2.imread(input_image, cv2.IMREAD_COLOR)
37
+ h, w, _ = img.shape
38
+ try:
39
+ output, _ = upsampler.enhance(img, outscale=outscale)
40
+ except RuntimeError as error:
41
+ print('Error', error)
42
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
43
+ # resize back to the original resolution
44
+ output = cv2.resize(output, (w, h), interpolation=cv2.INTER_CUBIC)
45
+
46
+ save_path = os.path.join(output_dir, f'{imgname}.png')
47
+ cv2.imwrite(save_path, output)
48
+
49
+ return save_path
JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus.yml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: finetune_RealESRGANx4plus_400k
3
+ model_type: RealESRGANModel
4
+ scale: 4
5
+ num_gpu: auto
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9
+ # USM the ground-truth
10
+ l1_gt_usm: True
11
+ percep_gt_usm: True
12
+ gan_gt_usm: False
13
+
14
+ # the first degradation process
15
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16
+ resize_range: [0.15, 1.5]
17
+ gaussian_noise_prob: 0.5
18
+ noise_range: [1, 30]
19
+ poisson_scale_range: [0.05, 3]
20
+ gray_noise_prob: 0.4
21
+ jpeg_range: [30, 95]
22
+
23
+ # the second degradation process
24
+ second_blur_prob: 0.8
25
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26
+ resize_range2: [0.3, 1.2]
27
+ gaussian_noise_prob2: 0.5
28
+ noise_range2: [1, 25]
29
+ poisson_scale_range2: [0.05, 2.5]
30
+ gray_noise_prob2: 0.4
31
+ jpeg_range2: [30, 95]
32
+
33
+ gt_size: 256
34
+ queue_size: 180
35
+
36
+ # dataset and data loader settings
37
+ datasets:
38
+ train:
39
+ name: DF2K+OST
40
+ type: RealESRGANDataset
41
+ dataroot_gt: datasets/DF2K
42
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
+ io_backend:
44
+ type: disk
45
+
46
+ blur_kernel_size: 21
47
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49
+ sinc_prob: 0.1
50
+ blur_sigma: [0.2, 3]
51
+ betag_range: [0.5, 4]
52
+ betap_range: [1, 2]
53
+
54
+ blur_kernel_size2: 21
55
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57
+ sinc_prob2: 0.1
58
+ blur_sigma2: [0.2, 1.5]
59
+ betag_range2: [0.5, 4]
60
+ betap_range2: [1, 2]
61
+
62
+ final_sinc_prob: 0.8
63
+
64
+ gt_size: 256
65
+ use_hflip: True
66
+ use_rot: False
67
+
68
+ # data loader
69
+ use_shuffle: true
70
+ num_worker_per_gpu: 5
71
+ batch_size_per_gpu: 12
72
+ dataset_enlarge_ratio: 1
73
+ prefetch_mode: ~
74
+
75
+ # Uncomment these for validation
76
+ # val:
77
+ # name: validation
78
+ # type: PairedImageDataset
79
+ # dataroot_gt: path_to_gt
80
+ # dataroot_lq: path_to_lq
81
+ # io_backend:
82
+ # type: disk
83
+
84
+ # network structures
85
+ network_g:
86
+ type: RRDBNet
87
+ num_in_ch: 3
88
+ num_out_ch: 3
89
+ num_feat: 64
90
+ num_block: 23
91
+ num_grow_ch: 32
92
+
93
+ network_d:
94
+ type: UNetDiscriminatorSN
95
+ num_in_ch: 3
96
+ num_feat: 64
97
+ skip_connection: True
98
+
99
+ # path
100
+ path:
101
+ # use the pre-trained Real-ESRNet model
102
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
103
+ param_key_g: params_ema
104
+ strict_load_g: true
105
+ pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
106
+ param_key_d: params
107
+ strict_load_d: true
108
+ resume_state: ~
109
+
110
+ # training settings
111
+ train:
112
+ ema_decay: 0.999
113
+ optim_g:
114
+ type: Adam
115
+ lr: !!float 1e-4
116
+ weight_decay: 0
117
+ betas: [0.9, 0.99]
118
+ optim_d:
119
+ type: Adam
120
+ lr: !!float 1e-4
121
+ weight_decay: 0
122
+ betas: [0.9, 0.99]
123
+
124
+ scheduler:
125
+ type: MultiStepLR
126
+ milestones: [400000]
127
+ gamma: 0.5
128
+
129
+ total_iter: 400000
130
+ warmup_iter: -1 # no warm up
131
+
132
+ # losses
133
+ pixel_opt:
134
+ type: L1Loss
135
+ loss_weight: 1.0
136
+ reduction: mean
137
+ # perceptual loss (content and style losses)
138
+ perceptual_opt:
139
+ type: PerceptualLoss
140
+ layer_weights:
141
+ # before relu
142
+ 'conv1_2': 0.1
143
+ 'conv2_2': 0.1
144
+ 'conv3_4': 1
145
+ 'conv4_4': 1
146
+ 'conv5_4': 1
147
+ vgg_type: vgg19
148
+ use_input_norm: true
149
+ perceptual_weight: !!float 1.0
150
+ style_weight: 0
151
+ range_norm: false
152
+ criterion: l1
153
+ # gan loss
154
+ gan_opt:
155
+ type: GANLoss
156
+ gan_type: vanilla
157
+ real_label_val: 1.0
158
+ fake_label_val: 0.0
159
+ loss_weight: !!float 1e-1
160
+
161
+ net_d_iters: 1
162
+ net_d_init_iters: 0
163
+
164
+ # Uncomment these for validation
165
+ # validation settings
166
+ # val:
167
+ # val_freq: !!float 5e3
168
+ # save_img: True
169
+
170
+ # metrics:
171
+ # psnr: # metric name
172
+ # type: calculate_psnr
173
+ # crop_border: 4
174
+ # test_y_channel: false
175
+
176
+ # logging settings
177
+ logger:
178
+ print_freq: 100
179
+ save_checkpoint_freq: !!float 5e3
180
+ use_tb_logger: true
181
+ wandb:
182
+ project: ~
183
+ resume_id: ~
184
+
185
+ # dist training settings
186
+ dist_params:
187
+ backend: nccl
188
+ port: 29500
JarvisIR/package/agent_tools/ESRGAN/options/finetune_realesrgan_x4plus_pairdata.yml ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: finetune_RealESRGANx4plus_400k_pairdata
3
+ model_type: RealESRGANModel
4
+ scale: 4
5
+ num_gpu: auto
6
+ manual_seed: 0
7
+
8
+ # USM the ground-truth
9
+ l1_gt_usm: True
10
+ percep_gt_usm: True
11
+ gan_gt_usm: False
12
+
13
+ high_order_degradation: False # do not use the high-order degradation generation process
14
+
15
+ # dataset and data loader settings
16
+ datasets:
17
+ train:
18
+ name: DIV2K
19
+ type: RealESRGANPairedDataset
20
+ dataroot_gt: datasets/DF2K
21
+ dataroot_lq: datasets/DF2K
22
+ meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
23
+ io_backend:
24
+ type: disk
25
+
26
+ gt_size: 256
27
+ use_hflip: True
28
+ use_rot: False
29
+
30
+ # data loader
31
+ use_shuffle: true
32
+ num_worker_per_gpu: 5
33
+ batch_size_per_gpu: 12
34
+ dataset_enlarge_ratio: 1
35
+ prefetch_mode: ~
36
+
37
+ # Uncomment these for validation
38
+ # val:
39
+ # name: validation
40
+ # type: PairedImageDataset
41
+ # dataroot_gt: path_to_gt
42
+ # dataroot_lq: path_to_lq
43
+ # io_backend:
44
+ # type: disk
45
+
46
+ # network structures
47
+ network_g:
48
+ type: RRDBNet
49
+ num_in_ch: 3
50
+ num_out_ch: 3
51
+ num_feat: 64
52
+ num_block: 23
53
+ num_grow_ch: 32
54
+
55
+ network_d:
56
+ type: UNetDiscriminatorSN
57
+ num_in_ch: 3
58
+ num_feat: 64
59
+ skip_connection: True
60
+
61
+ # path
62
+ path:
63
+ # use the pre-trained Real-ESRNet model
64
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
65
+ param_key_g: params_ema
66
+ strict_load_g: true
67
+ pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
68
+ param_key_d: params
69
+ strict_load_d: true
70
+ resume_state: ~
71
+
72
+ # training settings
73
+ train:
74
+ ema_decay: 0.999
75
+ optim_g:
76
+ type: Adam
77
+ lr: !!float 1e-4
78
+ weight_decay: 0
79
+ betas: [0.9, 0.99]
80
+ optim_d:
81
+ type: Adam
82
+ lr: !!float 1e-4
83
+ weight_decay: 0
84
+ betas: [0.9, 0.99]
85
+
86
+ scheduler:
87
+ type: MultiStepLR
88
+ milestones: [400000]
89
+ gamma: 0.5
90
+
91
+ total_iter: 400000
92
+ warmup_iter: -1 # no warm up
93
+
94
+ # losses
95
+ pixel_opt:
96
+ type: L1Loss
97
+ loss_weight: 1.0
98
+ reduction: mean
99
+ # perceptual loss (content and style losses)
100
+ perceptual_opt:
101
+ type: PerceptualLoss
102
+ layer_weights:
103
+ # before relu
104
+ 'conv1_2': 0.1
105
+ 'conv2_2': 0.1
106
+ 'conv3_4': 1
107
+ 'conv4_4': 1
108
+ 'conv5_4': 1
109
+ vgg_type: vgg19
110
+ use_input_norm: true
111
+ perceptual_weight: !!float 1.0
112
+ style_weight: 0
113
+ range_norm: false
114
+ criterion: l1
115
+ # gan loss
116
+ gan_opt:
117
+ type: GANLoss
118
+ gan_type: vanilla
119
+ real_label_val: 1.0
120
+ fake_label_val: 0.0
121
+ loss_weight: !!float 1e-1
122
+
123
+ net_d_iters: 1
124
+ net_d_init_iters: 0
125
+
126
+ # Uncomment these for validation
127
+ # validation settings
128
+ # val:
129
+ # val_freq: !!float 5e3
130
+ # save_img: True
131
+
132
+ # metrics:
133
+ # psnr: # metric name
134
+ # type: calculate_psnr
135
+ # crop_border: 4
136
+ # test_y_channel: false
137
+
138
+ # logging settings
139
+ logger:
140
+ print_freq: 100
141
+ save_checkpoint_freq: !!float 5e3
142
+ use_tb_logger: true
143
+ wandb:
144
+ project: ~
145
+ resume_id: ~
146
+
147
+ # dist training settings
148
+ dist_params:
149
+ backend: nccl
150
+ port: 29500
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x2plus.yml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_RealESRGANx2plus_400k_B12G4
3
+ model_type: RealESRGANModel
4
+ scale: 2
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9
+ # USM the ground-truth
10
+ l1_gt_usm: True
11
+ percep_gt_usm: True
12
+ gan_gt_usm: False
13
+
14
+ # the first degradation process
15
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16
+ resize_range: [0.15, 1.5]
17
+ gaussian_noise_prob: 0.5
18
+ noise_range: [1, 30]
19
+ poisson_scale_range: [0.05, 3]
20
+ gray_noise_prob: 0.4
21
+ jpeg_range: [30, 95]
22
+
23
+ # the second degradation process
24
+ second_blur_prob: 0.8
25
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26
+ resize_range2: [0.3, 1.2]
27
+ gaussian_noise_prob2: 0.5
28
+ noise_range2: [1, 25]
29
+ poisson_scale_range2: [0.05, 2.5]
30
+ gray_noise_prob2: 0.4
31
+ jpeg_range2: [30, 95]
32
+
33
+ gt_size: 256
34
+ queue_size: 180
35
+
36
+ # dataset and data loader settings
37
+ datasets:
38
+ train:
39
+ name: DF2K+OST
40
+ type: RealESRGANDataset
41
+ dataroot_gt: datasets/DF2K
42
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
+ io_backend:
44
+ type: disk
45
+
46
+ blur_kernel_size: 21
47
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49
+ sinc_prob: 0.1
50
+ blur_sigma: [0.2, 3]
51
+ betag_range: [0.5, 4]
52
+ betap_range: [1, 2]
53
+
54
+ blur_kernel_size2: 21
55
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57
+ sinc_prob2: 0.1
58
+ blur_sigma2: [0.2, 1.5]
59
+ betag_range2: [0.5, 4]
60
+ betap_range2: [1, 2]
61
+
62
+ final_sinc_prob: 0.8
63
+
64
+ gt_size: 256
65
+ use_hflip: True
66
+ use_rot: False
67
+
68
+ # data loader
69
+ use_shuffle: true
70
+ num_worker_per_gpu: 5
71
+ batch_size_per_gpu: 12
72
+ dataset_enlarge_ratio: 1
73
+ prefetch_mode: ~
74
+
75
+ # Uncomment these for validation
76
+ # val:
77
+ # name: validation
78
+ # type: PairedImageDataset
79
+ # dataroot_gt: path_to_gt
80
+ # dataroot_lq: path_to_lq
81
+ # io_backend:
82
+ # type: disk
83
+
84
+ # network structures
85
+ network_g:
86
+ type: RRDBNet
87
+ num_in_ch: 3
88
+ num_out_ch: 3
89
+ num_feat: 64
90
+ num_block: 23
91
+ num_grow_ch: 32
92
+ scale: 2
93
+
94
+ network_d:
95
+ type: UNetDiscriminatorSN
96
+ num_in_ch: 3
97
+ num_feat: 64
98
+ skip_connection: True
99
+
100
+ # path
101
+ path:
102
+ # use the pre-trained Real-ESRNet model
103
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
104
+ param_key_g: params_ema
105
+ strict_load_g: true
106
+ resume_state: ~
107
+
108
+ # training settings
109
+ train:
110
+ ema_decay: 0.999
111
+ optim_g:
112
+ type: Adam
113
+ lr: !!float 1e-4
114
+ weight_decay: 0
115
+ betas: [0.9, 0.99]
116
+ optim_d:
117
+ type: Adam
118
+ lr: !!float 1e-4
119
+ weight_decay: 0
120
+ betas: [0.9, 0.99]
121
+
122
+ scheduler:
123
+ type: MultiStepLR
124
+ milestones: [400000]
125
+ gamma: 0.5
126
+
127
+ total_iter: 400000
128
+ warmup_iter: -1 # no warm up
129
+
130
+ # losses
131
+ pixel_opt:
132
+ type: L1Loss
133
+ loss_weight: 1.0
134
+ reduction: mean
135
+ # perceptual loss (content and style losses)
136
+ perceptual_opt:
137
+ type: PerceptualLoss
138
+ layer_weights:
139
+ # before relu
140
+ 'conv1_2': 0.1
141
+ 'conv2_2': 0.1
142
+ 'conv3_4': 1
143
+ 'conv4_4': 1
144
+ 'conv5_4': 1
145
+ vgg_type: vgg19
146
+ use_input_norm: true
147
+ perceptual_weight: !!float 1.0
148
+ style_weight: 0
149
+ range_norm: false
150
+ criterion: l1
151
+ # gan loss
152
+ gan_opt:
153
+ type: GANLoss
154
+ gan_type: vanilla
155
+ real_label_val: 1.0
156
+ fake_label_val: 0.0
157
+ loss_weight: !!float 1e-1
158
+
159
+ net_d_iters: 1
160
+ net_d_init_iters: 0
161
+
162
+ # Uncomment these for validation
163
+ # validation settings
164
+ # val:
165
+ # val_freq: !!float 5e3
166
+ # save_img: True
167
+
168
+ # metrics:
169
+ # psnr: # metric name
170
+ # type: calculate_psnr
171
+ # crop_border: 4
172
+ # test_y_channel: false
173
+
174
+ # logging settings
175
+ logger:
176
+ print_freq: 100
177
+ save_checkpoint_freq: !!float 5e3
178
+ use_tb_logger: true
179
+ wandb:
180
+ project: ~
181
+ resume_id: ~
182
+
183
+ # dist training settings
184
+ dist_params:
185
+ backend: nccl
186
+ port: 29500
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrgan_x4plus.yml ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_RealESRGANx4plus_400k_B12G4
3
+ model_type: RealESRGANModel
4
+ scale: 4
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9
+ # USM the ground-truth
10
+ l1_gt_usm: True
11
+ percep_gt_usm: True
12
+ gan_gt_usm: False
13
+
14
+ # the first degradation process
15
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16
+ resize_range: [0.15, 1.5]
17
+ gaussian_noise_prob: 0.5
18
+ noise_range: [1, 30]
19
+ poisson_scale_range: [0.05, 3]
20
+ gray_noise_prob: 0.4
21
+ jpeg_range: [30, 95]
22
+
23
+ # the second degradation process
24
+ second_blur_prob: 0.8
25
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26
+ resize_range2: [0.3, 1.2]
27
+ gaussian_noise_prob2: 0.5
28
+ noise_range2: [1, 25]
29
+ poisson_scale_range2: [0.05, 2.5]
30
+ gray_noise_prob2: 0.4
31
+ jpeg_range2: [30, 95]
32
+
33
+ gt_size: 256
34
+ queue_size: 180
35
+
36
+ # dataset and data loader settings
37
+ datasets:
38
+ train:
39
+ name: DF2K+OST
40
+ type: RealESRGANDataset
41
+ dataroot_gt: datasets/DF2K
42
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
+ io_backend:
44
+ type: disk
45
+
46
+ blur_kernel_size: 21
47
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49
+ sinc_prob: 0.1
50
+ blur_sigma: [0.2, 3]
51
+ betag_range: [0.5, 4]
52
+ betap_range: [1, 2]
53
+
54
+ blur_kernel_size2: 21
55
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57
+ sinc_prob2: 0.1
58
+ blur_sigma2: [0.2, 1.5]
59
+ betag_range2: [0.5, 4]
60
+ betap_range2: [1, 2]
61
+
62
+ final_sinc_prob: 0.8
63
+
64
+ gt_size: 256
65
+ use_hflip: True
66
+ use_rot: False
67
+
68
+ # data loader
69
+ use_shuffle: true
70
+ num_worker_per_gpu: 5
71
+ batch_size_per_gpu: 12
72
+ dataset_enlarge_ratio: 1
73
+ prefetch_mode: ~
74
+
75
+ # Uncomment these for validation
76
+ # val:
77
+ # name: validation
78
+ # type: PairedImageDataset
79
+ # dataroot_gt: path_to_gt
80
+ # dataroot_lq: path_to_lq
81
+ # io_backend:
82
+ # type: disk
83
+
84
+ # network structures
85
+ network_g:
86
+ type: RRDBNet
87
+ num_in_ch: 3
88
+ num_out_ch: 3
89
+ num_feat: 64
90
+ num_block: 23
91
+ num_grow_ch: 32
92
+
93
+ network_d:
94
+ type: UNetDiscriminatorSN
95
+ num_in_ch: 3
96
+ num_feat: 64
97
+ skip_connection: True
98
+
99
+ # path
100
+ path:
101
+ # use the pre-trained Real-ESRNet model
102
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
103
+ param_key_g: params_ema
104
+ strict_load_g: true
105
+ resume_state: ~
106
+
107
+ # training settings
108
+ train:
109
+ ema_decay: 0.999
110
+ optim_g:
111
+ type: Adam
112
+ lr: !!float 1e-4
113
+ weight_decay: 0
114
+ betas: [0.9, 0.99]
115
+ optim_d:
116
+ type: Adam
117
+ lr: !!float 1e-4
118
+ weight_decay: 0
119
+ betas: [0.9, 0.99]
120
+
121
+ scheduler:
122
+ type: MultiStepLR
123
+ milestones: [400000]
124
+ gamma: 0.5
125
+
126
+ total_iter: 400000
127
+ warmup_iter: -1 # no warm up
128
+
129
+ # losses
130
+ pixel_opt:
131
+ type: L1Loss
132
+ loss_weight: 1.0
133
+ reduction: mean
134
+ # perceptual loss (content and style losses)
135
+ perceptual_opt:
136
+ type: PerceptualLoss
137
+ layer_weights:
138
+ # before relu
139
+ 'conv1_2': 0.1
140
+ 'conv2_2': 0.1
141
+ 'conv3_4': 1
142
+ 'conv4_4': 1
143
+ 'conv5_4': 1
144
+ vgg_type: vgg19
145
+ use_input_norm: true
146
+ perceptual_weight: !!float 1.0
147
+ style_weight: 0
148
+ range_norm: false
149
+ criterion: l1
150
+ # gan loss
151
+ gan_opt:
152
+ type: GANLoss
153
+ gan_type: vanilla
154
+ real_label_val: 1.0
155
+ fake_label_val: 0.0
156
+ loss_weight: !!float 1e-1
157
+
158
+ net_d_iters: 1
159
+ net_d_init_iters: 0
160
+
161
+ # Uncomment these for validation
162
+ # validation settings
163
+ # val:
164
+ # val_freq: !!float 5e3
165
+ # save_img: True
166
+
167
+ # metrics:
168
+ # psnr: # metric name
169
+ # type: calculate_psnr
170
+ # crop_border: 4
171
+ # test_y_channel: false
172
+
173
+ # logging settings
174
+ logger:
175
+ print_freq: 100
176
+ save_checkpoint_freq: !!float 5e3
177
+ use_tb_logger: true
178
+ wandb:
179
+ project: ~
180
+ resume_id: ~
181
+
182
+ # dist training settings
183
+ dist_params:
184
+ backend: nccl
185
+ port: 29500
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x2plus.yml ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_RealESRNetx2plus_1000k_B12G4
3
+ model_type: RealESRNetModel
4
+ scale: 2
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
9
+ gt_usm: True # USM the ground-truth
10
+
11
+ # the first degradation process
12
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
13
+ resize_range: [0.15, 1.5]
14
+ gaussian_noise_prob: 0.5
15
+ noise_range: [1, 30]
16
+ poisson_scale_range: [0.05, 3]
17
+ gray_noise_prob: 0.4
18
+ jpeg_range: [30, 95]
19
+
20
+ # the second degradation process
21
+ second_blur_prob: 0.8
22
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
23
+ resize_range2: [0.3, 1.2]
24
+ gaussian_noise_prob2: 0.5
25
+ noise_range2: [1, 25]
26
+ poisson_scale_range2: [0.05, 2.5]
27
+ gray_noise_prob2: 0.4
28
+ jpeg_range2: [30, 95]
29
+
30
+ gt_size: 256
31
+ queue_size: 180
32
+
33
+ # dataset and data loader settings
34
+ datasets:
35
+ train:
36
+ name: DF2K+OST
37
+ type: RealESRGANDataset
38
+ dataroot_gt: datasets/DF2K
39
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40
+ io_backend:
41
+ type: disk
42
+
43
+ blur_kernel_size: 21
44
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
45
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
46
+ sinc_prob: 0.1
47
+ blur_sigma: [0.2, 3]
48
+ betag_range: [0.5, 4]
49
+ betap_range: [1, 2]
50
+
51
+ blur_kernel_size2: 21
52
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
53
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
54
+ sinc_prob2: 0.1
55
+ blur_sigma2: [0.2, 1.5]
56
+ betag_range2: [0.5, 4]
57
+ betap_range2: [1, 2]
58
+
59
+ final_sinc_prob: 0.8
60
+
61
+ gt_size: 256
62
+ use_hflip: True
63
+ use_rot: False
64
+
65
+ # data loader
66
+ use_shuffle: true
67
+ num_worker_per_gpu: 5
68
+ batch_size_per_gpu: 12
69
+ dataset_enlarge_ratio: 1
70
+ prefetch_mode: ~
71
+
72
+ # Uncomment these for validation
73
+ # val:
74
+ # name: validation
75
+ # type: PairedImageDataset
76
+ # dataroot_gt: path_to_gt
77
+ # dataroot_lq: path_to_lq
78
+ # io_backend:
79
+ # type: disk
80
+
81
+ # network structures
82
+ network_g:
83
+ type: RRDBNet
84
+ num_in_ch: 3
85
+ num_out_ch: 3
86
+ num_feat: 64
87
+ num_block: 23
88
+ num_grow_ch: 32
89
+ scale: 2
90
+
91
+ # path
92
+ path:
93
+ pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
94
+ param_key_g: params_ema
95
+ strict_load_g: False
96
+ resume_state: ~
97
+
98
+ # training settings
99
+ train:
100
+ ema_decay: 0.999
101
+ optim_g:
102
+ type: Adam
103
+ lr: !!float 2e-4
104
+ weight_decay: 0
105
+ betas: [0.9, 0.99]
106
+
107
+ scheduler:
108
+ type: MultiStepLR
109
+ milestones: [1000000]
110
+ gamma: 0.5
111
+
112
+ total_iter: 1000000
113
+ warmup_iter: -1 # no warm up
114
+
115
+ # losses
116
+ pixel_opt:
117
+ type: L1Loss
118
+ loss_weight: 1.0
119
+ reduction: mean
120
+
121
+ # Uncomment these for validation
122
+ # validation settings
123
+ # val:
124
+ # val_freq: !!float 5e3
125
+ # save_img: True
126
+
127
+ # metrics:
128
+ # psnr: # metric name
129
+ # type: calculate_psnr
130
+ # crop_border: 4
131
+ # test_y_channel: false
132
+
133
+ # logging settings
134
+ logger:
135
+ print_freq: 100
136
+ save_checkpoint_freq: !!float 5e3
137
+ use_tb_logger: true
138
+ wandb:
139
+ project: ~
140
+ resume_id: ~
141
+
142
+ # dist training settings
143
+ dist_params:
144
+ backend: nccl
145
+ port: 29500
JarvisIR/package/agent_tools/ESRGAN/options/train_realesrnet_x4plus.yml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_RealESRNetx4plus_1000k_B12G4
3
+ model_type: RealESRNetModel
4
+ scale: 4
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
9
+ gt_usm: True # USM the ground-truth
10
+
11
+ # the first degradation process
12
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
13
+ resize_range: [0.15, 1.5]
14
+ gaussian_noise_prob: 0.5
15
+ noise_range: [1, 30]
16
+ poisson_scale_range: [0.05, 3]
17
+ gray_noise_prob: 0.4
18
+ jpeg_range: [30, 95]
19
+
20
+ # the second degradation process
21
+ second_blur_prob: 0.8
22
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
23
+ resize_range2: [0.3, 1.2]
24
+ gaussian_noise_prob2: 0.5
25
+ noise_range2: [1, 25]
26
+ poisson_scale_range2: [0.05, 2.5]
27
+ gray_noise_prob2: 0.4
28
+ jpeg_range2: [30, 95]
29
+
30
+ gt_size: 256
31
+ queue_size: 180
32
+
33
+ # dataset and data loader settings
34
+ datasets:
35
+ train:
36
+ name: DF2K+OST
37
+ type: RealESRGANDataset
38
+ dataroot_gt: datasets/DF2K
39
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40
+ io_backend:
41
+ type: disk
42
+
43
+ blur_kernel_size: 21
44
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
45
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
46
+ sinc_prob: 0.1
47
+ blur_sigma: [0.2, 3]
48
+ betag_range: [0.5, 4]
49
+ betap_range: [1, 2]
50
+
51
+ blur_kernel_size2: 21
52
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
53
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
54
+ sinc_prob2: 0.1
55
+ blur_sigma2: [0.2, 1.5]
56
+ betag_range2: [0.5, 4]
57
+ betap_range2: [1, 2]
58
+
59
+ final_sinc_prob: 0.8
60
+
61
+ gt_size: 256
62
+ use_hflip: True
63
+ use_rot: False
64
+
65
+ # data loader
66
+ use_shuffle: true
67
+ num_worker_per_gpu: 5
68
+ batch_size_per_gpu: 12
69
+ dataset_enlarge_ratio: 1
70
+ prefetch_mode: ~
71
+
72
+ # Uncomment these for validation
73
+ # val:
74
+ # name: validation
75
+ # type: PairedImageDataset
76
+ # dataroot_gt: path_to_gt
77
+ # dataroot_lq: path_to_lq
78
+ # io_backend:
79
+ # type: disk
80
+
81
+ # network structures
82
+ network_g:
83
+ type: RRDBNet
84
+ num_in_ch: 3
85
+ num_out_ch: 3
86
+ num_feat: 64
87
+ num_block: 23
88
+ num_grow_ch: 32
89
+
90
+ # path
91
+ path:
92
+ pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
93
+ param_key_g: params_ema
94
+ strict_load_g: true
95
+ resume_state: ~
96
+
97
+ # training settings
98
+ train:
99
+ ema_decay: 0.999
100
+ optim_g:
101
+ type: Adam
102
+ lr: !!float 2e-4
103
+ weight_decay: 0
104
+ betas: [0.9, 0.99]
105
+
106
+ scheduler:
107
+ type: MultiStepLR
108
+ milestones: [1000000]
109
+ gamma: 0.5
110
+
111
+ total_iter: 1000000
112
+ warmup_iter: -1 # no warm up
113
+
114
+ # losses
115
+ pixel_opt:
116
+ type: L1Loss
117
+ loss_weight: 1.0
118
+ reduction: mean
119
+
120
+ # Uncomment these for validation
121
+ # validation settings
122
+ # val:
123
+ # val_freq: !!float 5e3
124
+ # save_img: True
125
+
126
+ # metrics:
127
+ # psnr: # metric name
128
+ # type: calculate_psnr
129
+ # crop_border: 4
130
+ # test_y_channel: false
131
+
132
+ # logging settings
133
+ logger:
134
+ print_freq: 100
135
+ save_checkpoint_freq: !!float 5e3
136
+ use_tb_logger: true
137
+ wandb:
138
+ project: ~
139
+ resume_id: ~
140
+
141
+ # dist training settings
142
+ dist_params:
143
+ backend: nccl
144
+ port: 29500
JarvisIR/package/agent_tools/ESRGAN/realesrgan/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ from .archs import *
3
+ from .data import *
4
+ from .models import *
5
+ from .utils import *
JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import arch modules for registry
6
+ # scan all the files that end with '_arch.py' under the archs folder
7
+ arch_folder = osp.dirname(osp.abspath(__file__))
8
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9
+ # import all the arch modules
10
+ from . import srvgg_arch
11
+ from . import discriminator_arch
JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/discriminator_arch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.utils import spectral_norm
5
+
6
+
7
+ @ARCH_REGISTRY.register()
8
+ class UNetDiscriminatorSN(nn.Module):
9
+ """Defines a U-Net discriminator with spectral normalization (SN)
10
+
11
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12
+
13
+ Arg:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_feat (int): Channel number of base intermediate features. Default: 64.
16
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17
+ """
18
+
19
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20
+ super(UNetDiscriminatorSN, self).__init__()
21
+ self.skip_connection = skip_connection
22
+ norm = spectral_norm
23
+ # the first convolution
24
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25
+ # downsample
26
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
29
+ # upsample
30
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33
+ # extra convolutions
34
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
36
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37
+
38
+ def forward(self, x):
39
+ # downsample
40
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
41
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
42
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
43
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
44
+
45
+ # upsample
46
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
47
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
48
+
49
+ if self.skip_connection:
50
+ x4 = x4 + x2
51
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
52
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
53
+
54
+ if self.skip_connection:
55
+ x5 = x5 + x1
56
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
57
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
58
+
59
+ if self.skip_connection:
60
+ x6 = x6 + x0
61
+
62
+ # extra convolutions
63
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
64
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
65
+ out = self.conv9(out)
66
+
67
+ return out
JarvisIR/package/agent_tools/ESRGAN/realesrgan/archs/srvgg_arch.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ @ARCH_REGISTRY.register()
7
+ class SRVGGNetCompact(nn.Module):
8
+ """A compact VGG-style network structure for super-resolution.
9
+
10
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
11
+ conducted on the HR feature space.
12
+
13
+ Args:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_out_ch (int): Channel number of outputs. Default: 3.
16
+ num_feat (int): Channel number of intermediate features. Default: 64.
17
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
18
+ upscale (int): Upsampling factor. Default: 4.
19
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
20
+ """
21
+
22
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
23
+ super(SRVGGNetCompact, self).__init__()
24
+ self.num_in_ch = num_in_ch
25
+ self.num_out_ch = num_out_ch
26
+ self.num_feat = num_feat
27
+ self.num_conv = num_conv
28
+ self.upscale = upscale
29
+ self.act_type = act_type
30
+
31
+ self.body = nn.ModuleList()
32
+ # the first conv
33
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
34
+ # the first activation
35
+ if act_type == 'relu':
36
+ activation = nn.ReLU(inplace=True)
37
+ elif act_type == 'prelu':
38
+ activation = nn.PReLU(num_parameters=num_feat)
39
+ elif act_type == 'leakyrelu':
40
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
41
+ self.body.append(activation)
42
+
43
+ # the body structure
44
+ for _ in range(num_conv):
45
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
46
+ # activation
47
+ if act_type == 'relu':
48
+ activation = nn.ReLU(inplace=True)
49
+ elif act_type == 'prelu':
50
+ activation = nn.PReLU(num_parameters=num_feat)
51
+ elif act_type == 'leakyrelu':
52
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
53
+ self.body.append(activation)
54
+
55
+ # the last conv
56
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
57
+ # upsample
58
+ self.upsampler = nn.PixelShuffle(upscale)
59
+
60
+ def forward(self, x):
61
+ out = x
62
+ for i in range(0, len(self.body)):
63
+ out = self.body[i](out)
64
+
65
+ out = self.upsampler(out)
66
+ # add the nearest upsampled image, so that the network learns the residual
67
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
68
+ out += base
69
+ return out
JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import dataset modules for registry
6
+ # scan all the files that end with '_dataset.py' under the data folder
7
+ data_folder = osp.dirname(osp.abspath(__file__))
8
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9
+ # import all the dataset modules
10
+ from . import realesrgan_paired_dataset
11
+ from . import realesrgan_dataset
JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_dataset.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ import torch
9
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
10
+ from basicsr.data.transforms import augment
11
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
12
+ from basicsr.utils.registry import DATASET_REGISTRY
13
+ from torch.utils import data as data
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class RealESRGANDataset(data.Dataset):
18
+ """Dataset used for Real-ESRGAN model:
19
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
20
+
21
+ It loads gt (Ground-Truth) images, and augments them.
22
+ It also generates blur kernels and sinc kernels for generating low-quality images.
23
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ use_hflip (bool): Use horizontal flips.
31
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
32
+ Please see more options in the codes.
33
+ """
34
+
35
+ def __init__(self, opt):
36
+ super(RealESRGANDataset, self).__init__()
37
+ self.opt = opt
38
+ self.file_client = None
39
+ self.io_backend_opt = opt['io_backend']
40
+ self.gt_folder = opt['dataroot_gt']
41
+
42
+ # file client (lmdb io backend)
43
+ if self.io_backend_opt['type'] == 'lmdb':
44
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
45
+ self.io_backend_opt['client_keys'] = ['gt']
46
+ if not self.gt_folder.endswith('.lmdb'):
47
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
48
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
49
+ self.paths = [line.split('.')[0] for line in fin]
50
+ else:
51
+ # disk backend with meta_info
52
+ # Each line in the meta_info describes the relative path to an image
53
+ with open(self.opt['meta_info']) as fin:
54
+ paths = [line.strip().split(' ')[0] for line in fin]
55
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
56
+
57
+ # blur settings for the first degradation
58
+ self.blur_kernel_size = opt['blur_kernel_size']
59
+ self.kernel_list = opt['kernel_list']
60
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
61
+ self.blur_sigma = opt['blur_sigma']
62
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
63
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
64
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
65
+
66
+ # blur settings for the second degradation
67
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
68
+ self.kernel_list2 = opt['kernel_list2']
69
+ self.kernel_prob2 = opt['kernel_prob2']
70
+ self.blur_sigma2 = opt['blur_sigma2']
71
+ self.betag_range2 = opt['betag_range2']
72
+ self.betap_range2 = opt['betap_range2']
73
+ self.sinc_prob2 = opt['sinc_prob2']
74
+
75
+ # a final sinc filter
76
+ self.final_sinc_prob = opt['final_sinc_prob']
77
+
78
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
79
+ # TODO: kernel range is now hard-coded, should be in the configure file
80
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
81
+ self.pulse_tensor[10, 10] = 1
82
+
83
+ def __getitem__(self, index):
84
+ if self.file_client is None:
85
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
86
+
87
+ # -------------------------------- Load gt images -------------------------------- #
88
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
89
+ gt_path = self.paths[index]
90
+ # avoid errors caused by high latency in reading files
91
+ retry = 3
92
+ while retry > 0:
93
+ try:
94
+ img_bytes = self.file_client.get(gt_path, 'gt')
95
+ except (IOError, OSError) as e:
96
+ logger = get_root_logger()
97
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
98
+ # change another file to read
99
+ index = random.randint(0, self.__len__())
100
+ gt_path = self.paths[index]
101
+ time.sleep(1) # sleep 1s for occasional server congestion
102
+ else:
103
+ break
104
+ finally:
105
+ retry -= 1
106
+ img_gt = imfrombytes(img_bytes, float32=True)
107
+
108
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
109
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
110
+
111
+ # crop or pad to 400
112
+ # TODO: 400 is hard-coded. You may change it accordingly
113
+ h, w = img_gt.shape[0:2]
114
+ crop_pad_size = 400
115
+ # pad
116
+ if h < crop_pad_size or w < crop_pad_size:
117
+ pad_h = max(0, crop_pad_size - h)
118
+ pad_w = max(0, crop_pad_size - w)
119
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
120
+ # crop
121
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
122
+ h, w = img_gt.shape[0:2]
123
+ # randomly choose top and left coordinates
124
+ top = random.randint(0, h - crop_pad_size)
125
+ left = random.randint(0, w - crop_pad_size)
126
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
127
+
128
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
129
+ kernel_size = random.choice(self.kernel_range)
130
+ if np.random.uniform() < self.opt['sinc_prob']:
131
+ # this sinc filter setting is for kernels ranging from [7, 21]
132
+ if kernel_size < 13:
133
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
134
+ else:
135
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
136
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
137
+ else:
138
+ kernel = random_mixed_kernels(
139
+ self.kernel_list,
140
+ self.kernel_prob,
141
+ kernel_size,
142
+ self.blur_sigma,
143
+ self.blur_sigma, [-math.pi, math.pi],
144
+ self.betag_range,
145
+ self.betap_range,
146
+ noise_range=None)
147
+ # pad kernel
148
+ pad_size = (21 - kernel_size) // 2
149
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
150
+
151
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
152
+ kernel_size = random.choice(self.kernel_range)
153
+ if np.random.uniform() < self.opt['sinc_prob2']:
154
+ if kernel_size < 13:
155
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
156
+ else:
157
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
158
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
159
+ else:
160
+ kernel2 = random_mixed_kernels(
161
+ self.kernel_list2,
162
+ self.kernel_prob2,
163
+ kernel_size,
164
+ self.blur_sigma2,
165
+ self.blur_sigma2, [-math.pi, math.pi],
166
+ self.betag_range2,
167
+ self.betap_range2,
168
+ noise_range=None)
169
+
170
+ # pad kernel
171
+ pad_size = (21 - kernel_size) // 2
172
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
173
+
174
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
175
+ if np.random.uniform() < self.opt['final_sinc_prob']:
176
+ kernel_size = random.choice(self.kernel_range)
177
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
178
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
179
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
180
+ else:
181
+ sinc_kernel = self.pulse_tensor
182
+
183
+ # BGR to RGB, HWC to CHW, numpy to tensor
184
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
185
+ kernel = torch.FloatTensor(kernel)
186
+ kernel2 = torch.FloatTensor(kernel2)
187
+
188
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
189
+ return return_d
190
+
191
+ def __len__(self):
192
+ return len(self.paths)
JarvisIR/package/agent_tools/ESRGAN/realesrgan/data/realesrgan_paired_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
3
+ from basicsr.data.transforms import augment, paired_random_crop
4
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
5
+ from basicsr.utils.registry import DATASET_REGISTRY
6
+ from torch.utils import data as data
7
+ from torchvision.transforms.functional import normalize
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class RealESRGANPairedDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15
+
16
+ There are three modes:
17
+ 1. 'lmdb': Use lmdb files.
18
+ If opt['io_backend'] == lmdb.
19
+ 2. 'meta_info': Use meta information file to generate paths.
20
+ If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
21
+ 3. 'folder': Scan folders to generate paths.
22
+ The rest.
23
+
24
+ Args:
25
+ opt (dict): Config for train datasets. It contains the following keys:
26
+ dataroot_gt (str): Data root path for gt.
27
+ dataroot_lq (str): Data root path for lq.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31
+ Default: '{}'.
32
+ gt_size (int): Cropped patched size for gt patches.
33
+ use_hflip (bool): Use horizontal flips.
34
+ use_rot (bool): Use rotation (use vertical flip and transposing h
35
+ and w for implementation).
36
+
37
+ scale (bool): Scale, which will be added automatically.
38
+ phase (str): 'train' or 'val'.
39
+ """
40
+
41
+ def __init__(self, opt):
42
+ super(RealESRGANPairedDataset, self).__init__()
43
+ self.opt = opt
44
+ self.file_client = None
45
+ self.io_backend_opt = opt['io_backend']
46
+ # mean and std for normalizing the input images
47
+ self.mean = opt['mean'] if 'mean' in opt else None
48
+ self.std = opt['std'] if 'std' in opt else None
49
+
50
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
51
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
52
+
53
+ # file client (lmdb io backend)
54
+ if self.io_backend_opt['type'] == 'lmdb':
55
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
56
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
57
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
58
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
59
+ # disk backend with meta_info
60
+ # Each line in the meta_info describes the relative path to an image
61
+ with open(self.opt['meta_info']) as fin:
62
+ paths = [line.strip() for line in fin]
63
+ self.paths = []
64
+ for path in paths:
65
+ gt_path, lq_path = path.split(', ')
66
+ gt_path = os.path.join(self.gt_folder, gt_path)
67
+ lq_path = os.path.join(self.lq_folder, lq_path)
68
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
69
+ else:
70
+ # disk backend
71
+ # it will scan the whole folder to get meta info
72
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
73
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
74
+
75
+ def __getitem__(self, index):
76
+ if self.file_client is None:
77
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
78
+
79
+ scale = self.opt['scale']
80
+
81
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
82
+ # image range: [0, 1], float32.
83
+ gt_path = self.paths[index]['gt_path']
84
+ img_bytes = self.file_client.get(gt_path, 'gt')
85
+ img_gt = imfrombytes(img_bytes, float32=True)
86
+ lq_path = self.paths[index]['lq_path']
87
+ img_bytes = self.file_client.get(lq_path, 'lq')
88
+ img_lq = imfrombytes(img_bytes, float32=True)
89
+
90
+ # augmentation for training
91
+ if self.opt['phase'] == 'train':
92
+ gt_size = self.opt['gt_size']
93
+ # random crop
94
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
95
+ # flip, rotation
96
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
97
+
98
+ # BGR to RGB, HWC to CHW, numpy to tensor
99
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
100
+ # normalize
101
+ if self.mean is not None or self.std is not None:
102
+ normalize(img_lq, self.mean, self.std, inplace=True)
103
+ normalize(img_gt, self.mean, self.std, inplace=True)
104
+
105
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from basicsr.utils import scandir
3
+ from os import path as osp
4
+
5
+ # automatically scan and import model modules for registry
6
+ # scan all the files that end with '_model.py' under the model folder
7
+ model_folder = osp.dirname(osp.abspath(__file__))
8
+ model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9
+ # import all the model modules
10
+ from . import realesrgan_model
11
+ from . import realesrnet_model
JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrgan_model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5
+ from basicsr.data.transforms import paired_random_crop
6
+ from basicsr.models.srgan_model import SRGANModel
7
+ from basicsr.utils import DiffJPEG, USMSharp
8
+ from basicsr.utils.img_process_util import filter2D
9
+ from basicsr.utils.registry import MODEL_REGISTRY
10
+ from collections import OrderedDict
11
+ from torch.nn import functional as F
12
+
13
+
14
+ @MODEL_REGISTRY.register()
15
+ class RealESRGANModel(SRGANModel):
16
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
17
+
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ super(RealESRGANModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
+
29
+ @torch.no_grad()
30
+ def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
+ # initialize
38
+ b, c, h, w = self.lq.size()
39
+ if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
+ _, c, h, w = self.gt.size()
43
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
+ self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
+ # do dequeue and enqueue
47
+ # shuffle
48
+ idx = torch.randperm(self.queue_size)
49
+ self.queue_lr = self.queue_lr[idx]
50
+ self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
+
58
+ self.lq = lq_dequeue
59
+ self.gt = gt_dequeue
60
+ else:
61
+ # only do enqueue
62
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64
+ self.queue_ptr = self.queue_ptr + b
65
+
66
+ @torch.no_grad()
67
+ def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
+ # training data synthesis
72
+ self.gt = data['gt'].to(self.device)
73
+ self.gt_usm = self.usm_sharpener(self.gt)
74
+
75
+ self.kernel1 = data['kernel1'].to(self.device)
76
+ self.kernel2 = data['kernel2'].to(self.device)
77
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
78
+
79
+ ori_h, ori_w = self.gt.size()[2:4]
80
+
81
+ # ----------------------- The first degradation process ----------------------- #
82
+ # blur
83
+ out = filter2D(self.gt_usm, self.kernel1)
84
+ # random resize
85
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
86
+ if updown_type == 'up':
87
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
88
+ elif updown_type == 'down':
89
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
90
+ else:
91
+ scale = 1
92
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
93
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
94
+ # add noise
95
+ gray_noise_prob = self.opt['gray_noise_prob']
96
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
97
+ out = random_add_gaussian_noise_pt(
98
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
99
+ else:
100
+ out = random_add_poisson_noise_pt(
101
+ out,
102
+ scale_range=self.opt['poisson_scale_range'],
103
+ gray_prob=gray_noise_prob,
104
+ clip=True,
105
+ rounds=False)
106
+ # JPEG compression
107
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
108
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
109
+ out = self.jpeger(out, quality=jpeg_p)
110
+
111
+ # ----------------------- The second degradation process ----------------------- #
112
+ # blur
113
+ if np.random.uniform() < self.opt['second_blur_prob']:
114
+ out = filter2D(out, self.kernel2)
115
+ # random resize
116
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
117
+ if updown_type == 'up':
118
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
119
+ elif updown_type == 'down':
120
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
121
+ else:
122
+ scale = 1
123
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
124
+ out = F.interpolate(
125
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
126
+ # add noise
127
+ gray_noise_prob = self.opt['gray_noise_prob2']
128
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
129
+ out = random_add_gaussian_noise_pt(
130
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
131
+ else:
132
+ out = random_add_poisson_noise_pt(
133
+ out,
134
+ scale_range=self.opt['poisson_scale_range2'],
135
+ gray_prob=gray_noise_prob,
136
+ clip=True,
137
+ rounds=False)
138
+
139
+ # JPEG compression + the final sinc filter
140
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
141
+ # as one operation.
142
+ # We consider two orders:
143
+ # 1. [resize back + sinc filter] + JPEG compression
144
+ # 2. JPEG compression + [resize back + sinc filter]
145
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
146
+ if np.random.uniform() < 0.5:
147
+ # resize back + the final sinc filter
148
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
149
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
150
+ out = filter2D(out, self.sinc_kernel)
151
+ # JPEG compression
152
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
153
+ out = torch.clamp(out, 0, 1)
154
+ out = self.jpeger(out, quality=jpeg_p)
155
+ else:
156
+ # JPEG compression
157
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
158
+ out = torch.clamp(out, 0, 1)
159
+ out = self.jpeger(out, quality=jpeg_p)
160
+ # resize back + the final sinc filter
161
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
162
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
163
+ out = filter2D(out, self.sinc_kernel)
164
+
165
+ # clamp and round
166
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
167
+
168
+ # random crop
169
+ gt_size = self.opt['gt_size']
170
+ (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
171
+ self.opt['scale'])
172
+
173
+ # training pair pool
174
+ self._dequeue_and_enqueue()
175
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
176
+ self.gt_usm = self.usm_sharpener(self.gt)
177
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
178
+ else:
179
+ # for paired training or validation
180
+ self.lq = data['lq'].to(self.device)
181
+ if 'gt' in data:
182
+ self.gt = data['gt'].to(self.device)
183
+ self.gt_usm = self.usm_sharpener(self.gt)
184
+
185
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
186
+ # do not use the synthetic process during validation
187
+ self.is_train = False
188
+ super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
189
+ self.is_train = True
190
+
191
+ def optimize_parameters(self, current_iter):
192
+ # usm sharpening
193
+ l1_gt = self.gt_usm
194
+ percep_gt = self.gt_usm
195
+ gan_gt = self.gt_usm
196
+ if self.opt['l1_gt_usm'] is False:
197
+ l1_gt = self.gt
198
+ if self.opt['percep_gt_usm'] is False:
199
+ percep_gt = self.gt
200
+ if self.opt['gan_gt_usm'] is False:
201
+ gan_gt = self.gt
202
+
203
+ # optimize net_g
204
+ for p in self.net_d.parameters():
205
+ p.requires_grad = False
206
+
207
+ self.optimizer_g.zero_grad()
208
+ self.output = self.net_g(self.lq)
209
+
210
+ l_g_total = 0
211
+ loss_dict = OrderedDict()
212
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
213
+ # pixel loss
214
+ if self.cri_pix:
215
+ l_g_pix = self.cri_pix(self.output, l1_gt)
216
+ l_g_total += l_g_pix
217
+ loss_dict['l_g_pix'] = l_g_pix
218
+ # perceptual loss
219
+ if self.cri_perceptual:
220
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
221
+ if l_g_percep is not None:
222
+ l_g_total += l_g_percep
223
+ loss_dict['l_g_percep'] = l_g_percep
224
+ if l_g_style is not None:
225
+ l_g_total += l_g_style
226
+ loss_dict['l_g_style'] = l_g_style
227
+ # gan loss
228
+ fake_g_pred = self.net_d(self.output)
229
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
230
+ l_g_total += l_g_gan
231
+ loss_dict['l_g_gan'] = l_g_gan
232
+
233
+ l_g_total.backward()
234
+ self.optimizer_g.step()
235
+
236
+ # optimize net_d
237
+ for p in self.net_d.parameters():
238
+ p.requires_grad = True
239
+
240
+ self.optimizer_d.zero_grad()
241
+ # real
242
+ real_d_pred = self.net_d(gan_gt)
243
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
244
+ loss_dict['l_d_real'] = l_d_real
245
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
246
+ l_d_real.backward()
247
+ # fake
248
+ fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
249
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
250
+ loss_dict['l_d_fake'] = l_d_fake
251
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
252
+ l_d_fake.backward()
253
+ self.optimizer_d.step()
254
+
255
+ if self.ema_decay > 0:
256
+ self.model_ema(decay=self.ema_decay)
257
+
258
+ self.log_dict = self.reduce_loss_dict(loss_dict)
JarvisIR/package/agent_tools/ESRGAN/realesrgan/models/realesrnet_model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
5
+ from basicsr.data.transforms import paired_random_crop
6
+ from basicsr.models.sr_model import SRModel
7
+ from basicsr.utils import DiffJPEG, USMSharp
8
+ from basicsr.utils.img_process_util import filter2D
9
+ from basicsr.utils.registry import MODEL_REGISTRY
10
+ from torch.nn import functional as F
11
+
12
+
13
+ @MODEL_REGISTRY.register()
14
+ class RealESRNetModel(SRModel):
15
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
16
+
17
+ It is trained without GAN losses.
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ super(RealESRNetModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
+
29
+ @torch.no_grad()
30
+ def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
+ # initialize
38
+ b, c, h, w = self.lq.size()
39
+ if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
+ _, c, h, w = self.gt.size()
43
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
+ self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
+ # do dequeue and enqueue
47
+ # shuffle
48
+ idx = torch.randperm(self.queue_size)
49
+ self.queue_lr = self.queue_lr[idx]
50
+ self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
+
58
+ self.lq = lq_dequeue
59
+ self.gt = gt_dequeue
60
+ else:
61
+ # only do enqueue
62
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
63
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
64
+ self.queue_ptr = self.queue_ptr + b
65
+
66
+ @torch.no_grad()
67
+ def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
+ # training data synthesis
72
+ self.gt = data['gt'].to(self.device)
73
+ # USM sharpen the GT images
74
+ if self.opt['gt_usm'] is True:
75
+ self.gt = self.usm_sharpener(self.gt)
76
+
77
+ self.kernel1 = data['kernel1'].to(self.device)
78
+ self.kernel2 = data['kernel2'].to(self.device)
79
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
80
+
81
+ ori_h, ori_w = self.gt.size()[2:4]
82
+
83
+ # ----------------------- The first degradation process ----------------------- #
84
+ # blur
85
+ out = filter2D(self.gt, self.kernel1)
86
+ # random resize
87
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
88
+ if updown_type == 'up':
89
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
90
+ elif updown_type == 'down':
91
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
92
+ else:
93
+ scale = 1
94
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
95
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
96
+ # add noise
97
+ gray_noise_prob = self.opt['gray_noise_prob']
98
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
99
+ out = random_add_gaussian_noise_pt(
100
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
101
+ else:
102
+ out = random_add_poisson_noise_pt(
103
+ out,
104
+ scale_range=self.opt['poisson_scale_range'],
105
+ gray_prob=gray_noise_prob,
106
+ clip=True,
107
+ rounds=False)
108
+ # JPEG compression
109
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
110
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
111
+ out = self.jpeger(out, quality=jpeg_p)
112
+
113
+ # ----------------------- The second degradation process ----------------------- #
114
+ # blur
115
+ if np.random.uniform() < self.opt['second_blur_prob']:
116
+ out = filter2D(out, self.kernel2)
117
+ # random resize
118
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
119
+ if updown_type == 'up':
120
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
121
+ elif updown_type == 'down':
122
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
123
+ else:
124
+ scale = 1
125
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
126
+ out = F.interpolate(
127
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
128
+ # add noise
129
+ gray_noise_prob = self.opt['gray_noise_prob2']
130
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
131
+ out = random_add_gaussian_noise_pt(
132
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
133
+ else:
134
+ out = random_add_poisson_noise_pt(
135
+ out,
136
+ scale_range=self.opt['poisson_scale_range2'],
137
+ gray_prob=gray_noise_prob,
138
+ clip=True,
139
+ rounds=False)
140
+
141
+ # JPEG compression + the final sinc filter
142
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
143
+ # as one operation.
144
+ # We consider two orders:
145
+ # 1. [resize back + sinc filter] + JPEG compression
146
+ # 2. JPEG compression + [resize back + sinc filter]
147
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
148
+ if np.random.uniform() < 0.5:
149
+ # resize back + the final sinc filter
150
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
151
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
152
+ out = filter2D(out, self.sinc_kernel)
153
+ # JPEG compression
154
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
155
+ out = torch.clamp(out, 0, 1)
156
+ out = self.jpeger(out, quality=jpeg_p)
157
+ else:
158
+ # JPEG compression
159
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
160
+ out = torch.clamp(out, 0, 1)
161
+ out = self.jpeger(out, quality=jpeg_p)
162
+ # resize back + the final sinc filter
163
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
164
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
165
+ out = filter2D(out, self.sinc_kernel)
166
+
167
+ # clamp and round
168
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
169
+
170
+ # random crop
171
+ gt_size = self.opt['gt_size']
172
+ self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
173
+
174
+ # training pair pool
175
+ self._dequeue_and_enqueue()
176
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
177
+ else:
178
+ # for paired training or validation
179
+ self.lq = data['lq'].to(self.device)
180
+ if 'gt' in data:
181
+ self.gt = data['gt'].to(self.device)
182
+ self.gt_usm = self.usm_sharpener(self.gt)
183
+
184
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185
+ # do not use the synthetic process during validation
186
+ self.is_train = False
187
+ super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
188
+ self.is_train = True
JarvisIR/package/agent_tools/ESRGAN/realesrgan/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import os.path as osp
3
+ from basicsr.train import train_pipeline
4
+
5
+ import realesrgan.archs
6
+ import realesrgan.data
7
+ import realesrgan.models
8
+
9
+ if __name__ == '__main__':
10
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11
+ train_pipeline(root_path)
JarvisIR/package/agent_tools/ESRGAN/realesrgan/utils.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import queue
6
+ import threading
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from torch.nn import functional as F
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+
14
+ class RealESRGANer():
15
+ """A helper class for upsampling images with RealESRGAN.
16
+
17
+ Args:
18
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
19
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
20
+ model (nn.Module): The defined network. Default: None.
21
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
22
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
23
+ 0 denotes for do not use tile. Default: 0.
24
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
25
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
26
+ half (float): Whether to use half precision during inference. Default: False.
27
+ """
28
+
29
+ def __init__(self,
30
+ scale,
31
+ model_path,
32
+ dni_weight=None,
33
+ model=None,
34
+ tile=0,
35
+ tile_pad=10,
36
+ pre_pad=10,
37
+ half=False,
38
+ device=None,
39
+ gpu_id=None):
40
+ self.scale = scale
41
+ self.tile_size = tile
42
+ self.tile_pad = tile_pad
43
+ self.pre_pad = pre_pad
44
+ self.mod_scale = None
45
+ self.half = half
46
+
47
+ # initialize model
48
+ self.device = device
49
+
50
+ if isinstance(model_path, list):
51
+ # dni
52
+ assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
53
+ loadnet = self.dni(model_path[0], model_path[1], dni_weight)
54
+ else:
55
+ # if the model_path starts with https, it will first download models to the folder: weights
56
+ if model_path.startswith('https://'):
57
+ model_path = load_file_from_url(
58
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
59
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
60
+
61
+ # prefer to use params_ema
62
+ if 'params_ema' in loadnet:
63
+ keyname = 'params_ema'
64
+ else:
65
+ keyname = 'params'
66
+ model.load_state_dict(loadnet[keyname], strict=True)
67
+
68
+ model.eval()
69
+ self.model = model.to(self.device)
70
+ if self.half:
71
+ self.model = self.model.half()
72
+
73
+ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
74
+ """Deep network interpolation.
75
+
76
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
77
+ """
78
+ net_a = torch.load(net_a, map_location=torch.device(loc))
79
+ net_b = torch.load(net_b, map_location=torch.device(loc))
80
+ for k, v_a in net_a[key].items():
81
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
82
+ return net_a
83
+
84
+ def pre_process(self, img):
85
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
86
+ """
87
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
88
+ self.img = img.unsqueeze(0).to(self.device)
89
+ if self.half:
90
+ self.img = self.img.half()
91
+
92
+ # pre_pad
93
+ if self.pre_pad != 0:
94
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
95
+ # mod pad for divisible borders
96
+ if self.scale == 2:
97
+ self.mod_scale = 2
98
+ elif self.scale == 1:
99
+ self.mod_scale = 4
100
+ if self.mod_scale is not None:
101
+ self.mod_pad_h, self.mod_pad_w = 0, 0
102
+ _, _, h, w = self.img.size()
103
+ if (h % self.mod_scale != 0):
104
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
105
+ if (w % self.mod_scale != 0):
106
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
107
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
108
+
109
+ def process(self):
110
+ # model inference
111
+ self.output = self.model(self.img)
112
+
113
+ def tile_process(self):
114
+ """It will first crop input images to tiles, and then process each tile.
115
+ Finally, all the processed tiles are merged into one images.
116
+
117
+ Modified from: https://github.com/ata4/esrgan-launcher
118
+ """
119
+ batch, channel, height, width = self.img.shape
120
+ output_height = height * self.scale
121
+ output_width = width * self.scale
122
+ output_shape = (batch, channel, output_height, output_width)
123
+
124
+ # start with black image
125
+ self.output = self.img.new_zeros(output_shape)
126
+ tiles_x = math.ceil(width / self.tile_size)
127
+ tiles_y = math.ceil(height / self.tile_size)
128
+
129
+ # loop over all tiles
130
+ for y in range(tiles_y):
131
+ for x in range(tiles_x):
132
+ # extract tile from input image
133
+ ofs_x = x * self.tile_size
134
+ ofs_y = y * self.tile_size
135
+ # input tile area on total image
136
+ input_start_x = ofs_x
137
+ input_end_x = min(ofs_x + self.tile_size, width)
138
+ input_start_y = ofs_y
139
+ input_end_y = min(ofs_y + self.tile_size, height)
140
+
141
+ # input tile area on total image with padding
142
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
143
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
144
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
145
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
146
+
147
+ # input tile dimensions
148
+ input_tile_width = input_end_x - input_start_x
149
+ input_tile_height = input_end_y - input_start_y
150
+ tile_idx = y * tiles_x + x + 1
151
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
152
+
153
+ # upscale tile
154
+ try:
155
+ with torch.no_grad():
156
+ output_tile = self.model(input_tile)
157
+ except RuntimeError as error:
158
+ print('Error', error)
159
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
160
+
161
+ # output tile area on total image
162
+ output_start_x = input_start_x * self.scale
163
+ output_end_x = input_end_x * self.scale
164
+ output_start_y = input_start_y * self.scale
165
+ output_end_y = input_end_y * self.scale
166
+
167
+ # output tile area without padding
168
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
169
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
170
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
171
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
172
+
173
+ # put tile into output image
174
+ self.output[:, :, output_start_y:output_end_y,
175
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
176
+ output_start_x_tile:output_end_x_tile]
177
+
178
+ def post_process(self):
179
+ # remove extra pad
180
+ if self.mod_scale is not None:
181
+ _, _, h, w = self.output.size()
182
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
183
+ # remove prepad
184
+ if self.pre_pad != 0:
185
+ _, _, h, w = self.output.size()
186
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
187
+ return self.output
188
+
189
+ @torch.no_grad()
190
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
191
+ h_input, w_input = img.shape[0:2]
192
+ # img: numpy
193
+ img = img.astype(np.float32)
194
+ if np.max(img) > 256: # 16-bit image
195
+ max_range = 65535
196
+ print('\tInput is a 16-bit image')
197
+ else:
198
+ max_range = 255
199
+ img = img / max_range
200
+ if len(img.shape) == 2: # gray image
201
+ img_mode = 'L'
202
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
203
+ elif img.shape[2] == 4: # RGBA image with alpha channel
204
+ img_mode = 'RGBA'
205
+ alpha = img[:, :, 3]
206
+ img = img[:, :, 0:3]
207
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
208
+ if alpha_upsampler == 'realesrgan':
209
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
210
+ else:
211
+ img_mode = 'RGB'
212
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
213
+
214
+ # ------------------- process image (without the alpha channel) ------------------- #
215
+ self.pre_process(img)
216
+ if self.tile_size > 0:
217
+ self.tile_process()
218
+ else:
219
+ self.process()
220
+ output_img = self.post_process()
221
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
222
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
223
+ if img_mode == 'L':
224
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
225
+
226
+ # ------------------- process the alpha channel if necessary ------------------- #
227
+ if img_mode == 'RGBA':
228
+ if alpha_upsampler == 'realesrgan':
229
+ self.pre_process(alpha)
230
+ if self.tile_size > 0:
231
+ self.tile_process()
232
+ else:
233
+ self.process()
234
+ output_alpha = self.post_process()
235
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
236
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
237
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
238
+ else: # use the cv2 resize for alpha channel
239
+ h, w = alpha.shape[0:2]
240
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
241
+
242
+ # merge the alpha channel
243
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
244
+ output_img[:, :, 3] = output_alpha
245
+
246
+ # ------------------------------ return ------------------------------ #
247
+ if max_range == 65535: # 16-bit image
248
+ output = (output_img * 65535.0).round().astype(np.uint16)
249
+ else:
250
+ output = (output_img * 255.0).round().astype(np.uint8)
251
+
252
+ if outscale is not None and outscale != float(self.scale):
253
+ output = cv2.resize(
254
+ output, (
255
+ int(w_input * outscale),
256
+ int(h_input * outscale),
257
+ ), interpolation=cv2.INTER_LANCZOS4)
258
+
259
+ return output, img_mode
260
+
261
+
262
+ class PrefetchReader(threading.Thread):
263
+ """Prefetch images.
264
+
265
+ Args:
266
+ img_list (list[str]): A image list of image paths to be read.
267
+ num_prefetch_queue (int): Number of prefetch queue.
268
+ """
269
+
270
+ def __init__(self, img_list, num_prefetch_queue):
271
+ super().__init__()
272
+ self.que = queue.Queue(num_prefetch_queue)
273
+ self.img_list = img_list
274
+
275
+ def run(self):
276
+ for img_path in self.img_list:
277
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
278
+ self.que.put(img)
279
+
280
+ self.que.put(None)
281
+
282
+ def __next__(self):
283
+ next_item = self.que.get()
284
+ if next_item is None:
285
+ raise StopIteration
286
+ return next_item
287
+
288
+ def __iter__(self):
289
+ return self
290
+
291
+
292
+ class IOConsumer(threading.Thread):
293
+
294
+ def __init__(self, opt, que, qid):
295
+ super().__init__()
296
+ self._queue = que
297
+ self.qid = qid
298
+ self.opt = opt
299
+
300
+ def run(self):
301
+ while True:
302
+ msg = self._queue.get()
303
+ if isinstance(msg, str) and msg == 'quit':
304
+ break
305
+
306
+ output = msg['output']
307
+ save_path = msg['save_path']
308
+ cv2.imwrite(save_path, output)
309
+ print(f'IO worker {self.qid} is done.')
JarvisIR/package/agent_tools/HVICIDNet/inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from .net.CIDNet import CIDNet
6
+ import torchvision.transforms as transforms
7
+ import torch.nn.functional as F
8
+ import platform
9
+ import argparse
10
+
11
+ def load_hvicidnet_model(path, device):
12
+ model = CIDNet().to(device)
13
+ model.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage))
14
+ model.eval()
15
+ return model
16
+
17
+ def hvicidnet_predict(model, input_img, output_dir, device):
18
+ # Load image if path is provided as string
19
+ if isinstance(input_img, str):
20
+ img_name = os.path.basename(input_img).split('.')[0]
21
+ input_img = Image.open(input_img)
22
+ else:
23
+ img_name = "output" # default name if input is not a path
24
+
25
+ torch.set_grad_enabled(False)
26
+ pil2tensor = transforms.Compose([transforms.ToTensor()])
27
+ input = pil2tensor(input_img)
28
+ factor = 8
29
+ h, w = input.shape[1], input.shape[2]
30
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
31
+ padh = H - h if h % factor != 0 else 0
32
+ padw = W - w if w % factor != 0 else 0
33
+ input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
34
+ gamma = 1
35
+ alpha_s = 1.0
36
+ alpha_i = 1.0
37
+ with torch.no_grad():
38
+ model.trans.alpha_s = alpha_s
39
+ model.trans.alpha = alpha_i
40
+ output = model(input.cuda()**gamma)
41
+
42
+ output = torch.clamp(output.to(device),0,1).cpu()
43
+ output = output[:, :, :h, :w]
44
+ enhanced_img = transforms.ToPILImage()(output.squeeze(0))
45
+ if isinstance(input_img, str):
46
+ original_img = Image.open(input_img)
47
+ enhanced_img = enhanced_img.resize(original_img.size, Image.LANCZOS)
48
+
49
+ # Save the output
50
+ os.makedirs(output_dir, exist_ok=True)
51
+ save_path = os.path.join(output_dir, img_name+'.png')
52
+ enhanced_img.save(save_path)
53
+ return save_path
54
+
55
+
56
+ if __name__ == '__main__':
57
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
58
+ model_path = "checkpoints/HVICIDNet/generalization.pth"
59
+ img_path = "./Test_Input/108.png"
60
+ output_folder = "./output"
61
+ model = load_hvicidnet_model(model_path, device)
62
+ save_path = hvicidnet_predict(model, img_path, output_folder, device)
63
+ print(f"processed image saved to: {save_path}")
64
+
65
+
JarvisIR/package/agent_tools/HVICIDNet/loss/loss_utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import functools
4
+ from math import exp
5
+ from torch.autograd import Variable
6
+
7
+
8
+
9
+
10
+ def reduce_loss(loss, reduction):
11
+ """Reduce loss as specified.
12
+
13
+ Args:
14
+ loss (Tensor): Elementwise loss tensor.
15
+ reduction (str): Options are 'none', 'mean' and 'sum'.
16
+
17
+ Returns:
18
+ Tensor: Reduced loss tensor.
19
+ """
20
+ reduction_enum = F._Reduction.get_enum(reduction)
21
+ # none: 0, elementwise_mean:1, sum: 2
22
+ if reduction_enum == 0:
23
+ return loss
24
+ elif reduction_enum == 1:
25
+ return loss.mean()
26
+ else:
27
+ return loss.sum()
28
+
29
+
30
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
31
+ """Apply element-wise weight and reduce loss.
32
+
33
+ Args:
34
+ loss (Tensor): Element-wise loss.
35
+ weight (Tensor): Element-wise weights. Default: None.
36
+ reduction (str): Same as built-in losses of PyTorch. Options are
37
+ 'none', 'mean' and 'sum'. Default: 'mean'.
38
+
39
+ Returns:
40
+ Tensor: Loss values.
41
+ """
42
+ # if weight is specified, apply element-wise weight
43
+ if weight is not None:
44
+ assert weight.dim() == loss.dim()
45
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
46
+ loss = loss * weight
47
+
48
+ # if weight is not specified or reduction is sum, just reduce the loss
49
+ if weight is None or reduction == 'sum':
50
+ loss = reduce_loss(loss, reduction)
51
+ # if reduction is mean, then compute mean over weight region
52
+ elif reduction == 'mean':
53
+ if weight.size(1) > 1:
54
+ weight = weight.sum()
55
+ else:
56
+ weight = weight.sum() * loss.size(1)
57
+ loss = loss.sum() / weight
58
+
59
+ return loss
60
+
61
+ def weighted_loss(loss_func):
62
+ """Create a weighted version of a given loss function.
63
+
64
+ To use this decorator, the loss function must have the signature like
65
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
66
+ element-wise loss without any reduction. This decorator will add weight
67
+ and reduction arguments to the function. The decorated function will have
68
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
69
+ **kwargs)`.
70
+
71
+ :Example:
72
+
73
+ >>> import torch
74
+ >>> @weighted_loss
75
+ >>> def l1_loss(pred, target):
76
+ >>> return (pred - target).abs()
77
+
78
+ >>> pred = torch.Tensor([0, 2, 3])
79
+ >>> target = torch.Tensor([1, 1, 1])
80
+ >>> weight = torch.Tensor([1, 0, 1])
81
+
82
+ >>> l1_loss(pred, target)
83
+ tensor(1.3333)
84
+ >>> l1_loss(pred, target, weight)
85
+ tensor(1.5000)
86
+ >>> l1_loss(pred, target, reduction='none')
87
+ tensor([1., 1., 2.])
88
+ >>> l1_loss(pred, target, weight, reduction='sum')
89
+ tensor(3.)
90
+ """
91
+
92
+ @functools.wraps(loss_func)
93
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
94
+ # get element-wise loss
95
+ loss = loss_func(pred, target, **kwargs)
96
+ loss = weight_reduce_loss(loss, weight, reduction)
97
+ return loss
98
+
99
+ return wrapper
100
+
101
+ @weighted_loss
102
+ def l1_loss(pred, target):
103
+ return F.l1_loss(pred, target, reduction='none')
104
+
105
+ @weighted_loss
106
+ def mse_loss(pred, target):
107
+ return F.mse_loss(pred, target, reduction='none')
108
+
109
+
110
+
111
+
112
+
113
+ def gaussian(window_size,sigma):
114
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
115
+ return gauss/torch.sum(gauss)
116
+
117
+
118
+ def create_window(window_size,channel=1):
119
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
120
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
121
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
122
+ return window
123
+
124
+
125
+ def map_ssim(img1, img2, window, window_size, channel, size_average=True):
126
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
127
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
128
+
129
+ mu1_sq = mu1.pow(2)
130
+ mu2_sq = mu2.pow(2)
131
+ mu1_mu2 = mu1 * mu2
132
+
133
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
134
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
135
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
136
+
137
+ C1 = 0.01 ** 2
138
+ C2 = 0.03 ** 2
139
+
140
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
141
+
142
+ if size_average:
143
+ return ssim_map.mean()
144
+ else:
145
+ return ssim_map.mean(1).mean(1).mean(1)
JarvisIR/package/agent_tools/HVICIDNet/loss/losses.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from loss.vgg_arch import VGGFeatureExtractor, Registry
5
+ from loss.loss_utils import *
6
+
7
+
8
+ _reduction_modes = ['none', 'mean', 'sum']
9
+
10
+ class L1Loss(nn.Module):
11
+ """L1 (mean absolute error, MAE) loss.
12
+
13
+ Args:
14
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
15
+ reduction (str): Specifies the reduction to apply to the output.
16
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
17
+ """
18
+
19
+ def __init__(self, loss_weight=1.0, reduction='mean'):
20
+ super(L1Loss, self).__init__()
21
+ if reduction not in ['none', 'mean', 'sum']:
22
+ raise ValueError(f'Unsupported reduction mode: {reduction}. '
23
+ f'Supported ones are: {_reduction_modes}')
24
+
25
+ self.loss_weight = loss_weight
26
+ self.reduction = reduction
27
+
28
+ def forward(self, pred, target, weight=None, **kwargs):
29
+ """
30
+ Args:
31
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
32
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
33
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
34
+ weights. Default: None.
35
+ """
36
+ return self.loss_weight * l1_loss(
37
+ pred, target, weight, reduction=self.reduction)
38
+
39
+
40
+
41
+ class EdgeLoss(nn.Module):
42
+ def __init__(self,loss_weight=1.0, reduction='mean'):
43
+ super(EdgeLoss, self).__init__()
44
+ k = torch.Tensor([[.05, .25, .4, .25, .05]])
45
+ self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1).cuda()
46
+
47
+ self.weight = loss_weight
48
+
49
+ def conv_gauss(self, img):
50
+ n_channels, _, kw, kh = self.kernel.shape
51
+ img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
52
+ return F.conv2d(img, self.kernel, groups=n_channels)
53
+
54
+ def laplacian_kernel(self, current):
55
+ filtered = self.conv_gauss(current)
56
+ down = filtered[:,:,::2,::2]
57
+ new_filter = torch.zeros_like(filtered)
58
+ new_filter[:,:,::2,::2] = down*4
59
+ filtered = self.conv_gauss(new_filter)
60
+ diff = current - filtered
61
+ return diff
62
+
63
+ def forward(self, x, y):
64
+ loss = mse_loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
65
+ return loss*self.weight
66
+
67
+
68
+ class PerceptualLoss(nn.Module):
69
+ """Perceptual loss with commonly used style loss.
70
+
71
+ Args:
72
+ layer_weights (dict): The weight for each layer of vgg feature.
73
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
74
+ feature layer (before relu5_4) will be extracted with weight
75
+ 1.0 in calculting losses.
76
+ vgg_type (str): The type of vgg network used as feature extractor.
77
+ Default: 'vgg19'.
78
+ use_input_norm (bool): If True, normalize the input image in vgg.
79
+ Default: True.
80
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
81
+ Default: False.
82
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
83
+ loss will be calculated and the loss will multiplied by the
84
+ weight. Default: 1.0.
85
+ style_weight (float): If `style_weight > 0`, the style loss will be
86
+ calculated and the loss will multiplied by the weight.
87
+ Default: 0.
88
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
89
+ """
90
+
91
+ def __init__(self,
92
+ layer_weights,
93
+ vgg_type='vgg19',
94
+ use_input_norm=True,
95
+ range_norm=True,
96
+ perceptual_weight=1.0,
97
+ style_weight=0.,
98
+ criterion='l1'):
99
+ super(PerceptualLoss, self).__init__()
100
+ self.perceptual_weight = perceptual_weight
101
+ self.style_weight = style_weight
102
+ self.layer_weights = layer_weights
103
+ self.vgg = VGGFeatureExtractor(
104
+ layer_name_list=list(layer_weights.keys()),
105
+ vgg_type=vgg_type,
106
+ use_input_norm=use_input_norm,
107
+ range_norm=range_norm)
108
+
109
+ self.criterion_type = criterion
110
+ if self.criterion_type == 'l1':
111
+ self.criterion = torch.nn.L1Loss()
112
+ elif self.criterion_type == 'l2':
113
+ self.criterion = torch.nn.L2loss()
114
+ elif self.criterion_type == 'mse':
115
+ self.criterion = torch.nn.MSELoss(reduction='mean')
116
+ elif self.criterion_type == 'fro':
117
+ self.criterion = None
118
+ else:
119
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
120
+
121
+ def forward(self, x, gt):
122
+ """Forward function.
123
+
124
+ Args:
125
+ x (Tensor): Input tensor with shape (n, c, h, w).
126
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
127
+
128
+ Returns:
129
+ Tensor: Forward results.
130
+ """
131
+ # extract vgg features
132
+ x_features = self.vgg(x)
133
+ gt_features = self.vgg(gt.detach())
134
+
135
+ # calculate perceptual loss
136
+ if self.perceptual_weight > 0:
137
+ percep_loss = 0
138
+ for k in x_features.keys():
139
+ if self.criterion_type == 'fro':
140
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
141
+ else:
142
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
143
+ percep_loss *= self.perceptual_weight
144
+ else:
145
+ percep_loss = None
146
+
147
+ # calculate style loss
148
+ if self.style_weight > 0:
149
+ style_loss = 0
150
+ for k in x_features.keys():
151
+ if self.criterion_type == 'fro':
152
+ style_loss += torch.norm(
153
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
154
+ else:
155
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
156
+ gt_features[k])) * self.layer_weights[k]
157
+ style_loss *= self.style_weight
158
+ else:
159
+ style_loss = None
160
+
161
+ return percep_loss, style_loss
162
+
163
+
164
+
165
+
166
+ class SSIM(torch.nn.Module):
167
+ def __init__(self, window_size=11, size_average=True,weight=1.):
168
+ super(SSIM, self).__init__()
169
+ self.window_size = window_size
170
+ self.size_average = size_average
171
+ self.channel = 1
172
+ self.window = create_window(window_size, self.channel)
173
+ self.weight = weight
174
+
175
+ def forward(self, img1, img2):
176
+ (_, channel, _, _) = img1.size()
177
+
178
+ if channel == self.channel and self.window.data.type() == img1.data.type():
179
+ window = self.window
180
+ else:
181
+ window = create_window(self.window_size, channel)
182
+
183
+ if img1.is_cuda:
184
+ window = window.cuda(img1.get_device())
185
+ window = window.type_as(img1)
186
+
187
+ self.window = window
188
+ self.channel = channel
189
+
190
+ return (1. - map_ssim(img1, img2, window, self.window_size, channel, self.size_average)) * self.weight
191
+
192
+
193
+
JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_pris_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
3
+ size 11850
JarvisIR/package/agent_tools/HVICIDNet/loss/niqe_utils.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import math
4
+ import numpy as np
5
+ from scipy.ndimage import convolve
6
+ from scipy.special import gamma
7
+ import torch
8
+
9
+ def cubic(x):
10
+ """cubic function used for calculate_weights_indices."""
11
+ absx = torch.abs(x)
12
+ absx2 = absx**2
13
+ absx3 = absx**3
14
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
15
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
16
+ (absx <= 2)).type_as(absx))
17
+
18
+
19
+
20
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
21
+ """Calculate weights and indices, used for imresize function.
22
+ Args:
23
+ in_length (int): Input length.
24
+ out_length (int): Output length.
25
+ scale (float): Scale factor.
26
+ kernel_width (int): Kernel width.
27
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
28
+ """
29
+
30
+ if (scale < 1) and antialiasing:
31
+ # Use a modified kernel (larger kernel width) to simultaneously
32
+ # interpolate and antialias
33
+ kernel_width = kernel_width / scale
34
+
35
+ # Output-space coordinates
36
+ x = torch.linspace(1, out_length, out_length)
37
+
38
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
39
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
40
+ # space maps to 1.5 in input space.
41
+ u = x / scale + 0.5 * (1 - 1 / scale)
42
+
43
+ # What is the left-most pixel that can be involved in the computation?
44
+ left = torch.floor(u - kernel_width / 2)
45
+
46
+ # What is the maximum number of pixels that can be involved in the
47
+ # computation? Note: it's OK to use an extra pixel here; if the
48
+ # corresponding weights are all zero, it will be eliminated at the end
49
+ # of this function.
50
+ p = math.ceil(kernel_width) + 2
51
+
52
+ # The indices of the input pixels involved in computing the k-th output
53
+ # pixel are in row k of the indices matrix.
54
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
55
+ out_length, p)
56
+
57
+ # The weights used to compute the k-th output pixel are in row k of the
58
+ # weights matrix.
59
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
60
+
61
+ # apply cubic kernel
62
+ if (scale < 1) and antialiasing:
63
+ weights = scale * cubic(distance_to_center * scale)
64
+ else:
65
+ weights = cubic(distance_to_center)
66
+
67
+ # Normalize the weights matrix so that each row sums to 1.
68
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
69
+ weights = weights / weights_sum.expand(out_length, p)
70
+
71
+ # If a column in weights is all zero, get rid of it. only consider the
72
+ # first and last column.
73
+ weights_zero_tmp = torch.sum((weights == 0), 0)
74
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
75
+ indices = indices.narrow(1, 1, p - 2)
76
+ weights = weights.narrow(1, 1, p - 2)
77
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
78
+ indices = indices.narrow(1, 0, p - 2)
79
+ weights = weights.narrow(1, 0, p - 2)
80
+ weights = weights.contiguous()
81
+ indices = indices.contiguous()
82
+ sym_len_s = -indices.min() + 1
83
+ sym_len_e = indices.max() - in_length
84
+ indices = indices + sym_len_s - 1
85
+ return weights, indices, int(sym_len_s), int(sym_len_e)
86
+
87
+ def imresize(img, scale, antialiasing=True):
88
+ """imresize function same as MATLAB.
89
+ It now only supports bicubic.
90
+ The same scale applies for both height and width.
91
+ Args:
92
+ img (Tensor | Numpy array):
93
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
94
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
95
+ scale (float): Scale factor. The same scale applies for both height
96
+ and width.
97
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
98
+ Default: True.
99
+ Returns:
100
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
101
+ """
102
+ squeeze_flag = False
103
+ if type(img).__module__ == np.__name__: # numpy type
104
+ numpy_type = True
105
+ if img.ndim == 2:
106
+ img = img[:, :, None]
107
+ squeeze_flag = True
108
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
109
+ else:
110
+ numpy_type = False
111
+ if img.ndim == 2:
112
+ img = img.unsqueeze(0)
113
+ squeeze_flag = True
114
+
115
+ in_c, in_h, in_w = img.size()
116
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
117
+ kernel_width = 4
118
+ kernel = 'cubic'
119
+
120
+ # get weights and indices
121
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
122
+ antialiasing)
123
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
124
+ antialiasing)
125
+ # process H dimension
126
+ # symmetric copying
127
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
128
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
129
+
130
+ sym_patch = img[:, :sym_len_hs, :]
131
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
132
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
133
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
134
+
135
+ sym_patch = img[:, -sym_len_he:, :]
136
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
137
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
138
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
139
+
140
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
141
+ kernel_width = weights_h.size(1)
142
+ for i in range(out_h):
143
+ idx = int(indices_h[i][0])
144
+ for j in range(in_c):
145
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
146
+
147
+ # process W dimension
148
+ # symmetric copying
149
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
150
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
151
+
152
+ sym_patch = out_1[:, :, :sym_len_ws]
153
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
154
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
155
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
156
+
157
+ sym_patch = out_1[:, :, -sym_len_we:]
158
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
159
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
160
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
161
+
162
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
163
+ kernel_width = weights_w.size(1)
164
+ for i in range(out_w):
165
+ idx = int(indices_w[i][0])
166
+ for j in range(in_c):
167
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
168
+
169
+ if squeeze_flag:
170
+ out_2 = out_2.squeeze(0)
171
+ if numpy_type:
172
+ out_2 = out_2.numpy()
173
+ if not squeeze_flag:
174
+ out_2 = out_2.transpose(1, 2, 0)
175
+
176
+ return out_2
177
+
178
+
179
+ def _convert_input_type_range(img):
180
+ """Convert the type and range of the input image.
181
+ It converts the input image to np.float32 type and range of [0, 1].
182
+ It is mainly used for pre-processing the input image in colorspace
183
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
184
+ Args:
185
+ img (ndarray): The input image. It accepts:
186
+ 1. np.uint8 type with range [0, 255];
187
+ 2. np.float32 type with range [0, 1].
188
+ Returns:
189
+ (ndarray): The converted image with type of np.float32 and range of
190
+ [0, 1].
191
+ """
192
+ img_type = img.dtype
193
+ img = img.astype(np.float32)
194
+ if img_type == np.float32:
195
+ pass
196
+ elif img_type == np.uint8:
197
+ img /= 255.
198
+ else:
199
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
200
+ return img
201
+
202
+
203
+ def _convert_output_type_range(img, dst_type):
204
+ """Convert the type and range of the image according to dst_type.
205
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
206
+ images will be converted to np.uint8 type with range [0, 255]. If
207
+ `dst_type` is np.float32, it converts the image to np.float32 type with
208
+ range [0, 1].
209
+ It is mainly used for post-processing images in colorspace conversion
210
+ functions such as rgb2ycbcr and ycbcr2rgb.
211
+ Args:
212
+ img (ndarray): The image to be converted with np.float32 type and
213
+ range [0, 255].
214
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
215
+ converts the image to np.uint8 type with range [0, 255]. If
216
+ dst_type is np.float32, it converts the image to np.float32 type
217
+ with range [0, 1].
218
+ Returns:
219
+ (ndarray): The converted image with desired type and range.
220
+ """
221
+ if dst_type not in (np.uint8, np.float32):
222
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
223
+ if dst_type == np.uint8:
224
+ img = img.round()
225
+ else:
226
+ img /= 255.
227
+ return img.astype(dst_type)
228
+
229
+
230
+
231
+ def rgb2ycbcr(img, y_only=False):
232
+ """Convert a RGB image to YCbCr image.
233
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
234
+ It implements the ITU-R BT.601 conversion for standard-definition
235
+ television. See more details in
236
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
237
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
238
+ In OpenCV, it implements a JPEG conversion. See more details in
239
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
240
+ Args:
241
+ img (ndarray): The input image. It accepts:
242
+ 1. np.uint8 type with range [0, 255];
243
+ 2. np.float32 type with range [0, 1].
244
+ y_only (bool): Whether to only return Y channel. Default: False.
245
+ Returns:
246
+ ndarray: The converted YCbCr image. The output image has the same type
247
+ and range as input image.
248
+ """
249
+ img_type = img.dtype
250
+ img = _convert_input_type_range(img)
251
+ if y_only:
252
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
253
+ else:
254
+ out_img = np.matmul(
255
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
256
+ out_img = _convert_output_type_range(out_img, img_type)
257
+ return out_img
258
+
259
+
260
+ def bgr2ycbcr(img, y_only=False):
261
+ """Convert a BGR image to YCbCr image.
262
+ The bgr version of rgb2ycbcr.
263
+ It implements the ITU-R BT.601 conversion for standard-definition
264
+ television. See more details in
265
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
266
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
267
+ In OpenCV, it implements a JPEG conversion. See more details in
268
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
269
+ Args:
270
+ img (ndarray): The input image. It accepts:
271
+ 1. np.uint8 type with range [0, 255];
272
+ 2. np.float32 type with range [0, 1].
273
+ y_only (bool): Whether to only return Y channel. Default: False.
274
+ Returns:
275
+ ndarray: The converted YCbCr image. The output image has the same type
276
+ and range as input image.
277
+ """
278
+ img_type = img.dtype
279
+ img = _convert_input_type_range(img)
280
+ if y_only:
281
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
282
+ else:
283
+ out_img = np.matmul(
284
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
285
+ out_img = _convert_output_type_range(out_img, img_type)
286
+ return out_img
287
+
288
+ def ycbcr2rgb(img):
289
+ """Convert a YCbCr image to RGB image.
290
+ This function produces the same results as Matlab's ycbcr2rgb function.
291
+ It implements the ITU-R BT.601 conversion for standard-definition
292
+ television. See more details in
293
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
294
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
295
+ In OpenCV, it implements a JPEG conversion. See more details in
296
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
297
+ Args:
298
+ img (ndarray): The input image. It accepts:
299
+ 1. np.uint8 type with range [0, 255];
300
+ 2. np.float32 type with range [0, 1].
301
+ Returns:
302
+ ndarray: The converted RGB image. The output image has the same type
303
+ and range as input image.
304
+ """
305
+ img_type = img.dtype
306
+ img = _convert_input_type_range(img) * 255
307
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
308
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
309
+ out_img = _convert_output_type_range(out_img, img_type)
310
+ return out_img
311
+
312
+
313
+ def to_y_channel(img):
314
+ """Change to Y channel of YCbCr.
315
+ Args:
316
+ img (ndarray): Images with range [0, 255].
317
+ Returns:
318
+ (ndarray): Images with range [0, 255] (float type) without round.
319
+ """
320
+ img = img.astype(np.float32) / 255.
321
+ if img.ndim == 3 and img.shape[2] == 3:
322
+ img = bgr2ycbcr(img, y_only=True)
323
+ img = img[..., None]
324
+ return img * 255.
325
+
326
+
327
+ def reorder_image(img, input_order='HWC'):
328
+ """Reorder images to 'HWC' order.
329
+ If the input_order is (h, w), return (h, w, 1);
330
+ If the input_order is (c, h, w), return (h, w, c);
331
+ If the input_order is (h, w, c), return as it is.
332
+ Args:
333
+ img (ndarray): Input image.
334
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
335
+ If the input image shape is (h, w), input_order will not have
336
+ effects. Default: 'HWC'.
337
+ Returns:
338
+ ndarray: reordered image.
339
+ """
340
+
341
+ if input_order not in ['HWC', 'CHW']:
342
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
343
+ if len(img.shape) == 2:
344
+ img = img[..., None]
345
+ if input_order == 'CHW':
346
+ img = img.transpose(1, 2, 0)
347
+ return img
348
+
349
+ def rgb2ycbcr_pt(img, y_only=False):
350
+ """Convert RGB images to YCbCr images (PyTorch version).
351
+ It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
352
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
353
+ Args:
354
+ img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
355
+ y_only (bool): Whether to only return Y channel. Default: False.
356
+ Returns:
357
+ (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
358
+ """
359
+ if y_only:
360
+ weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
361
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
362
+ else:
363
+ weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
364
+ bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
365
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
366
+
367
+ out_img = out_img / 255.
368
+ return
369
+
370
+ def tensor2img(tensor):
371
+ im = (255. * tensor).data.cpu().numpy()
372
+ # clamp
373
+ im[im > 255] = 255
374
+ im[im < 0] = 0
375
+ im = im.astype(np.uint8)
376
+ return im
377
+
378
+ def img2tensor(img):
379
+ img = (img / 255.).astype('float32')
380
+ if img.ndim ==2:
381
+ img = np.expand_dims(np.expand_dims(img, axis = 0),axis=0)
382
+ else:
383
+ img = np.transpose(img, (2, 0, 1)) # C, H, W
384
+ img = np.expand_dims(img, axis=0)
385
+ img = np.ascontiguousarray(img, dtype=np.float32)
386
+ tensor = torch.from_numpy(img)
387
+ return tensor
388
+
389
+ def estimate_aggd_param(block):
390
+ """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
391
+ Args:
392
+ block (ndarray): 2D Image block.
393
+ Returns:
394
+ tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
395
+ distribution (Estimating the parames in Equation 7 in the paper).
396
+ """
397
+ block = block.flatten()
398
+ gam = np.arange(0.2, 10.001, 0.001) # len = 9801
399
+ gam_reciprocal = np.reciprocal(gam)
400
+ r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
401
+
402
+ left_std = np.sqrt(np.mean(block[block < 0]**2))
403
+ right_std = np.sqrt(np.mean(block[block > 0]**2))
404
+ gammahat = left_std / right_std
405
+ rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
406
+ rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
407
+ array_position = np.argmin((r_gam - rhatnorm)**2)
408
+
409
+ alpha = gam[array_position]
410
+ beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
411
+ beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
412
+ return (alpha, beta_l, beta_r)
413
+
414
+
415
+ def compute_feature(block):
416
+ """Compute features.
417
+ Args:
418
+ block (ndarray): 2D Image block.
419
+ Returns:
420
+ list: Features with length of 18.
421
+ """
422
+ feat = []
423
+ alpha, beta_l, beta_r = estimate_aggd_param(block)
424
+ feat.extend([alpha, (beta_l + beta_r) / 2])
425
+
426
+ # distortions disturb the fairly regular structure of natural images.
427
+ # This deviation can be captured by analyzing the sample distribution of
428
+ # the products of pairs of adjacent coefficients computed along
429
+ # horizontal, vertical and diagonal orientations.
430
+ shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
431
+ for i in range(len(shifts)):
432
+ shifted_block = np.roll(block, shifts[i], axis=(0, 1))
433
+ alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
434
+ # Eq. 8
435
+ mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
436
+ feat.extend([alpha, mean, beta_l, beta_r])
437
+ return feat
438
+
439
+
440
+ def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
441
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
442
+ ``Paper: Making a "Completely Blind" Image Quality Analyzer``
443
+ This implementation could produce almost the same results as the official
444
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
445
+ Note that we do not include block overlap height and width, since they are
446
+ always 0 in the official implementation.
447
+ For good performance, it is advisable by the official implementation to
448
+ divide the distorted image in to the same size patched as used for the
449
+ construction of multivariate Gaussian model.
450
+ Args:
451
+ img (ndarray): Input image whose quality needs to be computed. The
452
+ image must be a gray or Y (of YCbCr) image with shape (h, w).
453
+ Range [0, 255] with float type.
454
+ mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
455
+ model calculated on the pristine dataset.
456
+ cov_pris_param (ndarray): Covariance of a pre-defined multivariate
457
+ Gaussian model calculated on the pristine dataset.
458
+ gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
459
+ image.
460
+ block_size_h (int): Height of the blocks in to which image is divided.
461
+ Default: 96 (the official recommended value).
462
+ block_size_w (int): Width of the blocks in to which image is divided.
463
+ Default: 96 (the official recommended value).
464
+ """
465
+ assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
466
+ # crop image
467
+ h, w = img.shape
468
+ num_block_h = math.floor(h / block_size_h)
469
+ num_block_w = math.floor(w / block_size_w)
470
+ img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
471
+
472
+ distparam = [] # dist param is actually the multiscale features
473
+ for scale in (1, 2): # perform on two scales (1, 2)
474
+ mu = convolve(img, gaussian_window, mode='nearest')
475
+ sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
476
+ # normalize, as in Eq. 1 in the paper
477
+ img_nomalized = (img - mu) / (sigma + 1)
478
+
479
+ feat = []
480
+ for idx_w in range(num_block_w):
481
+ for idx_h in range(num_block_h):
482
+ # process ecah block
483
+ block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
484
+ idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
485
+ feat.append(compute_feature(block))
486
+
487
+ distparam.append(np.array(feat))
488
+
489
+ if scale == 1:
490
+ img = imresize(img / 255., scale=0.5, antialiasing=True)
491
+ img = img * 255.
492
+
493
+ distparam = np.concatenate(distparam, axis=1)
494
+
495
+ # fit a MVG (multivariate Gaussian) model to distorted patch features
496
+ mu_distparam = np.nanmean(distparam, axis=0)
497
+ # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
498
+ distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
499
+ cov_distparam = np.cov(distparam_no_nan, rowvar=False)
500
+
501
+ # compute niqe quality, Eq. 10 in the paper
502
+ invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
503
+ quality = np.matmul(
504
+ np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
505
+
506
+ quality = np.sqrt(quality)
507
+ quality = float(np.squeeze(quality))
508
+ return quality
509
+
510
+
511
+ def calculate_niqe(img, crop_border=0,input_order='HWC', convert_to='y', **kwargs):
512
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
513
+ ``Paper: Making a "Completely Blind" Image Quality Analyzer``
514
+ This implementation could produce almost the same results as the official
515
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
516
+ > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
517
+ > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
518
+ We use the official params estimated from the pristine dataset.
519
+ We use the recommended block size (96, 96) without overlaps.
520
+ Args:
521
+ img (ndarray): Input image whose quality needs to be computed.
522
+ The input image must be in range [0, 255] with float/int type.
523
+ The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
524
+ If the input order is 'HWC' or 'CHW', it will be converted to gray
525
+ or Y (of YCbCr) image according to the ``convert_to`` argument.
526
+ crop_border (int): Cropped pixels in each edge of an image. These
527
+ pixels are not involved in the metric calculation.
528
+ input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
529
+ Default: 'HWC'.
530
+ convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
531
+ Default: 'y'.
532
+ Returns:
533
+ float: NIQE result.
534
+ """
535
+ # ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
536
+ # we use the official params estimated from the pristine dataset.
537
+ niqe_pris_params = np.load('./loss/niqe_pris_params.npz')
538
+ mu_pris_param = niqe_pris_params['mu_pris_param']
539
+ cov_pris_param = niqe_pris_params['cov_pris_param']
540
+ gaussian_window = niqe_pris_params['gaussian_window']
541
+
542
+ img = img.astype(np.float32)
543
+ if input_order != 'HW':
544
+ img = reorder_image(img, input_order=input_order)
545
+ if convert_to == 'y':
546
+ img = to_y_channel(img)
547
+ elif convert_to == 'gray':
548
+ img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
549
+ img = np.squeeze(img)
550
+
551
+ if crop_border != 0:
552
+ img = img[crop_border:-crop_border, crop_border:-crop_border]
553
+
554
+ # round is necessary for being consistent with MATLAB's result
555
+ img = img.round()
556
+
557
+ niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
558
+
559
+ return niqe_result
JarvisIR/package/agent_tools/HVICIDNet/loss/vgg_arch.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ class Registry():
8
+ """
9
+ The registry that provides name -> object mapping, to support third-party
10
+ users' custom modules.
11
+
12
+ To create a registry (e.g. a backbone registry):
13
+
14
+ .. code-block:: python
15
+
16
+ BACKBONE_REGISTRY = Registry('BACKBONE')
17
+
18
+ To register an object:
19
+
20
+ .. code-block:: python
21
+
22
+ @BACKBONE_REGISTRY.register()
23
+ class MyBackbone():
24
+ ...
25
+
26
+ Or:
27
+
28
+ .. code-block:: python
29
+
30
+ BACKBONE_REGISTRY.register(MyBackbone)
31
+ """
32
+
33
+ def __init__(self, name):
34
+ """
35
+ Args:
36
+ name (str): the name of this registry
37
+ """
38
+ self._name = name
39
+ self._obj_map = {}
40
+
41
+ def _do_register(self, name, obj):
42
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
43
+ f"in '{self._name}' registry!")
44
+ self._obj_map[name] = obj
45
+
46
+ def register(self, obj=None):
47
+ """
48
+ Register the given object under the the name `obj.__name__`.
49
+ Can be used as either a decorator or not.
50
+ See docstring of this class for usage.
51
+ """
52
+ if obj is None:
53
+ # used as a decorator
54
+ def deco(func_or_class):
55
+ name = func_or_class.__name__
56
+ self._do_register(name, func_or_class)
57
+ return func_or_class
58
+
59
+ return deco
60
+
61
+ # used as a function call
62
+ name = obj.__name__
63
+ self._do_register(name, obj)
64
+
65
+ def get(self, name):
66
+ ret = self._obj_map.get(name)
67
+ if ret is None:
68
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
69
+ return ret
70
+
71
+ def __contains__(self, name):
72
+ return name in self._obj_map
73
+
74
+ def __iter__(self):
75
+ return iter(self._obj_map.items())
76
+
77
+ def keys(self):
78
+ return self._obj_map.keys()
79
+
80
+
81
+ DATASET_REGISTRY = Registry('dataset')
82
+ ARCH_REGISTRY = Registry('arch')
83
+ MODEL_REGISTRY = Registry('model')
84
+ LOSS_REGISTRY = Registry('loss')
85
+ METRIC_REGISTRY = Registry('metric')
86
+
87
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
88
+ NAMES = {
89
+ 'vgg11': [
90
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
91
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
92
+ 'pool5'
93
+ ],
94
+ 'vgg13': [
95
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
96
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
97
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
98
+ ],
99
+ 'vgg16': [
100
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
101
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
102
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
103
+ 'pool5'
104
+ ],
105
+ 'vgg19': [
106
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
107
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
108
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
109
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
110
+ ]
111
+ }
112
+
113
+
114
+ def insert_bn(names):
115
+ """Insert bn layer after each conv.
116
+
117
+ Args:
118
+ names (list): The list of layer names.
119
+
120
+ Returns:
121
+ list: The list of layer names with bn layers.
122
+ """
123
+ names_bn = []
124
+ for name in names:
125
+ names_bn.append(name)
126
+ if 'conv' in name:
127
+ position = name.replace('conv', '')
128
+ names_bn.append('bn' + position)
129
+ return names_bn
130
+
131
+
132
+ @ARCH_REGISTRY.register()
133
+ class VGGFeatureExtractor(nn.Module):
134
+ """VGG network for feature extraction.
135
+
136
+ In this implementation, we allow users to choose whether use normalization
137
+ in the input feature and the type of vgg network. Note that the pretrained
138
+ path must fit the vgg type.
139
+
140
+ Args:
141
+ layer_name_list (list[str]): Forward function returns the corresponding
142
+ features according to the layer_name_list.
143
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
144
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
145
+ use_input_norm (bool): If True, normalize the input image. Importantly,
146
+ the input feature must in the range [0, 1]. Default: True.
147
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
148
+ Default: False.
149
+ requires_grad (bool): If true, the parameters of VGG network will be
150
+ optimized. Default: False.
151
+ remove_pooling (bool): If true, the max pooling operations in VGG net
152
+ will be removed. Default: False.
153
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
154
+ """
155
+
156
+ def __init__(self,
157
+ layer_name_list,
158
+ vgg_type='vgg19',
159
+ use_input_norm=True,
160
+ range_norm=False,
161
+ requires_grad=False,
162
+ remove_pooling=False,
163
+ pooling_stride=2):
164
+ super(VGGFeatureExtractor, self).__init__()
165
+
166
+ self.layer_name_list = layer_name_list
167
+ self.use_input_norm = use_input_norm
168
+ self.range_norm = range_norm
169
+
170
+ self.names = NAMES[vgg_type.replace('_bn', '')]
171
+ if 'bn' in vgg_type:
172
+ self.names = insert_bn(self.names)
173
+
174
+ # only borrow layers that will be used to avoid unused params
175
+ max_idx = 0
176
+ for v in layer_name_list:
177
+ idx = self.names.index(v)
178
+ if idx > max_idx:
179
+ max_idx = idx
180
+
181
+ if os.path.exists(VGG_PRETRAIN_PATH):
182
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
183
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
184
+ vgg_net.load_state_dict(state_dict)
185
+ else:
186
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
187
+
188
+ features = vgg_net.features[:max_idx + 1]
189
+
190
+ modified_net = OrderedDict()
191
+ for k, v in zip(self.names, features):
192
+ if 'pool' in k:
193
+ # if remove_pooling is true, pooling operation will be removed
194
+ if remove_pooling:
195
+ continue
196
+ else:
197
+ # in some cases, we may want to change the default stride
198
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
199
+ else:
200
+ modified_net[k] = v
201
+
202
+ self.vgg_net = nn.Sequential(modified_net).cuda()
203
+
204
+ if not requires_grad:
205
+ self.vgg_net.eval()
206
+ for param in self.parameters():
207
+ param.requires_grad = False
208
+ else:
209
+ self.vgg_net.train()
210
+ for param in self.parameters():
211
+ param.requires_grad = True
212
+
213
+ if self.use_input_norm:
214
+ # the mean is for image with range [0, 1]
215
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda())
216
+ # the std is for image with range [0, 1]
217
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda())
218
+
219
+ def forward(self, x):
220
+ """Forward function.
221
+
222
+ Args:
223
+ x (Tensor): Input tensor with shape (n, c, h, w).
224
+
225
+ Returns:
226
+ Tensor: Forward results.
227
+ """
228
+ if self.range_norm:
229
+ x = (x + 1) / 2
230
+ if self.use_input_norm:
231
+ x = (x - self.mean) / self.std
232
+ output = {}
233
+
234
+ for key, layer in self.vgg_net._modules.items():
235
+ x = layer(x)
236
+ if key in self.layer_name_list:
237
+ output[key] = x.clone()
238
+
239
+ return output
JarvisIR/package/agent_tools/HVICIDNet/mods.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import warnings
4
+ import math
5
+
6
+ warnings.filterwarnings("ignore", category=UserWarning)
7
+ warnings.filterwarnings("ignore", category=FutureWarning)
8
+
9
+
10
+ class cross_attention(nn.Module):
11
+ def __init__(self, dim, num_heads, dropout=0.):
12
+ super(cross_attention, self).__init__()
13
+ if dim % num_heads != 0:
14
+ raise ValueError(
15
+ "The hidden size (%d) is not a multiple of the number of attention "
16
+ "heads (%d)" % (dim, num_heads)
17
+ )
18
+ self.num_heads = num_heads
19
+ self.attention_head_size = int(dim / num_heads)
20
+
21
+ self.query = Depth_conv(in_ch=dim, out_ch=dim)
22
+ self.key = Depth_conv(in_ch=dim, out_ch=dim)
23
+ self.value = Depth_conv(in_ch=dim, out_ch=dim)
24
+
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def transpose_for_scores(self, x):
28
+ '''
29
+ new_x_shape = x.size()[:-1] + (
30
+ self.num_heads,
31
+ self.attention_head_size,
32
+ )
33
+ print(new_x_shape)
34
+ x = x.view(*new_x_shape)
35
+ '''
36
+ return x.permute(0, 2, 1, 3)
37
+
38
+ def forward(self, hidden_states, ctx):
39
+ mixed_query_layer = self.query(hidden_states)
40
+ mixed_key_layer = self.key(ctx)
41
+ mixed_value_layer = self.value(ctx)
42
+
43
+ query_layer = self.transpose_for_scores(mixed_query_layer)
44
+ key_layer = self.transpose_for_scores(mixed_key_layer)
45
+ value_layer = self.transpose_for_scores(mixed_value_layer)
46
+
47
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
48
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
49
+
50
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
51
+
52
+ attention_probs = self.dropout(attention_probs)
53
+
54
+ ctx_layer = torch.matmul(attention_probs, value_layer)
55
+ ctx_layer = ctx_layer.permute(0, 2, 1, 3).contiguous()
56
+
57
+ return ctx_layer
58
+
59
+
60
+ class Depth_conv(nn.Module):
61
+ def __init__(self, in_ch, out_ch):
62
+ super(Depth_conv, self).__init__()
63
+ self.depth_conv = nn.Conv2d(
64
+ in_channels=in_ch,
65
+ out_channels=in_ch,
66
+ kernel_size=(3, 3),
67
+ stride=(1, 1),
68
+ padding=1,
69
+ groups=in_ch
70
+ )
71
+ self.point_conv = nn.Conv2d(
72
+ in_channels=in_ch,
73
+ out_channels=out_ch,
74
+ kernel_size=(1, 1),
75
+ stride=(1, 1),
76
+ padding=0,
77
+ groups=1
78
+ )
79
+
80
+ def forward(self, input):
81
+ out = self.depth_conv(input)
82
+ out = self.point_conv(out)
83
+ return out
84
+
85
+
86
+ class Dilated_Resblock(nn.Module):
87
+ def __init__(self, in_channels, out_channels):
88
+ super(Dilated_Resblock, self).__init__()
89
+
90
+ sequence = list()
91
+ sequence += [
92
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1),
93
+ padding=1, dilation=(1, 1)),
94
+ nn.LeakyReLU(),
95
+ nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1),
96
+ padding=2, dilation=(2, 2)),
97
+ nn.LeakyReLU(),
98
+ nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
99
+ padding=3, dilation=(3, 3)),
100
+ nn.LeakyReLU(),
101
+ nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
102
+ padding=2, dilation=(2, 2)),
103
+ nn.LeakyReLU(),
104
+ nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
105
+ padding=1, dilation=(1, 1))
106
+
107
+ ]
108
+
109
+ self.model = nn.Sequential(*sequence)
110
+
111
+ def forward(self, x):
112
+ out = self.model(x) + x
113
+
114
+ return out
115
+
116
+
117
+ class HFRM(nn.Module):
118
+ def __init__(self, in_channels, out_channels):
119
+ super(HFRM, self).__init__()
120
+
121
+ self.conv_head = Depth_conv(in_channels, out_channels)
122
+
123
+ self.dilated_block_LH = Dilated_Resblock(out_channels, out_channels)
124
+ self.dilated_block_HL = Dilated_Resblock(out_channels, out_channels)
125
+
126
+ self.cross_attention0 = cross_attention(out_channels, num_heads=8)
127
+ self.dilated_block_HH = Dilated_Resblock(out_channels, out_channels)
128
+ self.conv_HH = nn.Conv2d(out_channels*2, out_channels, kernel_size=3, stride=1, padding=1)
129
+ self.cross_attention1 = cross_attention(out_channels, num_heads=8)
130
+
131
+ self.conv_tail = Depth_conv(out_channels, in_channels)
132
+
133
+ def forward(self, x):
134
+
135
+ b, c, h, w = x.shape
136
+
137
+ residual = x
138
+
139
+ x = self.conv_head(x)
140
+
141
+ x_HL, x_LH, x_HH = x[:b//3, ...], x[b//3:2*b//3, ...], x[2*b//3:, ...]
142
+
143
+ x_HH_LH = self.cross_attention0(x_LH, x_HH)
144
+ x_HH_HL = self.cross_attention1(x_HL, x_HH)
145
+
146
+ x_HL = self.dilated_block_HL(x_HL)
147
+ x_LH = self.dilated_block_LH(x_LH)
148
+
149
+ x_HH = self.dilated_block_HH(self.conv_HH(torch.cat((x_HH_LH, x_HH_HL), dim=1)))
150
+
151
+ out = self.conv_tail(torch.cat((x_HL, x_LH, x_HH), dim=0))
152
+
153
+ return out + residual
JarvisIR/package/agent_tools/HVICIDNet/net/CIDNet.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .HVI_transform import RGB_HVI
4
+ from .transformer_utils import *
5
+ from .LCA import *
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+
8
+ class CIDNet(nn.Module, PyTorchModelHubMixin):
9
+ def __init__(self,
10
+ channels=[36, 36, 72, 144],
11
+ heads=[1, 2, 4, 8],
12
+ norm=False
13
+ ):
14
+ super(CIDNet, self).__init__()
15
+
16
+
17
+ [ch1, ch2, ch3, ch4] = channels
18
+ [head1, head2, head3, head4] = heads
19
+
20
+ # HV_ways
21
+ self.HVE_block0 = nn.Sequential(
22
+ nn.ReplicationPad2d(1),
23
+ nn.Conv2d(3, ch1, 3, stride=1, padding=0,bias=False)
24
+ )
25
+ self.HVE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
26
+ self.HVE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
27
+ self.HVE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
28
+
29
+ self.HVD_block3 = NormUpsample(ch4, ch3, use_norm = norm)
30
+ self.HVD_block2 = NormUpsample(ch3, ch2, use_norm = norm)
31
+ self.HVD_block1 = NormUpsample(ch2, ch1, use_norm = norm)
32
+ self.HVD_block0 = nn.Sequential(
33
+ nn.ReplicationPad2d(1),
34
+ nn.Conv2d(ch1, 2, 3, stride=1, padding=0,bias=False)
35
+ )
36
+
37
+
38
+ # I_ways
39
+ self.IE_block0 = nn.Sequential(
40
+ nn.ReplicationPad2d(1),
41
+ nn.Conv2d(1, ch1, 3, stride=1, padding=0,bias=False),
42
+ )
43
+ self.IE_block1 = NormDownsample(ch1, ch2, use_norm = norm)
44
+ self.IE_block2 = NormDownsample(ch2, ch3, use_norm = norm)
45
+ self.IE_block3 = NormDownsample(ch3, ch4, use_norm = norm)
46
+
47
+ self.ID_block3 = NormUpsample(ch4, ch3, use_norm=norm)
48
+ self.ID_block2 = NormUpsample(ch3, ch2, use_norm=norm)
49
+ self.ID_block1 = NormUpsample(ch2, ch1, use_norm=norm)
50
+ self.ID_block0 = nn.Sequential(
51
+ nn.ReplicationPad2d(1),
52
+ nn.Conv2d(ch1, 1, 3, stride=1, padding=0,bias=False),
53
+ )
54
+
55
+ self.HV_LCA1 = HV_LCA(ch2, head2)
56
+ self.HV_LCA2 = HV_LCA(ch3, head3)
57
+ self.HV_LCA3 = HV_LCA(ch4, head4)
58
+ self.HV_LCA4 = HV_LCA(ch4, head4)
59
+ self.HV_LCA5 = HV_LCA(ch3, head3)
60
+ self.HV_LCA6 = HV_LCA(ch2, head2)
61
+
62
+ self.I_LCA1 = I_LCA(ch2, head2)
63
+ self.I_LCA2 = I_LCA(ch3, head3)
64
+ self.I_LCA3 = I_LCA(ch4, head4)
65
+ self.I_LCA4 = I_LCA(ch4, head4)
66
+ self.I_LCA5 = I_LCA(ch3, head3)
67
+ self.I_LCA6 = I_LCA(ch2, head2)
68
+
69
+ self.trans = RGB_HVI()
70
+
71
+ def forward(self, x):
72
+ dtypes = x.dtype
73
+ hvi = self.trans.HVIT(x)
74
+ i = hvi[:,2,:,:].unsqueeze(1).to(dtypes)
75
+ # low
76
+ i_enc0 = self.IE_block0(i)
77
+ i_enc1 = self.IE_block1(i_enc0)
78
+ hv_0 = self.HVE_block0(hvi)
79
+ hv_1 = self.HVE_block1(hv_0)
80
+ i_jump0 = i_enc0
81
+ hv_jump0 = hv_0
82
+
83
+ i_enc2 = self.I_LCA1(i_enc1, hv_1)
84
+ hv_2 = self.HV_LCA1(hv_1, i_enc1)
85
+ v_jump1 = i_enc2
86
+ hv_jump1 = hv_2
87
+ i_enc2 = self.IE_block2(i_enc2)
88
+ hv_2 = self.HVE_block2(hv_2)
89
+
90
+ i_enc3 = self.I_LCA2(i_enc2, hv_2)
91
+ hv_3 = self.HV_LCA2(hv_2, i_enc2)
92
+ v_jump2 = i_enc3
93
+ hv_jump2 = hv_3
94
+ i_enc3 = self.IE_block3(i_enc2)
95
+ hv_3 = self.HVE_block3(hv_2)
96
+
97
+ i_enc4 = self.I_LCA3(i_enc3, hv_3)
98
+ hv_4 = self.HV_LCA3(hv_3, i_enc3)
99
+
100
+ i_dec4 = self.I_LCA4(i_enc4,hv_4)
101
+ hv_4 = self.HV_LCA4(hv_4, i_enc4)
102
+
103
+ hv_3 = self.HVD_block3(hv_4, hv_jump2)
104
+ i_dec3 = self.ID_block3(i_dec4, v_jump2)
105
+ i_dec2 = self.I_LCA5(i_dec3, hv_3)
106
+ hv_2 = self.HV_LCA5(hv_3, i_dec3)
107
+
108
+ hv_2 = self.HVD_block2(hv_2, hv_jump1)
109
+ i_dec2 = self.ID_block2(i_dec3, v_jump1)
110
+
111
+ i_dec1 = self.I_LCA6(i_dec2, hv_2)
112
+ hv_1 = self.HV_LCA6(hv_2, i_dec2)
113
+
114
+ i_dec1 = self.ID_block1(i_dec1, i_jump0)
115
+ i_dec0 = self.ID_block0(i_dec1)
116
+ hv_1 = self.HVD_block1(hv_1, hv_jump0)
117
+ hv_0 = self.HVD_block0(hv_1)
118
+
119
+ output_hvi = torch.cat([hv_0, i_dec0], dim=1) + hvi
120
+ output_rgb = self.trans.PHVIT(output_hvi)
121
+
122
+ return output_rgb
123
+
124
+ def HVIT(self,x):
125
+ hvi = self.trans.HVIT(x)
126
+ return hvi
127
+
128
+
129
+
JarvisIR/package/agent_tools/HVICIDNet/net/HVI_transform.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ pi = 3.141592653589793
5
+
6
+ class RGB_HVI(nn.Module):
7
+ def __init__(self):
8
+ super(RGB_HVI, self).__init__()
9
+ self.density_k = torch.nn.Parameter(torch.full([1],0.2)) # k is reciprocal to the paper mentioned
10
+ self.gated = False
11
+ self.gated2= False
12
+ self.alpha = 1.0
13
+ self.alpha_s = 1.3
14
+ self.this_k = 0
15
+
16
+ def HVIT(self, img):
17
+ eps = 1e-8
18
+ device = img.device
19
+ dtypes = img.dtype
20
+ hue = torch.Tensor(img.shape[0], img.shape[2], img.shape[3]).to(device).to(dtypes)
21
+ value = img.max(1)[0].to(dtypes)
22
+ img_min = img.min(1)[0].to(dtypes)
23
+ hue[img[:,2]==value] = 4.0 + ( (img[:,0]-img[:,1]) / (value - img_min + eps)) [img[:,2]==value]
24
+ hue[img[:,1]==value] = 2.0 + ( (img[:,2]-img[:,0]) / (value - img_min + eps)) [img[:,1]==value]
25
+ hue[img[:,0]==value] = (0.0 + ((img[:,1]-img[:,2]) / (value - img_min + eps)) [img[:,0]==value]) % 6
26
+
27
+ hue[img.min(1)[0]==value] = 0.0
28
+ hue = hue/6.0
29
+
30
+ saturation = (value - img_min ) / (value + eps )
31
+ saturation[value==0] = 0
32
+
33
+ hue = hue.unsqueeze(1)
34
+ saturation = saturation.unsqueeze(1)
35
+ value = value.unsqueeze(1)
36
+
37
+ k = self.density_k
38
+ self.this_k = k.item()
39
+
40
+ color_sensitive = ((value * 0.5 * pi).sin() + eps).pow(k)
41
+ ch = (2.0 * pi * hue).cos()
42
+ cv = (2.0 * pi * hue).sin()
43
+ H = color_sensitive * saturation * ch
44
+ V = color_sensitive * saturation * cv
45
+ I = value
46
+ xyz = torch.cat([H, V, I],dim=1)
47
+ return xyz
48
+
49
+ def PHVIT(self, img):
50
+ eps = 1e-8
51
+ H,V,I = img[:,0,:,:],img[:,1,:,:],img[:,2,:,:]
52
+
53
+ # clip
54
+ H = torch.clamp(H,-1,1)
55
+ V = torch.clamp(V,-1,1)
56
+ I = torch.clamp(I,0,1)
57
+
58
+ v = I
59
+ k = self.this_k
60
+ color_sensitive = ((v * 0.5 * pi).sin() + eps).pow(k)
61
+ H = (H) / (color_sensitive + eps)
62
+ V = (V) / (color_sensitive + eps)
63
+ H = torch.clamp(H,-1,1)
64
+ V = torch.clamp(V,-1,1)
65
+ h = torch.atan2(V + eps,H + eps) / (2*pi)
66
+ h = h%1
67
+ s = torch.sqrt(H**2 + V**2 + eps)
68
+
69
+ if self.gated:
70
+ s = s * self.alpha_s
71
+
72
+ s = torch.clamp(s,0,1)
73
+ v = torch.clamp(v,0,1)
74
+
75
+ r = torch.zeros_like(h)
76
+ g = torch.zeros_like(h)
77
+ b = torch.zeros_like(h)
78
+
79
+ hi = torch.floor(h * 6.0)
80
+ f = h * 6.0 - hi
81
+ p = v * (1. - s)
82
+ q = v * (1. - (f * s))
83
+ t = v * (1. - ((1. - f) * s))
84
+
85
+ hi0 = hi==0
86
+ hi1 = hi==1
87
+ hi2 = hi==2
88
+ hi3 = hi==3
89
+ hi4 = hi==4
90
+ hi5 = hi==5
91
+
92
+ r[hi0] = v[hi0]
93
+ g[hi0] = t[hi0]
94
+ b[hi0] = p[hi0]
95
+
96
+ r[hi1] = q[hi1]
97
+ g[hi1] = v[hi1]
98
+ b[hi1] = p[hi1]
99
+
100
+ r[hi2] = p[hi2]
101
+ g[hi2] = v[hi2]
102
+ b[hi2] = t[hi2]
103
+
104
+ r[hi3] = p[hi3]
105
+ g[hi3] = q[hi3]
106
+ b[hi3] = v[hi3]
107
+
108
+ r[hi4] = t[hi4]
109
+ g[hi4] = p[hi4]
110
+ b[hi4] = v[hi4]
111
+
112
+ r[hi5] = v[hi5]
113
+ g[hi5] = p[hi5]
114
+ b[hi5] = q[hi5]
115
+
116
+ r = r.unsqueeze(1)
117
+ g = g.unsqueeze(1)
118
+ b = b.unsqueeze(1)
119
+ rgb = torch.cat([r, g, b], dim=1)
120
+ if self.gated2:
121
+ rgb = rgb * self.alpha
122
+ return rgb
JarvisIR/package/agent_tools/HVICIDNet/net/LCA.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+ from .transformer_utils import *
5
+
6
+ # Cross Attention Block
7
+ class CAB(nn.Module):
8
+ def __init__(self, dim, num_heads, bias):
9
+ super(CAB, self).__init__()
10
+ self.num_heads = num_heads
11
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
12
+
13
+ self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
14
+ self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
15
+ self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
16
+ self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
17
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
18
+
19
+ def forward(self, x, y):
20
+ b, c, h, w = x.shape
21
+
22
+ q = self.q_dwconv(self.q(x))
23
+ kv = self.kv_dwconv(self.kv(y))
24
+ k, v = kv.chunk(2, dim=1)
25
+
26
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
27
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
28
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
29
+
30
+ q = torch.nn.functional.normalize(q, dim=-1)
31
+ k = torch.nn.functional.normalize(k, dim=-1)
32
+
33
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
34
+ attn = nn.functional.softmax(attn,dim=-1)
35
+
36
+ out = (attn @ v)
37
+
38
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
39
+
40
+ out = self.project_out(out)
41
+ return out
42
+
43
+
44
+ # Intensity Enhancement Layer
45
+ class IEL(nn.Module):
46
+ def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
47
+ super(IEL, self).__init__()
48
+
49
+ hidden_features = int(dim*ffn_expansion_factor)
50
+
51
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
52
+
53
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
54
+ self.dwconv1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
55
+ self.dwconv2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, groups=hidden_features, bias=bias)
56
+
57
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
58
+
59
+ self.Tanh = nn.Tanh()
60
+ def forward(self, x):
61
+ x = self.project_in(x)
62
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
63
+ x1 = self.Tanh(self.dwconv1(x1)) + x1
64
+ x2 = self.Tanh(self.dwconv2(x2)) + x2
65
+ x = x1 * x2
66
+ x = self.project_out(x)
67
+ return x
68
+
69
+
70
+ # Lightweight Cross Attention
71
+ class HV_LCA(nn.Module):
72
+ def __init__(self, dim,num_heads, bias=False):
73
+ super(HV_LCA, self).__init__()
74
+ self.gdfn = IEL(dim) # IEL and CDL have same structure
75
+ self.norm = LayerNorm(dim)
76
+ self.ffn = CAB(dim, num_heads, bias)
77
+
78
+ def forward(self, x, y):
79
+ x = x + self.ffn(self.norm(x),self.norm(y))
80
+ x = self.gdfn(self.norm(x))
81
+ return x
82
+
83
+ class I_LCA(nn.Module):
84
+ def __init__(self, dim,num_heads, bias=False):
85
+ super(I_LCA, self).__init__()
86
+ self.norm = LayerNorm(dim)
87
+ self.gdfn = IEL(dim)
88
+ self.ffn = CAB(dim, num_heads, bias=bias)
89
+
90
+ def forward(self, x, y):
91
+ x = x + self.ffn(self.norm(x),self.norm(y))
92
+ x = x + self.gdfn(self.norm(x))
93
+ return x
JarvisIR/package/agent_tools/HVICIDNet/net/transformer_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class LayerNorm(nn.Module):
6
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
7
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
8
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
9
+ with shape (batch_size, channels, height, width).
10
+ """
11
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
12
+ super().__init__()
13
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
14
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
15
+ self.eps = eps
16
+ self.data_format = data_format
17
+ if self.data_format not in ["channels_last", "channels_first"]:
18
+ raise NotImplementedError
19
+ self.normalized_shape = (normalized_shape, )
20
+
21
+ def forward(self, x):
22
+ if self.data_format == "channels_last":
23
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
24
+ elif self.data_format == "channels_first":
25
+ u = x.mean(1, keepdim=True)
26
+ s = (x - u).pow(2).mean(1, keepdim=True)
27
+ x = (x - u) / torch.sqrt(s + self.eps)
28
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
29
+ return x
30
+
31
+ class NormDownsample(nn.Module):
32
+ def __init__(self,in_ch,out_ch,scale=0.5,use_norm=False):
33
+ super(NormDownsample, self).__init__()
34
+ self.use_norm=use_norm
35
+ if self.use_norm:
36
+ self.norm=LayerNorm(out_ch)
37
+ self.prelu = nn.PReLU()
38
+ self.down = nn.Sequential(
39
+ nn.Conv2d(in_ch, out_ch,kernel_size=3,stride=1, padding=1, bias=False),
40
+ nn.UpsamplingBilinear2d(scale_factor=scale))
41
+ def forward(self, x):
42
+ x = self.down(x)
43
+ x = self.prelu(x)
44
+ if self.use_norm:
45
+ x = self.norm(x)
46
+ return x
47
+ else:
48
+ return x
49
+
50
+ class NormUpsample(nn.Module):
51
+ def __init__(self, in_ch,out_ch,scale=2,use_norm=False):
52
+ super(NormUpsample, self).__init__()
53
+ self.use_norm=use_norm
54
+ if self.use_norm:
55
+ self.norm=LayerNorm(out_ch)
56
+ self.prelu = nn.PReLU()
57
+ self.up_scale = nn.Sequential(
58
+ nn.Conv2d(in_ch,out_ch,kernel_size=3,stride=1, padding=1, bias=False),
59
+ nn.UpsamplingBilinear2d(scale_factor=scale))
60
+ self.up = nn.Conv2d(out_ch*2,out_ch,kernel_size=1,stride=1, padding=0, bias=False)
61
+
62
+ def forward(self, x,y):
63
+ x = self.up_scale(x)
64
+ x = torch.cat([x, y],dim=1)
65
+ x = self.up(x)
66
+ x = self.prelu(x)
67
+ if self.use_norm:
68
+ return self.norm(x)
69
+ else:
70
+ return x
71
+
JarvisIR/package/agent_tools/HVICIDNet/wavelet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def Normalize(x):
6
+ ymax = 255
7
+ ymin = 0
8
+ xmax = x.max()
9
+ xmin = x.min()
10
+ return (ymax-ymin)*(x-xmin)/(xmax-xmin) + ymin
11
+
12
+
13
+ def dwt_init(x):
14
+
15
+ x01 = x[:, :, 0::2, :] / 2
16
+ x02 = x[:, :, 1::2, :] / 2
17
+ x1 = x01[:, :, :, 0::2]
18
+ x2 = x02[:, :, :, 0::2]
19
+ x3 = x01[:, :, :, 1::2]
20
+ x4 = x02[:, :, :, 1::2]
21
+ x_LL = x1 + x2 + x3 + x4
22
+ x_HL = -x1 - x2 + x3 + x4
23
+ x_LH = -x1 + x2 - x3 + x4
24
+ x_HH = x1 - x2 - x3 + x4
25
+
26
+ return torch.cat((x_LL, x_HL, x_LH, x_HH), 0)
27
+
28
+
29
+ # 使用哈尔 haar 小波变换来实现二维离散小波
30
+ def iwt_init(x):
31
+ r = 2
32
+ in_batch, in_channel, in_height, in_width = x.size()
33
+ out_batch, out_channel, out_height, out_width = int(in_batch/(r**2)),in_channel, r * in_height, r * in_width
34
+ x1 = x[0:out_batch, :, :] / 2
35
+ x2 = x[out_batch:out_batch * 2, :, :, :] / 2
36
+ x3 = x[out_batch * 2:out_batch * 3, :, :, :] / 2
37
+ x4 = x[out_batch * 3:out_batch * 4, :, :, :] / 2
38
+
39
+ h = torch.zeros([out_batch, out_channel, out_height,
40
+ out_width]).float().to(x.device)
41
+
42
+ h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
43
+ h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
44
+ h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
45
+ h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
46
+
47
+ return h
48
+
49
+
50
+ class DWT(nn.Module):
51
+ def __init__(self):
52
+ super(DWT, self).__init__()
53
+ self.requires_grad = False # 信号处理,非卷积运算,不需要进行梯度求导
54
+
55
+ def forward(self, x):
56
+ return dwt_init(x)
57
+
58
+
59
+ class IWT(nn.Module):
60
+ def __init__(self):
61
+ super(IWT, self).__init__()
62
+ self.requires_grad = False
63
+
64
+ def forward(self, x):
65
+ return iwt_init(x)
JarvisIR/package/agent_tools/IDT/__init__.py ADDED
File without changes
JarvisIR/package/agent_tools/IDT/analyse/cal_rf_bf.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ def getimgarr(path_1):
7
+ dir_1 = sorted(os.listdir(path_1))
8
+ imgs = []
9
+ for j in range(len(dir_1)):
10
+ img = cv2.imread(path_1 + '/' +dir_1[j])
11
+ imgs.append(img)
12
+ imgs = np.array(imgs, dtype=float)
13
+
14
+ return imgs
15
+
16
+ #inp_path = '/disk1/release/DayRainDrop/Clear'
17
+ #inp_path = '/disk1/release/NightRainDrop/Clear'
18
+ #inp_path = '/disk1/release/DayRainDrop_Test/Clear'
19
+ inp_path = '/disk1/release/NightRainDrop_Test/Clear'
20
+
21
+ countall = 0
22
+ countbf = 0
23
+ countrf = 0
24
+
25
+ sid = sorted(os.listdir(inp_path))
26
+ print(sid,len(sid))
27
+
28
+ for i in range(len(sid)):
29
+
30
+ inp_path_clean = os.path.join(inp_path, sid[i])
31
+ inp_path_blur = inp_path_clean.replace('/Clear/','/Blur/')
32
+ #print(inp_path_clean,inp_path_blur)
33
+
34
+ dropimgs = sorted(os.listdir(inp_path_clean))
35
+ Droplist = []
36
+ for didx in range(len(dropimgs)):
37
+ if dropimgs[didx].endswith('.png'):
38
+ Droplist.append(dropimgs[didx])
39
+ # print(Droplist)
40
+ # print('-------------------------------------------------')
41
+
42
+ for j in range(len(Droplist)):
43
+ cleaninp_frame_name = os.path.join(inp_path_clean, Droplist[j])
44
+ blurinp_frame_name = cleaninp_frame_name.replace('/Clear/','/Blur/')
45
+ # print(cleaninp_frame_name,blurinp_frame_name)
46
+ countall = countall + 1
47
+ cleanimage = Image.open(cleaninp_frame_name)
48
+ blurimage = Image.open(blurinp_frame_name)
49
+ if cleanimage == blurimage:
50
+ countbf = countbf + 1
51
+ parts = cleaninp_frame_name.split(os.sep)
52
+ new_path = os.path.join(parts[-2], parts[-1])
53
+ print(new_path)
54
+ else:
55
+ countrf = countrf + 1
56
+ print('-countall-',countall,'-background focus-',countbf,'-raindrop focus-',countrf)
57
+
58
+ #DayRainDrop -countall- 5442,-background focus- 1836,-raindrop focus- 3606,
59
+ #NightRainDrop -countall- 9744,-background focus- 4906,-raindrop focus- 4838,
60
+ #Total: 15186
61
+
62
+ #DayRainDrop_Train -countall- 8655,-background focus- 4143,-raindrop focus- 4512
63
+ #DayRainDrop_Test -countall- 729 -background focus- 261 -raindrop focus- 468
64
+ #NightRainDrop_Train -countall- 4713,-background focus- 1575,-raindrop focus- 3138
65
+ #NightRainDrop_Test -countall- 1089 -background focus- 763 -raindrop focus- 326
66
+
67
+
68
+
JarvisIR/package/agent_tools/IDT/configs/daytime_128.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "RainDrop"
3
+ #image_size: 64 # DIT, RDiffusion, onego
4
+ #image_size: 256 # ICRA, Uformer, atgan
5
+ image_size: 128 # IDT, restormer
6
+ channels: 3
7
+ num_workers: 8
8
+ data_dir: "/data4/lx_workspace/datasets/SFT-data/mini/daytime_driving_rainy/rainy/"
9
+ conditional: True
10
+
11
+ model:
12
+ in_channels: 3
13
+ out_ch: 3
14
+ ch: 128
15
+ ch_mult: [1, 1, 2, 2, 4, 4]
16
+ num_res_blocks: 2
17
+ attn_resolutions: [16, ]
18
+ dropout: 0.0
19
+ ema_rate: 0.999
20
+ ema: True
21
+ resamp_with_conv: True
22
+
23
+ diffusion:
24
+ beta_schedule: linear
25
+ beta_start: 0.0001
26
+ beta_end: 0.02
27
+ num_diffusion_timesteps: 1000
28
+
29
+ training:
30
+ #-----------DIT, RDiffusion, onego, ICRA--------------
31
+ # patch_n: 4
32
+ # batch_size: 16
33
+ # #---------IDT Uformer restormer--------------
34
+ patch_n: 1
35
+ batch_size: 4
36
+ #-----------atgan--------------
37
+ # patch_n: 1
38
+ # batch_size: 32
39
+ #-----------DIT, RDiffusion--------------
40
+ #n_epochs: 37042
41
+ n_epochs: 401
42
+ n_iters: 20000000
43
+ snapshot_freq: 50
44
+ validation_freq: 10000
45
+
46
+ sampling:
47
+ batch_size: 4
48
+ last_only: True
49
+
50
+ optim:
51
+ weight_decay: 0.000
52
+ optimizer: "Adam"
53
+ lr: 0.00002
54
+ amsgrad: False
55
+ eps: 0.00000001
JarvisIR/package/agent_tools/IDT/configs/daytime_256.yml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: "RainDrop"
3
+ #image_size: 64 # DIT, RDiffusion, onego
4
+ image_size: 256 # ICRA, Uformer, atgan
5
+ # image_size: 128 # IDT, restormer
6
+ channels: 3
7
+ num_workers: 8
8
+ data_dir: "/data4/lx_workspace/datasets/SFT-data/mini/daytime_driving_rainy/rainy/"
9
+ conditional: True
10
+
11
+ model:
12
+ in_channels: 3
13
+ out_ch: 3
14
+ ch: 128
15
+ ch_mult: [1, 1, 2, 2, 4, 4]
16
+ num_res_blocks: 2
17
+ attn_resolutions: [16, ]
18
+ dropout: 0.0
19
+ ema_rate: 0.999
20
+ ema: True
21
+ resamp_with_conv: True
22
+
23
+ diffusion:
24
+ beta_schedule: linear
25
+ beta_start: 0.0001
26
+ beta_end: 0.02
27
+ num_diffusion_timesteps: 1000
28
+
29
+ training:
30
+ #-----------DIT, RDiffusion, onego, ICRA--------------
31
+ # patch_n: 4
32
+ # batch_size: 16
33
+ # #---------IDT Uformer restormer--------------
34
+ patch_n: 1
35
+ batch_size: 4
36
+ #-----------atgan--------------
37
+ # patch_n: 1
38
+ # batch_size: 32
39
+ #-----------DIT, RDiffusion--------------
40
+ #n_epochs: 37042
41
+ n_epochs: 401
42
+ n_iters: 20000000
43
+ snapshot_freq: 50
44
+ validation_freq: 10000
45
+
46
+ sampling:
47
+ batch_size: 4
48
+ last_only: True
49
+
50
+ optim:
51
+ weight_decay: 0.000
52
+ optimizer: "Adam"
53
+ lr: 0.00002
54
+ amsgrad: False
55
+ eps: 0.00000001