Spaces:
Runtime error
Runtime error
krakotay
commited on
Commit
·
f2b4d53
1
Parent(s):
a304ddb
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +9 -0
- LICENSE +201 -0
- app.py +277 -4
- assets/demo.gif +3 -0
- assets/metrics.png +0 -0
- assets/network.png +0 -0
- assets/title_any_image.gif +0 -0
- assets/title_harmon.gif +0 -0
- assets/title_you_want.gif +0 -0
- assets/visualizations.png +0 -0
- assets/visualizations2.png +3 -0
- datasets/__init__.py +0 -0
- datasets/build_INR_dataset.py +36 -0
- datasets/build_dataset.py +371 -0
- demo/demo_1k_composite_2.jpg +0 -0
- demo/demo_1k_composite_3.jpg +0 -0
- demo/demo_1k_mask_2.jpg +0 -0
- demo/demo_1k_mask_3.jpg +0 -0
- demo/demo_composite.jpg +0 -0
- demo/demo_composite_1.jpg +0 -0
- demo/demo_composite_2.jpg +0 -0
- demo/demo_composite_3.jpg +0 -0
- demo/demo_composite_4.jpg +0 -0
- demo/demo_composite_5.jpg +0 -0
- demo/demo_composite_6.jpg +0 -0
- demo/demo_mask.png +0 -0
- demo/demo_mask_1.png +0 -0
- demo/demo_mask_2.png +0 -0
- demo/demo_mask_3.png +0 -0
- demo/demo_mask_4.jpg +0 -0
- demo/demo_mask_5.jpg +0 -0
- demo/demo_mask_6.jpg +0 -0
- efficient_inference_for_square_image.py +356 -0
- hrnet_ocr.py +401 -0
- inference.py +236 -0
- inference_for_arbitrary_resolution_image.py +345 -0
- model/__init__.py +0 -0
- model/backbone.py +79 -0
- model/base/__init__.py +0 -0
- model/base/basic_blocks.py +366 -0
- model/base/conv_autoencoder.py +519 -0
- model/base/ih_model.py +88 -0
- model/base/ops.py +397 -0
- model/build_model.py +24 -0
- model/hrnetv2/__init__.py +0 -0
- model/hrnetv2/hrnet_ocr.py +405 -0
- model/hrnetv2/modifiers.py +11 -0
- model/hrnetv2/ocr.py +140 -0
- model/hrnetv2/resnetv1b.py +276 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea/*
|
| 2 |
+
logs/*
|
| 3 |
+
wandb/*
|
| 4 |
+
system/
|
| 5 |
+
*.bat
|
| 6 |
+
*.7z
|
| 7 |
+
.venv
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.pyc
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
app.py
CHANGED
|
@@ -1,7 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
demo.launch()
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import sys
|
| 8 |
+
import io
|
| 9 |
+
import spaces
|
| 10 |
+
|
| 11 |
+
class Logger:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self.terminal = sys.stdout
|
| 14 |
+
self.log = io.BytesIO()
|
| 15 |
+
|
| 16 |
+
def write(self, message):
|
| 17 |
+
self.terminal.write(message)
|
| 18 |
+
self.log.write(bytes(message, encoding='utf-8'))
|
| 19 |
+
|
| 20 |
+
def flush(self):
|
| 21 |
+
self.terminal.flush()
|
| 22 |
+
self.log.flush()
|
| 23 |
+
|
| 24 |
+
def isatty(self):
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
log = Logger()
|
| 29 |
+
sys.stdout = log
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def read_logs():
|
| 33 |
+
out = log.log.getvalue().decode()
|
| 34 |
+
if out.count("\n") >= 30:
|
| 35 |
+
log.log = io.BytesIO()
|
| 36 |
+
sys.stdout.flush()
|
| 37 |
+
return out
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
with gr.Blocks() as app:
|
| 41 |
+
|
| 42 |
+
valid_checkpoints_dict = {"Resolution_256_iHarmony4": "Resolution_256_iHarmony4.pth",
|
| 43 |
+
"Resolution_1024_HAdobe5K": "Resolution_1024_HAdobe5K.pth",
|
| 44 |
+
"Resolution_2048_HAdobe5K": "Resolution_2048_HAdobe5K.pth",
|
| 45 |
+
"Resolution_RAW_HAdobe5K": "Resolution_RAW_HAdobe5K.pth",
|
| 46 |
+
"Resolution_RAW_iHarmony4": "Resolution_RAW_iHarmony4.pth"}
|
| 47 |
+
|
| 48 |
+
global_state = gr.State(valid_checkpoints_dict["Resolution_RAW_iHarmony4"])
|
| 49 |
+
with gr.Row():
|
| 50 |
+
with gr.Column():
|
| 51 |
+
form_composite_image = gr.Image(label='Input Composite image', type='pil')
|
| 52 |
+
gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "composite" in i]),
|
| 53 |
+
label="Composite Examples", inputs=form_composite_image, cache_examples=False)
|
| 54 |
+
with gr.Column():
|
| 55 |
+
form_mask_image = gr.Image(label='Input Mask image', type='pil', interactive=False)
|
| 56 |
+
gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "mask" in i]),
|
| 57 |
+
label="Mask Examples", inputs=form_mask_image, cache_examples=False)
|
| 58 |
+
with gr.Row():
|
| 59 |
+
with gr.Column(scale=4):
|
| 60 |
+
with gr.Row():
|
| 61 |
+
with gr.Column():
|
| 62 |
+
gr.Markdown(value='Model Selection', show_label=False)
|
| 63 |
+
|
| 64 |
+
with gr.Column():
|
| 65 |
+
form_pretrained_dropdown = gr.Dropdown(
|
| 66 |
+
choices=list(valid_checkpoints_dict.values()),
|
| 67 |
+
label="Pretrained Model",
|
| 68 |
+
value=valid_checkpoints_dict["Resolution_RAW_iHarmony4"],
|
| 69 |
+
interactive=True
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
with gr.Row():
|
| 73 |
+
with gr.Column():
|
| 74 |
+
gr.Markdown(value='Inference Mode', show_label=False)
|
| 75 |
+
|
| 76 |
+
with gr.Column():
|
| 77 |
+
form_inference_mode = gr.Radio(
|
| 78 |
+
['Square Image', 'Arbitrary Image'],
|
| 79 |
+
value='Arbitrary Image',
|
| 80 |
+
interactive=False,
|
| 81 |
+
label='Mode',
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with gr.Row():
|
| 85 |
+
with gr.Column():
|
| 86 |
+
gr.Markdown(value='Split Parameter', show_label=False)
|
| 87 |
+
|
| 88 |
+
with gr.Column():
|
| 89 |
+
form_split_res = gr.Slider(
|
| 90 |
+
minimum=0,
|
| 91 |
+
maximum=2048,
|
| 92 |
+
step=128,
|
| 93 |
+
value=256,
|
| 94 |
+
interactive=True,
|
| 95 |
+
label="Split Resolution",
|
| 96 |
+
)
|
| 97 |
+
form_split_num = gr.Number(
|
| 98 |
+
value=2,
|
| 99 |
+
interactive=True,
|
| 100 |
+
label="Split Number")
|
| 101 |
+
with gr.Row():
|
| 102 |
+
form_log = gr.Textbox(read_logs, label="Logs", interactive=False, type="text", every=1)
|
| 103 |
+
|
| 104 |
+
with gr.Column(scale=4):
|
| 105 |
+
form_harmonized_image = gr.Image(label='Harmonized Result', type='numpy', interactive=False, format="png")
|
| 106 |
+
form_start_btn = gr.Button("Start Harmonization", interactive=False)
|
| 107 |
+
form_reset_btn = gr.Button("Reset", interactive=True)
|
| 108 |
+
form_stop_btn = gr.Button("Stop", interactive=True)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def on_change_form_composite_image(form_composite_image):
|
| 112 |
+
if form_composite_image is None:
|
| 113 |
+
return gr.update(interactive=False, value=None), gr.update(value=None)
|
| 114 |
+
return gr.update(interactive=True, value=None), gr.update(value=None)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def on_change_form_mask_image(form_composite_image, form_mask_image):
|
| 118 |
+
if form_mask_image is None:
|
| 119 |
+
return gr.update(interactive=False), gr.update(
|
| 120 |
+
interactive=False if form_composite_image is None else True), gr.update(interactive=False), gr.update(
|
| 121 |
+
interactive=False), gr.update(interactive=False), gr.update(value=None)
|
| 122 |
+
|
| 123 |
+
if form_composite_image.size[:2] != form_mask_image.size[:2]:
|
| 124 |
+
raise gr.Error("Composite image and mask image should have the same resolution!")
|
| 125 |
+
else:
|
| 126 |
+
w, h = form_composite_image.size[:2]
|
| 127 |
+
if h != w or (h % 16 != 0):
|
| 128 |
+
return gr.update(value='Arbitrary Image', interactive=False), gr.update(interactive=True), gr.update(
|
| 129 |
+
interactive=True), gr.update(interactive=True, visible=True), gr.update(interactive=False,
|
| 130 |
+
value=-1, visible=False), gr.update(value=None)
|
| 131 |
+
else:
|
| 132 |
+
return gr.update(value='Square Image', interactive=True), gr.update(interactive=True), gr.update(
|
| 133 |
+
interactive=True), gr.update(interactive=False, visible=False), gr.update(interactive=True,
|
| 134 |
+
value=h // 2,
|
| 135 |
+
maximum=h,
|
| 136 |
+
minimum=h // 16,
|
| 137 |
+
step=h // 16, visible=True), gr.update(value=None)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
form_composite_image.change(
|
| 141 |
+
on_change_form_composite_image,
|
| 142 |
+
inputs=[form_composite_image],
|
| 143 |
+
outputs=[form_mask_image, form_harmonized_image]
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
form_mask_image.change(
|
| 147 |
+
on_change_form_mask_image,
|
| 148 |
+
inputs=[form_composite_image, form_mask_image],
|
| 149 |
+
outputs=[form_inference_mode, form_mask_image, form_start_btn, form_split_num, form_split_res,
|
| 150 |
+
form_harmonized_image]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def on_change_form_split_num(form_composite_image, form_split_num):
|
| 155 |
+
w, h = form_composite_image.size[:2]
|
| 156 |
+
if form_split_num < 1:
|
| 157 |
+
return gr.update(value=1)
|
| 158 |
+
elif form_split_num > min(w, h):
|
| 159 |
+
return gr.update(value=min(w, h))
|
| 160 |
+
else:
|
| 161 |
+
return gr.update(value=form_split_num)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
form_split_num.change(
|
| 165 |
+
on_change_form_split_num,
|
| 166 |
+
inputs=[form_composite_image, form_split_num],
|
| 167 |
+
outputs=[form_split_num]
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def on_change_form_inference_mode(form_inference_mode):
|
| 172 |
+
if form_inference_mode == "Square Image":
|
| 173 |
+
return gr.update(interactive=True, visible=True), gr.update(interactive=False, visible=False)
|
| 174 |
+
else:
|
| 175 |
+
return gr.update(interactive=False, visible=False), gr.update(interactive=True, visible=True)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
form_inference_mode.change(on_change_form_inference_mode, inputs=[form_inference_mode],
|
| 179 |
+
outputs=[form_split_res, form_split_num])
|
| 180 |
+
|
| 181 |
+
@spaces.GPU
|
| 182 |
+
def on_click_form_start_btn(form_composite_image, form_mask_image, form_pretrained_dropdown, form_inference_mode,
|
| 183 |
+
form_split_res, form_split_num):
|
| 184 |
+
log.log = io.BytesIO()
|
| 185 |
+
print(f"Harmonizing image with {form_composite_image.size[1]}*{form_composite_image.size[0]}...")
|
| 186 |
+
if form_inference_mode == "Square Image":
|
| 187 |
+
from efficient_inference_for_square_image import parse_args, main_process, global_state
|
| 188 |
+
global_state[0] = 1
|
| 189 |
+
|
| 190 |
+
opt = parse_args()
|
| 191 |
+
opt.transform_mean = [.5, .5, .5]
|
| 192 |
+
opt.transform_var = [.5, .5, .5]
|
| 193 |
+
opt.pretrained = os.path.join("./pretrained_models", form_pretrained_dropdown)
|
| 194 |
+
opt.split_resolution = form_split_res
|
| 195 |
+
opt.save_path = None
|
| 196 |
+
opt.workers = 0
|
| 197 |
+
opt.device = "gpu"
|
| 198 |
+
|
| 199 |
+
composite_image = np.asarray(form_composite_image)
|
| 200 |
+
mask = np.asarray(form_mask_image)
|
| 201 |
+
|
| 202 |
+
try:
|
| 203 |
+
return cv2.cvtColor(
|
| 204 |
+
main_process(opt, composite_image=composite_image, mask=mask),
|
| 205 |
+
cv2.COLOR_BGR2RGB)
|
| 206 |
+
except Exception as e:
|
| 207 |
+
raise gr.Error(f"Patches too big. Try to reduce the `split_res`!\nException is {e}")
|
| 208 |
+
|
| 209 |
+
else:
|
| 210 |
+
from inference_for_arbitrary_resolution_image import parse_args, main_process, global_state
|
| 211 |
+
global_state[0] = 1
|
| 212 |
+
|
| 213 |
+
opt = parse_args()
|
| 214 |
+
opt.transform_mean = [.5, .5, .5]
|
| 215 |
+
opt.transform_var = [.5, .5, .5]
|
| 216 |
+
opt.pretrained = os.path.join("./pretrained_models", form_pretrained_dropdown)
|
| 217 |
+
opt.split_num = int(form_split_num)
|
| 218 |
+
opt.save_path = None
|
| 219 |
+
opt.workers = 0
|
| 220 |
+
opt.device = "gpu"
|
| 221 |
+
|
| 222 |
+
composite_image = np.asarray(form_composite_image)
|
| 223 |
+
mask = np.asarray(form_mask_image)
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
return cv2.cvtColor(
|
| 227 |
+
main_process(opt, composite_image=composite_image, mask=mask),
|
| 228 |
+
cv2.COLOR_BGR2RGB)
|
| 229 |
+
except Exception as e:
|
| 230 |
+
raise gr.Error(f"Patches too big. Try to increase the `split_num`!\nException is {e}")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
generate = form_start_btn.click(on_click_form_start_btn,
|
| 234 |
+
inputs=[form_composite_image, form_mask_image, form_pretrained_dropdown,
|
| 235 |
+
form_inference_mode,
|
| 236 |
+
form_split_res, form_split_num], outputs=[form_harmonized_image])
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def on_click_form_reset_btn(form_inference_mode):
|
| 240 |
+
if form_inference_mode == "Square Image":
|
| 241 |
+
from efficient_inference_for_square_image import global_state
|
| 242 |
+
global_state[0] = 0
|
| 243 |
+
else:
|
| 244 |
+
from inference_for_arbitrary_resolution_image import global_state
|
| 245 |
+
global_state[0] = 0
|
| 246 |
+
|
| 247 |
+
log.log = io.BytesIO()
|
| 248 |
+
return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
|
| 249 |
+
interactive=False), gr.update(
|
| 250 |
+
interactive=False)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
form_reset_btn.click(on_click_form_reset_btn,
|
| 254 |
+
inputs=[form_inference_mode],
|
| 255 |
+
outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def on_click_form_stop(form_inference_mode):
|
| 259 |
+
if form_inference_mode == "Square Image":
|
| 260 |
+
from efficient_inference_for_square_image import global_state
|
| 261 |
+
global_state[0] = 0
|
| 262 |
+
else:
|
| 263 |
+
from inference_for_arbitrary_resolution_image import global_state
|
| 264 |
+
global_state[0] = 0
|
| 265 |
+
|
| 266 |
+
log.log = io.BytesIO()
|
| 267 |
+
return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
|
| 268 |
+
interactive=False), gr.update(
|
| 269 |
+
interactive=False)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
form_stop_btn.click(on_click_form_stop,
|
| 273 |
+
inputs=[form_inference_mode],
|
| 274 |
+
outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
|
| 275 |
+
|
| 276 |
+
gr.close_all()
|
| 277 |
|
| 278 |
+
app.queue()
|
|
|
|
| 279 |
|
| 280 |
+
app.launch(show_api=False)
|
|
|
assets/demo.gif
ADDED
|
Git LFS Details
|
assets/metrics.png
ADDED
|
assets/network.png
ADDED
|
assets/title_any_image.gif
ADDED
|
assets/title_harmon.gif
ADDED
|
assets/title_you_want.gif
ADDED
|
assets/visualizations.png
ADDED
|
assets/visualizations2.png
ADDED
|
Git LFS Details
|
datasets/__init__.py
ADDED
|
File without changes
|
datasets/build_INR_dataset.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils import misc
|
| 2 |
+
from albumentations import Resize
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Implicit2DGenerator(object):
|
| 6 |
+
def __init__(self, opt, mode):
|
| 7 |
+
if mode == 'Train':
|
| 8 |
+
sidelength = opt.INR_input_size
|
| 9 |
+
elif mode == 'Val':
|
| 10 |
+
sidelength = opt.input_size
|
| 11 |
+
else:
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
self.mode = mode
|
| 15 |
+
|
| 16 |
+
self.size = sidelength
|
| 17 |
+
|
| 18 |
+
if isinstance(sidelength, int):
|
| 19 |
+
sidelength = (sidelength, sidelength)
|
| 20 |
+
|
| 21 |
+
self.mgrid = misc.get_mgrid(sidelength)
|
| 22 |
+
|
| 23 |
+
self.transform = Resize(self.size, self.size)
|
| 24 |
+
|
| 25 |
+
def generator(self, torch_transforms, composite_image, real_image, mask):
|
| 26 |
+
composite_image = torch_transforms(self.transform(image=composite_image)['image'])
|
| 27 |
+
real_image = torch_transforms(self.transform(image=real_image)['image'])
|
| 28 |
+
|
| 29 |
+
fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
| 30 |
+
fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
| 31 |
+
bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
| 32 |
+
|
| 33 |
+
fg_INR_coordinates = self.mgrid
|
| 34 |
+
bg_INR_coordinates = self.mgrid
|
| 35 |
+
|
| 36 |
+
return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB
|
datasets/build_dataset.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torchvision
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
from utils.misc import prepare_cooridinate_input, customRandomCrop
|
| 9 |
+
|
| 10 |
+
from datasets.build_INR_dataset import Implicit2DGenerator
|
| 11 |
+
import albumentations
|
| 12 |
+
from albumentations import Resize, RandomResizedCrop, HorizontalFlip
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class dataset_generator(torch.utils.data.Dataset):
|
| 17 |
+
def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.opt = opt
|
| 21 |
+
self.root_path = opt.dataset_path
|
| 22 |
+
self.mode = mode
|
| 23 |
+
|
| 24 |
+
self.alb_transforms = alb_transforms
|
| 25 |
+
self.torch_transforms = torch_transforms
|
| 26 |
+
self.kp_t = area_keep_thresh
|
| 27 |
+
|
| 28 |
+
with open(dataset_txt, 'r') as f:
|
| 29 |
+
self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()]
|
| 30 |
+
|
| 31 |
+
self.INR_dataset = Implicit2DGenerator(opt, self.mode)
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.dataset_samples)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
composite_image = self.dataset_samples[idx]
|
| 38 |
+
|
| 39 |
+
if self.opt.hr_train:
|
| 40 |
+
if self.opt.isFullRes:
|
| 41 |
+
"Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \
|
| 42 |
+
"quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \
|
| 43 |
+
"if `opt.isFullRes` is set to True."
|
| 44 |
+
composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori")
|
| 45 |
+
|
| 46 |
+
real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg'
|
| 47 |
+
mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png'
|
| 48 |
+
|
| 49 |
+
composite_image = cv2.imread(composite_image)
|
| 50 |
+
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
|
| 51 |
+
|
| 52 |
+
real_image = cv2.imread(real_image)
|
| 53 |
+
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
|
| 54 |
+
|
| 55 |
+
mask = cv2.imread(mask)
|
| 56 |
+
mask = mask[:, :, 0].astype(np.float32) / 255.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
If set `opt.hr_train` to True:
|
| 60 |
+
|
| 61 |
+
Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres),
|
| 62 |
+
the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size`
|
| 63 |
+
patch to feed in multiINR process. For inference, just resize it.
|
| 64 |
+
|
| 65 |
+
While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size.
|
| 66 |
+
|
| 67 |
+
BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5`
|
| 68 |
+
"""
|
| 69 |
+
if self.opt.hr_train:
|
| 70 |
+
if self.mode == 'Train' and self.opt.isFullRes:
|
| 71 |
+
if random.random() < 0.5: # LR mix training
|
| 72 |
+
mixTransform = albumentations.Compose(
|
| 73 |
+
[
|
| 74 |
+
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
|
| 75 |
+
HorizontalFlip()],
|
| 76 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'}
|
| 77 |
+
)
|
| 78 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 79 |
+
origin_bg_ratio = 1 - origin_fg_ratio
|
| 80 |
+
|
| 81 |
+
"Ensure fg and bg not disappear after transformation"
|
| 82 |
+
valid_augmentation = False
|
| 83 |
+
transform_out = None
|
| 84 |
+
time = 0
|
| 85 |
+
while not valid_augmentation:
|
| 86 |
+
time += 1
|
| 87 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 88 |
+
if time == 20:
|
| 89 |
+
tmp_transform = albumentations.Compose(
|
| 90 |
+
[Resize(self.opt.base_size, self.opt.base_size)],
|
| 91 |
+
additional_targets={'real_image': 'image',
|
| 92 |
+
'object_mask': 'image'})
|
| 93 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
| 94 |
+
object_mask=mask)
|
| 95 |
+
valid_augmentation = True
|
| 96 |
+
else:
|
| 97 |
+
transform_out = mixTransform(image=composite_image, real_image=real_image,
|
| 98 |
+
object_mask=mask)
|
| 99 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
| 100 |
+
origin_fg_ratio,
|
| 101 |
+
origin_bg_ratio,
|
| 102 |
+
self.kp_t)
|
| 103 |
+
composite_image = transform_out['image']
|
| 104 |
+
real_image = transform_out['real_image']
|
| 105 |
+
mask = transform_out['object_mask']
|
| 106 |
+
else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop.
|
| 107 |
+
if real_image.shape[0] < 256:
|
| 108 |
+
bottom_pad = 256 - real_image.shape[0]
|
| 109 |
+
else:
|
| 110 |
+
bottom_pad = (4 - real_image.shape[0] % 4) % 4
|
| 111 |
+
if real_image.shape[1] < 256:
|
| 112 |
+
right_pad = 256 - real_image.shape[1]
|
| 113 |
+
else:
|
| 114 |
+
right_pad = (4 - real_image.shape[1] % 4) % 4
|
| 115 |
+
composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad,
|
| 116 |
+
cv2.BORDER_REPLICATE)
|
| 117 |
+
real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
|
| 118 |
+
mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
|
| 119 |
+
|
| 120 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 121 |
+
origin_bg_ratio = 1 - origin_fg_ratio
|
| 122 |
+
|
| 123 |
+
"Ensure fg and bg not disappear after transformation"
|
| 124 |
+
valid_augmentation = False
|
| 125 |
+
transform_out = None
|
| 126 |
+
time = 0
|
| 127 |
+
|
| 128 |
+
if self.opt.hr_train:
|
| 129 |
+
if self.mode == 'Train':
|
| 130 |
+
if not self.opt.isFullRes:
|
| 131 |
+
if random.random() < 0.5: # LR mix training
|
| 132 |
+
mixTransform = albumentations.Compose(
|
| 133 |
+
[
|
| 134 |
+
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
|
| 135 |
+
HorizontalFlip()],
|
| 136 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'}
|
| 137 |
+
)
|
| 138 |
+
while not valid_augmentation:
|
| 139 |
+
time += 1
|
| 140 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 141 |
+
if time == 20:
|
| 142 |
+
tmp_transform = albumentations.Compose(
|
| 143 |
+
[Resize(self.opt.base_size, self.opt.base_size)],
|
| 144 |
+
additional_targets={'real_image': 'image',
|
| 145 |
+
'object_mask': 'image'})
|
| 146 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
| 147 |
+
object_mask=mask)
|
| 148 |
+
valid_augmentation = True
|
| 149 |
+
else:
|
| 150 |
+
transform_out = mixTransform(image=composite_image, real_image=real_image,
|
| 151 |
+
object_mask=mask)
|
| 152 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
| 153 |
+
origin_fg_ratio,
|
| 154 |
+
origin_bg_ratio,
|
| 155 |
+
self.kp_t)
|
| 156 |
+
else:
|
| 157 |
+
while not valid_augmentation:
|
| 158 |
+
time += 1
|
| 159 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 160 |
+
if time == 20:
|
| 161 |
+
tmp_transform = albumentations.Compose(
|
| 162 |
+
[Resize(self.opt.input_size, self.opt.input_size)],
|
| 163 |
+
additional_targets={'real_image': 'image',
|
| 164 |
+
'object_mask': 'image'})
|
| 165 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
| 166 |
+
object_mask=mask)
|
| 167 |
+
valid_augmentation = True
|
| 168 |
+
else:
|
| 169 |
+
transform_out = self.alb_transforms(image=composite_image, real_image=real_image,
|
| 170 |
+
object_mask=mask)
|
| 171 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
| 172 |
+
origin_fg_ratio,
|
| 173 |
+
origin_bg_ratio,
|
| 174 |
+
self.kp_t)
|
| 175 |
+
composite_image = transform_out['image']
|
| 176 |
+
real_image = transform_out['real_image']
|
| 177 |
+
mask = transform_out['object_mask']
|
| 178 |
+
|
| 179 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 180 |
+
|
| 181 |
+
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
|
| 182 |
+
|
| 183 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
|
| 184 |
+
additional_targets={'real_image': 'image',
|
| 185 |
+
'object_mask': 'image'})
|
| 186 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
| 187 |
+
compos_list = [self.torch_transforms(transform_out['image'])]
|
| 188 |
+
real_list = [self.torch_transforms(transform_out['real_image'])]
|
| 189 |
+
mask_list = [
|
| 190 |
+
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
|
| 191 |
+
coord_map_list = []
|
| 192 |
+
|
| 193 |
+
valid_augmentation = False
|
| 194 |
+
while not valid_augmentation:
|
| 195 |
+
# RSC strategy. To crop different resolutions.
|
| 196 |
+
transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord],
|
| 197 |
+
self.opt.base_size, self.opt.base_size)
|
| 198 |
+
valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio)
|
| 199 |
+
|
| 200 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 201 |
+
real_list.append(self.torch_transforms(transform_out[1]))
|
| 202 |
+
mask_list.append(
|
| 203 |
+
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
|
| 204 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
| 205 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
| 206 |
+
for n in range(2):
|
| 207 |
+
tmp_comp = cv2.resize(composite_image, (
|
| 208 |
+
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
|
| 209 |
+
tmp_real = cv2.resize(real_image,
|
| 210 |
+
(real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1)))
|
| 211 |
+
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
|
| 212 |
+
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
|
| 213 |
+
|
| 214 |
+
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord],
|
| 215 |
+
self.opt.base_size // 2 ** (n + 1),
|
| 216 |
+
self.opt.base_size // 2 ** (n + 1), c_h, c_w)
|
| 217 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 218 |
+
real_list.append(self.torch_transforms(transform_out[1]))
|
| 219 |
+
mask_list.append(
|
| 220 |
+
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
|
| 221 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
| 222 |
+
out_comp = compos_list
|
| 223 |
+
out_real = real_list
|
| 224 |
+
out_mask = mask_list
|
| 225 |
+
out_coord = coord_map_list
|
| 226 |
+
|
| 227 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 228 |
+
self.torch_transforms, transform_out[0], transform_out[1], mask)
|
| 229 |
+
|
| 230 |
+
return {
|
| 231 |
+
'file_path': self.dataset_samples[idx],
|
| 232 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 233 |
+
'composite_image': out_comp,
|
| 234 |
+
'real_image': out_real,
|
| 235 |
+
'mask': out_mask,
|
| 236 |
+
'coordinate_map': out_coord,
|
| 237 |
+
'composite_image0': out_comp[0],
|
| 238 |
+
'real_image0': out_real[0],
|
| 239 |
+
'mask0': out_mask[0],
|
| 240 |
+
'coordinate_map0': out_coord[0],
|
| 241 |
+
'composite_image1': out_comp[1],
|
| 242 |
+
'real_image1': out_real[1],
|
| 243 |
+
'mask1': out_mask[1],
|
| 244 |
+
'coordinate_map1': out_coord[1],
|
| 245 |
+
'composite_image2': out_comp[2],
|
| 246 |
+
'real_image2': out_real[2],
|
| 247 |
+
'mask2': out_mask[2],
|
| 248 |
+
'coordinate_map2': out_coord[2],
|
| 249 |
+
'composite_image3': out_comp[3],
|
| 250 |
+
'real_image3': out_real[3],
|
| 251 |
+
'mask3': out_mask[3],
|
| 252 |
+
'coordinate_map3': out_coord[3],
|
| 253 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 254 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 255 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 256 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 257 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 258 |
+
}
|
| 259 |
+
else:
|
| 260 |
+
if not self.opt.isFullRes:
|
| 261 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
|
| 262 |
+
additional_targets={'real_image': 'image',
|
| 263 |
+
'object_mask': 'image'})
|
| 264 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
| 265 |
+
|
| 266 |
+
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
|
| 267 |
+
|
| 268 |
+
"Generate INR dataset."
|
| 269 |
+
mask = (torchvision.transforms.ToTensor()(
|
| 270 |
+
transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
|
| 271 |
+
mask = np.bool_(mask.numpy())
|
| 272 |
+
|
| 273 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 274 |
+
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
'file_path': self.dataset_samples[idx],
|
| 278 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 279 |
+
'composite_image': self.torch_transforms(transform_out['image']),
|
| 280 |
+
'real_image': self.torch_transforms(transform_out['real_image']),
|
| 281 |
+
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
|
| 282 |
+
# Can automatically transfer to Tensor.
|
| 283 |
+
'coordinate_map': coordinate_map,
|
| 284 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 285 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 286 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 287 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 288 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 289 |
+
}
|
| 290 |
+
else:
|
| 291 |
+
coordinate_map = prepare_cooridinate_input(mask)
|
| 292 |
+
|
| 293 |
+
"Generate INR dataset."
|
| 294 |
+
mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1)
|
| 295 |
+
mask_tmp = np.bool_(mask_tmp.numpy())
|
| 296 |
+
|
| 297 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 298 |
+
self.torch_transforms, composite_image, real_image, mask_tmp)
|
| 299 |
+
|
| 300 |
+
return {
|
| 301 |
+
'file_path': self.dataset_samples[idx],
|
| 302 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 303 |
+
'composite_image': self.torch_transforms(composite_image),
|
| 304 |
+
'real_image': self.torch_transforms(real_image),
|
| 305 |
+
'mask': mask[np.newaxis, ...].astype(np.float32),
|
| 306 |
+
# Can automatically transfer to Tensor.
|
| 307 |
+
'coordinate_map': coordinate_map,
|
| 308 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 309 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 310 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 311 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 312 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
while not valid_augmentation:
|
| 316 |
+
time += 1
|
| 317 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
| 318 |
+
if time == 20:
|
| 319 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
|
| 320 |
+
additional_targets={'real_image': 'image',
|
| 321 |
+
'object_mask': 'image'})
|
| 322 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
| 323 |
+
valid_augmentation = True
|
| 324 |
+
else:
|
| 325 |
+
transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask)
|
| 326 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio,
|
| 327 |
+
origin_bg_ratio,
|
| 328 |
+
self.kp_t)
|
| 329 |
+
|
| 330 |
+
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
|
| 331 |
+
|
| 332 |
+
"Generate INR dataset."
|
| 333 |
+
mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
|
| 334 |
+
mask = np.bool_(mask.numpy())
|
| 335 |
+
|
| 336 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 337 |
+
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
'file_path': self.dataset_samples[idx],
|
| 341 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
| 342 |
+
'composite_image': self.torch_transforms(transform_out['image']),
|
| 343 |
+
'real_image': self.torch_transforms(transform_out['real_image']),
|
| 344 |
+
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
|
| 345 |
+
# Can automatically transfer to Tensor.
|
| 346 |
+
'coordinate_map': coordinate_map,
|
| 347 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 348 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 349 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 350 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 351 |
+
'bg_INR_RGB': bg_INR_RGB
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh):
|
| 356 |
+
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 357 |
+
current_bg_ratio = 1 - current_fg_ratio
|
| 358 |
+
|
| 359 |
+
if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh:
|
| 360 |
+
return False
|
| 361 |
+
|
| 362 |
+
return True
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def check_hr_crop_sample(mask, origin_fg_ratio):
|
| 366 |
+
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
| 367 |
+
|
| 368 |
+
if current_fg_ratio < 0.8 * origin_fg_ratio:
|
| 369 |
+
return False
|
| 370 |
+
|
| 371 |
+
return True
|
demo/demo_1k_composite_2.jpg
ADDED
|
demo/demo_1k_composite_3.jpg
ADDED
|
demo/demo_1k_mask_2.jpg
ADDED
|
demo/demo_1k_mask_3.jpg
ADDED
|
demo/demo_composite.jpg
ADDED
|
demo/demo_composite_1.jpg
ADDED
|
demo/demo_composite_2.jpg
ADDED
|
demo/demo_composite_3.jpg
ADDED
|
demo/demo_composite_4.jpg
ADDED
|
demo/demo_composite_5.jpg
ADDED
|
demo/demo_composite_6.jpg
ADDED
|
demo/demo_mask.png
ADDED
|
demo/demo_mask_1.png
ADDED
|
demo/demo_mask_2.png
ADDED
|
demo/demo_mask_3.png
ADDED
|
demo/demo_mask_4.jpg
ADDED
|
demo/demo_mask_5.jpg
ADDED
|
demo/demo_mask_6.jpg
ADDED
|
efficient_inference_for_square_image.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import builtins
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
import torch.backends.cudnn as cudnn
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from model.build_model import build_model
|
| 10 |
+
from torch.optim import AdamW
|
| 11 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torchvision
|
| 17 |
+
import os
|
| 18 |
+
import tqdm
|
| 19 |
+
import time
|
| 20 |
+
|
| 21 |
+
from utils.misc import prepare_cooridinate_input, customRandomCrop
|
| 22 |
+
|
| 23 |
+
from datasets.build_INR_dataset import Implicit2DGenerator
|
| 24 |
+
import albumentations
|
| 25 |
+
from albumentations import Resize
|
| 26 |
+
# from torch.utils.data import DataLoader
|
| 27 |
+
from utils.misc import normalize
|
| 28 |
+
|
| 29 |
+
import math
|
| 30 |
+
|
| 31 |
+
global_state = [1] # For Gradio Stop Button.
|
| 32 |
+
|
| 33 |
+
class single_image_dataset(torch.utils.data.Dataset):
|
| 34 |
+
def __init__(self, opt, composite_image=None, mask=None):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
self.opt = opt
|
| 38 |
+
|
| 39 |
+
if composite_image is None:
|
| 40 |
+
composite_image = cv2.imread(opt.composite_image)
|
| 41 |
+
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
|
| 42 |
+
self.composite_image = composite_image
|
| 43 |
+
|
| 44 |
+
assert composite_image.shape[0] == composite_image.shape[1], "This faster script only supports square images."
|
| 45 |
+
assert composite_image.shape[
|
| 46 |
+
0] % 256 == 0, "This faster script only supports images with resolution multiples of 256."
|
| 47 |
+
assert opt.split_resolution % (composite_image.shape[
|
| 48 |
+
0] // 16) == 0, f"The image resolution is {composite_image.shape[0]}, " \
|
| 49 |
+
f"you should set {opt.split_resolution} to multiplies of {composite_image.shape[0] // 16}"
|
| 50 |
+
|
| 51 |
+
if mask is None:
|
| 52 |
+
mask = cv2.imread(opt.mask)
|
| 53 |
+
mask = mask[:, :, 0].astype(np.float32) / 255.
|
| 54 |
+
self.mask = mask
|
| 55 |
+
|
| 56 |
+
self.torch_transforms = transforms.Compose([transforms.ToTensor(),
|
| 57 |
+
transforms.Normalize([.5, .5, .5], [.5, .5, .5])])
|
| 58 |
+
self.INR_dataset = Implicit2DGenerator(opt, 'Val')
|
| 59 |
+
|
| 60 |
+
self.split_width_resolution = self.split_height_resolution = opt.split_resolution
|
| 61 |
+
|
| 62 |
+
self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution)
|
| 63 |
+
self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution)
|
| 64 |
+
|
| 65 |
+
self.split_start_point = []
|
| 66 |
+
|
| 67 |
+
"Split the image into several parts."
|
| 68 |
+
for i in range(self.num_h):
|
| 69 |
+
for j in range(self.num_w):
|
| 70 |
+
if i == composite_image.shape[0] // self.split_height_resolution:
|
| 71 |
+
if j == composite_image.shape[1] // self.split_width_resolution:
|
| 72 |
+
self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution,
|
| 73 |
+
composite_image.shape[1] - self.split_width_resolution))
|
| 74 |
+
else:
|
| 75 |
+
self.split_start_point.append(
|
| 76 |
+
(composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution))
|
| 77 |
+
else:
|
| 78 |
+
if j == composite_image.shape[1] // self.split_width_resolution:
|
| 79 |
+
self.split_start_point.append(
|
| 80 |
+
(i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution))
|
| 81 |
+
else:
|
| 82 |
+
self.split_start_point.append(
|
| 83 |
+
(i * self.split_height_resolution, j * self.split_width_resolution))
|
| 84 |
+
|
| 85 |
+
assert len(self.split_start_point) == self.num_w * self.num_h
|
| 86 |
+
|
| 87 |
+
print(
|
| 88 |
+
f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.")
|
| 89 |
+
print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}")
|
| 90 |
+
|
| 91 |
+
def __len__(self):
|
| 92 |
+
return self.num_w * self.num_h
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, idx):
|
| 95 |
+
composite_image = self.composite_image
|
| 96 |
+
|
| 97 |
+
mask = self.mask
|
| 98 |
+
|
| 99 |
+
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
|
| 100 |
+
|
| 101 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
|
| 102 |
+
additional_targets={'object_mask': 'image'})
|
| 103 |
+
transform_out = tmp_transform(image=self.composite_image, object_mask=self.mask)
|
| 104 |
+
compos_list = [self.torch_transforms(transform_out['image'])]
|
| 105 |
+
mask_list = [
|
| 106 |
+
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
|
| 107 |
+
coord_map_list = []
|
| 108 |
+
|
| 109 |
+
if composite_image.shape[0] != self.split_height_resolution:
|
| 110 |
+
c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution)
|
| 111 |
+
else:
|
| 112 |
+
c_h = 0
|
| 113 |
+
if composite_image.shape[1] != self.split_width_resolution:
|
| 114 |
+
c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution)
|
| 115 |
+
else:
|
| 116 |
+
c_w = 0
|
| 117 |
+
transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord],
|
| 118 |
+
self.split_height_resolution, self.split_width_resolution, c_h, c_w)
|
| 119 |
+
|
| 120 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 121 |
+
mask_list.append(
|
| 122 |
+
torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
|
| 123 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
|
| 124 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
|
| 125 |
+
for n in range(2):
|
| 126 |
+
tmp_comp = cv2.resize(composite_image, (
|
| 127 |
+
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
|
| 128 |
+
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
|
| 129 |
+
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
|
| 130 |
+
|
| 131 |
+
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord],
|
| 132 |
+
self.split_height_resolution // 2 ** (n + 1),
|
| 133 |
+
self.split_width_resolution // 2 ** (n + 1), c_h, c_w)
|
| 134 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 135 |
+
mask_list.append(
|
| 136 |
+
torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
|
| 137 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
|
| 138 |
+
out_comp = compos_list
|
| 139 |
+
out_mask = mask_list
|
| 140 |
+
out_coord = coord_map_list
|
| 141 |
+
|
| 142 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 143 |
+
self.torch_transforms, transform_out[0], transform_out[0], mask)
|
| 144 |
+
|
| 145 |
+
return {
|
| 146 |
+
'composite_image': out_comp,
|
| 147 |
+
'mask': out_mask,
|
| 148 |
+
'coordinate_map': out_coord,
|
| 149 |
+
'composite_image0': out_comp[0],
|
| 150 |
+
'mask0': out_mask[0],
|
| 151 |
+
'coordinate_map0': out_coord[0],
|
| 152 |
+
'composite_image1': out_comp[1],
|
| 153 |
+
'mask1': out_mask[1],
|
| 154 |
+
'coordinate_map1': out_coord[1],
|
| 155 |
+
'composite_image2': out_comp[2],
|
| 156 |
+
'mask2': out_mask[2],
|
| 157 |
+
'coordinate_map2': out_coord[2],
|
| 158 |
+
'composite_image3': out_comp[3],
|
| 159 |
+
'mask3': out_mask[3],
|
| 160 |
+
'coordinate_map3': out_coord[3],
|
| 161 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 162 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 163 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 164 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 165 |
+
'bg_INR_RGB': bg_INR_RGB,
|
| 166 |
+
'start_point': self.split_start_point[idx],
|
| 167 |
+
'start_proportion': [self.split_start_point[idx][0] / (composite_image.shape[0]),
|
| 168 |
+
self.split_start_point[idx][1] / (composite_image.shape[1]),
|
| 169 |
+
(self.split_start_point[idx][0] + self.split_height_resolution) / (
|
| 170 |
+
composite_image.shape[0]),
|
| 171 |
+
(self.split_start_point[idx][1] + self.split_width_resolution) / (
|
| 172 |
+
composite_image.shape[1])],
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def parse_args():
|
| 177 |
+
parser = argparse.ArgumentParser()
|
| 178 |
+
|
| 179 |
+
parser.add_argument('--split_resolution', type=int, default=2048,
|
| 180 |
+
help='The resolution of the patch split.')
|
| 181 |
+
|
| 182 |
+
parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg',
|
| 183 |
+
help='composite image path')
|
| 184 |
+
|
| 185 |
+
parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg',
|
| 186 |
+
help='mask path')
|
| 187 |
+
|
| 188 |
+
parser.add_argument('--save_path', type=str, default=r'./demo/',
|
| 189 |
+
help='save path')
|
| 190 |
+
|
| 191 |
+
parser.add_argument('--workers', type=int, default=8,
|
| 192 |
+
metavar='N', help='Dataloader threads.')
|
| 193 |
+
|
| 194 |
+
parser.add_argument('--batch_size', type=int, default=1,
|
| 195 |
+
help='You can override model batch size by specify positive number.')
|
| 196 |
+
|
| 197 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 198 |
+
help="Whether use cuda, 'cuda' or 'cpu'.")
|
| 199 |
+
|
| 200 |
+
parser.add_argument('--base_size', type=int, default=256,
|
| 201 |
+
help='Base size. Resolution of the image input into the Encoder')
|
| 202 |
+
|
| 203 |
+
parser.add_argument('--input_size', type=int, default=256,
|
| 204 |
+
help='Input size. Resolution of the image that want to be generated by the Decoder')
|
| 205 |
+
|
| 206 |
+
parser.add_argument('--INR_input_size', type=int, default=256,
|
| 207 |
+
help='INR input size. Resolution of the image that want to be generated by the Decoder. '
|
| 208 |
+
'Should be the same as `input_size`')
|
| 209 |
+
|
| 210 |
+
parser.add_argument('--INR_MLP_dim', type=int, default=32,
|
| 211 |
+
help='Number of channels for INR linear layer.')
|
| 212 |
+
|
| 213 |
+
parser.add_argument('--LUT_dim', type=int, default=7,
|
| 214 |
+
help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
|
| 215 |
+
|
| 216 |
+
parser.add_argument('--activation', type=str, default='leakyrelu_pe',
|
| 217 |
+
help='INR activation layer type: leakyrelu_pe, sine')
|
| 218 |
+
|
| 219 |
+
parser.add_argument('--pretrained', type=str,
|
| 220 |
+
default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
|
| 221 |
+
help='Pretrained weight path')
|
| 222 |
+
|
| 223 |
+
parser.add_argument('--param_factorize_dim', type=int,
|
| 224 |
+
default=10,
|
| 225 |
+
help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
|
| 226 |
+
'Refer to https://arxiv.org/abs/2011.12026')
|
| 227 |
+
|
| 228 |
+
parser.add_argument('--embedding_type', type=str,
|
| 229 |
+
default="CIPS_embed",
|
| 230 |
+
help='Which embedding_type to use.')
|
| 231 |
+
|
| 232 |
+
parser.add_argument('--INRDecode', action="store_false",
|
| 233 |
+
help='Whether INR decoder. Set it to False if you want to test the baseline '
|
| 234 |
+
'(https://github.com/SamsungLabs/image_harmonization)')
|
| 235 |
+
|
| 236 |
+
parser.add_argument('--isMoreINRInput', action="store_false",
|
| 237 |
+
help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
|
| 238 |
+
|
| 239 |
+
parser.add_argument('--hr_train', action="store_false",
|
| 240 |
+
help='Whether use hr_train. See section 3.4 in the paper.')
|
| 241 |
+
|
| 242 |
+
parser.add_argument('--isFullRes', action="store_true",
|
| 243 |
+
help='Whether for original resolution. See section 3.4 in the paper.')
|
| 244 |
+
|
| 245 |
+
opt = parser.parse_args()
|
| 246 |
+
|
| 247 |
+
assert opt.batch_size == 1, 'This faster script only supports batch size 1 for inference.'
|
| 248 |
+
|
| 249 |
+
return opt
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
@torch.no_grad()
|
| 253 |
+
def inference(model, opt, composite_image=None, mask=None):
|
| 254 |
+
model.eval()
|
| 255 |
+
|
| 256 |
+
"dataset here is actually consisted of several patches of a single image."
|
| 257 |
+
singledataset = single_image_dataset(opt, composite_image, mask)
|
| 258 |
+
|
| 259 |
+
single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
|
| 260 |
+
num_workers=opt.workers, persistent_workers=False if composite_image is not None else True)
|
| 261 |
+
|
| 262 |
+
"Init a pure black image with the same size as the input image."
|
| 263 |
+
init_img = np.zeros_like(singledataset.composite_image)
|
| 264 |
+
|
| 265 |
+
time_all = 0
|
| 266 |
+
|
| 267 |
+
for step, batch in tqdm.tqdm(enumerate(single_data_loader)):
|
| 268 |
+
composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
|
| 269 |
+
mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
|
| 270 |
+
coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]
|
| 271 |
+
start_points = batch['start_point']
|
| 272 |
+
start_proportion = batch['start_proportion']
|
| 273 |
+
|
| 274 |
+
if opt.batch_size == 1:
|
| 275 |
+
start_points = [torch.cat(start_points)]
|
| 276 |
+
start_proportion = [torch.cat(start_proportion)]
|
| 277 |
+
|
| 278 |
+
fg_INR_coordinates = coordinate_map[1:]
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
if global_state[0] == 0:
|
| 282 |
+
print("Stop Harmonizing...!")
|
| 283 |
+
break
|
| 284 |
+
|
| 285 |
+
if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
|
| 286 |
+
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
| 287 |
+
composite_image,
|
| 288 |
+
mask,
|
| 289 |
+
fg_INR_coordinates, start_proportion[0]
|
| 290 |
+
)
|
| 291 |
+
print("Ready for harmonization...")
|
| 292 |
+
if opt.device == "cuda":
|
| 293 |
+
torch.cuda.reset_max_memory_allocated()
|
| 294 |
+
torch.cuda.reset_max_memory_cached()
|
| 295 |
+
start_time = time.time()
|
| 296 |
+
torch.cuda.synchronize()
|
| 297 |
+
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
| 298 |
+
composite_image,
|
| 299 |
+
mask,
|
| 300 |
+
fg_INR_coordinates, start_proportion[0]
|
| 301 |
+
)
|
| 302 |
+
if opt.device == "cuda":
|
| 303 |
+
torch.cuda.synchronize()
|
| 304 |
+
end_time = time.time()
|
| 305 |
+
|
| 306 |
+
end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2
|
| 307 |
+
end_memory = torch.cuda.memory_allocated() // 1024 ** 2
|
| 308 |
+
|
| 309 |
+
print(f'GPU max memory usage: {end_max_memory} MB')
|
| 310 |
+
print(f'GPU memory usage: {end_memory} MB')
|
| 311 |
+
time_all += (end_time - start_time)
|
| 312 |
+
print(f'progress: {step} / {len(single_data_loader)}')
|
| 313 |
+
except:
|
| 314 |
+
raise Exception(
|
| 315 |
+
f'The image resolution is large. Please reduce the `split_resolution` value. Your current set is {opt.split_resolution}')
|
| 316 |
+
|
| 317 |
+
"Assemble the every patch's harmonized result into the final whole image."
|
| 318 |
+
for id in range(len(fg_INR_coordinates[0])):
|
| 319 |
+
pred_fg_image = fg_content_bg_appearance_construct[-1][id]
|
| 320 |
+
pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * (
|
| 321 |
+
~(mask[1][id] > 100 / 255.))
|
| 322 |
+
|
| 323 |
+
pred_harmonized_tmp = cv2.cvtColor(
|
| 324 |
+
normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_(
|
| 325 |
+
0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 326 |
+
|
| 327 |
+
init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
|
| 328 |
+
start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
|
| 329 |
+
|
| 330 |
+
if opt.device == "cuda":
|
| 331 |
+
print(f'Inference time: {time_all}')
|
| 332 |
+
if opt.save_path is not None:
|
| 333 |
+
os.makedirs(opt.save_path, exist_ok=True)
|
| 334 |
+
cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
|
| 335 |
+
return init_img
|
| 336 |
+
|
| 337 |
+
def main_process(opt, composite_image=None, mask=None):
|
| 338 |
+
# torch.serialization.add_safe_globals([getattr, OneCycleLR, AdamW, defaultdict, builtins.dict])
|
| 339 |
+
cudnn.benchmark = True
|
| 340 |
+
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 341 |
+
print("Preparing model...")
|
| 342 |
+
model = build_model(opt).to(opt.device)
|
| 343 |
+
|
| 344 |
+
# Заменяем 'gpu' на 'cuda' и добавляем weights_only=True
|
| 345 |
+
load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
|
| 346 |
+
|
| 347 |
+
model.load_state_dict(load_dict, strict=False)
|
| 348 |
+
|
| 349 |
+
return inference(model, opt, composite_image, mask)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == '__main__':
|
| 353 |
+
opt = parse_args()
|
| 354 |
+
opt.transform_mean = [.5, .5, .5]
|
| 355 |
+
opt.transform_var = [.5, .5, .5]
|
| 356 |
+
main_process(opt)
|
hrnet_ocr.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch._utils
|
| 7 |
+
|
| 8 |
+
from .ocr import SpatialOCR_Module, SpatialGather_Module
|
| 9 |
+
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
| 10 |
+
|
| 11 |
+
relu_inplace = True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HighResolutionModule(nn.Module):
|
| 15 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
| 16 |
+
num_channels, fuse_method,multi_scale_output=True,
|
| 17 |
+
norm_layer=nn.BatchNorm2d, align_corners=True):
|
| 18 |
+
super(HighResolutionModule, self).__init__()
|
| 19 |
+
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
| 20 |
+
|
| 21 |
+
self.num_inchannels = num_inchannels
|
| 22 |
+
self.fuse_method = fuse_method
|
| 23 |
+
self.num_branches = num_branches
|
| 24 |
+
self.norm_layer = norm_layer
|
| 25 |
+
self.align_corners = align_corners
|
| 26 |
+
|
| 27 |
+
self.multi_scale_output = multi_scale_output
|
| 28 |
+
|
| 29 |
+
self.branches = self._make_branches(
|
| 30 |
+
num_branches, blocks, num_blocks, num_channels)
|
| 31 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 32 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
| 33 |
+
|
| 34 |
+
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
| 35 |
+
if num_branches != len(num_blocks):
|
| 36 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
| 37 |
+
num_branches, len(num_blocks))
|
| 38 |
+
raise ValueError(error_msg)
|
| 39 |
+
|
| 40 |
+
if num_branches != len(num_channels):
|
| 41 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
| 42 |
+
num_branches, len(num_channels))
|
| 43 |
+
raise ValueError(error_msg)
|
| 44 |
+
|
| 45 |
+
if num_branches != len(num_inchannels):
|
| 46 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
| 47 |
+
num_branches, len(num_inchannels))
|
| 48 |
+
raise ValueError(error_msg)
|
| 49 |
+
|
| 50 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
| 51 |
+
stride=1):
|
| 52 |
+
downsample = None
|
| 53 |
+
if stride != 1 or \
|
| 54 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
| 55 |
+
downsample = nn.Sequential(
|
| 56 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
| 57 |
+
num_channels[branch_index] * block.expansion,
|
| 58 |
+
kernel_size=1, stride=stride, bias=False),
|
| 59 |
+
self.norm_layer(num_channels[branch_index] * block.expansion),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
layers = []
|
| 63 |
+
layers.append(block(self.num_inchannels[branch_index],
|
| 64 |
+
num_channels[branch_index], stride,
|
| 65 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
| 66 |
+
self.num_inchannels[branch_index] = \
|
| 67 |
+
num_channels[branch_index] * block.expansion
|
| 68 |
+
for i in range(1, num_blocks[branch_index]):
|
| 69 |
+
layers.append(block(self.num_inchannels[branch_index],
|
| 70 |
+
num_channels[branch_index],
|
| 71 |
+
norm_layer=self.norm_layer))
|
| 72 |
+
|
| 73 |
+
return nn.Sequential(*layers)
|
| 74 |
+
|
| 75 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 76 |
+
branches = []
|
| 77 |
+
|
| 78 |
+
for i in range(num_branches):
|
| 79 |
+
branches.append(
|
| 80 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
| 81 |
+
|
| 82 |
+
return nn.ModuleList(branches)
|
| 83 |
+
|
| 84 |
+
def _make_fuse_layers(self):
|
| 85 |
+
if self.num_branches == 1:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
num_branches = self.num_branches
|
| 89 |
+
num_inchannels = self.num_inchannels
|
| 90 |
+
fuse_layers = []
|
| 91 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 92 |
+
fuse_layer = []
|
| 93 |
+
for j in range(num_branches):
|
| 94 |
+
if j > i:
|
| 95 |
+
fuse_layer.append(nn.Sequential(
|
| 96 |
+
nn.Conv2d(in_channels=num_inchannels[j],
|
| 97 |
+
out_channels=num_inchannels[i],
|
| 98 |
+
kernel_size=1,
|
| 99 |
+
bias=False),
|
| 100 |
+
self.norm_layer(num_inchannels[i])))
|
| 101 |
+
elif j == i:
|
| 102 |
+
fuse_layer.append(None)
|
| 103 |
+
else:
|
| 104 |
+
conv3x3s = []
|
| 105 |
+
for k in range(i - j):
|
| 106 |
+
if k == i - j - 1:
|
| 107 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
| 108 |
+
conv3x3s.append(nn.Sequential(
|
| 109 |
+
nn.Conv2d(num_inchannels[j],
|
| 110 |
+
num_outchannels_conv3x3,
|
| 111 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 112 |
+
self.norm_layer(num_outchannels_conv3x3)))
|
| 113 |
+
else:
|
| 114 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
| 115 |
+
conv3x3s.append(nn.Sequential(
|
| 116 |
+
nn.Conv2d(num_inchannels[j],
|
| 117 |
+
num_outchannels_conv3x3,
|
| 118 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 119 |
+
self.norm_layer(num_outchannels_conv3x3),
|
| 120 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 121 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 122 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 123 |
+
|
| 124 |
+
return nn.ModuleList(fuse_layers)
|
| 125 |
+
|
| 126 |
+
def get_num_inchannels(self):
|
| 127 |
+
return self.num_inchannels
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
if self.num_branches == 1:
|
| 131 |
+
return [self.branches[0](x[0])]
|
| 132 |
+
|
| 133 |
+
for i in range(self.num_branches):
|
| 134 |
+
x[i] = self.branches[i](x[i])
|
| 135 |
+
|
| 136 |
+
x_fuse = []
|
| 137 |
+
for i in range(len(self.fuse_layers)):
|
| 138 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 139 |
+
for j in range(1, self.num_branches):
|
| 140 |
+
if i == j:
|
| 141 |
+
y = y + x[j]
|
| 142 |
+
elif j > i:
|
| 143 |
+
width_output = x[i].shape[-1]
|
| 144 |
+
height_output = x[i].shape[-2]
|
| 145 |
+
y = y + F.interpolate(
|
| 146 |
+
self.fuse_layers[i][j](x[j]),
|
| 147 |
+
size=[height_output, width_output],
|
| 148 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 149 |
+
else:
|
| 150 |
+
y = y + self.fuse_layers[i][j](x[j])
|
| 151 |
+
x_fuse.append(self.relu(y))
|
| 152 |
+
|
| 153 |
+
return x_fuse
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class HighResolutionNet(nn.Module):
|
| 157 |
+
def __init__(self, width, num_classes, ocr_width=256, small=False,
|
| 158 |
+
norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
|
| 159 |
+
super(HighResolutionNet, self).__init__()
|
| 160 |
+
self.opt = opt
|
| 161 |
+
self.norm_layer = norm_layer
|
| 162 |
+
self.width = width
|
| 163 |
+
self.ocr_width = ocr_width
|
| 164 |
+
self.ocr_on = ocr_width > 0
|
| 165 |
+
self.align_corners = align_corners
|
| 166 |
+
|
| 167 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 168 |
+
self.bn1 = norm_layer(64)
|
| 169 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 170 |
+
self.bn2 = norm_layer(64)
|
| 171 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
| 172 |
+
|
| 173 |
+
num_blocks = 2 if small else 4
|
| 174 |
+
|
| 175 |
+
stage1_num_channels = 64
|
| 176 |
+
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
|
| 177 |
+
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
| 178 |
+
|
| 179 |
+
self.stage2_num_branches = 2
|
| 180 |
+
num_channels = [width, 2 * width]
|
| 181 |
+
num_inchannels = [
|
| 182 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 183 |
+
self.transition1 = self._make_transition_layer(
|
| 184 |
+
[stage1_out_channel], num_inchannels)
|
| 185 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
| 186 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
|
| 187 |
+
num_blocks=2 * [num_blocks], num_channels=num_channels)
|
| 188 |
+
|
| 189 |
+
self.stage3_num_branches = 3
|
| 190 |
+
num_channels = [width, 2 * width, 4 * width]
|
| 191 |
+
num_inchannels = [
|
| 192 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 193 |
+
self.transition2 = self._make_transition_layer(
|
| 194 |
+
pre_stage_channels, num_inchannels)
|
| 195 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
| 196 |
+
BasicBlockV1b, num_inchannels=num_inchannels,
|
| 197 |
+
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
|
| 198 |
+
num_blocks=3 * [num_blocks], num_channels=num_channels)
|
| 199 |
+
|
| 200 |
+
self.stage4_num_branches = 4
|
| 201 |
+
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
| 202 |
+
num_inchannels = [
|
| 203 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 204 |
+
self.transition3 = self._make_transition_layer(
|
| 205 |
+
pre_stage_channels, num_inchannels)
|
| 206 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
| 207 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
|
| 208 |
+
num_branches=self.stage4_num_branches,
|
| 209 |
+
num_blocks=4 * [num_blocks], num_channels=num_channels)
|
| 210 |
+
|
| 211 |
+
if self.ocr_on:
|
| 212 |
+
last_inp_channels = np.int(np.sum(pre_stage_channels))
|
| 213 |
+
ocr_mid_channels = 2 * ocr_width
|
| 214 |
+
ocr_key_channels = ocr_width
|
| 215 |
+
|
| 216 |
+
self.conv3x3_ocr = nn.Sequential(
|
| 217 |
+
nn.Conv2d(last_inp_channels, ocr_mid_channels,
|
| 218 |
+
kernel_size=3, stride=1, padding=1),
|
| 219 |
+
norm_layer(ocr_mid_channels),
|
| 220 |
+
nn.ReLU(inplace=relu_inplace),
|
| 221 |
+
)
|
| 222 |
+
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
| 223 |
+
|
| 224 |
+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
|
| 225 |
+
key_channels=ocr_key_channels,
|
| 226 |
+
out_channels=ocr_mid_channels,
|
| 227 |
+
scale=1,
|
| 228 |
+
dropout=0.05,
|
| 229 |
+
norm_layer=norm_layer,
|
| 230 |
+
align_corners=align_corners, opt=opt)
|
| 231 |
+
|
| 232 |
+
def _make_transition_layer(
|
| 233 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
| 234 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 235 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 236 |
+
|
| 237 |
+
transition_layers = []
|
| 238 |
+
for i in range(num_branches_cur):
|
| 239 |
+
if i < num_branches_pre:
|
| 240 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 241 |
+
transition_layers.append(nn.Sequential(
|
| 242 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
| 243 |
+
num_channels_cur_layer[i],
|
| 244 |
+
kernel_size=3,
|
| 245 |
+
stride=1,
|
| 246 |
+
padding=1,
|
| 247 |
+
bias=False),
|
| 248 |
+
self.norm_layer(num_channels_cur_layer[i]),
|
| 249 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 250 |
+
else:
|
| 251 |
+
transition_layers.append(None)
|
| 252 |
+
else:
|
| 253 |
+
conv3x3s = []
|
| 254 |
+
for j in range(i + 1 - num_branches_pre):
|
| 255 |
+
inchannels = num_channels_pre_layer[-1]
|
| 256 |
+
outchannels = num_channels_cur_layer[i] \
|
| 257 |
+
if j == i - num_branches_pre else inchannels
|
| 258 |
+
conv3x3s.append(nn.Sequential(
|
| 259 |
+
nn.Conv2d(inchannels, outchannels,
|
| 260 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 261 |
+
self.norm_layer(outchannels),
|
| 262 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 263 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 264 |
+
|
| 265 |
+
return nn.ModuleList(transition_layers)
|
| 266 |
+
|
| 267 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
| 268 |
+
downsample = None
|
| 269 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
| 270 |
+
downsample = nn.Sequential(
|
| 271 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
| 272 |
+
kernel_size=1, stride=stride, bias=False),
|
| 273 |
+
self.norm_layer(planes * block.expansion),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
layers = []
|
| 277 |
+
layers.append(block(inplanes, planes, stride,
|
| 278 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
| 279 |
+
inplanes = planes * block.expansion
|
| 280 |
+
for i in range(1, blocks):
|
| 281 |
+
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
| 282 |
+
|
| 283 |
+
return nn.Sequential(*layers)
|
| 284 |
+
|
| 285 |
+
def _make_stage(self, block, num_inchannels,
|
| 286 |
+
num_modules, num_branches, num_blocks, num_channels,
|
| 287 |
+
fuse_method='SUM',
|
| 288 |
+
multi_scale_output=True):
|
| 289 |
+
modules = []
|
| 290 |
+
for i in range(num_modules):
|
| 291 |
+
# multi_scale_output is only used last module
|
| 292 |
+
if not multi_scale_output and i == num_modules - 1:
|
| 293 |
+
reset_multi_scale_output = False
|
| 294 |
+
else:
|
| 295 |
+
reset_multi_scale_output = True
|
| 296 |
+
modules.append(
|
| 297 |
+
HighResolutionModule(num_branches,
|
| 298 |
+
block,
|
| 299 |
+
num_blocks,
|
| 300 |
+
num_inchannels,
|
| 301 |
+
num_channels,
|
| 302 |
+
fuse_method,
|
| 303 |
+
reset_multi_scale_output,
|
| 304 |
+
norm_layer=self.norm_layer,
|
| 305 |
+
align_corners=self.align_corners)
|
| 306 |
+
)
|
| 307 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
| 308 |
+
|
| 309 |
+
return nn.Sequential(*modules), num_inchannels
|
| 310 |
+
|
| 311 |
+
def forward(self, x, mask=None, additional_features=None):
|
| 312 |
+
hrnet_feats = self.compute_hrnet_feats(x, additional_features)
|
| 313 |
+
if not self.ocr_on:
|
| 314 |
+
return hrnet_feats,
|
| 315 |
+
|
| 316 |
+
ocr_feats = self.conv3x3_ocr(hrnet_feats)
|
| 317 |
+
mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
|
| 318 |
+
context = self.ocr_gather_head(ocr_feats, mask)
|
| 319 |
+
ocr_feats = self.ocr_distri_head(ocr_feats, context)
|
| 320 |
+
return ocr_feats,
|
| 321 |
+
|
| 322 |
+
def compute_hrnet_feats(self, x, additional_features, return_list=False):
|
| 323 |
+
x = self.compute_pre_stage_features(x, additional_features)
|
| 324 |
+
x = self.layer1(x)
|
| 325 |
+
|
| 326 |
+
x_list = []
|
| 327 |
+
for i in range(self.stage2_num_branches):
|
| 328 |
+
if self.transition1[i] is not None:
|
| 329 |
+
x_list.append(self.transition1[i](x))
|
| 330 |
+
else:
|
| 331 |
+
x_list.append(x)
|
| 332 |
+
y_list = self.stage2(x_list)
|
| 333 |
+
|
| 334 |
+
x_list = []
|
| 335 |
+
for i in range(self.stage3_num_branches):
|
| 336 |
+
if self.transition2[i] is not None:
|
| 337 |
+
if i < self.stage2_num_branches:
|
| 338 |
+
x_list.append(self.transition2[i](y_list[i]))
|
| 339 |
+
else:
|
| 340 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
| 341 |
+
else:
|
| 342 |
+
x_list.append(y_list[i])
|
| 343 |
+
y_list = self.stage3(x_list)
|
| 344 |
+
|
| 345 |
+
x_list = []
|
| 346 |
+
for i in range(self.stage4_num_branches):
|
| 347 |
+
if self.transition3[i] is not None:
|
| 348 |
+
if i < self.stage3_num_branches:
|
| 349 |
+
x_list.append(self.transition3[i](y_list[i]))
|
| 350 |
+
else:
|
| 351 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
| 352 |
+
else:
|
| 353 |
+
x_list.append(y_list[i])
|
| 354 |
+
x = self.stage4(x_list)
|
| 355 |
+
|
| 356 |
+
if return_list:
|
| 357 |
+
return x
|
| 358 |
+
|
| 359 |
+
# Upsampling
|
| 360 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
| 361 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
|
| 362 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 363 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
|
| 364 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 365 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
|
| 366 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 367 |
+
|
| 368 |
+
return torch.cat([x[0], x1, x2, x3], 1)
|
| 369 |
+
|
| 370 |
+
def compute_pre_stage_features(self, x, additional_features):
|
| 371 |
+
x = self.conv1(x)
|
| 372 |
+
x = self.bn1(x)
|
| 373 |
+
x = self.relu(x)
|
| 374 |
+
if additional_features is not None:
|
| 375 |
+
x = x + additional_features
|
| 376 |
+
x = self.conv2(x)
|
| 377 |
+
x = self.bn2(x)
|
| 378 |
+
return self.relu(x)
|
| 379 |
+
|
| 380 |
+
def load_pretrained_weights(self, pretrained_path=''):
|
| 381 |
+
model_dict = self.state_dict()
|
| 382 |
+
|
| 383 |
+
if not os.path.exists(pretrained_path):
|
| 384 |
+
print(f'\nFile "{pretrained_path}" does not exist.')
|
| 385 |
+
print('You need to specify the correct path to the pre-trained weights.\n'
|
| 386 |
+
'You can download the weights for HRNet from the repository:\n'
|
| 387 |
+
'https://github.com/HRNet/HRNet-Image-Classification')
|
| 388 |
+
exit(1)
|
| 389 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 390 |
+
pretrained_dict = torch.load(pretrained_path, map_location=device)
|
| 391 |
+
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
|
| 392 |
+
pretrained_dict.items()}
|
| 393 |
+
params_count = len(pretrained_dict)
|
| 394 |
+
|
| 395 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
| 396 |
+
if k in model_dict.keys()}
|
| 397 |
+
|
| 398 |
+
# print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
|
| 399 |
+
|
| 400 |
+
model_dict.update(pretrained_dict)
|
| 401 |
+
self.load_state_dict(model_dict)
|
inference.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
import albumentations
|
| 5 |
+
from albumentations import Resize
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from model.build_model import build_model
|
| 13 |
+
from datasets.build_dataset import dataset_generator
|
| 14 |
+
|
| 15 |
+
from utils import misc, metrics
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args():
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
|
| 21 |
+
parser.add_argument('--workers', type=int, default=1,
|
| 22 |
+
metavar='N', help='Dataloader threads.')
|
| 23 |
+
|
| 24 |
+
parser.add_argument('--batch_size', type=int, default=1,
|
| 25 |
+
help='You can override model batch size by specify positive number.')
|
| 26 |
+
|
| 27 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 28 |
+
help="Whether use cuda, 'cuda' or 'cpu'.")
|
| 29 |
+
|
| 30 |
+
parser.add_argument('--save_path', type=str, default="./logs",
|
| 31 |
+
help='Where to save logs and checkpoints.')
|
| 32 |
+
|
| 33 |
+
parser.add_argument('--dataset_path', type=str, default=r".\iHarmony4",
|
| 34 |
+
help='Dataset path.')
|
| 35 |
+
|
| 36 |
+
parser.add_argument('--base_size', type=int, default=256,
|
| 37 |
+
help='Base size. Resolution of the image input into the Encoder')
|
| 38 |
+
|
| 39 |
+
parser.add_argument('--input_size', type=int, default=256,
|
| 40 |
+
help='Input size. Resolution of the image that want to be generated by the Decoder')
|
| 41 |
+
|
| 42 |
+
parser.add_argument('--INR_input_size', type=int, default=256,
|
| 43 |
+
help='INR input size. Resolution of the image that want to be generated by the Decoder. '
|
| 44 |
+
'Should be the same as `input_size`')
|
| 45 |
+
|
| 46 |
+
parser.add_argument('--INR_MLP_dim', type=int, default=32,
|
| 47 |
+
help='Number of channels for INR linear layer.')
|
| 48 |
+
|
| 49 |
+
parser.add_argument('--LUT_dim', type=int, default=7,
|
| 50 |
+
help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
|
| 51 |
+
|
| 52 |
+
parser.add_argument('--activation', type=str, default='leakyrelu_pe',
|
| 53 |
+
help='INR activation layer type: leakyrelu_pe, sine')
|
| 54 |
+
|
| 55 |
+
parser.add_argument('--pretrained', type=str,
|
| 56 |
+
default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
|
| 57 |
+
help='Pretrained weight path')
|
| 58 |
+
|
| 59 |
+
parser.add_argument('--param_factorize_dim', type=int,
|
| 60 |
+
default=10,
|
| 61 |
+
help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
|
| 62 |
+
'Refer to https://arxiv.org/abs/2011.12026')
|
| 63 |
+
|
| 64 |
+
parser.add_argument('--embedding_type', type=str,
|
| 65 |
+
default="CIPS_embed",
|
| 66 |
+
help='Which embedding_type to use.')
|
| 67 |
+
|
| 68 |
+
parser.add_argument('--optim', type=str,
|
| 69 |
+
default='adamw',
|
| 70 |
+
help='Which optimizer to use.')
|
| 71 |
+
|
| 72 |
+
parser.add_argument('--INRDecode', action="store_false",
|
| 73 |
+
help='Whether INR decoder. Set it to False if you want to test the baseline '
|
| 74 |
+
'(https://github.com/SamsungLabs/image_harmonization)')
|
| 75 |
+
|
| 76 |
+
parser.add_argument('--isMoreINRInput', action="store_false",
|
| 77 |
+
help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
|
| 78 |
+
|
| 79 |
+
parser.add_argument('--hr_train', action="store_true",
|
| 80 |
+
help='Whether use hr_train. See section 3.4 in the paper.')
|
| 81 |
+
|
| 82 |
+
parser.add_argument('--isFullRes', action="store_true",
|
| 83 |
+
help='Whether for original resolution. See section 3.4 in the paper.')
|
| 84 |
+
|
| 85 |
+
opt = parser.parse_args()
|
| 86 |
+
|
| 87 |
+
opt.save_path = misc.increment_path(os.path.join(opt.save_path, "test1"))
|
| 88 |
+
|
| 89 |
+
return opt
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def inference(val_loader, model, logger, opt):
|
| 93 |
+
current_process = 10
|
| 94 |
+
model.eval()
|
| 95 |
+
|
| 96 |
+
metric_log = {
|
| 97 |
+
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 98 |
+
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 99 |
+
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 100 |
+
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 101 |
+
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
lut_metric_log = {
|
| 105 |
+
'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 106 |
+
'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 107 |
+
'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 108 |
+
'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 109 |
+
'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
for step, batch in enumerate(val_loader):
|
| 113 |
+
composite_image = batch['composite_image'].to(opt.device)
|
| 114 |
+
real_image = batch['real_image'].to(opt.device)
|
| 115 |
+
mask = batch['mask'].to(opt.device)
|
| 116 |
+
category = batch['category']
|
| 117 |
+
|
| 118 |
+
fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
| 122 |
+
composite_image,
|
| 123 |
+
mask,
|
| 124 |
+
fg_INR_coordinates,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if opt.INRDecode:
|
| 128 |
+
pred_fg_image = fg_content_bg_appearance_construct[-1]
|
| 129 |
+
else:
|
| 130 |
+
pred_fg_image = misc.lin2img(fg_content_bg_appearance_construct,
|
| 131 |
+
val_loader.dataset.INR_dataset.size) if fg_content_bg_appearance_construct is not None else None
|
| 132 |
+
|
| 133 |
+
if not opt.INRDecode:
|
| 134 |
+
pred_harmonized_image = None
|
| 135 |
+
else:
|
| 136 |
+
pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
|
| 137 |
+
lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
|
| 138 |
+
|
| 139 |
+
misc.visualize(real_image, composite_image, mask, pred_fg_image,
|
| 140 |
+
pred_harmonized_image, lut_transform_image, opt, -1, show=False,
|
| 141 |
+
wandb=False, isAll=True, step=step)
|
| 142 |
+
|
| 143 |
+
if opt.INRDecode:
|
| 144 |
+
mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'),
|
| 145 |
+
misc.normalize(real_image, opt, 'inv'), mask)
|
| 146 |
+
|
| 147 |
+
lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'),
|
| 148 |
+
misc.normalize(real_image, opt, 'inv'), mask)
|
| 149 |
+
|
| 150 |
+
for idx in range(len(category)):
|
| 151 |
+
if opt.INRDecode:
|
| 152 |
+
metric_log[category[idx]]['Samples'] += 1
|
| 153 |
+
metric_log[category[idx]]['MSE'] += mse[idx]
|
| 154 |
+
metric_log[category[idx]]['fMSE'] += fmse[idx]
|
| 155 |
+
metric_log[category[idx]]['PSNR'] += psnr[idx]
|
| 156 |
+
metric_log[category[idx]]['SSIM'] += ssim[idx]
|
| 157 |
+
|
| 158 |
+
metric_log['All']['Samples'] += 1
|
| 159 |
+
metric_log['All']['MSE'] += mse[idx]
|
| 160 |
+
metric_log['All']['fMSE'] += fmse[idx]
|
| 161 |
+
metric_log['All']['PSNR'] += psnr[idx]
|
| 162 |
+
metric_log['All']['SSIM'] += ssim[idx]
|
| 163 |
+
|
| 164 |
+
lut_metric_log[category[idx]]['Samples'] += 1
|
| 165 |
+
lut_metric_log[category[idx]]['MSE'] += lut_mse[idx]
|
| 166 |
+
lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx]
|
| 167 |
+
lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx]
|
| 168 |
+
lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx]
|
| 169 |
+
|
| 170 |
+
lut_metric_log['All']['Samples'] += 1
|
| 171 |
+
lut_metric_log['All']['MSE'] += lut_mse[idx]
|
| 172 |
+
lut_metric_log['All']['fMSE'] += lut_fmse[idx]
|
| 173 |
+
lut_metric_log['All']['PSNR'] += lut_psnr[idx]
|
| 174 |
+
lut_metric_log['All']['SSIM'] += lut_ssim[idx]
|
| 175 |
+
|
| 176 |
+
if (step + 1) / len(val_loader) * 100 >= current_process:
|
| 177 |
+
logger.info(f'Processing: {current_process}')
|
| 178 |
+
current_process += 10
|
| 179 |
+
|
| 180 |
+
logger.info('=========================')
|
| 181 |
+
for key in metric_log.keys():
|
| 182 |
+
if opt.INRDecode:
|
| 183 |
+
msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \
|
| 184 |
+
f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \
|
| 185 |
+
f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \
|
| 186 |
+
f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \
|
| 187 |
+
f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
|
| 188 |
+
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
|
| 189 |
+
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
|
| 190 |
+
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
|
| 191 |
+
else:
|
| 192 |
+
msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
|
| 193 |
+
f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
|
| 194 |
+
f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
|
| 195 |
+
f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
|
| 196 |
+
|
| 197 |
+
logger.info(msg)
|
| 198 |
+
|
| 199 |
+
logger.info('=========================')
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def main_process(opt):
|
| 203 |
+
logger = misc.create_logger(os.path.join(opt.save_path, "log.txt"))
|
| 204 |
+
cudnn.benchmark = True
|
| 205 |
+
|
| 206 |
+
valset_path = os.path.join(opt.dataset_path, "IHD_test.txt")
|
| 207 |
+
|
| 208 |
+
opt.transform_mean = [.5, .5, .5]
|
| 209 |
+
opt.transform_var = [.5, .5, .5]
|
| 210 |
+
torch_transform = transforms.Compose([transforms.ToTensor(),
|
| 211 |
+
transforms.Normalize(opt.transform_mean, opt.transform_var)])
|
| 212 |
+
|
| 213 |
+
valset_alb_transform = albumentations.Compose([Resize(opt.input_size, opt.input_size)],
|
| 214 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'})
|
| 215 |
+
|
| 216 |
+
valset = dataset_generator(valset_path, valset_alb_transform, torch_transform, opt, mode='Val')
|
| 217 |
+
|
| 218 |
+
val_loader = DataLoader(valset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
|
| 219 |
+
num_workers=opt.workers, persistent_workers=True)
|
| 220 |
+
|
| 221 |
+
model = build_model(opt).to(opt.device)
|
| 222 |
+
logger.info(f"Load pretrained weight from {opt.pretrained}")
|
| 223 |
+
|
| 224 |
+
load_dict = torch.load(opt.pretrained)['model']
|
| 225 |
+
for k in load_dict.keys():
|
| 226 |
+
if k not in model.state_dict().keys():
|
| 227 |
+
print(f"Skip {k}")
|
| 228 |
+
model.load_state_dict(load_dict, strict=False)
|
| 229 |
+
|
| 230 |
+
inference(val_loader, model, logger, opt)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == '__main__':
|
| 234 |
+
opt = parse_args()
|
| 235 |
+
os.makedirs(opt.save_path, exist_ok=True)
|
| 236 |
+
main_process(opt)
|
inference_for_arbitrary_resolution_image.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch.backends.cudnn as cudnn
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
from model.build_model import build_model
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torchvision
|
| 13 |
+
import os
|
| 14 |
+
import tqdm
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
from utils.misc import prepare_cooridinate_input, customRandomCrop
|
| 18 |
+
|
| 19 |
+
from datasets.build_INR_dataset import Implicit2DGenerator
|
| 20 |
+
import albumentations
|
| 21 |
+
from albumentations import Resize
|
| 22 |
+
from torch.utils.data import DataLoader
|
| 23 |
+
from utils.misc import normalize
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
|
| 27 |
+
global_state = [1] # For Gradio Stop Button.
|
| 28 |
+
|
| 29 |
+
class single_image_dataset(torch.utils.data.Dataset):
|
| 30 |
+
def __init__(self, opt, composite_image=None, mask=None):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.opt = opt
|
| 34 |
+
|
| 35 |
+
if composite_image is None:
|
| 36 |
+
composite_image = cv2.imread(opt.composite_image)
|
| 37 |
+
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
|
| 38 |
+
self.composite_image = composite_image
|
| 39 |
+
|
| 40 |
+
if mask is None:
|
| 41 |
+
mask = cv2.imread(opt.mask)
|
| 42 |
+
mask = mask[:, :, 0].astype(np.float32) / 255.
|
| 43 |
+
self.mask = mask
|
| 44 |
+
|
| 45 |
+
self.torch_transforms = transforms.Compose([transforms.ToTensor(),
|
| 46 |
+
transforms.Normalize([.5, .5, .5], [.5, .5, .5])])
|
| 47 |
+
self.INR_dataset = Implicit2DGenerator(opt, 'Val')
|
| 48 |
+
|
| 49 |
+
self.split_width_resolution = composite_image.shape[1] // opt.split_num
|
| 50 |
+
self.split_height_resolution = composite_image.shape[0] // opt.split_num
|
| 51 |
+
|
| 52 |
+
self.split_width_resolution = self.split_height_resolution = min(self.split_width_resolution,
|
| 53 |
+
self.split_height_resolution)
|
| 54 |
+
|
| 55 |
+
if self.split_width_resolution % 4 != 0:
|
| 56 |
+
self.split_width_resolution = self.split_width_resolution + (4 - self.split_width_resolution % 4)
|
| 57 |
+
|
| 58 |
+
if self.split_height_resolution % 4 != 0:
|
| 59 |
+
self.split_height_resolution = self.split_height_resolution + (4 - self.split_height_resolution % 4)
|
| 60 |
+
|
| 61 |
+
self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution)
|
| 62 |
+
self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution)
|
| 63 |
+
|
| 64 |
+
self.split_start_point = []
|
| 65 |
+
|
| 66 |
+
"Split the image into several parts."
|
| 67 |
+
for i in range(self.num_h):
|
| 68 |
+
for j in range(self.num_w):
|
| 69 |
+
if i == composite_image.shape[0] // self.split_height_resolution:
|
| 70 |
+
if j == composite_image.shape[1] // self.split_width_resolution:
|
| 71 |
+
self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution,
|
| 72 |
+
composite_image.shape[1] - self.split_width_resolution))
|
| 73 |
+
else:
|
| 74 |
+
self.split_start_point.append(
|
| 75 |
+
(composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution))
|
| 76 |
+
else:
|
| 77 |
+
if j == composite_image.shape[1] // self.split_width_resolution:
|
| 78 |
+
self.split_start_point.append(
|
| 79 |
+
(i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution))
|
| 80 |
+
else:
|
| 81 |
+
self.split_start_point.append(
|
| 82 |
+
(i * self.split_height_resolution, j * self.split_width_resolution))
|
| 83 |
+
|
| 84 |
+
assert len(self.split_start_point) == self.num_w * self.num_h
|
| 85 |
+
|
| 86 |
+
print(
|
| 87 |
+
f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.")
|
| 88 |
+
print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}")
|
| 89 |
+
|
| 90 |
+
def __len__(self):
|
| 91 |
+
return self.num_w * self.num_h
|
| 92 |
+
|
| 93 |
+
def __getitem__(self, idx):
|
| 94 |
+
composite_image = self.composite_image
|
| 95 |
+
|
| 96 |
+
mask = self.mask
|
| 97 |
+
|
| 98 |
+
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
|
| 99 |
+
|
| 100 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
|
| 101 |
+
additional_targets={'object_mask': 'image'})
|
| 102 |
+
transform_out = tmp_transform(image=composite_image, object_mask=mask)
|
| 103 |
+
compos_list = [self.torch_transforms(transform_out['image'])]
|
| 104 |
+
mask_list = [
|
| 105 |
+
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
|
| 106 |
+
coord_map_list = []
|
| 107 |
+
|
| 108 |
+
if composite_image.shape[0] != self.split_height_resolution:
|
| 109 |
+
c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution)
|
| 110 |
+
else:
|
| 111 |
+
c_h = 0
|
| 112 |
+
if composite_image.shape[1] != self.split_width_resolution:
|
| 113 |
+
c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution)
|
| 114 |
+
else:
|
| 115 |
+
c_w = 0
|
| 116 |
+
transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord],
|
| 117 |
+
self.split_height_resolution, self.split_width_resolution, c_h, c_w)
|
| 118 |
+
|
| 119 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 120 |
+
mask_list.append(
|
| 121 |
+
torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
|
| 122 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
|
| 123 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
|
| 124 |
+
for n in range(2):
|
| 125 |
+
tmp_comp = cv2.resize(composite_image, (
|
| 126 |
+
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
|
| 127 |
+
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
|
| 128 |
+
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
|
| 129 |
+
|
| 130 |
+
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord],
|
| 131 |
+
self.split_height_resolution // 2 ** (n + 1),
|
| 132 |
+
self.split_width_resolution // 2 ** (n + 1), c_h, c_w)
|
| 133 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
| 134 |
+
mask_list.append(
|
| 135 |
+
torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
|
| 136 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
|
| 137 |
+
out_comp = compos_list
|
| 138 |
+
out_mask = mask_list
|
| 139 |
+
out_coord = coord_map_list
|
| 140 |
+
|
| 141 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
| 142 |
+
self.torch_transforms, transform_out[0], transform_out[0], mask)
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
'composite_image': out_comp,
|
| 146 |
+
'mask': out_mask,
|
| 147 |
+
'coordinate_map': out_coord,
|
| 148 |
+
'composite_image0': out_comp[0],
|
| 149 |
+
'mask0': out_mask[0],
|
| 150 |
+
'coordinate_map0': out_coord[0],
|
| 151 |
+
'composite_image1': out_comp[1],
|
| 152 |
+
'mask1': out_mask[1],
|
| 153 |
+
'coordinate_map1': out_coord[1],
|
| 154 |
+
'composite_image2': out_comp[2],
|
| 155 |
+
'mask2': out_mask[2],
|
| 156 |
+
'coordinate_map2': out_coord[2],
|
| 157 |
+
'composite_image3': out_comp[3],
|
| 158 |
+
'mask3': out_mask[3],
|
| 159 |
+
'coordinate_map3': out_coord[3],
|
| 160 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
| 161 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
| 162 |
+
'fg_INR_RGB': fg_INR_RGB,
|
| 163 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
| 164 |
+
'bg_INR_RGB': bg_INR_RGB,
|
| 165 |
+
'start_point': self.split_start_point[idx],
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def parse_args():
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
|
| 172 |
+
parser.add_argument('--split_num', type=int, default=4,
|
| 173 |
+
help='How many pieces do you want to split an image width / height.')
|
| 174 |
+
|
| 175 |
+
parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg',
|
| 176 |
+
help='composite image path')
|
| 177 |
+
|
| 178 |
+
parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg',
|
| 179 |
+
help='mask path')
|
| 180 |
+
|
| 181 |
+
parser.add_argument('--save_path', type=str, default=r'./demo/',
|
| 182 |
+
help='save path')
|
| 183 |
+
|
| 184 |
+
parser.add_argument('--workers', type=int, default=8,
|
| 185 |
+
metavar='N', help='Dataloader threads.')
|
| 186 |
+
|
| 187 |
+
parser.add_argument('--batch_size', type=int, default=1,
|
| 188 |
+
help='You can override model batch size by specify positive number.')
|
| 189 |
+
|
| 190 |
+
parser.add_argument('--device', type=str, default='cuda',
|
| 191 |
+
help="Whether use cuda, 'cuda' or 'cpu'.")
|
| 192 |
+
|
| 193 |
+
parser.add_argument('--base_size', type=int, default=256,
|
| 194 |
+
help='Base size. Resolution of the image input into the Encoder')
|
| 195 |
+
|
| 196 |
+
parser.add_argument('--input_size', type=int, default=256,
|
| 197 |
+
help='Input size. Resolution of the image that want to be generated by the Decoder')
|
| 198 |
+
|
| 199 |
+
parser.add_argument('--INR_input_size', type=int, default=256,
|
| 200 |
+
help='INR input size. Resolution of the image that want to be generated by the Decoder. '
|
| 201 |
+
'Should be the same as `input_size`')
|
| 202 |
+
|
| 203 |
+
parser.add_argument('--INR_MLP_dim', type=int, default=32,
|
| 204 |
+
help='Number of channels for INR linear layer.')
|
| 205 |
+
|
| 206 |
+
parser.add_argument('--LUT_dim', type=int, default=7,
|
| 207 |
+
help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
|
| 208 |
+
|
| 209 |
+
parser.add_argument('--activation', type=str, default='leakyrelu_pe',
|
| 210 |
+
help='INR activation layer type: leakyrelu_pe, sine')
|
| 211 |
+
|
| 212 |
+
parser.add_argument('--pretrained', type=str,
|
| 213 |
+
default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
|
| 214 |
+
help='Pretrained weight path')
|
| 215 |
+
|
| 216 |
+
parser.add_argument('--param_factorize_dim', type=int,
|
| 217 |
+
default=10,
|
| 218 |
+
help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
|
| 219 |
+
'Refer to https://arxiv.org/abs/2011.12026')
|
| 220 |
+
|
| 221 |
+
parser.add_argument('--embedding_type', type=str,
|
| 222 |
+
default="CIPS_embed",
|
| 223 |
+
help='Which embedding_type to use.')
|
| 224 |
+
|
| 225 |
+
parser.add_argument('--INRDecode', action="store_false",
|
| 226 |
+
help='Whether INR decoder. Set it to False if you want to test the baseline '
|
| 227 |
+
'(https://github.com/SamsungLabs/image_harmonization)')
|
| 228 |
+
|
| 229 |
+
parser.add_argument('--isMoreINRInput', action="store_false",
|
| 230 |
+
help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
|
| 231 |
+
|
| 232 |
+
parser.add_argument('--hr_train', action="store_false",
|
| 233 |
+
help='Whether use hr_train. See section 3.4 in the paper.')
|
| 234 |
+
|
| 235 |
+
parser.add_argument('--isFullRes', action="store_true",
|
| 236 |
+
help='Whether for original resolution. See section 3.4 in the paper.')
|
| 237 |
+
|
| 238 |
+
opt = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
return opt
|
| 241 |
+
|
| 242 |
+
@torch.no_grad()
|
| 243 |
+
def inference(model, opt, composite_image=None, mask=None):
|
| 244 |
+
model.eval()
|
| 245 |
+
|
| 246 |
+
"dataset here is actually consisted of several patches of a single image."
|
| 247 |
+
singledataset = single_image_dataset(opt, composite_image, mask)
|
| 248 |
+
|
| 249 |
+
single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
|
| 250 |
+
num_workers=opt.workers, persistent_workers=False if composite_image is not None else True)
|
| 251 |
+
|
| 252 |
+
"Init a pure black image with the same size as the input image."
|
| 253 |
+
init_img = np.zeros_like(singledataset.composite_image)
|
| 254 |
+
|
| 255 |
+
time_all = 0
|
| 256 |
+
|
| 257 |
+
for step, batch in tqdm.tqdm(enumerate(single_data_loader)):
|
| 258 |
+
composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
|
| 259 |
+
mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
|
| 260 |
+
coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]
|
| 261 |
+
start_points = batch['start_point']
|
| 262 |
+
|
| 263 |
+
if opt.batch_size == 1:
|
| 264 |
+
start_points = [torch.cat(start_points)]
|
| 265 |
+
|
| 266 |
+
fg_INR_coordinates = coordinate_map[1:]
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
if global_state[0] == 0:
|
| 270 |
+
print("Stop Harmonizing...!")
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
|
| 274 |
+
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
| 275 |
+
composite_image,
|
| 276 |
+
mask,
|
| 277 |
+
fg_INR_coordinates,
|
| 278 |
+
)
|
| 279 |
+
print("Ready for harmonization...")
|
| 280 |
+
|
| 281 |
+
if opt.device == "cuda":
|
| 282 |
+
torch.cuda.reset_max_memory_allocated()
|
| 283 |
+
torch.cuda.reset_max_memory_cached()
|
| 284 |
+
start_time = time.time()
|
| 285 |
+
torch.cuda.synchronize()
|
| 286 |
+
fg_content_bg_appearance_construct, _, lut_transform_image = model(
|
| 287 |
+
composite_image,
|
| 288 |
+
mask,
|
| 289 |
+
fg_INR_coordinates,
|
| 290 |
+
)
|
| 291 |
+
if opt.device == "cuda":
|
| 292 |
+
torch.cuda.synchronize()
|
| 293 |
+
end_time = time.time()
|
| 294 |
+
|
| 295 |
+
end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2
|
| 296 |
+
end_memory = torch.cuda.memory_allocated() // 1024 ** 2
|
| 297 |
+
|
| 298 |
+
print(f'GPU max memory usage: {end_max_memory} MB')
|
| 299 |
+
print(f'GPU memory usage: {end_memory} MB')
|
| 300 |
+
time_all += (end_time - start_time)
|
| 301 |
+
print(f'progress: {step} / {len(single_data_loader)}')
|
| 302 |
+
except:
|
| 303 |
+
raise Exception(
|
| 304 |
+
f'The image resolution is large. Please increase the `split_num` value. Your current set is {opt.split_num}')
|
| 305 |
+
|
| 306 |
+
"Assemble the every patch's harmonized result into the final whole image."
|
| 307 |
+
for id in range(len(fg_INR_coordinates[0])):
|
| 308 |
+
pred_fg_image = fg_content_bg_appearance_construct[-1][id]
|
| 309 |
+
pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * (
|
| 310 |
+
~(mask[1][id] > 100 / 255.))
|
| 311 |
+
|
| 312 |
+
pred_harmonized_tmp = cv2.cvtColor(
|
| 313 |
+
normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_(
|
| 314 |
+
0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)
|
| 315 |
+
|
| 316 |
+
init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
|
| 317 |
+
start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
|
| 318 |
+
|
| 319 |
+
if opt.device == "cuda":
|
| 320 |
+
print(f'Inference time: {time_all}')
|
| 321 |
+
if opt.save_path is not None:
|
| 322 |
+
os.makedirs(opt.save_path, exist_ok=True)
|
| 323 |
+
cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
|
| 324 |
+
return init_img
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def main_process(opt, composite_image=None, mask=None):
|
| 328 |
+
cudnn.benchmark = True
|
| 329 |
+
# Заменяем 'gpu' на 'cuda'
|
| 330 |
+
opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 331 |
+
print("Preparing model...")
|
| 332 |
+
model = build_model(opt).to(opt.device)
|
| 333 |
+
|
| 334 |
+
load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
|
| 335 |
+
|
| 336 |
+
model.load_state_dict(load_dict, strict=False)
|
| 337 |
+
|
| 338 |
+
return inference(model, opt, composite_image, mask)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if __name__ == '__main__':
|
| 342 |
+
opt = parse_args()
|
| 343 |
+
opt.transform_mean = [.5, .5, .5]
|
| 344 |
+
opt.transform_var = [.5, .5, .5]
|
| 345 |
+
main_process(opt)
|
model/__init__.py
ADDED
|
File without changes
|
model/backbone.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from .hrnetv2.hrnet_ocr import HighResolutionNet
|
| 4 |
+
from .hrnetv2.modifiers import LRMult
|
| 5 |
+
from .base.basic_blocks import MaxPoolDownSize
|
| 6 |
+
from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build_backbone(name, opt):
|
| 10 |
+
return eval(name)(opt)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class baseline(IHModelWithBackbone):
|
| 14 |
+
def __init__(self, opt, ocr=64):
|
| 15 |
+
base_config = {'model': DeepImageHarmonization,
|
| 16 |
+
'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}
|
| 17 |
+
|
| 18 |
+
params = base_config['params']
|
| 19 |
+
|
| 20 |
+
backbone = HRNetV2(opt, ocr=ocr)
|
| 21 |
+
|
| 22 |
+
params.update(dict(
|
| 23 |
+
backbone_from=2,
|
| 24 |
+
backbone_channels=backbone.output_channels,
|
| 25 |
+
backbone_mode='cat',
|
| 26 |
+
opt=opt
|
| 27 |
+
))
|
| 28 |
+
base_model = base_config['model'](**params)
|
| 29 |
+
|
| 30 |
+
super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HRNetV2(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self, opt,
|
| 36 |
+
cat_outputs=True,
|
| 37 |
+
pyramid_channels=-1, pyramid_depth=4,
|
| 38 |
+
width=18, ocr=128, small=False,
|
| 39 |
+
lr_mult=0.1, pretained=True
|
| 40 |
+
):
|
| 41 |
+
super(HRNetV2, self).__init__()
|
| 42 |
+
self.opt = opt
|
| 43 |
+
self.cat_outputs = cat_outputs
|
| 44 |
+
self.ocr_on = ocr > 0 and cat_outputs
|
| 45 |
+
self.pyramid_on = pyramid_channels > 0 and cat_outputs
|
| 46 |
+
|
| 47 |
+
self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
|
| 48 |
+
self.hrnet.apply(LRMult(lr_mult))
|
| 49 |
+
if self.ocr_on:
|
| 50 |
+
self.hrnet.ocr_distri_head.apply(LRMult(1.0))
|
| 51 |
+
self.hrnet.ocr_gather_head.apply(LRMult(1.0))
|
| 52 |
+
self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
|
| 53 |
+
|
| 54 |
+
hrnet_cat_channels = [width * 2 ** i for i in range(4)]
|
| 55 |
+
if self.pyramid_on:
|
| 56 |
+
self.output_channels = [pyramid_channels] * 4
|
| 57 |
+
elif self.ocr_on:
|
| 58 |
+
self.output_channels = [ocr * 2]
|
| 59 |
+
elif self.cat_outputs:
|
| 60 |
+
self.output_channels = [sum(hrnet_cat_channels)]
|
| 61 |
+
else:
|
| 62 |
+
self.output_channels = hrnet_cat_channels
|
| 63 |
+
|
| 64 |
+
if self.pyramid_on:
|
| 65 |
+
downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
|
| 66 |
+
self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
|
| 67 |
+
|
| 68 |
+
if pretained:
|
| 69 |
+
self.load_pretrained_weights(
|
| 70 |
+
"./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")
|
| 71 |
+
|
| 72 |
+
self.output_resolution = (opt.input_size // 8) ** 2
|
| 73 |
+
|
| 74 |
+
def forward(self, image, mask, mask_features=None):
|
| 75 |
+
outputs = list(self.hrnet(image, mask, mask_features))
|
| 76 |
+
return outputs
|
| 77 |
+
|
| 78 |
+
def load_pretrained_weights(self, pretrained_path):
|
| 79 |
+
self.hrnet.load_pretrained_weights(pretrained_path)
|
model/base/__init__.py
ADDED
|
File without changes
|
model/base/basic_blocks.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def hyper_weight_init(m, in_features_main_net, activation):
|
| 7 |
+
if hasattr(m, 'weight'):
|
| 8 |
+
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
|
| 9 |
+
m.weight.data = m.weight.data / 1.e2
|
| 10 |
+
|
| 11 |
+
if hasattr(m, 'bias'):
|
| 12 |
+
with torch.no_grad():
|
| 13 |
+
if activation == 'sine':
|
| 14 |
+
m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30)
|
| 15 |
+
elif activation == 'leakyrelu_pe':
|
| 16 |
+
m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net))
|
| 17 |
+
else:
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConvBlock(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_channels, out_channels,
|
| 25 |
+
kernel_size=4, stride=2, padding=1,
|
| 26 |
+
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
|
| 27 |
+
bias=True,
|
| 28 |
+
):
|
| 29 |
+
super(ConvBlock, self).__init__()
|
| 30 |
+
self.block = nn.Sequential(
|
| 31 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
|
| 32 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
| 33 |
+
activation(),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return self.block(x)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MaxPoolDownSize(nn.Module):
|
| 41 |
+
def __init__(self, in_channels, mid_channels, out_channels, depth):
|
| 42 |
+
super(MaxPoolDownSize, self).__init__()
|
| 43 |
+
self.depth = depth
|
| 44 |
+
self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
|
| 45 |
+
self.convs = nn.ModuleList([
|
| 46 |
+
ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 47 |
+
for conv_i in range(depth)
|
| 48 |
+
])
|
| 49 |
+
self.pool2d = nn.MaxPool2d(kernel_size=2)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
outputs = []
|
| 53 |
+
|
| 54 |
+
output = self.reduce_conv(x)
|
| 55 |
+
|
| 56 |
+
for conv_i, conv in enumerate(self.convs):
|
| 57 |
+
output = output if conv_i == 0 else self.pool2d(output)
|
| 58 |
+
outputs.append(conv(output))
|
| 59 |
+
|
| 60 |
+
return outputs
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class convParams(nn.Module):
|
| 64 |
+
def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False):
|
| 65 |
+
super(convParams, self).__init__()
|
| 66 |
+
self.INR_in_out = INR_in_out
|
| 67 |
+
self.cont_split_weight = []
|
| 68 |
+
self.cont_split_bias = []
|
| 69 |
+
self.hidden_mlp_num = hidden_mlp_num
|
| 70 |
+
self.param_factorize_dim = opt.param_factorize_dim
|
| 71 |
+
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB)
|
| 72 |
+
self.output_dim = output_dim
|
| 73 |
+
self.toRGB = toRGB
|
| 74 |
+
self.cont_extraction_net = nn.Sequential(
|
| 75 |
+
nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
|
| 76 |
+
# nn.BatchNorm2d(hidden_dim),
|
| 77 |
+
nn.ReLU(inplace=True),
|
| 78 |
+
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
|
| 79 |
+
# nn.BatchNorm2d(hidden_dim),
|
| 80 |
+
nn.ReLU(inplace=True),
|
| 81 |
+
nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
|
| 85 |
+
|
| 86 |
+
self.basic_params = nn.ParameterList()
|
| 87 |
+
if opt.param_factorize_dim > 0:
|
| 88 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 89 |
+
if id == 0:
|
| 90 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 91 |
+
else:
|
| 92 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 93 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp)))
|
| 94 |
+
|
| 95 |
+
if toRGB:
|
| 96 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3)))
|
| 97 |
+
|
| 98 |
+
def forward(self, feat, outMore=False):
|
| 99 |
+
cont_params = self.cont_extraction_net(feat)
|
| 100 |
+
out_mlp = self.to_mlp(cont_params)
|
| 101 |
+
if outMore:
|
| 102 |
+
return out_mlp, cont_params
|
| 103 |
+
return out_mlp
|
| 104 |
+
|
| 105 |
+
def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False):
|
| 106 |
+
cont_params = 0
|
| 107 |
+
start = 0
|
| 108 |
+
if self.param_factorize_dim == -1:
|
| 109 |
+
cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
|
| 110 |
+
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
|
| 111 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 112 |
+
start = cont_params
|
| 113 |
+
|
| 114 |
+
for id in range(hidden_mlp_num):
|
| 115 |
+
cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
|
| 116 |
+
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
|
| 117 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 118 |
+
start = cont_params
|
| 119 |
+
|
| 120 |
+
if toRGB:
|
| 121 |
+
cont_params += INR_in_out[1] * 3 + 3
|
| 122 |
+
self.cont_split_weight.append([start, cont_params - 3])
|
| 123 |
+
self.cont_split_bias.append([cont_params - 3, cont_params])
|
| 124 |
+
|
| 125 |
+
elif self.param_factorize_dim > 0:
|
| 126 |
+
cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 127 |
+
INR_in_out[1]
|
| 128 |
+
self.cont_split_weight.append(
|
| 129 |
+
[start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]])
|
| 130 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 131 |
+
start = cont_params
|
| 132 |
+
|
| 133 |
+
for id in range(hidden_mlp_num):
|
| 134 |
+
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 135 |
+
INR_in_out[1]
|
| 136 |
+
self.cont_split_weight.append(
|
| 137 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]])
|
| 138 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
| 139 |
+
start = cont_params
|
| 140 |
+
|
| 141 |
+
if toRGB:
|
| 142 |
+
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
|
| 143 |
+
self.cont_split_weight.append(
|
| 144 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3])
|
| 145 |
+
self.cont_split_bias.append([cont_params - 3, cont_params])
|
| 146 |
+
|
| 147 |
+
return cont_params
|
| 148 |
+
|
| 149 |
+
def to_mlp(self, params):
|
| 150 |
+
all_weight_bias = []
|
| 151 |
+
if self.param_factorize_dim == -1:
|
| 152 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 153 |
+
if id == 0:
|
| 154 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 155 |
+
else:
|
| 156 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 157 |
+
weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
|
| 158 |
+
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
|
| 159 |
+
inp, outp)
|
| 160 |
+
|
| 161 |
+
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
|
| 162 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 163 |
+
all_weight_bias.append([weight, bias])
|
| 164 |
+
|
| 165 |
+
if self.toRGB:
|
| 166 |
+
inp, outp = self.INR_in_out[1], 3
|
| 167 |
+
weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
|
| 168 |
+
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
|
| 169 |
+
inp, outp)
|
| 170 |
+
|
| 171 |
+
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
|
| 172 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 173 |
+
all_weight_bias.append([weight, bias])
|
| 174 |
+
|
| 175 |
+
return all_weight_bias
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 179 |
+
if id == 0:
|
| 180 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 181 |
+
else:
|
| 182 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 183 |
+
weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
|
| 184 |
+
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
|
| 185 |
+
inp, self.param_factorize_dim)
|
| 186 |
+
|
| 187 |
+
weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :]
|
| 188 |
+
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
|
| 189 |
+
self.param_factorize_dim, outp)
|
| 190 |
+
|
| 191 |
+
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
|
| 192 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 193 |
+
|
| 194 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
| 195 |
+
|
| 196 |
+
if self.toRGB:
|
| 197 |
+
inp, outp = self.INR_in_out[1], 3
|
| 198 |
+
weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
|
| 199 |
+
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
|
| 200 |
+
inp, self.param_factorize_dim)
|
| 201 |
+
|
| 202 |
+
weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :]
|
| 203 |
+
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
|
| 204 |
+
self.param_factorize_dim, outp)
|
| 205 |
+
|
| 206 |
+
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
|
| 207 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
| 208 |
+
|
| 209 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias])
|
| 210 |
+
|
| 211 |
+
return all_weight_bias
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class lineParams(nn.Module):
|
| 215 |
+
def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False,
|
| 216 |
+
hidden_dim=512):
|
| 217 |
+
super(lineParams, self).__init__()
|
| 218 |
+
self.INR_in_out = INR_in_out
|
| 219 |
+
self.app_split_weight = []
|
| 220 |
+
self.app_split_bias = []
|
| 221 |
+
self.toRGB = toRGB
|
| 222 |
+
self.hidden_mlp_num = hidden_mlp_num
|
| 223 |
+
self.param_factorize_dim = opt.param_factorize_dim
|
| 224 |
+
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num)
|
| 225 |
+
self.output_dim = output_dim
|
| 226 |
+
|
| 227 |
+
self.compress_layer = nn.Sequential(
|
| 228 |
+
nn.Linear(input_resolution, 64, bias=False),
|
| 229 |
+
nn.BatchNorm1d(input_dim),
|
| 230 |
+
nn.ReLU(inplace=True),
|
| 231 |
+
nn.Linear(64, 1, bias=True)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.app_extraction_net = nn.Sequential(
|
| 235 |
+
nn.Linear(input_dim, hidden_dim, bias=False),
|
| 236 |
+
# nn.BatchNorm1d(hidden_dim),
|
| 237 |
+
nn.ReLU(inplace=True),
|
| 238 |
+
nn.Linear(hidden_dim, hidden_dim, bias=False),
|
| 239 |
+
# nn.BatchNorm1d(hidden_dim),
|
| 240 |
+
nn.ReLU(inplace=True),
|
| 241 |
+
nn.Linear(hidden_dim, output_dim, bias=True)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
|
| 245 |
+
|
| 246 |
+
self.basic_params = nn.ParameterList()
|
| 247 |
+
if opt.param_factorize_dim > 0:
|
| 248 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 249 |
+
if id == 0:
|
| 250 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 251 |
+
else:
|
| 252 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 253 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp)))
|
| 254 |
+
if toRGB:
|
| 255 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3)))
|
| 256 |
+
|
| 257 |
+
def forward(self, feat):
|
| 258 |
+
app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1))
|
| 259 |
+
out_mlp = self.to_mlp(app_params)
|
| 260 |
+
return out_mlp, app_params
|
| 261 |
+
|
| 262 |
+
def cal_params_num(self, INR_in_out, hidden_mlp_num):
|
| 263 |
+
app_params = 0
|
| 264 |
+
start = 0
|
| 265 |
+
if self.param_factorize_dim == -1:
|
| 266 |
+
app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
|
| 267 |
+
self.app_split_weight.append([start, app_params - INR_in_out[1]])
|
| 268 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 269 |
+
start = app_params
|
| 270 |
+
|
| 271 |
+
for id in range(hidden_mlp_num):
|
| 272 |
+
app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
|
| 273 |
+
self.app_split_weight.append([start, app_params - INR_in_out[1]])
|
| 274 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 275 |
+
start = app_params
|
| 276 |
+
|
| 277 |
+
if self.toRGB:
|
| 278 |
+
app_params += INR_in_out[1] * 3 + 3
|
| 279 |
+
self.app_split_weight.append([start, app_params - 3])
|
| 280 |
+
self.app_split_bias.append([app_params - 3, app_params])
|
| 281 |
+
|
| 282 |
+
elif self.param_factorize_dim > 0:
|
| 283 |
+
app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 284 |
+
INR_in_out[1]
|
| 285 |
+
self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim,
|
| 286 |
+
app_params - INR_in_out[1]])
|
| 287 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 288 |
+
start = app_params
|
| 289 |
+
|
| 290 |
+
for id in range(hidden_mlp_num):
|
| 291 |
+
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
| 292 |
+
INR_in_out[1]
|
| 293 |
+
self.app_split_weight.append(
|
| 294 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]])
|
| 295 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
| 296 |
+
start = app_params
|
| 297 |
+
|
| 298 |
+
if self.toRGB:
|
| 299 |
+
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
|
| 300 |
+
self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim,
|
| 301 |
+
app_params - 3])
|
| 302 |
+
self.app_split_bias.append([app_params - 3, app_params])
|
| 303 |
+
|
| 304 |
+
return app_params
|
| 305 |
+
|
| 306 |
+
def to_mlp(self, params):
|
| 307 |
+
all_weight_bias = []
|
| 308 |
+
if self.param_factorize_dim == -1:
|
| 309 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 310 |
+
if id == 0:
|
| 311 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 312 |
+
else:
|
| 313 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 314 |
+
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 315 |
+
weight = weight.view(weight.shape[0], inp, outp)
|
| 316 |
+
|
| 317 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 318 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 319 |
+
|
| 320 |
+
all_weight_bias.append([weight, bias])
|
| 321 |
+
|
| 322 |
+
if self.toRGB:
|
| 323 |
+
id = -1
|
| 324 |
+
inp, outp = self.INR_in_out[1], 3
|
| 325 |
+
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 326 |
+
weight = weight.view(weight.shape[0], inp, outp)
|
| 327 |
+
|
| 328 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 329 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 330 |
+
|
| 331 |
+
all_weight_bias.append([weight, bias])
|
| 332 |
+
|
| 333 |
+
return all_weight_bias
|
| 334 |
+
|
| 335 |
+
else:
|
| 336 |
+
for id in range(self.hidden_mlp_num + 1):
|
| 337 |
+
if id == 0:
|
| 338 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
| 339 |
+
else:
|
| 340 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
| 341 |
+
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 342 |
+
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
|
| 343 |
+
|
| 344 |
+
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
|
| 345 |
+
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
|
| 346 |
+
|
| 347 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 348 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 349 |
+
|
| 350 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
| 351 |
+
|
| 352 |
+
if self.toRGB:
|
| 353 |
+
id = -1
|
| 354 |
+
inp, outp = self.INR_in_out[1], 3
|
| 355 |
+
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
| 356 |
+
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
|
| 357 |
+
|
| 358 |
+
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
|
| 359 |
+
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
|
| 360 |
+
|
| 361 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
| 362 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
| 363 |
+
|
| 364 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
| 365 |
+
|
| 366 |
+
return all_weight_bias
|
model/base/conv_autoencoder.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from .basic_blocks import ConvBlock, lineParams, convParams
|
| 9 |
+
from .ops import MaskedChannelAttention, FeaturesConnector
|
| 10 |
+
from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed
|
| 11 |
+
from utils import misc
|
| 12 |
+
from utils.misc import lin2img
|
| 13 |
+
from ..lut_transformation_net import build_lut_transform
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Sine(nn.Module):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
def forward(self, input):
|
| 21 |
+
return torch.sin(30 * input)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Leaky_relu(nn.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
def forward(self, input):
|
| 29 |
+
return torch.nn.functional.leaky_relu(input, 0.01, inplace=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def select_activation(type):
|
| 33 |
+
if type == 'sine':
|
| 34 |
+
return Sine()
|
| 35 |
+
elif type == 'leakyrelu_pe':
|
| 36 |
+
return Leaky_relu()
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ConvEncoder(nn.Module):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
depth, ch,
|
| 45 |
+
norm_layer, batchnorm_from, max_channels,
|
| 46 |
+
backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False
|
| 47 |
+
):
|
| 48 |
+
super(ConvEncoder, self).__init__()
|
| 49 |
+
self.depth = depth
|
| 50 |
+
self.INRDecode = INRDecode
|
| 51 |
+
self.backbone_from = backbone_from
|
| 52 |
+
backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
|
| 53 |
+
|
| 54 |
+
in_channels = 4
|
| 55 |
+
out_channels = ch
|
| 56 |
+
|
| 57 |
+
self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None)
|
| 58 |
+
self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None)
|
| 59 |
+
self.blocks_channels = [out_channels, out_channels]
|
| 60 |
+
|
| 61 |
+
self.blocks_connected = nn.ModuleDict()
|
| 62 |
+
self.connectors = nn.ModuleDict()
|
| 63 |
+
for block_i in range(2, depth):
|
| 64 |
+
if block_i % 2:
|
| 65 |
+
in_channels = out_channels
|
| 66 |
+
else:
|
| 67 |
+
in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
|
| 68 |
+
|
| 69 |
+
if 0 <= backbone_from <= block_i and len(backbone_channels):
|
| 70 |
+
if INRDecode:
|
| 71 |
+
self.blocks_connected[f'block{block_i}_decode'] = ConvBlock(
|
| 72 |
+
in_channels, out_channels,
|
| 73 |
+
norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
|
| 74 |
+
padding=int(block_i < depth - 1)
|
| 75 |
+
)
|
| 76 |
+
self.blocks_channels += [out_channels]
|
| 77 |
+
stage_channels = backbone_channels.pop()
|
| 78 |
+
connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
|
| 79 |
+
self.connectors[f'connector{block_i}'] = connector
|
| 80 |
+
in_channels = connector.output_channels
|
| 81 |
+
|
| 82 |
+
self.blocks_connected[f'block{block_i}'] = ConvBlock(
|
| 83 |
+
in_channels, out_channels,
|
| 84 |
+
norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
|
| 85 |
+
padding=int(block_i < depth - 1)
|
| 86 |
+
)
|
| 87 |
+
self.blocks_channels += [out_channels]
|
| 88 |
+
|
| 89 |
+
def forward(self, x, backbone_features):
|
| 90 |
+
backbone_features = [] if backbone_features is None else backbone_features[::-1]
|
| 91 |
+
|
| 92 |
+
outputs = [self.block0(x)]
|
| 93 |
+
outputs += [self.block1(outputs[-1])]
|
| 94 |
+
|
| 95 |
+
for block_i in range(2, self.depth):
|
| 96 |
+
output = outputs[-1]
|
| 97 |
+
connector_name = f'connector{block_i}'
|
| 98 |
+
if connector_name in self.connectors:
|
| 99 |
+
if self.INRDecode:
|
| 100 |
+
block = self.blocks_connected[f'block{block_i}_decode']
|
| 101 |
+
outputs += [block(output)]
|
| 102 |
+
|
| 103 |
+
stage_features = backbone_features.pop()
|
| 104 |
+
connector = self.connectors[connector_name]
|
| 105 |
+
output = connector(output, stage_features)
|
| 106 |
+
block = self.blocks_connected[f'block{block_i}']
|
| 107 |
+
outputs += [block(output)]
|
| 108 |
+
|
| 109 |
+
return outputs[::-1]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DeconvDecoder(nn.Module):
|
| 113 |
+
def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False):
|
| 114 |
+
super(DeconvDecoder, self).__init__()
|
| 115 |
+
self.image_fusion = image_fusion
|
| 116 |
+
self.deconv_blocks = nn.ModuleList()
|
| 117 |
+
|
| 118 |
+
in_channels = encoder_blocks_channels.pop()
|
| 119 |
+
out_channels = in_channels
|
| 120 |
+
for d in range(depth):
|
| 121 |
+
out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
|
| 122 |
+
self.deconv_blocks.append(SEDeconvBlock(
|
| 123 |
+
in_channels, out_channels,
|
| 124 |
+
norm_layer=norm_layer,
|
| 125 |
+
padding=0 if d == 0 else 1,
|
| 126 |
+
with_se=0 <= attend_from <= d
|
| 127 |
+
))
|
| 128 |
+
in_channels = out_channels
|
| 129 |
+
|
| 130 |
+
if self.image_fusion:
|
| 131 |
+
self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
|
| 132 |
+
self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
|
| 133 |
+
|
| 134 |
+
def forward(self, encoder_outputs, image, mask=None):
|
| 135 |
+
output = encoder_outputs[0]
|
| 136 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
| 137 |
+
output = block(output, mask)
|
| 138 |
+
output = output + skip_output
|
| 139 |
+
output = self.deconv_blocks[-1](output, mask)
|
| 140 |
+
|
| 141 |
+
if self.image_fusion:
|
| 142 |
+
attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
|
| 143 |
+
output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output)
|
| 144 |
+
else:
|
| 145 |
+
output = self.to_rgb(output)
|
| 146 |
+
|
| 147 |
+
return output
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SEDeconvBlock(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
in_channels, out_channels,
|
| 154 |
+
kernel_size=4, stride=2, padding=1,
|
| 155 |
+
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
|
| 156 |
+
with_se=False
|
| 157 |
+
):
|
| 158 |
+
super(SEDeconvBlock, self).__init__()
|
| 159 |
+
self.with_se = with_se
|
| 160 |
+
self.block = nn.Sequential(
|
| 161 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
|
| 162 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
| 163 |
+
activation(),
|
| 164 |
+
)
|
| 165 |
+
if self.with_se:
|
| 166 |
+
self.se = MaskedChannelAttention(out_channels)
|
| 167 |
+
|
| 168 |
+
def forward(self, x, mask=None):
|
| 169 |
+
out = self.block(x)
|
| 170 |
+
if self.with_se:
|
| 171 |
+
out = self.se(out, mask)
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class INRDecoder(nn.Module):
|
| 176 |
+
def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from):
|
| 177 |
+
super(INRDecoder, self).__init__()
|
| 178 |
+
self.INR_encoding = None
|
| 179 |
+
if opt.embedding_type == "PosEncodingNeRF":
|
| 180 |
+
self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size)
|
| 181 |
+
elif opt.embedding_type == "RandomFourier":
|
| 182 |
+
self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device)
|
| 183 |
+
elif opt.embedding_type == "CIPS_embed":
|
| 184 |
+
self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32)
|
| 185 |
+
elif opt.embedding_type == "INRGAN_embed":
|
| 186 |
+
self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size)
|
| 187 |
+
else:
|
| 188 |
+
raise NotImplementedError
|
| 189 |
+
encoder_blocks_channels = encoder_blocks_channels[::-1]
|
| 190 |
+
max_hidden_mlp_num = attend_from + 1
|
| 191 |
+
self.opt = opt
|
| 192 |
+
self.max_hidden_mlp_num = max_hidden_mlp_num
|
| 193 |
+
self.content_mlp_blocks = nn.ModuleDict()
|
| 194 |
+
for n in range(max_hidden_mlp_num):
|
| 195 |
+
if n != max_hidden_mlp_num - 1:
|
| 196 |
+
self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
|
| 197 |
+
[self.INR_encoding.out_dim + opt.INR_MLP_dim + (
|
| 198 |
+
4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
|
| 199 |
+
opt, n + 1)
|
| 200 |
+
else:
|
| 201 |
+
self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
|
| 202 |
+
[self.INR_encoding.out_dim + (
|
| 203 |
+
4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
|
| 204 |
+
opt, n + 1)
|
| 205 |
+
|
| 206 |
+
self.deconv_blocks = nn.ModuleList()
|
| 207 |
+
|
| 208 |
+
encoder_blocks_channels = encoder_blocks_channels[::-1]
|
| 209 |
+
in_channels = encoder_blocks_channels.pop()
|
| 210 |
+
out_channels = in_channels
|
| 211 |
+
for d in range(depth - attend_from):
|
| 212 |
+
out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
|
| 213 |
+
self.deconv_blocks.append(SEDeconvBlock(
|
| 214 |
+
in_channels, out_channels,
|
| 215 |
+
norm_layer=norm_layer,
|
| 216 |
+
padding=0 if d == 0 else 1,
|
| 217 |
+
with_se=False
|
| 218 |
+
))
|
| 219 |
+
in_channels = out_channels
|
| 220 |
+
|
| 221 |
+
self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim],
|
| 222 |
+
(opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2,
|
| 223 |
+
opt, 2, toRGB=True)
|
| 224 |
+
|
| 225 |
+
self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim,
|
| 226 |
+
None, opt)
|
| 227 |
+
|
| 228 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
| 229 |
+
|
| 230 |
+
def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None):
|
| 231 |
+
"""For full resolution, do split."""
|
| 232 |
+
if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt,
|
| 233 |
+
'split_resolution')) and self.opt.isFullRes:
|
| 234 |
+
return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples)
|
| 235 |
+
|
| 236 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 237 |
+
mlp_output = None
|
| 238 |
+
waitToRGB = []
|
| 239 |
+
for n in range(self.max_hidden_mlp_num):
|
| 240 |
+
if not self.opt.hr_train:
|
| 241 |
+
coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \
|
| 242 |
+
.unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
| 243 |
+
else:
|
| 244 |
+
if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'):
|
| 245 |
+
coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view(
|
| 246 |
+
encoder_outputs[0].shape[0], -1, 2)
|
| 247 |
+
else:
|
| 248 |
+
coord = misc.get_mgrid(
|
| 249 |
+
self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat(
|
| 250 |
+
encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
| 251 |
+
|
| 252 |
+
"""Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper."""
|
| 253 |
+
if self.opt.isMoreINRInput:
|
| 254 |
+
if not self.opt.isFullRes or (
|
| 255 |
+
self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 256 |
+
res_h = res_w = np.sqrt(coord.shape[1]).astype(int)
|
| 257 |
+
else:
|
| 258 |
+
res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1))
|
| 259 |
+
res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1))
|
| 260 |
+
|
| 261 |
+
res_image = torchvision.transforms.Resize([res_h, res_w])(image)
|
| 262 |
+
res_mask = torchvision.transforms.Resize([res_h, res_w])(mask)
|
| 263 |
+
coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1),
|
| 264 |
+
res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
|
| 265 |
+
else:
|
| 266 |
+
coord = self.INR_encoding(coord)
|
| 267 |
+
|
| 268 |
+
"""============ LRIP structure, see Section 3.3 =============="""
|
| 269 |
+
|
| 270 |
+
"""Local MLPs."""
|
| 271 |
+
if n == 0:
|
| 272 |
+
mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0),
|
| 273 |
+
self.opt, content_mlp=self.content_mlp_blocks[
|
| 274 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 275 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion)
|
| 276 |
+
waitToRGB.append(mlp_output[1])
|
| 277 |
+
else:
|
| 278 |
+
mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
|
| 279 |
+
4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0],
|
| 280 |
+
content_mlp=self.content_mlp_blocks[
|
| 281 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 282 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)),
|
| 283 |
+
start_proportion=start_proportion)
|
| 284 |
+
waitToRGB.append(mlp_output[1])
|
| 285 |
+
|
| 286 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 287 |
+
output = encoder_outputs[0]
|
| 288 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
| 289 |
+
output = block(output)
|
| 290 |
+
output = output + skip_output
|
| 291 |
+
output = self.deconv_blocks[-1](output)
|
| 292 |
+
|
| 293 |
+
"""Global MLPs."""
|
| 294 |
+
app_mlp, app_params = self.appearance_mlps(output)
|
| 295 |
+
harm_out = []
|
| 296 |
+
for id in range(len(waitToRGB)):
|
| 297 |
+
output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id],
|
| 298 |
+
appearance_mlp=app_mlp)
|
| 299 |
+
harm_out.append(output[0])
|
| 300 |
+
|
| 301 |
+
"""Optional 3D LUT prediction."""
|
| 302 |
+
fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
|
| 303 |
+
|
| 304 |
+
return harm_out, fit_lut3d, lut_transform_image
|
| 305 |
+
|
| 306 |
+
def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None,
|
| 307 |
+
resolution=None, start_proportion=None):
|
| 308 |
+
|
| 309 |
+
activation = select_activation(opt.activation)
|
| 310 |
+
|
| 311 |
+
output = None
|
| 312 |
+
|
| 313 |
+
if content_mlp is not None:
|
| 314 |
+
if base_feat is not None:
|
| 315 |
+
coorinates = torch.cat([coorinates, base_feat], dim=2)
|
| 316 |
+
coorinates = lin2img(coorinates, resolution)
|
| 317 |
+
|
| 318 |
+
if hasattr(opt, 'split_resolution'):
|
| 319 |
+
"""
|
| 320 |
+
Here we crop the needed MLPs according to the region of the split input patches.
|
| 321 |
+
Note that this only support inferencing square images.
|
| 322 |
+
"""
|
| 323 |
+
for idx in range(len(content_mlp)):
|
| 324 |
+
content_mlp[idx][0] = content_mlp[idx][0][:,
|
| 325 |
+
(content_mlp[idx][0].shape[1] * start_proportion[0]).int():(
|
| 326 |
+
content_mlp[idx][0].shape[1] * start_proportion[2]).int(),
|
| 327 |
+
(content_mlp[idx][0].shape[2] * start_proportion[1]).int():(
|
| 328 |
+
content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :,
|
| 329 |
+
:]
|
| 330 |
+
content_mlp[idx][1] = content_mlp[idx][1][:,
|
| 331 |
+
(content_mlp[idx][1].shape[1] * start_proportion[0]).int():(
|
| 332 |
+
content_mlp[idx][1].shape[1] * start_proportion[2]).int(),
|
| 333 |
+
(content_mlp[idx][1].shape[2] * start_proportion[1]).int():(
|
| 334 |
+
content_mlp[idx][1].shape[2] * start_proportion[3]).int(),
|
| 335 |
+
:,
|
| 336 |
+
:]
|
| 337 |
+
k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
|
| 338 |
+
k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
|
| 339 |
+
bs = coorinates.shape[0]
|
| 340 |
+
h_lr = w_lr = content_mlp[0][0].shape[1]
|
| 341 |
+
nci = INR_input_dim
|
| 342 |
+
|
| 343 |
+
coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
|
| 344 |
+
coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
|
| 345 |
+
bs, h_lr, w_lr, int(k_h * k_w), nci)
|
| 346 |
+
|
| 347 |
+
for id, layer in enumerate(content_mlp):
|
| 348 |
+
if id == 0:
|
| 349 |
+
output = torch.matmul(coorinates, layer[0]) + layer[1]
|
| 350 |
+
output = activation(output)
|
| 351 |
+
else:
|
| 352 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 353 |
+
output = activation(output)
|
| 354 |
+
|
| 355 |
+
output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
|
| 356 |
+
0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
|
| 357 |
+
|
| 358 |
+
output_large = self.up(lin2img(output))
|
| 359 |
+
|
| 360 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
| 361 |
+
|
| 362 |
+
k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
|
| 363 |
+
k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
|
| 364 |
+
bs = coorinates.shape[0]
|
| 365 |
+
h_lr = w_lr = content_mlp[0][0].shape[1]
|
| 366 |
+
nci = INR_input_dim
|
| 367 |
+
|
| 368 |
+
"""(evaluation or not HR training) and not fullres evaluation"""
|
| 369 |
+
if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not (
|
| 370 |
+
not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train):
|
| 371 |
+
coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
|
| 372 |
+
coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
|
| 373 |
+
bs, h_lr, w_lr, int(k_h * k_w), nci)
|
| 374 |
+
|
| 375 |
+
for id, layer in enumerate(content_mlp):
|
| 376 |
+
if id == 0:
|
| 377 |
+
output = torch.matmul(coorinates, layer[0]) + layer[1]
|
| 378 |
+
output = activation(output)
|
| 379 |
+
else:
|
| 380 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 381 |
+
output = activation(output)
|
| 382 |
+
|
| 383 |
+
output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
|
| 384 |
+
0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
|
| 385 |
+
|
| 386 |
+
output_large = self.up(lin2img(output))
|
| 387 |
+
|
| 388 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
| 389 |
+
else:
|
| 390 |
+
coorinates = coorinates.permute(0, 2, 3, 1)
|
| 391 |
+
for id, layer in enumerate(content_mlp):
|
| 392 |
+
weigt_shape = layer[0].shape
|
| 393 |
+
bias_shape = layer[1].shape
|
| 394 |
+
layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
|
| 395 |
+
layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
|
| 396 |
+
layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True
|
| 397 |
+
else 'bilinear', padding_mode='border', align_corners=False)
|
| 398 |
+
layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True
|
| 399 |
+
else 'bilinear', padding_mode='border', align_corners=False)
|
| 400 |
+
layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:])
|
| 401 |
+
layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:])
|
| 402 |
+
|
| 403 |
+
if id == 0:
|
| 404 |
+
output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1]
|
| 405 |
+
output = activation(output)
|
| 406 |
+
else:
|
| 407 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 408 |
+
output = activation(output)
|
| 409 |
+
|
| 410 |
+
output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim)
|
| 411 |
+
|
| 412 |
+
output_large = self.up(lin2img(output, resolution))
|
| 413 |
+
|
| 414 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
| 415 |
+
|
| 416 |
+
elif appearance_mlp is not None:
|
| 417 |
+
output = base_feat
|
| 418 |
+
genMask = None
|
| 419 |
+
for id, layer in enumerate(appearance_mlp):
|
| 420 |
+
if id != len(appearance_mlp) - 1:
|
| 421 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
| 422 |
+
output = activation(output)
|
| 423 |
+
else:
|
| 424 |
+
output = torch.matmul(output, layer[0]) + layer[1] # last layer
|
| 425 |
+
if opt.activation == 'leakyrelu_pe':
|
| 426 |
+
output = torch.tanh(output)
|
| 427 |
+
return lin2img(output, resolution), None
|
| 428 |
+
|
| 429 |
+
def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None):
|
| 430 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 431 |
+
mlp_output = None
|
| 432 |
+
res_w = image.shape[-1]
|
| 433 |
+
res_h = image.shape[-2]
|
| 434 |
+
coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat(
|
| 435 |
+
encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
| 436 |
+
|
| 437 |
+
if self.opt.isMoreINRInput:
|
| 438 |
+
coord = torch.cat(
|
| 439 |
+
[self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1),
|
| 440 |
+
mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
|
| 441 |
+
else:
|
| 442 |
+
coord = self.INR_encoding(coord, (res_h, res_w))
|
| 443 |
+
|
| 444 |
+
total = coord.clone()
|
| 445 |
+
|
| 446 |
+
interval = 10
|
| 447 |
+
all_intervals = math.ceil(res_h / interval)
|
| 448 |
+
divisible = True
|
| 449 |
+
if res_h / interval != res_h // interval:
|
| 450 |
+
divisible = False
|
| 451 |
+
|
| 452 |
+
for n in range(self.max_hidden_mlp_num):
|
| 453 |
+
accum_mlp_output = []
|
| 454 |
+
for line in range(all_intervals):
|
| 455 |
+
if not divisible and line == all_intervals - 1:
|
| 456 |
+
coord = total[:, line * interval * res_w:, :]
|
| 457 |
+
else:
|
| 458 |
+
coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :]
|
| 459 |
+
if n == 0:
|
| 460 |
+
accum_mlp_output.append(self.mlp_process(coord,
|
| 461 |
+
self.INR_encoding.out_dim + (
|
| 462 |
+
4 if self.opt.isMoreINRInput else 0),
|
| 463 |
+
self.opt, content_mlp=self.content_mlp_blocks[
|
| 464 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 465 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
|
| 466 |
+
encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
|
| 467 |
+
resolution=(interval,
|
| 468 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
| 469 |
+
res_h - interval * (all_intervals - 1), res_w))[1])
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
|
| 473 |
+
4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:,
|
| 474 |
+
line * interval * res_w: (
|
| 475 |
+
line + 1) * interval * res_w,
|
| 476 |
+
:]
|
| 477 |
+
if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :],
|
| 478 |
+
content_mlp=self.content_mlp_blocks[
|
| 479 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
| 480 |
+
encoder_outputs.pop(
|
| 481 |
+
self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
|
| 482 |
+
encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
|
| 483 |
+
resolution=(interval,
|
| 484 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
| 485 |
+
res_h - interval * (all_intervals - 1), res_w))[1])
|
| 486 |
+
|
| 487 |
+
accum_mlp_output = torch.cat(accum_mlp_output, dim=1)
|
| 488 |
+
mlp_output = [accum_mlp_output, accum_mlp_output]
|
| 489 |
+
|
| 490 |
+
encoder_outputs = encoder_outputs[::-1]
|
| 491 |
+
output = encoder_outputs[0]
|
| 492 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
| 493 |
+
output = block(output)
|
| 494 |
+
output = output + skip_output
|
| 495 |
+
output = self.deconv_blocks[-1](output)
|
| 496 |
+
|
| 497 |
+
app_mlp, app_params = self.appearance_mlps(output)
|
| 498 |
+
harm_out = []
|
| 499 |
+
|
| 500 |
+
accum_mlp_output = []
|
| 501 |
+
for line in range(all_intervals):
|
| 502 |
+
if not divisible and line == all_intervals - 1:
|
| 503 |
+
base = mlp_output[1][:, line * interval * res_w:, :]
|
| 504 |
+
else:
|
| 505 |
+
base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :]
|
| 506 |
+
|
| 507 |
+
accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base,
|
| 508 |
+
appearance_mlp=app_mlp,
|
| 509 |
+
resolution=(
|
| 510 |
+
interval,
|
| 511 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
| 512 |
+
res_h - interval * (all_intervals - 1), res_w))[0])
|
| 513 |
+
|
| 514 |
+
accum_mlp_output = torch.cat(accum_mlp_output, dim=2)
|
| 515 |
+
harm_out.append(accum_mlp_output)
|
| 516 |
+
|
| 517 |
+
fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
|
| 518 |
+
|
| 519 |
+
return harm_out, fit_lut3d, lut_transform_image
|
model/base/ih_model.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder
|
| 6 |
+
|
| 7 |
+
from .ops import ScaleLayer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class IHModelWithBackbone(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model, backbone,
|
| 14 |
+
downsize_backbone_input=False,
|
| 15 |
+
mask_fusion='sum',
|
| 16 |
+
backbone_conv1_channels=64, opt=None
|
| 17 |
+
):
|
| 18 |
+
super(IHModelWithBackbone, self).__init__()
|
| 19 |
+
self.downsize_backbone_input = downsize_backbone_input
|
| 20 |
+
self.mask_fusion = mask_fusion
|
| 21 |
+
|
| 22 |
+
self.backbone = backbone
|
| 23 |
+
self.model = model
|
| 24 |
+
self.opt = opt
|
| 25 |
+
|
| 26 |
+
self.mask_conv = nn.Sequential(
|
| 27 |
+
nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
|
| 28 |
+
ScaleLayer(init_value=0.1, lr_mult=1)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, image, mask, coord=None, start_proportion=None):
|
| 32 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 33 |
+
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
|
| 34 |
+
backbone_mask = torch.cat(
|
| 35 |
+
(torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
|
| 36 |
+
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
|
| 37 |
+
else:
|
| 38 |
+
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
|
| 39 |
+
backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
|
| 40 |
+
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
|
| 41 |
+
|
| 42 |
+
backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
|
| 43 |
+
backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
|
| 44 |
+
|
| 45 |
+
output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
|
| 46 |
+
return output
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DeepImageHarmonization(nn.Module):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
depth,
|
| 53 |
+
norm_layer=nn.BatchNorm2d, batchnorm_from=0,
|
| 54 |
+
attend_from=-1,
|
| 55 |
+
image_fusion=False,
|
| 56 |
+
ch=64, max_channels=512,
|
| 57 |
+
backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
|
| 58 |
+
):
|
| 59 |
+
super(DeepImageHarmonization, self).__init__()
|
| 60 |
+
self.depth = depth
|
| 61 |
+
self.encoder = ConvEncoder(
|
| 62 |
+
depth, ch,
|
| 63 |
+
norm_layer, batchnorm_from, max_channels,
|
| 64 |
+
backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
|
| 65 |
+
)
|
| 66 |
+
self.opt = opt
|
| 67 |
+
if opt.INRDecode:
|
| 68 |
+
"See Table 2 in the paper to test with different INR decoders' structures."
|
| 69 |
+
self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
|
| 70 |
+
else:
|
| 71 |
+
"Baseline: https://github.com/SamsungLabs/image_harmonization"
|
| 72 |
+
self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
|
| 73 |
+
|
| 74 |
+
def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
|
| 75 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 76 |
+
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
|
| 77 |
+
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
|
| 78 |
+
else:
|
| 79 |
+
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
|
| 80 |
+
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
|
| 81 |
+
|
| 82 |
+
intermediates = self.encoder(x, backbone_features)
|
| 83 |
+
|
| 84 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 85 |
+
output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
|
| 86 |
+
else:
|
| 87 |
+
output = self.decoder(intermediates, image, mask)
|
| 88 |
+
return output
|
model/base/ops.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SimpleInputFusion(nn.Module):
|
| 9 |
+
def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
|
| 10 |
+
super(SimpleInputFusion, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.fusion_conv = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
|
| 14 |
+
nn.LeakyReLU(negative_slope=0.2),
|
| 15 |
+
norm_layer(ch),
|
| 16 |
+
nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def forward(self, image, additional_input):
|
| 20 |
+
return self.fusion_conv(torch.cat((image, additional_input), dim=1))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MaskedChannelAttention(nn.Module):
|
| 24 |
+
def __init__(self, in_channels, *args, **kwargs):
|
| 25 |
+
super(MaskedChannelAttention, self).__init__()
|
| 26 |
+
self.global_max_pool = MaskedGlobalMaxPool2d()
|
| 27 |
+
self.global_avg_pool = FastGlobalAvgPool2d()
|
| 28 |
+
|
| 29 |
+
intermediate_channels_count = max(in_channels // 16, 8)
|
| 30 |
+
self.attention_transform = nn.Sequential(
|
| 31 |
+
nn.Linear(3 * in_channels, intermediate_channels_count),
|
| 32 |
+
nn.ReLU(inplace=True),
|
| 33 |
+
nn.Linear(intermediate_channels_count, in_channels),
|
| 34 |
+
nn.Sigmoid(),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(self, x, mask):
|
| 38 |
+
if mask.shape[2:] != x.shape[:2]:
|
| 39 |
+
mask = nn.functional.interpolate(
|
| 40 |
+
mask, size=x.size()[-2:],
|
| 41 |
+
mode='bilinear', align_corners=True
|
| 42 |
+
)
|
| 43 |
+
pooled_x = torch.cat([
|
| 44 |
+
self.global_max_pool(x, mask),
|
| 45 |
+
self.global_avg_pool(x)
|
| 46 |
+
], dim=1)
|
| 47 |
+
channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
|
| 48 |
+
|
| 49 |
+
return channel_attention_weights * x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MaskedGlobalMaxPool2d(nn.Module):
|
| 53 |
+
def __init__(self):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.global_max_pool = FastGlobalMaxPool2d()
|
| 56 |
+
|
| 57 |
+
def forward(self, x, mask):
|
| 58 |
+
return torch.cat((
|
| 59 |
+
self.global_max_pool(x * mask),
|
| 60 |
+
self.global_max_pool(x * (1.0 - mask))
|
| 61 |
+
), dim=1)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class FastGlobalAvgPool2d(nn.Module):
|
| 65 |
+
def __init__(self):
|
| 66 |
+
super(FastGlobalAvgPool2d, self).__init__()
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
in_size = x.size()
|
| 70 |
+
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class FastGlobalMaxPool2d(nn.Module):
|
| 74 |
+
def __init__(self):
|
| 75 |
+
super(FastGlobalMaxPool2d, self).__init__()
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
in_size = x.size()
|
| 79 |
+
return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ScaleLayer(nn.Module):
|
| 83 |
+
def __init__(self, init_value=1.0, lr_mult=1):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.lr_mult = lr_mult
|
| 86 |
+
self.scale = nn.Parameter(
|
| 87 |
+
torch.full((1,), init_value / lr_mult, dtype=torch.float32)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
scale = torch.abs(self.scale * self.lr_mult)
|
| 92 |
+
return x * scale
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FeaturesConnector(nn.Module):
|
| 96 |
+
def __init__(self, mode, in_channels, feature_channels, out_channels):
|
| 97 |
+
super(FeaturesConnector, self).__init__()
|
| 98 |
+
self.mode = mode if feature_channels else ''
|
| 99 |
+
|
| 100 |
+
if self.mode == 'catc':
|
| 101 |
+
self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
|
| 102 |
+
elif self.mode == 'sum':
|
| 103 |
+
self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
|
| 104 |
+
|
| 105 |
+
self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
|
| 106 |
+
|
| 107 |
+
def forward(self, x, features):
|
| 108 |
+
if self.mode == 'cat':
|
| 109 |
+
return torch.cat((x, features), 1)
|
| 110 |
+
if self.mode == 'catc':
|
| 111 |
+
return self.reduce_conv(torch.cat((x, features), 1))
|
| 112 |
+
if self.mode == 'sum':
|
| 113 |
+
return self.reduce_conv(features) + x
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
def extra_repr(self):
|
| 117 |
+
return self.mode
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class PosEncodingNeRF(nn.Module):
|
| 121 |
+
def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
self.in_features = in_features
|
| 125 |
+
|
| 126 |
+
if self.in_features == 3:
|
| 127 |
+
self.num_frequencies = 10
|
| 128 |
+
elif self.in_features == 2:
|
| 129 |
+
assert sidelength is not None
|
| 130 |
+
if isinstance(sidelength, int):
|
| 131 |
+
sidelength = (sidelength, sidelength)
|
| 132 |
+
self.num_frequencies = 4
|
| 133 |
+
if use_nyquist:
|
| 134 |
+
self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
|
| 135 |
+
elif self.in_features == 1:
|
| 136 |
+
assert fn_samples is not None
|
| 137 |
+
self.num_frequencies = 4
|
| 138 |
+
if use_nyquist:
|
| 139 |
+
self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
|
| 140 |
+
|
| 141 |
+
self.out_dim = in_features + 2 * in_features * self.num_frequencies
|
| 142 |
+
|
| 143 |
+
def get_num_frequencies_nyquist(self, samples):
|
| 144 |
+
nyquist_rate = 1 / (2 * (2 * 1 / samples))
|
| 145 |
+
return int(math.floor(math.log(nyquist_rate, 2)))
|
| 146 |
+
|
| 147 |
+
def forward(self, coords):
|
| 148 |
+
coords = coords.view(coords.shape[0], -1, self.in_features)
|
| 149 |
+
|
| 150 |
+
coords_pos_enc = coords
|
| 151 |
+
for i in range(self.num_frequencies):
|
| 152 |
+
for j in range(self.in_features):
|
| 153 |
+
c = coords[..., j]
|
| 154 |
+
|
| 155 |
+
sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
|
| 156 |
+
cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)
|
| 157 |
+
|
| 158 |
+
coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)
|
| 159 |
+
|
| 160 |
+
return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class RandomFourier(nn.Module):
|
| 164 |
+
def __init__(self, std_scale, embedding_length, device):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale
|
| 168 |
+
self.embed = self.embed.to(device)
|
| 169 |
+
|
| 170 |
+
self.out_dim = embedding_length * 2 + 2
|
| 171 |
+
|
| 172 |
+
def forward(self, coords):
|
| 173 |
+
coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)),
|
| 174 |
+
torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1)
|
| 175 |
+
|
| 176 |
+
return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class CIPS_embed(nn.Module):
|
| 180 |
+
def __init__(self, size, embedding_length):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.fourier_embed = ConstantInput(size, embedding_length)
|
| 183 |
+
self.predict_embed = Predict_embed(embedding_length)
|
| 184 |
+
self.out_dim = embedding_length * 2 + 2
|
| 185 |
+
|
| 186 |
+
def forward(self, coord, res=None):
|
| 187 |
+
x = self.predict_embed(coord)
|
| 188 |
+
y = self.fourier_embed(x, coord, res)
|
| 189 |
+
|
| 190 |
+
return torch.cat([coord, x, y], dim=-1)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Predict_embed(nn.Module):
|
| 194 |
+
def __init__(self, embedding_length):
|
| 195 |
+
super(Predict_embed, self).__init__()
|
| 196 |
+
self.ffm = nn.Linear(2, embedding_length, bias=True)
|
| 197 |
+
nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2))
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
x = self.ffm(x)
|
| 201 |
+
x = torch.sin(x)
|
| 202 |
+
return x
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class ConstantInput(nn.Module):
|
| 206 |
+
def __init__(self, size, channel):
|
| 207 |
+
super().__init__()
|
| 208 |
+
|
| 209 |
+
self.input = nn.Parameter(torch.randn(1, size ** 2, channel))
|
| 210 |
+
|
| 211 |
+
def forward(self, input, coord, resolution=None):
|
| 212 |
+
batch = input.shape[0]
|
| 213 |
+
out = self.input.repeat(batch, 1, 1)
|
| 214 |
+
|
| 215 |
+
if coord.shape[1] != self.input.shape[1]:
|
| 216 |
+
x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1],
|
| 217 |
+
int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5))
|
| 218 |
+
|
| 219 |
+
if resolution is None:
|
| 220 |
+
grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1])
|
| 221 |
+
else:
|
| 222 |
+
grid = coord.view(coord.shape[0], *resolution, coord.shape[-1])
|
| 223 |
+
|
| 224 |
+
out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True)
|
| 225 |
+
|
| 226 |
+
out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1])
|
| 227 |
+
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class INRGAN_embed(nn.Module):
|
| 232 |
+
def __init__(self, resolution: int, w_dim=None):
|
| 233 |
+
super().__init__()
|
| 234 |
+
|
| 235 |
+
self.resolution = resolution
|
| 236 |
+
self.res_cfg = {"log_emb_size": 32,
|
| 237 |
+
"random_emb_size": 32,
|
| 238 |
+
"const_emb_size": 64,
|
| 239 |
+
"use_cosine": True}
|
| 240 |
+
self.log_emb_size = self.res_cfg.get('log_emb_size', 0)
|
| 241 |
+
self.random_emb_size = self.res_cfg.get('random_emb_size', 0)
|
| 242 |
+
self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0)
|
| 243 |
+
self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0)
|
| 244 |
+
self.const_emb_size = self.res_cfg.get('const_emb_size', 0)
|
| 245 |
+
self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10))
|
| 246 |
+
self.use_cosine = self.res_cfg.get('use_cosine', False)
|
| 247 |
+
|
| 248 |
+
if self.log_emb_size > 0:
|
| 249 |
+
self.register_buffer('log_basis', generate_logarithmic_basis(
|
| 250 |
+
resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False)))
|
| 251 |
+
|
| 252 |
+
if self.random_emb_size > 0:
|
| 253 |
+
self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale))
|
| 254 |
+
|
| 255 |
+
if self.shared_emb_size > 0:
|
| 256 |
+
self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale))
|
| 257 |
+
|
| 258 |
+
if self.predictable_emb_size > 0:
|
| 259 |
+
self.W_size = self.predictable_emb_size * self.cfg.coord_dim
|
| 260 |
+
self.b_size = self.predictable_emb_size
|
| 261 |
+
self.affine = nn.Linear(w_dim, self.W_size + self.b_size)
|
| 262 |
+
|
| 263 |
+
if self.const_emb_size > 0:
|
| 264 |
+
self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size))
|
| 265 |
+
|
| 266 |
+
self.out_dim = self.get_total_dim() + 2
|
| 267 |
+
|
| 268 |
+
def sample_w_matrix(self, shape, scale: float):
|
| 269 |
+
return torch.randn(shape) * scale
|
| 270 |
+
|
| 271 |
+
def get_total_dim(self) -> int:
|
| 272 |
+
total_dim = 0
|
| 273 |
+
if self.log_emb_size > 0:
|
| 274 |
+
total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1)
|
| 275 |
+
total_dim += self.random_emb_size * (2 if self.use_cosine else 1)
|
| 276 |
+
total_dim += self.shared_emb_size * (2 if self.use_cosine else 1)
|
| 277 |
+
total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1)
|
| 278 |
+
total_dim += self.const_emb_size
|
| 279 |
+
|
| 280 |
+
return total_dim
|
| 281 |
+
|
| 282 |
+
def forward(self, raw_coords, w=None):
|
| 283 |
+
batch_size, img_size, in_channels = raw_coords.shape
|
| 284 |
+
|
| 285 |
+
raw_embs = []
|
| 286 |
+
|
| 287 |
+
if self.log_emb_size > 0:
|
| 288 |
+
log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1)
|
| 289 |
+
raw_log_embs = torch.matmul(raw_coords, log_bases)
|
| 290 |
+
raw_embs.append(raw_log_embs)
|
| 291 |
+
|
| 292 |
+
if self.random_emb_size > 0:
|
| 293 |
+
random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 294 |
+
raw_random_embs = torch.matmul(raw_coords, random_bases)
|
| 295 |
+
raw_embs.append(raw_random_embs)
|
| 296 |
+
|
| 297 |
+
if self.shared_emb_size > 0:
|
| 298 |
+
shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1)
|
| 299 |
+
raw_shared_embs = torch.matmul(raw_coords, shared_bases)
|
| 300 |
+
raw_embs.append(raw_shared_embs)
|
| 301 |
+
|
| 302 |
+
if self.predictable_emb_size > 0:
|
| 303 |
+
mod = self.affine(w)
|
| 304 |
+
W = self.fourier_scale * mod[:, :self.W_size]
|
| 305 |
+
W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size)
|
| 306 |
+
bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size)
|
| 307 |
+
raw_predictable_embs = (torch.matmul(raw_coords, W) + bias)
|
| 308 |
+
raw_embs.append(raw_predictable_embs)
|
| 309 |
+
|
| 310 |
+
if len(raw_embs) > 0:
|
| 311 |
+
raw_embs = torch.cat(raw_embs, dim=-1)
|
| 312 |
+
raw_embs = raw_embs.contiguous()
|
| 313 |
+
out = raw_embs.sin()
|
| 314 |
+
|
| 315 |
+
if self.use_cosine:
|
| 316 |
+
out = torch.cat([out, raw_embs.cos()], dim=-1)
|
| 317 |
+
|
| 318 |
+
if self.const_emb_size > 0:
|
| 319 |
+
const_embs = self.const_embs.repeat([batch_size, 1, 1])
|
| 320 |
+
const_embs = const_embs
|
| 321 |
+
out = torch.cat([out, const_embs], dim=-1)
|
| 322 |
+
|
| 323 |
+
return torch.cat([raw_coords, out], dim=-1)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def generate_logarithmic_basis(
|
| 327 |
+
resolution,
|
| 328 |
+
max_num_feats,
|
| 329 |
+
remove_lowest_freq: bool = False,
|
| 330 |
+
use_diagonal: bool = True):
|
| 331 |
+
"""
|
| 332 |
+
Generates a directional logarithmic basis with the following directions:
|
| 333 |
+
- horizontal
|
| 334 |
+
- vertical
|
| 335 |
+
- main diagonal
|
| 336 |
+
- anti-diagonal
|
| 337 |
+
"""
|
| 338 |
+
max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int)
|
| 339 |
+
bases = [
|
| 340 |
+
generate_horizontal_basis(max_num_feats_per_direction),
|
| 341 |
+
generate_vertical_basis(max_num_feats_per_direction),
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
if use_diagonal:
|
| 345 |
+
bases.extend([
|
| 346 |
+
generate_diag_main_basis(max_num_feats_per_direction),
|
| 347 |
+
generate_anti_diag_basis(max_num_feats_per_direction),
|
| 348 |
+
])
|
| 349 |
+
|
| 350 |
+
if remove_lowest_freq:
|
| 351 |
+
bases = [b[1:] for b in bases]
|
| 352 |
+
|
| 353 |
+
# If we do not fit into `max_num_feats`, then trying to remove the features in the order:
|
| 354 |
+
# 1) anti-diagonal 2) main-diagonal
|
| 355 |
+
# while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2):
|
| 356 |
+
# bases = bases[:-1]
|
| 357 |
+
|
| 358 |
+
basis = torch.cat(bases, dim=0)
|
| 359 |
+
|
| 360 |
+
# If we still do not fit, then let's remove each second feature,
|
| 361 |
+
# then each third, each forth and so on
|
| 362 |
+
# We cannot drop the whole horizontal or vertical direction since otherwise
|
| 363 |
+
# model won't be able to locate the position
|
| 364 |
+
# (unless the previously computed embeddings encode the position)
|
| 365 |
+
# while basis.shape[0] > max_num_feats:
|
| 366 |
+
# num_exceeding_feats = basis.shape[0] - max_num_feats
|
| 367 |
+
# basis = basis[::2]
|
| 368 |
+
|
| 369 |
+
assert basis.shape[0] <= max_num_feats, \
|
| 370 |
+
f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}."
|
| 371 |
+
|
| 372 |
+
return basis
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def generate_horizontal_basis(num_feats: int):
|
| 376 |
+
return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def generate_vertical_basis(num_feats: int):
|
| 380 |
+
return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def generate_diag_main_basis(num_feats: int):
|
| 384 |
+
return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def generate_anti_diag_basis(num_feats: int):
|
| 388 |
+
return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def generate_wavefront_basis(num_feats: int, basis_block, period_length: float):
|
| 392 |
+
period_coef = 2.0 * np.pi / period_length
|
| 393 |
+
basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2]
|
| 394 |
+
powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1]
|
| 395 |
+
result = basis * powers * period_coef # [num_feats, 2]
|
| 396 |
+
|
| 397 |
+
return result.float()
|
model/build_model.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from .backbone import build_backbone
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class build_model(nn.Module):
|
| 6 |
+
def __init__(self, opt):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
self.opt = opt
|
| 10 |
+
self.backbone = build_backbone('baseline', opt)
|
| 11 |
+
|
| 12 |
+
def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
|
| 13 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
| 14 |
+
"""
|
| 15 |
+
For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
|
| 16 |
+
here we need to pass in the coordinates of the cropped regions.
|
| 17 |
+
"""
|
| 18 |
+
extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
|
| 19 |
+
else:
|
| 20 |
+
extracted_features = self.backbone(composite_image, mask)
|
| 21 |
+
|
| 22 |
+
if self.opt.INRDecode:
|
| 23 |
+
return extracted_features
|
| 24 |
+
return None, None, extracted_features
|
model/hrnetv2/__init__.py
ADDED
|
File without changes
|
model/hrnetv2/hrnet_ocr.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch._utils
|
| 7 |
+
from .ocr import SpatialOCR_Module, SpatialGather_Module
|
| 8 |
+
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
| 9 |
+
|
| 10 |
+
relu_inplace = True
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class HighResolutionModule(nn.Module):
|
| 14 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
| 15 |
+
num_channels, fuse_method,multi_scale_output=True,
|
| 16 |
+
norm_layer=nn.BatchNorm2d, align_corners=True):
|
| 17 |
+
super(HighResolutionModule, self).__init__()
|
| 18 |
+
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
| 19 |
+
|
| 20 |
+
self.num_inchannels = num_inchannels
|
| 21 |
+
self.fuse_method = fuse_method
|
| 22 |
+
self.num_branches = num_branches
|
| 23 |
+
self.norm_layer = norm_layer
|
| 24 |
+
self.align_corners = align_corners
|
| 25 |
+
|
| 26 |
+
self.multi_scale_output = multi_scale_output
|
| 27 |
+
|
| 28 |
+
self.branches = self._make_branches(
|
| 29 |
+
num_branches, blocks, num_blocks, num_channels)
|
| 30 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 31 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
| 32 |
+
|
| 33 |
+
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
| 34 |
+
if num_branches != len(num_blocks):
|
| 35 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
| 36 |
+
num_branches, len(num_blocks))
|
| 37 |
+
raise ValueError(error_msg)
|
| 38 |
+
|
| 39 |
+
if num_branches != len(num_channels):
|
| 40 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
| 41 |
+
num_branches, len(num_channels))
|
| 42 |
+
raise ValueError(error_msg)
|
| 43 |
+
|
| 44 |
+
if num_branches != len(num_inchannels):
|
| 45 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
| 46 |
+
num_branches, len(num_inchannels))
|
| 47 |
+
raise ValueError(error_msg)
|
| 48 |
+
|
| 49 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
| 50 |
+
stride=1):
|
| 51 |
+
downsample = None
|
| 52 |
+
if stride != 1 or \
|
| 53 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
| 54 |
+
downsample = nn.Sequential(
|
| 55 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
| 56 |
+
num_channels[branch_index] * block.expansion,
|
| 57 |
+
kernel_size=1, stride=stride, bias=False),
|
| 58 |
+
self.norm_layer(num_channels[branch_index] * block.expansion),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
layers = []
|
| 62 |
+
layers.append(block(self.num_inchannels[branch_index],
|
| 63 |
+
num_channels[branch_index], stride,
|
| 64 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
| 65 |
+
self.num_inchannels[branch_index] = \
|
| 66 |
+
num_channels[branch_index] * block.expansion
|
| 67 |
+
for i in range(1, num_blocks[branch_index]):
|
| 68 |
+
layers.append(block(self.num_inchannels[branch_index],
|
| 69 |
+
num_channels[branch_index],
|
| 70 |
+
norm_layer=self.norm_layer))
|
| 71 |
+
|
| 72 |
+
return nn.Sequential(*layers)
|
| 73 |
+
|
| 74 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 75 |
+
branches = []
|
| 76 |
+
|
| 77 |
+
for i in range(num_branches):
|
| 78 |
+
branches.append(
|
| 79 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
| 80 |
+
|
| 81 |
+
return nn.ModuleList(branches)
|
| 82 |
+
|
| 83 |
+
def _make_fuse_layers(self):
|
| 84 |
+
if self.num_branches == 1:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
num_branches = self.num_branches
|
| 88 |
+
num_inchannels = self.num_inchannels
|
| 89 |
+
fuse_layers = []
|
| 90 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 91 |
+
fuse_layer = []
|
| 92 |
+
for j in range(num_branches):
|
| 93 |
+
if j > i:
|
| 94 |
+
fuse_layer.append(nn.Sequential(
|
| 95 |
+
nn.Conv2d(in_channels=num_inchannels[j],
|
| 96 |
+
out_channels=num_inchannels[i],
|
| 97 |
+
kernel_size=1,
|
| 98 |
+
bias=False),
|
| 99 |
+
self.norm_layer(num_inchannels[i])))
|
| 100 |
+
elif j == i:
|
| 101 |
+
fuse_layer.append(None)
|
| 102 |
+
else:
|
| 103 |
+
conv3x3s = []
|
| 104 |
+
for k in range(i - j):
|
| 105 |
+
if k == i - j - 1:
|
| 106 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
| 107 |
+
conv3x3s.append(nn.Sequential(
|
| 108 |
+
nn.Conv2d(num_inchannels[j],
|
| 109 |
+
num_outchannels_conv3x3,
|
| 110 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 111 |
+
self.norm_layer(num_outchannels_conv3x3)))
|
| 112 |
+
else:
|
| 113 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
| 114 |
+
conv3x3s.append(nn.Sequential(
|
| 115 |
+
nn.Conv2d(num_inchannels[j],
|
| 116 |
+
num_outchannels_conv3x3,
|
| 117 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 118 |
+
self.norm_layer(num_outchannels_conv3x3),
|
| 119 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 120 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 121 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 122 |
+
|
| 123 |
+
return nn.ModuleList(fuse_layers)
|
| 124 |
+
|
| 125 |
+
def get_num_inchannels(self):
|
| 126 |
+
return self.num_inchannels
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
if self.num_branches == 1:
|
| 130 |
+
return [self.branches[0](x[0])]
|
| 131 |
+
|
| 132 |
+
for i in range(self.num_branches):
|
| 133 |
+
x[i] = self.branches[i](x[i])
|
| 134 |
+
|
| 135 |
+
x_fuse = []
|
| 136 |
+
for i in range(len(self.fuse_layers)):
|
| 137 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 138 |
+
for j in range(1, self.num_branches):
|
| 139 |
+
if i == j:
|
| 140 |
+
y = y + x[j]
|
| 141 |
+
elif j > i:
|
| 142 |
+
width_output = x[i].shape[-1]
|
| 143 |
+
height_output = x[i].shape[-2]
|
| 144 |
+
y = y + F.interpolate(
|
| 145 |
+
self.fuse_layers[i][j](x[j]),
|
| 146 |
+
size=[height_output, width_output],
|
| 147 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 148 |
+
else:
|
| 149 |
+
y = y + self.fuse_layers[i][j](x[j])
|
| 150 |
+
x_fuse.append(self.relu(y))
|
| 151 |
+
|
| 152 |
+
return x_fuse
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class HighResolutionNet(nn.Module):
|
| 156 |
+
def __init__(self, width, num_classes, ocr_width=256, small=False,
|
| 157 |
+
norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
|
| 158 |
+
super(HighResolutionNet, self).__init__()
|
| 159 |
+
self.opt = opt
|
| 160 |
+
self.norm_layer = norm_layer
|
| 161 |
+
self.width = width
|
| 162 |
+
self.ocr_width = ocr_width
|
| 163 |
+
self.ocr_on = ocr_width > 0
|
| 164 |
+
self.align_corners = align_corners
|
| 165 |
+
|
| 166 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 167 |
+
self.bn1 = norm_layer(64)
|
| 168 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 169 |
+
self.bn2 = norm_layer(64)
|
| 170 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
| 171 |
+
|
| 172 |
+
num_blocks = 2 if small else 4
|
| 173 |
+
|
| 174 |
+
stage1_num_channels = 64
|
| 175 |
+
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
|
| 176 |
+
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
| 177 |
+
|
| 178 |
+
self.stage2_num_branches = 2
|
| 179 |
+
num_channels = [width, 2 * width]
|
| 180 |
+
num_inchannels = [
|
| 181 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 182 |
+
self.transition1 = self._make_transition_layer(
|
| 183 |
+
[stage1_out_channel], num_inchannels)
|
| 184 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
| 185 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
|
| 186 |
+
num_blocks=2 * [num_blocks], num_channels=num_channels)
|
| 187 |
+
|
| 188 |
+
self.stage3_num_branches = 3
|
| 189 |
+
num_channels = [width, 2 * width, 4 * width]
|
| 190 |
+
num_inchannels = [
|
| 191 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 192 |
+
self.transition2 = self._make_transition_layer(
|
| 193 |
+
pre_stage_channels, num_inchannels)
|
| 194 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
| 195 |
+
BasicBlockV1b, num_inchannels=num_inchannels,
|
| 196 |
+
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
|
| 197 |
+
num_blocks=3 * [num_blocks], num_channels=num_channels)
|
| 198 |
+
|
| 199 |
+
self.stage4_num_branches = 4
|
| 200 |
+
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
| 201 |
+
num_inchannels = [
|
| 202 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
| 203 |
+
self.transition3 = self._make_transition_layer(
|
| 204 |
+
pre_stage_channels, num_inchannels)
|
| 205 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
| 206 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
|
| 207 |
+
num_branches=self.stage4_num_branches,
|
| 208 |
+
num_blocks=4 * [num_blocks], num_channels=num_channels)
|
| 209 |
+
|
| 210 |
+
if self.ocr_on:
|
| 211 |
+
last_inp_channels = np.int_(np.sum(pre_stage_channels))
|
| 212 |
+
ocr_mid_channels = 2 * ocr_width
|
| 213 |
+
ocr_key_channels = ocr_width
|
| 214 |
+
|
| 215 |
+
self.conv3x3_ocr = nn.Sequential(
|
| 216 |
+
nn.Conv2d(last_inp_channels, ocr_mid_channels,
|
| 217 |
+
kernel_size=3, stride=1, padding=1),
|
| 218 |
+
norm_layer(ocr_mid_channels),
|
| 219 |
+
nn.ReLU(inplace=relu_inplace),
|
| 220 |
+
)
|
| 221 |
+
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
| 222 |
+
|
| 223 |
+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
|
| 224 |
+
key_channels=ocr_key_channels,
|
| 225 |
+
out_channels=ocr_mid_channels,
|
| 226 |
+
scale=1,
|
| 227 |
+
dropout=0.05,
|
| 228 |
+
norm_layer=norm_layer,
|
| 229 |
+
align_corners=align_corners, opt=opt)
|
| 230 |
+
|
| 231 |
+
def _make_transition_layer(
|
| 232 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
| 233 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 234 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 235 |
+
|
| 236 |
+
transition_layers = []
|
| 237 |
+
for i in range(num_branches_cur):
|
| 238 |
+
if i < num_branches_pre:
|
| 239 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 240 |
+
transition_layers.append(nn.Sequential(
|
| 241 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
| 242 |
+
num_channels_cur_layer[i],
|
| 243 |
+
kernel_size=3,
|
| 244 |
+
stride=1,
|
| 245 |
+
padding=1,
|
| 246 |
+
bias=False),
|
| 247 |
+
self.norm_layer(num_channels_cur_layer[i]),
|
| 248 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 249 |
+
else:
|
| 250 |
+
transition_layers.append(None)
|
| 251 |
+
else:
|
| 252 |
+
conv3x3s = []
|
| 253 |
+
for j in range(i + 1 - num_branches_pre):
|
| 254 |
+
inchannels = num_channels_pre_layer[-1]
|
| 255 |
+
outchannels = num_channels_cur_layer[i] \
|
| 256 |
+
if j == i - num_branches_pre else inchannels
|
| 257 |
+
conv3x3s.append(nn.Sequential(
|
| 258 |
+
nn.Conv2d(inchannels, outchannels,
|
| 259 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
| 260 |
+
self.norm_layer(outchannels),
|
| 261 |
+
nn.ReLU(inplace=relu_inplace)))
|
| 262 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 263 |
+
|
| 264 |
+
return nn.ModuleList(transition_layers)
|
| 265 |
+
|
| 266 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
| 267 |
+
downsample = None
|
| 268 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
| 269 |
+
downsample = nn.Sequential(
|
| 270 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
| 271 |
+
kernel_size=1, stride=stride, bias=False),
|
| 272 |
+
self.norm_layer(planes * block.expansion),
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
layers = []
|
| 276 |
+
layers.append(block(inplanes, planes, stride,
|
| 277 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
| 278 |
+
inplanes = planes * block.expansion
|
| 279 |
+
for i in range(1, blocks):
|
| 280 |
+
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
| 281 |
+
|
| 282 |
+
return nn.Sequential(*layers)
|
| 283 |
+
|
| 284 |
+
def _make_stage(self, block, num_inchannels,
|
| 285 |
+
num_modules, num_branches, num_blocks, num_channels,
|
| 286 |
+
fuse_method='SUM',
|
| 287 |
+
multi_scale_output=True):
|
| 288 |
+
modules = []
|
| 289 |
+
for i in range(num_modules):
|
| 290 |
+
# multi_scale_output is only used last module
|
| 291 |
+
if not multi_scale_output and i == num_modules - 1:
|
| 292 |
+
reset_multi_scale_output = False
|
| 293 |
+
else:
|
| 294 |
+
reset_multi_scale_output = True
|
| 295 |
+
modules.append(
|
| 296 |
+
HighResolutionModule(num_branches,
|
| 297 |
+
block,
|
| 298 |
+
num_blocks,
|
| 299 |
+
num_inchannels,
|
| 300 |
+
num_channels,
|
| 301 |
+
fuse_method,
|
| 302 |
+
reset_multi_scale_output,
|
| 303 |
+
norm_layer=self.norm_layer,
|
| 304 |
+
align_corners=self.align_corners)
|
| 305 |
+
)
|
| 306 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
| 307 |
+
|
| 308 |
+
return nn.Sequential(*modules), num_inchannels
|
| 309 |
+
|
| 310 |
+
def forward(self, x, mask=None, additional_features=None):
|
| 311 |
+
hrnet_feats = self.compute_hrnet_feats(x, additional_features)
|
| 312 |
+
if not self.ocr_on:
|
| 313 |
+
return hrnet_feats,
|
| 314 |
+
|
| 315 |
+
ocr_feats = self.conv3x3_ocr(hrnet_feats)
|
| 316 |
+
mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
|
| 317 |
+
context = self.ocr_gather_head(ocr_feats, mask)
|
| 318 |
+
ocr_feats = self.ocr_distri_head(ocr_feats, context)
|
| 319 |
+
return ocr_feats,
|
| 320 |
+
|
| 321 |
+
def compute_hrnet_feats(self, x, additional_features, return_list=False):
|
| 322 |
+
x = self.compute_pre_stage_features(x, additional_features)
|
| 323 |
+
x = self.layer1(x)
|
| 324 |
+
|
| 325 |
+
x_list = []
|
| 326 |
+
for i in range(self.stage2_num_branches):
|
| 327 |
+
if self.transition1[i] is not None:
|
| 328 |
+
x_list.append(self.transition1[i](x))
|
| 329 |
+
else:
|
| 330 |
+
x_list.append(x)
|
| 331 |
+
y_list = self.stage2(x_list)
|
| 332 |
+
|
| 333 |
+
x_list = []
|
| 334 |
+
for i in range(self.stage3_num_branches):
|
| 335 |
+
if self.transition2[i] is not None:
|
| 336 |
+
if i < self.stage2_num_branches:
|
| 337 |
+
x_list.append(self.transition2[i](y_list[i]))
|
| 338 |
+
else:
|
| 339 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
| 340 |
+
else:
|
| 341 |
+
x_list.append(y_list[i])
|
| 342 |
+
y_list = self.stage3(x_list)
|
| 343 |
+
|
| 344 |
+
x_list = []
|
| 345 |
+
for i in range(self.stage4_num_branches):
|
| 346 |
+
if self.transition3[i] is not None:
|
| 347 |
+
if i < self.stage3_num_branches:
|
| 348 |
+
x_list.append(self.transition3[i](y_list[i]))
|
| 349 |
+
else:
|
| 350 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
| 351 |
+
else:
|
| 352 |
+
x_list.append(y_list[i])
|
| 353 |
+
x = self.stage4(x_list)
|
| 354 |
+
|
| 355 |
+
if return_list:
|
| 356 |
+
return x
|
| 357 |
+
|
| 358 |
+
# Upsampling
|
| 359 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
| 360 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
|
| 361 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 362 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
|
| 363 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 364 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
|
| 365 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 366 |
+
|
| 367 |
+
return torch.cat([x[0], x1, x2, x3], 1)
|
| 368 |
+
|
| 369 |
+
def compute_pre_stage_features(self, x, additional_features):
|
| 370 |
+
x = self.conv1(x)
|
| 371 |
+
x = self.bn1(x)
|
| 372 |
+
x = self.relu(x)
|
| 373 |
+
if additional_features is not None:
|
| 374 |
+
x = x + additional_features
|
| 375 |
+
x = self.conv2(x)
|
| 376 |
+
x = self.bn2(x)
|
| 377 |
+
return self.relu(x)
|
| 378 |
+
|
| 379 |
+
def load_pretrained_weights(self, pretrained_path=''):
|
| 380 |
+
model_dict = self.state_dict()
|
| 381 |
+
|
| 382 |
+
if not os.path.exists(pretrained_path):
|
| 383 |
+
print(f'\nFile "{pretrained_path}" does not exist.')
|
| 384 |
+
print('You need to specify the correct path to the pre-trained weights.\n'
|
| 385 |
+
'You can download the weights for HRNet from the repository:\n'
|
| 386 |
+
'https://github.com/HRNet/HRNet-Image-Classification')
|
| 387 |
+
exit(1)
|
| 388 |
+
|
| 389 |
+
# Устанавливаем устройство, на котором будет работать модель
|
| 390 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 391 |
+
|
| 392 |
+
# Загружаем веса и перемещаем на выбранное устройство
|
| 393 |
+
pretrained_dict = torch.load(pretrained_path, map_location=device)
|
| 394 |
+
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in pretrained_dict.items()}
|
| 395 |
+
params_count = len(pretrained_dict)
|
| 396 |
+
|
| 397 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
|
| 398 |
+
|
| 399 |
+
print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
|
| 400 |
+
|
| 401 |
+
model_dict.update(pretrained_dict)
|
| 402 |
+
self.load_state_dict(model_dict)
|
| 403 |
+
|
| 404 |
+
# Перемещаем модель на устройство
|
| 405 |
+
self.to(device)
|
model/hrnetv2/modifiers.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
class LRMult(object):
|
| 4 |
+
def __init__(self, lr_mult=1.):
|
| 5 |
+
self.lr_mult = lr_mult
|
| 6 |
+
|
| 7 |
+
def __call__(self, m):
|
| 8 |
+
if getattr(m, 'weight', None) is not None:
|
| 9 |
+
m.weight.lr_mult = self.lr_mult
|
| 10 |
+
if getattr(m, 'bias', None) is not None:
|
| 11 |
+
m.bias.lr_mult = self.lr_mult
|
model/hrnetv2/ocr.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch._utils
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SpatialGather_Module(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Aggregate the context features according to the initial
|
| 10 |
+
predicted probability distribution.
|
| 11 |
+
Employ the soft-weighted method to aggregate the context.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, cls_num=0, scale=1):
|
| 15 |
+
super(SpatialGather_Module, self).__init__()
|
| 16 |
+
self.cls_num = cls_num
|
| 17 |
+
self.scale = scale
|
| 18 |
+
|
| 19 |
+
def forward(self, feats, probs):
|
| 20 |
+
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
| 21 |
+
probs = probs.view(batch_size, c, -1)
|
| 22 |
+
feats = feats.view(batch_size, feats.size(1), -1)
|
| 23 |
+
feats = feats.permute(0, 2, 1) # batch x hw x c
|
| 24 |
+
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
| 25 |
+
ocr_context = torch.matmul(probs, feats) \
|
| 26 |
+
.permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c
|
| 27 |
+
return ocr_context
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SpatialOCR_Module(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Implementation of the OCR module:
|
| 33 |
+
We aggregate the global object representation to update the representation for each pixel.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self,
|
| 37 |
+
in_channels,
|
| 38 |
+
key_channels,
|
| 39 |
+
out_channels,
|
| 40 |
+
scale=1,
|
| 41 |
+
dropout=0.1,
|
| 42 |
+
norm_layer=nn.BatchNorm2d,
|
| 43 |
+
align_corners=True, opt=None):
|
| 44 |
+
super(SpatialOCR_Module, self).__init__()
|
| 45 |
+
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
|
| 46 |
+
norm_layer, align_corners)
|
| 47 |
+
_in_channels = 2 * in_channels
|
| 48 |
+
self.conv_bn_dropout = nn.Sequential(
|
| 49 |
+
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
| 50 |
+
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
| 51 |
+
nn.Dropout2d(dropout)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(self, feats, proxy_feats):
|
| 55 |
+
context = self.object_context_block(feats, proxy_feats)
|
| 56 |
+
|
| 57 |
+
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
| 58 |
+
|
| 59 |
+
return output
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ObjectAttentionBlock2D(nn.Module):
|
| 63 |
+
'''
|
| 64 |
+
The basic implementation for object context block
|
| 65 |
+
Input:
|
| 66 |
+
N X C X H X W
|
| 67 |
+
Parameters:
|
| 68 |
+
in_channels : the dimension of the input feature map
|
| 69 |
+
key_channels : the dimension after the key/query transform
|
| 70 |
+
scale : choose the scale to downsample the input feature maps (save memory cost)
|
| 71 |
+
bn_type : specify the bn type
|
| 72 |
+
Return:
|
| 73 |
+
N X C X H X W
|
| 74 |
+
'''
|
| 75 |
+
|
| 76 |
+
def __init__(self,
|
| 77 |
+
in_channels,
|
| 78 |
+
key_channels,
|
| 79 |
+
scale=1,
|
| 80 |
+
norm_layer=nn.BatchNorm2d,
|
| 81 |
+
align_corners=True):
|
| 82 |
+
super(ObjectAttentionBlock2D, self).__init__()
|
| 83 |
+
self.scale = scale
|
| 84 |
+
self.in_channels = in_channels
|
| 85 |
+
self.key_channels = key_channels
|
| 86 |
+
self.align_corners = align_corners
|
| 87 |
+
|
| 88 |
+
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
| 89 |
+
self.f_pixel = nn.Sequential(
|
| 90 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
| 91 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 92 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
| 93 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
| 94 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 95 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
| 96 |
+
)
|
| 97 |
+
self.f_object = nn.Sequential(
|
| 98 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
| 99 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 100 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
| 101 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
| 102 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 103 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
| 104 |
+
)
|
| 105 |
+
self.f_down = nn.Sequential(
|
| 106 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
| 107 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 108 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
| 109 |
+
)
|
| 110 |
+
self.f_up = nn.Sequential(
|
| 111 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
|
| 112 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
| 113 |
+
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, x, proxy):
|
| 117 |
+
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
| 118 |
+
if self.scale > 1:
|
| 119 |
+
x = self.pool(x)
|
| 120 |
+
|
| 121 |
+
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
| 122 |
+
query = query.permute(0, 2, 1)
|
| 123 |
+
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
| 124 |
+
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
| 125 |
+
value = value.permute(0, 2, 1)
|
| 126 |
+
|
| 127 |
+
sim_map = torch.matmul(query, key)
|
| 128 |
+
sim_map = (self.key_channels ** -.5) * sim_map
|
| 129 |
+
sim_map = F.softmax(sim_map, dim=-1)
|
| 130 |
+
|
| 131 |
+
# add bg context ...
|
| 132 |
+
context = torch.matmul(sim_map, value)
|
| 133 |
+
context = context.permute(0, 2, 1).contiguous()
|
| 134 |
+
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
| 135 |
+
context = self.f_up(context)
|
| 136 |
+
if self.scale > 1:
|
| 137 |
+
context = F.interpolate(input=context, size=(h, w),
|
| 138 |
+
mode='bilinear', align_corners=self.align_corners)
|
| 139 |
+
|
| 140 |
+
return context
|
model/hrnetv2/resnetv1b.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BasicBlockV1b(nn.Module):
|
| 7 |
+
expansion = 1
|
| 8 |
+
|
| 9 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
| 10 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
| 11 |
+
super(BasicBlockV1b, self).__init__()
|
| 12 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
| 13 |
+
padding=dilation, dilation=dilation, bias=False)
|
| 14 |
+
self.bn1 = norm_layer(planes)
|
| 15 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
|
| 16 |
+
padding=previous_dilation, dilation=previous_dilation, bias=False)
|
| 17 |
+
self.bn2 = norm_layer(planes)
|
| 18 |
+
|
| 19 |
+
self.relu = nn.ReLU(inplace=True)
|
| 20 |
+
self.downsample = downsample
|
| 21 |
+
self.stride = stride
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
residual = x
|
| 25 |
+
|
| 26 |
+
out = self.conv1(x)
|
| 27 |
+
out = self.bn1(out)
|
| 28 |
+
out = self.relu(out)
|
| 29 |
+
|
| 30 |
+
out = self.conv2(out)
|
| 31 |
+
out = self.bn2(out)
|
| 32 |
+
|
| 33 |
+
if self.downsample is not None:
|
| 34 |
+
residual = self.downsample(x)
|
| 35 |
+
|
| 36 |
+
out = out + residual
|
| 37 |
+
out = self.relu(out)
|
| 38 |
+
|
| 39 |
+
return out
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BottleneckV1b(nn.Module):
|
| 43 |
+
expansion = 4
|
| 44 |
+
|
| 45 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
| 46 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
| 47 |
+
super(BottleneckV1b, self).__init__()
|
| 48 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 49 |
+
self.bn1 = norm_layer(planes)
|
| 50 |
+
|
| 51 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 52 |
+
padding=dilation, dilation=dilation, bias=False)
|
| 53 |
+
self.bn2 = norm_layer(planes)
|
| 54 |
+
|
| 55 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 56 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 57 |
+
|
| 58 |
+
self.relu = nn.ReLU(inplace=True)
|
| 59 |
+
self.downsample = downsample
|
| 60 |
+
self.stride = stride
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
residual = x
|
| 64 |
+
|
| 65 |
+
out = self.conv1(x)
|
| 66 |
+
out = self.bn1(out)
|
| 67 |
+
out = self.relu(out)
|
| 68 |
+
|
| 69 |
+
out = self.conv2(out)
|
| 70 |
+
out = self.bn2(out)
|
| 71 |
+
out = self.relu(out)
|
| 72 |
+
|
| 73 |
+
out = self.conv3(out)
|
| 74 |
+
out = self.bn3(out)
|
| 75 |
+
|
| 76 |
+
if self.downsample is not None:
|
| 77 |
+
residual = self.downsample(x)
|
| 78 |
+
|
| 79 |
+
out = out + residual
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ResNetV1b(nn.Module):
|
| 86 |
+
""" Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
|
| 87 |
+
|
| 88 |
+
Parameters
|
| 89 |
+
----------
|
| 90 |
+
block : Block
|
| 91 |
+
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
|
| 92 |
+
layers : list of int
|
| 93 |
+
Numbers of layers in each block
|
| 94 |
+
classes : int, default 1000
|
| 95 |
+
Number of classification classes.
|
| 96 |
+
dilated : bool, default False
|
| 97 |
+
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
| 98 |
+
typically used in Semantic Segmentation.
|
| 99 |
+
norm_layer : object
|
| 100 |
+
Normalization layer used (default: :class:`nn.BatchNorm2d`)
|
| 101 |
+
deep_stem : bool, default False
|
| 102 |
+
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
|
| 103 |
+
avg_down : bool, default False
|
| 104 |
+
Whether to use average pooling for projection skip connection between stages/downsample.
|
| 105 |
+
final_drop : float, default 0.0
|
| 106 |
+
Dropout ratio before the final classification layer.
|
| 107 |
+
|
| 108 |
+
Reference:
|
| 109 |
+
- He, Kaiming, et al. "Deep residual learning for image recognition."
|
| 110 |
+
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
|
| 111 |
+
|
| 112 |
+
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
|
| 115 |
+
avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
|
| 116 |
+
self.inplanes = stem_width*2 if deep_stem else 64
|
| 117 |
+
super(ResNetV1b, self).__init__()
|
| 118 |
+
if not deep_stem:
|
| 119 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 120 |
+
else:
|
| 121 |
+
self.conv1 = nn.Sequential(
|
| 122 |
+
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
|
| 123 |
+
norm_layer(stem_width),
|
| 124 |
+
nn.ReLU(True),
|
| 125 |
+
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
|
| 126 |
+
norm_layer(stem_width),
|
| 127 |
+
nn.ReLU(True),
|
| 128 |
+
nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
|
| 129 |
+
)
|
| 130 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 131 |
+
self.relu = nn.ReLU(True)
|
| 132 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
| 133 |
+
self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
|
| 134 |
+
norm_layer=norm_layer)
|
| 135 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
|
| 136 |
+
norm_layer=norm_layer)
|
| 137 |
+
if dilated:
|
| 138 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
|
| 139 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 140 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
|
| 141 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 142 |
+
else:
|
| 143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 144 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 145 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 146 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
| 147 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 148 |
+
self.drop = None
|
| 149 |
+
if final_drop > 0.0:
|
| 150 |
+
self.drop = nn.Dropout(final_drop)
|
| 151 |
+
self.fc = nn.Linear(512 * block.expansion, classes)
|
| 152 |
+
|
| 153 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
| 154 |
+
avg_down=False, norm_layer=nn.BatchNorm2d):
|
| 155 |
+
downsample = None
|
| 156 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 157 |
+
downsample = []
|
| 158 |
+
if avg_down:
|
| 159 |
+
if dilation == 1:
|
| 160 |
+
downsample.append(
|
| 161 |
+
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
downsample.append(
|
| 165 |
+
nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
|
| 166 |
+
)
|
| 167 |
+
downsample.extend([
|
| 168 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
| 169 |
+
kernel_size=1, stride=1, bias=False),
|
| 170 |
+
norm_layer(planes * block.expansion)
|
| 171 |
+
])
|
| 172 |
+
downsample = nn.Sequential(*downsample)
|
| 173 |
+
else:
|
| 174 |
+
downsample = nn.Sequential(
|
| 175 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
| 176 |
+
kernel_size=1, stride=stride, bias=False),
|
| 177 |
+
norm_layer(planes * block.expansion)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
layers = []
|
| 181 |
+
if dilation in (1, 2):
|
| 182 |
+
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
|
| 183 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
| 184 |
+
elif dilation == 4:
|
| 185 |
+
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
|
| 186 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
| 187 |
+
else:
|
| 188 |
+
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
| 189 |
+
|
| 190 |
+
self.inplanes = planes * block.expansion
|
| 191 |
+
for _ in range(1, blocks):
|
| 192 |
+
layers.append(block(self.inplanes, planes, dilation=dilation,
|
| 193 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
| 194 |
+
|
| 195 |
+
return nn.Sequential(*layers)
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
x = self.conv1(x)
|
| 199 |
+
x = self.bn1(x)
|
| 200 |
+
x = self.relu(x)
|
| 201 |
+
x = self.maxpool(x)
|
| 202 |
+
|
| 203 |
+
x = self.layer1(x)
|
| 204 |
+
x = self.layer2(x)
|
| 205 |
+
x = self.layer3(x)
|
| 206 |
+
x = self.layer4(x)
|
| 207 |
+
|
| 208 |
+
x = self.avgpool(x)
|
| 209 |
+
x = x.view(x.size(0), -1)
|
| 210 |
+
if self.drop is not None:
|
| 211 |
+
x = self.drop(x)
|
| 212 |
+
x = self.fc(x)
|
| 213 |
+
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _safe_state_dict_filtering(orig_dict, model_dict_keys):
|
| 218 |
+
filtered_orig_dict = {}
|
| 219 |
+
for k, v in orig_dict.items():
|
| 220 |
+
if k in model_dict_keys:
|
| 221 |
+
filtered_orig_dict[k] = v
|
| 222 |
+
else:
|
| 223 |
+
print(f"[ERROR] Failed to load <{k}> in backbone")
|
| 224 |
+
return filtered_orig_dict
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def resnet34_v1b(pretrained=False, **kwargs):
|
| 228 |
+
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
|
| 229 |
+
if pretrained:
|
| 230 |
+
model_dict = model.state_dict()
|
| 231 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 232 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
|
| 233 |
+
model_dict.keys()
|
| 234 |
+
)
|
| 235 |
+
model_dict.update(filtered_orig_dict)
|
| 236 |
+
model.load_state_dict(model_dict)
|
| 237 |
+
return model
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def resnet50_v1s(pretrained=False, **kwargs):
|
| 241 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
|
| 242 |
+
if pretrained:
|
| 243 |
+
model_dict = model.state_dict()
|
| 244 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 245 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
|
| 246 |
+
model_dict.keys()
|
| 247 |
+
)
|
| 248 |
+
model_dict.update(filtered_orig_dict)
|
| 249 |
+
model.load_state_dict(model_dict)
|
| 250 |
+
return model
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def resnet101_v1s(pretrained=False, **kwargs):
|
| 254 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
|
| 255 |
+
if pretrained:
|
| 256 |
+
model_dict = model.state_dict()
|
| 257 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 258 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
|
| 259 |
+
model_dict.keys()
|
| 260 |
+
)
|
| 261 |
+
model_dict.update(filtered_orig_dict)
|
| 262 |
+
model.load_state_dict(model_dict)
|
| 263 |
+
return model
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def resnet152_v1s(pretrained=False, **kwargs):
|
| 267 |
+
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
|
| 268 |
+
if pretrained:
|
| 269 |
+
model_dict = model.state_dict()
|
| 270 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
| 271 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
|
| 272 |
+
model_dict.keys()
|
| 273 |
+
)
|
| 274 |
+
model_dict.update(filtered_orig_dict)
|
| 275 |
+
model.load_state_dict(model_dict)
|
| 276 |
+
return model
|