umzi commited on
Commit
4f763cc
·
verified ·
1 Parent(s): ea4a206

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figs/FIDSR.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
.idea/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
.idea/FIGSR3.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.13 virtualenv at /run/media/umzi/H/resseltrr/.venv" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.13 virtualenv at /run/media/umzi/H/resseltrr/.venv" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 virtualenv at /run/media/umzi/H/resseltrr/.venv" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/FIGSR3.iml" filepath="$PROJECT_DIR$/.idea/FIGSR3.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="bd19aad7-30f1-41f2-8970-657e29732baf" name="Changes" comment="" />
5
+ <option name="SHOW_DIALOG" value="false" />
6
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
7
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
+ <option name="LAST_RESOLUTION" value="IGNORE" />
9
+ </component>
10
+ <component name="Git.Settings">
11
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
12
+ </component>
13
+ <component name="ProjectColorInfo"><![CDATA[{
14
+ "associatedIndex": 2
15
+ }]]></component>
16
+ <component name="ProjectId" id="39WMPB6Zm6p9HESSvY2vaN0KQUO" />
17
+ <component name="ProjectViewState">
18
+ <option name="hideEmptyMiddlePackages" value="true" />
19
+ <option name="showLibraryContents" value="true" />
20
+ </component>
21
+ <component name="PropertiesComponent"><![CDATA[{
22
+ "keyToString": {
23
+ "RunOnceActivity.ShowReadmeOnStart": "true",
24
+ "RunOnceActivity.git.unshallow": "true",
25
+ "git-widget-placeholder": "main",
26
+ "last_opened_file_path": "/run/media/umzi/H/FIGSR3/weights"
27
+ }
28
+ }]]></component>
29
+ <component name="RecentsManager">
30
+ <key name="CopyFile.RECENT_KEYS">
31
+ <recent name="$PROJECT_DIR$/weights" />
32
+ <recent name="$PROJECT_DIR$" />
33
+ </key>
34
+ <key name="MoveFile.RECENT_KEYS">
35
+ <recent name="$PROJECT_DIR$/weights" />
36
+ </key>
37
+ </component>
38
+ <component name="SharedIndexes">
39
+ <attachedChunks>
40
+ <set>
41
+ <option value="bundled-python-sdk-164cda30dcd9-0af03a5fa574-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-252.26830.99" />
42
+ </set>
43
+ </attachedChunks>
44
+ </component>
45
+ <component name="TaskManager">
46
+ <task active="true" id="Default" summary="Default task">
47
+ <changelist id="bd19aad7-30f1-41f2-8970-657e29732baf" name="Changes" comment="" />
48
+ <created>1770807346896</created>
49
+ <option name="number" value="Default" />
50
+ <option name="presentableId" value="Default" />
51
+ <updated>1770807346896</updated>
52
+ </task>
53
+ <servers />
54
+ </component>
55
+ </project>
README.md CHANGED
@@ -1,3 +1,109 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fourier Inception Gated Super Resolution
2
+
3
+ The main idea of the model is to integrate the [FourierUnit](https://github.com/deng-ai-lab/SFHformer/blob/1f7994112b9ced9153edc7187e320e0383a9dfd3/models/SFHformer.py#L143) into the [GatedCNN](https://github.com/yuweihao/MambaOut/blob/main/models/mambaout.py#L119) pipeline in order to strengthen the model’s global perception with minimal computational overhead.
4
+
5
+ The FourierUnit adds feature processing in the frequency domain, expanding the effective receptive field, while the GatedCNN provides efficient local modeling and control of information flow through a gating mechanism. Their combination allows merging global context and computational efficiency within a compact SISR architecture.
6
+
7
+ ---
8
+ # TODO:
9
+ + [ ] Fix trt inference
10
+ ---
11
+ ## Showcase:
12
+ [show pics](https://slow.pics/s/fPvcS3P0?image-fit=contain)
13
+
14
+ [gdrive](https://drive.google.com/drive/u/1/folders/1ofJo5CCgrOtLdVm9psmlJv15Z3aP4Aiz)
15
+
16
+ ---
17
+ ## Model structure:
18
+
19
+ ### figsr
20
+
21
+ <img src="figs/figsr.png" width="600"/>
22
+
23
+ ### GDB FU
24
+
25
+ <img src="figs/gdb_and_FU.png" width="600"/>
26
+
27
+ ---
28
+
29
+ ### Main blocks and their changes relative to the originals:
30
+
31
+ * [GatedCNN](https://github.com/yuweihao/MambaOut/blob/main/models/mambaout.py#L119) — borrowed from the [MambaOut](https://github.com/yuweihao/MambaOut/blob/main/models/mambaout.py#L119) repository with the following changes:
32
+
33
+ * `Linear` replaced with `Conv` to avoid unnecessary `permute` operations;
34
+ * one of the linear layers replaced with a `Conv 3×3`, which improves quality without a significant increase in computational cost;
35
+ * `LayerNorm` replaced with `RMSNorm` for speed and greater stability;
36
+ * `DConv` replaced with `InceptionConv`.
37
+
38
+ * [InceptionConv](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L627) — a modified version of the block from [InceptionNeXt](https://github.com/sail-sg/inceptionnext/blob/main/models/inceptionnext.py#L19):
39
+
40
+ * `DConv` replaced with standard convolutions;
41
+ * kernel sizes increased following the findings of [PLKSR](https://github.com/dslisleedh/PLKSR);
42
+ * the shortcut replaced with `FourierUnit`, which improves convergence because a residual connection is already present inside `GatedCNN`.
43
+
44
+ * [FourierUnit](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L585) — a modified version of the block from [SFHformer](https://github.com/deng-ai-lab/SFHformer/blob/1f7994112b9ced9153edc7187e320e0383a9dfd3/models/SFHformer.py#L143):
45
+
46
+ * `BatchNorm` replaced with `RMSNorm`, which works better with the small batch sizes typical for SISR;
47
+ * structural changes made for correct export to ONNX;
48
+ * post-normalization added, since without it training instability and `NaN` values were observed in the context of `GatedCNN`.
49
+
50
+ ---
51
+
52
+ ## Metrics:
53
+ * Metrics were computed using [PyIQA](https://github.com/chaofengc/IQA-PyTorch/tree/main), except for those starting with “bs”, which were calculated using BasicSR.
54
+ ### [Esrgan DF2K](https://drive.google.com/file/d/1mSJ6Z40weL-dnPvi390xDd3uZBCFMeqr/view?usp=sharing):
55
+ | Dataset | SSIM-Y | PSNR-Y | TOPIQ | bs_ssim_y | bs_psnr_y |
56
+ | ------------- | ------ | ------ | ------ | --------- | --------- |
57
+ | BHI100 | 0.7150 | 22.84 | 0.5694 | 0.7279 | 24.1636 |
58
+ | psisrd_val125 | 0.7881 | 27.01 | 0.6043 | 0.8034 | 28.3273 |
59
+ | set14 | 0.7730 | 27.67 | 0.6905 | 0.7915 | 28.9969 |
60
+ | urban100 | 0.8025 | 25.71 | 0.6701 | 0.8152 | 27.0282 |
61
+ ### [FIGSR BHI](https://huggingface.co/enhancr-dev/figsr/blob/main/weight/v1.0.0):
62
+ | Dataset | SSIM-Y | PSNR-Y | TOPIQ | bs_ssim_y | bs_psnr_y |
63
+ | ------------- | ------ | ------ | ------ | --------- | --------- |
64
+ | BHI100 | 0.7196 | 22.83 | 0.5723 | 0.7327 | 24.1549 |
65
+ | psisrd_val125 | 0.7911 | 26.97 | 0.6095 | 0.8065 | 28.2946 |
66
+ | set14 | 0.7769 | 27.70 | 0.7036 | 0.7952 | 29.0221 |
67
+ | urban100 | 0.8056 | 25.80 | 0.6725 | 0.8185 | 27.1170 |
68
+
69
+ ---
70
+
71
+ ## Performance 3060 12gb:
72
+ | Model | input_size | params ↓ | avg_inference ↓ | fps ↑ | memory_use ↓ |
73
+ |--------| ---------- | -------- |-----------------| ------------------ | ------------ |
74
+ | ESRGAN | 1024x1024 | ~16.6m | ~2.8s | 0.3483220866736526 | 8.29GB |
75
+ | FIGSR | 1024x1024 | ~4.4m | ~1.64s | 0.6081749253740837 | 2.26GB |
76
+
77
+ ## Training
78
+
79
+ To train, choose one of the frameworks and place the model file in the `archs` folder:
80
+
81
+ * **[NeoSR](https://github.com/neosr-project/neosr)** — `figsr_arch.py` → `neosr/archs/figsr_arch.py`. [Config](configs/neosr.toml)
82
+
83
+ * Uncomment lines [14–17](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L14-L17), [694](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L694) and [705](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L705).
84
+ * Comment out line [703](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L703).
85
+
86
+ * **[traiNNer-redux](https://github.com/the-database/traiNNer-redux)** — `figsr_arch.py` → `traiNNer/archs/figsr_arch.py`. [Config](configs/trainner-redux.yml)
87
+
88
+ * Uncomment lines [11](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L11) and [694](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L694).
89
+
90
+ * **[BasicSR](https://github.com/XPixelGroup/BasicSR/tree/master/basicsr/archs)** — `figsr_arch.py` → `basicsr/archs/figsr_arch.py`. [Config](configs/basicsr.yml)
91
+
92
+ * Uncomment lines [19](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L19) and [694](https://huggingface.co/enhancr-dev/figsr/blob/main/figsr_arch.py#L694).
93
+
94
+ ---
95
+
96
+ ## Inference:
97
+ ### Resselt install
98
+ ```shell
99
+ uv venv --python=3.12
100
+ source .venv/bin/activate
101
+ uv pip install "resselt==1.3.1" "pepeline==1.2.3"
102
+ ```
103
+ ### main.py
104
+ ```shell
105
+ python main.py --input_dir urban/x4 --output_dir urban/x4_scale --weights 4x_FIGSR.safetensors
106
+ ```
107
+ ---
108
+ ## Contacts:
109
+ [discord](https://discord.gg/xwZfWWMwBq)
configs/basicsr.yml ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: 4x_figsr
3
+ model_type: SRModel
4
+ scale: 4
5
+ num_gpu: 1 # set num_gpu: 0 for cpu mode
6
+ manual_seed: 1024
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ name: BHI
12
+ type: PairedImageDataset
13
+ dataroot_gt: datasets/BHI
14
+ dataroot_lq: datasets/BHI_lq
15
+
16
+ filename_tmpl: '{}'
17
+ io_backend:
18
+ type: disk
19
+
20
+ gt_size: 256
21
+ use_hflip: true
22
+ use_rot: true
23
+
24
+ # data loader
25
+ num_worker_per_gpu: 6
26
+ batch_size_per_gpu: 64
27
+ dataset_enlarge_ratio: 1
28
+ prefetch_mode: ~
29
+
30
+ val:
31
+ name: Set5
32
+ type: PairedImageDataset
33
+ dataroot_gt: datasets/Set5/GTmod12
34
+ dataroot_lq: datasets/Set5/LRbicx4
35
+ io_backend:
36
+ type: disk
37
+
38
+ # network structures
39
+ network_g:
40
+ type: FIGSR
41
+
42
+ # path
43
+ path:
44
+ # pretrain_network_g: ""
45
+ strict_load_g: false
46
+ resume_state: ~
47
+
48
+ # training settings
49
+ train:
50
+ ema_decay: 0.999
51
+ optim_g:
52
+ type: Adam
53
+ lr: !!float 5e-4
54
+ weight_decay: 0
55
+ betas: [0.9, 0.99]
56
+
57
+ scheduler:
58
+ type: MultiStepLR
59
+ milestones: [200000,400000,600000,800000]
60
+ gamma: 0.5
61
+
62
+ total_iter: 1000000
63
+ warmup_iter: -1 # no warm up
64
+
65
+ # losses
66
+ pixel_opt:
67
+ type: CharbonnierLoss
68
+ loss_weight: 1.0
69
+ reduction: mean
70
+
71
+ # validation settings
72
+ val:
73
+ val_freq: !!float 5e3
74
+ save_img: true
75
+
76
+ metrics:
77
+ psnr: # metric name, can be arbitrary
78
+ type: calculate_psnr
79
+ crop_border: 4
80
+ test_y_channel: false
81
+
82
+ # logging settings
83
+ logger:
84
+ print_freq: 100
85
+ save_checkpoint_freq: !!float 5e3
86
+ use_tb_logger: true
87
+ wandb:
88
+ project: ~
89
+ resume_id: ~
90
+
91
+ # dist training settings
92
+ dist_params:
93
+ backend: nccl
94
+ port: 29500
configs/neosr.toml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ name = "4x_figsr"
3
+ model_type = "image"
4
+ scale = 4
5
+ use_amp = true
6
+ bfloat16 = false
7
+ fast_matmul = false
8
+ #compile = true
9
+ manual_seed = 1024
10
+
11
+ [datasets.train]
12
+ type = "paired"
13
+ dataroot_gt = 'datasets/BHI'
14
+ dataroot_lq = 'datasets/BHI_lq'
15
+ patch_size = 64
16
+ batch_size = 64
17
+ #accumulate = 1
18
+
19
+ [datasets.val]
20
+ name = "val"
21
+ type = "paired"
22
+ dataroot_gt = 'C:\datasets\val\gt\'
23
+ dataroot_lq = 'C:\datasets\val\lq\'
24
+ [val]
25
+ val_freq = 5000
26
+ #tile = 200
27
+ [val.metrics.psnr]
28
+ type = "calculate_psnr"
29
+ [val.metrics.ssim]
30
+ type = "calculate_ssim"
31
+ #[val.metrics.dists]
32
+ #type = "calculate_dists"
33
+ #better = "lower"
34
+ #[val.metrics.topiq]
35
+ #type = "calculate_topiq"
36
+
37
+ [path]
38
+ #pretrain_network_g = 'experiments\pretrain_g.pth'
39
+ #pretrain_network_d = 'experiments\pretrain_d.pth'
40
+
41
+ [network_g]
42
+ type = "FIGSR"
43
+
44
+
45
+ [train]
46
+ grad_clip = false
47
+ ema = 0.999
48
+ wavelet_guided = false
49
+ #wavelet_init = 80000
50
+ #sam = "fsam"
51
+ #sam_init = 1000
52
+ #eco = true
53
+ #eco_init = 15000
54
+ #match_lq_colors = true
55
+
56
+ [train.optim_g]
57
+ type = "adamw"
58
+ lr = 5e-4
59
+ betas = [0.9, 0.99]
60
+ weight_decay = 0.01
61
+
62
+ [train.scheduler]
63
+ type = "multisteplr"
64
+ milestones = [200000,400000,600000,800000]
65
+ gamma = 0.5
66
+
67
+
68
+ # losses
69
+ [train.mssim_opt]
70
+ type = "huber_loss"
71
+ loss_weight = 1.0
72
+
73
+ [logger]
74
+ total_iter = 1000000
75
+ save_checkpoint_freq = 5000
76
+ use_tb_logger = true
77
+ #save_tb_img = true
78
+ #print_freq = 100
configs/trainner-redux.yml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # yaml-language-server: $schema=https://raw.githubusercontent.com/the-database/traiNNer-redux/refs/heads/master/schemas/redux-config.schema.json
2
+ #########################################################################################
3
+ # General Settings
4
+ # https://trainner-redux.readthedocs.io/en/latest/config_reference.html#top-level-options
5
+ #########################################################################################
6
+ name: 4x_figsr
7
+ scale: 4 # 1, 2, 3, 4, 8
8
+ use_amp: true # Speed up training and reduce VRAM usage. NVIDIA only.
9
+ amp_bf16: false # Use bf16 instead of fp16 for AMP, RTX 3000 series or newer only. Only recommended if fp16 doesn't work.
10
+ use_channels_last: true # Enable channels last memory format while using AMP. Reduces VRAM and speeds up training for most architectures, but some architectures are slower with channels last.
11
+ fast_matmul: false # Trade precision for performance.
12
+ use_compile: false # Enable torch.compile for generator. Takes time on startup to compile the model, but can speed up training after the model is compiled.
13
+ compile_mode: default # Mode to use with torch.compile. See https://docs.pytorch.org/docs/stable/generated/torch.compile.html for more info.
14
+ num_gpu: auto
15
+ # manual_seed: 1024 # Random seed for training, useful for removing randomness when testing the effect of different settings.
16
+
17
+
18
+ ########################################################################################################################
19
+ # Dataset and Dataloader Settings
20
+ # https://trainner-redux.readthedocs.io/en/latest/config_reference.html#dataset-options-datasets-train-and-datasets-val
21
+ ########################################################################################################################
22
+ datasets:
23
+ # Settings for the training dataset.
24
+ train:
25
+ name: Train Dataset
26
+ type: pairedimagedataset
27
+ # Path to the HR (high res) images in your training dataset. Specify one or multiple folders, separated by commas.
28
+ dataroot_gt: [datasets/BHI]
29
+ dataroot_lq: [datasets/BHI_lq]
30
+ # meta_info: data/meta_info/dataset1.txt
31
+
32
+
33
+ lq_size: 64 # During training, a square of this size is cropped from LR images. Larger is usually better but uses more VRAM. Previously gt_size, use lq_size = gt_size / scale to convert. Use multiple of 8 for best performance with AMP.
34
+ use_hflip: true # Randomly flip the images horizontally.
35
+ use_rot: true # Randomly rotate the images.
36
+
37
+ num_worker_per_gpu: 8
38
+ batch_size_per_gpu: 64 # recommended: 64 # Increasing stabilizes training but with diminishing returns. Use multiple of 8 for best performance with AMP.
39
+ accum_iter: 1 # Using values larger than 1 simulates higher batch size by trading performance for reduced VRAM usage. If accum_iter = 4 and batch_size_per_gpu = 6 then effective batch size = 4 * 6 = 24 but performance may be as much as 4 times as slow.
40
+ # Settings for your validation dataset (optional). These settings will
41
+ # be ignored if val_enabled is false in the Validation section below.
42
+ val:
43
+ name: Val Dataset
44
+ type: pairedimagedataset
45
+ dataroot_gt: [
46
+ datasets/val/dataset1/hr,
47
+ datasets/val/dataset1/hr2,
48
+ ]
49
+ dataroot_lq: [
50
+ datasets/val/dataset1/lr,
51
+ datasets/val/dataset1/lr2
52
+ ]
53
+
54
+ #####################################################################
55
+ # Network Settings
56
+ # https://trainner-redux.readthedocs.io/en/latest/arch_reference.html
57
+ #####################################################################
58
+ # Generator model settings
59
+ network_g:
60
+ type: FIGSR
61
+
62
+ #########################################################################################
63
+ # Pretrain and Resume Paths
64
+ # https://trainner-redux.readthedocs.io/en/latest/config_reference.html#path-options-path
65
+ #########################################################################################
66
+ path:
67
+ # pretrain_network_g: experiments/pretrained_models/pretrain.pth
68
+ param_key_g: ~
69
+ strict_load_g: true # Disable strict loading to partially load a pretrain model with a different scale
70
+ resume_state: ~
71
+
72
+ ###########################################################################################
73
+ # Training Settings
74
+ # https://trainner-redux.readthedocs.io/en/latest/config_reference.html#train-options-train
75
+ ###########################################################################################
76
+ train:
77
+ ema_decay: 0.999
78
+ ema_power: 0.75 # Gradually warm up ema decay when training from scratch
79
+ grad_clip: false # Gradient clipping allows more stable training when using higher learning rates.
80
+ # Optimizer for generator model
81
+ optim_g:
82
+ type: AdamW
83
+ lr: !!float 5e-4
84
+ weight_decay: 0
85
+ betas: [0.9, 0.99]
86
+
87
+ scheduler:
88
+ type: MultiStepLR
89
+ milestones: [200000, 400000, 600000, 800000]
90
+ gamma: 0.5
91
+
92
+ total_iter: 1000000 # Total number of iterations.
93
+ warmup_iter: -1 # Gradually ramp up learning rates until this iteration, to stabilize early training. Use -1 to disable.
94
+
95
+ # Losses - for any loss set the loss_weight to 0 to disable it.
96
+ # https://trainner-redux.readthedocs.io/en/latest/loss_reference.html
97
+ losses:
98
+ # Charbonnier loss
99
+ - type: charbonnierloss
100
+ loss_weight: 1.0
101
+
102
+ ##############################################################################################
103
+ # Validation
104
+ # https://trainner-redux.readthedocs.io/en/latest/config_reference.html#validation-options-val
105
+ ##############################################################################################
106
+ val:
107
+ val_enabled: true # Whether to enable validations. If disabled, all validation settings below are ignored.
108
+ val_freq: 5000 # How often to run validations, in iterations.
109
+ save_img: true # Whether to save the validation images during validation, in the experiments/<name>/visualization folder.
110
+ tile_size: 0 # Tile size of input, reduce VRAM usage but slower inference. 0 to disable.
111
+ tile_overlap: 8 # Number of pixels to overlap tiles by, larger is slower but reduces tile seams.
112
+
113
+ metrics_enabled: true # Whether to run metrics calculations during validation.
114
+ metrics:
115
+ psnr:
116
+ type: calculate_psnr
117
+ crop_border: 4
118
+ test_y_channel: true
119
+ ssim:
120
+ type: calculate_ssim
121
+ crop_border: 4 # Whether to crop border during validation.
122
+ test_y_channel: true # Whether to convert to Y(CbCr) for validation.
123
+ #topiq:
124
+ #type: calculate_topiq
125
+ #lpips:
126
+ #type: calculate_lpips
127
+ #better: lower
128
+ #dists:
129
+ #type: calculate_dists
130
+ #better: lower
131
+
132
+ ##############################################################################################
133
+ # Logging
134
+ # https://trainner-redux.readthedocs.io/en/latest/config_reference.html#logging-options-logger
135
+ ##############################################################################################
136
+ logger:
137
+ print_freq: 100
138
+ save_checkpoint_freq: 5000
139
+ save_checkpoint_format: safetensors
140
+ use_tb_logger: true
figs/FIDSR.png ADDED

Git LFS Details

  • SHA256: 4312b4443e8848f950a36c65521dcecaade497247c0351ef60d5335f75aecc46
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
figs/gdb_and_FU.png ADDED
figsr_arch.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Literal
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch import Tensor, nn
9
+
10
+ # trainner-redux https://github.com/the-database/traiNNer-redux
11
+ # from traiNNer.utils.registry import ARCH_REGISTRY
12
+
13
+ # neosr https://github.com/neosr-project/neosr/tree/master
14
+ # from neosr.archs.arch_util import net_opt
15
+ # from neosr.utils.registry import ARCH_REGISTRY
16
+ #
17
+ # upscale, __ = net_opt()
18
+ # basic sr https://github.com/XPixelGroup/BasicSR/tree/master
19
+ # from basicsr.utils.registry import ARCH_REGISTRY
20
+ SampleMods = Literal[
21
+ "conv",
22
+ "pixelshuffledirect",
23
+ "pixelshuffle",
24
+ "nearest+conv",
25
+ "dysample",
26
+ "transpose+conv",
27
+ "lda",
28
+ "pa_up",
29
+ ]
30
+
31
+
32
+ def ICNR(tensor, initializer, upscale_factor=2, *args, **kwargs):
33
+ upscale_factor_squared = upscale_factor * upscale_factor
34
+ assert tensor.shape[0] % upscale_factor_squared == 0, (
35
+ "The size of the first dimension: "
36
+ f"tensor.shape[0] = {tensor.shape[0]}"
37
+ " is not divisible by square of upscale_factor: "
38
+ f"upscale_factor = {upscale_factor}"
39
+ )
40
+ sub_kernel = torch.empty(
41
+ tensor.shape[0] // upscale_factor_squared, *tensor.shape[1:]
42
+ )
43
+ sub_kernel = initializer(sub_kernel, *args, **kwargs)
44
+ return sub_kernel.repeat_interleave(upscale_factor_squared, dim=0)
45
+
46
+
47
+ class DySample(nn.Module):
48
+ """Adapted from 'Learning to Upsample by Learning to Sample':
49
+ https://arxiv.org/abs/2308.15085
50
+ https://github.com/tiny-smart/dysample
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ in_channels: int = 64,
56
+ out_ch: int = 3,
57
+ scale: int = 2,
58
+ groups: int = 4,
59
+ end_convolution: bool = True,
60
+ end_kernel=1,
61
+ ) -> None:
62
+ super().__init__()
63
+
64
+ if in_channels <= groups or in_channels % groups != 0:
65
+ msg = "Incorrect in_channels and groups values."
66
+ raise ValueError(msg)
67
+
68
+ out_channels = 2 * groups * scale**2
69
+ self.scale = scale
70
+ self.groups = groups
71
+ self.end_convolution = end_convolution
72
+ if end_convolution:
73
+ self.end_conv = nn.Conv2d(
74
+ in_channels, out_ch, end_kernel, 1, end_kernel // 2
75
+ )
76
+ self.offset = nn.Conv2d(in_channels, out_channels, 1)
77
+ self.scope = nn.Conv2d(in_channels, out_channels, 1, bias=False)
78
+ if self.training:
79
+ nn.init.trunc_normal_(self.offset.weight, std=0.02)
80
+ nn.init.constant_(self.scope.weight, val=0)
81
+
82
+ self.register_buffer("init_pos", self._init_pos())
83
+
84
+ def _init_pos(self) -> Tensor:
85
+ h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
86
+ return (
87
+ torch.stack(torch.meshgrid([h, h], indexing="ij"))
88
+ .transpose(1, 2)
89
+ .repeat(1, self.groups, 1)
90
+ .reshape(1, -1, 1, 1)
91
+ )
92
+
93
+ def forward(self, x: Tensor) -> Tensor:
94
+ offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
95
+ B, _, H, W = offset.shape
96
+ offset = offset.view(B, 2, -1, H, W)
97
+ coords_h = torch.arange(H) + 0.5
98
+ coords_w = torch.arange(W) + 0.5
99
+
100
+ coords = (
101
+ torch.stack(torch.meshgrid([coords_w, coords_h], indexing="ij"))
102
+ .transpose(1, 2)
103
+ .unsqueeze(1)
104
+ .unsqueeze(0)
105
+ .type(x.dtype)
106
+ .to(x.device, non_blocking=True)
107
+ )
108
+ normalizer = torch.tensor(
109
+ [W, H], dtype=x.dtype, device=x.device, pin_memory=True
110
+ ).view(1, 2, 1, 1, 1)
111
+ coords = 2 * (coords + offset) / normalizer - 1
112
+
113
+ coords = (
114
+ F.pixel_shuffle(coords.reshape(B, -1, H, W), self.scale)
115
+ .view(B, 2, -1, self.scale * H, self.scale * W)
116
+ .permute(0, 2, 3, 4, 1)
117
+ .contiguous()
118
+ .flatten(0, 1)
119
+ )
120
+ output = F.grid_sample(
121
+ x.reshape(B * self.groups, -1, H, W),
122
+ coords,
123
+ mode="bilinear",
124
+ align_corners=False,
125
+ padding_mode="border",
126
+ ).view(B, -1, self.scale * H, self.scale * W)
127
+
128
+ if self.end_convolution:
129
+ output = self.end_conv(output)
130
+
131
+ return output
132
+
133
+
134
+ class LayerNorm(nn.Module):
135
+ def __init__(self, dim: int = 64, eps: float = 1e-6) -> None:
136
+ super().__init__()
137
+ self.weight = nn.Parameter(torch.ones(dim))
138
+ self.bias = nn.Parameter(torch.zeros(dim))
139
+ self.eps = eps
140
+ self.dim = (dim,)
141
+
142
+ def forward(self, x):
143
+ if x.is_contiguous(memory_format=torch.channels_last):
144
+ return F.layer_norm(
145
+ x.permute(0, 2, 3, 1), self.dim, self.weight, self.bias, self.eps
146
+ ).permute(0, 3, 1, 2)
147
+ u = x.mean(1, keepdim=True)
148
+ s = (x - u).pow(2).mean(1, keepdim=True)
149
+ x = (x - u) / torch.sqrt(s + self.eps)
150
+ return self.weight[:, None, None] * x + self.bias[:, None, None]
151
+
152
+
153
+ class LDA_AQU(nn.Module):
154
+ def __init__(
155
+ self,
156
+ in_channels=48,
157
+ reduction_factor=4,
158
+ nh=1,
159
+ scale_factor=2.0,
160
+ k_e=3,
161
+ k_u=3,
162
+ n_groups=2,
163
+ range_factor=11,
164
+ rpb=True,
165
+ ) -> None:
166
+ super().__init__()
167
+ self.k_u = k_u
168
+ self.num_head = nh
169
+ self.scale_factor = scale_factor
170
+ self.n_groups = n_groups
171
+ self.offset_range_factor = range_factor
172
+
173
+ self.attn_dim = in_channels // (reduction_factor * self.num_head)
174
+ self.scale = self.attn_dim**-0.5
175
+ self.rpb = rpb
176
+ self.hidden_dim = in_channels // reduction_factor
177
+ self.proj_q = nn.Conv2d(
178
+ in_channels, self.hidden_dim, kernel_size=1, stride=1, padding=0, bias=False
179
+ )
180
+
181
+ self.proj_k = nn.Conv2d(
182
+ in_channels, self.hidden_dim, kernel_size=1, stride=1, padding=0, bias=False
183
+ )
184
+
185
+ self.group_channel = in_channels // (reduction_factor * self.n_groups)
186
+ # print(self.group_channel)
187
+ self.conv_offset = nn.Sequential(
188
+ nn.Conv2d(
189
+ self.group_channel,
190
+ self.group_channel,
191
+ 3,
192
+ 1,
193
+ 1,
194
+ groups=self.group_channel,
195
+ bias=False,
196
+ ),
197
+ LayerNorm(self.group_channel),
198
+ nn.SiLU(),
199
+ nn.Conv2d(self.group_channel, 2 * k_u**2, k_e, 1, k_e // 2),
200
+ )
201
+ print(2 * k_u**2)
202
+ self.layer_norm = LayerNorm(in_channels)
203
+
204
+ self.pad = int((self.k_u - 1) / 2)
205
+ base = np.arange(-self.pad, self.pad + 1).astype(np.float32)
206
+ base_y = np.repeat(base, self.k_u)
207
+ base_x = np.tile(base, self.k_u)
208
+ base_offset = np.stack([base_y, base_x], axis=1).flatten()
209
+ base_offset = torch.tensor(base_offset).view(1, -1, 1, 1)
210
+ self.register_buffer("base_offset", base_offset, persistent=False)
211
+
212
+ if self.rpb:
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(
215
+ 1, self.num_head, 1, self.k_u**2, self.hidden_dim // self.num_head
216
+ )
217
+ )
218
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
219
+
220
+ def init_weights(self) -> None:
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ nn.init.xavier_uniform(m)
224
+ elif isinstance(m, nn.LayerNorm):
225
+ nn.init.constant_(m.bias, 0)
226
+ nn.init.constant_(m.weight, 1.0)
227
+ nn.init.constant_(self.conv_offset[-1].weight, 0)
228
+ nn.init.constant_(self.conv_offset[-1].bias, 0)
229
+
230
+ def get_offset(self, offset, Hout, Wout):
231
+ B, _, _, _ = offset.shape
232
+ device = offset.device
233
+ row_indices = torch.arange(Hout, device=device)
234
+ col_indices = torch.arange(Wout, device=device)
235
+ row_indices, col_indices = torch.meshgrid(row_indices, col_indices)
236
+ index_tensor = torch.stack((row_indices, col_indices), dim=-1).view(
237
+ 1, Hout, Wout, 2
238
+ )
239
+ offset = rearrange(
240
+ offset, "b (kh kw d) h w -> b kh h kw w d", kh=self.k_u, kw=self.k_u
241
+ )
242
+ offset = offset + index_tensor.view(1, 1, Hout, 1, Wout, 2)
243
+ offset = offset.contiguous().view(B, self.k_u * Hout, self.k_u * Wout, 2)
244
+
245
+ offset[..., 0] = 2 * offset[..., 0] / (Hout - 1) - 1
246
+ offset[..., 1] = 2 * offset[..., 1] / (Wout - 1) - 1
247
+ offset = offset.flip(-1)
248
+ return offset
249
+
250
+ def extract_feats(self, x, offset, ks=3):
251
+ out = nn.functional.grid_sample(
252
+ x, offset, mode="bilinear", padding_mode="zeros", align_corners=True
253
+ )
254
+ out = rearrange(out, "b c (ksh h) (ksw w) -> b (ksh ksw) c h w", ksh=ks, ksw=ks)
255
+ return out
256
+
257
+ def forward(self, x):
258
+ B, C, H, W = x.shape
259
+ out_H, out_W = int(H * self.scale_factor), int(W * self.scale_factor)
260
+ v = x
261
+ x = self.layer_norm(x)
262
+ q = self.proj_q(x)
263
+ k = self.proj_k(x)
264
+
265
+ q = torch.nn.functional.interpolate(
266
+ q, (out_H, out_W), mode="bilinear", align_corners=True
267
+ )
268
+ q_off = q.view(B * self.n_groups, -1, out_H, out_W)
269
+ pred_offset = self.conv_offset(q_off)
270
+ offset = pred_offset.tanh().mul(self.offset_range_factor) + self.base_offset.to(
271
+ x.dtype
272
+ )
273
+
274
+ k = k.view(B * self.n_groups, self.hidden_dim // self.n_groups, H, W)
275
+ v = v.view(B * self.n_groups, C // self.n_groups, H, W)
276
+ offset = self.get_offset(offset, out_H, out_W)
277
+ k = self.extract_feats(k, offset=offset)
278
+ v = self.extract_feats(v, offset=offset)
279
+
280
+ q = rearrange(q, "b (nh c) h w -> b nh (h w) () c", nh=self.num_head)
281
+ k = rearrange(k, "(b g) n c h w -> b (h w) n (g c)", g=self.n_groups)
282
+ v = rearrange(v, "(b g) n c h w -> b (h w) n (g c)", g=self.n_groups)
283
+ k = rearrange(k, "b n1 n (nh c) -> b nh n1 n c", nh=self.num_head)
284
+ v = rearrange(v, "b n1 n (nh c) -> b nh n1 n c", nh=self.num_head)
285
+
286
+ if self.rpb:
287
+ k = k + self.relative_position_bias_table
288
+
289
+ q = q * self.scale
290
+ attn = q @ k.transpose(-1, -2)
291
+ attn = attn.softmax(dim=-1)
292
+ out = attn @ v
293
+
294
+ out = rearrange(out, "b nh (h w) t c -> b (nh c) (t h) w", h=out_H)
295
+ return out
296
+
297
+
298
+ class PA(nn.Module):
299
+ def __init__(self, dim) -> None:
300
+ super().__init__()
301
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim, 1), nn.Sigmoid())
302
+
303
+ def forward(self, x):
304
+ return x.mul(self.conv(x))
305
+
306
+
307
+ class UniUpsampleV3(nn.Sequential):
308
+ def __init__(
309
+ self,
310
+ upsample: SampleMods = "pa_up",
311
+ scale: int = 2,
312
+ in_dim: int = 48,
313
+ out_dim: int = 3,
314
+ mid_dim: int = 48,
315
+ group: int = 4, # Only DySample
316
+ dysample_end_kernel=1, # needed only for compatibility with version 2
317
+ ) -> None:
318
+ m = []
319
+
320
+ if scale == 1 or upsample == "conv":
321
+ m.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1))
322
+ elif upsample == "pixelshuffledirect":
323
+ m.extend(
324
+ [nn.Conv2d(in_dim, out_dim * scale**2, 3, 1, 1), nn.PixelShuffle(scale)]
325
+ )
326
+ elif upsample == "pixelshuffle":
327
+ m.extend([nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)])
328
+ if (scale & (scale - 1)) == 0: # scale = 2^n
329
+ for _ in range(int(math.log2(scale))):
330
+ m.extend(
331
+ [nn.Conv2d(mid_dim, 4 * mid_dim, 3, 1, 1), nn.PixelShuffle(2)]
332
+ )
333
+ elif scale == 3:
334
+ m.extend([nn.Conv2d(mid_dim, 9 * mid_dim, 3, 1, 1), nn.PixelShuffle(3)])
335
+ else:
336
+ raise ValueError(
337
+ f"scale {scale} is not supported. Supported scales: 2^n and 3."
338
+ )
339
+ m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1))
340
+ elif upsample == "nearest+conv":
341
+ if (scale & (scale - 1)) == 0:
342
+ for _ in range(int(math.log2(scale))):
343
+ m.extend(
344
+ (
345
+ nn.Conv2d(in_dim, in_dim, 3, 1, 1),
346
+ nn.Upsample(scale_factor=2),
347
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
348
+ )
349
+ )
350
+ m.extend(
351
+ (
352
+ nn.Conv2d(in_dim, in_dim, 3, 1, 1),
353
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
354
+ )
355
+ )
356
+ elif scale == 3:
357
+ m.extend(
358
+ (
359
+ nn.Conv2d(in_dim, in_dim, 3, 1, 1),
360
+ nn.Upsample(scale_factor=scale),
361
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
362
+ nn.Conv2d(in_dim, in_dim, 3, 1, 1),
363
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
364
+ )
365
+ )
366
+ else:
367
+ raise ValueError(
368
+ f"scale {scale} is not supported. Supported scales: 2^n and 3."
369
+ )
370
+ m.append(nn.Conv2d(in_dim, out_dim, 3, 1, 1))
371
+ elif upsample == "dysample":
372
+ if mid_dim != in_dim:
373
+ m.extend(
374
+ [nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)]
375
+ )
376
+ m.append(
377
+ DySample(mid_dim, out_dim, scale, group, end_kernel=dysample_end_kernel)
378
+ )
379
+ # m.append(nn.Conv2d(mid_dim, out_dim, dysample_end_kernel, 1, dysample_end_kernel//2)) # kernel 1 causes chromatic artifacts
380
+ elif upsample == "transpose+conv":
381
+ if scale == 2:
382
+ m.append(nn.ConvTranspose2d(in_dim, out_dim, 4, 2, 1))
383
+ elif scale == 3:
384
+ m.append(nn.ConvTranspose2d(in_dim, out_dim, 3, 3, 0))
385
+ elif scale == 4:
386
+ m.extend(
387
+ [
388
+ nn.ConvTranspose2d(in_dim, in_dim, 4, 2, 1),
389
+ nn.GELU(),
390
+ nn.ConvTranspose2d(in_dim, out_dim, 4, 2, 1),
391
+ ]
392
+ )
393
+ else:
394
+ raise ValueError(
395
+ f"scale {scale} is not supported. Supported scales: 2, 3, 4"
396
+ )
397
+ m.append(nn.Conv2d(out_dim, out_dim, 3, 1, 1))
398
+ elif upsample == "lda":
399
+ if mid_dim != in_dim:
400
+ m.extend(
401
+ [nn.Conv2d(in_dim, mid_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)]
402
+ )
403
+ m.append(LDA_AQU(mid_dim, scale_factor=scale))
404
+ m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1))
405
+ elif upsample == "pa_up":
406
+ if (scale & (scale - 1)) == 0:
407
+ for _ in range(int(math.log2(scale))):
408
+ m.extend(
409
+ [
410
+ nn.Upsample(scale_factor=2),
411
+ nn.Conv2d(in_dim, mid_dim, 3, 1, 1),
412
+ PA(mid_dim),
413
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
414
+ nn.Conv2d(mid_dim, mid_dim, 3, 1, 1),
415
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
416
+ ]
417
+ )
418
+ in_dim = mid_dim
419
+ elif scale == 3:
420
+ m.extend(
421
+ [
422
+ nn.Upsample(scale_factor=3),
423
+ nn.Conv2d(in_dim, mid_dim, 3, 1, 1),
424
+ PA(mid_dim),
425
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
426
+ nn.Conv2d(mid_dim, mid_dim, 3, 1, 1),
427
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
428
+ ]
429
+ )
430
+ else:
431
+ raise ValueError(
432
+ f"scale {scale} is not supported. Supported scales: 2^n and 3."
433
+ )
434
+ m.append(nn.Conv2d(mid_dim, out_dim, 3, 1, 1))
435
+ else:
436
+ raise ValueError(
437
+ f"An invalid Upsample was selected. Please choose one of {SampleMods}"
438
+ )
439
+ super().__init__(*m)
440
+
441
+ self.register_buffer(
442
+ "MetaUpsample",
443
+ torch.tensor(
444
+ [
445
+ 3, # Block version, if you change something, please number from the end so that you can distinguish between authorized changes and third parties
446
+ list(SampleMods.__args__).index(upsample), # UpSample method index
447
+ scale,
448
+ in_dim,
449
+ out_dim,
450
+ mid_dim,
451
+ group,
452
+ ],
453
+ dtype=torch.uint8,
454
+ ),
455
+ )
456
+
457
+
458
+ class RMSNorm(nn.Module):
459
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
460
+ super().__init__()
461
+ self.scale = nn.Parameter(torch.ones(dim))
462
+ self.offset = nn.Parameter(torch.zeros(dim))
463
+ self.eps = nn.Parameter(torch.Tensor(torch.ones(1) * eps), requires_grad=False)
464
+ self.rms = nn.Parameter(
465
+ torch.Tensor(torch.ones(1) * (dim**-0.5)), requires_grad=False
466
+ )
467
+
468
+ def forward(self, x: Tensor) -> Tensor:
469
+ norm_x = torch.addcmul(self.eps, x.norm(2, dim=1, keepdim=True), self.rms)
470
+ return torch.addcmul(
471
+ self.offset[:, None, None], x.div(norm_x), self.scale[:, None, None]
472
+ )
473
+
474
+
475
+ class CustomRFFT2(torch.autograd.Function):
476
+ @staticmethod
477
+ def forward(ctx, x: torch.Tensor):
478
+ y = torch.fft.rfft2(x, dim=(2, 3), norm="ortho")
479
+ return torch.view_as_real(y)
480
+
481
+ @staticmethod
482
+ def symbolic(g, x: torch.Value):
483
+ shp = g.op("Shape", x)
484
+ iH = g.op("Constant", value_t=torch.tensor([2], dtype=torch.int64))
485
+ iW = g.op("Constant", value_t=torch.tensor([3], dtype=torch.int64))
486
+ nH = g.op("Gather", shp, iH, axis_i=0)
487
+ nW = g.op("Gather", shp, iW, axis_i=0)
488
+
489
+ axes_last = g.op("Constant", value_t=torch.tensor([4], dtype=torch.int64))
490
+ x_u = g.op("Unsqueeze", x, axes_last)
491
+ zero = g.op("Sub", x_u, x_u)
492
+ x_c = g.op("Concat", x_u, zero, axis_i=4)
493
+
494
+ Hf = g.op("Cast", nH, to_i=torch.onnx.TensorProtoDataType.FLOAT)
495
+ Wf = g.op("Cast", nW, to_i=torch.onnx.TensorProtoDataType.FLOAT)
496
+
497
+ y = g.op("DFT", x_c, nW, axis_i=3, onesided_i=1)
498
+ y = g.op("Div", y, g.op("Sqrt", Wf))
499
+
500
+ y = g.op("DFT", y, nH, axis_i=2, onesided_i=0)
501
+ y = g.op("Div", y, g.op("Sqrt", Hf))
502
+
503
+ return y
504
+
505
+
506
+ class CustomIRFFT2(torch.autograd.Function):
507
+ @staticmethod
508
+ def forward(ctx, x_ri: torch.Tensor):
509
+ x_c = torch.view_as_complex(x_ri)
510
+ return torch.fft.irfft2(x_c, dim=(2, 3), norm="ortho")
511
+
512
+ @staticmethod
513
+ def symbolic(g, x: torch.Value):
514
+ shp = g.op("Shape", x)
515
+ iH = g.op("Constant", value_t=torch.tensor([2], dtype=torch.int64))
516
+ iWr = g.op("Constant", value_t=torch.tensor([3], dtype=torch.int64))
517
+ nH = g.op("Gather", shp, iH, axis_i=0)
518
+ nWr = g.op("Gather", shp, iWr, axis_i=0)
519
+
520
+ one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
521
+ two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))
522
+ nW = g.op("Mul", g.op("Sub", nWr, one), two)
523
+ Hf = g.op("Cast", nH, to_i=torch.onnx.TensorProtoDataType.FLOAT)
524
+ Wf = g.op("Cast", nW, to_i=torch.onnx.TensorProtoDataType.FLOAT)
525
+
526
+ yH = g.op("DFT", x, nH, axis_i=2, inverse_i=1, onesided_i=0)
527
+ yH = g.op("Mul", yH, g.op("Sqrt", Hf))
528
+
529
+ start = g.op("Sub", nWr, two)
530
+ start = g.op(
531
+ "Squeeze",
532
+ start,
533
+ g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
534
+ )
535
+ limit = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
536
+ step = g.op("Constant", value_t=torch.tensor(-1, dtype=torch.int64))
537
+ idx_r = g.op("Range", start, limit, step)
538
+
539
+ mirW = g.op("Gather", yH, idx_r, axis_i=3)
540
+ maskW = g.op("Constant", value_t=torch.tensor([1.0, -1.0], dtype=torch.float32))
541
+ maskW = g.op(
542
+ "Unsqueeze",
543
+ maskW,
544
+ g.op("Constant", value_t=torch.tensor([0, 1, 2, 3], dtype=torch.int64)),
545
+ )
546
+ mirWc = g.op("Mul", mirW, maskW)
547
+ x_full = g.op("Concat", yH, mirWc, axis_i=3)
548
+
549
+ y = g.op("DFT", x_full, nW, axis_i=3, inverse_i=1, onesided_i=0)
550
+ y = g.op("Mul", y, g.op("Sqrt", Wf))
551
+
552
+ s0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
553
+ s1 = g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
554
+ axC = g.op("Constant", value_t=torch.tensor([4], dtype=torch.int64))
555
+ y = g.op("Slice", y, s0, s1, axC)
556
+ y = g.op("Squeeze", y, axC)
557
+
558
+ return y
559
+
560
+
561
+ class CustomRfft2Wrap(nn.Module):
562
+ def __init__(self) -> None:
563
+ super().__init__()
564
+
565
+ def forward(self, x):
566
+ if self.training:
567
+ y = torch.fft.rfft2(x, dim=(2, 3), norm="ortho")
568
+ return torch.view_as_real(y)
569
+ else:
570
+ return CustomRFFT2().apply(x)
571
+
572
+
573
+ class CustomIrfft2Wrap(nn.Module):
574
+ def __init__(self) -> None:
575
+ super().__init__()
576
+
577
+ def forward(self, x):
578
+ if self.training:
579
+ x_c = torch.view_as_complex(x) # [B,C,H,Wr]
580
+ return torch.fft.irfft2(x_c, dim=(2, 3), norm="ortho") # [B,C,H,W]
581
+ else:
582
+ return CustomIRFFT2().apply(x)
583
+
584
+
585
+ class FourierUnit(nn.Module):
586
+ def __init__(self, in_channels: int = 48, out_channels: int = 48) -> None:
587
+ super().__init__()
588
+ self.rn = RMSNorm(out_channels * 2)
589
+ self.post_norm = RMSNorm(out_channels)
590
+
591
+ self.fdc = nn.Conv2d(
592
+ in_channels=in_channels * 2,
593
+ out_channels=out_channels * 2,
594
+ kernel_size=1,
595
+ bias=True,
596
+ )
597
+
598
+ self.fpe = nn.Conv2d(
599
+ in_channels=in_channels * 2,
600
+ out_channels=in_channels * 2,
601
+ kernel_size=3,
602
+ padding=1,
603
+ groups=in_channels * 2,
604
+ bias=True,
605
+ )
606
+ self.gelu = nn.GELU()
607
+ self.irfft2 = CustomIrfft2Wrap()
608
+ self.rfft2 = CustomRfft2Wrap()
609
+
610
+ def forward(self, x: Tensor) -> Tensor:
611
+ orig_dtype = x.dtype
612
+ x = x.to(torch.float32)
613
+ b, c, h, w = x.shape
614
+ ffted = self.rfft2(x)
615
+ ffted = ffted.permute(0, 4, 1, 2, 3).contiguous()
616
+ ffted = ffted.view(b, c * 2, h, -1).to(orig_dtype)
617
+ ffted = self.rn(ffted)
618
+ ffted = self.fpe(ffted) + ffted
619
+ ffted = self.fdc(ffted)
620
+ ffted = self.gelu(ffted)
621
+ ffted = ffted.view(b, c, 2, h, -1).permute(0, 1, 3, 4, 2).contiguous().float()
622
+ out = self.irfft2(ffted)
623
+ out = self.post_norm(out.to(orig_dtype))
624
+ return out
625
+
626
+
627
+ class InceptionConv2d(nn.Module):
628
+ """Inception convolution"""
629
+
630
+ def __init__(
631
+ self,
632
+ fu_dim: int = 24,
633
+ gc: int = 8,
634
+ square_kernel_size: int = 13,
635
+ band_kernel_size: int = 17,
636
+ ) -> None:
637
+ super().__init__()
638
+
639
+ self.fu = FourierUnit(fu_dim, fu_dim)
640
+ self.convhw = nn.Conv2d(
641
+ gc, gc, square_kernel_size, padding=square_kernel_size // 2
642
+ )
643
+ self.convw = nn.Conv2d(
644
+ gc,
645
+ gc,
646
+ kernel_size=(1, band_kernel_size),
647
+ padding=(0, band_kernel_size // 2),
648
+ )
649
+ self.convh = nn.Conv2d(
650
+ gc,
651
+ gc,
652
+ kernel_size=(band_kernel_size, 1),
653
+ padding=(band_kernel_size // 2, 0),
654
+ )
655
+
656
+ def forward(
657
+ self, x: Tensor, x_hw: Tensor, x_w: Tensor, xh: Tensor
658
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
659
+ return self.fu(x), self.convhw(x_hw), self.convw(x_w), self.convh(xh)
660
+
661
+
662
+ class GatedCNNBlock(nn.Module):
663
+ def __init__(
664
+ self,
665
+ dim: int = 64,
666
+ expansion_ratio: float = 8 / 3,
667
+ gc: int = 8,
668
+ square_kernel_size: int = 13,
669
+ band_kernel_size: int = 17,
670
+ ) -> None:
671
+ super().__init__()
672
+ hidden = int(expansion_ratio * dim) // 8 * 8
673
+ self.norm = RMSNorm(dim)
674
+ self.fc1 = nn.Conv2d(dim, hidden * 2, 3, 1, 1)
675
+ self.act = nn.SiLU()
676
+ self.split_indices = [hidden, hidden - dim, dim - gc * 3, gc, gc, gc]
677
+ self.conv = InceptionConv2d(
678
+ dim - gc * 3, gc, square_kernel_size, band_kernel_size
679
+ )
680
+ self.fc2 = nn.Conv2d(hidden, dim, 3, 1, 1)
681
+
682
+ def gated_forward(self, x: Tensor) -> Tensor:
683
+ x = self.norm(x)
684
+ x = self.fc1(x)
685
+ g, i, c, c_hw, c_w, c_h = torch.split(x, self.split_indices, dim=1)
686
+ c, c_hw, c_w, c_h = self.conv(c, c_hw, c_w, c_h)
687
+ x = self.fc2(self.act(g) * torch.cat((i, c, c_hw, c_w, c_h), dim=1))
688
+ return x
689
+
690
+ def forward(self, x: Tensor) -> Tensor:
691
+ return self.gated_forward(x) + x
692
+
693
+
694
+ # @ARCH_REGISTRY.register()
695
+ class FIGSR(nn.Module):
696
+ """Fourier Inception Gated Super Resolution"""
697
+
698
+ def __init__(
699
+ self,
700
+ in_nc: int = 3,
701
+ dim: int = 48,
702
+ expansion_ratio: float = 8 / 3,
703
+ scale: int = 4,
704
+ # neosr style:
705
+ # scale=upscale
706
+ out_nc: int = 3,
707
+ upsampler: SampleMods = "pixelshuffledirect",
708
+ mid_dim: int = 32,
709
+ n_blocks: int = 24,
710
+ gc: int = 8,
711
+ square_kernel_size: int = 13,
712
+ band_kernel_size: int = 17,
713
+ **kwargs,
714
+ ) -> None:
715
+ super().__init__()
716
+ self.in_to_dim = nn.Conv2d(in_nc, dim, 3, 1, 1)
717
+ self.pad = 2
718
+ self.gfisr_body_half = nn.Sequential(
719
+ *[
720
+ GatedCNNBlock(
721
+ dim, expansion_ratio, gc, square_kernel_size, band_kernel_size
722
+ )
723
+ for _ in range(n_blocks // 2)
724
+ ]
725
+ )
726
+ self.gfisr_body_half_2 = nn.Sequential(
727
+ *[
728
+ GatedCNNBlock(
729
+ dim, expansion_ratio, gc, square_kernel_size, band_kernel_size
730
+ )
731
+ for _ in range(n_blocks - n_blocks // 2)
732
+ ]
733
+ + [nn.Conv2d(dim, dim, 3, 1, 1)]
734
+ )
735
+ self.cat_to_dim = nn.Conv2d(dim * 3, dim, 1)
736
+ self.upscale = UniUpsampleV3(
737
+ upsampler, scale, dim, out_nc, mid_dim, dysample_end_kernel=3
738
+ )
739
+ if upsampler == "pixelshuffledirect":
740
+ weight = ICNR(
741
+ self.upscale[0].weight,
742
+ initializer=nn.init.kaiming_normal_,
743
+ upscale_factor=scale,
744
+ )
745
+ self.upscale[0].weight.data.copy_(weight)
746
+
747
+ self.scale = scale
748
+ self.shift = nn.Parameter(torch.ones(1, 3, 1, 1) * 0.5, requires_grad=True)
749
+ self.scale_norm = nn.Parameter(torch.ones(1, 3, 1, 1) / 6, requires_grad=True)
750
+
751
+ def load_state_dict(self, state_dict, strict=True, assign=True):
752
+ state_dict["upscale.MetaUpsample"] = self.upscale.MetaUpsample
753
+ return super().load_state_dict(state_dict, strict, assign)
754
+
755
+ def forward(self, x: Tensor) -> Tensor:
756
+ x = (x - self.shift) / self.scale_norm
757
+
758
+ _, _, H, W = x.shape
759
+ mod_pad_h = (self.pad - H % self.pad) % self.pad
760
+ mod_pad_w = (self.pad - W % self.pad) % self.pad
761
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
762
+
763
+ x = self.in_to_dim(x)
764
+ x0 = self.gfisr_body_half(x)
765
+ x1 = self.gfisr_body_half_2(x0)
766
+
767
+ x = self.cat_to_dim(torch.cat([x1, x, x0], dim=1))
768
+ x = self.upscale(x)[:, :, : H * self.scale, : W * self.scale]
769
+ return x * self.scale_norm + self.shift
inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+
5
+ from resselt import load_from_file
6
+ from pepeline import read, save, ImgColor, ImgFormat
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(
11
+ description="Batch image upscaling script"
12
+ )
13
+ parser.add_argument("--input_dir", type=str, required=True, help="Path to input images")
14
+ parser.add_argument("--output_dir", type=str, required=True, help="Path to save results")
15
+ parser.add_argument("--weights", type=str, required=True, help="Path to model weights")
16
+ parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
17
+ return parser.parse_args()
18
+
19
+
20
+ def load_model(weights_path: str, device: torch.device):
21
+ model = load_from_file(weights_path)
22
+ model = model.to(
23
+ device,
24
+ memory_format=torch.preserve_format,
25
+ non_blocking=True,
26
+ ).eval()
27
+ return model
28
+
29
+
30
+ def process_image(model, img_path: str, device: torch.device):
31
+ img = read(img_path, ImgColor.RGB, ImgFormat.F32).transpose(2, 0, 1)
32
+ img = (
33
+ torch.tensor(img)
34
+ .to(
35
+ device,
36
+ memory_format=torch.preserve_format,
37
+ non_blocking=True,
38
+ )
39
+ .unsqueeze(0)
40
+ )
41
+
42
+ with torch.autocast(device.type, torch.float16):
43
+ with torch.inference_mode():
44
+ output = model(img)
45
+
46
+ output = output.permute(0, 2, 3, 1).detach().cpu().numpy()[0]
47
+ return output
48
+
49
+
50
+ def main():
51
+ args = parse_args()
52
+
53
+ device = torch.device(
54
+ args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
55
+ )
56
+
57
+ os.makedirs(args.output_dir, exist_ok=True)
58
+
59
+ model = load_model(args.weights, device)
60
+
61
+ img_list = os.listdir(args.input_dir)
62
+ total = len(img_list)
63
+
64
+ for index, img_name in enumerate(img_list, start=1):
65
+ print(
66
+ f"\rProcessing {index}/{total} | {img_name}",
67
+ end="",
68
+ flush=True,
69
+ )
70
+
71
+ img_path = os.path.join(args.input_dir, img_name)
72
+ result = process_image(model, img_path, device)
73
+
74
+ save(result.copy(), os.path.join(args.output_dir, img_name))
75
+
76
+ print("\nDone.")
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
weights/4x_FIGSR.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86a88ec488726cb6b95642dc846cc7aeff583ed93bc5036d56ff08fd5ac9fb1f
3
+ size 18504754
weights/4x_FIGSR.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d48ee05490c6b63f043cc4b5c7aa546b0661016a7d61fc0deec14c3019e0e5c1
3
+ size 17913930
weights/4x_FIGSR.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98b239859ff3ca726ddc570eb0fb2c1a2f618b86c2d065a2f8cc9eee8e289c82
3
+ size 17763135