diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..e53bdd888d3221c0b317e64532d9a40078391bc6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/demo.gif filter=lfs diff=lfs merge=lfs -text +assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text +demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text +demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2dc7cdd9b9bf6c7c5b0f8c9af23da304787baed1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.idea/* +logs/* +wandb/* +system/ +*.bat +*.7z +.venv +__pycache__/ +*.pyc \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..7a64d60f08c3d3f702624ed8223fcbb01bdd4a18 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,280 @@ +import os + +import cv2 + import gradio as gr +import numpy as np +import sys +import io +import spaces + +class Logger: + def __init__(self): + self.terminal = sys.stdout + self.log = io.BytesIO() + + def write(self, message): + self.terminal.write(message) + self.log.write(bytes(message, encoding='utf-8')) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + + +log = Logger() +sys.stdout = log + + +def read_logs(): + out = log.log.getvalue().decode() + if out.count("\n") >= 30: + log.log = io.BytesIO() + sys.stdout.flush() + return out + + +with gr.Blocks() as app: + + valid_checkpoints_dict = {"Resolution_256_iHarmony4": "Resolution_256_iHarmony4.pth", + "Resolution_1024_HAdobe5K": "Resolution_1024_HAdobe5K.pth", + "Resolution_2048_HAdobe5K": "Resolution_2048_HAdobe5K.pth", + "Resolution_RAW_HAdobe5K": "Resolution_RAW_HAdobe5K.pth", + "Resolution_RAW_iHarmony4": "Resolution_RAW_iHarmony4.pth"} + + global_state = gr.State(valid_checkpoints_dict["Resolution_RAW_iHarmony4"]) + with gr.Row(): + with gr.Column(): + form_composite_image = gr.Image(label='Input Composite image', type='pil') + gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "composite" in i]), + label="Composite Examples", inputs=form_composite_image, cache_examples=False) + with gr.Column(): + form_mask_image = gr.Image(label='Input Mask image', type='pil', interactive=False) + gr.Examples(examples=sorted([os.path.join("demo", i) for i in os.listdir("demo") if "mask" in i]), + label="Mask Examples", inputs=form_mask_image, cache_examples=False) + with gr.Row(): + with gr.Column(scale=4): + with gr.Row(): + with gr.Column(): + gr.Markdown(value='Model Selection', show_label=False) + + with gr.Column(): + form_pretrained_dropdown = gr.Dropdown( + choices=list(valid_checkpoints_dict.values()), + label="Pretrained Model", + value=valid_checkpoints_dict["Resolution_RAW_iHarmony4"], + interactive=True + ) + + with gr.Row(): + with gr.Column(): + gr.Markdown(value='Inference Mode', show_label=False) + + with gr.Column(): + form_inference_mode = gr.Radio( + ['Square Image', 'Arbitrary Image'], + value='Arbitrary Image', + interactive=False, + label='Mode', + ) + + with gr.Row(): + with gr.Column(): + gr.Markdown(value='Split Parameter', show_label=False) + + with gr.Column(): + form_split_res = gr.Slider( + minimum=0, + maximum=2048, + step=128, + value=256, + interactive=True, + label="Split Resolution", + ) + form_split_num = gr.Number( + value=2, + interactive=True, + label="Split Number") + with gr.Row(): + form_log = gr.Textbox(read_logs, label="Logs", interactive=False, type="text", every=1) + + with gr.Column(scale=4): + form_harmonized_image = gr.Image(label='Harmonized Result', type='numpy', interactive=False, format="png") + form_start_btn = gr.Button("Start Harmonization", interactive=False) + form_reset_btn = gr.Button("Reset", interactive=True) + form_stop_btn = gr.Button("Stop", interactive=True) + + + def on_change_form_composite_image(form_composite_image): + if form_composite_image is None: + return gr.update(interactive=False, value=None), gr.update(value=None) + return gr.update(interactive=True, value=None), gr.update(value=None) + + + def on_change_form_mask_image(form_composite_image, form_mask_image): + if form_mask_image is None: + return gr.update(interactive=False), gr.update( + interactive=False if form_composite_image is None else True), gr.update(interactive=False), gr.update( + interactive=False), gr.update(interactive=False), gr.update(value=None) + + if form_composite_image.size[:2] != form_mask_image.size[:2]: + raise gr.Error("Composite image and mask image should have the same resolution!") + else: + w, h = form_composite_image.size[:2] + if h != w or (h % 16 != 0): + return gr.update(value='Arbitrary Image', interactive=False), gr.update(interactive=True), gr.update( + interactive=True), gr.update(interactive=True, visible=True), gr.update(interactive=False, + value=-1, visible=False), gr.update(value=None) + else: + return gr.update(value='Square Image', interactive=True), gr.update(interactive=True), gr.update( + interactive=True), gr.update(interactive=False, visible=False), gr.update(interactive=True, + value=h // 2, + maximum=h, + minimum=h // 16, + step=h // 16, visible=True), gr.update(value=None) + + + form_composite_image.change( + on_change_form_composite_image, + inputs=[form_composite_image], + outputs=[form_mask_image, form_harmonized_image] + ) + + form_mask_image.change( + on_change_form_mask_image, + inputs=[form_composite_image, form_mask_image], + outputs=[form_inference_mode, form_mask_image, form_start_btn, form_split_num, form_split_res, + form_harmonized_image] + ) + + + def on_change_form_split_num(form_composite_image, form_split_num): + w, h = form_composite_image.size[:2] + if form_split_num < 1: + return gr.update(value=1) + elif form_split_num > min(w, h): + return gr.update(value=min(w, h)) + else: + return gr.update(value=form_split_num) + + + form_split_num.change( + on_change_form_split_num, + inputs=[form_composite_image, form_split_num], + outputs=[form_split_num] + ) + + + def on_change_form_inference_mode(form_inference_mode): + if form_inference_mode == "Square Image": + return gr.update(interactive=True, visible=True), gr.update(interactive=False, visible=False) + else: + return gr.update(interactive=False, visible=False), gr.update(interactive=True, visible=True) + + + form_inference_mode.change(on_change_form_inference_mode, inputs=[form_inference_mode], + outputs=[form_split_res, form_split_num]) + + @spaces.GPU + def on_click_form_start_btn(form_composite_image, form_mask_image, form_pretrained_dropdown, form_inference_mode, + form_split_res, form_split_num): + log.log = io.BytesIO() + print(f"Harmonizing image with {form_composite_image.size[1]}*{form_composite_image.size[0]}...") + if form_inference_mode == "Square Image": + from efficient_inference_for_square_image import parse_args, main_process, global_state + global_state[0] = 1 + + opt = parse_args() + opt.transform_mean = [.5, .5, .5] + opt.transform_var = [.5, .5, .5] + opt.pretrained = os.path.join("./pretrained_models", form_pretrained_dropdown) + opt.split_resolution = form_split_res + opt.save_path = None + opt.workers = 0 + opt.device = "gpu" + + composite_image = np.asarray(form_composite_image) + mask = np.asarray(form_mask_image) + + try: + return cv2.cvtColor( + main_process(opt, composite_image=composite_image, mask=mask), + cv2.COLOR_BGR2RGB) + except Exception as e: + raise gr.Error(f"Patches too big. Try to reduce the `split_res`!\nException is {e}") + + else: + from inference_for_arbitrary_resolution_image import parse_args, main_process, global_state + global_state[0] = 1 + + opt = parse_args() + opt.transform_mean = [.5, .5, .5] + opt.transform_var = [.5, .5, .5] + opt.pretrained = os.path.join("./pretrained_models", form_pretrained_dropdown) + opt.split_num = int(form_split_num) + opt.save_path = None + opt.workers = 0 + opt.device = "gpu" + + composite_image = np.asarray(form_composite_image) + mask = np.asarray(form_mask_image) + + try: + return cv2.cvtColor( + main_process(opt, composite_image=composite_image, mask=mask), + cv2.COLOR_BGR2RGB) + except Exception as e: + raise gr.Error(f"Patches too big. Try to increase the `split_num`!\nException is {e}") + + + generate = form_start_btn.click(on_click_form_start_btn, + inputs=[form_composite_image, form_mask_image, form_pretrained_dropdown, + form_inference_mode, + form_split_res, form_split_num], outputs=[form_harmonized_image]) + + + def on_click_form_reset_btn(form_inference_mode): + if form_inference_mode == "Square Image": + from efficient_inference_for_square_image import global_state + global_state[0] = 0 + else: + from inference_for_arbitrary_resolution_image import global_state + global_state[0] = 0 + + log.log = io.BytesIO() + return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None, + interactive=False), gr.update( + interactive=False) + + + form_reset_btn.click(on_click_form_reset_btn, + inputs=[form_inference_mode], + outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate) + + + def on_click_form_stop(form_inference_mode): + if form_inference_mode == "Square Image": + from efficient_inference_for_square_image import global_state + global_state[0] = 0 + else: + from inference_for_arbitrary_resolution_image import global_state + global_state[0] = 0 + + log.log = io.BytesIO() + return gr.update(value=None), gr.update(value=None, interactive=True), gr.update(value=None, + interactive=False), gr.update( + interactive=False) + + + form_stop_btn.click(on_click_form_stop, + inputs=[form_inference_mode], + outputs=[form_log, form_composite_image, form_mask_image, form_start_btn], cancels=generate) + +gr.close_all() -def greet(name): - return "Hello " + name + "!!" +app.queue() -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() +app.launch(show_api=False) diff --git a/assets/demo.gif b/assets/demo.gif new file mode 100644 index 0000000000000000000000000000000000000000..86b3d14c40233d09211296def0fa1730358ab6e3 --- /dev/null +++ b/assets/demo.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5f136d5335252050ca723e0360a767ebc5d94fd87d6d372221575769d6528a7 +size 1727946 diff --git a/assets/metrics.png b/assets/metrics.png new file mode 100644 index 0000000000000000000000000000000000000000..a582340ec3adeaf2238d29ac75dfe4379f10455e Binary files /dev/null and b/assets/metrics.png differ diff --git a/assets/network.png b/assets/network.png new file mode 100644 index 0000000000000000000000000000000000000000..33e0ce848c8c80259f21179b28cee41c160c4f91 Binary files /dev/null and b/assets/network.png differ diff --git a/assets/title_any_image.gif b/assets/title_any_image.gif new file mode 100644 index 0000000000000000000000000000000000000000..811cdca8c592f3e818c2e12ede9a416cb2ed9f0b Binary files /dev/null and b/assets/title_any_image.gif differ diff --git a/assets/title_harmon.gif b/assets/title_harmon.gif new file mode 100644 index 0000000000000000000000000000000000000000..dfd60802a20933714ee26ff6fefbbc996a120e9b Binary files /dev/null and b/assets/title_harmon.gif differ diff --git a/assets/title_you_want.gif b/assets/title_you_want.gif new file mode 100644 index 0000000000000000000000000000000000000000..26d5d9f036680a1cc2d9ab7c62959272f555c055 Binary files /dev/null and b/assets/title_you_want.gif differ diff --git a/assets/visualizations.png b/assets/visualizations.png new file mode 100644 index 0000000000000000000000000000000000000000..93fbd1362704fea79b06262da32d598e9b848f45 Binary files /dev/null and b/assets/visualizations.png differ diff --git a/assets/visualizations2.png b/assets/visualizations2.png new file mode 100644 index 0000000000000000000000000000000000000000..51a4c8467cce746fb33331d9e0ee901dd9a6e267 --- /dev/null +++ b/assets/visualizations2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fa5f4c202818ab94d6faf57055a323285e169a33ccfd59200bc93a8d597a4a4 +size 1673273 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/build_INR_dataset.py b/datasets/build_INR_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..141384f87bee9e4e4741dc87e6297046e05d9fe7 --- /dev/null +++ b/datasets/build_INR_dataset.py @@ -0,0 +1,36 @@ +from utils import misc +from albumentations import Resize + + +class Implicit2DGenerator(object): + def __init__(self, opt, mode): + if mode == 'Train': + sidelength = opt.INR_input_size + elif mode == 'Val': + sidelength = opt.input_size + else: + raise NotImplementedError + + self.mode = mode + + self.size = sidelength + + if isinstance(sidelength, int): + sidelength = (sidelength, sidelength) + + self.mgrid = misc.get_mgrid(sidelength) + + self.transform = Resize(self.size, self.size) + + def generator(self, torch_transforms, composite_image, real_image, mask): + composite_image = torch_transforms(self.transform(image=composite_image)['image']) + real_image = torch_transforms(self.transform(image=real_image)['image']) + + fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3) + fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3) + bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3) + + fg_INR_coordinates = self.mgrid + bg_INR_coordinates = self.mgrid + + return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB diff --git a/datasets/build_dataset.py b/datasets/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..50414948efae73ac05c1343235417a32260bf265 --- /dev/null +++ b/datasets/build_dataset.py @@ -0,0 +1,371 @@ +import torch +import cv2 +import numpy as np +import torchvision +import os +import random + +from utils.misc import prepare_cooridinate_input, customRandomCrop + +from datasets.build_INR_dataset import Implicit2DGenerator +import albumentations +from albumentations import Resize, RandomResizedCrop, HorizontalFlip +from torch.utils.data import DataLoader + + +class dataset_generator(torch.utils.data.Dataset): + def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'): + super().__init__() + + self.opt = opt + self.root_path = opt.dataset_path + self.mode = mode + + self.alb_transforms = alb_transforms + self.torch_transforms = torch_transforms + self.kp_t = area_keep_thresh + + with open(dataset_txt, 'r') as f: + self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()] + + self.INR_dataset = Implicit2DGenerator(opt, self.mode) + + def __len__(self): + return len(self.dataset_samples) + + def __getitem__(self, idx): + composite_image = self.dataset_samples[idx] + + if self.opt.hr_train: + if self.opt.isFullRes: + "Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \ + "quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \ + "if `opt.isFullRes` is set to True." + composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori") + + real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg' + mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png' + + composite_image = cv2.imread(composite_image) + composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) + + real_image = cv2.imread(real_image) + real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) + + mask = cv2.imread(mask) + mask = mask[:, :, 0].astype(np.float32) / 255. + + """ + If set `opt.hr_train` to True: + + Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres), + the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size` + patch to feed in multiINR process. For inference, just resize it. + + While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size. + + BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5` + """ + if self.opt.hr_train: + if self.mode == 'Train' and self.opt.isFullRes: + if random.random() < 0.5: # LR mix training + mixTransform = albumentations.Compose( + [ + RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), + HorizontalFlip()], + additional_targets={'real_image': 'image', 'object_mask': 'image'} + ) + origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + origin_bg_ratio = 1 - origin_fg_ratio + + "Ensure fg and bg not disappear after transformation" + valid_augmentation = False + transform_out = None + time = 0 + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose( + [Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = True + else: + transform_out = mixTransform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], + origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + composite_image = transform_out['image'] + real_image = transform_out['real_image'] + mask = transform_out['object_mask'] + else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop. + if real_image.shape[0] < 256: + bottom_pad = 256 - real_image.shape[0] + else: + bottom_pad = (4 - real_image.shape[0] % 4) % 4 + if real_image.shape[1] < 256: + right_pad = 256 - real_image.shape[1] + else: + right_pad = (4 - real_image.shape[1] % 4) % 4 + composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad, + cv2.BORDER_REPLICATE) + real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) + mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE) + + origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + origin_bg_ratio = 1 - origin_fg_ratio + + "Ensure fg and bg not disappear after transformation" + valid_augmentation = False + transform_out = None + time = 0 + + if self.opt.hr_train: + if self.mode == 'Train': + if not self.opt.isFullRes: + if random.random() < 0.5: # LR mix training + mixTransform = albumentations.Compose( + [ + RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)), + HorizontalFlip()], + additional_targets={'real_image': 'image', 'object_mask': 'image'} + ) + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose( + [Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = True + else: + transform_out = mixTransform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], + origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + else: + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose( + [Resize(self.opt.input_size, self.opt.input_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = True + else: + transform_out = self.alb_transforms(image=composite_image, real_image=real_image, + object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], + origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + composite_image = transform_out['image'] + real_image = transform_out['real_image'] + mask = transform_out['object_mask'] + + origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + + full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) + + tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) + compos_list = [self.torch_transforms(transform_out['image'])] + real_list = [self.torch_transforms(transform_out['real_image'])] + mask_list = [ + torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] + coord_map_list = [] + + valid_augmentation = False + while not valid_augmentation: + # RSC strategy. To crop different resolutions. + transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord], + self.opt.base_size, self.opt.base_size) + valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio) + + compos_list.append(self.torch_transforms(transform_out[0])) + real_list.append(self.torch_transforms(transform_out[1])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) + for n in range(2): + tmp_comp = cv2.resize(composite_image, ( + composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) + tmp_real = cv2.resize(real_image, + (real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1))) + tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) + tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) + + transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord], + self.opt.base_size // 2 ** (n + 1), + self.opt.base_size // 2 ** (n + 1), c_h, c_w) + compos_list.append(self.torch_transforms(transform_out[0])) + real_list.append(self.torch_transforms(transform_out[1])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3])) + out_comp = compos_list + out_real = real_list + out_mask = mask_list + out_coord = coord_map_list + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out[0], transform_out[1], mask) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': out_comp, + 'real_image': out_real, + 'mask': out_mask, + 'coordinate_map': out_coord, + 'composite_image0': out_comp[0], + 'real_image0': out_real[0], + 'mask0': out_mask[0], + 'coordinate_map0': out_coord[0], + 'composite_image1': out_comp[1], + 'real_image1': out_real[1], + 'mask1': out_mask[1], + 'coordinate_map1': out_coord[1], + 'composite_image2': out_comp[2], + 'real_image2': out_real[2], + 'mask2': out_mask[2], + 'coordinate_map2': out_coord[2], + 'composite_image3': out_comp[3], + 'real_image3': out_real[3], + 'mask3': out_mask[3], + 'coordinate_map3': out_coord[3], + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + else: + if not self.opt.isFullRes: + tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) + + coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) + + "Generate INR dataset." + mask = (torchvision.transforms.ToTensor()( + transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) + mask = np.bool_(mask.numpy()) + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': self.torch_transforms(transform_out['image']), + 'real_image': self.torch_transforms(transform_out['real_image']), + 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), + # Can automatically transfer to Tensor. + 'coordinate_map': coordinate_map, + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + else: + coordinate_map = prepare_cooridinate_input(mask) + + "Generate INR dataset." + mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1) + mask_tmp = np.bool_(mask_tmp.numpy()) + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, composite_image, real_image, mask_tmp) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': self.torch_transforms(composite_image), + 'real_image': self.torch_transforms(real_image), + 'mask': mask[np.newaxis, ...].astype(np.float32), + # Can automatically transfer to Tensor. + 'coordinate_map': coordinate_map, + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + + while not valid_augmentation: + time += 1 + # There are some extreme ratio pics, this code is to avoid being hindered by them. + if time == 20: + tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)], + additional_targets={'real_image': 'image', + 'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask) + valid_augmentation = True + else: + transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask) + valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio, + origin_bg_ratio, + self.kp_t) + + coordinate_map = prepare_cooridinate_input(transform_out['object_mask']) + + "Generate INR dataset." + mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1) + mask = np.bool_(mask.numpy()) + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out['image'], transform_out['real_image'], mask) + + return { + 'file_path': self.dataset_samples[idx], + 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0], + 'composite_image': self.torch_transforms(transform_out['image']), + 'real_image': self.torch_transforms(transform_out['real_image']), + 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32), + # Can automatically transfer to Tensor. + 'coordinate_map': coordinate_map, + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB + } + + +def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh): + current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + current_bg_ratio = 1 - current_fg_ratio + + if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh: + return False + + return True + + +def check_hr_crop_sample(mask, origin_fg_ratio): + current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1]) + + if current_fg_ratio < 0.8 * origin_fg_ratio: + return False + + return True diff --git a/demo/demo_1k_composite_2.jpg b/demo/demo_1k_composite_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8810507f60782158922a2bb49de79f0b58c939a1 Binary files /dev/null and b/demo/demo_1k_composite_2.jpg differ diff --git a/demo/demo_1k_composite_3.jpg b/demo/demo_1k_composite_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9a9f4ca163d1f0a86e4c88de821a692715c3089b Binary files /dev/null and b/demo/demo_1k_composite_3.jpg differ diff --git a/demo/demo_1k_mask_2.jpg b/demo/demo_1k_mask_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6bc188582d600082a3b29da2defe8cfb1bb04068 Binary files /dev/null and b/demo/demo_1k_mask_2.jpg differ diff --git a/demo/demo_1k_mask_3.jpg b/demo/demo_1k_mask_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cb24b813e3b9b4e375c42b2f69800c478147ba8f Binary files /dev/null and b/demo/demo_1k_mask_3.jpg differ diff --git a/demo/demo_composite.jpg b/demo/demo_composite.jpg new file mode 100644 index 0000000000000000000000000000000000000000..151325cbcfffeb3fc35612e561ed4c2c5e8a33fd Binary files /dev/null and b/demo/demo_composite.jpg differ diff --git a/demo/demo_composite_1.jpg b/demo/demo_composite_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dcff1ce9c458e547107535e8530d08cca3f2e6c1 Binary files /dev/null and b/demo/demo_composite_1.jpg differ diff --git a/demo/demo_composite_2.jpg b/demo/demo_composite_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f830ce430155057c946ce000901e182f7550f55e Binary files /dev/null and b/demo/demo_composite_2.jpg differ diff --git a/demo/demo_composite_3.jpg b/demo/demo_composite_3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd9f1a7ab4ad85bb976e36595d93318a6b3dc025 Binary files /dev/null and b/demo/demo_composite_3.jpg differ diff --git a/demo/demo_composite_4.jpg b/demo/demo_composite_4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a1667f4d3152b351127a2d1275054d28535305a8 Binary files /dev/null and b/demo/demo_composite_4.jpg differ diff --git a/demo/demo_composite_5.jpg b/demo/demo_composite_5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..73ea2d399bb347789ac96d683e1cdec840758513 Binary files /dev/null and b/demo/demo_composite_5.jpg differ diff --git a/demo/demo_composite_6.jpg b/demo/demo_composite_6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0badf7a4eb65b54e770f68a1d396c728f7a5a9d2 Binary files /dev/null and b/demo/demo_composite_6.jpg differ diff --git a/demo/demo_mask.png b/demo/demo_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..246102f188a67b79ba138b81186bff8b6da0ff8c Binary files /dev/null and b/demo/demo_mask.png differ diff --git a/demo/demo_mask_1.png b/demo/demo_mask_1.png new file mode 100644 index 0000000000000000000000000000000000000000..0709e9f06c82b8f856fdc719aa23c4bc872b51d5 Binary files /dev/null and b/demo/demo_mask_1.png differ diff --git a/demo/demo_mask_2.png b/demo/demo_mask_2.png new file mode 100644 index 0000000000000000000000000000000000000000..acfe7581fb90d18a79b145ac493d180e3d6f5919 Binary files /dev/null and b/demo/demo_mask_2.png differ diff --git a/demo/demo_mask_3.png b/demo/demo_mask_3.png new file mode 100644 index 0000000000000000000000000000000000000000..918129753f1b25385f1d345a3bb9dc360edefa6b Binary files /dev/null and b/demo/demo_mask_3.png differ diff --git a/demo/demo_mask_4.jpg b/demo/demo_mask_4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26867aca142e6f0712a5de352edc3dd11352ab0f Binary files /dev/null and b/demo/demo_mask_4.jpg differ diff --git a/demo/demo_mask_5.jpg b/demo/demo_mask_5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc96965f2c7272dbf093ae0433312f4c9446a862 Binary files /dev/null and b/demo/demo_mask_5.jpg differ diff --git a/demo/demo_mask_6.jpg b/demo/demo_mask_6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25778519e8af5fe999aa1ad1228306c796a37f51 Binary files /dev/null and b/demo/demo_mask_6.jpg differ diff --git a/efficient_inference_for_square_image.py b/efficient_inference_for_square_image.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b4d6c1fb2ddb5a6c940a6ff30c76bee99b6c9e --- /dev/null +++ b/efficient_inference_for_square_image.py @@ -0,0 +1,356 @@ +import argparse +import builtins +from collections import defaultdict + +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from model.build_model import build_model +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR + +import torch +import cv2 +import numpy as np +import torchvision +import os +import tqdm +import time + +from utils.misc import prepare_cooridinate_input, customRandomCrop + +from datasets.build_INR_dataset import Implicit2DGenerator +import albumentations +from albumentations import Resize +# from torch.utils.data import DataLoader +from utils.misc import normalize + +import math + +global_state = [1] # For Gradio Stop Button. + +class single_image_dataset(torch.utils.data.Dataset): + def __init__(self, opt, composite_image=None, mask=None): + super().__init__() + + self.opt = opt + + if composite_image is None: + composite_image = cv2.imread(opt.composite_image) + composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) + self.composite_image = composite_image + + assert composite_image.shape[0] == composite_image.shape[1], "This faster script only supports square images." + assert composite_image.shape[ + 0] % 256 == 0, "This faster script only supports images with resolution multiples of 256." + assert opt.split_resolution % (composite_image.shape[ + 0] // 16) == 0, f"The image resolution is {composite_image.shape[0]}, " \ + f"you should set {opt.split_resolution} to multiplies of {composite_image.shape[0] // 16}" + + if mask is None: + mask = cv2.imread(opt.mask) + mask = mask[:, :, 0].astype(np.float32) / 255. + self.mask = mask + + self.torch_transforms = transforms.Compose([transforms.ToTensor(), + transforms.Normalize([.5, .5, .5], [.5, .5, .5])]) + self.INR_dataset = Implicit2DGenerator(opt, 'Val') + + self.split_width_resolution = self.split_height_resolution = opt.split_resolution + + self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution) + self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution) + + self.split_start_point = [] + + "Split the image into several parts." + for i in range(self.num_h): + for j in range(self.num_w): + if i == composite_image.shape[0] // self.split_height_resolution: + if j == composite_image.shape[1] // self.split_width_resolution: + self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution, + composite_image.shape[1] - self.split_width_resolution)) + else: + self.split_start_point.append( + (composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution)) + else: + if j == composite_image.shape[1] // self.split_width_resolution: + self.split_start_point.append( + (i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution)) + else: + self.split_start_point.append( + (i * self.split_height_resolution, j * self.split_width_resolution)) + + assert len(self.split_start_point) == self.num_w * self.num_h + + print( + f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.") + print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}") + + def __len__(self): + return self.num_w * self.num_h + + def __getitem__(self, idx): + composite_image = self.composite_image + + mask = self.mask + + full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) + + tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'object_mask': 'image'}) + transform_out = tmp_transform(image=self.composite_image, object_mask=self.mask) + compos_list = [self.torch_transforms(transform_out['image'])] + mask_list = [ + torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] + coord_map_list = [] + + if composite_image.shape[0] != self.split_height_resolution: + c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution) + else: + c_h = 0 + if composite_image.shape[1] != self.split_width_resolution: + c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution) + else: + c_w = 0 + transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord], + self.split_height_resolution, self.split_width_resolution, c_h, c_w) + + compos_list.append(self.torch_transforms(transform_out[0])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2])) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2])) + for n in range(2): + tmp_comp = cv2.resize(composite_image, ( + composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) + tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) + tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) + + transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord], + self.split_height_resolution // 2 ** (n + 1), + self.split_width_resolution // 2 ** (n + 1), c_h, c_w) + compos_list.append(self.torch_transforms(transform_out[0])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2])) + out_comp = compos_list + out_mask = mask_list + out_coord = coord_map_list + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out[0], transform_out[0], mask) + + return { + 'composite_image': out_comp, + 'mask': out_mask, + 'coordinate_map': out_coord, + 'composite_image0': out_comp[0], + 'mask0': out_mask[0], + 'coordinate_map0': out_coord[0], + 'composite_image1': out_comp[1], + 'mask1': out_mask[1], + 'coordinate_map1': out_coord[1], + 'composite_image2': out_comp[2], + 'mask2': out_mask[2], + 'coordinate_map2': out_coord[2], + 'composite_image3': out_comp[3], + 'mask3': out_mask[3], + 'coordinate_map3': out_coord[3], + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB, + 'start_point': self.split_start_point[idx], + 'start_proportion': [self.split_start_point[idx][0] / (composite_image.shape[0]), + self.split_start_point[idx][1] / (composite_image.shape[1]), + (self.split_start_point[idx][0] + self.split_height_resolution) / ( + composite_image.shape[0]), + (self.split_start_point[idx][1] + self.split_width_resolution) / ( + composite_image.shape[1])], + } + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--split_resolution', type=int, default=2048, + help='The resolution of the patch split.') + + parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg', + help='composite image path') + + parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg', + help='mask path') + + parser.add_argument('--save_path', type=str, default=r'./demo/', + help='save path') + + parser.add_argument('--workers', type=int, default=8, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch_size', type=int, default=1, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--device', type=str, default='cuda', + help="Whether use cuda, 'cuda' or 'cpu'.") + + parser.add_argument('--base_size', type=int, default=256, + help='Base size. Resolution of the image input into the Encoder') + + parser.add_argument('--input_size', type=int, default=256, + help='Input size. Resolution of the image that want to be generated by the Decoder') + + parser.add_argument('--INR_input_size', type=int, default=256, + help='INR input size. Resolution of the image that want to be generated by the Decoder. ' + 'Should be the same as `input_size`') + + parser.add_argument('--INR_MLP_dim', type=int, default=32, + help='Number of channels for INR linear layer.') + + parser.add_argument('--LUT_dim', type=int, default=7, + help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076') + + parser.add_argument('--activation', type=str, default='leakyrelu_pe', + help='INR activation layer type: leakyrelu_pe, sine') + + parser.add_argument('--pretrained', type=str, + default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth', + help='Pretrained weight path') + + parser.add_argument('--param_factorize_dim', type=int, + default=10, + help='The intermediate dimensions of the factorization of the predicted MLP parameters. ' + 'Refer to https://arxiv.org/abs/2011.12026') + + parser.add_argument('--embedding_type', type=str, + default="CIPS_embed", + help='Which embedding_type to use.') + + parser.add_argument('--INRDecode', action="store_false", + help='Whether INR decoder. Set it to False if you want to test the baseline ' + '(https://github.com/SamsungLabs/image_harmonization)') + + parser.add_argument('--isMoreINRInput', action="store_false", + help='Whether to cat RGB and mask. See Section 3.4 in the paper.') + + parser.add_argument('--hr_train', action="store_false", + help='Whether use hr_train. See section 3.4 in the paper.') + + parser.add_argument('--isFullRes', action="store_true", + help='Whether for original resolution. See section 3.4 in the paper.') + + opt = parser.parse_args() + + assert opt.batch_size == 1, 'This faster script only supports batch size 1 for inference.' + + return opt + + +@torch.no_grad() +def inference(model, opt, composite_image=None, mask=None): + model.eval() + + "dataset here is actually consisted of several patches of a single image." + singledataset = single_image_dataset(opt, composite_image, mask) + + single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True, + num_workers=opt.workers, persistent_workers=False if composite_image is not None else True) + + "Init a pure black image with the same size as the input image." + init_img = np.zeros_like(singledataset.composite_image) + + time_all = 0 + + for step, batch in tqdm.tqdm(enumerate(single_data_loader)): + composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)] + mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)] + coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)] + start_points = batch['start_point'] + start_proportion = batch['start_proportion'] + + if opt.batch_size == 1: + start_points = [torch.cat(start_points)] + start_proportion = [torch.cat(start_proportion)] + + fg_INR_coordinates = coordinate_map[1:] + + try: + if global_state[0] == 0: + print("Stop Harmonizing...!") + break + + if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow. + fg_content_bg_appearance_construct, _, lut_transform_image = model( + composite_image, + mask, + fg_INR_coordinates, start_proportion[0] + ) + print("Ready for harmonization...") + if opt.device == "cuda": + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() + start_time = time.time() + torch.cuda.synchronize() + fg_content_bg_appearance_construct, _, lut_transform_image = model( + composite_image, + mask, + fg_INR_coordinates, start_proportion[0] + ) + if opt.device == "cuda": + torch.cuda.synchronize() + end_time = time.time() + + end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2 + end_memory = torch.cuda.memory_allocated() // 1024 ** 2 + + print(f'GPU max memory usage: {end_max_memory} MB') + print(f'GPU memory usage: {end_memory} MB') + time_all += (end_time - start_time) + print(f'progress: {step} / {len(single_data_loader)}') + except: + raise Exception( + f'The image resolution is large. Please reduce the `split_resolution` value. Your current set is {opt.split_resolution}') + + "Assemble the every patch's harmonized result into the final whole image." + for id in range(len(fg_INR_coordinates[0])): + pred_fg_image = fg_content_bg_appearance_construct[-1][id] + pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * ( + ~(mask[1][id] > 100 / 255.)) + + pred_harmonized_tmp = cv2.cvtColor( + normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + + init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution, + start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp + + if opt.device == "cuda": + print(f'Inference time: {time_all}') + if opt.save_path is not None: + os.makedirs(opt.save_path, exist_ok=True) + cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img) + return init_img + +def main_process(opt, composite_image=None, mask=None): + # torch.serialization.add_safe_globals([getattr, OneCycleLR, AdamW, defaultdict, builtins.dict]) + cudnn.benchmark = True + opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("Preparing model...") + model = build_model(opt).to(opt.device) + + # Заменяем 'gpu' на 'cuda' и добавляем weights_only=True + load_dict = torch.load(opt.pretrained, map_location='cpu')['model'] + + model.load_state_dict(load_dict, strict=False) + + return inference(model, opt, composite_image, mask) + + +if __name__ == '__main__': + opt = parse_args() + opt.transform_mean = [.5, .5, .5] + opt.transform_var = [.5, .5, .5] + main_process(opt) diff --git a/hrnet_ocr.py b/hrnet_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..04c46069ea5d4b4bc4ae4f252ef898ff361cdb83 --- /dev/null +++ b/hrnet_ocr.py @@ -0,0 +1,401 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch._utils + +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True, opt=None): + super(HighResolutionNet, self).__init__() + self.opt = opt + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.ocr_on = ocr_width > 0 + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + if self.ocr_on: + last_inp_channels = np.int(np.sum(pre_stage_channels)) + ocr_mid_channels = 2 * ocr_width + ocr_key_channels = ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners, opt=opt) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, mask=None, additional_features=None): + hrnet_feats = self.compute_hrnet_feats(x, additional_features) + if not self.ocr_on: + return hrnet_feats, + + ocr_feats = self.conv3x3_ocr(hrnet_feats) + mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True) + context = self.ocr_gather_head(ocr_feats, mask) + ocr_feats = self.ocr_distri_head(ocr_feats, context) + return ocr_feats, + + def compute_hrnet_feats(self, x, additional_features, return_list=False): + x = self.compute_pre_stage_features(x, additional_features) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + if return_list: + return x + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def compute_pre_stage_features(self, x, additional_features): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + additional_features + x = self.conv2(x) + x = self.bn2(x) + return self.relu(x) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pretrained_dict = torch.load(pretrained_path, map_location=device) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in + pretrained_dict.items()} + params_count = len(pretrained_dict) + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + + # print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet') + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..bb689c81a120247bfa9afa6b217b87b25bc0de6a --- /dev/null +++ b/inference.py @@ -0,0 +1,236 @@ +import os +import argparse + +import albumentations +from albumentations import Resize + +import torch +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from model.build_model import build_model +from datasets.build_dataset import dataset_generator + +from utils import misc, metrics + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--workers', type=int, default=1, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch_size', type=int, default=1, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--device', type=str, default='cuda', + help="Whether use cuda, 'cuda' or 'cpu'.") + + parser.add_argument('--save_path', type=str, default="./logs", + help='Where to save logs and checkpoints.') + + parser.add_argument('--dataset_path', type=str, default=r".\iHarmony4", + help='Dataset path.') + + parser.add_argument('--base_size', type=int, default=256, + help='Base size. Resolution of the image input into the Encoder') + + parser.add_argument('--input_size', type=int, default=256, + help='Input size. Resolution of the image that want to be generated by the Decoder') + + parser.add_argument('--INR_input_size', type=int, default=256, + help='INR input size. Resolution of the image that want to be generated by the Decoder. ' + 'Should be the same as `input_size`') + + parser.add_argument('--INR_MLP_dim', type=int, default=32, + help='Number of channels for INR linear layer.') + + parser.add_argument('--LUT_dim', type=int, default=7, + help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076') + + parser.add_argument('--activation', type=str, default='leakyrelu_pe', + help='INR activation layer type: leakyrelu_pe, sine') + + parser.add_argument('--pretrained', type=str, + default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth', + help='Pretrained weight path') + + parser.add_argument('--param_factorize_dim', type=int, + default=10, + help='The intermediate dimensions of the factorization of the predicted MLP parameters. ' + 'Refer to https://arxiv.org/abs/2011.12026') + + parser.add_argument('--embedding_type', type=str, + default="CIPS_embed", + help='Which embedding_type to use.') + + parser.add_argument('--optim', type=str, + default='adamw', + help='Which optimizer to use.') + + parser.add_argument('--INRDecode', action="store_false", + help='Whether INR decoder. Set it to False if you want to test the baseline ' + '(https://github.com/SamsungLabs/image_harmonization)') + + parser.add_argument('--isMoreINRInput', action="store_false", + help='Whether to cat RGB and mask. See Section 3.4 in the paper.') + + parser.add_argument('--hr_train', action="store_true", + help='Whether use hr_train. See section 3.4 in the paper.') + + parser.add_argument('--isFullRes', action="store_true", + help='Whether for original resolution. See section 3.4 in the paper.') + + opt = parser.parse_args() + + opt.save_path = misc.increment_path(os.path.join(opt.save_path, "test1")) + + return opt + + +def inference(val_loader, model, logger, opt): + current_process = 10 + model.eval() + + metric_log = { + 'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + } + + lut_metric_log = { + 'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + } + + for step, batch in enumerate(val_loader): + composite_image = batch['composite_image'].to(opt.device) + real_image = batch['real_image'].to(opt.device) + mask = batch['mask'].to(opt.device) + category = batch['category'] + + fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device) + + with torch.no_grad(): + fg_content_bg_appearance_construct, _, lut_transform_image = model( + composite_image, + mask, + fg_INR_coordinates, + ) + + if opt.INRDecode: + pred_fg_image = fg_content_bg_appearance_construct[-1] + else: + pred_fg_image = misc.lin2img(fg_content_bg_appearance_construct, + val_loader.dataset.INR_dataset.size) if fg_content_bg_appearance_construct is not None else None + + if not opt.INRDecode: + pred_harmonized_image = None + else: + pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.)) + lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.)) + + misc.visualize(real_image, composite_image, mask, pred_fg_image, + pred_harmonized_image, lut_transform_image, opt, -1, show=False, + wandb=False, isAll=True, step=step) + + if opt.INRDecode: + mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'), + misc.normalize(real_image, opt, 'inv'), mask) + + lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'), + misc.normalize(real_image, opt, 'inv'), mask) + + for idx in range(len(category)): + if opt.INRDecode: + metric_log[category[idx]]['Samples'] += 1 + metric_log[category[idx]]['MSE'] += mse[idx] + metric_log[category[idx]]['fMSE'] += fmse[idx] + metric_log[category[idx]]['PSNR'] += psnr[idx] + metric_log[category[idx]]['SSIM'] += ssim[idx] + + metric_log['All']['Samples'] += 1 + metric_log['All']['MSE'] += mse[idx] + metric_log['All']['fMSE'] += fmse[idx] + metric_log['All']['PSNR'] += psnr[idx] + metric_log['All']['SSIM'] += ssim[idx] + + lut_metric_log[category[idx]]['Samples'] += 1 + lut_metric_log[category[idx]]['MSE'] += lut_mse[idx] + lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx] + lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx] + lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx] + + lut_metric_log['All']['Samples'] += 1 + lut_metric_log['All']['MSE'] += lut_mse[idx] + lut_metric_log['All']['fMSE'] += lut_fmse[idx] + lut_metric_log['All']['PSNR'] += lut_psnr[idx] + lut_metric_log['All']['SSIM'] += lut_ssim[idx] + + if (step + 1) / len(val_loader) * 100 >= current_process: + logger.info(f'Processing: {current_process}') + current_process += 10 + + logger.info('=========================') + for key in metric_log.keys(): + if opt.INRDecode: + msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \ + f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n" + else: + msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n" + + logger.info(msg) + + logger.info('=========================') + + +def main_process(opt): + logger = misc.create_logger(os.path.join(opt.save_path, "log.txt")) + cudnn.benchmark = True + + valset_path = os.path.join(opt.dataset_path, "IHD_test.txt") + + opt.transform_mean = [.5, .5, .5] + opt.transform_var = [.5, .5, .5] + torch_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(opt.transform_mean, opt.transform_var)]) + + valset_alb_transform = albumentations.Compose([Resize(opt.input_size, opt.input_size)], + additional_targets={'real_image': 'image', 'object_mask': 'image'}) + + valset = dataset_generator(valset_path, valset_alb_transform, torch_transform, opt, mode='Val') + + val_loader = DataLoader(valset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True, + num_workers=opt.workers, persistent_workers=True) + + model = build_model(opt).to(opt.device) + logger.info(f"Load pretrained weight from {opt.pretrained}") + + load_dict = torch.load(opt.pretrained)['model'] + for k in load_dict.keys(): + if k not in model.state_dict().keys(): + print(f"Skip {k}") + model.load_state_dict(load_dict, strict=False) + + inference(val_loader, model, logger, opt) + + +if __name__ == '__main__': + opt = parse_args() + os.makedirs(opt.save_path, exist_ok=True) + main_process(opt) \ No newline at end of file diff --git a/inference_for_arbitrary_resolution_image.py b/inference_for_arbitrary_resolution_image.py new file mode 100644 index 0000000000000000000000000000000000000000..2a59a1462704646b6a998ebf2fd8ff446d6d04dd --- /dev/null +++ b/inference_for_arbitrary_resolution_image.py @@ -0,0 +1,345 @@ +import argparse + +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +from model.build_model import build_model + +import torch +import cv2 +import numpy as np +import torchvision +import os +import tqdm +import time + +from utils.misc import prepare_cooridinate_input, customRandomCrop + +from datasets.build_INR_dataset import Implicit2DGenerator +import albumentations +from albumentations import Resize +from torch.utils.data import DataLoader +from utils.misc import normalize + +import math + +global_state = [1] # For Gradio Stop Button. + +class single_image_dataset(torch.utils.data.Dataset): + def __init__(self, opt, composite_image=None, mask=None): + super().__init__() + + self.opt = opt + + if composite_image is None: + composite_image = cv2.imread(opt.composite_image) + composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) + self.composite_image = composite_image + + if mask is None: + mask = cv2.imread(opt.mask) + mask = mask[:, :, 0].astype(np.float32) / 255. + self.mask = mask + + self.torch_transforms = transforms.Compose([transforms.ToTensor(), + transforms.Normalize([.5, .5, .5], [.5, .5, .5])]) + self.INR_dataset = Implicit2DGenerator(opt, 'Val') + + self.split_width_resolution = composite_image.shape[1] // opt.split_num + self.split_height_resolution = composite_image.shape[0] // opt.split_num + + self.split_width_resolution = self.split_height_resolution = min(self.split_width_resolution, + self.split_height_resolution) + + if self.split_width_resolution % 4 != 0: + self.split_width_resolution = self.split_width_resolution + (4 - self.split_width_resolution % 4) + + if self.split_height_resolution % 4 != 0: + self.split_height_resolution = self.split_height_resolution + (4 - self.split_height_resolution % 4) + + self.num_w = math.ceil(composite_image.shape[1] / self.split_width_resolution) + self.num_h = math.ceil(composite_image.shape[0] / self.split_height_resolution) + + self.split_start_point = [] + + "Split the image into several parts." + for i in range(self.num_h): + for j in range(self.num_w): + if i == composite_image.shape[0] // self.split_height_resolution: + if j == composite_image.shape[1] // self.split_width_resolution: + self.split_start_point.append((composite_image.shape[0] - self.split_height_resolution, + composite_image.shape[1] - self.split_width_resolution)) + else: + self.split_start_point.append( + (composite_image.shape[0] - self.split_height_resolution, j * self.split_width_resolution)) + else: + if j == composite_image.shape[1] // self.split_width_resolution: + self.split_start_point.append( + (i * self.split_height_resolution, composite_image.shape[1] - self.split_width_resolution)) + else: + self.split_start_point.append( + (i * self.split_height_resolution, j * self.split_width_resolution)) + + assert len(self.split_start_point) == self.num_w * self.num_h + + print( + f"The image will be split into {self.num_h} pieces in height, and {self.num_w} pieces in width. Totally {self.num_h * self.num_w} patches.") + print(f"The final resolution of each patch is {self.split_height_resolution} x {self.split_width_resolution}") + + def __len__(self): + return self.num_w * self.num_h + + def __getitem__(self, idx): + composite_image = self.composite_image + + mask = self.mask + + full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0) + + tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)], + additional_targets={'object_mask': 'image'}) + transform_out = tmp_transform(image=composite_image, object_mask=mask) + compos_list = [self.torch_transforms(transform_out['image'])] + mask_list = [ + torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))] + coord_map_list = [] + + if composite_image.shape[0] != self.split_height_resolution: + c_h = self.split_start_point[idx][0] / (composite_image.shape[0] - self.split_height_resolution) + else: + c_h = 0 + if composite_image.shape[1] != self.split_width_resolution: + c_w = self.split_start_point[idx][1] / (composite_image.shape[1] - self.split_width_resolution) + else: + c_w = 0 + transform_out, c_h, c_w = customRandomCrop([composite_image, mask, full_coord], + self.split_height_resolution, self.split_width_resolution, c_h, c_w) + + compos_list.append(self.torch_transforms(transform_out[0])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2])) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2])) + for n in range(2): + tmp_comp = cv2.resize(composite_image, ( + composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1))) + tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1))) + tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0) + + transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_mask, tmp_coord], + self.split_height_resolution // 2 ** (n + 1), + self.split_width_resolution // 2 ** (n + 1), c_h, c_w) + compos_list.append(self.torch_transforms(transform_out[0])) + mask_list.append( + torchvision.transforms.ToTensor()(transform_out[1][..., np.newaxis].astype(np.float32))) + coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[2])) + out_comp = compos_list + out_mask = mask_list + out_coord = coord_map_list + + fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator( + self.torch_transforms, transform_out[0], transform_out[0], mask) + + return { + 'composite_image': out_comp, + 'mask': out_mask, + 'coordinate_map': out_coord, + 'composite_image0': out_comp[0], + 'mask0': out_mask[0], + 'coordinate_map0': out_coord[0], + 'composite_image1': out_comp[1], + 'mask1': out_mask[1], + 'coordinate_map1': out_coord[1], + 'composite_image2': out_comp[2], + 'mask2': out_mask[2], + 'coordinate_map2': out_coord[2], + 'composite_image3': out_comp[3], + 'mask3': out_mask[3], + 'coordinate_map3': out_coord[3], + 'fg_INR_coordinates': fg_INR_coordinates, + 'bg_INR_coordinates': bg_INR_coordinates, + 'fg_INR_RGB': fg_INR_RGB, + 'fg_transfer_INR_RGB': fg_transfer_INR_RGB, + 'bg_INR_RGB': bg_INR_RGB, + 'start_point': self.split_start_point[idx], + } + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--split_num', type=int, default=4, + help='How many pieces do you want to split an image width / height.') + + parser.add_argument('--composite_image', type=str, default=r'./demo/demo_2k_composite.jpg', + help='composite image path') + + parser.add_argument('--mask', type=str, default=r'./demo/demo_2k_mask.jpg', + help='mask path') + + parser.add_argument('--save_path', type=str, default=r'./demo/', + help='save path') + + parser.add_argument('--workers', type=int, default=8, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch_size', type=int, default=1, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--device', type=str, default='cuda', + help="Whether use cuda, 'cuda' or 'cpu'.") + + parser.add_argument('--base_size', type=int, default=256, + help='Base size. Resolution of the image input into the Encoder') + + parser.add_argument('--input_size', type=int, default=256, + help='Input size. Resolution of the image that want to be generated by the Decoder') + + parser.add_argument('--INR_input_size', type=int, default=256, + help='INR input size. Resolution of the image that want to be generated by the Decoder. ' + 'Should be the same as `input_size`') + + parser.add_argument('--INR_MLP_dim', type=int, default=32, + help='Number of channels for INR linear layer.') + + parser.add_argument('--LUT_dim', type=int, default=7, + help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076') + + parser.add_argument('--activation', type=str, default='leakyrelu_pe', + help='INR activation layer type: leakyrelu_pe, sine') + + parser.add_argument('--pretrained', type=str, + default=r'.\pretrained_models\Resolution_RAW_iHarmony4.pth', + help='Pretrained weight path') + + parser.add_argument('--param_factorize_dim', type=int, + default=10, + help='The intermediate dimensions of the factorization of the predicted MLP parameters. ' + 'Refer to https://arxiv.org/abs/2011.12026') + + parser.add_argument('--embedding_type', type=str, + default="CIPS_embed", + help='Which embedding_type to use.') + + parser.add_argument('--INRDecode', action="store_false", + help='Whether INR decoder. Set it to False if you want to test the baseline ' + '(https://github.com/SamsungLabs/image_harmonization)') + + parser.add_argument('--isMoreINRInput', action="store_false", + help='Whether to cat RGB and mask. See Section 3.4 in the paper.') + + parser.add_argument('--hr_train', action="store_false", + help='Whether use hr_train. See section 3.4 in the paper.') + + parser.add_argument('--isFullRes', action="store_true", + help='Whether for original resolution. See section 3.4 in the paper.') + + opt = parser.parse_args() + + return opt + +@torch.no_grad() +def inference(model, opt, composite_image=None, mask=None): + model.eval() + + "dataset here is actually consisted of several patches of a single image." + singledataset = single_image_dataset(opt, composite_image, mask) + + single_data_loader = DataLoader(singledataset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True, + num_workers=opt.workers, persistent_workers=False if composite_image is not None else True) + + "Init a pure black image with the same size as the input image." + init_img = np.zeros_like(singledataset.composite_image) + + time_all = 0 + + for step, batch in tqdm.tqdm(enumerate(single_data_loader)): + composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)] + mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)] + coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)] + start_points = batch['start_point'] + + if opt.batch_size == 1: + start_points = [torch.cat(start_points)] + + fg_INR_coordinates = coordinate_map[1:] + + try: + if global_state[0] == 0: + print("Stop Harmonizing...!") + break + + if step == 0: # This is for CUDA Kernel Warm-up, or the first inference step will be quite slow. + fg_content_bg_appearance_construct, _, lut_transform_image = model( + composite_image, + mask, + fg_INR_coordinates, + ) + print("Ready for harmonization...") + + if opt.device == "cuda": + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() + start_time = time.time() + torch.cuda.synchronize() + fg_content_bg_appearance_construct, _, lut_transform_image = model( + composite_image, + mask, + fg_INR_coordinates, + ) + if opt.device == "cuda": + torch.cuda.synchronize() + end_time = time.time() + + end_max_memory = torch.cuda.max_memory_allocated() // 1024 ** 2 + end_memory = torch.cuda.memory_allocated() // 1024 ** 2 + + print(f'GPU max memory usage: {end_max_memory} MB') + print(f'GPU memory usage: {end_memory} MB') + time_all += (end_time - start_time) + print(f'progress: {step} / {len(single_data_loader)}') + except: + raise Exception( + f'The image resolution is large. Please increase the `split_num` value. Your current set is {opt.split_num}') + + "Assemble the every patch's harmonized result into the final whole image." + for id in range(len(fg_INR_coordinates[0])): + pred_fg_image = fg_content_bg_appearance_construct[-1][id] + pred_harmonized_image = pred_fg_image * (mask[1][id] > 100 / 255.) + composite_image[1][id] * ( + ~(mask[1][id] > 100 / 255.)) + + pred_harmonized_tmp = cv2.cvtColor( + normalize(pred_harmonized_image.unsqueeze(0), opt, 'inv')[0].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + + init_img[start_points[id][0]:start_points[id][0] + singledataset.split_height_resolution, + start_points[id][1]:start_points[id][1] + singledataset.split_width_resolution] = pred_harmonized_tmp + + if opt.device == "cuda": + print(f'Inference time: {time_all}') + if opt.save_path is not None: + os.makedirs(opt.save_path, exist_ok=True) + cv2.imwrite(os.path.join(opt.save_path, "pred_harmonized_image.jpg"), init_img) + return init_img + + +def main_process(opt, composite_image=None, mask=None): + cudnn.benchmark = True + # Заменяем 'gpu' на 'cuda' + opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("Preparing model...") + model = build_model(opt).to(opt.device) + + load_dict = torch.load(opt.pretrained, map_location='cpu')['model'] + + model.load_state_dict(load_dict, strict=False) + + return inference(model, opt, composite_image, mask) + + +if __name__ == '__main__': + opt = parse_args() + opt.transform_mean = [.5, .5, .5] + opt.transform_var = [.5, .5, .5] + main_process(opt) diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/backbone.py b/model/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..6ef7b61ca1bf5a22e9ac62cf9519dd1b68832cbe --- /dev/null +++ b/model/backbone.py @@ -0,0 +1,79 @@ +import torch.nn as nn + +from .hrnetv2.hrnet_ocr import HighResolutionNet +from .hrnetv2.modifiers import LRMult +from .base.basic_blocks import MaxPoolDownSize +from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization + + +def build_backbone(name, opt): + return eval(name)(opt) + + +class baseline(IHModelWithBackbone): + def __init__(self, opt, ocr=64): + base_config = {'model': DeepImageHarmonization, + 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}} + + params = base_config['params'] + + backbone = HRNetV2(opt, ocr=ocr) + + params.update(dict( + backbone_from=2, + backbone_channels=backbone.output_channels, + backbone_mode='cat', + opt=opt + )) + base_model = base_config['model'](**params) + + super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt) + + +class HRNetV2(nn.Module): + def __init__( + self, opt, + cat_outputs=True, + pyramid_channels=-1, pyramid_depth=4, + width=18, ocr=128, small=False, + lr_mult=0.1, pretained=True + ): + super(HRNetV2, self).__init__() + self.opt = opt + self.cat_outputs = cat_outputs + self.ocr_on = ocr > 0 and cat_outputs + self.pyramid_on = pyramid_channels > 0 and cat_outputs + + self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt) + self.hrnet.apply(LRMult(lr_mult)) + if self.ocr_on: + self.hrnet.ocr_distri_head.apply(LRMult(1.0)) + self.hrnet.ocr_gather_head.apply(LRMult(1.0)) + self.hrnet.conv3x3_ocr.apply(LRMult(1.0)) + + hrnet_cat_channels = [width * 2 ** i for i in range(4)] + if self.pyramid_on: + self.output_channels = [pyramid_channels] * 4 + elif self.ocr_on: + self.output_channels = [ocr * 2] + elif self.cat_outputs: + self.output_channels = [sum(hrnet_cat_channels)] + else: + self.output_channels = hrnet_cat_channels + + if self.pyramid_on: + downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels) + self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth) + + if pretained: + self.load_pretrained_weights( + "./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth") + + self.output_resolution = (opt.input_size // 8) ** 2 + + def forward(self, image, mask, mask_features=None): + outputs = list(self.hrnet(image, mask, mask_features)) + return outputs + + def load_pretrained_weights(self, pretrained_path): + self.hrnet.load_pretrained_weights(pretrained_path) diff --git a/model/base/__init__.py b/model/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/base/basic_blocks.py b/model/base/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..cf62bec8deb6f23d23b94aaec27448adb4d03672 --- /dev/null +++ b/model/base/basic_blocks.py @@ -0,0 +1,366 @@ +import torch +from torch import nn as nn +import numpy as np + + +def hyper_weight_init(m, in_features_main_net, activation): + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + m.weight.data = m.weight.data / 1.e2 + + if hasattr(m, 'bias'): + with torch.no_grad(): + if activation == 'sine': + m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30) + elif activation == 'leakyrelu_pe': + m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net)) + else: + raise NotImplementedError + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels, out_channels, + kernel_size=4, stride=2, padding=1, + norm_layer=nn.BatchNorm2d, activation=nn.ELU, + bias=True, + ): + super(ConvBlock, self).__init__() + self.block = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + activation(), + ) + + def forward(self, x): + return self.block(x) + + +class MaxPoolDownSize(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, depth): + super(MaxPoolDownSize, self).__init__() + self.depth = depth + self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0) + self.convs = nn.ModuleList([ + ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) + for conv_i in range(depth) + ]) + self.pool2d = nn.MaxPool2d(kernel_size=2) + + def forward(self, x): + outputs = [] + + output = self.reduce_conv(x) + + for conv_i, conv in enumerate(self.convs): + output = output if conv_i == 0 else self.pool2d(output) + outputs.append(conv(output)) + + return outputs + + +class convParams(nn.Module): + def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False): + super(convParams, self).__init__() + self.INR_in_out = INR_in_out + self.cont_split_weight = [] + self.cont_split_bias = [] + self.hidden_mlp_num = hidden_mlp_num + self.param_factorize_dim = opt.param_factorize_dim + output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB) + self.output_dim = output_dim + self.toRGB = toRGB + self.cont_extraction_net = nn.Sequential( + nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False), + # nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), + # nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True), + ) + + self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation)) + + self.basic_params = nn.ParameterList() + if opt.param_factorize_dim > 0: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp))) + + if toRGB: + self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3))) + + def forward(self, feat, outMore=False): + cont_params = self.cont_extraction_net(feat) + out_mlp = self.to_mlp(cont_params) + if outMore: + return out_mlp, cont_params + return out_mlp + + def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False): + cont_params = 0 + start = 0 + if self.param_factorize_dim == -1: + cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1] + self.cont_split_weight.append([start, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + for id in range(hidden_mlp_num): + cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1] + self.cont_split_weight.append([start, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + if toRGB: + cont_params += INR_in_out[1] * 3 + 3 + self.cont_split_weight.append([start, cont_params - 3]) + self.cont_split_bias.append([cont_params - 3, cont_params]) + + elif self.param_factorize_dim > 0: + cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.cont_split_weight.append( + [start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + for id in range(hidden_mlp_num): + cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.cont_split_weight.append( + [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]]) + self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params]) + start = cont_params + + if toRGB: + cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3 + self.cont_split_weight.append( + [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3]) + self.cont_split_bias.append([cont_params - 3, cont_params]) + + return cont_params + + def to_mlp(self, params): + all_weight_bias = [] + if self.param_factorize_dim == -1: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :] + weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:], + inp, outp) + + bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + all_weight_bias.append([weight, bias]) + + if self.toRGB: + inp, outp = self.INR_in_out[1], 3 + weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :] + weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:], + inp, outp) + + bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + all_weight_bias.append([weight, bias]) + + return all_weight_bias + + else: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :] + weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:], + inp, self.param_factorize_dim) + + weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :] + weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:], + self.param_factorize_dim, outp) + + bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) + + if self.toRGB: + inp, outp = self.INR_in_out[1], 3 + weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :] + weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:], + inp, self.param_factorize_dim) + + weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :] + weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:], + self.param_factorize_dim, outp) + + bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :] + bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias]) + + return all_weight_bias + + +class lineParams(nn.Module): + def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False, + hidden_dim=512): + super(lineParams, self).__init__() + self.INR_in_out = INR_in_out + self.app_split_weight = [] + self.app_split_bias = [] + self.toRGB = toRGB + self.hidden_mlp_num = hidden_mlp_num + self.param_factorize_dim = opt.param_factorize_dim + output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num) + self.output_dim = output_dim + + self.compress_layer = nn.Sequential( + nn.Linear(input_resolution, 64, bias=False), + nn.BatchNorm1d(input_dim), + nn.ReLU(inplace=True), + nn.Linear(64, 1, bias=True) + ) + + self.app_extraction_net = nn.Sequential( + nn.Linear(input_dim, hidden_dim, bias=False), + # nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, hidden_dim, bias=False), + # nn.BatchNorm1d(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, output_dim, bias=True) + ) + + self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation)) + + self.basic_params = nn.ParameterList() + if opt.param_factorize_dim > 0: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp))) + if toRGB: + self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3))) + + def forward(self, feat): + app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1)) + out_mlp = self.to_mlp(app_params) + return out_mlp, app_params + + def cal_params_num(self, INR_in_out, hidden_mlp_num): + app_params = 0 + start = 0 + if self.param_factorize_dim == -1: + app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1] + self.app_split_weight.append([start, app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + for id in range(hidden_mlp_num): + app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1] + self.app_split_weight.append([start, app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + if self.toRGB: + app_params += INR_in_out[1] * 3 + 3 + self.app_split_weight.append([start, app_params - 3]) + self.app_split_bias.append([app_params - 3, app_params]) + + elif self.param_factorize_dim > 0: + app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim, + app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + for id in range(hidden_mlp_num): + app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \ + INR_in_out[1] + self.app_split_weight.append( + [start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]]) + self.app_split_bias.append([app_params - INR_in_out[1], app_params]) + start = app_params + + if self.toRGB: + app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3 + self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim, + app_params - 3]) + self.app_split_bias.append([app_params - 3, app_params]) + + return app_params + + def to_mlp(self, params): + all_weight_bias = [] + if self.param_factorize_dim == -1: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight = weight.view(weight.shape[0], inp, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([weight, bias]) + + if self.toRGB: + id = -1 + inp, outp = self.INR_in_out[1], 3 + weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight = weight.view(weight.shape[0], inp, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([weight, bias]) + + return all_weight_bias + + else: + for id in range(self.hidden_mlp_num + 1): + if id == 0: + inp, outp = self.INR_in_out[0], self.INR_in_out[1] + else: + inp, outp = self.INR_in_out[1], self.INR_in_out[1] + weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim) + + weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]] + weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) + + if self.toRGB: + id = -1 + inp, outp = self.INR_in_out[1], 3 + weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]] + weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim) + + weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]] + weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp) + + bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]] + bias = bias.view(bias.shape[0], 1, outp) + + all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias]) + + return all_weight_bias diff --git a/model/base/conv_autoencoder.py b/model/base/conv_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bc30b98941037b4d6dc30ee177f5872daa1327c1 --- /dev/null +++ b/model/base/conv_autoencoder.py @@ -0,0 +1,519 @@ +import torch +import torchvision +from torch import nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from .basic_blocks import ConvBlock, lineParams, convParams +from .ops import MaskedChannelAttention, FeaturesConnector +from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed +from utils import misc +from utils.misc import lin2img +from ..lut_transformation_net import build_lut_transform + + +class Sine(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.sin(30 * input) + + +class Leaky_relu(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.leaky_relu(input, 0.01, inplace=True) + + +def select_activation(type): + if type == 'sine': + return Sine() + elif type == 'leakyrelu_pe': + return Leaky_relu() + else: + raise NotImplementedError + + +class ConvEncoder(nn.Module): + def __init__( + self, + depth, ch, + norm_layer, batchnorm_from, max_channels, + backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False + ): + super(ConvEncoder, self).__init__() + self.depth = depth + self.INRDecode = INRDecode + self.backbone_from = backbone_from + backbone_channels = [] if backbone_channels is None else backbone_channels[::-1] + + in_channels = 4 + out_channels = ch + + self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None) + self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None) + self.blocks_channels = [out_channels, out_channels] + + self.blocks_connected = nn.ModuleDict() + self.connectors = nn.ModuleDict() + for block_i in range(2, depth): + if block_i % 2: + in_channels = out_channels + else: + in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) + + if 0 <= backbone_from <= block_i and len(backbone_channels): + if INRDecode: + self.blocks_connected[f'block{block_i}_decode'] = ConvBlock( + in_channels, out_channels, + norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, + padding=int(block_i < depth - 1) + ) + self.blocks_channels += [out_channels] + stage_channels = backbone_channels.pop() + connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels) + self.connectors[f'connector{block_i}'] = connector + in_channels = connector.output_channels + + self.blocks_connected[f'block{block_i}'] = ConvBlock( + in_channels, out_channels, + norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, + padding=int(block_i < depth - 1) + ) + self.blocks_channels += [out_channels] + + def forward(self, x, backbone_features): + backbone_features = [] if backbone_features is None else backbone_features[::-1] + + outputs = [self.block0(x)] + outputs += [self.block1(outputs[-1])] + + for block_i in range(2, self.depth): + output = outputs[-1] + connector_name = f'connector{block_i}' + if connector_name in self.connectors: + if self.INRDecode: + block = self.blocks_connected[f'block{block_i}_decode'] + outputs += [block(output)] + + stage_features = backbone_features.pop() + connector = self.connectors[connector_name] + output = connector(output, stage_features) + block = self.blocks_connected[f'block{block_i}'] + outputs += [block(output)] + + return outputs[::-1] + + +class DeconvDecoder(nn.Module): + def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False): + super(DeconvDecoder, self).__init__() + self.image_fusion = image_fusion + self.deconv_blocks = nn.ModuleList() + + in_channels = encoder_blocks_channels.pop() + out_channels = in_channels + for d in range(depth): + out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 + self.deconv_blocks.append(SEDeconvBlock( + in_channels, out_channels, + norm_layer=norm_layer, + padding=0 if d == 0 else 1, + with_se=0 <= attend_from <= d + )) + in_channels = out_channels + + if self.image_fusion: + self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1) + self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1) + + def forward(self, encoder_outputs, image, mask=None): + output = encoder_outputs[0] + for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): + output = block(output, mask) + output = output + skip_output + output = self.deconv_blocks[-1](output, mask) + + if self.image_fusion: + attention_map = torch.sigmoid(3.0 * self.conv_attention(output)) + output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output) + else: + output = self.to_rgb(output) + + return output + + +class SEDeconvBlock(nn.Module): + def __init__( + self, + in_channels, out_channels, + kernel_size=4, stride=2, padding=1, + norm_layer=nn.BatchNorm2d, activation=nn.ELU, + with_se=False + ): + super(SEDeconvBlock, self).__init__() + self.with_se = with_se + self.block = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + activation(), + ) + if self.with_se: + self.se = MaskedChannelAttention(out_channels) + + def forward(self, x, mask=None): + out = self.block(x) + if self.with_se: + out = self.se(out, mask) + return out + + +class INRDecoder(nn.Module): + def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from): + super(INRDecoder, self).__init__() + self.INR_encoding = None + if opt.embedding_type == "PosEncodingNeRF": + self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size) + elif opt.embedding_type == "RandomFourier": + self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device) + elif opt.embedding_type == "CIPS_embed": + self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32) + elif opt.embedding_type == "INRGAN_embed": + self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size) + else: + raise NotImplementedError + encoder_blocks_channels = encoder_blocks_channels[::-1] + max_hidden_mlp_num = attend_from + 1 + self.opt = opt + self.max_hidden_mlp_num = max_hidden_mlp_num + self.content_mlp_blocks = nn.ModuleDict() + for n in range(max_hidden_mlp_num): + if n != max_hidden_mlp_num - 1: + self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(), + [self.INR_encoding.out_dim + opt.INR_MLP_dim + ( + 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim], + opt, n + 1) + else: + self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(), + [self.INR_encoding.out_dim + ( + 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim], + opt, n + 1) + + self.deconv_blocks = nn.ModuleList() + + encoder_blocks_channels = encoder_blocks_channels[::-1] + in_channels = encoder_blocks_channels.pop() + out_channels = in_channels + for d in range(depth - attend_from): + out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 + self.deconv_blocks.append(SEDeconvBlock( + in_channels, out_channels, + norm_layer=norm_layer, + padding=0 if d == 0 else 1, + with_se=False + )) + in_channels = out_channels + + self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim], + (opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2, + opt, 2, toRGB=True) + + self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim, + None, opt) + + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None): + """For full resolution, do split.""" + if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, + 'split_resolution')) and self.opt.isFullRes: + return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples) + + encoder_outputs = encoder_outputs[::-1] + mlp_output = None + waitToRGB = [] + for n in range(self.max_hidden_mlp_num): + if not self.opt.hr_train: + coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \ + .unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) + else: + if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'): + coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view( + encoder_outputs[0].shape[0], -1, 2) + else: + coord = misc.get_mgrid( + self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat( + encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) + + """Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper.""" + if self.opt.isMoreINRInput: + if not self.opt.isFullRes or ( + self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + res_h = res_w = np.sqrt(coord.shape[1]).astype(int) + else: + res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1)) + res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1)) + + res_image = torchvision.transforms.Resize([res_h, res_w])(image) + res_mask = torchvision.transforms.Resize([res_h, res_w])(mask) + coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1), + res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1) + else: + coord = self.INR_encoding(coord) + + """============ LRIP structure, see Section 3.3 ==============""" + + """Local MLPs.""" + if n == 0: + mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0), + self.opt, content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion) + waitToRGB.append(mlp_output[1]) + else: + mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + ( + 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0], + content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), + start_proportion=start_proportion) + waitToRGB.append(mlp_output[1]) + + encoder_outputs = encoder_outputs[::-1] + output = encoder_outputs[0] + for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): + output = block(output) + output = output + skip_output + output = self.deconv_blocks[-1](output) + + """Global MLPs.""" + app_mlp, app_params = self.appearance_mlps(output) + harm_out = [] + for id in range(len(waitToRGB)): + output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id], + appearance_mlp=app_mlp) + harm_out.append(output[0]) + + """Optional 3D LUT prediction.""" + fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None) + + return harm_out, fit_lut3d, lut_transform_image + + def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None, + resolution=None, start_proportion=None): + + activation = select_activation(opt.activation) + + output = None + + if content_mlp is not None: + if base_feat is not None: + coorinates = torch.cat([coorinates, base_feat], dim=2) + coorinates = lin2img(coorinates, resolution) + + if hasattr(opt, 'split_resolution'): + """ + Here we crop the needed MLPs according to the region of the split input patches. + Note that this only support inferencing square images. + """ + for idx in range(len(content_mlp)): + content_mlp[idx][0] = content_mlp[idx][0][:, + (content_mlp[idx][0].shape[1] * start_proportion[0]).int():( + content_mlp[idx][0].shape[1] * start_proportion[2]).int(), + (content_mlp[idx][0].shape[2] * start_proportion[1]).int():( + content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :, + :] + content_mlp[idx][1] = content_mlp[idx][1][:, + (content_mlp[idx][1].shape[1] * start_proportion[0]).int():( + content_mlp[idx][1].shape[1] * start_proportion[2]).int(), + (content_mlp[idx][1].shape[2] * start_proportion[1]).int():( + content_mlp[idx][1].shape[2] * start_proportion[3]).int(), + :, + :] + k_h = coorinates.shape[2] // content_mlp[0][0].shape[1] + k_w = coorinates.shape[3] // content_mlp[0][0].shape[1] + bs = coorinates.shape[0] + h_lr = w_lr = content_mlp[0][0].shape[1] + nci = INR_input_dim + + coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w) + coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view( + bs, h_lr, w_lr, int(k_h * k_w), nci) + + for id, layer in enumerate(content_mlp): + if id == 0: + output = torch.matmul(coorinates, layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + + output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute( + 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim) + + output_large = self.up(lin2img(output)) + + return output_large.view(bs, -1, opt.INR_MLP_dim), output + + k_h = coorinates.shape[2] // content_mlp[0][0].shape[1] + k_w = coorinates.shape[3] // content_mlp[0][0].shape[1] + bs = coorinates.shape[0] + h_lr = w_lr = content_mlp[0][0].shape[1] + nci = INR_input_dim + + """(evaluation or not HR training) and not fullres evaluation""" + if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not ( + not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train): + coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w) + coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view( + bs, h_lr, w_lr, int(k_h * k_w), nci) + + for id, layer in enumerate(content_mlp): + if id == 0: + output = torch.matmul(coorinates, layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + + output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute( + 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim) + + output_large = self.up(lin2img(output)) + + return output_large.view(bs, -1, opt.INR_MLP_dim), output + else: + coorinates = coorinates.permute(0, 2, 3, 1) + for id, layer in enumerate(content_mlp): + weigt_shape = layer[0].shape + bias_shape = layer[1].shape + layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous() + layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous() + layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True + else 'bilinear', padding_mode='border', align_corners=False) + layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True + else 'bilinear', padding_mode='border', align_corners=False) + layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:]) + layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:]) + + if id == 0: + output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + + output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim) + + output_large = self.up(lin2img(output, resolution)) + + return output_large.view(bs, -1, opt.INR_MLP_dim), output + + elif appearance_mlp is not None: + output = base_feat + genMask = None + for id, layer in enumerate(appearance_mlp): + if id != len(appearance_mlp) - 1: + output = torch.matmul(output, layer[0]) + layer[1] + output = activation(output) + else: + output = torch.matmul(output, layer[0]) + layer[1] # last layer + if opt.activation == 'leakyrelu_pe': + output = torch.tanh(output) + return lin2img(output, resolution), None + + def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None): + encoder_outputs = encoder_outputs[::-1] + mlp_output = None + res_w = image.shape[-1] + res_h = image.shape[-2] + coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat( + encoder_outputs[0].shape[0], 1, 1).to(self.opt.device) + + if self.opt.isMoreINRInput: + coord = torch.cat( + [self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1), + mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1) + else: + coord = self.INR_encoding(coord, (res_h, res_w)) + + total = coord.clone() + + interval = 10 + all_intervals = math.ceil(res_h / interval) + divisible = True + if res_h / interval != res_h // interval: + divisible = False + + for n in range(self.max_hidden_mlp_num): + accum_mlp_output = [] + for line in range(all_intervals): + if not divisible and line == all_intervals - 1: + coord = total[:, line * interval * res_w:, :] + else: + coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :] + if n == 0: + accum_mlp_output.append(self.mlp_process(coord, + self.INR_encoding.out_dim + ( + 4 if self.opt.isMoreINRInput else 0), + self.opt, content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else + encoder_outputs[self.max_hidden_mlp_num - 1 - n]), + resolution=(interval, + res_w) if divisible or line != all_intervals - 1 else ( + res_h - interval * (all_intervals - 1), res_w))[1]) + + else: + accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + ( + 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:, + line * interval * res_w: ( + line + 1) * interval * res_w, + :] + if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :], + content_mlp=self.content_mlp_blocks[ + f"block{self.max_hidden_mlp_num - 1 - n}"]( + encoder_outputs.pop( + self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else + encoder_outputs[self.max_hidden_mlp_num - 1 - n]), + resolution=(interval, + res_w) if divisible or line != all_intervals - 1 else ( + res_h - interval * (all_intervals - 1), res_w))[1]) + + accum_mlp_output = torch.cat(accum_mlp_output, dim=1) + mlp_output = [accum_mlp_output, accum_mlp_output] + + encoder_outputs = encoder_outputs[::-1] + output = encoder_outputs[0] + for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): + output = block(output) + output = output + skip_output + output = self.deconv_blocks[-1](output) + + app_mlp, app_params = self.appearance_mlps(output) + harm_out = [] + + accum_mlp_output = [] + for line in range(all_intervals): + if not divisible and line == all_intervals - 1: + base = mlp_output[1][:, line * interval * res_w:, :] + else: + base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :] + + accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base, + appearance_mlp=app_mlp, + resolution=( + interval, + res_w) if divisible or line != all_intervals - 1 else ( + res_h - interval * (all_intervals - 1), res_w))[0]) + + accum_mlp_output = torch.cat(accum_mlp_output, dim=2) + harm_out.append(accum_mlp_output) + + fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None) + + return harm_out, fit_lut3d, lut_transform_image diff --git a/model/base/ih_model.py b/model/base/ih_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4dc531e41d99169dc35113e2c4bfcdb0aa5e67 --- /dev/null +++ b/model/base/ih_model.py @@ -0,0 +1,88 @@ +import torch +import torchvision +import torch.nn as nn + +from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder + +from .ops import ScaleLayer + + +class IHModelWithBackbone(nn.Module): + def __init__( + self, + model, backbone, + downsize_backbone_input=False, + mask_fusion='sum', + backbone_conv1_channels=64, opt=None + ): + super(IHModelWithBackbone, self).__init__() + self.downsize_backbone_input = downsize_backbone_input + self.mask_fusion = mask_fusion + + self.backbone = backbone + self.model = model + self.opt = opt + + self.mask_conv = nn.Sequential( + nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True), + ScaleLayer(init_value=0.1, lr_mult=1) + ) + + def forward(self, image, mask, coord=None, start_proportion=None): + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]) + backbone_mask = torch.cat( + (torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]), + 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) + else: + backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image) + backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask), + 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) + + backbone_mask_features = self.mask_conv(backbone_mask[:, :1]) + backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features) + + output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion) + return output + + +class DeepImageHarmonization(nn.Module): + def __init__( + self, + depth, + norm_layer=nn.BatchNorm2d, batchnorm_from=0, + attend_from=-1, + image_fusion=False, + ch=64, max_channels=512, + backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None + ): + super(DeepImageHarmonization, self).__init__() + self.depth = depth + self.encoder = ConvEncoder( + depth, ch, + norm_layer, batchnorm_from, max_channels, + backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode + ) + self.opt = opt + if opt.INRDecode: + "See Table 2 in the paper to test with different INR decoders' structures." + self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from) + else: + "Baseline: https://github.com/SamsungLabs/image_harmonization" + self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion) + + def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None): + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]), + torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) + else: + x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image), + torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) + + intermediates = self.encoder(x, backbone_features) + + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion) + else: + output = self.decoder(intermediates, image, mask) + return output diff --git a/model/base/ops.py b/model/base/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2a027c79b9995f8be59cf412c96ee05e8b5050 --- /dev/null +++ b/model/base/ops.py @@ -0,0 +1,397 @@ +import torch +from torch import nn as nn +import numpy as np +import math +import torch.nn.functional as F + + +class SimpleInputFusion(nn.Module): + def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d): + super(SimpleInputFusion, self).__init__() + + self.fusion_conv = nn.Sequential( + nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2), + norm_layer(ch), + nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1), + ) + + def forward(self, image, additional_input): + return self.fusion_conv(torch.cat((image, additional_input), dim=1)) + + +class MaskedChannelAttention(nn.Module): + def __init__(self, in_channels, *args, **kwargs): + super(MaskedChannelAttention, self).__init__() + self.global_max_pool = MaskedGlobalMaxPool2d() + self.global_avg_pool = FastGlobalAvgPool2d() + + intermediate_channels_count = max(in_channels // 16, 8) + self.attention_transform = nn.Sequential( + nn.Linear(3 * in_channels, intermediate_channels_count), + nn.ReLU(inplace=True), + nn.Linear(intermediate_channels_count, in_channels), + nn.Sigmoid(), + ) + + def forward(self, x, mask): + if mask.shape[2:] != x.shape[:2]: + mask = nn.functional.interpolate( + mask, size=x.size()[-2:], + mode='bilinear', align_corners=True + ) + pooled_x = torch.cat([ + self.global_max_pool(x, mask), + self.global_avg_pool(x) + ], dim=1) + channel_attention_weights = self.attention_transform(pooled_x)[..., None, None] + + return channel_attention_weights * x + + +class MaskedGlobalMaxPool2d(nn.Module): + def __init__(self): + super().__init__() + self.global_max_pool = FastGlobalMaxPool2d() + + def forward(self, x, mask): + return torch.cat(( + self.global_max_pool(x * mask), + self.global_max_pool(x * (1.0 - mask)) + ), dim=1) + + +class FastGlobalAvgPool2d(nn.Module): + def __init__(self): + super(FastGlobalAvgPool2d, self).__init__() + + def forward(self, x): + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + + +class FastGlobalMaxPool2d(nn.Module): + def __init__(self): + super(FastGlobalMaxPool2d, self).__init__() + + def forward(self, x): + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0] + + +class ScaleLayer(nn.Module): + def __init__(self, init_value=1.0, lr_mult=1): + super().__init__() + self.lr_mult = lr_mult + self.scale = nn.Parameter( + torch.full((1,), init_value / lr_mult, dtype=torch.float32) + ) + + def forward(self, x): + scale = torch.abs(self.scale * self.lr_mult) + return x * scale + + +class FeaturesConnector(nn.Module): + def __init__(self, mode, in_channels, feature_channels, out_channels): + super(FeaturesConnector, self).__init__() + self.mode = mode if feature_channels else '' + + if self.mode == 'catc': + self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1) + elif self.mode == 'sum': + self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1) + + self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels + + def forward(self, x, features): + if self.mode == 'cat': + return torch.cat((x, features), 1) + if self.mode == 'catc': + return self.reduce_conv(torch.cat((x, features), 1)) + if self.mode == 'sum': + return self.reduce_conv(features) + x + return x + + def extra_repr(self): + return self.mode + + +class PosEncodingNeRF(nn.Module): + def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True): + super().__init__() + + self.in_features = in_features + + if self.in_features == 3: + self.num_frequencies = 10 + elif self.in_features == 2: + assert sidelength is not None + if isinstance(sidelength, int): + sidelength = (sidelength, sidelength) + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1])) + elif self.in_features == 1: + assert fn_samples is not None + self.num_frequencies = 4 + if use_nyquist: + self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples) + + self.out_dim = in_features + 2 * in_features * self.num_frequencies + + def get_num_frequencies_nyquist(self, samples): + nyquist_rate = 1 / (2 * (2 * 1 / samples)) + return int(math.floor(math.log(nyquist_rate, 2))) + + def forward(self, coords): + coords = coords.view(coords.shape[0], -1, self.in_features) + + coords_pos_enc = coords + for i in range(self.num_frequencies): + for j in range(self.in_features): + c = coords[..., j] + + sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1) + cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1) + + coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1) + + return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim) + + +class RandomFourier(nn.Module): + def __init__(self, std_scale, embedding_length, device): + super().__init__() + + self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale + self.embed = self.embed.to(device) + + self.out_dim = embedding_length * 2 + 2 + + def forward(self, coords): + coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)), + torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1) + + return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1) + + +class CIPS_embed(nn.Module): + def __init__(self, size, embedding_length): + super().__init__() + self.fourier_embed = ConstantInput(size, embedding_length) + self.predict_embed = Predict_embed(embedding_length) + self.out_dim = embedding_length * 2 + 2 + + def forward(self, coord, res=None): + x = self.predict_embed(coord) + y = self.fourier_embed(x, coord, res) + + return torch.cat([coord, x, y], dim=-1) + + +class Predict_embed(nn.Module): + def __init__(self, embedding_length): + super(Predict_embed, self).__init__() + self.ffm = nn.Linear(2, embedding_length, bias=True) + nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2)) + + def forward(self, x): + x = self.ffm(x) + x = torch.sin(x) + return x + + +class ConstantInput(nn.Module): + def __init__(self, size, channel): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, size ** 2, channel)) + + def forward(self, input, coord, resolution=None): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1) + + if coord.shape[1] != self.input.shape[1]: + x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1], + int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5)) + + if resolution is None: + grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1]) + else: + grid = coord.view(coord.shape[0], *resolution, coord.shape[-1]) + + out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True) + + out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1]) + + return out + + +class INRGAN_embed(nn.Module): + def __init__(self, resolution: int, w_dim=None): + super().__init__() + + self.resolution = resolution + self.res_cfg = {"log_emb_size": 32, + "random_emb_size": 32, + "const_emb_size": 64, + "use_cosine": True} + self.log_emb_size = self.res_cfg.get('log_emb_size', 0) + self.random_emb_size = self.res_cfg.get('random_emb_size', 0) + self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0) + self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0) + self.const_emb_size = self.res_cfg.get('const_emb_size', 0) + self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10)) + self.use_cosine = self.res_cfg.get('use_cosine', False) + + if self.log_emb_size > 0: + self.register_buffer('log_basis', generate_logarithmic_basis( + resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False))) + + if self.random_emb_size > 0: + self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale)) + + if self.shared_emb_size > 0: + self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale)) + + if self.predictable_emb_size > 0: + self.W_size = self.predictable_emb_size * self.cfg.coord_dim + self.b_size = self.predictable_emb_size + self.affine = nn.Linear(w_dim, self.W_size + self.b_size) + + if self.const_emb_size > 0: + self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size)) + + self.out_dim = self.get_total_dim() + 2 + + def sample_w_matrix(self, shape, scale: float): + return torch.randn(shape) * scale + + def get_total_dim(self) -> int: + total_dim = 0 + if self.log_emb_size > 0: + total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1) + total_dim += self.random_emb_size * (2 if self.use_cosine else 1) + total_dim += self.shared_emb_size * (2 if self.use_cosine else 1) + total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1) + total_dim += self.const_emb_size + + return total_dim + + def forward(self, raw_coords, w=None): + batch_size, img_size, in_channels = raw_coords.shape + + raw_embs = [] + + if self.log_emb_size > 0: + log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1) + raw_log_embs = torch.matmul(raw_coords, log_bases) + raw_embs.append(raw_log_embs) + + if self.random_emb_size > 0: + random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1) + raw_random_embs = torch.matmul(raw_coords, random_bases) + raw_embs.append(raw_random_embs) + + if self.shared_emb_size > 0: + shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1) + raw_shared_embs = torch.matmul(raw_coords, shared_bases) + raw_embs.append(raw_shared_embs) + + if self.predictable_emb_size > 0: + mod = self.affine(w) + W = self.fourier_scale * mod[:, :self.W_size] + W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size) + bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size) + raw_predictable_embs = (torch.matmul(raw_coords, W) + bias) + raw_embs.append(raw_predictable_embs) + + if len(raw_embs) > 0: + raw_embs = torch.cat(raw_embs, dim=-1) + raw_embs = raw_embs.contiguous() + out = raw_embs.sin() + + if self.use_cosine: + out = torch.cat([out, raw_embs.cos()], dim=-1) + + if self.const_emb_size > 0: + const_embs = self.const_embs.repeat([batch_size, 1, 1]) + const_embs = const_embs + out = torch.cat([out, const_embs], dim=-1) + + return torch.cat([raw_coords, out], dim=-1) + + +def generate_logarithmic_basis( + resolution, + max_num_feats, + remove_lowest_freq: bool = False, + use_diagonal: bool = True): + """ + Generates a directional logarithmic basis with the following directions: + - horizontal + - vertical + - main diagonal + - anti-diagonal + """ + max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int) + bases = [ + generate_horizontal_basis(max_num_feats_per_direction), + generate_vertical_basis(max_num_feats_per_direction), + ] + + if use_diagonal: + bases.extend([ + generate_diag_main_basis(max_num_feats_per_direction), + generate_anti_diag_basis(max_num_feats_per_direction), + ]) + + if remove_lowest_freq: + bases = [b[1:] for b in bases] + + # If we do not fit into `max_num_feats`, then trying to remove the features in the order: + # 1) anti-diagonal 2) main-diagonal + # while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2): + # bases = bases[:-1] + + basis = torch.cat(bases, dim=0) + + # If we still do not fit, then let's remove each second feature, + # then each third, each forth and so on + # We cannot drop the whole horizontal or vertical direction since otherwise + # model won't be able to locate the position + # (unless the previously computed embeddings encode the position) + # while basis.shape[0] > max_num_feats: + # num_exceeding_feats = basis.shape[0] - max_num_feats + # basis = basis[::2] + + assert basis.shape[0] <= max_num_feats, \ + f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}." + + return basis + + +def generate_horizontal_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0) + + +def generate_vertical_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0) + + +def generate_diag_main_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) + + +def generate_anti_diag_basis(num_feats: int): + return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2)) + + +def generate_wavefront_basis(num_feats: int, basis_block, period_length: float): + period_coef = 2.0 * np.pi / period_length + basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2] + powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1] + result = basis * powers * period_coef # [num_feats, 2] + + return result.float() \ No newline at end of file diff --git a/model/build_model.py b/model/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9d91c125aa7caca6467d2a6af6c8973112019681 --- /dev/null +++ b/model/build_model.py @@ -0,0 +1,24 @@ +import torch.nn as nn +from .backbone import build_backbone + + +class build_model(nn.Module): + def __init__(self, opt): + super().__init__() + + self.opt = opt + self.backbone = build_backbone('baseline', opt) + + def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None): + if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): + """ + For HR Training, due to the designed RSC strategy in Section 3.4 in the paper, + here we need to pass in the coordinates of the cropped regions. + """ + extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion) + else: + extracted_features = self.backbone(composite_image, mask) + + if self.opt.INRDecode: + return extracted_features + return None, None, extracted_features \ No newline at end of file diff --git a/model/hrnetv2/__init__.py b/model/hrnetv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/hrnetv2/hrnet_ocr.py b/model/hrnetv2/hrnet_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..93e04ea0bbacb76fae0e71069558bccb6864c8ff --- /dev/null +++ b/model/hrnetv2/hrnet_ocr.py @@ -0,0 +1,405 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch._utils +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True, opt=None): + super(HighResolutionNet, self).__init__() + self.opt = opt + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.ocr_on = ocr_width > 0 + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + if self.ocr_on: + last_inp_channels = np.int_(np.sum(pre_stage_channels)) + ocr_mid_channels = 2 * ocr_width + ocr_key_channels = ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners, opt=opt) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, mask=None, additional_features=None): + hrnet_feats = self.compute_hrnet_feats(x, additional_features) + if not self.ocr_on: + return hrnet_feats, + + ocr_feats = self.conv3x3_ocr(hrnet_feats) + mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True) + context = self.ocr_gather_head(ocr_feats, mask) + ocr_feats = self.ocr_distri_head(ocr_feats, context) + return ocr_feats, + + def compute_hrnet_feats(self, x, additional_features, return_list=False): + x = self.compute_pre_stage_features(x, additional_features) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + if return_list: + return x + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def compute_pre_stage_features(self, x, additional_features): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + additional_features + x = self.conv2(x) + x = self.bn2(x) + return self.relu(x) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + + # Устанавливаем устройство, на котором будет работать модель + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Загружаем веса и перемещаем на выбранное устройство + pretrained_dict = torch.load(pretrained_path, map_location=device) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in pretrained_dict.items()} + params_count = len(pretrained_dict) + + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} + + print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet') + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + + # Перемещаем модель на устройство + self.to(device) diff --git a/model/hrnetv2/modifiers.py b/model/hrnetv2/modifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..046221838069e90ae201b9169db159cc69c13244 --- /dev/null +++ b/model/hrnetv2/modifiers.py @@ -0,0 +1,11 @@ + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult diff --git a/model/hrnetv2/ocr.py b/model/hrnetv2/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cfbb8eec51f1e5532b9e8b3e35c6a4e0757cff --- /dev/null +++ b/model/hrnetv2/ocr.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) \ + .permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c + return ocr_context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True, opt=None): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, + norm_layer, align_corners) + _in_channels = 2 * in_channels + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class ObjectAttentionBlock2D(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(ObjectAttentionBlock2D, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.align_corners = align_corners + + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), + mode='bilinear', align_corners=self.align_corners) + + return context diff --git a/model/hrnetv2/resnetv1b.py b/model/hrnetv2/resnetv1b.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad24cef5bde19f2627cfd3f755636f37cfb39ac --- /dev/null +++ b/model/hrnetv2/resnetv1b.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + + +class BasicBlockV1b(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BasicBlockV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn1 = norm_layer(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=previous_dilation, dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class BottleneckV1b(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BottleneckV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class ResNetV1b(nn.Module): + """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`nn.BatchNorm2d`) + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): + self.inplanes = stem_width*2 if deep_stem else 64 + super(ResNetV1b, self).__init__() + if not deep_stem: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, + norm_layer=norm_layer) + if dilated: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, + avg_down=avg_down, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Linear(512 * block.expansion, classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, + avg_down=False, norm_layer=nn.BatchNorm2d): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = [] + if avg_down: + if dilation == 1: + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + else: + downsample.append( + nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + ) + downsample.extend([ + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=1, bias=False), + norm_layer(planes * block.expansion) + ]) + downsample = nn.Sequential(*downsample) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + norm_layer(planes * block.expansion) + ) + + layers = [] + if dilation in (1, 2): + layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x + + +def _safe_state_dict_filtering(orig_dict, model_dict_keys): + filtered_orig_dict = {} + for k, v in orig_dict.items(): + if k in model_dict_keys: + filtered_orig_dict[k] = v + else: + print(f"[ERROR] Failed to load <{k}> in backbone") + return filtered_orig_dict + + +def resnet34_v1b(pretrained=False, **kwargs): + model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet50_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet101_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model diff --git a/model/lut_transformation_net.py b/model/lut_transformation_net.py new file mode 100644 index 0000000000000000000000000000000000000000..f119bc17e4d9ff3ac346bf85553061ec852177af --- /dev/null +++ b/model/lut_transformation_net.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.misc import normalize + + +class build_lut_transform(nn.Module): + + def __init__(self, input_dim, lut_dim, input_resolution, opt): + super().__init__() + + self.lut_dim = lut_dim + self.opt = opt + + # self.compress_layer = nn.Linear(input_resolution, 1) + + self.transform_layers = nn.Sequential( + nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True), + # nn.BatchNorm1d(3 * lut_dim ** 3, affine=False), + nn.ReLU(inplace=True), + nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True), + ) + self.transform_layers[-1].apply(lambda m: hyper_weight_init(m)) + + def forward(self, composite_image, fg_appearance_features, bg_appearance_features): + composite_image = normalize(composite_image, self.opt, 'inv') + + features = fg_appearance_features + + lut_params = self.transform_layers(features) + + fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim) + + lut_transform_image = torch.stack( + [TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0) + + return fit_3DLUT, normalize(lut_transform_image, self.opt) + + +def TrilinearInterpolation(LUT, img): + img = (img - 0.5) * 2. + + img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1) + + # Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW. + LUT = LUT[None] + + # grid sample + result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) + + # drop added dimensions and permute back + result = result[:, :, 0] + + return result + + +def hyper_weight_init(m): + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') + m.weight.data = m.weight.data / 1.e2 + + if hasattr(m, 'bias'): + with torch.no_grad(): + m.bias.uniform_(0., 1.) diff --git a/pretrained_models/Resolution_1024_HAdobe5K.pth b/pretrained_models/Resolution_1024_HAdobe5K.pth new file mode 100644 index 0000000000000000000000000000000000000000..be7712c1c3f781748eff3e713e314dbc101c503b --- /dev/null +++ b/pretrained_models/Resolution_1024_HAdobe5K.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4917e99cc20c2530b6d248d530368929c1784113d20365085b96bbb10860a2f8 +size 477235439 diff --git a/pretrained_models/Resolution_2048_HAdobe5K.pth b/pretrained_models/Resolution_2048_HAdobe5K.pth new file mode 100644 index 0000000000000000000000000000000000000000..ef4d0b31b61d544bd97be5cb5d97608c34497473 --- /dev/null +++ b/pretrained_models/Resolution_2048_HAdobe5K.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa3d076b5cbf653f17fbf02b95f45b95a0d38e6cb53eaeed71cac1fb22af6f69 +size 477235439 diff --git a/pretrained_models/Resolution_256_iHarmony4.pth b/pretrained_models/Resolution_256_iHarmony4.pth new file mode 100644 index 0000000000000000000000000000000000000000..4d364bd0939ce0ba3d393e6f3461aff8e31b2646 --- /dev/null +++ b/pretrained_models/Resolution_256_iHarmony4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70a7df7a5b8ba502b69d8dba3b9f47cb99e95f5ebb64b289f9f0454613d6f5b6 +size 477528743 diff --git a/pretrained_models/Resolution_RAW_HAdobe5K.pth b/pretrained_models/Resolution_RAW_HAdobe5K.pth new file mode 100644 index 0000000000000000000000000000000000000000..ebecc42d63627ca3dd76d91a50fbc99ce2eb66d9 --- /dev/null +++ b/pretrained_models/Resolution_RAW_HAdobe5K.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1829182a1a03e9bb5116ac166a9debfb6fdb97c8547790cc0d94bfe313e0c80 +size 953285076 diff --git a/pretrained_models/Resolution_RAW_iHarmony4.pth b/pretrained_models/Resolution_RAW_iHarmony4.pth new file mode 100644 index 0000000000000000000000000000000000000000..164d6467d05514225f26561b499b551aefc1ef40 --- /dev/null +++ b/pretrained_models/Resolution_RAW_iHarmony4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5475d7a2a77260f8c9503b02e46a7b456bd1a7a7d7b0fc3c85ef534b473eef8 +size 477235439 diff --git a/pretrained_models/hrnetv2_w18_imagenet_pretrained.pth b/pretrained_models/hrnetv2_w18_imagenet_pretrained.pth new file mode 100644 index 0000000000000000000000000000000000000000..5d6c001ff1443b030fdd11e72d60b9722b8f027d --- /dev/null +++ b/pretrained_models/hrnetv2_w18_imagenet_pretrained.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00eb200687c1ed1fc1042767e5b965052f4c7338823ba4d36a790979a078b36b +size 85758673 diff --git a/processing.py b/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..f4628a7a3c742c2de86d4424545f0fc43c5443d1 --- /dev/null +++ b/processing.py @@ -0,0 +1,308 @@ +import os +import time +import datetime + +import torch +import torchvision + +from utils import misc, metrics + +best_psnr = 0 + + +def train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, logger, opt): + total_step = opt.epochs * len(train_loader) + + step_time_log = misc.AverageMeter() + loss_log = misc.AverageMeter(':6f') + loss_fg_content_bg_appearance_construct_log = misc.AverageMeter(':6f') + loss_lut_transform_image_log = misc.AverageMeter(':6f') + loss_lut_regularize_log = misc.AverageMeter(':6f') + + start_epoch = 0 + + "Load pretrained checkpoints" + if opt.pretrained is not None: + logger.info(f"Load pretrained weight from {opt.pretrained}") + load_state = torch.load(opt.pretrained) + model = model.cpu() + model.load_state_dict(load_state['model']) + model = model.to(opt.device) + optimizer.load_state_dict(load_state['optimizer']) + scheduler.load_state_dict(load_state['scheduler']) + start_epoch = load_state['last_epoch'] + 1 + + for epoch in range(start_epoch, opt.epochs): + model.train() + time_ckp = time.time() + for step, batch in enumerate(train_loader): + current_step = epoch * len(train_loader) + step + 1 + + if opt.INRDecode and opt.hr_train: + "List with 4 elements: [Input to Encoder, three different resolutions' crop to INR Decoder]" + composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)] + real_image = [batch[f'real_image{name}'].to(opt.device) for name in range(4)] + mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)] + coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)] + + fg_INR_coordinates = coordinate_map[1:] + + else: + composite_image = batch['composite_image'].to(opt.device) + real_image = batch['real_image'].to(opt.device) + mask = batch['mask'].to(opt.device) + + fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device) + + fg_content_bg_appearance_construct, fit_lut3d, lut_transform_image = model( + composite_image, mask, fg_INR_coordinates) + + if opt.INRDecode: + loss_fg_content_bg_appearance_construct = 0 + """ + Our LRIP module requires three different resolution layers, thus here + `loss_fg_content_bg_appearance_construct` is calculated in multiple layers. + Besides, when leverage `hr_train`, i.e. use RSC strategy (See Section 3.4), the `real_image` + and `mask` are list type, corresponding different resolutions' crop. + """ + if opt.hr_train: + for n in range(3): + loss_fg_content_bg_appearance_construct += loss_fn['masked_mse'] \ + (fg_content_bg_appearance_construct[n], real_image[3 - n], mask[3 - n]) + loss_fg_content_bg_appearance_construct /= 3 + loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image[1], mask[1]) + else: + for n in range(3): + loss_fg_content_bg_appearance_construct += loss_fn['MaskWeightedMSE'] \ + (fg_content_bg_appearance_construct[n], + torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(real_image), + torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(mask)) + loss_fg_content_bg_appearance_construct /= 3 + loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask) + loss_lut_regularize = loss_fn['regularize_LUT'](fit_lut3d) + + else: + loss_fg_content_bg_appearance_construct = 0 + loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask) + loss_lut_regularize = 0 + + loss = loss_fg_content_bg_appearance_construct + loss_lut_transform_image + loss_lut_regularize + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + step_time_log.update(time.time() - time_ckp) + + loss_fg_content_bg_appearance_construct_log.update(0 if isinstance(loss_fg_content_bg_appearance_construct, + int) else loss_fg_content_bg_appearance_construct.item()) + loss_lut_transform_image_log.update( + 0 if isinstance(loss_lut_transform_image, int) else loss_lut_transform_image.item()) + loss_lut_regularize_log.update(0 if isinstance(loss_lut_regularize, int) else loss_lut_regularize.item()) + loss_log.update(loss.item()) + + if current_step % opt.print_freq == 0: + remain_secs = (total_step - current_step) * step_time_log.avg + remain_time = datetime.timedelta(seconds=round(remain_secs)) + finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs)) + + log_msg = f'Epoch: [{epoch}/{opt.epochs}]\t' \ + f'Step: [{step}/{len(train_loader)}]\t' \ + f'StepTime {step_time_log.val:.3f} ({step_time_log.avg:.3f})\t' \ + f'lr {optimizer.param_groups[0]["lr"]}\t' \ + f'Loss {loss_log.val:.4f} ({loss_log.avg:.4f})\t' \ + f'Loss_fg_bg_cons {loss_fg_content_bg_appearance_construct_log.val:.4f} ({loss_fg_content_bg_appearance_construct_log.avg:.4f})\t' \ + f'Loss_lut_trans {loss_lut_transform_image_log.val:.4f} ({loss_lut_transform_image_log.avg:.4f})\t' \ + f'Loss_lut_reg {loss_lut_regularize_log.val:.4f} ({loss_lut_regularize_log.avg:.4f})\t' \ + f'Remaining Time {remain_time} ({finish_time})' + logger.info(log_msg) + + if opt.wandb: + import wandb + wandb.log( + {'Train/Epoch': epoch, 'Train/lr': optimizer.param_groups[0]['lr'], 'Train/Step': current_step, + 'Train/Loss': loss_log.val, + 'Train/Loss_fg_bg_cons': loss_fg_content_bg_appearance_construct_log.val, + 'Train/Loss_lut_trans': loss_lut_transform_image_log.val, + 'Train/Loss_lut_reg': loss_lut_regularize_log.val, + }) + + time_ckp = time.time() + + state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch, + 'scheduler': scheduler.state_dict()} + + """ + As the validation of original resolution Harmonization will have no consistent resolution among images + (so fail to form a batch) and also may lead to out-of-memory problem when combined with training phase, + we here only save the model when `opt.isFullRes` is True, leaving the evaluation in `inference.py`. + """ + if opt.isFullRes and opt.hr_train: + if epoch % 5 == 0: + torch.save(state, os.path.join(opt.save_path, f"epoch{epoch}.pth")) + else: + torch.save(state, os.path.join(opt.save_path, "last.pth")) + else: + val(val_loader, model, logger, opt, state) + + +def val(val_loader, model, logger, opt, state): + global best_psnr + current_process = 10 + model.eval() + + metric_log = { + 'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + } + + lut_metric_log = { + 'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + 'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0}, + } + + for step, batch in enumerate(val_loader): + composite_image = batch['composite_image'].to(opt.device) + real_image = batch['real_image'].to(opt.device) + mask = batch['mask'].to(opt.device) + category = batch['category'] + + fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device) + bg_INR_coordinates = batch['bg_INR_coordinates'].to(opt.device) + fg_transfer_INR_RGB = batch['fg_transfer_INR_RGB'].to(opt.device) + + with torch.no_grad(): + fg_content_bg_appearance_construct, _, lut_transform_image = model( + composite_image, + mask, + fg_INR_coordinates, + bg_INR_coordinates) + if opt.INRDecode: + pred_fg_image = fg_content_bg_appearance_construct[-1] + else: + pred_fg_image = None + fg_transfer_INR_RGB = misc.lin2img(fg_transfer_INR_RGB, + val_loader.dataset.INR_dataset.size) if fg_transfer_INR_RGB is not None else None + + "For INR" + mask_INR = torchvision.transforms.Resize(opt.INR_input_size)(mask) + + if not opt.INRDecode: + pred_harmonized_image = None + else: + pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.)) + lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.)) + + "Save the output images. For every 10 epochs, save more results, otherwise, save little. Thus save storage." + if state['last_epoch'] % 10 == 0: + misc.visualize(real_image, composite_image, mask, pred_fg_image, + pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False, + wandb=opt.wandb, isAll=True, step=step) + elif step == 0: + misc.visualize(real_image, composite_image, mask, pred_fg_image, + pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False, + wandb=opt.wandb, step=step) + + if opt.INRDecode: + mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'), + misc.normalize(fg_transfer_INR_RGB, opt, 'inv'), mask_INR) + + lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'), + misc.normalize(real_image, opt, 'inv'), mask) + + for idx in range(len(category)): + if opt.INRDecode: + metric_log[category[idx]]['Samples'] += 1 + metric_log[category[idx]]['MSE'] += mse[idx] + metric_log[category[idx]]['fMSE'] += fmse[idx] + metric_log[category[idx]]['PSNR'] += psnr[idx] + metric_log[category[idx]]['SSIM'] += ssim[idx] + + metric_log['All']['Samples'] += 1 + metric_log['All']['MSE'] += mse[idx] + metric_log['All']['fMSE'] += fmse[idx] + metric_log['All']['PSNR'] += psnr[idx] + metric_log['All']['SSIM'] += ssim[idx] + + lut_metric_log[category[idx]]['Samples'] += 1 + lut_metric_log[category[idx]]['MSE'] += lut_mse[idx] + lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx] + lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx] + lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx] + + lut_metric_log['All']['Samples'] += 1 + lut_metric_log['All']['MSE'] += lut_mse[idx] + lut_metric_log['All']['fMSE'] += lut_fmse[idx] + lut_metric_log['All']['PSNR'] += lut_psnr[idx] + lut_metric_log['All']['SSIM'] += lut_ssim[idx] + + if (step + 1) / len(val_loader) * 100 >= current_process: + logger.info(f'Processing: {current_process}') + current_process += 10 + + logger.info('=========================') + for key in metric_log.keys(): + if opt.INRDecode: + msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \ + f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n" + else: + msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \ + f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n" + + logger.info(msg) + + if opt.wandb: + import wandb + if opt.INRDecode: + wandb.log( + {f'Val/{key}/Epoch': state['last_epoch'], + f'Val/{key}/MSE': metric_log[key]['MSE'] / metric_log[key]['Samples'], + f'Val/{key}/fMSE': metric_log[key]['fMSE'] / metric_log[key]['Samples'], + f'Val/{key}/PSNR': metric_log[key]['PSNR'] / metric_log[key]['Samples'], + f'Val/{key}/SSIM': metric_log[key]['SSIM'] / metric_log[key]['Samples'], + f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'], + f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'], + f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'], + f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples'] + }) + else: + wandb.log( + {f'Val/{key}/Epoch': state['last_epoch'], + f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'], + f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'], + f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'], + f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples'] + }) + + logger.info('=========================') + + if not opt.INRDecode: + if lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples'] > best_psnr: + logger.info("Best Save!") + best_psnr = lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples'] + torch.save(state, os.path.join(opt.save_path, "best.pth")) + else: + logger.info("Last Save!") + torch.save(state, os.path.join(opt.save_path, "last.pth")) + else: + if metric_log['All']['PSNR'] / metric_log['All']['Samples'] > best_psnr: + logger.info("Best Save!") + best_psnr = metric_log['All']['PSNR'] / metric_log['All']['Samples'] + torch.save(state, os.path.join(opt.save_path, "best.pth")) + else: + logger.info("Last Save!") + torch.save(state, os.path.join(opt.save_path, "last.pth")) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4fdc7bff506d2d312d81ba223c9d29c116ebbd3a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch +torchvision +adamp +gradio +spaces +opencv-python \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/constructDataset.py b/tools/constructDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..10e9ff67bc531f959835bc1db84a28390b3eb42f --- /dev/null +++ b/tools/constructDataset.py @@ -0,0 +1,33 @@ +import os +import shutil + +# root = r"G:\Datasets\Images Harmonization\LR_real_composite_images_99_DIH\all" +# +# all_path = os.listdir(os.path.join(root, "image")) +# +# with open(os.path.join(root, "dataset.txt"), mode='w') as f: +# for im in all_path: +# f.write(os.path.join('image', im) + "\n") +# +# print("Done!") + + +# Re-order dataset +with open(r"G:\Datasets\Images Harmonization\iHarmony4\IHD_test.txt", mode="r") as f: + names = f.readlines() + +for id in range(len(names)): + names[id] = names[id].strip().split("/")[-1].split(".")[0] + +root = r"G:\ComputerPrograms\Image_Harmonization\Supervised_Harmonization\logs\HINet_2048×2048_HAdobe5k\figs\-1" +os.makedirs(os.path.join(root, "reorder"), exist_ok=True) +allFiles = os.listdir(root) +for id, file in enumerate(allFiles): + if "pred_harmonized_image" in file: + name = names[int(file.split("_")[0])] + "_" + file + shutil.copy(os.path.join(root, file), os.path.join(root, "reorder", name)) + shutil.copy(os.path.join(root, file).replace("pred_harmonized_image", "mask"), os.path.join(root, "reorder", name).replace("pred_harmonized_image", "mask")) + shutil.copy(os.path.join(root, file).replace("pred_harmonized_image", "real"), os.path.join(root, "reorder", name).replace("pred_harmonized_image", "real")) + shutil.copy(os.path.join(root, file).replace("pred_harmonized_image", "composite"), os.path.join(root, "reorder", name).replace("pred_harmonized_image", "composite")) + +print("Done!") diff --git a/tools/resize_Adobe.py b/tools/resize_Adobe.py new file mode 100644 index 0000000000000000000000000000000000000000..ce971ab2d2af260c567140443bd1f1cddb539747 --- /dev/null +++ b/tools/resize_Adobe.py @@ -0,0 +1,45 @@ +import cv2 +import shutil +from tqdm import tqdm +from pathlib import Path + +max_size = 1024 +input_dataset_path = r'.\iHarmony4\HAdobe5k' +output_path = f'{input_dataset_path}_resized{max_size}' + +input_dataset_path = Path(input_dataset_path) +output_path = Path(output_path) + +assert not output_path.exists() + +output_path.mkdir() +for subfolder in ['composite_images', 'masks', 'real_images']: + (output_path / subfolder).mkdir() + +for annotation_path in input_dataset_path.glob('*.txt'): + shutil.copy(annotation_path, output_path / annotation_path.name) + +images_list = sorted(input_dataset_path.rglob('*.jpg')) +images_list.extend(sorted(input_dataset_path.rglob('*.png'))) + +for x in tqdm(images_list): + image = cv2.imread(str(x), cv2.IMREAD_UNCHANGED) + new_path = output_path / x.relative_to(input_dataset_path) + + if max(image.shape[:2]) <= max_size: + shutil.copy(x, new_path) + continue + + new_width = max_size + new_height = max_size + scale = max_size / max(image.shape[:2]) + if image.shape[0] > image.shape[1]: + new_width = int(round(scale * image.shape[1])) + else: + new_height = int(round(scale * image.shape[0])) + + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) + if x.suffix == '.jpg': + cv2.imwrite(str(new_path), image, [cv2.IMWRITE_JPEG_QUALITY, 90]) + else: + cv2.imwrite(str(new_path), image) \ No newline at end of file diff --git a/tools/resize_HR_Adobe.py b/tools/resize_HR_Adobe.py new file mode 100644 index 0000000000000000000000000000000000000000..d38bbdbf8a55b7ad5a88e73fc4c4c6ed7d874941 --- /dev/null +++ b/tools/resize_HR_Adobe.py @@ -0,0 +1,36 @@ +import cv2 +import shutil +from tqdm import tqdm +from pathlib import Path + +max_size = 2048 +input_dataset_path = r'.\iHarmony4\HAdobe5k' +output_path = f'{input_dataset_path}_resized{max_size}_{max_size}' + +input_dataset_path = Path(input_dataset_path) +output_path = Path(output_path) + +assert not output_path.exists() + +output_path.mkdir() +for subfolder in ['composite_images', 'masks', 'real_images']: + (output_path / subfolder).mkdir() + +for annotation_path in input_dataset_path.glob('*.txt'): + shutil.copy(annotation_path, output_path / annotation_path.name) + +images_list = sorted(input_dataset_path.rglob('*.jpg')) +images_list.extend(sorted(input_dataset_path.rglob('*.png'))) + +for x in tqdm(images_list): + image = cv2.imread(str(x), cv2.IMREAD_UNCHANGED) + new_path = output_path / x.relative_to(input_dataset_path) + + new_width = max_size + new_height = max_size + + image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) + if x.suffix == '.jpg': + cv2.imwrite(str(new_path), image, [cv2.IMWRITE_JPEG_QUALITY, 90]) + else: + cv2.imwrite(str(new_path), image) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..856c188564d2fd20c3c3fb3be5675c6326142fa7 --- /dev/null +++ b/train.py @@ -0,0 +1,161 @@ +import os +import argparse + +import albumentations +from albumentations import HorizontalFlip, Resize, RandomResizedCrop + +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from torch.optim import lr_scheduler + +import processing +from utils import build_loss, misc +from model.build_model import build_model +from datasets.build_dataset import dataset_generator + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--workers', type=int, default=8, + metavar='N', help='Dataloader threads.') + + parser.add_argument('--batch_size', type=int, default=16, + help='You can override model batch size by specify positive number.') + + parser.add_argument('--device', type=str, default='cuda', + help="Whether use cuda, 'cuda' or 'cpu'.") + + parser.add_argument('--epochs', type=int, default=60, + help='Epochs number.') + + parser.add_argument('--lr', type=int, default=1e-4, + help='Learning rate.') + + parser.add_argument('--save_path', type=str, default="./logs", + help='Where to save logs and checkpoints.') + + parser.add_argument('--dataset_path', type=str, default=r".\iHarmony4", + help='Dataset path.') + + parser.add_argument('--print_freq', type=int, default=100, + help='Number of iterations then print.') + + parser.add_argument('--base_size', type=int, default=256, + help='Base size. Resolution of the image input into the Encoder') + + parser.add_argument('--input_size', type=int, default=256, + help='Input size. Resolution of the image that want to be generated by the Decoder') + + parser.add_argument('--INR_input_size', type=int, default=256, + help='INR input size. Resolution of the image that want to be generated by the Decoder. ' + 'Should be the same as `input_size`') + + parser.add_argument('--INR_MLP_dim', type=int, default=32, + help='Number of channels for INR linear layer.') + + parser.add_argument('--LUT_dim', type=int, default=7, + help='Dim of the output LUT. Refer to https://ieeexplore.ieee.org/abstract/document/9206076') + + parser.add_argument('--activation', type=str, default='leakyrelu_pe', + help='INR activation layer type: leakyrelu_pe, sine') + + parser.add_argument('--pretrained', type=str, + default=None, + help='Pretrained weight path') + + parser.add_argument('--param_factorize_dim', type=int, + default=10, + help='The intermediate dimensions of the factorization of the predicted MLP parameters. ' + 'Refer to https://arxiv.org/abs/2011.12026') + + parser.add_argument('--embedding_type', type=str, + default="CIPS_embed", + help='Which embedding_type to use.') + + parser.add_argument('--optim', type=str, + default='adamw', + help='Which optimizer to use.') + + parser.add_argument('--INRDecode', action="store_false", + help='Whether INR decoder. Set it to False if you want to test the baseline ' + '(https://github.com/SamsungLabs/image_harmonization)') + + parser.add_argument('--isMoreINRInput', action="store_false", + help='Whether to cat RGB and mask. See Section 3.4 in the paper.') + + parser.add_argument('--hr_train', action="store_true", + help='Whether use hr_train. See section 3.4 in the paper.') + + parser.add_argument('--isFullRes', action="store_true", + help='Whether for original resolution. See section 3.4 in the paper.') + + opt = parser.parse_args() + + opt.save_path = misc.increment_path(os.path.join(opt.save_path, "exp1")) + + try: + import wandb + opt.wandb = True + wandb.init(config=opt, project="INR_Harmonization", name=os.path.basename(opt.save_path)) + + except: + opt.wandb = False + + return opt + + +def main_process(opt): + logger = misc.create_logger(os.path.join(opt.save_path, "log.txt")) + cudnn.benchmark = True + + trainset_path = os.path.join(opt.dataset_path, "IHD_train.txt") + valset_path = os.path.join(opt.dataset_path, "IHD_test.txt") + + opt.transform_mean = [.5, .5, .5] + opt.transform_var = [.5, .5, .5] + torch_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(opt.transform_mean, opt.transform_var)]) + + trainset_alb_transform = albumentations.Compose( + [ + RandomResizedCrop(opt.input_size, opt.input_size, scale=(0.5, 1.0)), + HorizontalFlip()], + additional_targets={'real_image': 'image', 'object_mask': 'image'} + ) + + valset_alb_transform = albumentations.Compose([Resize(opt.input_size, opt.input_size)], + additional_targets={'real_image': 'image', 'object_mask': 'image'}) + + trainset = dataset_generator(trainset_path, trainset_alb_transform, torch_transform, opt, mode='Train') + + valset = dataset_generator(valset_path, valset_alb_transform, torch_transform, opt, mode='Val') + + train_loader = DataLoader(trainset, opt.batch_size, shuffle=True, drop_last=True, + pin_memory=True, + num_workers=opt.workers, persistent_workers=True) + + val_loader = DataLoader(valset, opt.batch_size, shuffle=False, drop_last=False, pin_memory=True, + num_workers=opt.workers, persistent_workers=True) + + model = build_model(opt).to(opt.device) + + loss_fn = build_loss.loss_generator() + + optimizer_params = { + 'lr': opt.lr, + 'weight_decay': 1e-2 + } + optimizer = misc.get_optimizer(model, opt.optim, optimizer_params) + + scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=opt.epochs * len(train_loader), + pct_start=0.0) + + processing.train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, logger, opt) + + +if __name__ == '__main__': + opt = parse_args() + os.makedirs(opt.save_path, exist_ok=True) + main_process(opt) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/build_loss.py b/utils/build_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..01ebe4bba88be6a3b611f69809a9c9960aefd9ae --- /dev/null +++ b/utils/build_loss.py @@ -0,0 +1,76 @@ +import torch + + +def loss_generator(ignore: list = None): + loss_fn = {'mse': mse, + 'lut_mse': lut_mse, + 'masked_mse': masked_mse, + 'sample_weighted_mse': sample_weighted_mse, + 'regularize_LUT': regularize_LUT, + 'MaskWeightedMSE': MaskWeightedMSE} + + if ignore: + for fn in ignore: + ignore.pop(fn) + + return loss_fn + + +def mse(pred, gt): + return torch.mean((pred - gt) ** 2) + + +def masked_mse(pred, gt, mask): + delimin = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() + # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) + out = torch.sum((mask > 100 / 255.) * (pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) + out = out / delimin + return torch.mean(out) + + +def sample_weighted_mse(pred, gt, mask): + multi_factor = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() + multi_factor = multi_factor / (multi_factor.sum()) + # total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) + out = torch.mean((pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) + out = out * multi_factor + return torch.sum(out) + + +def regularize_LUT(lut): + st = lut[lut < 0.] + reg_st = (st ** 2).mean() if min(st.shape) != 0 else 0 + + lt = lut[lut > 1.] + reg_lt = ((lt - 1.) ** 2).mean() if min(lt.shape) != 0 else 0 + + return reg_lt + reg_st + + +def lut_mse(feat, lut_batch): + loss = 0 + for id in range(feat.shape[0] // lut_batch): + for i in feat[id * lut_batch: id * lut_batch + lut_batch]: + for j in feat[id * lut_batch: id * lut_batch + lut_batch]: + loss += mse(i, j) + + return loss / lut_batch + + +def MaskWeightedMSE(pred, label, mask): + label = label.view(pred.size()) + reduce_dims = get_dims_with_exclusion(label.dim(), 0) + + loss = (pred - label) ** 2 + delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), 100) + loss = torch.sum(loss, dim=reduce_dims) / delimeter + + return torch.mean(loss) + + +def get_dims_with_exclusion(dim, exclude=None): + dims = list(range(dim)) + if exclude is not None: + dims.remove(exclude) + + return dims \ No newline at end of file diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0cad855fd24bd9fbad0803f1d77101e667215c --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,58 @@ +import skimage + +import torch +import numpy as np +from pytorch_msssim import ssim +import math + + +def calc_metrics(harmonized, real, mask_batch): + n, c, h, w = harmonized.shape + + mse = [] + fmse = [] + psnr = [] + ssim = [] + for id in range(n): + # fg = (mask_batch[id]).view(-1) + # fg_pixels = int(torch.sum(fg).cpu().numpy()) + # total_pixels = h * w + # + # pred = torch.clamp(harmonized[id] * 255, 0, 255) + # gt = torch.clamp(real[id] * 255, 0, 255) + # + # pred = pred.permute(1, 2, 0).cpu().numpy() + # gt = gt.permute(1, 2, 0).cpu().numpy() + # mask = mask_batch[id].permute(1, 2, 0).cpu().numpy() + # + # mse.append(skimage.metrics.mean_squared_error(pred, gt)) + # fmse.append(skimage.metrics.mean_squared_error(pred * mask, gt * mask) * total_pixels / fg_pixels) + # psnr.append(skimage.metrics.peak_signal_noise_ratio(pred, gt, data_range=pred.max() - pred.min())) + # ssim.append(skimage.metrics.structural_similarity(pred, gt, multichannel=True)) + mse.append(MSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + fmse.append(fMSE(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + psnr.append(PSNR(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + ssim.append(SSIM(torch.clamp(harmonized[id] * 255, 0, 255), torch.clamp(real[id] * 255, 0, 255), mask_batch[id])) + + return mse, fmse, psnr, ssim + + +def SSIM(pred, target_image, mask): + pred = pred * mask + (target_image) * (1 - mask) + return ssim(pred.unsqueeze(0), target_image.unsqueeze(0)) + + +def MSE(pred, target_image, mask): + return (mask * (pred - target_image) ** 2).mean().item() + + +def fMSE(pred, target_image, mask): + diff = mask * ((pred - target_image) ** 2) + return (diff.sum() / (diff.size(0) * mask.sum() + 1e-6)).item() + + +def PSNR(pred, target_image, mask): + mse = (mask * (pred - target_image) ** 2).mean().item() + squared_max = target_image.max().item() ** 2 + + return 10 * math.log10(squared_max / (mse + 1e-6)) \ No newline at end of file diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..b1419b5af7ace016621c0e3e8934c6397aef2a20 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,380 @@ +import re +from pathlib import Path +import glob +import logging +import numpy as np +import torch +import cv2 +import os +import math +from adamp import AdamP +import random +import torch.nn as nn + +_logger = None + + +def increment_path(path): + # Increment path, i.e. runs/exp1 --> runs/exp{sep}1, runs/exp{sep}2 etc. + res = re.search("\d+", path) + if res is None: + print("Set initial exp number!") + exit(1) + + if not Path(path).exists(): + return str(path) + else: + path = path[:res.start()] + dirs = glob.glob(f"{path}*") # similar paths + matches = [re.search(rf"%s(\d+)" % Path(path).stem, d) for d in dirs] + i = [int(m.groups()[0]) for m in matches if m] # indices + n = max(i) + 1 # increment number + return f"{path}{n}" # update path + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, fmt=':f'): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def create_logger(log_file, level=logging.INFO): + global _logger + _logger = logging.getLogger() + formatter = logging.Formatter( + '[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s') + fh = logging.FileHandler(log_file) + fh.setFormatter(formatter) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + _logger.setLevel(level) + _logger.addHandler(fh) + _logger.addHandler(sh) + + return _logger + + +def get_mgrid(sidelen, dim=2): + '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' + if isinstance(sidelen, int): + sidelen = dim * (sidelen,) + + if dim == 2: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) + pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) + elif dim == 3: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) + pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) + pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) + else: + raise NotImplementedError('Not implemented for dim=%d' % dim) + + pixel_coords -= 0.5 + pixel_coords *= 2. + pixel_coords = torch.Tensor(pixel_coords).view(-1, dim) + return pixel_coords + + +def lin2img(tensor, image_resolution=None): + batch_size, num_samples, channels = tensor.shape + if image_resolution is None: + width = np.sqrt(num_samples).astype(int) + height = width + else: + if isinstance(image_resolution, int): + image_resolution = (image_resolution, image_resolution) + height = image_resolution[0] + width = image_resolution[1] + + return tensor.permute(0, 2, 1).contiguous().view(batch_size, channels, height, width) + + +def normalize(x, opt, mode='normal'): + device = x.device + mean = torch.tensor(np.array(opt.transform_mean), dtype=x.dtype)[np.newaxis, :, np.newaxis, np.newaxis].to(device) + var = torch.tensor(np.array(opt.transform_var), dtype=x.dtype)[np.newaxis, :, np.newaxis, np.newaxis].to(device) + if mode == 'normal': + return (x - mean) / var + elif mode == 'inv': + return x * var + mean + + +def prepare_cooridinate_input(mask, dim=2): + '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.''' + if mask.shape[0] == mask.shape[1]: + sidelen = mask.shape[0] + else: + sidelen = mask.shape[:2] + + if isinstance(sidelen, int): + sidelen = dim * (sidelen,) + + if dim == 2: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[0, :, :, 0] = pixel_coords[0, :, :, 0] / (sidelen[0] - 1) + pixel_coords[0, :, :, 1] = pixel_coords[0, :, :, 1] / (sidelen[1] - 1) + elif dim == 3: + pixel_coords = np.stack(np.mgrid[:sidelen[0], :sidelen[1], :sidelen[2]], axis=-1)[None, ...].astype(np.float32) + pixel_coords[..., 0] = pixel_coords[..., 0] / max(sidelen[0] - 1, 1) + pixel_coords[..., 1] = pixel_coords[..., 1] / (sidelen[1] - 1) + pixel_coords[..., 2] = pixel_coords[..., 2] / (sidelen[2] - 1) + else: + raise NotImplementedError('Not implemented for dim=%d' % dim) + + pixel_coords -= 0.5 + pixel_coords *= 2. + return pixel_coords.squeeze(0).transpose(2, 0, 1) + + +def visualize(real, composite, mask, pred_fg, pred_harmonized, lut_transform_image, opt, epoch, + show=False, wandb=True, isAll=False, step=None): + save_path = os.path.join(opt.save_path, "figs", str(epoch)) + os.makedirs(save_path, exist_ok=True) + + if isAll: + final_index = 1 + + """ + Uncomment the following code if you want to save all the results, otherwise will only save the first image + of each batch + """ + # final_index = len(real) + else: + final_index = 1 + + for id in range(final_index): + if show: + cv2.imshow("pred_fg", normalize(pred_fg, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("real", normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("lut_transform", normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("composite", normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("mask", mask[id].permute(1, 2, 0).cpu().numpy()) + cv2.imshow("pred_harmonized_image", + normalize(pred_harmonized, opt, 'inv')[id].permute(1, 2, 0).cpu().numpy()) + cv2.waitKey() + + if not opt.INRDecode: + real_tmp = cv2.cvtColor( + normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), + cv2.COLOR_RGB2BGR) + composite_tmp = cv2.cvtColor( + normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), cv2.COLOR_RGB2BGR) + mask_tmp = mask[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype(np.uint8) + lut_transform_image_tmp = cv2.cvtColor( + normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + else: + pred_fg_tmp = cv2.cvtColor( + normalize(pred_fg, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), cv2.COLOR_RGB2BGR) + real_tmp = cv2.cvtColor( + normalize(real, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), + cv2.COLOR_RGB2BGR) + composite_tmp = cv2.cvtColor( + normalize(composite, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype( + np.uint8), cv2.COLOR_RGB2BGR) + lut_transform_image_tmp = cv2.cvtColor( + normalize(lut_transform_image, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + mask_tmp = mask[id].permute(1, 2, 0).cpu().mul_(255.).clamp_(0., 255.).numpy().astype(np.uint8) + pred_harmonized_tmp = cv2.cvtColor( + normalize(pred_harmonized, opt, 'inv')[id].permute(1, 2, 0).cpu().mul_(255.).clamp_( + 0., 255.).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) + + if isAll: + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_composite.jpg"), composite_tmp) + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_real.jpg"), real_tmp) + if opt.INRDecode: + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_pred_harmonized_image.jpg"), pred_harmonized_tmp) + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_lut_transform_image.jpg"), lut_transform_image_tmp) + cv2.imwrite(os.path.join(save_path, f"{step}_{id}_mask.jpg"), mask_tmp) + else: + if not opt.INRDecode: + cv2.imwrite(os.path.join(save_path, f"real_{step}_{id}.jpg"), real_tmp) + cv2.imwrite(os.path.join(save_path, f"composite_{step}_{id}.jpg"), composite_tmp) + cv2.imwrite(os.path.join(save_path, f"mask_{step}_{id}.jpg"), mask_tmp) + cv2.imwrite(os.path.join(save_path, f"lut_transform_image_{step}_{id}.jpg"), lut_transform_image_tmp) + else: + cv2.imwrite(os.path.join(save_path, f"pred_fg_{step}_{id}.jpg"), pred_fg_tmp) + cv2.imwrite(os.path.join(save_path, f"real_{step}_{id}.jpg"), real_tmp) + cv2.imwrite(os.path.join(save_path, f"composite_{step}_{id}.jpg"), composite_tmp) + cv2.imwrite(os.path.join(save_path, f"mask_{step}_{id}.jpg"), mask_tmp) + cv2.imwrite(os.path.join(save_path, f"pred_harmonized_image_{step}_{id}.jpg"), pred_harmonized_tmp) + cv2.imwrite(os.path.join(save_path, f"lut_transform_image_{step}_{id}.jpg"), lut_transform_image_tmp) + + "Only upload images of the first batch of the first epoch to save storage." + if wandb and id == 0 and step == 0: + import wandb + real_tmp = wandb.Image(real_tmp, caption=epoch) + composite_tmp = wandb.Image(composite_tmp, caption=epoch) + if opt.INRDecode: + pred_fg_tmp = wandb.Image(pred_fg_tmp, caption=epoch) + pred_harmonized_tmp = wandb.Image(pred_harmonized_tmp, caption=epoch) + lut_transform_image_tmp = wandb.Image(lut_transform_image_tmp, caption=epoch) + mask_tmp = wandb.Image(mask_tmp, caption=epoch) + if not opt.INRDecode: + wandb.log( + {"pic/real": real_tmp, "pic/composite": composite_tmp, + "pic/mask": mask_tmp, + "pic/lut_trans": lut_transform_image_tmp, + "pic/epoch": epoch}) + else: + wandb.log( + {"pic/pred_fg": pred_fg_tmp, "pic/real": real_tmp, "pic/composite": composite_tmp, + "pic/mask": mask_tmp, + "pic/lut_trans": lut_transform_image_tmp, + "pic/pred_harmonized": pred_harmonized_tmp, + "pic/epoch": epoch}) + wandb.log({}) + + +def get_optimizer(model, opt_name, opt_kwargs): + params = [] + base_lr = opt_kwargs['lr'] + for name, param in model.named_parameters(): + param_group = {'params': [param]} + if not param.requires_grad: + params.append(param_group) + continue + + if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): + # print(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') + param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult + + params.append(param_group) + + optimizer = { + 'sgd': torch.optim.SGD, + 'adam': torch.optim.Adam, + 'adamw': torch.optim.AdamW, + 'adamp': AdamP + }[opt_name.lower()](params, **opt_kwargs) + + return optimizer + + +def improved_efficient_matmul(a, c, index, batch=256): + """ + Reduce the unneed memory cost, but the speed is very slow. + :param a: N * I * J + :param b: N * J * K + :return: N * I * K + """ + "The first can only support when a is not requires_grad_, and have high speed. While the second one supports " + "whatever situations, but speed is quite slow. More Details in " + "https://discuss.pytorch.org/t/many-weird-phenomena-about-torch-matmul-operation/158208" + + # out = torch.cat( + # [torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch:i * batch + batch], :, :]) for i in + # range(a.shape[0] // batch)], dim=0) + + batch = 1 + out = torch.cat( + [torch.matmul(a[i * batch:i * batch + batch, :, :], c[index[i * batch], :, :]) for i in + range(a.shape[0] // batch)], dim=0) + + return out + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult + + +def customRandomCrop(objects, crop_height, crop_width, h_start=None, w_start=None): + if h_start is None: + h_start = random.random() + if w_start is None: + w_start = random.random() + if isinstance(objects, list): + out = [] + for obj in objects: + out.append(random_crop(obj, crop_height, crop_width, h_start, w_start)) + + else: + out = random_crop(objects, crop_height, crop_width, h_start, w_start) + + return out, h_start, w_start + + +def get_random_crop_coords(height: int, width: int, crop_height: int, crop_width: int, h_start: float, + w_start: float): + y1 = int((height - crop_height) * h_start) + y2 = y1 + crop_height + x1 = int((width - crop_width) * w_start) + x2 = x1 + crop_width + return x1, y1, x2, y2 + + +def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float): + height, width = img.shape[:2] + if height < crop_height or width < crop_width: + raise ValueError( + "Requested crop size ({crop_height}, {crop_width}) is " + "larger than the image size ({height}, {width})".format( + crop_height=crop_height, crop_width=crop_width, height=height, width=width + ) + ) + x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start) + img = img[y1:y2, x1:x2] + return img + + +class PadToDivisor: + def __init__(self, divisor): + super().__init__() + self.divisor = divisor + + def transform(self, images): + + self._pads = (*self._get_dim_padding(images[0].shape[-1]), *self._get_dim_padding(images[0].shape[-2])) + self.pad_operation = nn.ZeroPad2d(padding=self._pads) + + out = [] + for im in images: + out.append(self.pad_operation(im)) + + return out + + def inv_transform(self, image): + assert self._pads is not None,\ + 'Something went wrong, inv_transform(...) should be called after transform(...)' + return self._remove_padding(image) + + def _get_dim_padding(self, dim_size): + pad = (self.divisor - dim_size % self.divisor) % self.divisor + pad_upper = pad // 2 + pad_lower = pad - pad_upper + + return pad_upper, pad_lower + + def _remove_padding(self, tensors): + tensor_h, tensor_w = tensors[0].shape[-2:] + out = [] + for t in tensors: + out.append(t[..., self._pads[2]:tensor_h - self._pads[3], self._pads[0]:tensor_w - self._pads[1]]) + return out