krakotay commited on
Commit
f2b4d53
·
1 Parent(s): a304ddb

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +9 -0
  3. LICENSE +201 -0
  4. app.py +277 -4
  5. assets/demo.gif +3 -0
  6. assets/metrics.png +0 -0
  7. assets/network.png +0 -0
  8. assets/title_any_image.gif +0 -0
  9. assets/title_harmon.gif +0 -0
  10. assets/title_you_want.gif +0 -0
  11. assets/visualizations.png +0 -0
  12. assets/visualizations2.png +3 -0
  13. datasets/__init__.py +0 -0
  14. datasets/build_INR_dataset.py +36 -0
  15. datasets/build_dataset.py +371 -0
  16. demo/demo_1k_composite_2.jpg +0 -0
  17. demo/demo_1k_composite_3.jpg +0 -0
  18. demo/demo_1k_mask_2.jpg +0 -0
  19. demo/demo_1k_mask_3.jpg +0 -0
  20. demo/demo_composite.jpg +0 -0
  21. demo/demo_composite_1.jpg +0 -0
  22. demo/demo_composite_2.jpg +0 -0
  23. demo/demo_composite_3.jpg +0 -0
  24. demo/demo_composite_4.jpg +0 -0
  25. demo/demo_composite_5.jpg +0 -0
  26. demo/demo_composite_6.jpg +0 -0
  27. demo/demo_mask.png +0 -0
  28. demo/demo_mask_1.png +0 -0
  29. demo/demo_mask_2.png +0 -0
  30. demo/demo_mask_3.png +0 -0
  31. demo/demo_mask_4.jpg +0 -0
  32. demo/demo_mask_5.jpg +0 -0
  33. demo/demo_mask_6.jpg +0 -0
  34. efficient_inference_for_square_image.py +356 -0
  35. hrnet_ocr.py +401 -0
  36. inference.py +236 -0
  37. inference_for_arbitrary_resolution_image.py +345 -0
  38. model/__init__.py +0 -0
  39. model/backbone.py +79 -0
  40. model/base/__init__.py +0 -0
  41. model/base/basic_blocks.py +366 -0
  42. model/base/conv_autoencoder.py +519 -0
  43. model/base/ih_model.py +88 -0
  44. model/base/ops.py +397 -0
  45. model/build_model.py +24 -0
  46. model/hrnetv2/__init__.py +0 -0
  47. model/hrnetv2/hrnet_ocr.py +405 -0
  48. model/hrnetv2/modifiers.py +11 -0
  49. model/hrnetv2/ocr.py +140 -0
  50. model/hrnetv2/resnetv1b.py +276 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text
38
+ demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text
39
+ demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .idea/*
2
+ logs/*
3
+ wandb/*
4
+ system/
5
+ *.bat
6
+ *.7z
7
+ .venv
8
+ __pycache__/
9
+ *.pyc
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py CHANGED
@@ -1,7 +1,280 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+
3
+ import cv2
4
+
5
  import gradio as gr
6
+ import numpy as np
7
+ import sys
8
+ import io
9
+ import spaces
10
+
11
+ class Logger:
12
+ def __init__(self):
13
+ self.terminal = sys.stdout
14
+ self.log = io.BytesIO()
15
+
16
+ def write(self, message):
17
+ self.terminal.write(message)
18
+ self.log.write(bytes(message, encoding='utf-8'))
19
+
20
+ def flush(self):
21
+ self.terminal.flush()
22
+ self.log.flush()
23
+
24
+ def isatty(self):
25
+ return False
26
+
27
+
28
+ log = Logger()
29
+ sys.stdout = log
30
+
31
+
32
+ def read_logs():
33
+ out = log.log.getvalue().decode()
34
+ if out.count("\n") >= 30:
35
+ log.log = io.BytesIO()
36
+ sys.stdout.flush()
37
+ return out
38
+
39
+
40
+ with gr.Blocks() as app:
41
+
42
+ valid_checkpoints_dict = {"Resolution_256_iHarmony4": "Resolution_256_iHarmony4.pth",
43
+ "Resolution_1024_HAdobe5K": "Resolution_1024_HAdobe5K.pth",
44
+ "Resolution_2048_HAdobe5K": "Resolution_2048_HAdobe5K.pth",
45
+ "Resolution_RAW_HAdobe5K": "Resolution_RAW_HAdobe5K.pth",
46
+ "Resolution_RAW_iHarmony4": "Resolution_RAW_iHarmony4.pth"}
47
+
48
+ global_state = gr.State(valid_checkpoints_dict["Resolution_RAW_iHarmony4"])
49
+ with gr.Row():
50
+ with gr.Column():
51
+ form_composite_image = gr.Image(label='Input Composite image', type='pil')
52
+ gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "composite" in i]),
53
+ label="Composite Examples", inputs=form_composite_image, cache_examples=False)
54
+ with gr.Column():
55
+ form_mask_image = gr.Image(label='Input Mask image', type='pil', interactive=False)
56
+ gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "mask" in i]),
57
+ label="Mask Examples", inputs=form_mask_image, cache_examples=False)
58
+ with gr.Row():
59
+ with gr.Column(scale=4):
60
+ with gr.Row():
61
+ with gr.Column():
62
+ gr.Markdown(value='Model Selection', show_label=False)
63
+
64
+ with gr.Column():
65
+ form_pretrained_dropdown = gr.Dropdown(
66
+ choices=list(valid_checkpoints_dict.values()),
67
+ label="Pretrained Model",
68
+ value=valid_checkpoints_dict["Resolution_RAW_iHarmony4"],
69
+ interactive=True
70
+ )
71
+
72
+ with gr.Row():
73
+ with gr.Column():
74
+ gr.Markdown(value='Inference Mode', show_label=False)
75
+
76
+ with gr.Column():
77
+ form_inference_mode = gr.Radio(
78
+ ['Square Image', 'Arbitrary Image'],
79
+ value='Arbitrary Image',
80
+ interactive=False,
81
+ label='Mode',
82
+ )
83
+
84
+ with gr.Row():
85
+ with gr.Column():
86
+ gr.Markdown(value='Split Parameter', show_label=False)
87
+
88
+ with gr.Column():
89
+ form_split_res = gr.Slider(
90
+ minimum=0,
91
+ maximum=2048,
92
+ step=128,
93
+ value=256,
94
+ interactive=True,
95
+ label="Split Resolution",
96
+ )
97
+ form_split_num = gr.Number(
98
+ value=2,
99
+ interactive=True,
100
+ label="Split Number")
101
+ with gr.Row():
102
+ form_log = gr.Textbox(read_logs, label="Logs", interactive=False, type="text", every=1)
103
+
104
+ with gr.Column(scale=4):
105
+ form_harmonized_image = gr.Image(label='Harmonized Result', type='numpy', interactive=False, format="png")
106
+ form_start_btn = gr.Button("Start Harmonization", interactive=False)
107
+ form_reset_btn = gr.Button("Reset", interactive=True)
108
+ form_stop_btn = gr.Button("Stop", interactive=True)
109
+
110
+
111
+ def on_change_form_composite_image(form_composite_image):
112
+ if form_composite_image is None:
113
+ return gr.update(interactive=False, value=None), gr.update(value=None)
114
+ return gr.update(interactive=True, value=None), gr.update(value=None)
115
+
116
+
117
+ def on_change_form_mask_image(form_composite_image, form_mask_image):
118
+ if form_mask_image is None:
119
+ return gr.update(interactive=False), gr.update(
120
+ interactive=False if form_composite_image is None else True), gr.update(interactive=False), gr.update(
121
+ interactive=False), gr.update(interactive=False), gr.update(value=None)
122
+
123
+ if form_composite_image.size[:2] != form_mask_image.size[:2]:
124
+ raise gr.Error("Composite image and mask image should have the same resolution!")
125
+ else:
126
+ w, h = form_composite_image.size[:2]
127
+ if h != w or (h % 16 != 0):
128
+ return gr.update(value='Arbitrary Image', interactive=False), gr.update(interactive=True), gr.update(
129
+ interactive=True), gr.update(interactive=True, visible=True), gr.update(interactive=False,
130
+ value=-1, visible=False), gr.update(value=None)
131
+ else:
132
+ return gr.update(value='Square Image', interactive=True), gr.update(interactive=True), gr.update(
133
+ interactive=True), gr.update(interactive=False, visible=False), gr.update(interactive=True,
134
+ value=h // 2,
135
+ maximum=h,
136
+ minimum=h // 16,
137
+ step=h // 16, visible=True), gr.update(value=None)
138
+
139
+
140
+ form_composite_image.change(
141
+ on_change_form_composite_image,
142
+ inputs=[form_composite_image],
143
+ outputs=[form_mask_image, form_harmonized_image]
144
+ )
145
+
146
+ form_mask_image.change(
147
+ on_change_form_mask_image,
148
+ inputs=[form_composite_image, form_mask_image],
149
+ outputs=[form_inference_mode, form_mask_image, form_start_btn, form_split_num, form_split_res,
150
+ form_harmonized_image]
151
+ )
152
+
153
+
154
+ def on_change_form_split_num(form_composite_image, form_split_num):
155
+ w, h = form_composite_image.size[:2]
156
+ if form_split_num < 1:
157
+ return gr.update(value=1)
158
+ elif form_split_num > min(w, h):
159
+ return gr.update(value=min(w, h))
160
+ else:
161
+ return gr.update(value=form_split_num)
162
+
163
+
164
+ form_split_num.change(
165
+ on_change_form_split_num,
166
+ inputs=[form_composite_image, form_split_num],
167
+ outputs=[form_split_num]
168
+ )
169
+
170
+
171
+ def on_change_form_inference_mode(form_inference_mode):
172
+ if form_inference_mode == "Square Image":
173
+ return gr.update(interactive=True, visible=True), gr.update(interactive=False, visible=False)
174
+ else:
175
+ return gr.update(interactive=False, visible=False), gr.update(interactive=True, visible=True)
176
+
177
+
178
+ form_inference_mode.change(on_change_form_inference_mode, inputs=[form_inference_mode],
179
+ outputs=[form_split_res, form_split_num])
180
+
181
+ @spaces.GPU
182
+ def on_click_form_start_btn(form_composite_image, form_mask_image, form_pretrained_dropdown, form_inference_mode,
183
+ form_split_res, form_split_num):
184
+ log.log = io.BytesIO()
185
+ print(f"Harmonizing image with {form_composite_image.size[1]}*{form_composite_image.size[0]}...")
186
+ if form_inference_mode == "Square Image":
187
+ from efficient_inference_for_square_image import parse_args, main_process, global_state
188
+ global_state[0] = 1
189
+
190
+ opt = parse_args()
191
+ opt.transform_mean = [.5, .5, .5]
192
+ opt.transform_var = [.5, .5, .5]
193
+ opt.pretrained = os.path.join("./pretrained_models", form_pretrained_dropdown)
194
+ opt.split_resolution = form_split_res
195
+ opt.save_path = None
196
+ opt.workers = 0
197
+ opt.device = "gpu"
198
+
199
+ composite_image = np.asarray(form_composite_image)
200
+ mask = np.asarray(form_mask_image)
201
+
202
+ try:
203
+ return cv2.cvtColor(
204
+ main_process(opt, composite_image=composite_image, mask=mask),
205
+ cv2.COLOR_BGR2RGB)
206
+ except Exception as e:
207
+ raise gr.Error(f"Patches too big. Try to reduce the `split_res`!\nException is {e}")
208
+
209
+ else:
210
+ from inference_for_arbitrary_resolution_image import parse_args, main_process, global_state
211
+ global_state[0] = 1
212
+
213
+ opt = parse_args()
214
+ opt.transform_mean = [.5, .5, .5]
215
+ opt.transform_var = [.5, .5, .5]
216
+ opt.pretrained = os.path.join("./pretrained_models", form_pretrained_dropdown)
217
+ opt.split_num = int(form_split_num)
218
+ opt.save_path = None
219
+ opt.workers = 0
220
+ opt.device = "gpu"
221
+
222
+ composite_image = np.asarray(form_composite_image)
223
+ mask = np.asarray(form_mask_image)
224
+
225
+ try:
226
+ return cv2.cvtColor(
227
+ main_process(opt, composite_image=composite_image, mask=mask),
228
+ cv2.COLOR_BGR2RGB)
229
+ except Exception as e:
230
+ raise gr.Error(f"Patches too big. Try to increase the `split_num`!\nException is {e}")
231
+
232
+
233
+ generate = form_start_btn.click(on_click_form_start_btn,
234
+ inputs=[form_composite_image, form_mask_image, form_pretrained_dropdown,
235
+ form_inference_mode,
236
+ form_split_res, form_split_num], outputs=[form_harmonized_image])
237
+
238
+
239
+ def on_click_form_reset_btn(form_inference_mode):
240
+ if form_inference_mode == "Square Image":
241
+ from efficient_inference_for_square_image import global_state
242
+ global_state[0] = 0
243
+ else:
244
+ from inference_for_arbitrary_resolution_image import global_state
245
+ global_state[0] = 0
246
+
247
+ log.log = io.BytesIO()
248
+ return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
249
+ interactive=False), gr.update(
250
+ interactive=False)
251
+
252
+
253
+ form_reset_btn.click(on_click_form_reset_btn,
254
+ inputs=[form_inference_mode],
255
+ outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
256
+
257
+
258
+ def on_click_form_stop(form_inference_mode):
259
+ if form_inference_mode == "Square Image":
260
+ from efficient_inference_for_square_image import global_state
261
+ global_state[0] = 0
262
+ else:
263
+ from inference_for_arbitrary_resolution_image import global_state
264
+ global_state[0] = 0
265
+
266
+ log.log = io.BytesIO()
267
+ return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None,
268
+ interactive=False), gr.update(
269
+ interactive=False)
270
+
271
+
272
+ form_stop_btn.click(on_click_form_stop,
273
+ inputs=[form_inference_mode],
274
+ outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate)
275
+
276
+ gr.close_all()
277
 
278
+ app.queue()
 
279
 
280
+ app.launch(show_api=False)
 
assets/demo.gif ADDED

Git LFS Details

  • SHA256: c5f136d5335252050ca723e0360a767ebc5d94fd87d6d372221575769d6528a7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
assets/metrics.png ADDED
assets/network.png ADDED
assets/title_any_image.gif ADDED
assets/title_harmon.gif ADDED
assets/title_you_want.gif ADDED
assets/visualizations.png ADDED
assets/visualizations2.png ADDED

Git LFS Details

  • SHA256: 0fa5f4c202818ab94d6faf57055a323285e169a33ccfd59200bc93a8d597a4a4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
datasets/__init__.py ADDED
File without changes
datasets/build_INR_dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import misc
2
+ from albumentations import Resize
3
+
4
+
5
+ class Implicit2DGenerator(object):
6
+ def __init__(self, opt, mode):
7
+ if mode == 'Train':
8
+ sidelength = opt.INR_input_size
9
+ elif mode == 'Val':
10
+ sidelength = opt.input_size
11
+ else:
12
+ raise NotImplementedError
13
+
14
+ self.mode = mode
15
+
16
+ self.size = sidelength
17
+
18
+ if isinstance(sidelength, int):
19
+ sidelength = (sidelength, sidelength)
20
+
21
+ self.mgrid = misc.get_mgrid(sidelength)
22
+
23
+ self.transform = Resize(self.size, self.size)
24
+
25
+ def generator(self, torch_transforms, composite_image, real_image, mask):
26
+ composite_image = torch_transforms(self.transform(image=composite_image)['image'])
27
+ real_image = torch_transforms(self.transform(image=real_image)['image'])
28
+
29
+ fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3)
30
+ fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
31
+ bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
32
+
33
+ fg_INR_coordinates = self.mgrid
34
+ bg_INR_coordinates = self.mgrid
35
+
36
+ return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB
datasets/build_dataset.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import torchvision
5
+ import os
6
+ import random
7
+
8
+ from utils.misc import prepare_cooridinate_input, customRandomCrop
9
+
10
+ from datasets.build_INR_dataset import Implicit2DGenerator
11
+ import albumentations
12
+ from albumentations import Resize, RandomResizedCrop, HorizontalFlip
13
+ from torch.utils.data import DataLoader
14
+
15
+
16
+ class dataset_generator(torch.utils.data.Dataset):
17
+ def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'):
18
+ super().__init__()
19
+
20
+ self.opt = opt
21
+ self.root_path = opt.dataset_path
22
+ self.mode = mode
23
+
24
+ self.alb_transforms = alb_transforms
25
+ self.torch_transforms = torch_transforms
26
+ self.kp_t = area_keep_thresh
27
+
28
+ with open(dataset_txt, 'r') as f:
29
+ self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()]
30
+
31
+ self.INR_dataset = Implicit2DGenerator(opt, self.mode)
32
+
33
+ def __len__(self):
34
+ return len(self.dataset_samples)
35
+
36
+ def __getitem__(self, idx):
37
+ composite_image = self.dataset_samples[idx]
38
+
39
+ if self.opt.hr_train:
40
+ if self.opt.isFullRes:
41
+ "Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \
42
+ "quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \
43
+ "if `opt.isFullRes` is set to True."
44
+ composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori")
45
+
46
+ real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg'
47
+ mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png'
48
+
49
+ composite_image = cv2.imread(composite_image)
50
+ composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
51
+
52
+ real_image = cv2.imread(real_image)
53
+ real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
54
+
55
+ mask = cv2.imread(mask)
56
+ mask = mask[:, :, 0].astype(np.float32) / 255.
57
+
58
+ """
59
+ If set `opt.hr_train` to True:
60
+
61
+ Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres),
62
+ the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size`
63
+ patch to feed in multiINR process. For inference, just resize it.
64
+
65
+ While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size.
66
+
67
+ BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5`
68
+ """
69
+ if self.opt.hr_train:
70
+ if self.mode == 'Train' and self.opt.isFullRes:
71
+ if random.random() < 0.5: # LR mix training
72
+ mixTransform = albumentations.Compose(
73
+ [
74
+ RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
75
+ HorizontalFlip()],
76
+ additional_targets={'real_image': 'image', 'object_mask': 'image'}
77
+ )
78
+ origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
79
+ origin_bg_ratio = 1 - origin_fg_ratio
80
+
81
+ "Ensure fg and bg not disappear after transformation"
82
+ valid_augmentation = False
83
+ transform_out = None
84
+ time = 0
85
+ while not valid_augmentation:
86
+ time += 1
87
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
88
+ if time == 20:
89
+ tmp_transform = albumentations.Compose(
90
+ [Resize(self.opt.base_size, self.opt.base_size)],
91
+ additional_targets={'real_image': 'image',
92
+ 'object_mask': 'image'})
93
+ transform_out = tmp_transform(image=composite_image, real_image=real_image,
94
+ object_mask=mask)
95
+ valid_augmentation = True
96
+ else:
97
+ transform_out = mixTransform(image=composite_image, real_image=real_image,
98
+ object_mask=mask)
99
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'],
100
+ origin_fg_ratio,
101
+ origin_bg_ratio,
102
+ self.kp_t)
103
+ composite_image = transform_out['image']
104
+ real_image = transform_out['real_image']
105
+ mask = transform_out['object_mask']
106
+ else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop.
107
+ if real_image.shape[0] < 256:
108
+ bottom_pad = 256 - real_image.shape[0]
109
+ else:
110
+ bottom_pad = (4 - real_image.shape[0] % 4) % 4
111
+ if real_image.shape[1] < 256:
112
+ right_pad = 256 - real_image.shape[1]
113
+ else:
114
+ right_pad = (4 - real_image.shape[1] % 4) % 4
115
+ composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad,
116
+ cv2.BORDER_REPLICATE)
117
+ real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
118
+ mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
119
+
120
+ origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
121
+ origin_bg_ratio = 1 - origin_fg_ratio
122
+
123
+ "Ensure fg and bg not disappear after transformation"
124
+ valid_augmentation = False
125
+ transform_out = None
126
+ time = 0
127
+
128
+ if self.opt.hr_train:
129
+ if self.mode == 'Train':
130
+ if not self.opt.isFullRes:
131
+ if random.random() < 0.5: # LR mix training
132
+ mixTransform = albumentations.Compose(
133
+ [
134
+ RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
135
+ HorizontalFlip()],
136
+ additional_targets={'real_image': 'image', 'object_mask': 'image'}
137
+ )
138
+ while not valid_augmentation:
139
+ time += 1
140
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
141
+ if time == 20:
142
+ tmp_transform = albumentations.Compose(
143
+ [Resize(self.opt.base_size, self.opt.base_size)],
144
+ additional_targets={'real_image': 'image',
145
+ 'object_mask': 'image'})
146
+ transform_out = tmp_transform(image=composite_image, real_image=real_image,
147
+ object_mask=mask)
148
+ valid_augmentation = True
149
+ else:
150
+ transform_out = mixTransform(image=composite_image, real_image=real_image,
151
+ object_mask=mask)
152
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'],
153
+ origin_fg_ratio,
154
+ origin_bg_ratio,
155
+ self.kp_t)
156
+ else:
157
+ while not valid_augmentation:
158
+ time += 1
159
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
160
+ if time == 20:
161
+ tmp_transform = albumentations.Compose(
162
+ [Resize(self.opt.input_size, self.opt.input_size)],
163
+ additional_targets={'real_image': 'image',
164
+ 'object_mask': 'image'})
165
+ transform_out = tmp_transform(image=composite_image, real_image=real_image,
166
+ object_mask=mask)
167
+ valid_augmentation = True
168
+ else:
169
+ transform_out = self.alb_transforms(image=composite_image, real_image=real_image,
170
+ object_mask=mask)
171
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'],
172
+ origin_fg_ratio,
173
+ origin_bg_ratio,
174
+ self.kp_t)
175
+ composite_image = transform_out['image']
176
+ real_image = transform_out['real_image']
177
+ mask = transform_out['object_mask']
178
+
179
+ origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
180
+
181
+ full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
182
+
183
+ tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
184
+ additional_targets={'real_image': 'image',
185
+ 'object_mask': 'image'})
186
+ transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
187
+ compos_list = [self.torch_transforms(transform_out['image'])]
188
+ real_list = [self.torch_transforms(transform_out['real_image'])]
189
+ mask_list = [
190
+ torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
191
+ coord_map_list = []
192
+
193
+ valid_augmentation = False
194
+ while not valid_augmentation:
195
+ # RSC strategy. To crop different resolutions.
196
+ transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord],
197
+ self.opt.base_size, self.opt.base_size)
198
+ valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio)
199
+
200
+ compos_list.append(self.torch_transforms(transform_out[0]))
201
+ real_list.append(self.torch_transforms(transform_out[1]))
202
+ mask_list.append(
203
+ torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
204
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
205
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
206
+ for n in range(2):
207
+ tmp_comp = cv2.resize(composite_image, (
208
+ composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
209
+ tmp_real = cv2.resize(real_image,
210
+ (real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1)))
211
+ tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
212
+ tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
213
+
214
+ transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord],
215
+ self.opt.base_size // 2 ** (n + 1),
216
+ self.opt.base_size // 2 ** (n + 1), c_h, c_w)
217
+ compos_list.append(self.torch_transforms(transform_out[0]))
218
+ real_list.append(self.torch_transforms(transform_out[1]))
219
+ mask_list.append(
220
+ torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
221
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
222
+ out_comp = compos_list
223
+ out_real = real_list
224
+ out_mask = mask_list
225
+ out_coord = coord_map_list
226
+
227
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
228
+ self.torch_transforms, transform_out[0], transform_out[1], mask)
229
+
230
+ return {
231
+ 'file_path': self.dataset_samples[idx],
232
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
233
+ 'composite_image': out_comp,
234
+ 'real_image': out_real,
235
+ 'mask': out_mask,
236
+ 'coordinate_map': out_coord,
237
+ 'composite_image0': out_comp[0],
238
+ 'real_image0': out_real[0],
239
+ 'mask0': out_mask[0],
240
+ 'coordinate_map0': out_coord[0],
241
+ 'composite_image1': out_comp[1],
242
+ 'real_image1': out_real[1],
243
+ 'mask1': out_mask[1],
244
+ 'coordinate_map1': out_coord[1],
245
+ 'composite_image2': out_comp[2],
246
+ 'real_image2': out_real[2],
247
+ 'mask2': out_mask[2],
248
+ 'coordinate_map2': out_coord[2],
249
+ 'composite_image3': out_comp[3],
250
+ 'real_image3': out_real[3],
251
+ 'mask3': out_mask[3],
252
+ 'coordinate_map3': out_coord[3],
253
+ 'fg_INR_coordinates': fg_INR_coordinates,
254
+ 'bg_INR_coordinates': bg_INR_coordinates,
255
+ 'fg_INR_RGB': fg_INR_RGB,
256
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
257
+ 'bg_INR_RGB': bg_INR_RGB
258
+ }
259
+ else:
260
+ if not self.opt.isFullRes:
261
+ tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
262
+ additional_targets={'real_image': 'image',
263
+ 'object_mask': 'image'})
264
+ transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
265
+
266
+ coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
267
+
268
+ "Generate INR dataset."
269
+ mask = (torchvision.transforms.ToTensor()(
270
+ transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
271
+ mask = np.bool_(mask.numpy())
272
+
273
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
274
+ self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
275
+
276
+ return {
277
+ 'file_path': self.dataset_samples[idx],
278
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
279
+ 'composite_image': self.torch_transforms(transform_out['image']),
280
+ 'real_image': self.torch_transforms(transform_out['real_image']),
281
+ 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
282
+ # Can automatically transfer to Tensor.
283
+ 'coordinate_map': coordinate_map,
284
+ 'fg_INR_coordinates': fg_INR_coordinates,
285
+ 'bg_INR_coordinates': bg_INR_coordinates,
286
+ 'fg_INR_RGB': fg_INR_RGB,
287
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
288
+ 'bg_INR_RGB': bg_INR_RGB
289
+ }
290
+ else:
291
+ coordinate_map = prepare_cooridinate_input(mask)
292
+
293
+ "Generate INR dataset."
294
+ mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1)
295
+ mask_tmp = np.bool_(mask_tmp.numpy())
296
+
297
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
298
+ self.torch_transforms, composite_image, real_image, mask_tmp)
299
+
300
+ return {
301
+ 'file_path': self.dataset_samples[idx],
302
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
303
+ 'composite_image': self.torch_transforms(composite_image),
304
+ 'real_image': self.torch_transforms(real_image),
305
+ 'mask': mask[np.newaxis, ...].astype(np.float32),
306
+ # Can automatically transfer to Tensor.
307
+ 'coordinate_map': coordinate_map,
308
+ 'fg_INR_coordinates': fg_INR_coordinates,
309
+ 'bg_INR_coordinates': bg_INR_coordinates,
310
+ 'fg_INR_RGB': fg_INR_RGB,
311
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
312
+ 'bg_INR_RGB': bg_INR_RGB
313
+ }
314
+
315
+ while not valid_augmentation:
316
+ time += 1
317
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
318
+ if time == 20:
319
+ tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
320
+ additional_targets={'real_image': 'image',
321
+ 'object_mask': 'image'})
322
+ transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
323
+ valid_augmentation = True
324
+ else:
325
+ transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask)
326
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio,
327
+ origin_bg_ratio,
328
+ self.kp_t)
329
+
330
+ coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
331
+
332
+ "Generate INR dataset."
333
+ mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
334
+ mask = np.bool_(mask.numpy())
335
+
336
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
337
+ self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
338
+
339
+ return {
340
+ 'file_path': self.dataset_samples[idx],
341
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
342
+ 'composite_image': self.torch_transforms(transform_out['image']),
343
+ 'real_image': self.torch_transforms(transform_out['real_image']),
344
+ 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
345
+ # Can automatically transfer to Tensor.
346
+ 'coordinate_map': coordinate_map,
347
+ 'fg_INR_coordinates': fg_INR_coordinates,
348
+ 'bg_INR_coordinates': bg_INR_coordinates,
349
+ 'fg_INR_RGB': fg_INR_RGB,
350
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
351
+ 'bg_INR_RGB': bg_INR_RGB
352
+ }
353
+
354
+
355
+ def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh):
356
+ current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
357
+ current_bg_ratio = 1 - current_fg_ratio
358
+
359
+ if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh:
360
+ return False
361
+
362
+ return True
363
+
364
+
365
+ def check_hr_crop_sample(mask, origin_fg_ratio):
366
+ current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
367
+
368
+ if current_fg_ratio < 0.8 * origin_fg_ratio:
369
+ return False
370
+
371
+ return True
demo/demo_1k_composite_2.jpg ADDED
demo/demo_1k_composite_3.jpg ADDED
demo/demo_1k_mask_2.jpg ADDED
demo/demo_1k_mask_3.jpg ADDED
demo/demo_composite.jpg ADDED
demo/demo_composite_1.jpg ADDED
demo/demo_composite_2.jpg ADDED
demo/demo_composite_3.jpg ADDED
demo/demo_composite_4.jpg ADDED
demo/demo_composite_5.jpg ADDED
demo/demo_composite_6.jpg ADDED
demo/demo_mask.png ADDED
demo/demo_mask_1.png ADDED
demo/demo_mask_2.png ADDED
demo/demo_mask_3.png ADDED
demo/demo_mask_4.jpg ADDED
demo/demo_mask_5.jpg ADDED
demo/demo_mask_6.jpg ADDED
efficient_inference_for_square_image.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import builtins
3
+ from collections import defaultdict
4
+
5
+ import torch.backends.cudnn as cudnn
6
+ import torchvision.transforms as transforms
7
+ from torch.utils.data import DataLoader
8
+
9
+ from model.build_model import build_model
10
+ from torch.optim import AdamW
11
+ from torch.optim.lr_scheduler import OneCycleLR
12
+
13
+ import torch
14
+ import cv2
15
+ import numpy as np
16
+ import torchvision
17
+ import os
18
+ import tqdm
19
+ import time
20
+
21
+ from utils.misc import prepare_cooridinate_input, customRandomCrop
22
+
23
+ from datasets.build_INR_dataset import Implicit2DGenerator
24
+ import albumentations
25
+ from albumentations import Resize
26
+ # from torch.utils.data import DataLoader
27
+ from utils.misc import normalize
28
+
29
+ import math
30
+
31
+ global_state = [1] # For Gradio Stop Button.
32
+
33
+ class single_image_dataset(torch.utils.data.Dataset):
34
+ def __init__(self, opt, composite_image=None, mask=None):
35
+ super().__init__()
36
+
37
+ self.opt = opt
38
+
39
+ if composite_image is None:
40
+ composite_image = cv2.imread(opt.composite_image)
41
+ composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
42
+ self.composite_image = composite_image
43
+
44
+ assert composite_image.shape[0] == composite_image.shape[1], "This faster script only supports square images."
45
+ assert composite_image.shape[
46
+ 0] % 256 == 0, "This faster script only supports images with resolution multiples of 256."
47
+ assert opt.split_resolution % (composite_image.shape[
48
+ 0] // 16) == 0, f"The image resolution is {composite_image.shape[0]}, " \
49
+ f"you should set {opt.split_resolution} to multiplies of {composite_image.shape[0] // 16}"
50
+
51
+ if mask is None:
52
+ mask = cv2.imread(opt.mask)
53
+ mask = mask[:, :, 0].astype(np.float32) / 255.
54
+ self.mask = mask
55
+
56
+ self.torch_transforms = transforms.Compose([transforms.ToTensor(),
57
+ transforms.Normalize([.5, .5, .5], [.5, .5, .5])])
58
+ self.INR_dataset = Implicit2DGenerator(opt, 'Val')
59
+
60
+ self.split_width_resolution = self.split_height_resolution = opt.split_resolution
61
+
62
+ self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution)
63
+ self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution)
64
+
65
+ self.split_start_point = []
66
+
67
+ "Split the image into several parts."
68
+ for i in range(self.num_h):
69
+ for j in range(self.num_w):
70
+ if i == composite_image.shape[0] // self.split_height_resolution:
71
+ if j == composite_image.shape[1] // self.split_width_resolution:
72
+ self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution,
73
+ composite_image.shape[1] - self.split_width_resolution))
74
+ else:
75
+ self.split_start_point.append(
76
+ (composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution))
77
+ else:
78
+ if j == composite_image.shape[1] // self.split_width_resolution:
79
+ self.split_start_point.append(
80
+ (i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution))
81
+ else:
82
+ self.split_start_point.append(
83
+ (i * self.split_height_resolution, j * self.split_width_resolution))
84
+
85
+ assert len(self.split_start_point) == self.num_w * self.num_h
86
+
87
+ print(
88
+ f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.")
89
+ print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}")
90
+
91
+ def __len__(self):
92
+ return self.num_w * self.num_h
93
+
94
+ def __getitem__(self, idx):
95
+ composite_image = self.composite_image
96
+
97
+ mask = self.mask
98
+
99
+ full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
100
+
101
+ tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
102
+ additional_targets={'object_mask': 'image'})
103
+ transform_out = tmp_transform(image=self.composite_image, object_mask=self.mask)
104
+ compos_list = [self.torch_transforms(transform_out['image'])]
105
+ mask_list = [
106
+ torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
107
+ coord_map_list = []
108
+
109
+ if composite_image.shape[0] != self.split_height_resolution:
110
+ c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution)
111
+ else:
112
+ c_h = 0
113
+ if composite_image.shape[1] != self.split_width_resolution:
114
+ c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution)
115
+ else:
116
+ c_w = 0
117
+ transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord],
118
+ self.split_height_resolution, self.split_width_resolution, c_h, c_w)
119
+
120
+ compos_list.append(self.torch_transforms(transform_out[0]))
121
+ mask_list.append(
122
+ torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
123
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
124
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
125
+ for n in range(2):
126
+ tmp_comp = cv2.resize(composite_image, (
127
+ composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
128
+ tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
129
+ tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
130
+
131
+ transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord],
132
+ self.split_height_resolution // 2 ** (n + 1),
133
+ self.split_width_resolution // 2 ** (n + 1), c_h, c_w)
134
+ compos_list.append(self.torch_transforms(transform_out[0]))
135
+ mask_list.append(
136
+ torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
137
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
138
+ out_comp = compos_list
139
+ out_mask = mask_list
140
+ out_coord = coord_map_list
141
+
142
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
143
+ self.torch_transforms, transform_out[0], transform_out[0], mask)
144
+
145
+ return {
146
+ 'composite_image': out_comp,
147
+ 'mask': out_mask,
148
+ 'coordinate_map': out_coord,
149
+ 'composite_image0': out_comp[0],
150
+ 'mask0': out_mask[0],
151
+ 'coordinate_map0': out_coord[0],
152
+ 'composite_image1': out_comp[1],
153
+ 'mask1': out_mask[1],
154
+ 'coordinate_map1': out_coord[1],
155
+ 'composite_image2': out_comp[2],
156
+ 'mask2': out_mask[2],
157
+ 'coordinate_map2': out_coord[2],
158
+ 'composite_image3': out_comp[3],
159
+ 'mask3': out_mask[3],
160
+ 'coordinate_map3': out_coord[3],
161
+ 'fg_INR_coordinates': fg_INR_coordinates,
162
+ 'bg_INR_coordinates': bg_INR_coordinates,
163
+ 'fg_INR_RGB': fg_INR_RGB,
164
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
165
+ 'bg_INR_RGB': bg_INR_RGB,
166
+ 'start_point': self.split_start_point[idx],
167
+ 'start_proportion': [self.split_start_point[idx][0] / (composite_image.shape[0]),
168
+ self.split_start_point[idx][1] / (composite_image.shape[1]),
169
+ (self.split_start_point[idx][0] + self.split_height_resolution) / (
170
+ composite_image.shape[0]),
171
+ (self.split_start_point[idx][1] + self.split_width_resolution) / (
172
+ composite_image.shape[1])],
173
+ }
174
+
175
+
176
+ def parse_args():
177
+ parser = argparse.ArgumentParser()
178
+
179
+ parser.add_argument('--split_resolution', type=int, default=2048,
180
+ help='The resolution of the patch split.')
181
+
182
+ parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg',
183
+ help='composite image path')
184
+
185
+ parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg',
186
+ help='mask path')
187
+
188
+ parser.add_argument('--save_path', type=str, default=r'./demo/',
189
+ help='save path')
190
+
191
+ parser.add_argument('--workers', type=int, default=8,
192
+ metavar='N', help='Dataloader threads.')
193
+
194
+ parser.add_argument('--batch_size', type=int, default=1,
195
+ help='You can override model batch size by specify positive number.')
196
+
197
+ parser.add_argument('--device', type=str, default='cuda',
198
+ help="Whether use cuda, 'cuda' or 'cpu'.")
199
+
200
+ parser.add_argument('--base_size', type=int, default=256,
201
+ help='Base size. Resolution of the image input into the Encoder')
202
+
203
+ parser.add_argument('--input_size', type=int, default=256,
204
+ help='Input size. Resolution of the image that want to be generated by the Decoder')
205
+
206
+ parser.add_argument('--INR_input_size', type=int, default=256,
207
+ help='INR input size. Resolution of the image that want to be generated by the Decoder. '
208
+ 'Should be the same as `input_size`')
209
+
210
+ parser.add_argument('--INR_MLP_dim', type=int, default=32,
211
+ help='Number of channels for INR linear layer.')
212
+
213
+ parser.add_argument('--LUT_dim', type=int, default=7,
214
+ help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
215
+
216
+ parser.add_argument('--activation', type=str, default='leakyrelu_pe',
217
+ help='INR activation layer type: leakyrelu_pe, sine')
218
+
219
+ parser.add_argument('--pretrained', type=str,
220
+ default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
221
+ help='Pretrained weight path')
222
+
223
+ parser.add_argument('--param_factorize_dim', type=int,
224
+ default=10,
225
+ help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
226
+ 'Refer to https://arxiv.org/abs/2011.12026')
227
+
228
+ parser.add_argument('--embedding_type', type=str,
229
+ default="CIPS_embed",
230
+ help='Which embedding_type to use.')
231
+
232
+ parser.add_argument('--INRDecode', action="store_false",
233
+ help='Whether INR decoder. Set it to False if you want to test the baseline '
234
+ '(https://github.com/SamsungLabs/image_harmonization)')
235
+
236
+ parser.add_argument('--isMoreINRInput', action="store_false",
237
+ help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
238
+
239
+ parser.add_argument('--hr_train', action="store_false",
240
+ help='Whether use hr_train. See section 3.4 in the paper.')
241
+
242
+ parser.add_argument('--isFullRes', action="store_true",
243
+ help='Whether for original resolution. See section 3.4 in the paper.')
244
+
245
+ opt = parser.parse_args()
246
+
247
+ assert opt.batch_size == 1, 'This faster script only supports batch size 1 for inference.'
248
+
249
+ return opt
250
+
251
+
252
+ @torch.no_grad()
253
+ def inference(model, opt, composite_image=None, mask=None):
254
+ model.eval()
255
+
256
+ "dataset here is actually consisted of several patches of a single image."
257
+ singledataset = single_image_dataset(opt, composite_image, mask)
258
+
259
+ single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
260
+ num_workers=opt.workers, persistent_workers=False if composite_image is not None else True)
261
+
262
+ "Init a pure black image with the same size as the input image."
263
+ init_img = np.zeros_like(singledataset.composite_image)
264
+
265
+ time_all = 0
266
+
267
+ for step, batch in tqdm.tqdm(enumerate(single_data_loader)):
268
+ composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
269
+ mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
270
+ coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]
271
+ start_points = batch['start_point']
272
+ start_proportion = batch['start_proportion']
273
+
274
+ if opt.batch_size == 1:
275
+ start_points = [torch.cat(start_points)]
276
+ start_proportion = [torch.cat(start_proportion)]
277
+
278
+ fg_INR_coordinates = coordinate_map[1:]
279
+
280
+ try:
281
+ if global_state[0] == 0:
282
+ print("Stop Harmonizing...!")
283
+ break
284
+
285
+ if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
286
+ fg_content_bg_appearance_construct, _, lut_transform_image = model(
287
+ composite_image,
288
+ mask,
289
+ fg_INR_coordinates, start_proportion[0]
290
+ )
291
+ print("Ready for harmonization...")
292
+ if opt.device == "cuda":
293
+ torch.cuda.reset_max_memory_allocated()
294
+ torch.cuda.reset_max_memory_cached()
295
+ start_time = time.time()
296
+ torch.cuda.synchronize()
297
+ fg_content_bg_appearance_construct, _, lut_transform_image = model(
298
+ composite_image,
299
+ mask,
300
+ fg_INR_coordinates, start_proportion[0]
301
+ )
302
+ if opt.device == "cuda":
303
+ torch.cuda.synchronize()
304
+ end_time = time.time()
305
+
306
+ end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2
307
+ end_memory = torch.cuda.memory_allocated() // 1024 ** 2
308
+
309
+ print(f'GPU max memory usage: {end_max_memory} MB')
310
+ print(f'GPU memory usage: {end_memory} MB')
311
+ time_all += (end_time - start_time)
312
+ print(f'progress: {step} / {len(single_data_loader)}')
313
+ except:
314
+ raise Exception(
315
+ f'The image resolution is large. Please reduce the `split_resolution` value. Your current set is {opt.split_resolution}')
316
+
317
+ "Assemble the every patch's harmonized result into the final whole image."
318
+ for id in range(len(fg_INR_coordinates[0])):
319
+ pred_fg_image = fg_content_bg_appearance_construct[-1][id]
320
+ pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * (
321
+ ~(mask[1][id] > 100 / 255.))
322
+
323
+ pred_harmonized_tmp = cv2.cvtColor(
324
+ normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_(
325
+ 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)
326
+
327
+ init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
328
+ start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
329
+
330
+ if opt.device == "cuda":
331
+ print(f'Inference time: {time_all}')
332
+ if opt.save_path is not None:
333
+ os.makedirs(opt.save_path, exist_ok=True)
334
+ cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
335
+ return init_img
336
+
337
+ def main_process(opt, composite_image=None, mask=None):
338
+ # torch.serialization.add_safe_globals([getattr, OneCycleLR, AdamW, defaultdict, builtins.dict])
339
+ cudnn.benchmark = True
340
+ opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
341
+ print("Preparing model...")
342
+ model = build_model(opt).to(opt.device)
343
+
344
+ # Заменяем 'gpu' на 'cuda' и добавляем weights_only=True
345
+ load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
346
+
347
+ model.load_state_dict(load_dict, strict=False)
348
+
349
+ return inference(model, opt, composite_image, mask)
350
+
351
+
352
+ if __name__ == '__main__':
353
+ opt = parse_args()
354
+ opt.transform_mean = [.5, .5, .5]
355
+ opt.transform_var = [.5, .5, .5]
356
+ main_process(opt)
hrnet_ocr.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch._utils
7
+
8
+ from .ocr import SpatialOCR_Module, SpatialGather_Module
9
+ from .resnetv1b import BasicBlockV1b, BottleneckV1b
10
+
11
+ relu_inplace = True
12
+
13
+
14
+ class HighResolutionModule(nn.Module):
15
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
16
+ num_channels, fuse_method,multi_scale_output=True,
17
+ norm_layer=nn.BatchNorm2d, align_corners=True):
18
+ super(HighResolutionModule, self).__init__()
19
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
20
+
21
+ self.num_inchannels = num_inchannels
22
+ self.fuse_method = fuse_method
23
+ self.num_branches = num_branches
24
+ self.norm_layer = norm_layer
25
+ self.align_corners = align_corners
26
+
27
+ self.multi_scale_output = multi_scale_output
28
+
29
+ self.branches = self._make_branches(
30
+ num_branches, blocks, num_blocks, num_channels)
31
+ self.fuse_layers = self._make_fuse_layers()
32
+ self.relu = nn.ReLU(inplace=relu_inplace)
33
+
34
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
35
+ if num_branches != len(num_blocks):
36
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
37
+ num_branches, len(num_blocks))
38
+ raise ValueError(error_msg)
39
+
40
+ if num_branches != len(num_channels):
41
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
42
+ num_branches, len(num_channels))
43
+ raise ValueError(error_msg)
44
+
45
+ if num_branches != len(num_inchannels):
46
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
47
+ num_branches, len(num_inchannels))
48
+ raise ValueError(error_msg)
49
+
50
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
51
+ stride=1):
52
+ downsample = None
53
+ if stride != 1 or \
54
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
55
+ downsample = nn.Sequential(
56
+ nn.Conv2d(self.num_inchannels[branch_index],
57
+ num_channels[branch_index] * block.expansion,
58
+ kernel_size=1, stride=stride, bias=False),
59
+ self.norm_layer(num_channels[branch_index] * block.expansion),
60
+ )
61
+
62
+ layers = []
63
+ layers.append(block(self.num_inchannels[branch_index],
64
+ num_channels[branch_index], stride,
65
+ downsample=downsample, norm_layer=self.norm_layer))
66
+ self.num_inchannels[branch_index] = \
67
+ num_channels[branch_index] * block.expansion
68
+ for i in range(1, num_blocks[branch_index]):
69
+ layers.append(block(self.num_inchannels[branch_index],
70
+ num_channels[branch_index],
71
+ norm_layer=self.norm_layer))
72
+
73
+ return nn.Sequential(*layers)
74
+
75
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
76
+ branches = []
77
+
78
+ for i in range(num_branches):
79
+ branches.append(
80
+ self._make_one_branch(i, block, num_blocks, num_channels))
81
+
82
+ return nn.ModuleList(branches)
83
+
84
+ def _make_fuse_layers(self):
85
+ if self.num_branches == 1:
86
+ return None
87
+
88
+ num_branches = self.num_branches
89
+ num_inchannels = self.num_inchannels
90
+ fuse_layers = []
91
+ for i in range(num_branches if self.multi_scale_output else 1):
92
+ fuse_layer = []
93
+ for j in range(num_branches):
94
+ if j > i:
95
+ fuse_layer.append(nn.Sequential(
96
+ nn.Conv2d(in_channels=num_inchannels[j],
97
+ out_channels=num_inchannels[i],
98
+ kernel_size=1,
99
+ bias=False),
100
+ self.norm_layer(num_inchannels[i])))
101
+ elif j == i:
102
+ fuse_layer.append(None)
103
+ else:
104
+ conv3x3s = []
105
+ for k in range(i - j):
106
+ if k == i - j - 1:
107
+ num_outchannels_conv3x3 = num_inchannels[i]
108
+ conv3x3s.append(nn.Sequential(
109
+ nn.Conv2d(num_inchannels[j],
110
+ num_outchannels_conv3x3,
111
+ kernel_size=3, stride=2, padding=1, bias=False),
112
+ self.norm_layer(num_outchannels_conv3x3)))
113
+ else:
114
+ num_outchannels_conv3x3 = num_inchannels[j]
115
+ conv3x3s.append(nn.Sequential(
116
+ nn.Conv2d(num_inchannels[j],
117
+ num_outchannels_conv3x3,
118
+ kernel_size=3, stride=2, padding=1, bias=False),
119
+ self.norm_layer(num_outchannels_conv3x3),
120
+ nn.ReLU(inplace=relu_inplace)))
121
+ fuse_layer.append(nn.Sequential(*conv3x3s))
122
+ fuse_layers.append(nn.ModuleList(fuse_layer))
123
+
124
+ return nn.ModuleList(fuse_layers)
125
+
126
+ def get_num_inchannels(self):
127
+ return self.num_inchannels
128
+
129
+ def forward(self, x):
130
+ if self.num_branches == 1:
131
+ return [self.branches[0](x[0])]
132
+
133
+ for i in range(self.num_branches):
134
+ x[i] = self.branches[i](x[i])
135
+
136
+ x_fuse = []
137
+ for i in range(len(self.fuse_layers)):
138
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
139
+ for j in range(1, self.num_branches):
140
+ if i == j:
141
+ y = y + x[j]
142
+ elif j > i:
143
+ width_output = x[i].shape[-1]
144
+ height_output = x[i].shape[-2]
145
+ y = y + F.interpolate(
146
+ self.fuse_layers[i][j](x[j]),
147
+ size=[height_output, width_output],
148
+ mode='bilinear', align_corners=self.align_corners)
149
+ else:
150
+ y = y + self.fuse_layers[i][j](x[j])
151
+ x_fuse.append(self.relu(y))
152
+
153
+ return x_fuse
154
+
155
+
156
+ class HighResolutionNet(nn.Module):
157
+ def __init__(self, width, num_classes, ocr_width=256, small=False,
158
+ norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
159
+ super(HighResolutionNet, self).__init__()
160
+ self.opt = opt
161
+ self.norm_layer = norm_layer
162
+ self.width = width
163
+ self.ocr_width = ocr_width
164
+ self.ocr_on = ocr_width > 0
165
+ self.align_corners = align_corners
166
+
167
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
168
+ self.bn1 = norm_layer(64)
169
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
170
+ self.bn2 = norm_layer(64)
171
+ self.relu = nn.ReLU(inplace=relu_inplace)
172
+
173
+ num_blocks = 2 if small else 4
174
+
175
+ stage1_num_channels = 64
176
+ self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
177
+ stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
178
+
179
+ self.stage2_num_branches = 2
180
+ num_channels = [width, 2 * width]
181
+ num_inchannels = [
182
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
183
+ self.transition1 = self._make_transition_layer(
184
+ [stage1_out_channel], num_inchannels)
185
+ self.stage2, pre_stage_channels = self._make_stage(
186
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
187
+ num_blocks=2 * [num_blocks], num_channels=num_channels)
188
+
189
+ self.stage3_num_branches = 3
190
+ num_channels = [width, 2 * width, 4 * width]
191
+ num_inchannels = [
192
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
193
+ self.transition2 = self._make_transition_layer(
194
+ pre_stage_channels, num_inchannels)
195
+ self.stage3, pre_stage_channels = self._make_stage(
196
+ BasicBlockV1b, num_inchannels=num_inchannels,
197
+ num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
198
+ num_blocks=3 * [num_blocks], num_channels=num_channels)
199
+
200
+ self.stage4_num_branches = 4
201
+ num_channels = [width, 2 * width, 4 * width, 8 * width]
202
+ num_inchannels = [
203
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
204
+ self.transition3 = self._make_transition_layer(
205
+ pre_stage_channels, num_inchannels)
206
+ self.stage4, pre_stage_channels = self._make_stage(
207
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
208
+ num_branches=self.stage4_num_branches,
209
+ num_blocks=4 * [num_blocks], num_channels=num_channels)
210
+
211
+ if self.ocr_on:
212
+ last_inp_channels = np.int(np.sum(pre_stage_channels))
213
+ ocr_mid_channels = 2 * ocr_width
214
+ ocr_key_channels = ocr_width
215
+
216
+ self.conv3x3_ocr = nn.Sequential(
217
+ nn.Conv2d(last_inp_channels, ocr_mid_channels,
218
+ kernel_size=3, stride=1, padding=1),
219
+ norm_layer(ocr_mid_channels),
220
+ nn.ReLU(inplace=relu_inplace),
221
+ )
222
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
223
+
224
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
225
+ key_channels=ocr_key_channels,
226
+ out_channels=ocr_mid_channels,
227
+ scale=1,
228
+ dropout=0.05,
229
+ norm_layer=norm_layer,
230
+ align_corners=align_corners, opt=opt)
231
+
232
+ def _make_transition_layer(
233
+ self, num_channels_pre_layer, num_channels_cur_layer):
234
+ num_branches_cur = len(num_channels_cur_layer)
235
+ num_branches_pre = len(num_channels_pre_layer)
236
+
237
+ transition_layers = []
238
+ for i in range(num_branches_cur):
239
+ if i < num_branches_pre:
240
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
241
+ transition_layers.append(nn.Sequential(
242
+ nn.Conv2d(num_channels_pre_layer[i],
243
+ num_channels_cur_layer[i],
244
+ kernel_size=3,
245
+ stride=1,
246
+ padding=1,
247
+ bias=False),
248
+ self.norm_layer(num_channels_cur_layer[i]),
249
+ nn.ReLU(inplace=relu_inplace)))
250
+ else:
251
+ transition_layers.append(None)
252
+ else:
253
+ conv3x3s = []
254
+ for j in range(i + 1 - num_branches_pre):
255
+ inchannels = num_channels_pre_layer[-1]
256
+ outchannels = num_channels_cur_layer[i] \
257
+ if j == i - num_branches_pre else inchannels
258
+ conv3x3s.append(nn.Sequential(
259
+ nn.Conv2d(inchannels, outchannels,
260
+ kernel_size=3, stride=2, padding=1, bias=False),
261
+ self.norm_layer(outchannels),
262
+ nn.ReLU(inplace=relu_inplace)))
263
+ transition_layers.append(nn.Sequential(*conv3x3s))
264
+
265
+ return nn.ModuleList(transition_layers)
266
+
267
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
268
+ downsample = None
269
+ if stride != 1 or inplanes != planes * block.expansion:
270
+ downsample = nn.Sequential(
271
+ nn.Conv2d(inplanes, planes * block.expansion,
272
+ kernel_size=1, stride=stride, bias=False),
273
+ self.norm_layer(planes * block.expansion),
274
+ )
275
+
276
+ layers = []
277
+ layers.append(block(inplanes, planes, stride,
278
+ downsample=downsample, norm_layer=self.norm_layer))
279
+ inplanes = planes * block.expansion
280
+ for i in range(1, blocks):
281
+ layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
282
+
283
+ return nn.Sequential(*layers)
284
+
285
+ def _make_stage(self, block, num_inchannels,
286
+ num_modules, num_branches, num_blocks, num_channels,
287
+ fuse_method='SUM',
288
+ multi_scale_output=True):
289
+ modules = []
290
+ for i in range(num_modules):
291
+ # multi_scale_output is only used last module
292
+ if not multi_scale_output and i == num_modules - 1:
293
+ reset_multi_scale_output = False
294
+ else:
295
+ reset_multi_scale_output = True
296
+ modules.append(
297
+ HighResolutionModule(num_branches,
298
+ block,
299
+ num_blocks,
300
+ num_inchannels,
301
+ num_channels,
302
+ fuse_method,
303
+ reset_multi_scale_output,
304
+ norm_layer=self.norm_layer,
305
+ align_corners=self.align_corners)
306
+ )
307
+ num_inchannels = modules[-1].get_num_inchannels()
308
+
309
+ return nn.Sequential(*modules), num_inchannels
310
+
311
+ def forward(self, x, mask=None, additional_features=None):
312
+ hrnet_feats = self.compute_hrnet_feats(x, additional_features)
313
+ if not self.ocr_on:
314
+ return hrnet_feats,
315
+
316
+ ocr_feats = self.conv3x3_ocr(hrnet_feats)
317
+ mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
318
+ context = self.ocr_gather_head(ocr_feats, mask)
319
+ ocr_feats = self.ocr_distri_head(ocr_feats, context)
320
+ return ocr_feats,
321
+
322
+ def compute_hrnet_feats(self, x, additional_features, return_list=False):
323
+ x = self.compute_pre_stage_features(x, additional_features)
324
+ x = self.layer1(x)
325
+
326
+ x_list = []
327
+ for i in range(self.stage2_num_branches):
328
+ if self.transition1[i] is not None:
329
+ x_list.append(self.transition1[i](x))
330
+ else:
331
+ x_list.append(x)
332
+ y_list = self.stage2(x_list)
333
+
334
+ x_list = []
335
+ for i in range(self.stage3_num_branches):
336
+ if self.transition2[i] is not None:
337
+ if i < self.stage2_num_branches:
338
+ x_list.append(self.transition2[i](y_list[i]))
339
+ else:
340
+ x_list.append(self.transition2[i](y_list[-1]))
341
+ else:
342
+ x_list.append(y_list[i])
343
+ y_list = self.stage3(x_list)
344
+
345
+ x_list = []
346
+ for i in range(self.stage4_num_branches):
347
+ if self.transition3[i] is not None:
348
+ if i < self.stage3_num_branches:
349
+ x_list.append(self.transition3[i](y_list[i]))
350
+ else:
351
+ x_list.append(self.transition3[i](y_list[-1]))
352
+ else:
353
+ x_list.append(y_list[i])
354
+ x = self.stage4(x_list)
355
+
356
+ if return_list:
357
+ return x
358
+
359
+ # Upsampling
360
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
361
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w),
362
+ mode='bilinear', align_corners=self.align_corners)
363
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w),
364
+ mode='bilinear', align_corners=self.align_corners)
365
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w),
366
+ mode='bilinear', align_corners=self.align_corners)
367
+
368
+ return torch.cat([x[0], x1, x2, x3], 1)
369
+
370
+ def compute_pre_stage_features(self, x, additional_features):
371
+ x = self.conv1(x)
372
+ x = self.bn1(x)
373
+ x = self.relu(x)
374
+ if additional_features is not None:
375
+ x = x + additional_features
376
+ x = self.conv2(x)
377
+ x = self.bn2(x)
378
+ return self.relu(x)
379
+
380
+ def load_pretrained_weights(self, pretrained_path=''):
381
+ model_dict = self.state_dict()
382
+
383
+ if not os.path.exists(pretrained_path):
384
+ print(f'\nFile "{pretrained_path}" does not exist.')
385
+ print('You need to specify the correct path to the pre-trained weights.\n'
386
+ 'You can download the weights for HRNet from the repository:\n'
387
+ 'https://github.com/HRNet/HRNet-Image-Classification')
388
+ exit(1)
389
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
390
+ pretrained_dict = torch.load(pretrained_path, map_location=device)
391
+ pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
392
+ pretrained_dict.items()}
393
+ params_count = len(pretrained_dict)
394
+
395
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
396
+ if k in model_dict.keys()}
397
+
398
+ # print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
399
+
400
+ model_dict.update(pretrained_dict)
401
+ self.load_state_dict(model_dict)
inference.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+ import albumentations
5
+ from albumentations import Resize
6
+
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import torchvision.transforms as transforms
10
+ from torch.utils.data import DataLoader
11
+
12
+ from model.build_model import build_model
13
+ from datasets.build_dataset import dataset_generator
14
+
15
+ from utils import misc, metrics
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser()
20
+
21
+ parser.add_argument('--workers', type=int, default=1,
22
+ metavar='N', help='Dataloader threads.')
23
+
24
+ parser.add_argument('--batch_size', type=int, default=1,
25
+ help='You can override model batch size by specify positive number.')
26
+
27
+ parser.add_argument('--device', type=str, default='cuda',
28
+ help="Whether use cuda, 'cuda' or 'cpu'.")
29
+
30
+ parser.add_argument('--save_path', type=str, default="./logs",
31
+ help='Where to save logs and checkpoints.')
32
+
33
+ parser.add_argument('--dataset_path', type=str, default=r".\iHarmony4",
34
+ help='Dataset path.')
35
+
36
+ parser.add_argument('--base_size', type=int, default=256,
37
+ help='Base size. Resolution of the image input into the Encoder')
38
+
39
+ parser.add_argument('--input_size', type=int, default=256,
40
+ help='Input size. Resolution of the image that want to be generated by the Decoder')
41
+
42
+ parser.add_argument('--INR_input_size', type=int, default=256,
43
+ help='INR input size. Resolution of the image that want to be generated by the Decoder. '
44
+ 'Should be the same as `input_size`')
45
+
46
+ parser.add_argument('--INR_MLP_dim', type=int, default=32,
47
+ help='Number of channels for INR linear layer.')
48
+
49
+ parser.add_argument('--LUT_dim', type=int, default=7,
50
+ help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
51
+
52
+ parser.add_argument('--activation', type=str, default='leakyrelu_pe',
53
+ help='INR activation layer type: leakyrelu_pe, sine')
54
+
55
+ parser.add_argument('--pretrained', type=str,
56
+ default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
57
+ help='Pretrained weight path')
58
+
59
+ parser.add_argument('--param_factorize_dim', type=int,
60
+ default=10,
61
+ help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
62
+ 'Refer to https://arxiv.org/abs/2011.12026')
63
+
64
+ parser.add_argument('--embedding_type', type=str,
65
+ default="CIPS_embed",
66
+ help='Which embedding_type to use.')
67
+
68
+ parser.add_argument('--optim', type=str,
69
+ default='adamw',
70
+ help='Which optimizer to use.')
71
+
72
+ parser.add_argument('--INRDecode', action="store_false",
73
+ help='Whether INR decoder. Set it to False if you want to test the baseline '
74
+ '(https://github.com/SamsungLabs/image_harmonization)')
75
+
76
+ parser.add_argument('--isMoreINRInput', action="store_false",
77
+ help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
78
+
79
+ parser.add_argument('--hr_train', action="store_true",
80
+ help='Whether use hr_train. See section 3.4 in the paper.')
81
+
82
+ parser.add_argument('--isFullRes', action="store_true",
83
+ help='Whether for original resolution. See section 3.4 in the paper.')
84
+
85
+ opt = parser.parse_args()
86
+
87
+ opt.save_path = misc.increment_path(os.path.join(opt.save_path, "test1"))
88
+
89
+ return opt
90
+
91
+
92
+ def inference(val_loader, model, logger, opt):
93
+ current_process = 10
94
+ model.eval()
95
+
96
+ metric_log = {
97
+ 'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
98
+ 'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
99
+ 'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
100
+ 'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
101
+ 'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
102
+ }
103
+
104
+ lut_metric_log = {
105
+ 'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
106
+ 'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
107
+ 'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
108
+ 'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
109
+ 'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
110
+ }
111
+
112
+ for step, batch in enumerate(val_loader):
113
+ composite_image = batch['composite_image'].to(opt.device)
114
+ real_image = batch['real_image'].to(opt.device)
115
+ mask = batch['mask'].to(opt.device)
116
+ category = batch['category']
117
+
118
+ fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)
119
+
120
+ with torch.no_grad():
121
+ fg_content_bg_appearance_construct, _, lut_transform_image = model(
122
+ composite_image,
123
+ mask,
124
+ fg_INR_coordinates,
125
+ )
126
+
127
+ if opt.INRDecode:
128
+ pred_fg_image = fg_content_bg_appearance_construct[-1]
129
+ else:
130
+ pred_fg_image = misc.lin2img(fg_content_bg_appearance_construct,
131
+ val_loader.dataset.INR_dataset.size) if fg_content_bg_appearance_construct is not None else None
132
+
133
+ if not opt.INRDecode:
134
+ pred_harmonized_image = None
135
+ else:
136
+ pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
137
+ lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
138
+
139
+ misc.visualize(real_image, composite_image, mask, pred_fg_image,
140
+ pred_harmonized_image, lut_transform_image, opt, -1, show=False,
141
+ wandb=False, isAll=True, step=step)
142
+
143
+ if opt.INRDecode:
144
+ mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'),
145
+ misc.normalize(real_image, opt, 'inv'), mask)
146
+
147
+ lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'),
148
+ misc.normalize(real_image, opt, 'inv'), mask)
149
+
150
+ for idx in range(len(category)):
151
+ if opt.INRDecode:
152
+ metric_log[category[idx]]['Samples'] += 1
153
+ metric_log[category[idx]]['MSE'] += mse[idx]
154
+ metric_log[category[idx]]['fMSE'] += fmse[idx]
155
+ metric_log[category[idx]]['PSNR'] += psnr[idx]
156
+ metric_log[category[idx]]['SSIM'] += ssim[idx]
157
+
158
+ metric_log['All']['Samples'] += 1
159
+ metric_log['All']['MSE'] += mse[idx]
160
+ metric_log['All']['fMSE'] += fmse[idx]
161
+ metric_log['All']['PSNR'] += psnr[idx]
162
+ metric_log['All']['SSIM'] += ssim[idx]
163
+
164
+ lut_metric_log[category[idx]]['Samples'] += 1
165
+ lut_metric_log[category[idx]]['MSE'] += lut_mse[idx]
166
+ lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx]
167
+ lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx]
168
+ lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx]
169
+
170
+ lut_metric_log['All']['Samples'] += 1
171
+ lut_metric_log['All']['MSE'] += lut_mse[idx]
172
+ lut_metric_log['All']['fMSE'] += lut_fmse[idx]
173
+ lut_metric_log['All']['PSNR'] += lut_psnr[idx]
174
+ lut_metric_log['All']['SSIM'] += lut_ssim[idx]
175
+
176
+ if (step + 1) / len(val_loader) * 100 >= current_process:
177
+ logger.info(f'Processing: {current_process}')
178
+ current_process += 10
179
+
180
+ logger.info('=========================')
181
+ for key in metric_log.keys():
182
+ if opt.INRDecode:
183
+ msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \
184
+ f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \
185
+ f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \
186
+ f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \
187
+ f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
188
+ f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
189
+ f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
190
+ f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
191
+ else:
192
+ msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
193
+ f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
194
+ f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
195
+ f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
196
+
197
+ logger.info(msg)
198
+
199
+ logger.info('=========================')
200
+
201
+
202
+ def main_process(opt):
203
+ logger = misc.create_logger(os.path.join(opt.save_path, "log.txt"))
204
+ cudnn.benchmark = True
205
+
206
+ valset_path = os.path.join(opt.dataset_path, "IHD_test.txt")
207
+
208
+ opt.transform_mean = [.5, .5, .5]
209
+ opt.transform_var = [.5, .5, .5]
210
+ torch_transform = transforms.Compose([transforms.ToTensor(),
211
+ transforms.Normalize(opt.transform_mean, opt.transform_var)])
212
+
213
+ valset_alb_transform = albumentations.Compose([Resize(opt.input_size, opt.input_size)],
214
+ additional_targets={'real_image': 'image', 'object_mask': 'image'})
215
+
216
+ valset = dataset_generator(valset_path, valset_alb_transform, torch_transform, opt, mode='Val')
217
+
218
+ val_loader = DataLoader(valset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
219
+ num_workers=opt.workers, persistent_workers=True)
220
+
221
+ model = build_model(opt).to(opt.device)
222
+ logger.info(f"Load pretrained weight from {opt.pretrained}")
223
+
224
+ load_dict = torch.load(opt.pretrained)['model']
225
+ for k in load_dict.keys():
226
+ if k not in model.state_dict().keys():
227
+ print(f"Skip {k}")
228
+ model.load_state_dict(load_dict, strict=False)
229
+
230
+ inference(val_loader, model, logger, opt)
231
+
232
+
233
+ if __name__ == '__main__':
234
+ opt = parse_args()
235
+ os.makedirs(opt.save_path, exist_ok=True)
236
+ main_process(opt)
inference_for_arbitrary_resolution_image.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch.backends.cudnn as cudnn
4
+ import torchvision.transforms as transforms
5
+ from torch.utils.data import DataLoader
6
+
7
+ from model.build_model import build_model
8
+
9
+ import torch
10
+ import cv2
11
+ import numpy as np
12
+ import torchvision
13
+ import os
14
+ import tqdm
15
+ import time
16
+
17
+ from utils.misc import prepare_cooridinate_input, customRandomCrop
18
+
19
+ from datasets.build_INR_dataset import Implicit2DGenerator
20
+ import albumentations
21
+ from albumentations import Resize
22
+ from torch.utils.data import DataLoader
23
+ from utils.misc import normalize
24
+
25
+ import math
26
+
27
+ global_state = [1] # For Gradio Stop Button.
28
+
29
+ class single_image_dataset(torch.utils.data.Dataset):
30
+ def __init__(self, opt, composite_image=None, mask=None):
31
+ super().__init__()
32
+
33
+ self.opt = opt
34
+
35
+ if composite_image is None:
36
+ composite_image = cv2.imread(opt.composite_image)
37
+ composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
38
+ self.composite_image = composite_image
39
+
40
+ if mask is None:
41
+ mask = cv2.imread(opt.mask)
42
+ mask = mask[:, :, 0].astype(np.float32) / 255.
43
+ self.mask = mask
44
+
45
+ self.torch_transforms = transforms.Compose([transforms.ToTensor(),
46
+ transforms.Normalize([.5, .5, .5], [.5, .5, .5])])
47
+ self.INR_dataset = Implicit2DGenerator(opt, 'Val')
48
+
49
+ self.split_width_resolution = composite_image.shape[1] // opt.split_num
50
+ self.split_height_resolution = composite_image.shape[0] // opt.split_num
51
+
52
+ self.split_width_resolution = self.split_height_resolution = min(self.split_width_resolution,
53
+ self.split_height_resolution)
54
+
55
+ if self.split_width_resolution % 4 != 0:
56
+ self.split_width_resolution = self.split_width_resolution + (4 - self.split_width_resolution % 4)
57
+
58
+ if self.split_height_resolution % 4 != 0:
59
+ self.split_height_resolution = self.split_height_resolution + (4 - self.split_height_resolution % 4)
60
+
61
+ self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution)
62
+ self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution)
63
+
64
+ self.split_start_point = []
65
+
66
+ "Split the image into several parts."
67
+ for i in range(self.num_h):
68
+ for j in range(self.num_w):
69
+ if i == composite_image.shape[0] // self.split_height_resolution:
70
+ if j == composite_image.shape[1] // self.split_width_resolution:
71
+ self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution,
72
+ composite_image.shape[1] - self.split_width_resolution))
73
+ else:
74
+ self.split_start_point.append(
75
+ (composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution))
76
+ else:
77
+ if j == composite_image.shape[1] // self.split_width_resolution:
78
+ self.split_start_point.append(
79
+ (i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution))
80
+ else:
81
+ self.split_start_point.append(
82
+ (i * self.split_height_resolution, j * self.split_width_resolution))
83
+
84
+ assert len(self.split_start_point) == self.num_w * self.num_h
85
+
86
+ print(
87
+ f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.")
88
+ print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}")
89
+
90
+ def __len__(self):
91
+ return self.num_w * self.num_h
92
+
93
+ def __getitem__(self, idx):
94
+ composite_image = self.composite_image
95
+
96
+ mask = self.mask
97
+
98
+ full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
99
+
100
+ tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
101
+ additional_targets={'object_mask': 'image'})
102
+ transform_out = tmp_transform(image=composite_image, object_mask=mask)
103
+ compos_list = [self.torch_transforms(transform_out['image'])]
104
+ mask_list = [
105
+ torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
106
+ coord_map_list = []
107
+
108
+ if composite_image.shape[0] != self.split_height_resolution:
109
+ c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution)
110
+ else:
111
+ c_h = 0
112
+ if composite_image.shape[1] != self.split_width_resolution:
113
+ c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution)
114
+ else:
115
+ c_w = 0
116
+ transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord],
117
+ self.split_height_resolution, self.split_width_resolution, c_h, c_w)
118
+
119
+ compos_list.append(self.torch_transforms(transform_out[0]))
120
+ mask_list.append(
121
+ torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
122
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
123
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
124
+ for n in range(2):
125
+ tmp_comp = cv2.resize(composite_image, (
126
+ composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
127
+ tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
128
+ tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
129
+
130
+ transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord],
131
+ self.split_height_resolution // 2 ** (n + 1),
132
+ self.split_width_resolution // 2 ** (n + 1), c_h, c_w)
133
+ compos_list.append(self.torch_transforms(transform_out[0]))
134
+ mask_list.append(
135
+ torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32)))
136
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2]))
137
+ out_comp = compos_list
138
+ out_mask = mask_list
139
+ out_coord = coord_map_list
140
+
141
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
142
+ self.torch_transforms, transform_out[0], transform_out[0], mask)
143
+
144
+ return {
145
+ 'composite_image': out_comp,
146
+ 'mask': out_mask,
147
+ 'coordinate_map': out_coord,
148
+ 'composite_image0': out_comp[0],
149
+ 'mask0': out_mask[0],
150
+ 'coordinate_map0': out_coord[0],
151
+ 'composite_image1': out_comp[1],
152
+ 'mask1': out_mask[1],
153
+ 'coordinate_map1': out_coord[1],
154
+ 'composite_image2': out_comp[2],
155
+ 'mask2': out_mask[2],
156
+ 'coordinate_map2': out_coord[2],
157
+ 'composite_image3': out_comp[3],
158
+ 'mask3': out_mask[3],
159
+ 'coordinate_map3': out_coord[3],
160
+ 'fg_INR_coordinates': fg_INR_coordinates,
161
+ 'bg_INR_coordinates': bg_INR_coordinates,
162
+ 'fg_INR_RGB': fg_INR_RGB,
163
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
164
+ 'bg_INR_RGB': bg_INR_RGB,
165
+ 'start_point': self.split_start_point[idx],
166
+ }
167
+
168
+
169
+ def parse_args():
170
+ parser = argparse.ArgumentParser()
171
+
172
+ parser.add_argument('--split_num', type=int, default=4,
173
+ help='How many pieces do you want to split an image width / height.')
174
+
175
+ parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg',
176
+ help='composite image path')
177
+
178
+ parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg',
179
+ help='mask path')
180
+
181
+ parser.add_argument('--save_path', type=str, default=r'./demo/',
182
+ help='save path')
183
+
184
+ parser.add_argument('--workers', type=int, default=8,
185
+ metavar='N', help='Dataloader threads.')
186
+
187
+ parser.add_argument('--batch_size', type=int, default=1,
188
+ help='You can override model batch size by specify positive number.')
189
+
190
+ parser.add_argument('--device', type=str, default='cuda',
191
+ help="Whether use cuda, 'cuda' or 'cpu'.")
192
+
193
+ parser.add_argument('--base_size', type=int, default=256,
194
+ help='Base size. Resolution of the image input into the Encoder')
195
+
196
+ parser.add_argument('--input_size', type=int, default=256,
197
+ help='Input size. Resolution of the image that want to be generated by the Decoder')
198
+
199
+ parser.add_argument('--INR_input_size', type=int, default=256,
200
+ help='INR input size. Resolution of the image that want to be generated by the Decoder. '
201
+ 'Should be the same as `input_size`')
202
+
203
+ parser.add_argument('--INR_MLP_dim', type=int, default=32,
204
+ help='Number of channels for INR linear layer.')
205
+
206
+ parser.add_argument('--LUT_dim', type=int, default=7,
207
+ help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076')
208
+
209
+ parser.add_argument('--activation', type=str, default='leakyrelu_pe',
210
+ help='INR activation layer type: leakyrelu_pe, sine')
211
+
212
+ parser.add_argument('--pretrained', type=str,
213
+ default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth',
214
+ help='Pretrained weight path')
215
+
216
+ parser.add_argument('--param_factorize_dim', type=int,
217
+ default=10,
218
+ help='The intermediate dimensions of the factorization of the predicted MLP parameters. '
219
+ 'Refer to https://arxiv.org/abs/2011.12026')
220
+
221
+ parser.add_argument('--embedding_type', type=str,
222
+ default="CIPS_embed",
223
+ help='Which embedding_type to use.')
224
+
225
+ parser.add_argument('--INRDecode', action="store_false",
226
+ help='Whether INR decoder. Set it to False if you want to test the baseline '
227
+ '(https://github.com/SamsungLabs/image_harmonization)')
228
+
229
+ parser.add_argument('--isMoreINRInput', action="store_false",
230
+ help='Whether to cat RGB and mask. See Section 3.4 in the paper.')
231
+
232
+ parser.add_argument('--hr_train', action="store_false",
233
+ help='Whether use hr_train. See section 3.4 in the paper.')
234
+
235
+ parser.add_argument('--isFullRes', action="store_true",
236
+ help='Whether for original resolution. See section 3.4 in the paper.')
237
+
238
+ opt = parser.parse_args()
239
+
240
+ return opt
241
+
242
+ @torch.no_grad()
243
+ def inference(model, opt, composite_image=None, mask=None):
244
+ model.eval()
245
+
246
+ "dataset here is actually consisted of several patches of a single image."
247
+ singledataset = single_image_dataset(opt, composite_image, mask)
248
+
249
+ single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True,
250
+ num_workers=opt.workers, persistent_workers=False if composite_image is not None else True)
251
+
252
+ "Init a pure black image with the same size as the input image."
253
+ init_img = np.zeros_like(singledataset.composite_image)
254
+
255
+ time_all = 0
256
+
257
+ for step, batch in tqdm.tqdm(enumerate(single_data_loader)):
258
+ composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
259
+ mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
260
+ coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]
261
+ start_points = batch['start_point']
262
+
263
+ if opt.batch_size == 1:
264
+ start_points = [torch.cat(start_points)]
265
+
266
+ fg_INR_coordinates = coordinate_map[1:]
267
+
268
+ try:
269
+ if global_state[0] == 0:
270
+ print("Stop Harmonizing...!")
271
+ break
272
+
273
+ if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow.
274
+ fg_content_bg_appearance_construct, _, lut_transform_image = model(
275
+ composite_image,
276
+ mask,
277
+ fg_INR_coordinates,
278
+ )
279
+ print("Ready for harmonization...")
280
+
281
+ if opt.device == "cuda":
282
+ torch.cuda.reset_max_memory_allocated()
283
+ torch.cuda.reset_max_memory_cached()
284
+ start_time = time.time()
285
+ torch.cuda.synchronize()
286
+ fg_content_bg_appearance_construct, _, lut_transform_image = model(
287
+ composite_image,
288
+ mask,
289
+ fg_INR_coordinates,
290
+ )
291
+ if opt.device == "cuda":
292
+ torch.cuda.synchronize()
293
+ end_time = time.time()
294
+
295
+ end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2
296
+ end_memory = torch.cuda.memory_allocated() // 1024 ** 2
297
+
298
+ print(f'GPU max memory usage: {end_max_memory} MB')
299
+ print(f'GPU memory usage: {end_memory} MB')
300
+ time_all += (end_time - start_time)
301
+ print(f'progress: {step} / {len(single_data_loader)}')
302
+ except:
303
+ raise Exception(
304
+ f'The image resolution is large. Please increase the `split_num` value. Your current set is {opt.split_num}')
305
+
306
+ "Assemble the every patch's harmonized result into the final whole image."
307
+ for id in range(len(fg_INR_coordinates[0])):
308
+ pred_fg_image = fg_content_bg_appearance_construct[-1][id]
309
+ pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * (
310
+ ~(mask[1][id] > 100 / 255.))
311
+
312
+ pred_harmonized_tmp = cv2.cvtColor(
313
+ normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_(
314
+ 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR)
315
+
316
+ init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution,
317
+ start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp
318
+
319
+ if opt.device == "cuda":
320
+ print(f'Inference time: {time_all}')
321
+ if opt.save_path is not None:
322
+ os.makedirs(opt.save_path, exist_ok=True)
323
+ cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img)
324
+ return init_img
325
+
326
+
327
+ def main_process(opt, composite_image=None, mask=None):
328
+ cudnn.benchmark = True
329
+ # Заменяем 'gpu' на 'cuda'
330
+ opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
331
+ print("Preparing model...")
332
+ model = build_model(opt).to(opt.device)
333
+
334
+ load_dict = torch.load(opt.pretrained, map_location='cpu')['model']
335
+
336
+ model.load_state_dict(load_dict, strict=False)
337
+
338
+ return inference(model, opt, composite_image, mask)
339
+
340
+
341
+ if __name__ == '__main__':
342
+ opt = parse_args()
343
+ opt.transform_mean = [.5, .5, .5]
344
+ opt.transform_var = [.5, .5, .5]
345
+ main_process(opt)
model/__init__.py ADDED
File without changes
model/backbone.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .hrnetv2.hrnet_ocr import HighResolutionNet
4
+ from .hrnetv2.modifiers import LRMult
5
+ from .base.basic_blocks import MaxPoolDownSize
6
+ from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization
7
+
8
+
9
+ def build_backbone(name, opt):
10
+ return eval(name)(opt)
11
+
12
+
13
+ class baseline(IHModelWithBackbone):
14
+ def __init__(self, opt, ocr=64):
15
+ base_config = {'model': DeepImageHarmonization,
16
+ 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}
17
+
18
+ params = base_config['params']
19
+
20
+ backbone = HRNetV2(opt, ocr=ocr)
21
+
22
+ params.update(dict(
23
+ backbone_from=2,
24
+ backbone_channels=backbone.output_channels,
25
+ backbone_mode='cat',
26
+ opt=opt
27
+ ))
28
+ base_model = base_config['model'](**params)
29
+
30
+ super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)
31
+
32
+
33
+ class HRNetV2(nn.Module):
34
+ def __init__(
35
+ self, opt,
36
+ cat_outputs=True,
37
+ pyramid_channels=-1, pyramid_depth=4,
38
+ width=18, ocr=128, small=False,
39
+ lr_mult=0.1, pretained=True
40
+ ):
41
+ super(HRNetV2, self).__init__()
42
+ self.opt = opt
43
+ self.cat_outputs = cat_outputs
44
+ self.ocr_on = ocr > 0 and cat_outputs
45
+ self.pyramid_on = pyramid_channels > 0 and cat_outputs
46
+
47
+ self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
48
+ self.hrnet.apply(LRMult(lr_mult))
49
+ if self.ocr_on:
50
+ self.hrnet.ocr_distri_head.apply(LRMult(1.0))
51
+ self.hrnet.ocr_gather_head.apply(LRMult(1.0))
52
+ self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
53
+
54
+ hrnet_cat_channels = [width * 2 ** i for i in range(4)]
55
+ if self.pyramid_on:
56
+ self.output_channels = [pyramid_channels] * 4
57
+ elif self.ocr_on:
58
+ self.output_channels = [ocr * 2]
59
+ elif self.cat_outputs:
60
+ self.output_channels = [sum(hrnet_cat_channels)]
61
+ else:
62
+ self.output_channels = hrnet_cat_channels
63
+
64
+ if self.pyramid_on:
65
+ downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
66
+ self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
67
+
68
+ if pretained:
69
+ self.load_pretrained_weights(
70
+ "./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")
71
+
72
+ self.output_resolution = (opt.input_size // 8) ** 2
73
+
74
+ def forward(self, image, mask, mask_features=None):
75
+ outputs = list(self.hrnet(image, mask, mask_features))
76
+ return outputs
77
+
78
+ def load_pretrained_weights(self, pretrained_path):
79
+ self.hrnet.load_pretrained_weights(pretrained_path)
model/base/__init__.py ADDED
File without changes
model/base/basic_blocks.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ import numpy as np
4
+
5
+
6
+ def hyper_weight_init(m, in_features_main_net, activation):
7
+ if hasattr(m, 'weight'):
8
+ nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
9
+ m.weight.data = m.weight.data / 1.e2
10
+
11
+ if hasattr(m, 'bias'):
12
+ with torch.no_grad():
13
+ if activation == 'sine':
14
+ m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30)
15
+ elif activation == 'leakyrelu_pe':
16
+ m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net))
17
+ else:
18
+ raise NotImplementedError
19
+
20
+
21
+ class ConvBlock(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels, out_channels,
25
+ kernel_size=4, stride=2, padding=1,
26
+ norm_layer=nn.BatchNorm2d, activation=nn.ELU,
27
+ bias=True,
28
+ ):
29
+ super(ConvBlock, self).__init__()
30
+ self.block = nn.Sequential(
31
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
32
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
33
+ activation(),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.block(x)
38
+
39
+
40
+ class MaxPoolDownSize(nn.Module):
41
+ def __init__(self, in_channels, mid_channels, out_channels, depth):
42
+ super(MaxPoolDownSize, self).__init__()
43
+ self.depth = depth
44
+ self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
45
+ self.convs = nn.ModuleList([
46
+ ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
47
+ for conv_i in range(depth)
48
+ ])
49
+ self.pool2d = nn.MaxPool2d(kernel_size=2)
50
+
51
+ def forward(self, x):
52
+ outputs = []
53
+
54
+ output = self.reduce_conv(x)
55
+
56
+ for conv_i, conv in enumerate(self.convs):
57
+ output = output if conv_i == 0 else self.pool2d(output)
58
+ outputs.append(conv(output))
59
+
60
+ return outputs
61
+
62
+
63
+ class convParams(nn.Module):
64
+ def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False):
65
+ super(convParams, self).__init__()
66
+ self.INR_in_out = INR_in_out
67
+ self.cont_split_weight = []
68
+ self.cont_split_bias = []
69
+ self.hidden_mlp_num = hidden_mlp_num
70
+ self.param_factorize_dim = opt.param_factorize_dim
71
+ output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB)
72
+ self.output_dim = output_dim
73
+ self.toRGB = toRGB
74
+ self.cont_extraction_net = nn.Sequential(
75
+ nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
76
+ # nn.BatchNorm2d(hidden_dim),
77
+ nn.ReLU(inplace=True),
78
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
79
+ # nn.BatchNorm2d(hidden_dim),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True),
82
+ )
83
+
84
+ self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
85
+
86
+ self.basic_params = nn.ParameterList()
87
+ if opt.param_factorize_dim > 0:
88
+ for id in range(self.hidden_mlp_num + 1):
89
+ if id == 0:
90
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
91
+ else:
92
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
93
+ self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp)))
94
+
95
+ if toRGB:
96
+ self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3)))
97
+
98
+ def forward(self, feat, outMore=False):
99
+ cont_params = self.cont_extraction_net(feat)
100
+ out_mlp = self.to_mlp(cont_params)
101
+ if outMore:
102
+ return out_mlp, cont_params
103
+ return out_mlp
104
+
105
+ def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False):
106
+ cont_params = 0
107
+ start = 0
108
+ if self.param_factorize_dim == -1:
109
+ cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
110
+ self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
111
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
112
+ start = cont_params
113
+
114
+ for id in range(hidden_mlp_num):
115
+ cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
116
+ self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
117
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
118
+ start = cont_params
119
+
120
+ if toRGB:
121
+ cont_params += INR_in_out[1] * 3 + 3
122
+ self.cont_split_weight.append([start, cont_params - 3])
123
+ self.cont_split_bias.append([cont_params - 3, cont_params])
124
+
125
+ elif self.param_factorize_dim > 0:
126
+ cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
127
+ INR_in_out[1]
128
+ self.cont_split_weight.append(
129
+ [start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]])
130
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
131
+ start = cont_params
132
+
133
+ for id in range(hidden_mlp_num):
134
+ cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
135
+ INR_in_out[1]
136
+ self.cont_split_weight.append(
137
+ [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]])
138
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
139
+ start = cont_params
140
+
141
+ if toRGB:
142
+ cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
143
+ self.cont_split_weight.append(
144
+ [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3])
145
+ self.cont_split_bias.append([cont_params - 3, cont_params])
146
+
147
+ return cont_params
148
+
149
+ def to_mlp(self, params):
150
+ all_weight_bias = []
151
+ if self.param_factorize_dim == -1:
152
+ for id in range(self.hidden_mlp_num + 1):
153
+ if id == 0:
154
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
155
+ else:
156
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
157
+ weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
158
+ weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
159
+ inp, outp)
160
+
161
+ bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
162
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
163
+ all_weight_bias.append([weight, bias])
164
+
165
+ if self.toRGB:
166
+ inp, outp = self.INR_in_out[1], 3
167
+ weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
168
+ weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
169
+ inp, outp)
170
+
171
+ bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
172
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
173
+ all_weight_bias.append([weight, bias])
174
+
175
+ return all_weight_bias
176
+
177
+ else:
178
+ for id in range(self.hidden_mlp_num + 1):
179
+ if id == 0:
180
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
181
+ else:
182
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
183
+ weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
184
+ weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
185
+ inp, self.param_factorize_dim)
186
+
187
+ weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :]
188
+ weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
189
+ self.param_factorize_dim, outp)
190
+
191
+ bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
192
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
193
+
194
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
195
+
196
+ if self.toRGB:
197
+ inp, outp = self.INR_in_out[1], 3
198
+ weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
199
+ weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
200
+ inp, self.param_factorize_dim)
201
+
202
+ weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :]
203
+ weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
204
+ self.param_factorize_dim, outp)
205
+
206
+ bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
207
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
208
+
209
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias])
210
+
211
+ return all_weight_bias
212
+
213
+
214
+ class lineParams(nn.Module):
215
+ def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False,
216
+ hidden_dim=512):
217
+ super(lineParams, self).__init__()
218
+ self.INR_in_out = INR_in_out
219
+ self.app_split_weight = []
220
+ self.app_split_bias = []
221
+ self.toRGB = toRGB
222
+ self.hidden_mlp_num = hidden_mlp_num
223
+ self.param_factorize_dim = opt.param_factorize_dim
224
+ output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num)
225
+ self.output_dim = output_dim
226
+
227
+ self.compress_layer = nn.Sequential(
228
+ nn.Linear(input_resolution, 64, bias=False),
229
+ nn.BatchNorm1d(input_dim),
230
+ nn.ReLU(inplace=True),
231
+ nn.Linear(64, 1, bias=True)
232
+ )
233
+
234
+ self.app_extraction_net = nn.Sequential(
235
+ nn.Linear(input_dim, hidden_dim, bias=False),
236
+ # nn.BatchNorm1d(hidden_dim),
237
+ nn.ReLU(inplace=True),
238
+ nn.Linear(hidden_dim, hidden_dim, bias=False),
239
+ # nn.BatchNorm1d(hidden_dim),
240
+ nn.ReLU(inplace=True),
241
+ nn.Linear(hidden_dim, output_dim, bias=True)
242
+ )
243
+
244
+ self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
245
+
246
+ self.basic_params = nn.ParameterList()
247
+ if opt.param_factorize_dim > 0:
248
+ for id in range(self.hidden_mlp_num + 1):
249
+ if id == 0:
250
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
251
+ else:
252
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
253
+ self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp)))
254
+ if toRGB:
255
+ self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3)))
256
+
257
+ def forward(self, feat):
258
+ app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1))
259
+ out_mlp = self.to_mlp(app_params)
260
+ return out_mlp, app_params
261
+
262
+ def cal_params_num(self, INR_in_out, hidden_mlp_num):
263
+ app_params = 0
264
+ start = 0
265
+ if self.param_factorize_dim == -1:
266
+ app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
267
+ self.app_split_weight.append([start, app_params - INR_in_out[1]])
268
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
269
+ start = app_params
270
+
271
+ for id in range(hidden_mlp_num):
272
+ app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
273
+ self.app_split_weight.append([start, app_params - INR_in_out[1]])
274
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
275
+ start = app_params
276
+
277
+ if self.toRGB:
278
+ app_params += INR_in_out[1] * 3 + 3
279
+ self.app_split_weight.append([start, app_params - 3])
280
+ self.app_split_bias.append([app_params - 3, app_params])
281
+
282
+ elif self.param_factorize_dim > 0:
283
+ app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
284
+ INR_in_out[1]
285
+ self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim,
286
+ app_params - INR_in_out[1]])
287
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
288
+ start = app_params
289
+
290
+ for id in range(hidden_mlp_num):
291
+ app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
292
+ INR_in_out[1]
293
+ self.app_split_weight.append(
294
+ [start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]])
295
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
296
+ start = app_params
297
+
298
+ if self.toRGB:
299
+ app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
300
+ self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim,
301
+ app_params - 3])
302
+ self.app_split_bias.append([app_params - 3, app_params])
303
+
304
+ return app_params
305
+
306
+ def to_mlp(self, params):
307
+ all_weight_bias = []
308
+ if self.param_factorize_dim == -1:
309
+ for id in range(self.hidden_mlp_num + 1):
310
+ if id == 0:
311
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
312
+ else:
313
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
314
+ weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
315
+ weight = weight.view(weight.shape[0], inp, outp)
316
+
317
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
318
+ bias = bias.view(bias.shape[0], 1, outp)
319
+
320
+ all_weight_bias.append([weight, bias])
321
+
322
+ if self.toRGB:
323
+ id = -1
324
+ inp, outp = self.INR_in_out[1], 3
325
+ weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
326
+ weight = weight.view(weight.shape[0], inp, outp)
327
+
328
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
329
+ bias = bias.view(bias.shape[0], 1, outp)
330
+
331
+ all_weight_bias.append([weight, bias])
332
+
333
+ return all_weight_bias
334
+
335
+ else:
336
+ for id in range(self.hidden_mlp_num + 1):
337
+ if id == 0:
338
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
339
+ else:
340
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
341
+ weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
342
+ weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
343
+
344
+ weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
345
+ weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
346
+
347
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
348
+ bias = bias.view(bias.shape[0], 1, outp)
349
+
350
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
351
+
352
+ if self.toRGB:
353
+ id = -1
354
+ inp, outp = self.INR_in_out[1], 3
355
+ weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
356
+ weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
357
+
358
+ weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
359
+ weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
360
+
361
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
362
+ bias = bias.view(bias.shape[0], 1, outp)
363
+
364
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
365
+
366
+ return all_weight_bias
model/base/conv_autoencoder.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import math
7
+
8
+ from .basic_blocks import ConvBlock, lineParams, convParams
9
+ from .ops import MaskedChannelAttention, FeaturesConnector
10
+ from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed
11
+ from utils import misc
12
+ from utils.misc import lin2img
13
+ from ..lut_transformation_net import build_lut_transform
14
+
15
+
16
+ class Sine(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, input):
21
+ return torch.sin(30 * input)
22
+
23
+
24
+ class Leaky_relu(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def forward(self, input):
29
+ return torch.nn.functional.leaky_relu(input, 0.01, inplace=True)
30
+
31
+
32
+ def select_activation(type):
33
+ if type == 'sine':
34
+ return Sine()
35
+ elif type == 'leakyrelu_pe':
36
+ return Leaky_relu()
37
+ else:
38
+ raise NotImplementedError
39
+
40
+
41
+ class ConvEncoder(nn.Module):
42
+ def __init__(
43
+ self,
44
+ depth, ch,
45
+ norm_layer, batchnorm_from, max_channels,
46
+ backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False
47
+ ):
48
+ super(ConvEncoder, self).__init__()
49
+ self.depth = depth
50
+ self.INRDecode = INRDecode
51
+ self.backbone_from = backbone_from
52
+ backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
53
+
54
+ in_channels = 4
55
+ out_channels = ch
56
+
57
+ self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None)
58
+ self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None)
59
+ self.blocks_channels = [out_channels, out_channels]
60
+
61
+ self.blocks_connected = nn.ModuleDict()
62
+ self.connectors = nn.ModuleDict()
63
+ for block_i in range(2, depth):
64
+ if block_i % 2:
65
+ in_channels = out_channels
66
+ else:
67
+ in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
68
+
69
+ if 0 <= backbone_from <= block_i and len(backbone_channels):
70
+ if INRDecode:
71
+ self.blocks_connected[f'block{block_i}_decode'] = ConvBlock(
72
+ in_channels, out_channels,
73
+ norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
74
+ padding=int(block_i < depth - 1)
75
+ )
76
+ self.blocks_channels += [out_channels]
77
+ stage_channels = backbone_channels.pop()
78
+ connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
79
+ self.connectors[f'connector{block_i}'] = connector
80
+ in_channels = connector.output_channels
81
+
82
+ self.blocks_connected[f'block{block_i}'] = ConvBlock(
83
+ in_channels, out_channels,
84
+ norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
85
+ padding=int(block_i < depth - 1)
86
+ )
87
+ self.blocks_channels += [out_channels]
88
+
89
+ def forward(self, x, backbone_features):
90
+ backbone_features = [] if backbone_features is None else backbone_features[::-1]
91
+
92
+ outputs = [self.block0(x)]
93
+ outputs += [self.block1(outputs[-1])]
94
+
95
+ for block_i in range(2, self.depth):
96
+ output = outputs[-1]
97
+ connector_name = f'connector{block_i}'
98
+ if connector_name in self.connectors:
99
+ if self.INRDecode:
100
+ block = self.blocks_connected[f'block{block_i}_decode']
101
+ outputs += [block(output)]
102
+
103
+ stage_features = backbone_features.pop()
104
+ connector = self.connectors[connector_name]
105
+ output = connector(output, stage_features)
106
+ block = self.blocks_connected[f'block{block_i}']
107
+ outputs += [block(output)]
108
+
109
+ return outputs[::-1]
110
+
111
+
112
+ class DeconvDecoder(nn.Module):
113
+ def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False):
114
+ super(DeconvDecoder, self).__init__()
115
+ self.image_fusion = image_fusion
116
+ self.deconv_blocks = nn.ModuleList()
117
+
118
+ in_channels = encoder_blocks_channels.pop()
119
+ out_channels = in_channels
120
+ for d in range(depth):
121
+ out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
122
+ self.deconv_blocks.append(SEDeconvBlock(
123
+ in_channels, out_channels,
124
+ norm_layer=norm_layer,
125
+ padding=0 if d == 0 else 1,
126
+ with_se=0 <= attend_from <= d
127
+ ))
128
+ in_channels = out_channels
129
+
130
+ if self.image_fusion:
131
+ self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
132
+ self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
133
+
134
+ def forward(self, encoder_outputs, image, mask=None):
135
+ output = encoder_outputs[0]
136
+ for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
137
+ output = block(output, mask)
138
+ output = output + skip_output
139
+ output = self.deconv_blocks[-1](output, mask)
140
+
141
+ if self.image_fusion:
142
+ attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
143
+ output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output)
144
+ else:
145
+ output = self.to_rgb(output)
146
+
147
+ return output
148
+
149
+
150
+ class SEDeconvBlock(nn.Module):
151
+ def __init__(
152
+ self,
153
+ in_channels, out_channels,
154
+ kernel_size=4, stride=2, padding=1,
155
+ norm_layer=nn.BatchNorm2d, activation=nn.ELU,
156
+ with_se=False
157
+ ):
158
+ super(SEDeconvBlock, self).__init__()
159
+ self.with_se = with_se
160
+ self.block = nn.Sequential(
161
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
162
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
163
+ activation(),
164
+ )
165
+ if self.with_se:
166
+ self.se = MaskedChannelAttention(out_channels)
167
+
168
+ def forward(self, x, mask=None):
169
+ out = self.block(x)
170
+ if self.with_se:
171
+ out = self.se(out, mask)
172
+ return out
173
+
174
+
175
+ class INRDecoder(nn.Module):
176
+ def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from):
177
+ super(INRDecoder, self).__init__()
178
+ self.INR_encoding = None
179
+ if opt.embedding_type == "PosEncodingNeRF":
180
+ self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size)
181
+ elif opt.embedding_type == "RandomFourier":
182
+ self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device)
183
+ elif opt.embedding_type == "CIPS_embed":
184
+ self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32)
185
+ elif opt.embedding_type == "INRGAN_embed":
186
+ self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size)
187
+ else:
188
+ raise NotImplementedError
189
+ encoder_blocks_channels = encoder_blocks_channels[::-1]
190
+ max_hidden_mlp_num = attend_from + 1
191
+ self.opt = opt
192
+ self.max_hidden_mlp_num = max_hidden_mlp_num
193
+ self.content_mlp_blocks = nn.ModuleDict()
194
+ for n in range(max_hidden_mlp_num):
195
+ if n != max_hidden_mlp_num - 1:
196
+ self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
197
+ [self.INR_encoding.out_dim + opt.INR_MLP_dim + (
198
+ 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
199
+ opt, n + 1)
200
+ else:
201
+ self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
202
+ [self.INR_encoding.out_dim + (
203
+ 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
204
+ opt, n + 1)
205
+
206
+ self.deconv_blocks = nn.ModuleList()
207
+
208
+ encoder_blocks_channels = encoder_blocks_channels[::-1]
209
+ in_channels = encoder_blocks_channels.pop()
210
+ out_channels = in_channels
211
+ for d in range(depth - attend_from):
212
+ out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
213
+ self.deconv_blocks.append(SEDeconvBlock(
214
+ in_channels, out_channels,
215
+ norm_layer=norm_layer,
216
+ padding=0 if d == 0 else 1,
217
+ with_se=False
218
+ ))
219
+ in_channels = out_channels
220
+
221
+ self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim],
222
+ (opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2,
223
+ opt, 2, toRGB=True)
224
+
225
+ self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim,
226
+ None, opt)
227
+
228
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
229
+
230
+ def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None):
231
+ """For full resolution, do split."""
232
+ if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt,
233
+ 'split_resolution')) and self.opt.isFullRes:
234
+ return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples)
235
+
236
+ encoder_outputs = encoder_outputs[::-1]
237
+ mlp_output = None
238
+ waitToRGB = []
239
+ for n in range(self.max_hidden_mlp_num):
240
+ if not self.opt.hr_train:
241
+ coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \
242
+ .unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
243
+ else:
244
+ if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'):
245
+ coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view(
246
+ encoder_outputs[0].shape[0], -1, 2)
247
+ else:
248
+ coord = misc.get_mgrid(
249
+ self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat(
250
+ encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
251
+
252
+ """Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper."""
253
+ if self.opt.isMoreINRInput:
254
+ if not self.opt.isFullRes or (
255
+ self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
256
+ res_h = res_w = np.sqrt(coord.shape[1]).astype(int)
257
+ else:
258
+ res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1))
259
+ res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1))
260
+
261
+ res_image = torchvision.transforms.Resize([res_h, res_w])(image)
262
+ res_mask = torchvision.transforms.Resize([res_h, res_w])(mask)
263
+ coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1),
264
+ res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
265
+ else:
266
+ coord = self.INR_encoding(coord)
267
+
268
+ """============ LRIP structure, see Section 3.3 =============="""
269
+
270
+ """Local MLPs."""
271
+ if n == 0:
272
+ mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0),
273
+ self.opt, content_mlp=self.content_mlp_blocks[
274
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
275
+ encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion)
276
+ waitToRGB.append(mlp_output[1])
277
+ else:
278
+ mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
279
+ 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0],
280
+ content_mlp=self.content_mlp_blocks[
281
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
282
+ encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)),
283
+ start_proportion=start_proportion)
284
+ waitToRGB.append(mlp_output[1])
285
+
286
+ encoder_outputs = encoder_outputs[::-1]
287
+ output = encoder_outputs[0]
288
+ for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
289
+ output = block(output)
290
+ output = output + skip_output
291
+ output = self.deconv_blocks[-1](output)
292
+
293
+ """Global MLPs."""
294
+ app_mlp, app_params = self.appearance_mlps(output)
295
+ harm_out = []
296
+ for id in range(len(waitToRGB)):
297
+ output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id],
298
+ appearance_mlp=app_mlp)
299
+ harm_out.append(output[0])
300
+
301
+ """Optional 3D LUT prediction."""
302
+ fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
303
+
304
+ return harm_out, fit_lut3d, lut_transform_image
305
+
306
+ def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None,
307
+ resolution=None, start_proportion=None):
308
+
309
+ activation = select_activation(opt.activation)
310
+
311
+ output = None
312
+
313
+ if content_mlp is not None:
314
+ if base_feat is not None:
315
+ coorinates = torch.cat([coorinates, base_feat], dim=2)
316
+ coorinates = lin2img(coorinates, resolution)
317
+
318
+ if hasattr(opt, 'split_resolution'):
319
+ """
320
+ Here we crop the needed MLPs according to the region of the split input patches.
321
+ Note that this only support inferencing square images.
322
+ """
323
+ for idx in range(len(content_mlp)):
324
+ content_mlp[idx][0] = content_mlp[idx][0][:,
325
+ (content_mlp[idx][0].shape[1] * start_proportion[0]).int():(
326
+ content_mlp[idx][0].shape[1] * start_proportion[2]).int(),
327
+ (content_mlp[idx][0].shape[2] * start_proportion[1]).int():(
328
+ content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :,
329
+ :]
330
+ content_mlp[idx][1] = content_mlp[idx][1][:,
331
+ (content_mlp[idx][1].shape[1] * start_proportion[0]).int():(
332
+ content_mlp[idx][1].shape[1] * start_proportion[2]).int(),
333
+ (content_mlp[idx][1].shape[2] * start_proportion[1]).int():(
334
+ content_mlp[idx][1].shape[2] * start_proportion[3]).int(),
335
+ :,
336
+ :]
337
+ k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
338
+ k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
339
+ bs = coorinates.shape[0]
340
+ h_lr = w_lr = content_mlp[0][0].shape[1]
341
+ nci = INR_input_dim
342
+
343
+ coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
344
+ coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
345
+ bs, h_lr, w_lr, int(k_h * k_w), nci)
346
+
347
+ for id, layer in enumerate(content_mlp):
348
+ if id == 0:
349
+ output = torch.matmul(coorinates, layer[0]) + layer[1]
350
+ output = activation(output)
351
+ else:
352
+ output = torch.matmul(output, layer[0]) + layer[1]
353
+ output = activation(output)
354
+
355
+ output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
356
+ 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
357
+
358
+ output_large = self.up(lin2img(output))
359
+
360
+ return output_large.view(bs, -1, opt.INR_MLP_dim), output
361
+
362
+ k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
363
+ k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
364
+ bs = coorinates.shape[0]
365
+ h_lr = w_lr = content_mlp[0][0].shape[1]
366
+ nci = INR_input_dim
367
+
368
+ """(evaluation or not HR training) and not fullres evaluation"""
369
+ if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not (
370
+ not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train):
371
+ coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
372
+ coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
373
+ bs, h_lr, w_lr, int(k_h * k_w), nci)
374
+
375
+ for id, layer in enumerate(content_mlp):
376
+ if id == 0:
377
+ output = torch.matmul(coorinates, layer[0]) + layer[1]
378
+ output = activation(output)
379
+ else:
380
+ output = torch.matmul(output, layer[0]) + layer[1]
381
+ output = activation(output)
382
+
383
+ output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
384
+ 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
385
+
386
+ output_large = self.up(lin2img(output))
387
+
388
+ return output_large.view(bs, -1, opt.INR_MLP_dim), output
389
+ else:
390
+ coorinates = coorinates.permute(0, 2, 3, 1)
391
+ for id, layer in enumerate(content_mlp):
392
+ weigt_shape = layer[0].shape
393
+ bias_shape = layer[1].shape
394
+ layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
395
+ layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
396
+ layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True
397
+ else 'bilinear', padding_mode='border', align_corners=False)
398
+ layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True
399
+ else 'bilinear', padding_mode='border', align_corners=False)
400
+ layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:])
401
+ layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:])
402
+
403
+ if id == 0:
404
+ output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1]
405
+ output = activation(output)
406
+ else:
407
+ output = torch.matmul(output, layer[0]) + layer[1]
408
+ output = activation(output)
409
+
410
+ output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim)
411
+
412
+ output_large = self.up(lin2img(output, resolution))
413
+
414
+ return output_large.view(bs, -1, opt.INR_MLP_dim), output
415
+
416
+ elif appearance_mlp is not None:
417
+ output = base_feat
418
+ genMask = None
419
+ for id, layer in enumerate(appearance_mlp):
420
+ if id != len(appearance_mlp) - 1:
421
+ output = torch.matmul(output, layer[0]) + layer[1]
422
+ output = activation(output)
423
+ else:
424
+ output = torch.matmul(output, layer[0]) + layer[1] # last layer
425
+ if opt.activation == 'leakyrelu_pe':
426
+ output = torch.tanh(output)
427
+ return lin2img(output, resolution), None
428
+
429
+ def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None):
430
+ encoder_outputs = encoder_outputs[::-1]
431
+ mlp_output = None
432
+ res_w = image.shape[-1]
433
+ res_h = image.shape[-2]
434
+ coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat(
435
+ encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
436
+
437
+ if self.opt.isMoreINRInput:
438
+ coord = torch.cat(
439
+ [self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1),
440
+ mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
441
+ else:
442
+ coord = self.INR_encoding(coord, (res_h, res_w))
443
+
444
+ total = coord.clone()
445
+
446
+ interval = 10
447
+ all_intervals = math.ceil(res_h / interval)
448
+ divisible = True
449
+ if res_h / interval != res_h // interval:
450
+ divisible = False
451
+
452
+ for n in range(self.max_hidden_mlp_num):
453
+ accum_mlp_output = []
454
+ for line in range(all_intervals):
455
+ if not divisible and line == all_intervals - 1:
456
+ coord = total[:, line * interval * res_w:, :]
457
+ else:
458
+ coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :]
459
+ if n == 0:
460
+ accum_mlp_output.append(self.mlp_process(coord,
461
+ self.INR_encoding.out_dim + (
462
+ 4 if self.opt.isMoreINRInput else 0),
463
+ self.opt, content_mlp=self.content_mlp_blocks[
464
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
465
+ encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
466
+ encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
467
+ resolution=(interval,
468
+ res_w) if divisible or line != all_intervals - 1 else (
469
+ res_h - interval * (all_intervals - 1), res_w))[1])
470
+
471
+ else:
472
+ accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
473
+ 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:,
474
+ line * interval * res_w: (
475
+ line + 1) * interval * res_w,
476
+ :]
477
+ if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :],
478
+ content_mlp=self.content_mlp_blocks[
479
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
480
+ encoder_outputs.pop(
481
+ self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
482
+ encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
483
+ resolution=(interval,
484
+ res_w) if divisible or line != all_intervals - 1 else (
485
+ res_h - interval * (all_intervals - 1), res_w))[1])
486
+
487
+ accum_mlp_output = torch.cat(accum_mlp_output, dim=1)
488
+ mlp_output = [accum_mlp_output, accum_mlp_output]
489
+
490
+ encoder_outputs = encoder_outputs[::-1]
491
+ output = encoder_outputs[0]
492
+ for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
493
+ output = block(output)
494
+ output = output + skip_output
495
+ output = self.deconv_blocks[-1](output)
496
+
497
+ app_mlp, app_params = self.appearance_mlps(output)
498
+ harm_out = []
499
+
500
+ accum_mlp_output = []
501
+ for line in range(all_intervals):
502
+ if not divisible and line == all_intervals - 1:
503
+ base = mlp_output[1][:, line * interval * res_w:, :]
504
+ else:
505
+ base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :]
506
+
507
+ accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base,
508
+ appearance_mlp=app_mlp,
509
+ resolution=(
510
+ interval,
511
+ res_w) if divisible or line != all_intervals - 1 else (
512
+ res_h - interval * (all_intervals - 1), res_w))[0])
513
+
514
+ accum_mlp_output = torch.cat(accum_mlp_output, dim=2)
515
+ harm_out.append(accum_mlp_output)
516
+
517
+ fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
518
+
519
+ return harm_out, fit_lut3d, lut_transform_image
model/base/ih_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+
5
+ from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder
6
+
7
+ from .ops import ScaleLayer
8
+
9
+
10
+ class IHModelWithBackbone(nn.Module):
11
+ def __init__(
12
+ self,
13
+ model, backbone,
14
+ downsize_backbone_input=False,
15
+ mask_fusion='sum',
16
+ backbone_conv1_channels=64, opt=None
17
+ ):
18
+ super(IHModelWithBackbone, self).__init__()
19
+ self.downsize_backbone_input = downsize_backbone_input
20
+ self.mask_fusion = mask_fusion
21
+
22
+ self.backbone = backbone
23
+ self.model = model
24
+ self.opt = opt
25
+
26
+ self.mask_conv = nn.Sequential(
27
+ nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
28
+ ScaleLayer(init_value=0.1, lr_mult=1)
29
+ )
30
+
31
+ def forward(self, image, mask, coord=None, start_proportion=None):
32
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
33
+ backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
34
+ backbone_mask = torch.cat(
35
+ (torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
36
+ 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
37
+ else:
38
+ backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
39
+ backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
40
+ 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
41
+
42
+ backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
43
+ backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
44
+
45
+ output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
46
+ return output
47
+
48
+
49
+ class DeepImageHarmonization(nn.Module):
50
+ def __init__(
51
+ self,
52
+ depth,
53
+ norm_layer=nn.BatchNorm2d, batchnorm_from=0,
54
+ attend_from=-1,
55
+ image_fusion=False,
56
+ ch=64, max_channels=512,
57
+ backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
58
+ ):
59
+ super(DeepImageHarmonization, self).__init__()
60
+ self.depth = depth
61
+ self.encoder = ConvEncoder(
62
+ depth, ch,
63
+ norm_layer, batchnorm_from, max_channels,
64
+ backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
65
+ )
66
+ self.opt = opt
67
+ if opt.INRDecode:
68
+ "See Table 2 in the paper to test with different INR decoders' structures."
69
+ self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
70
+ else:
71
+ "Baseline: https://github.com/SamsungLabs/image_harmonization"
72
+ self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
73
+
74
+ def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
75
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
76
+ x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
77
+ torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
78
+ else:
79
+ x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
80
+ torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
81
+
82
+ intermediates = self.encoder(x, backbone_features)
83
+
84
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
85
+ output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
86
+ else:
87
+ output = self.decoder(intermediates, image, mask)
88
+ return output
model/base/ops.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class SimpleInputFusion(nn.Module):
9
+ def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
10
+ super(SimpleInputFusion, self).__init__()
11
+
12
+ self.fusion_conv = nn.Sequential(
13
+ nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
14
+ nn.LeakyReLU(negative_slope=0.2),
15
+ norm_layer(ch),
16
+ nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
17
+ )
18
+
19
+ def forward(self, image, additional_input):
20
+ return self.fusion_conv(torch.cat((image, additional_input), dim=1))
21
+
22
+
23
+ class MaskedChannelAttention(nn.Module):
24
+ def __init__(self, in_channels, *args, **kwargs):
25
+ super(MaskedChannelAttention, self).__init__()
26
+ self.global_max_pool = MaskedGlobalMaxPool2d()
27
+ self.global_avg_pool = FastGlobalAvgPool2d()
28
+
29
+ intermediate_channels_count = max(in_channels // 16, 8)
30
+ self.attention_transform = nn.Sequential(
31
+ nn.Linear(3 * in_channels, intermediate_channels_count),
32
+ nn.ReLU(inplace=True),
33
+ nn.Linear(intermediate_channels_count, in_channels),
34
+ nn.Sigmoid(),
35
+ )
36
+
37
+ def forward(self, x, mask):
38
+ if mask.shape[2:] != x.shape[:2]:
39
+ mask = nn.functional.interpolate(
40
+ mask, size=x.size()[-2:],
41
+ mode='bilinear', align_corners=True
42
+ )
43
+ pooled_x = torch.cat([
44
+ self.global_max_pool(x, mask),
45
+ self.global_avg_pool(x)
46
+ ], dim=1)
47
+ channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
48
+
49
+ return channel_attention_weights * x
50
+
51
+
52
+ class MaskedGlobalMaxPool2d(nn.Module):
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.global_max_pool = FastGlobalMaxPool2d()
56
+
57
+ def forward(self, x, mask):
58
+ return torch.cat((
59
+ self.global_max_pool(x * mask),
60
+ self.global_max_pool(x * (1.0 - mask))
61
+ ), dim=1)
62
+
63
+
64
+ class FastGlobalAvgPool2d(nn.Module):
65
+ def __init__(self):
66
+ super(FastGlobalAvgPool2d, self).__init__()
67
+
68
+ def forward(self, x):
69
+ in_size = x.size()
70
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
71
+
72
+
73
+ class FastGlobalMaxPool2d(nn.Module):
74
+ def __init__(self):
75
+ super(FastGlobalMaxPool2d, self).__init__()
76
+
77
+ def forward(self, x):
78
+ in_size = x.size()
79
+ return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
80
+
81
+
82
+ class ScaleLayer(nn.Module):
83
+ def __init__(self, init_value=1.0, lr_mult=1):
84
+ super().__init__()
85
+ self.lr_mult = lr_mult
86
+ self.scale = nn.Parameter(
87
+ torch.full((1,), init_value / lr_mult, dtype=torch.float32)
88
+ )
89
+
90
+ def forward(self, x):
91
+ scale = torch.abs(self.scale * self.lr_mult)
92
+ return x * scale
93
+
94
+
95
+ class FeaturesConnector(nn.Module):
96
+ def __init__(self, mode, in_channels, feature_channels, out_channels):
97
+ super(FeaturesConnector, self).__init__()
98
+ self.mode = mode if feature_channels else ''
99
+
100
+ if self.mode == 'catc':
101
+ self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
102
+ elif self.mode == 'sum':
103
+ self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
104
+
105
+ self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
106
+
107
+ def forward(self, x, features):
108
+ if self.mode == 'cat':
109
+ return torch.cat((x, features), 1)
110
+ if self.mode == 'catc':
111
+ return self.reduce_conv(torch.cat((x, features), 1))
112
+ if self.mode == 'sum':
113
+ return self.reduce_conv(features) + x
114
+ return x
115
+
116
+ def extra_repr(self):
117
+ return self.mode
118
+
119
+
120
+ class PosEncodingNeRF(nn.Module):
121
+ def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
122
+ super().__init__()
123
+
124
+ self.in_features = in_features
125
+
126
+ if self.in_features == 3:
127
+ self.num_frequencies = 10
128
+ elif self.in_features == 2:
129
+ assert sidelength is not None
130
+ if isinstance(sidelength, int):
131
+ sidelength = (sidelength, sidelength)
132
+ self.num_frequencies = 4
133
+ if use_nyquist:
134
+ self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
135
+ elif self.in_features == 1:
136
+ assert fn_samples is not None
137
+ self.num_frequencies = 4
138
+ if use_nyquist:
139
+ self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
140
+
141
+ self.out_dim = in_features + 2 * in_features * self.num_frequencies
142
+
143
+ def get_num_frequencies_nyquist(self, samples):
144
+ nyquist_rate = 1 / (2 * (2 * 1 / samples))
145
+ return int(math.floor(math.log(nyquist_rate, 2)))
146
+
147
+ def forward(self, coords):
148
+ coords = coords.view(coords.shape[0], -1, self.in_features)
149
+
150
+ coords_pos_enc = coords
151
+ for i in range(self.num_frequencies):
152
+ for j in range(self.in_features):
153
+ c = coords[..., j]
154
+
155
+ sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
156
+ cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)
157
+
158
+ coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)
159
+
160
+ return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)
161
+
162
+
163
+ class RandomFourier(nn.Module):
164
+ def __init__(self, std_scale, embedding_length, device):
165
+ super().__init__()
166
+
167
+ self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale
168
+ self.embed = self.embed.to(device)
169
+
170
+ self.out_dim = embedding_length * 2 + 2
171
+
172
+ def forward(self, coords):
173
+ coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)),
174
+ torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1)
175
+
176
+ return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1)
177
+
178
+
179
+ class CIPS_embed(nn.Module):
180
+ def __init__(self, size, embedding_length):
181
+ super().__init__()
182
+ self.fourier_embed = ConstantInput(size, embedding_length)
183
+ self.predict_embed = Predict_embed(embedding_length)
184
+ self.out_dim = embedding_length * 2 + 2
185
+
186
+ def forward(self, coord, res=None):
187
+ x = self.predict_embed(coord)
188
+ y = self.fourier_embed(x, coord, res)
189
+
190
+ return torch.cat([coord, x, y], dim=-1)
191
+
192
+
193
+ class Predict_embed(nn.Module):
194
+ def __init__(self, embedding_length):
195
+ super(Predict_embed, self).__init__()
196
+ self.ffm = nn.Linear(2, embedding_length, bias=True)
197
+ nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2))
198
+
199
+ def forward(self, x):
200
+ x = self.ffm(x)
201
+ x = torch.sin(x)
202
+ return x
203
+
204
+
205
+ class ConstantInput(nn.Module):
206
+ def __init__(self, size, channel):
207
+ super().__init__()
208
+
209
+ self.input = nn.Parameter(torch.randn(1, size ** 2, channel))
210
+
211
+ def forward(self, input, coord, resolution=None):
212
+ batch = input.shape[0]
213
+ out = self.input.repeat(batch, 1, 1)
214
+
215
+ if coord.shape[1] != self.input.shape[1]:
216
+ x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1],
217
+ int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5))
218
+
219
+ if resolution is None:
220
+ grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1])
221
+ else:
222
+ grid = coord.view(coord.shape[0], *resolution, coord.shape[-1])
223
+
224
+ out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True)
225
+
226
+ out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1])
227
+
228
+ return out
229
+
230
+
231
+ class INRGAN_embed(nn.Module):
232
+ def __init__(self, resolution: int, w_dim=None):
233
+ super().__init__()
234
+
235
+ self.resolution = resolution
236
+ self.res_cfg = {"log_emb_size": 32,
237
+ "random_emb_size": 32,
238
+ "const_emb_size": 64,
239
+ "use_cosine": True}
240
+ self.log_emb_size = self.res_cfg.get('log_emb_size', 0)
241
+ self.random_emb_size = self.res_cfg.get('random_emb_size', 0)
242
+ self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0)
243
+ self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0)
244
+ self.const_emb_size = self.res_cfg.get('const_emb_size', 0)
245
+ self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10))
246
+ self.use_cosine = self.res_cfg.get('use_cosine', False)
247
+
248
+ if self.log_emb_size > 0:
249
+ self.register_buffer('log_basis', generate_logarithmic_basis(
250
+ resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False)))
251
+
252
+ if self.random_emb_size > 0:
253
+ self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale))
254
+
255
+ if self.shared_emb_size > 0:
256
+ self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale))
257
+
258
+ if self.predictable_emb_size > 0:
259
+ self.W_size = self.predictable_emb_size * self.cfg.coord_dim
260
+ self.b_size = self.predictable_emb_size
261
+ self.affine = nn.Linear(w_dim, self.W_size + self.b_size)
262
+
263
+ if self.const_emb_size > 0:
264
+ self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size))
265
+
266
+ self.out_dim = self.get_total_dim() + 2
267
+
268
+ def sample_w_matrix(self, shape, scale: float):
269
+ return torch.randn(shape) * scale
270
+
271
+ def get_total_dim(self) -> int:
272
+ total_dim = 0
273
+ if self.log_emb_size > 0:
274
+ total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1)
275
+ total_dim += self.random_emb_size * (2 if self.use_cosine else 1)
276
+ total_dim += self.shared_emb_size * (2 if self.use_cosine else 1)
277
+ total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1)
278
+ total_dim += self.const_emb_size
279
+
280
+ return total_dim
281
+
282
+ def forward(self, raw_coords, w=None):
283
+ batch_size, img_size, in_channels = raw_coords.shape
284
+
285
+ raw_embs = []
286
+
287
+ if self.log_emb_size > 0:
288
+ log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1)
289
+ raw_log_embs = torch.matmul(raw_coords, log_bases)
290
+ raw_embs.append(raw_log_embs)
291
+
292
+ if self.random_emb_size > 0:
293
+ random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1)
294
+ raw_random_embs = torch.matmul(raw_coords, random_bases)
295
+ raw_embs.append(raw_random_embs)
296
+
297
+ if self.shared_emb_size > 0:
298
+ shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1)
299
+ raw_shared_embs = torch.matmul(raw_coords, shared_bases)
300
+ raw_embs.append(raw_shared_embs)
301
+
302
+ if self.predictable_emb_size > 0:
303
+ mod = self.affine(w)
304
+ W = self.fourier_scale * mod[:, :self.W_size]
305
+ W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size)
306
+ bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size)
307
+ raw_predictable_embs = (torch.matmul(raw_coords, W) + bias)
308
+ raw_embs.append(raw_predictable_embs)
309
+
310
+ if len(raw_embs) > 0:
311
+ raw_embs = torch.cat(raw_embs, dim=-1)
312
+ raw_embs = raw_embs.contiguous()
313
+ out = raw_embs.sin()
314
+
315
+ if self.use_cosine:
316
+ out = torch.cat([out, raw_embs.cos()], dim=-1)
317
+
318
+ if self.const_emb_size > 0:
319
+ const_embs = self.const_embs.repeat([batch_size, 1, 1])
320
+ const_embs = const_embs
321
+ out = torch.cat([out, const_embs], dim=-1)
322
+
323
+ return torch.cat([raw_coords, out], dim=-1)
324
+
325
+
326
+ def generate_logarithmic_basis(
327
+ resolution,
328
+ max_num_feats,
329
+ remove_lowest_freq: bool = False,
330
+ use_diagonal: bool = True):
331
+ """
332
+ Generates a directional logarithmic basis with the following directions:
333
+ - horizontal
334
+ - vertical
335
+ - main diagonal
336
+ - anti-diagonal
337
+ """
338
+ max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int)
339
+ bases = [
340
+ generate_horizontal_basis(max_num_feats_per_direction),
341
+ generate_vertical_basis(max_num_feats_per_direction),
342
+ ]
343
+
344
+ if use_diagonal:
345
+ bases.extend([
346
+ generate_diag_main_basis(max_num_feats_per_direction),
347
+ generate_anti_diag_basis(max_num_feats_per_direction),
348
+ ])
349
+
350
+ if remove_lowest_freq:
351
+ bases = [b[1:] for b in bases]
352
+
353
+ # If we do not fit into `max_num_feats`, then trying to remove the features in the order:
354
+ # 1) anti-diagonal 2) main-diagonal
355
+ # while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2):
356
+ # bases = bases[:-1]
357
+
358
+ basis = torch.cat(bases, dim=0)
359
+
360
+ # If we still do not fit, then let's remove each second feature,
361
+ # then each third, each forth and so on
362
+ # We cannot drop the whole horizontal or vertical direction since otherwise
363
+ # model won't be able to locate the position
364
+ # (unless the previously computed embeddings encode the position)
365
+ # while basis.shape[0] > max_num_feats:
366
+ # num_exceeding_feats = basis.shape[0] - max_num_feats
367
+ # basis = basis[::2]
368
+
369
+ assert basis.shape[0] <= max_num_feats, \
370
+ f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}."
371
+
372
+ return basis
373
+
374
+
375
+ def generate_horizontal_basis(num_feats: int):
376
+ return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0)
377
+
378
+
379
+ def generate_vertical_basis(num_feats: int):
380
+ return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0)
381
+
382
+
383
+ def generate_diag_main_basis(num_feats: int):
384
+ return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
385
+
386
+
387
+ def generate_anti_diag_basis(num_feats: int):
388
+ return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
389
+
390
+
391
+ def generate_wavefront_basis(num_feats: int, basis_block, period_length: float):
392
+ period_coef = 2.0 * np.pi / period_length
393
+ basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2]
394
+ powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1]
395
+ result = basis * powers * period_coef # [num_feats, 2]
396
+
397
+ return result.float()
model/build_model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .backbone import build_backbone
3
+
4
+
5
+ class build_model(nn.Module):
6
+ def __init__(self, opt):
7
+ super().__init__()
8
+
9
+ self.opt = opt
10
+ self.backbone = build_backbone('baseline', opt)
11
+
12
+ def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
13
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
14
+ """
15
+ For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
16
+ here we need to pass in the coordinates of the cropped regions.
17
+ """
18
+ extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
19
+ else:
20
+ extracted_features = self.backbone(composite_image, mask)
21
+
22
+ if self.opt.INRDecode:
23
+ return extracted_features
24
+ return None, None, extracted_features
model/hrnetv2/__init__.py ADDED
File without changes
model/hrnetv2/hrnet_ocr.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch._utils
7
+ from .ocr import SpatialOCR_Module, SpatialGather_Module
8
+ from .resnetv1b import BasicBlockV1b, BottleneckV1b
9
+
10
+ relu_inplace = True
11
+
12
+
13
+ class HighResolutionModule(nn.Module):
14
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
15
+ num_channels, fuse_method,multi_scale_output=True,
16
+ norm_layer=nn.BatchNorm2d, align_corners=True):
17
+ super(HighResolutionModule, self).__init__()
18
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
19
+
20
+ self.num_inchannels = num_inchannels
21
+ self.fuse_method = fuse_method
22
+ self.num_branches = num_branches
23
+ self.norm_layer = norm_layer
24
+ self.align_corners = align_corners
25
+
26
+ self.multi_scale_output = multi_scale_output
27
+
28
+ self.branches = self._make_branches(
29
+ num_branches, blocks, num_blocks, num_channels)
30
+ self.fuse_layers = self._make_fuse_layers()
31
+ self.relu = nn.ReLU(inplace=relu_inplace)
32
+
33
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
34
+ if num_branches != len(num_blocks):
35
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
36
+ num_branches, len(num_blocks))
37
+ raise ValueError(error_msg)
38
+
39
+ if num_branches != len(num_channels):
40
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
41
+ num_branches, len(num_channels))
42
+ raise ValueError(error_msg)
43
+
44
+ if num_branches != len(num_inchannels):
45
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
46
+ num_branches, len(num_inchannels))
47
+ raise ValueError(error_msg)
48
+
49
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
50
+ stride=1):
51
+ downsample = None
52
+ if stride != 1 or \
53
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
54
+ downsample = nn.Sequential(
55
+ nn.Conv2d(self.num_inchannels[branch_index],
56
+ num_channels[branch_index] * block.expansion,
57
+ kernel_size=1, stride=stride, bias=False),
58
+ self.norm_layer(num_channels[branch_index] * block.expansion),
59
+ )
60
+
61
+ layers = []
62
+ layers.append(block(self.num_inchannels[branch_index],
63
+ num_channels[branch_index], stride,
64
+ downsample=downsample, norm_layer=self.norm_layer))
65
+ self.num_inchannels[branch_index] = \
66
+ num_channels[branch_index] * block.expansion
67
+ for i in range(1, num_blocks[branch_index]):
68
+ layers.append(block(self.num_inchannels[branch_index],
69
+ num_channels[branch_index],
70
+ norm_layer=self.norm_layer))
71
+
72
+ return nn.Sequential(*layers)
73
+
74
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
75
+ branches = []
76
+
77
+ for i in range(num_branches):
78
+ branches.append(
79
+ self._make_one_branch(i, block, num_blocks, num_channels))
80
+
81
+ return nn.ModuleList(branches)
82
+
83
+ def _make_fuse_layers(self):
84
+ if self.num_branches == 1:
85
+ return None
86
+
87
+ num_branches = self.num_branches
88
+ num_inchannels = self.num_inchannels
89
+ fuse_layers = []
90
+ for i in range(num_branches if self.multi_scale_output else 1):
91
+ fuse_layer = []
92
+ for j in range(num_branches):
93
+ if j > i:
94
+ fuse_layer.append(nn.Sequential(
95
+ nn.Conv2d(in_channels=num_inchannels[j],
96
+ out_channels=num_inchannels[i],
97
+ kernel_size=1,
98
+ bias=False),
99
+ self.norm_layer(num_inchannels[i])))
100
+ elif j == i:
101
+ fuse_layer.append(None)
102
+ else:
103
+ conv3x3s = []
104
+ for k in range(i - j):
105
+ if k == i - j - 1:
106
+ num_outchannels_conv3x3 = num_inchannels[i]
107
+ conv3x3s.append(nn.Sequential(
108
+ nn.Conv2d(num_inchannels[j],
109
+ num_outchannels_conv3x3,
110
+ kernel_size=3, stride=2, padding=1, bias=False),
111
+ self.norm_layer(num_outchannels_conv3x3)))
112
+ else:
113
+ num_outchannels_conv3x3 = num_inchannels[j]
114
+ conv3x3s.append(nn.Sequential(
115
+ nn.Conv2d(num_inchannels[j],
116
+ num_outchannels_conv3x3,
117
+ kernel_size=3, stride=2, padding=1, bias=False),
118
+ self.norm_layer(num_outchannels_conv3x3),
119
+ nn.ReLU(inplace=relu_inplace)))
120
+ fuse_layer.append(nn.Sequential(*conv3x3s))
121
+ fuse_layers.append(nn.ModuleList(fuse_layer))
122
+
123
+ return nn.ModuleList(fuse_layers)
124
+
125
+ def get_num_inchannels(self):
126
+ return self.num_inchannels
127
+
128
+ def forward(self, x):
129
+ if self.num_branches == 1:
130
+ return [self.branches[0](x[0])]
131
+
132
+ for i in range(self.num_branches):
133
+ x[i] = self.branches[i](x[i])
134
+
135
+ x_fuse = []
136
+ for i in range(len(self.fuse_layers)):
137
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
138
+ for j in range(1, self.num_branches):
139
+ if i == j:
140
+ y = y + x[j]
141
+ elif j > i:
142
+ width_output = x[i].shape[-1]
143
+ height_output = x[i].shape[-2]
144
+ y = y + F.interpolate(
145
+ self.fuse_layers[i][j](x[j]),
146
+ size=[height_output, width_output],
147
+ mode='bilinear', align_corners=self.align_corners)
148
+ else:
149
+ y = y + self.fuse_layers[i][j](x[j])
150
+ x_fuse.append(self.relu(y))
151
+
152
+ return x_fuse
153
+
154
+
155
+ class HighResolutionNet(nn.Module):
156
+ def __init__(self, width, num_classes, ocr_width=256, small=False,
157
+ norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
158
+ super(HighResolutionNet, self).__init__()
159
+ self.opt = opt
160
+ self.norm_layer = norm_layer
161
+ self.width = width
162
+ self.ocr_width = ocr_width
163
+ self.ocr_on = ocr_width > 0
164
+ self.align_corners = align_corners
165
+
166
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
167
+ self.bn1 = norm_layer(64)
168
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
169
+ self.bn2 = norm_layer(64)
170
+ self.relu = nn.ReLU(inplace=relu_inplace)
171
+
172
+ num_blocks = 2 if small else 4
173
+
174
+ stage1_num_channels = 64
175
+ self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
176
+ stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
177
+
178
+ self.stage2_num_branches = 2
179
+ num_channels = [width, 2 * width]
180
+ num_inchannels = [
181
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
182
+ self.transition1 = self._make_transition_layer(
183
+ [stage1_out_channel], num_inchannels)
184
+ self.stage2, pre_stage_channels = self._make_stage(
185
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
186
+ num_blocks=2 * [num_blocks], num_channels=num_channels)
187
+
188
+ self.stage3_num_branches = 3
189
+ num_channels = [width, 2 * width, 4 * width]
190
+ num_inchannels = [
191
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
192
+ self.transition2 = self._make_transition_layer(
193
+ pre_stage_channels, num_inchannels)
194
+ self.stage3, pre_stage_channels = self._make_stage(
195
+ BasicBlockV1b, num_inchannels=num_inchannels,
196
+ num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
197
+ num_blocks=3 * [num_blocks], num_channels=num_channels)
198
+
199
+ self.stage4_num_branches = 4
200
+ num_channels = [width, 2 * width, 4 * width, 8 * width]
201
+ num_inchannels = [
202
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
203
+ self.transition3 = self._make_transition_layer(
204
+ pre_stage_channels, num_inchannels)
205
+ self.stage4, pre_stage_channels = self._make_stage(
206
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
207
+ num_branches=self.stage4_num_branches,
208
+ num_blocks=4 * [num_blocks], num_channels=num_channels)
209
+
210
+ if self.ocr_on:
211
+ last_inp_channels = np.int_(np.sum(pre_stage_channels))
212
+ ocr_mid_channels = 2 * ocr_width
213
+ ocr_key_channels = ocr_width
214
+
215
+ self.conv3x3_ocr = nn.Sequential(
216
+ nn.Conv2d(last_inp_channels, ocr_mid_channels,
217
+ kernel_size=3, stride=1, padding=1),
218
+ norm_layer(ocr_mid_channels),
219
+ nn.ReLU(inplace=relu_inplace),
220
+ )
221
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
222
+
223
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
224
+ key_channels=ocr_key_channels,
225
+ out_channels=ocr_mid_channels,
226
+ scale=1,
227
+ dropout=0.05,
228
+ norm_layer=norm_layer,
229
+ align_corners=align_corners, opt=opt)
230
+
231
+ def _make_transition_layer(
232
+ self, num_channels_pre_layer, num_channels_cur_layer):
233
+ num_branches_cur = len(num_channels_cur_layer)
234
+ num_branches_pre = len(num_channels_pre_layer)
235
+
236
+ transition_layers = []
237
+ for i in range(num_branches_cur):
238
+ if i < num_branches_pre:
239
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
240
+ transition_layers.append(nn.Sequential(
241
+ nn.Conv2d(num_channels_pre_layer[i],
242
+ num_channels_cur_layer[i],
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1,
246
+ bias=False),
247
+ self.norm_layer(num_channels_cur_layer[i]),
248
+ nn.ReLU(inplace=relu_inplace)))
249
+ else:
250
+ transition_layers.append(None)
251
+ else:
252
+ conv3x3s = []
253
+ for j in range(i + 1 - num_branches_pre):
254
+ inchannels = num_channels_pre_layer[-1]
255
+ outchannels = num_channels_cur_layer[i] \
256
+ if j == i - num_branches_pre else inchannels
257
+ conv3x3s.append(nn.Sequential(
258
+ nn.Conv2d(inchannels, outchannels,
259
+ kernel_size=3, stride=2, padding=1, bias=False),
260
+ self.norm_layer(outchannels),
261
+ nn.ReLU(inplace=relu_inplace)))
262
+ transition_layers.append(nn.Sequential(*conv3x3s))
263
+
264
+ return nn.ModuleList(transition_layers)
265
+
266
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
267
+ downsample = None
268
+ if stride != 1 or inplanes != planes * block.expansion:
269
+ downsample = nn.Sequential(
270
+ nn.Conv2d(inplanes, planes * block.expansion,
271
+ kernel_size=1, stride=stride, bias=False),
272
+ self.norm_layer(planes * block.expansion),
273
+ )
274
+
275
+ layers = []
276
+ layers.append(block(inplanes, planes, stride,
277
+ downsample=downsample, norm_layer=self.norm_layer))
278
+ inplanes = planes * block.expansion
279
+ for i in range(1, blocks):
280
+ layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
281
+
282
+ return nn.Sequential(*layers)
283
+
284
+ def _make_stage(self, block, num_inchannels,
285
+ num_modules, num_branches, num_blocks, num_channels,
286
+ fuse_method='SUM',
287
+ multi_scale_output=True):
288
+ modules = []
289
+ for i in range(num_modules):
290
+ # multi_scale_output is only used last module
291
+ if not multi_scale_output and i == num_modules - 1:
292
+ reset_multi_scale_output = False
293
+ else:
294
+ reset_multi_scale_output = True
295
+ modules.append(
296
+ HighResolutionModule(num_branches,
297
+ block,
298
+ num_blocks,
299
+ num_inchannels,
300
+ num_channels,
301
+ fuse_method,
302
+ reset_multi_scale_output,
303
+ norm_layer=self.norm_layer,
304
+ align_corners=self.align_corners)
305
+ )
306
+ num_inchannels = modules[-1].get_num_inchannels()
307
+
308
+ return nn.Sequential(*modules), num_inchannels
309
+
310
+ def forward(self, x, mask=None, additional_features=None):
311
+ hrnet_feats = self.compute_hrnet_feats(x, additional_features)
312
+ if not self.ocr_on:
313
+ return hrnet_feats,
314
+
315
+ ocr_feats = self.conv3x3_ocr(hrnet_feats)
316
+ mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
317
+ context = self.ocr_gather_head(ocr_feats, mask)
318
+ ocr_feats = self.ocr_distri_head(ocr_feats, context)
319
+ return ocr_feats,
320
+
321
+ def compute_hrnet_feats(self, x, additional_features, return_list=False):
322
+ x = self.compute_pre_stage_features(x, additional_features)
323
+ x = self.layer1(x)
324
+
325
+ x_list = []
326
+ for i in range(self.stage2_num_branches):
327
+ if self.transition1[i] is not None:
328
+ x_list.append(self.transition1[i](x))
329
+ else:
330
+ x_list.append(x)
331
+ y_list = self.stage2(x_list)
332
+
333
+ x_list = []
334
+ for i in range(self.stage3_num_branches):
335
+ if self.transition2[i] is not None:
336
+ if i < self.stage2_num_branches:
337
+ x_list.append(self.transition2[i](y_list[i]))
338
+ else:
339
+ x_list.append(self.transition2[i](y_list[-1]))
340
+ else:
341
+ x_list.append(y_list[i])
342
+ y_list = self.stage3(x_list)
343
+
344
+ x_list = []
345
+ for i in range(self.stage4_num_branches):
346
+ if self.transition3[i] is not None:
347
+ if i < self.stage3_num_branches:
348
+ x_list.append(self.transition3[i](y_list[i]))
349
+ else:
350
+ x_list.append(self.transition3[i](y_list[-1]))
351
+ else:
352
+ x_list.append(y_list[i])
353
+ x = self.stage4(x_list)
354
+
355
+ if return_list:
356
+ return x
357
+
358
+ # Upsampling
359
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
360
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w),
361
+ mode='bilinear', align_corners=self.align_corners)
362
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w),
363
+ mode='bilinear', align_corners=self.align_corners)
364
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w),
365
+ mode='bilinear', align_corners=self.align_corners)
366
+
367
+ return torch.cat([x[0], x1, x2, x3], 1)
368
+
369
+ def compute_pre_stage_features(self, x, additional_features):
370
+ x = self.conv1(x)
371
+ x = self.bn1(x)
372
+ x = self.relu(x)
373
+ if additional_features is not None:
374
+ x = x + additional_features
375
+ x = self.conv2(x)
376
+ x = self.bn2(x)
377
+ return self.relu(x)
378
+
379
+ def load_pretrained_weights(self, pretrained_path=''):
380
+ model_dict = self.state_dict()
381
+
382
+ if not os.path.exists(pretrained_path):
383
+ print(f'\nFile "{pretrained_path}" does not exist.')
384
+ print('You need to specify the correct path to the pre-trained weights.\n'
385
+ 'You can download the weights for HRNet from the repository:\n'
386
+ 'https://github.com/HRNet/HRNet-Image-Classification')
387
+ exit(1)
388
+
389
+ # Устанавливаем устройство, на котором будет работать модель
390
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
391
+
392
+ # Загружаем веса и перемещаем на выбранное устройство
393
+ pretrained_dict = torch.load(pretrained_path, map_location=device)
394
+ pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in pretrained_dict.items()}
395
+ params_count = len(pretrained_dict)
396
+
397
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
398
+
399
+ print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
400
+
401
+ model_dict.update(pretrained_dict)
402
+ self.load_state_dict(model_dict)
403
+
404
+ # Перемещаем модель на устройство
405
+ self.to(device)
model/hrnetv2/modifiers.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class LRMult(object):
4
+ def __init__(self, lr_mult=1.):
5
+ self.lr_mult = lr_mult
6
+
7
+ def __call__(self, m):
8
+ if getattr(m, 'weight', None) is not None:
9
+ m.weight.lr_mult = self.lr_mult
10
+ if getattr(m, 'bias', None) is not None:
11
+ m.bias.lr_mult = self.lr_mult
model/hrnetv2/ocr.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch._utils
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class SpatialGather_Module(nn.Module):
8
+ """
9
+ Aggregate the context features according to the initial
10
+ predicted probability distribution.
11
+ Employ the soft-weighted method to aggregate the context.
12
+ """
13
+
14
+ def __init__(self, cls_num=0, scale=1):
15
+ super(SpatialGather_Module, self).__init__()
16
+ self.cls_num = cls_num
17
+ self.scale = scale
18
+
19
+ def forward(self, feats, probs):
20
+ batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
21
+ probs = probs.view(batch_size, c, -1)
22
+ feats = feats.view(batch_size, feats.size(1), -1)
23
+ feats = feats.permute(0, 2, 1) # batch x hw x c
24
+ probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
25
+ ocr_context = torch.matmul(probs, feats) \
26
+ .permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c
27
+ return ocr_context
28
+
29
+
30
+ class SpatialOCR_Module(nn.Module):
31
+ """
32
+ Implementation of the OCR module:
33
+ We aggregate the global object representation to update the representation for each pixel.
34
+ """
35
+
36
+ def __init__(self,
37
+ in_channels,
38
+ key_channels,
39
+ out_channels,
40
+ scale=1,
41
+ dropout=0.1,
42
+ norm_layer=nn.BatchNorm2d,
43
+ align_corners=True, opt=None):
44
+ super(SpatialOCR_Module, self).__init__()
45
+ self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
46
+ norm_layer, align_corners)
47
+ _in_channels = 2 * in_channels
48
+ self.conv_bn_dropout = nn.Sequential(
49
+ nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
50
+ nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
51
+ nn.Dropout2d(dropout)
52
+ )
53
+
54
+ def forward(self, feats, proxy_feats):
55
+ context = self.object_context_block(feats, proxy_feats)
56
+
57
+ output = self.conv_bn_dropout(torch.cat([context, feats], 1))
58
+
59
+ return output
60
+
61
+
62
+ class ObjectAttentionBlock2D(nn.Module):
63
+ '''
64
+ The basic implementation for object context block
65
+ Input:
66
+ N X C X H X W
67
+ Parameters:
68
+ in_channels : the dimension of the input feature map
69
+ key_channels : the dimension after the key/query transform
70
+ scale : choose the scale to downsample the input feature maps (save memory cost)
71
+ bn_type : specify the bn type
72
+ Return:
73
+ N X C X H X W
74
+ '''
75
+
76
+ def __init__(self,
77
+ in_channels,
78
+ key_channels,
79
+ scale=1,
80
+ norm_layer=nn.BatchNorm2d,
81
+ align_corners=True):
82
+ super(ObjectAttentionBlock2D, self).__init__()
83
+ self.scale = scale
84
+ self.in_channels = in_channels
85
+ self.key_channels = key_channels
86
+ self.align_corners = align_corners
87
+
88
+ self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
89
+ self.f_pixel = nn.Sequential(
90
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
91
+ kernel_size=1, stride=1, padding=0, bias=False),
92
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
93
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
94
+ kernel_size=1, stride=1, padding=0, bias=False),
95
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
96
+ )
97
+ self.f_object = nn.Sequential(
98
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
99
+ kernel_size=1, stride=1, padding=0, bias=False),
100
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
101
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
102
+ kernel_size=1, stride=1, padding=0, bias=False),
103
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
104
+ )
105
+ self.f_down = nn.Sequential(
106
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
107
+ kernel_size=1, stride=1, padding=0, bias=False),
108
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
109
+ )
110
+ self.f_up = nn.Sequential(
111
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
112
+ kernel_size=1, stride=1, padding=0, bias=False),
113
+ nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
114
+ )
115
+
116
+ def forward(self, x, proxy):
117
+ batch_size, h, w = x.size(0), x.size(2), x.size(3)
118
+ if self.scale > 1:
119
+ x = self.pool(x)
120
+
121
+ query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
122
+ query = query.permute(0, 2, 1)
123
+ key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
124
+ value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
125
+ value = value.permute(0, 2, 1)
126
+
127
+ sim_map = torch.matmul(query, key)
128
+ sim_map = (self.key_channels ** -.5) * sim_map
129
+ sim_map = F.softmax(sim_map, dim=-1)
130
+
131
+ # add bg context ...
132
+ context = torch.matmul(sim_map, value)
133
+ context = context.permute(0, 2, 1).contiguous()
134
+ context = context.view(batch_size, self.key_channels, *x.size()[2:])
135
+ context = self.f_up(context)
136
+ if self.scale > 1:
137
+ context = F.interpolate(input=context, size=(h, w),
138
+ mode='bilinear', align_corners=self.align_corners)
139
+
140
+ return context
model/hrnetv2/resnetv1b.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
4
+
5
+
6
+ class BasicBlockV1b(nn.Module):
7
+ expansion = 1
8
+
9
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
10
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
11
+ super(BasicBlockV1b, self).__init__()
12
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
13
+ padding=dilation, dilation=dilation, bias=False)
14
+ self.bn1 = norm_layer(planes)
15
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
16
+ padding=previous_dilation, dilation=previous_dilation, bias=False)
17
+ self.bn2 = norm_layer(planes)
18
+
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.downsample = downsample
21
+ self.stride = stride
22
+
23
+ def forward(self, x):
24
+ residual = x
25
+
26
+ out = self.conv1(x)
27
+ out = self.bn1(out)
28
+ out = self.relu(out)
29
+
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+
33
+ if self.downsample is not None:
34
+ residual = self.downsample(x)
35
+
36
+ out = out + residual
37
+ out = self.relu(out)
38
+
39
+ return out
40
+
41
+
42
+ class BottleneckV1b(nn.Module):
43
+ expansion = 4
44
+
45
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
46
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
47
+ super(BottleneckV1b, self).__init__()
48
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49
+ self.bn1 = norm_layer(planes)
50
+
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52
+ padding=dilation, dilation=dilation, bias=False)
53
+ self.bn2 = norm_layer(planes)
54
+
55
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
56
+ self.bn3 = norm_layer(planes * self.expansion)
57
+
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.downsample = downsample
60
+ self.stride = stride
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+
65
+ out = self.conv1(x)
66
+ out = self.bn1(out)
67
+ out = self.relu(out)
68
+
69
+ out = self.conv2(out)
70
+ out = self.bn2(out)
71
+ out = self.relu(out)
72
+
73
+ out = self.conv3(out)
74
+ out = self.bn3(out)
75
+
76
+ if self.downsample is not None:
77
+ residual = self.downsample(x)
78
+
79
+ out = out + residual
80
+ out = self.relu(out)
81
+
82
+ return out
83
+
84
+
85
+ class ResNetV1b(nn.Module):
86
+ """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
87
+
88
+ Parameters
89
+ ----------
90
+ block : Block
91
+ Class for the residual block. Options are BasicBlockV1, BottleneckV1.
92
+ layers : list of int
93
+ Numbers of layers in each block
94
+ classes : int, default 1000
95
+ Number of classification classes.
96
+ dilated : bool, default False
97
+ Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
98
+ typically used in Semantic Segmentation.
99
+ norm_layer : object
100
+ Normalization layer used (default: :class:`nn.BatchNorm2d`)
101
+ deep_stem : bool, default False
102
+ Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
103
+ avg_down : bool, default False
104
+ Whether to use average pooling for projection skip connection between stages/downsample.
105
+ final_drop : float, default 0.0
106
+ Dropout ratio before the final classification layer.
107
+
108
+ Reference:
109
+ - He, Kaiming, et al. "Deep residual learning for image recognition."
110
+ Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
111
+
112
+ - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
113
+ """
114
+ def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
115
+ avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
116
+ self.inplanes = stem_width*2 if deep_stem else 64
117
+ super(ResNetV1b, self).__init__()
118
+ if not deep_stem:
119
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
120
+ else:
121
+ self.conv1 = nn.Sequential(
122
+ nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
123
+ norm_layer(stem_width),
124
+ nn.ReLU(True),
125
+ nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
126
+ norm_layer(stem_width),
127
+ nn.ReLU(True),
128
+ nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
129
+ )
130
+ self.bn1 = norm_layer(self.inplanes)
131
+ self.relu = nn.ReLU(True)
132
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
133
+ self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
134
+ norm_layer=norm_layer)
135
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
136
+ norm_layer=norm_layer)
137
+ if dilated:
138
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
139
+ avg_down=avg_down, norm_layer=norm_layer)
140
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
141
+ avg_down=avg_down, norm_layer=norm_layer)
142
+ else:
143
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144
+ avg_down=avg_down, norm_layer=norm_layer)
145
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146
+ avg_down=avg_down, norm_layer=norm_layer)
147
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148
+ self.drop = None
149
+ if final_drop > 0.0:
150
+ self.drop = nn.Dropout(final_drop)
151
+ self.fc = nn.Linear(512 * block.expansion, classes)
152
+
153
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
154
+ avg_down=False, norm_layer=nn.BatchNorm2d):
155
+ downsample = None
156
+ if stride != 1 or self.inplanes != planes * block.expansion:
157
+ downsample = []
158
+ if avg_down:
159
+ if dilation == 1:
160
+ downsample.append(
161
+ nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
162
+ )
163
+ else:
164
+ downsample.append(
165
+ nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
166
+ )
167
+ downsample.extend([
168
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
169
+ kernel_size=1, stride=1, bias=False),
170
+ norm_layer(planes * block.expansion)
171
+ ])
172
+ downsample = nn.Sequential(*downsample)
173
+ else:
174
+ downsample = nn.Sequential(
175
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
176
+ kernel_size=1, stride=stride, bias=False),
177
+ norm_layer(planes * block.expansion)
178
+ )
179
+
180
+ layers = []
181
+ if dilation in (1, 2):
182
+ layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
183
+ previous_dilation=dilation, norm_layer=norm_layer))
184
+ elif dilation == 4:
185
+ layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
186
+ previous_dilation=dilation, norm_layer=norm_layer))
187
+ else:
188
+ raise RuntimeError("=> unknown dilation size: {}".format(dilation))
189
+
190
+ self.inplanes = planes * block.expansion
191
+ for _ in range(1, blocks):
192
+ layers.append(block(self.inplanes, planes, dilation=dilation,
193
+ previous_dilation=dilation, norm_layer=norm_layer))
194
+
195
+ return nn.Sequential(*layers)
196
+
197
+ def forward(self, x):
198
+ x = self.conv1(x)
199
+ x = self.bn1(x)
200
+ x = self.relu(x)
201
+ x = self.maxpool(x)
202
+
203
+ x = self.layer1(x)
204
+ x = self.layer2(x)
205
+ x = self.layer3(x)
206
+ x = self.layer4(x)
207
+
208
+ x = self.avgpool(x)
209
+ x = x.view(x.size(0), -1)
210
+ if self.drop is not None:
211
+ x = self.drop(x)
212
+ x = self.fc(x)
213
+
214
+ return x
215
+
216
+
217
+ def _safe_state_dict_filtering(orig_dict, model_dict_keys):
218
+ filtered_orig_dict = {}
219
+ for k, v in orig_dict.items():
220
+ if k in model_dict_keys:
221
+ filtered_orig_dict[k] = v
222
+ else:
223
+ print(f"[ERROR] Failed to load <{k}> in backbone")
224
+ return filtered_orig_dict
225
+
226
+
227
+ def resnet34_v1b(pretrained=False, **kwargs):
228
+ model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
229
+ if pretrained:
230
+ model_dict = model.state_dict()
231
+ filtered_orig_dict = _safe_state_dict_filtering(
232
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
233
+ model_dict.keys()
234
+ )
235
+ model_dict.update(filtered_orig_dict)
236
+ model.load_state_dict(model_dict)
237
+ return model
238
+
239
+
240
+ def resnet50_v1s(pretrained=False, **kwargs):
241
+ model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
242
+ if pretrained:
243
+ model_dict = model.state_dict()
244
+ filtered_orig_dict = _safe_state_dict_filtering(
245
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
246
+ model_dict.keys()
247
+ )
248
+ model_dict.update(filtered_orig_dict)
249
+ model.load_state_dict(model_dict)
250
+ return model
251
+
252
+
253
+ def resnet101_v1s(pretrained=False, **kwargs):
254
+ model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
255
+ if pretrained:
256
+ model_dict = model.state_dict()
257
+ filtered_orig_dict = _safe_state_dict_filtering(
258
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
259
+ model_dict.keys()
260
+ )
261
+ model_dict.update(filtered_orig_dict)
262
+ model.load_state_dict(model_dict)
263
+ return model
264
+
265
+
266
+ def resnet152_v1s(pretrained=False, **kwargs):
267
+ model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
268
+ if pretrained:
269
+ model_dict = model.state_dict()
270
+ filtered_orig_dict = _safe_state_dict_filtering(
271
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
272
+ model_dict.keys()
273
+ )
274
+ model_dict.update(filtered_orig_dict)
275
+ model.load_state_dict(model_dict)
276
+ return model