Enes Bol commited on
Commit
fd4bbc8
·
1 Parent(s): 5f9fb3b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. __pycache__/config.cpython-39.pyc +0 -0
  3. __pycache__/dataloader.cpython-39.pyc +0 -0
  4. __pycache__/inference.cpython-39.pyc +0 -0
  5. __pycache__/trainer.cpython-39.pyc +0 -0
  6. app.py +207 -0
  7. config.py +47 -0
  8. data/custom_dataset/images.jpg +0 -0
  9. data/processed/110000026240767.jpg +0 -0
  10. data/processed/200630_colekt_pack0937__la_chambre__bottle_50ml__final__16x9-copy-scaled.jpg +0 -0
  11. data/processed/images (1).jpg +0 -0
  12. data/processed/images.jpg +0 -0
  13. data/processed/indir (1).jpg +0 -0
  14. data/processed/photo-1541643600914-78b084683601.jpg +0 -0
  15. data/processed/smart-collection-512-narciso-rodriguez-black-600x800-0.jpg +0 -0
  16. dataloader.py +147 -0
  17. demo_run.sh +11 -0
  18. edge_generator.py +38 -0
  19. img/Poster.png +0 -0
  20. inference.py +89 -0
  21. main.py +55 -0
  22. mask/custom_dataset/110000026240767.png +0 -0
  23. mask/custom_dataset/200630_colekt_pack0937__la_chambre__bottle_50ml__final__16x9-copy-scaled.png +0 -0
  24. mask/custom_dataset/TOMBUL.png +0 -0
  25. mask/custom_dataset/images (1).png +0 -0
  26. mask/custom_dataset/images.png +0 -0
  27. mask/custom_dataset/indir (1).png +0 -0
  28. mask/custom_dataset/indir.png +0 -0
  29. mask/custom_dataset/photo-1541643600914-78b084683601.png +0 -0
  30. model/EfficientNet.py +356 -0
  31. model/TRACER.py +58 -0
  32. model/__pycache__/EfficientNet.cpython-39.pyc +0 -0
  33. model/__pycache__/TRACER.cpython-39.pyc +0 -0
  34. modules/__pycache__/att_modules.cpython-39.pyc +0 -0
  35. modules/__pycache__/conv_modules.cpython-39.pyc +0 -0
  36. modules/att_modules.py +297 -0
  37. modules/conv_modules.py +56 -0
  38. object/custom_dataset/110000026240767.png +0 -0
  39. object/custom_dataset/200630_colekt_pack0937__la_chambre__bottle_50ml__final__16x9-copy-scaled.png +0 -0
  40. object/custom_dataset/43a71f50-b839-4ab8-8be8-5be346ffe8be.png +0 -0
  41. object/custom_dataset/TOMBUL.png +0 -0
  42. object/custom_dataset/images (1).png +0 -0
  43. object/custom_dataset/images.png +0 -0
  44. object/custom_dataset/indir (1).png +0 -0
  45. object/custom_dataset/indir (2).png +0 -0
  46. object/custom_dataset/indir.png +0 -0
  47. object/custom_dataset/kabe-a-1912087.png +0 -0
  48. object/custom_dataset/photo-1541643600914-78b084683601.png +0 -0
  49. requirements.txt +28 -0
  50. trainer.py +293 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 Min Seok (Karel) Lee
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
__pycache__/config.cpython-39.pyc ADDED
Binary file (1.85 kB). View file
 
__pycache__/dataloader.cpython-39.pyc ADDED
Binary file (4.51 kB). View file
 
__pycache__/inference.cpython-39.pyc ADDED
Binary file (3.43 kB). View file
 
__pycache__/trainer.cpython-39.pyc ADDED
Binary file (7.99 kB). View file
 
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import subprocess
4
+ from PIL import Image, ImageOps
5
+ import torch
6
+ from diffusers import StableDiffusionInpaintPipeline
7
+ import transformers
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.image as mpimg
10
+ import cv2
11
+ import diffusers
12
+ import accelerate
13
+ import warnings
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ warnings.filterwarnings("ignore")
18
+
19
+ st.title('Background Generation')
20
+
21
+ st.write('This app generates new backgrounds for images.')
22
+
23
+ # set environment variable for dll
24
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
25
+
26
+ @st.cache_data
27
+ def mode(width, height):
28
+ output_width = np.floor_divide(width, 8) * 8
29
+ output_height = np.floor_divide(height, 8) * 8
30
+ return output_width, output_height
31
+
32
+ def get_prompt():
33
+ prompt = st.text_input('Enter your prompt here:', placeholder="Imagine our perfume bottle amidst a lush garden, surrounded by blooming flowers and vibrant colors.")
34
+ return prompt
35
+
36
+ def get_negative_prompt():
37
+ negative_prompt = st.text_input('Enter your negative prompt here:', placeholder="low quality, out of frame, watermark.. etc.")
38
+ return negative_prompt
39
+
40
+ def get_user_input():
41
+ st.subheader("Upload an image file, Press Clean Background Button.")
42
+ uploaded_file = st.file_uploader("Upload a JPG image file", type=["jpg", "jpeg"])
43
+
44
+ if uploaded_file is not None:
45
+ user_file_path = os.path.join("data/custom_dataset/", uploaded_file.name)
46
+
47
+ # Open the uploaded image
48
+ uploaded_image = Image.open(uploaded_file)
49
+
50
+ # Check if the width is larger than 640
51
+ if uploaded_image.width > 640:
52
+ # Calculate the proportional height based on the desired width of 640 pixels
53
+ aspect_ratio = uploaded_image.width / uploaded_image.height
54
+ resized_height = int(640 / aspect_ratio)
55
+ # Resize the image to a width of 640 pixels and proportional height
56
+ resized_image = uploaded_image.resize((640, resized_height))
57
+ else:
58
+ resized_image = uploaded_image
59
+
60
+ return resized_image, user_file_path
61
+
62
+ return None, None
63
+
64
+
65
+ def clean_files(directory):
66
+ files = os.listdir(directory)
67
+ for file in files:
68
+ file_path = os.path.join(directory, file)
69
+ if os.path.isfile(file_path):
70
+ os.remove(file_path)
71
+
72
+ uploaded_file, user_file_path = get_user_input()
73
+ button_1 = st.button("Clean Background")
74
+
75
+ button_1_clicked = False # Variable to track button state
76
+
77
+ def run_subprocess():
78
+ mask_created = False
79
+ command = "python main.py inference --dataset custom_dataset/ --arch 7 --img_size 640 --save_map True"
80
+ subprocess.run(command, shell=True)
81
+ mask_created = True
82
+
83
+
84
+ # Perform the necessary actions when the "Clean Background" button is clicked
85
+ st.write(button_1)
86
+
87
+ # Log data for analyzing the app later.
88
+ def log(copy = False):
89
+ custom_dataset_directory = "data/custom_dataset/"
90
+ processed_directory = "data/processed"
91
+ for filename in os.listdir(custom_dataset_directory):
92
+ file_path = os.path.join(custom_dataset_directory, filename)
93
+
94
+ if copy == True:
95
+ shutil.copy(file_path, processed_directory) # Copy files
96
+ else:
97
+ shutil.move(file_path, processed_directory) # Move files
98
+
99
+
100
+ def load_images():
101
+ x = user_file_path.split('/')[-1]
102
+ uploaded_file_name = os.path.basename(user_file_path)
103
+ image_path = os.path.join("data/custom_dataset/", x)
104
+ dif_image = Image.open(image_path)
105
+
106
+ mask_path = os.path.join("mask/custom_dataset/", x.replace('.jpg', '.png'))
107
+ png_image = Image.open(mask_path)
108
+ inverted_image = ImageOps.invert(png_image)
109
+ return dif_image , inverted_image
110
+
111
+ if button_1:
112
+ button_1_clicked = True
113
+ # Move items from data/custom_dataset/ to data/processed
114
+ log( copy= True)
115
+ clean_files("data/custom_dataset/")
116
+ if uploaded_file is not None:
117
+ uploaded_file.save(user_file_path)
118
+ run_subprocess()
119
+ st.success("Background cleaned.")
120
+ log(copy = True)
121
+ dif_image , inverted_image = load_images()
122
+
123
+
124
+ st.subheader("Text your prompt and choose parameters, then press Run Model button")
125
+
126
+ # Create a two-column layout
127
+ col1, col2 = st.columns(2)
128
+
129
+ # Get user input for prompts
130
+ with col1:
131
+ input_prompt = st.text_area('Enter Prompt', height=80)
132
+ with col2:
133
+ input_negative_prompt = st.text_area('Enter Negative Prompt', height=80)
134
+
135
+ num_inference_steps = st.slider('Number of Inference Steps:', min_value=5, max_value=50, value=10)
136
+ num_images_per_prompt = st.slider('Image Count to be Produced:', min_value=1, max_value=2, value=1)
137
+
138
+ # use seed with torch generator
139
+ torch.manual_seed(0)
140
+ # seed
141
+ seed = st.slider('Seed:', min_value=0, max_value=100, value=1)
142
+ generator = [torch.Generator(device="cuda").manual_seed(seed) for i in range(num_images_per_prompt)]
143
+
144
+ #generator = torch.Generator(device="cuda").manual_seed(0)
145
+ run_model_button = st.button("Run Model")
146
+
147
+ @st.cache_resource
148
+ def initialize_pipe():
149
+ pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",
150
+ revision="fp16",
151
+ torch_dtype=torch.float16,
152
+ safety_checker = None,
153
+ requires_safety_checker = False).to("cuda")
154
+
155
+ pipe.safety_checker = None
156
+ pipe.requires_safety_checker = False
157
+ return pipe
158
+
159
+ def image_resize(dif_image):
160
+ output_width, output_height = mode(dif_image.width, dif_image.height)
161
+ while output_height > 800:
162
+ output_height = output_height // 1.5
163
+ output_width = output_width // 1.5
164
+ output_width, output_height = mode(output_width, output_height)
165
+ return output_width, output_height
166
+
167
+
168
+ def show_output(x5):
169
+ if len(x5) == 1:
170
+ col1, col2 = st.columns(2)
171
+ with col1 :
172
+ st.image(inverted_image, width=256, caption='Generated Mask', use_column_width=False)
173
+ with col2:
174
+ st.image(x5[0], width=256, caption='Generated Image', use_column_width=False)
175
+
176
+ elif len(x5) == 2:
177
+ col1, col2, col3 = st.columns(3)
178
+ with col1 :
179
+ col1.image(inverted_image, width=256, caption='Generated Mask', use_column_width=False)
180
+ with col2 :
181
+ col2.image(x5[0], width=256, caption='Gener ted Image', use_column_width=False)
182
+ with col3 :
183
+ col3.image(x5[1], width=256, caption='Generated Image-2', use_column_width=False)
184
+
185
+ # Check if the button is clicked and all inputs are provided
186
+ if run_model_button == True and input_prompt is not None :
187
+ st.write("Running the model...")
188
+ dif_image , inverted_image = load_images()
189
+ output_width, output_height = image_resize(dif_image)
190
+ base_prompt = "high resolution, high quality, use mask. Do not distort the shape of the object. make the object stand out, show it clearly and vividly, preserving the shape of the object, use the mask"
191
+ prompt = input_prompt + " " + base_prompt
192
+
193
+ st.write("Pipe working with {0} inference steps and {1} image will be created for prompt".format(num_inference_steps, num_images_per_prompt))
194
+
195
+ pipe = initialize_pipe()
196
+
197
+ output_height = 128
198
+ output_width = 128
199
+
200
+ x5 = pipe(image=dif_image, mask_image=inverted_image, num_inference_steps=num_inference_steps, generator= generator,
201
+ num_images_per_prompt=num_images_per_prompt, prompt=prompt, negative_prompt=input_negative_prompt,
202
+ height=output_height, width=output_width).images
203
+
204
+ show_output(x5)
205
+ torch.cuda.empty_cache()
206
+ else:
207
+ st.write("Please provide prompt and click the 'Run Model' button to proceed.")
config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def getConfig():
4
+ parser = argparse.ArgumentParser()
5
+ parser.add_argument('action', type=str, default='train', help='Model Training or Testing options')
6
+ parser.add_argument('--exp_num', default=0, type=str, help='experiment_number')
7
+ parser.add_argument('--dataset', type=str, default='DUTS', help='DUTS')
8
+ parser.add_argument('--data_path', type=str, default='data/')
9
+
10
+ # Model parameter settings
11
+ parser.add_argument('--arch', type=str, default='0', help='Backbone Architecture')
12
+ parser.add_argument('--channels', type=list, default=[24, 40, 112, 320])
13
+ parser.add_argument('--RFB_aggregated_channel', type=int, nargs='*', default=[32, 64, 128])
14
+ parser.add_argument('--frequency_radius', type=int, default=16, help='Frequency radius r in FFT')
15
+ parser.add_argument('--denoise', type=float, default=0.93, help='Denoising background ratio')
16
+ parser.add_argument('--gamma', type=float, default=0.1, help='Confidence ratio')
17
+
18
+ # Training parameter settings
19
+ parser.add_argument('--img_size', type=int, default=320)
20
+ parser.add_argument('--batch_size', type=int, default=32)
21
+ parser.add_argument('--epochs', type=int, default=100)
22
+ parser.add_argument('--lr', type=float, default=5e-5)
23
+ parser.add_argument('--optimizer', type=str, default='Adam')
24
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
25
+ parser.add_argument('--criterion', type=str, default='API', help='API or bce')
26
+ parser.add_argument('--scheduler', type=str, default='Reduce', help='Reduce or Step')
27
+ parser.add_argument('--aug_ver', type=int, default=2, help='1=Normal, 2=Hard')
28
+ parser.add_argument('--lr_factor', type=float, default=0.1)
29
+ parser.add_argument('--clipping', type=float, default=2, help='Gradient clipping')
30
+ parser.add_argument('--patience', type=int, default=5, help="Scheduler ReduceLROnPlateau's parameter & Early Stopping(+5)")
31
+ parser.add_argument('--model_path', type=str, default='results/')
32
+ parser.add_argument('--seed', type=int, default=42)
33
+ parser.add_argument('--save_map', type=bool, default=None, help='Save prediction map')
34
+
35
+
36
+ # Hardware settings
37
+ parser.add_argument('--multi_gpu', type=bool, default=True)
38
+ parser.add_argument('--num_workers', type=int, default=4)
39
+ cfg = parser.parse_args()
40
+
41
+ return cfg
42
+
43
+
44
+ if __name__ == '__main__':
45
+ cfg = getConfig()
46
+ cfg = vars(cfg)
47
+ print(cfg)
data/custom_dataset/images.jpg ADDED
data/processed/110000026240767.jpg ADDED
data/processed/200630_colekt_pack0937__la_chambre__bottle_50ml__final__16x9-copy-scaled.jpg ADDED
data/processed/images (1).jpg ADDED
data/processed/images.jpg ADDED
data/processed/indir (1).jpg ADDED
data/processed/photo-1541643600914-78b084683601.jpg ADDED
data/processed/smart-collection-512-narciso-rodriguez-black-600x800-0.jpg ADDED
dataloader.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import glob
3
+ import torch
4
+ import numpy as np
5
+ import albumentations as albu
6
+ from pathlib import Path
7
+ from albumentations.pytorch.transforms import ToTensorV2
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from sklearn.model_selection import train_test_split
10
+
11
+
12
+ class DatasetGenerate(Dataset):
13
+ def __init__(self, img_folder, gt_folder, edge_folder, phase: str = 'train', transform=None, seed=None):
14
+ self.images = sorted(glob.glob(img_folder + '/*'))
15
+ self.gts = sorted(glob.glob(gt_folder + '/*'))
16
+ self.edges = sorted(glob.glob(edge_folder + '/*'))
17
+ self.transform = transform
18
+
19
+ train_images, val_images, train_gts, val_gts, train_edges, val_edges = train_test_split(self.images, self.gts,
20
+ self.edges,
21
+ test_size=0.05,
22
+ random_state=seed)
23
+ if phase == 'train':
24
+ self.images = train_images
25
+ self.gts = train_gts
26
+ self.edges = train_edges
27
+ elif phase == 'val':
28
+ self.images = val_images
29
+ self.gts = val_gts
30
+ self.edges = val_edges
31
+ else: # Testset
32
+ pass
33
+
34
+ def __getitem__(self, idx):
35
+ image = cv2.imread(self.images[idx])
36
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
37
+ mask = cv2.imread(self.gts[idx])
38
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
39
+ edge = cv2.imread(self.edges[idx])
40
+ edge = cv2.cvtColor(edge, cv2.COLOR_BGR2GRAY)
41
+
42
+ if self.transform is not None:
43
+ augmented = self.transform(image=image, masks=[mask, edge])
44
+ image = augmented['image']
45
+ mask = np.expand_dims(augmented['masks'][0], axis=0) # (1, H, W)
46
+ mask = mask / 255.0
47
+ edge = np.expand_dims(augmented['masks'][1], axis=0) # (1, H, W)
48
+ edge = edge / 255.0
49
+
50
+ return image, mask, edge
51
+
52
+ def __len__(self):
53
+ return len(self.images)
54
+
55
+
56
+ class Test_DatasetGenerate(Dataset):
57
+ def __init__(self, img_folder, gt_folder=None, transform=None):
58
+ self.images = sorted(glob.glob(img_folder + '/*'))
59
+ self.gts = sorted(glob.glob(gt_folder + '/*')) if gt_folder is not None else None
60
+ self.transform = transform
61
+
62
+ def __getitem__(self, idx):
63
+ image_name = Path(self.images[idx]).stem
64
+ image = cv2.imread(self.images[idx])
65
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
66
+ original_size = image.shape[:2]
67
+
68
+ if self.transform is not None:
69
+ augmented = self.transform(image=image)
70
+ image = augmented['image']
71
+
72
+ if self.gts is not None:
73
+ return image, self.gts[idx], original_size, image_name
74
+ else:
75
+ return image, original_size, image_name
76
+
77
+ def __len__(self):
78
+ return len(self.images)
79
+
80
+
81
+ def get_loader(img_folder, gt_folder, edge_folder, phase: str, batch_size, shuffle,
82
+ num_workers, transform, seed=None):
83
+ if phase == 'test':
84
+ dataset = Test_DatasetGenerate(img_folder, gt_folder, transform)
85
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
86
+ else:
87
+ dataset = DatasetGenerate(img_folder, gt_folder, edge_folder, phase, transform, seed)
88
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
89
+ drop_last=True)
90
+
91
+ print(f'{phase} length : {len(dataset)}')
92
+
93
+ return data_loader
94
+
95
+
96
+ def get_train_augmentation(img_size, ver):
97
+ if ver == 1:
98
+ transforms = albu.Compose([
99
+ albu.Resize(img_size, img_size, always_apply=True),
100
+ albu.Normalize([0.485, 0.456, 0.406],
101
+ [0.229, 0.224, 0.225]),
102
+ ToTensorV2(),
103
+ ])
104
+ if ver == 2:
105
+ transforms = albu.Compose([
106
+ albu.OneOf([
107
+ albu.HorizontalFlip(),
108
+ albu.VerticalFlip(),
109
+ albu.RandomRotate90()
110
+ ], p=0.5),
111
+ albu.OneOf([
112
+ albu.RandomContrast(),
113
+ albu.RandomGamma(),
114
+ albu.RandomBrightness(),
115
+ ], p=0.5),
116
+ albu.OneOf([
117
+ albu.MotionBlur(blur_limit=5),
118
+ albu.MedianBlur(blur_limit=5),
119
+ albu.GaussianBlur(blur_limit=5),
120
+ albu.GaussNoise(var_limit=(5.0, 20.0)),
121
+ ], p=0.5),
122
+ albu.Resize(img_size, img_size, always_apply=True),
123
+ albu.Normalize([0.485, 0.456, 0.406],
124
+ [0.229, 0.224, 0.225]),
125
+ ToTensorV2(),
126
+ ])
127
+ return transforms
128
+
129
+
130
+ def get_test_augmentation(img_size):
131
+ transforms = albu.Compose([
132
+ albu.Resize(img_size, img_size, always_apply=True),
133
+ albu.Normalize([0.485, 0.456, 0.406],
134
+ [0.229, 0.224, 0.225]),
135
+ ToTensorV2(),
136
+ ])
137
+ return transforms
138
+
139
+
140
+ def gt_to_tensor(gt):
141
+ gt = cv2.imread(gt)
142
+ gt = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY) / 255.0
143
+ gt = np.where(gt > 0.5, 1.0, 0.0)
144
+ gt = torch.tensor(gt, device='cuda', dtype=torch.float32)
145
+ gt = gt.unsqueeze(0).unsqueeze(1)
146
+
147
+ return gt
demo_run.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #TRACER
2
+ #├── data
3
+ #│ ├── custom_dataset
4
+ #│ │ ├── sample_image1.png
5
+ #│ │ ├── sample_image2.png
6
+ # .
7
+ # .
8
+ # .
9
+
10
+ # For testing TRACER with pre-trained model (e.g.)
11
+ python main.py inference --dataset custom_dataset/ --arch 7 --img_size 640 --save_map True
edge_generator.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Min Seok Lee and Wooseok Shin
3
+ TRACER: Extreme Attention Guided Salient Object Tracing Network
4
+ git repo: https://github.com/Karel911/TRACER
5
+ """
6
+ import os
7
+ import cv2
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ # Append custom datasets below list
12
+ dataset_list = ['DUTS', 'DUT-O', 'HKU-IS', 'ECSSD', 'PASCAL-S']
13
+
14
+
15
+ def edge_generator(dataset):
16
+ if dataset == 'DUTS':
17
+ mask_path = os.path.join('data/', dataset, 'Train/masks/')
18
+ else:
19
+ mask_path = os.path.join('data/', dataset, 'Test/masks/')
20
+ save_path = os.path.join('data/', dataset, 'Train/edges/')
21
+ os.makedirs(save_path, exist_ok=True)
22
+ mask_list = os.listdir(mask_path)
23
+
24
+ for i, img_name in tqdm(enumerate(mask_list)):
25
+ mask = cv2.imread(mask_path + img_name)
26
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
27
+ mask = np.int64(mask > 128)
28
+
29
+ [gy, gx] = np.gradient(mask)
30
+ tmp_edge = gy * gy + gx * gx
31
+ tmp_edge[tmp_edge != 0] = 1
32
+ bound = np.uint8(tmp_edge * 255)
33
+ cv2.imwrite(os.path.join(save_path, f'{img_name}'), bound)
34
+
35
+
36
+ if __name__ == '__main__':
37
+ for dataset in dataset_list:
38
+ edge_generator(dataset)
img/Poster.png ADDED
inference.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ author: Min Seok Lee and Wooseok Shin
3
+ """
4
+ import os
5
+ import cv2
6
+ import time
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchvision.transforms import transforms
12
+ from tqdm import tqdm
13
+ from dataloader import get_test_augmentation, get_loader
14
+ from model.TRACER import TRACER
15
+ from util.utils import load_pretrained
16
+
17
+
18
+ class Inference():
19
+ def __init__(self, args, save_path):
20
+ super(Inference, self).__init__()
21
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ self.test_transform = get_test_augmentation(img_size=args.img_size)
23
+ self.args = args
24
+ self.save_path = save_path
25
+
26
+ # Network
27
+ self.model = TRACER(args).to(self.device)
28
+ if args.multi_gpu:
29
+ self.model = nn.DataParallel(self.model).to(self.device)
30
+
31
+ path = load_pretrained(f'TE-{args.arch}')
32
+ self.model.load_state_dict(path)
33
+ print('###### pre-trained Model restored #####')
34
+
35
+ te_img_folder = os.path.join(args.data_path, args.dataset)
36
+ te_gt_folder = None
37
+
38
+ self.test_loader = get_loader(te_img_folder, te_gt_folder, edge_folder=None, phase='test',
39
+ batch_size=args.batch_size, shuffle=False,
40
+ num_workers=args.num_workers, transform=self.test_transform)
41
+
42
+ if args.save_map is not None:
43
+ os.makedirs(os.path.join('mask', self.args.dataset), exist_ok=True)
44
+ os.makedirs(os.path.join('object', self.args.dataset), exist_ok=True)
45
+
46
+ def test(self):
47
+ self.model.eval()
48
+ t = time.time()
49
+
50
+ with torch.no_grad():
51
+ for i, (images, original_size, image_name) in enumerate(tqdm(self.test_loader)):
52
+ images = torch.tensor(images, device=self.device, dtype=torch.float32)
53
+
54
+ outputs, edge_mask, ds_map = self.model(images)
55
+ H, W = original_size
56
+
57
+ for i in range(images.size(0)):
58
+ h, w = H[i].item(), W[i].item()
59
+ output = F.interpolate(outputs[i].unsqueeze(0), size=(h, w), mode='bilinear')
60
+
61
+ # Save prediction map
62
+ if self.args.save_map is not None:
63
+ output = (output.squeeze().detach().cpu().numpy() * 255.0).astype(np.uint8)
64
+
65
+ salient_object = self.post_processing(images[i], output, h, w)
66
+ cv2.imwrite(os.path.join('mask', self.args.dataset, image_name[i] + '.png'), output)
67
+ cv2.imwrite(os.path.join('object', self.args.dataset, image_name[i] + '.png'), salient_object)
68
+
69
+ print(f'time: {time.time() - t:.3f}s')
70
+
71
+ def post_processing(self, original_image, output_image, height, width, threshold=200):
72
+ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
73
+ std=[1 / 0.229, 1 / 0.224, 1 / 0.225]),
74
+ transforms.Normalize(mean=[-0.485, -0.456, -0.406],
75
+ std=[1., 1., 1.]),
76
+ ])
77
+ original_image = invTrans(original_image)
78
+
79
+ original_image = F.interpolate(original_image.unsqueeze(0), size=(height, width), mode='bilinear')
80
+ original_image = (original_image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)
81
+
82
+ rgba_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2BGRA)
83
+ output_rbga_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2BGRA)
84
+
85
+ output_rbga_image[:, :, 3] = output_image # Extract edges
86
+ edge_y, edge_x, _ = np.where(output_rbga_image <= threshold) # Edge coordinates
87
+
88
+ rgba_image[edge_y, edge_x, 3] = 0
89
+ return cv2.cvtColor(rgba_image, cv2.COLOR_RGBA2BGRA)
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pprint
3
+ import random
4
+ import warnings
5
+ import torch
6
+ import numpy as np
7
+ from trainer import Trainer, Tester
8
+ from inference import Inference
9
+
10
+ from config import getConfig
11
+ warnings.filterwarnings('ignore')
12
+ args = getConfig()
13
+
14
+
15
+ def main(args):
16
+ print('<---- Training Params ---->')
17
+ pprint.pprint(args)
18
+
19
+ # Random Seed
20
+ seed = args.seed
21
+ os.environ['PYTHONHASHSEED'] = str(seed)
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed) # if use multi-GPU
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+
30
+ if args.action == 'train':
31
+ save_path = os.path.join(args.model_path, args.dataset, f'TE{args.arch}_{str(args.exp_num)}')
32
+
33
+ # Create model directory
34
+ os.makedirs(save_path, exist_ok=True)
35
+ Trainer(args, save_path)
36
+
37
+ elif args.action == 'test':
38
+ save_path = os.path.join(args.model_path, args.dataset, f'TE{args.arch}_{str(args.exp_num)}')
39
+ datasets = ['DUTS', 'DUT-O', 'HKU-IS', 'ECSSD', 'PASCAL-S']
40
+
41
+ for dataset in datasets:
42
+ args.dataset = dataset
43
+ test_loss, test_mae, test_maxf, test_avgf, test_s_m = Tester(args, save_path).test()
44
+
45
+ print(f'Test Loss:{test_loss:.3f} | MAX_F:{test_maxf:.4f} '
46
+ f'| AVG_F:{test_avgf:.4f} | MAE:{test_mae:.4f} | S_Measure:{test_s_m:.4f}')
47
+ else:
48
+ save_path = os.path.join(args.model_path, args.dataset, f'TE{args.arch}_{str(args.exp_num)}')
49
+
50
+ print('<----- Initializing inference mode ----->')
51
+ Inference(args, save_path).test()
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main(args)
mask/custom_dataset/110000026240767.png ADDED
mask/custom_dataset/200630_colekt_pack0937__la_chambre__bottle_50ml__final__16x9-copy-scaled.png ADDED
mask/custom_dataset/TOMBUL.png ADDED
mask/custom_dataset/images (1).png ADDED
mask/custom_dataset/images.png ADDED
mask/custom_dataset/indir (1).png ADDED
mask/custom_dataset/indir.png ADDED
mask/custom_dataset/photo-1541643600914-78b084683601.png ADDED
model/EfficientNet.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Original author: lukemelas (github username)
3
+ Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
4
+ With adjustments and added comments by workingcoder (github username).
5
+
6
+ Reimplemented: Min Seok Lee and Wooseok Shin
7
+ """
8
+
9
+
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+ from util.effi_utils import (
15
+ get_model_shape,
16
+ round_filters,
17
+ round_repeats,
18
+ drop_connect,
19
+ get_same_padding_conv2d,
20
+ get_model_params,
21
+ efficientnet_params,
22
+ load_pretrained_weights,
23
+ Swish,
24
+ MemoryEfficientSwish,
25
+ calculate_output_image_size
26
+ )
27
+ from modules.att_modules import Frequency_Edge_Module
28
+ from config import getConfig
29
+
30
+ cfg = getConfig()
31
+
32
+ VALID_MODELS = (
33
+ 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
34
+ 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
35
+ 'efficientnet-b8',
36
+
37
+ # Support the construction of 'efficientnet-l2' without pretrained weights
38
+ 'efficientnet-l2'
39
+ )
40
+
41
+
42
+ class MBConvBlock(nn.Module):
43
+ """Mobile Inverted Residual Bottleneck Block.
44
+
45
+ Args:
46
+ block_args (namedtuple): BlockArgs, defined in utils.py.
47
+ global_params (namedtuple): GlobalParam, defined in utils.py.
48
+ image_size (tuple or list): [image_height, image_width].
49
+
50
+ References:
51
+ [1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
52
+ [2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
53
+ [3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
54
+ """
55
+
56
+ def __init__(self, block_args, global_params, image_size=None):
57
+ super().__init__()
58
+ self._block_args = block_args
59
+ self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow
60
+ self._bn_eps = global_params.batch_norm_epsilon
61
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
62
+ self.id_skip = block_args.id_skip # whether to use skip connection and drop connect
63
+
64
+ # Expansion phase (Inverted Bottleneck)
65
+ inp = self._block_args.input_filters # number of input channels
66
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
67
+ if self._block_args.expand_ratio != 1:
68
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
69
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
70
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
71
+ # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
72
+
73
+ # Depthwise convolution phase
74
+ k = self._block_args.kernel_size
75
+ s = self._block_args.stride
76
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
77
+ self._depthwise_conv = Conv2d(
78
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
79
+ kernel_size=k, stride=s, bias=False)
80
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
81
+ image_size = calculate_output_image_size(image_size, s)
82
+
83
+ # Squeeze and Excitation layer, if desired
84
+ if self.has_se:
85
+ Conv2d = get_same_padding_conv2d(image_size=(1, 1))
86
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
87
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
88
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
89
+
90
+ # Pointwise convolution phase
91
+ final_oup = self._block_args.output_filters
92
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
93
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
94
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
95
+ self._swish = MemoryEfficientSwish()
96
+
97
+ def forward(self, inputs, drop_connect_rate=None):
98
+ """MBConvBlock's forward function.
99
+
100
+ Args:
101
+ inputs (tensor): Input tensor.
102
+ drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
103
+
104
+ Returns:
105
+ Output of this block after processing.
106
+ """
107
+
108
+ # Expansion and Depthwise Convolution
109
+ x = inputs
110
+ if self._block_args.expand_ratio != 1:
111
+ x = self._expand_conv(inputs)
112
+ x = self._bn0(x)
113
+ x = self._swish(x)
114
+
115
+ x = self._depthwise_conv(x)
116
+ x = self._bn1(x)
117
+ x = self._swish(x)
118
+
119
+ # Squeeze and Excitation
120
+ if self.has_se:
121
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
122
+ x_squeezed = self._se_reduce(x_squeezed)
123
+ x_squeezed = self._swish(x_squeezed)
124
+ x_squeezed = self._se_expand(x_squeezed)
125
+ x = torch.sigmoid(x_squeezed) * x
126
+
127
+ # Pointwise Convolution
128
+ x = self._project_conv(x)
129
+ x = self._bn2(x)
130
+
131
+ # Skip connection and drop connect
132
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
133
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
134
+ # The combination of skip connection and drop connect brings about stochastic depth.
135
+ if drop_connect_rate:
136
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
137
+ x = x + inputs # skip connection
138
+ return x
139
+
140
+ def set_swish(self, memory_efficient=True):
141
+ """Sets swish function as memory efficient (for training) or standard (for export).
142
+
143
+ Args:
144
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
145
+ """
146
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
147
+
148
+
149
+ class EfficientNet(nn.Module):
150
+ def __init__(self, blocks_args=None, global_params=None):
151
+ super().__init__()
152
+ assert isinstance(blocks_args, list), 'blocks_args should be a list'
153
+ assert len(blocks_args) > 0, 'block args must be greater than 0'
154
+ self._global_params = global_params
155
+ self._blocks_args = blocks_args
156
+ self.block_idx, self.channels = get_model_shape()
157
+ self.Frequency_Edge_Module1 = Frequency_Edge_Module(radius=cfg.frequency_radius,
158
+ channel=self.channels[0])
159
+ # Batch norm parameters
160
+ bn_mom = 1 - self._global_params.batch_norm_momentum
161
+ bn_eps = self._global_params.batch_norm_epsilon
162
+
163
+ # Get stem static or dynamic convolution depending on image size
164
+ image_size = global_params.image_size
165
+ Conv2d = get_same_padding_conv2d(image_size=image_size)
166
+
167
+ # Stem
168
+ in_channels = 3 # rgb
169
+ out_channels = round_filters(32, self._global_params) # number of output channels
170
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
171
+ self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
172
+ image_size = calculate_output_image_size(image_size, 2)
173
+
174
+ # Build blocks
175
+ self._blocks = nn.ModuleList([])
176
+ for block_args in self._blocks_args:
177
+
178
+ # Update block input and output filters based on depth multiplier.
179
+ block_args = block_args._replace(
180
+ input_filters=round_filters(block_args.input_filters, self._global_params),
181
+ output_filters=round_filters(block_args.output_filters, self._global_params),
182
+ num_repeat=round_repeats(block_args.num_repeat, self._global_params)
183
+ )
184
+
185
+ # The first block needs to take care of stride and filter size increase.
186
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
187
+ image_size = calculate_output_image_size(image_size, block_args.stride)
188
+ if block_args.num_repeat > 1: # modify block_args to keep same output size
189
+ block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
190
+ for _ in range(block_args.num_repeat - 1):
191
+ self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size))
192
+ # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
193
+
194
+ self._swish = MemoryEfficientSwish()
195
+
196
+ def set_swish(self, memory_efficient=True):
197
+ """Sets swish function as memory efficient (for training) or standard (for export).
198
+
199
+ Args:
200
+ memory_efficient (bool): Whether to use memory-efficient version of swish.
201
+
202
+ """
203
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
204
+ for block in self._blocks:
205
+ block.set_swish(memory_efficient)
206
+
207
+ def extract_endpoints(self, inputs):
208
+ endpoints = dict()
209
+
210
+ # Stem
211
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
212
+ prev_x = x
213
+
214
+ # Blocks
215
+ for idx, block in enumerate(self._blocks):
216
+ drop_connect_rate = self._global_params.drop_connect_rate
217
+ if drop_connect_rate:
218
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
219
+ x = block(x, drop_connect_rate=drop_connect_rate)
220
+ if prev_x.size(2) > x.size(2):
221
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x
222
+ prev_x = x
223
+
224
+ # Head
225
+ x = self._swish(self._bn1(self._conv_head(x)))
226
+ endpoints['reduction_{}'.format(len(endpoints) + 1)] = x
227
+
228
+ return endpoints
229
+
230
+
231
+ def initial_conv(self, inputs):
232
+ # Stem
233
+ x = self._swish(self._bn0(self._conv_stem(inputs)))
234
+
235
+ return x
236
+
237
+
238
+ def get_blocks(self, x, H, W):
239
+ # Blocks
240
+ for idx, block in enumerate(self._blocks):
241
+ drop_connect_rate = self._global_params.drop_connect_rate
242
+ if drop_connect_rate:
243
+ drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate
244
+
245
+ x = block(x, drop_connect_rate=drop_connect_rate)
246
+
247
+ if idx == self.block_idx[0]:
248
+ x, edge = self.Frequency_Edge_Module1(x)
249
+ edge = F.interpolate(edge, size=(H, W), mode='bilinear')
250
+ x1 = x.clone()
251
+ if idx == self.block_idx[1]:
252
+ x2 = x.clone()
253
+ if idx == self.block_idx[2]:
254
+ x3 = x.clone()
255
+ if idx == self.block_idx[3]:
256
+ x4 = x.clone()
257
+
258
+ return (x1, x2, x3, x4), edge
259
+
260
+
261
+ @classmethod
262
+ def from_name(cls, model_name, in_channels=3, **override_params):
263
+ """create an efficientnet model according to name.
264
+
265
+ Args:
266
+ model_name (str): Name for efficientnet.
267
+ in_channels (int): Input data's channel number.
268
+ override_params (other key word params):
269
+ Params to override model's global_params.
270
+ Optional key:
271
+ 'width_coefficient', 'depth_coefficient',
272
+ 'image_size', 'dropout_rate',
273
+ 'num_classes', 'batch_norm_momentum',
274
+ 'batch_norm_epsilon', 'drop_connect_rate',
275
+ 'depth_divisor', 'min_depth'
276
+
277
+ Returns:
278
+ An efficientnet model.
279
+ """
280
+ cls._check_model_name_is_valid(model_name)
281
+ blocks_args, global_params = get_model_params(model_name, override_params)
282
+ model = cls(blocks_args, global_params)
283
+ model._change_in_channels(in_channels)
284
+ return model
285
+
286
+ @classmethod
287
+ def from_pretrained(cls, model_name, weights_path=None, advprop=False,
288
+ in_channels=3, num_classes=1000, **override_params):
289
+ """create an efficientnet model according to name.
290
+
291
+ Args:
292
+ model_name (str): Name for efficientnet.
293
+ weights_path (None or str):
294
+ str: path to pretrained weights file on the local disk.
295
+ None: use pretrained weights downloaded from the Internet.
296
+ advprop (bool):
297
+ Whether to load pretrained weights
298
+ trained with advprop (valid when weights_path is None).
299
+ in_channels (int): Input data's channel number.
300
+ num_classes (int):
301
+ Number of categories for classification.
302
+ It controls the output size for final linear layer.
303
+ override_params (other key word params):
304
+ Params to override model's global_params.
305
+ Optional key:
306
+ 'width_coefficient', 'depth_coefficient',
307
+ 'image_size', 'dropout_rate',
308
+ 'batch_norm_momentum',
309
+ 'batch_norm_epsilon', 'drop_connect_rate',
310
+ 'depth_divisor', 'min_depth'
311
+
312
+ Returns:
313
+ A pretrained TRACER-EfficientNet model.
314
+ """
315
+ model = cls.from_name(model_name, num_classes=num_classes, **override_params)
316
+ load_pretrained_weights(model, model_name, weights_path=weights_path, advprop=advprop)
317
+ model._change_in_channels(in_channels)
318
+ return model
319
+
320
+ @classmethod
321
+ def get_image_size(cls, model_name):
322
+ """Get the input image size for a given efficientnet model.
323
+
324
+ Args:
325
+ model_name (str): Name for efficientnet.
326
+
327
+ Returns:
328
+ Input image size (resolution).
329
+ """
330
+ cls._check_model_name_is_valid(model_name)
331
+ _, _, res, _ = efficientnet_params(model_name)
332
+ return res
333
+
334
+ @classmethod
335
+ def _check_model_name_is_valid(cls, model_name):
336
+ """Validates model name.
337
+
338
+ Args:
339
+ model_name (str): Name for efficientnet.
340
+
341
+ Returns:
342
+ bool: Is a valid name or not.
343
+ """
344
+ if model_name not in VALID_MODELS:
345
+ raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))
346
+
347
+ def _change_in_channels(self, in_channels):
348
+ """Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
349
+
350
+ Args:
351
+ in_channels (int): Input data's channel number.
352
+ """
353
+ if in_channels != 3:
354
+ Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
355
+ out_channels = round_filters(32, self._global_params)
356
+ self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
model/TRACER.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ author: Min Seok Lee and Wooseok Shin
3
+ Github repo: https://github.com/Karel911/TRACER
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from model.EfficientNet import EfficientNet
10
+ from util.effi_utils import get_model_shape
11
+ from modules.att_modules import RFB_Block, aggregation, ObjectAttention
12
+
13
+
14
+ class TRACER(nn.Module):
15
+ def __init__(self, cfg):
16
+ super().__init__()
17
+ self.model = EfficientNet.from_pretrained(f'efficientnet-b{cfg.arch}', advprop=True)
18
+ self.block_idx, self.channels = get_model_shape()
19
+
20
+ # Receptive Field Blocks
21
+ channels = [int(arg_c) for arg_c in cfg.RFB_aggregated_channel]
22
+ self.rfb2 = RFB_Block(self.channels[1], channels[0])
23
+ self.rfb3 = RFB_Block(self.channels[2], channels[1])
24
+ self.rfb4 = RFB_Block(self.channels[3], channels[2])
25
+
26
+ # Multi-level aggregation
27
+ self.agg = aggregation(channels)
28
+
29
+ # Object Attention
30
+ self.ObjectAttention2 = ObjectAttention(channel=self.channels[1], kernel_size=3)
31
+ self.ObjectAttention1 = ObjectAttention(channel=self.channels[0], kernel_size=3)
32
+
33
+ def forward(self, inputs):
34
+ B, C, H, W = inputs.size()
35
+
36
+ # EfficientNet backbone Encoder
37
+ x = self.model.initial_conv(inputs)
38
+ features, edge = self.model.get_blocks(x, H, W)
39
+
40
+ x3_rfb = self.rfb2(features[1])
41
+ x4_rfb = self.rfb3(features[2])
42
+ x5_rfb = self.rfb4(features[3])
43
+
44
+ D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
45
+
46
+ ds_map0 = F.interpolate(D_0, scale_factor=8, mode='bilinear')
47
+
48
+ D_1 = self.ObjectAttention2(D_0, features[1])
49
+ ds_map1 = F.interpolate(D_1, scale_factor=8, mode='bilinear')
50
+
51
+ ds_map = F.interpolate(D_1, scale_factor=2, mode='bilinear')
52
+ D_2 = self.ObjectAttention1(ds_map, features[0])
53
+ ds_map2 = F.interpolate(D_2, scale_factor=4, mode='bilinear')
54
+
55
+ final_map = (ds_map2 + ds_map1 + ds_map0) / 3
56
+
57
+ return torch.sigmoid(final_map), torch.sigmoid(edge), \
58
+ (torch.sigmoid(ds_map0), torch.sigmoid(ds_map1), torch.sigmoid(ds_map2))
model/__pycache__/EfficientNet.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
model/__pycache__/TRACER.cpython-39.pyc ADDED
Binary file (2.16 kB). View file
 
modules/__pycache__/att_modules.cpython-39.pyc ADDED
Binary file (9.33 kB). View file
 
modules/__pycache__/conv_modules.cpython-39.pyc ADDED
Binary file (2.28 kB). View file
 
modules/att_modules.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ author: Min Seok Lee and Wooseok Shin
3
+ """
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from torch.fft import fft2, fftshift, ifft2, ifftshift
7
+ from util.utils import *
8
+ import torch.nn.functional as F
9
+ from config import getConfig
10
+ from modules.conv_modules import BasicConv2d, DWConv, DWSConv
11
+
12
+ cfg = getConfig()
13
+
14
+
15
+ class Frequency_Edge_Module(nn.Module):
16
+ def __init__(self, radius, channel):
17
+ super(Frequency_Edge_Module, self).__init__()
18
+ self.radius = radius
19
+ self.UAM = UnionAttentionModule(channel, only_channel_tracing=True)
20
+
21
+ # DWS + DWConv
22
+ self.DWSConv = DWSConv(channel, channel, kernel=3, padding=1, kernels_per_layer=1)
23
+ self.DWConv1 = nn.Sequential(
24
+ DWConv(channel, channel, kernel=1, padding=0, dilation=1),
25
+ BasicConv2d(channel, channel // 4, 1),
26
+ )
27
+ self.DWConv2 = nn.Sequential(
28
+ DWConv(channel, channel, kernel=3, padding=1, dilation=1),
29
+ BasicConv2d(channel, channel // 4, 1),
30
+ )
31
+ self.DWConv3 = nn.Sequential(
32
+ DWConv(channel, channel, kernel=3, padding=3, dilation=3),
33
+ BasicConv2d(channel, channel // 4, 1),
34
+ )
35
+ self.DWConv4 = nn.Sequential(
36
+ DWConv(channel, channel, kernel=3, padding=5, dilation=5),
37
+ BasicConv2d(channel, channel // 4, 1),
38
+ )
39
+ self.conv = BasicConv2d(channel, 1, 1)
40
+
41
+ def distance(self, i, j, imageSize, r):
42
+ dis = np.sqrt((i - imageSize / 2) ** 2 + (j - imageSize / 2) ** 2)
43
+ if dis < r:
44
+ return 1.0
45
+ else:
46
+ return 0
47
+
48
+ def mask_radial(self, img, r):
49
+ batch, channels, rows, cols = img.shape
50
+ mask = torch.zeros((rows, cols), dtype=torch.float32)
51
+ for i in range(rows):
52
+ for j in range(cols):
53
+ mask[i, j] = self.distance(i, j, imageSize=rows, r=r)
54
+ return mask
55
+
56
+ def forward(self, x):
57
+ """
58
+ Input:
59
+ The first encoder block representation: (B, C, H, W)
60
+ Returns:
61
+ Edge refined representation: X + edge (B, C, H, W)
62
+ """
63
+ x_fft = fft2(x, dim=(-2, -1))
64
+ x_fft = fftshift(x_fft)
65
+
66
+ # Mask -> low, high separate
67
+ mask = self.mask_radial(img=x, r=self.radius).cuda()
68
+ high_frequency = x_fft * (1 - mask)
69
+ x_fft = ifftshift(high_frequency)
70
+ x_fft = ifft2(x_fft, dim=(-2, -1))
71
+ x_H = torch.abs(x_fft)
72
+
73
+ x_H, _ = self.UAM.Channel_Tracer(x_H)
74
+ edge_maks = self.DWSConv(x_H)
75
+ skip = edge_maks.clone()
76
+
77
+ edge_maks = torch.cat([self.DWConv1(edge_maks), self.DWConv2(edge_maks),
78
+ self.DWConv3(edge_maks), self.DWConv4(edge_maks)], dim=1) + skip
79
+ edge = torch.relu(self.conv(edge_maks))
80
+
81
+ x = x + edge # Feature + Masked Edge information
82
+
83
+ return x, edge
84
+
85
+
86
+ class RFB_Block(nn.Module):
87
+ def __init__(self, in_channel, out_channel):
88
+ super(RFB_Block, self).__init__()
89
+ self.relu = nn.ReLU(True)
90
+ self.branch0 = nn.Sequential(
91
+ BasicConv2d(in_channel, out_channel, 1),
92
+ )
93
+ self.branch1 = nn.Sequential(
94
+ BasicConv2d(in_channel, out_channel, 1),
95
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
96
+ BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
97
+ BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
98
+ )
99
+ self.branch2 = nn.Sequential(
100
+ BasicConv2d(in_channel, out_channel, 1),
101
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
102
+ BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
103
+ BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
104
+ )
105
+ self.branch3 = nn.Sequential(
106
+ BasicConv2d(in_channel, out_channel, 1),
107
+ BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
108
+ BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
109
+ BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
110
+ )
111
+ self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
112
+ self.conv_res = BasicConv2d(in_channel, out_channel, 1)
113
+
114
+ def forward(self, x):
115
+ x0 = self.branch0(x)
116
+ x1 = self.branch1(x)
117
+ x2 = self.branch2(x)
118
+ x3 = self.branch3(x)
119
+ x_cat = torch.cat((x0, x1, x2, x3), 1)
120
+ x_cat = self.conv_cat(x_cat)
121
+
122
+ x = self.relu(x_cat + self.conv_res(x))
123
+ return x
124
+
125
+
126
+ class GlobalAvgPool(nn.Module):
127
+ def __init__(self, flatten=False):
128
+ super(GlobalAvgPool, self).__init__()
129
+ self.flatten = flatten
130
+
131
+ def forward(self, x):
132
+ if self.flatten:
133
+ in_size = x.size()
134
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
135
+ else:
136
+ return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
137
+
138
+
139
+ class UnionAttentionModule(nn.Module):
140
+ def __init__(self, n_channels, only_channel_tracing=False):
141
+ super(UnionAttentionModule, self).__init__()
142
+ self.GAP = GlobalAvgPool()
143
+ self.confidence_ratio = cfg.gamma
144
+ self.bn = nn.BatchNorm2d(n_channels)
145
+ self.norm = nn.Sequential(
146
+ nn.BatchNorm2d(n_channels),
147
+ nn.Dropout3d(self.confidence_ratio)
148
+ )
149
+ self.channel_q = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=1, stride=1,
150
+ padding=0, bias=False)
151
+ self.channel_k = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=1, stride=1,
152
+ padding=0, bias=False)
153
+ self.channel_v = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=1, stride=1,
154
+ padding=0, bias=False)
155
+
156
+ self.fc = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=1, stride=1,
157
+ padding=0, bias=False)
158
+
159
+ if only_channel_tracing == False:
160
+ self.spatial_q = nn.Conv2d(in_channels=n_channels, out_channels=1, kernel_size=1, stride=1,
161
+ padding=0, bias=False)
162
+ self.spatial_k = nn.Conv2d(in_channels=n_channels, out_channels=1, kernel_size=1, stride=1,
163
+ padding=0, bias=False)
164
+ self.spatial_v = nn.Conv2d(in_channels=n_channels, out_channels=1, kernel_size=1, stride=1,
165
+ padding=0, bias=False)
166
+ self.sigmoid = nn.Sigmoid()
167
+
168
+ def masking(self, x, mask):
169
+ mask = mask.squeeze(3).squeeze(2)
170
+ threshold = torch.quantile(mask, self.confidence_ratio, dim=-1, keepdim=True)
171
+ mask[mask <= threshold] = 0.0
172
+ mask = mask.unsqueeze(2).unsqueeze(3)
173
+ mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
174
+ masked_x = x * mask
175
+
176
+ return masked_x
177
+
178
+ def Channel_Tracer(self, x):
179
+ avg_pool = self.GAP(x)
180
+ x_norm = self.norm(avg_pool)
181
+
182
+ q = self.channel_q(x_norm).squeeze(-1)
183
+ k = self.channel_k(x_norm).squeeze(-1)
184
+ v = self.channel_v(x_norm).squeeze(-1)
185
+
186
+ # softmax(Q*K^T)
187
+ QK_T = torch.matmul(q, k.transpose(1, 2))
188
+ alpha = F.softmax(QK_T, dim=-1)
189
+
190
+ # a*v
191
+ att = torch.matmul(alpha, v).unsqueeze(-1)
192
+ att = self.fc(att)
193
+ att = self.sigmoid(att)
194
+
195
+ output = (x * att) + x
196
+ alpha_mask = att.clone()
197
+
198
+ return output, alpha_mask
199
+
200
+ def forward(self, x):
201
+ X_c, alpha_mask = self.Channel_Tracer(x)
202
+ X_c = self.bn(X_c)
203
+ x_drop = self.masking(X_c, alpha_mask)
204
+
205
+ q = self.spatial_q(x_drop).squeeze(1)
206
+ k = self.spatial_k(x_drop).squeeze(1)
207
+ v = self.spatial_v(x_drop).squeeze(1)
208
+
209
+ # softmax(Q*K^T)
210
+ QK_T = torch.matmul(q, k.transpose(1, 2))
211
+ alpha = F.softmax(QK_T, dim=-1)
212
+
213
+ output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
214
+
215
+ return output
216
+
217
+
218
+ class aggregation(nn.Module):
219
+ def __init__(self, channel):
220
+ super(aggregation, self).__init__()
221
+ self.relu = nn.ReLU(True)
222
+
223
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
224
+ self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
225
+ self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
226
+ self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
227
+ self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
228
+ self.conv_upsample5 = BasicConv2d(channel[2] + channel[1], channel[2] + channel[1], 3, padding=1)
229
+
230
+ self.conv_concat2 = BasicConv2d((channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1)
231
+ self.conv_concat3 = BasicConv2d((channel[0] + channel[1] + channel[2]),
232
+ (channel[0] + channel[1] + channel[2]), 3, padding=1)
233
+
234
+ self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
235
+
236
+ def forward(self, e4, e3, e2):
237
+ e4_1 = e4
238
+ e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
239
+ e2_1 = self.conv_upsample2(self.upsample(self.upsample(e4))) \
240
+ * self.conv_upsample3(self.upsample(e3)) * e2
241
+
242
+ e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
243
+ e3_2 = self.conv_concat2(e3_2)
244
+
245
+ e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
246
+ x = self.conv_concat3(e2_2)
247
+
248
+ output = self.UAM(x)
249
+
250
+ return output
251
+
252
+
253
+ class ObjectAttention(nn.Module):
254
+ def __init__(self, channel, kernel_size):
255
+ super(ObjectAttention, self).__init__()
256
+ self.channel = channel
257
+ self.DWSConv = DWSConv(channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1)
258
+ self.DWConv1 = nn.Sequential(
259
+ DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
260
+ BasicConv2d(channel // 2, channel // 8, 1),
261
+ )
262
+ self.DWConv2 = nn.Sequential(
263
+ DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
264
+ BasicConv2d(channel // 2, channel // 8, 1),
265
+ )
266
+ self.DWConv3 = nn.Sequential(
267
+ DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
268
+ BasicConv2d(channel // 2, channel // 8, 1),
269
+ )
270
+ self.DWConv4 = nn.Sequential(
271
+ DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
272
+ BasicConv2d(channel // 2, channel // 8, 1),
273
+ )
274
+ self.conv1 = BasicConv2d(channel // 2, 1, 1)
275
+
276
+ def forward(self, decoder_map, encoder_map):
277
+ """
278
+ Args:
279
+ decoder_map: decoder representation (B, 1, H, W).
280
+ encoder_map: encoder block output (B, C, H, W).
281
+ Returns:
282
+ decoder representation: (B, 1, H, W)
283
+ """
284
+ mask_bg = -1 * torch.sigmoid(decoder_map) + 1 # Sigmoid & Reverse
285
+ mask_ob = torch.sigmoid(decoder_map) # object attention
286
+ x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
287
+
288
+ edge = mask_bg.clone()
289
+ edge[edge > cfg.denoise] = 0
290
+ x = x + (edge * encoder_map)
291
+
292
+ x = self.DWSConv(x)
293
+ skip = x.clone()
294
+ x = torch.cat([self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)], dim=1) + skip
295
+ x = torch.relu(self.conv1(x))
296
+
297
+ return x + decoder_map
modules/conv_modules.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ author: Min Seok Lee and Wooseok Shin
3
+ """
4
+ import torch.nn as nn
5
+
6
+
7
+ class BasicConv2d(nn.Module):
8
+ def __init__(self, in_channel, out_channel, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
9
+ super(BasicConv2d, self).__init__()
10
+ self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding,
11
+ dilation=dilation, bias=False)
12
+ self.bn = nn.BatchNorm2d(out_channel)
13
+ self.selu = nn.SELU()
14
+
15
+ def forward(self, x):
16
+ x = self.conv(x)
17
+ x = self.bn(x)
18
+ x = self.selu(x)
19
+
20
+ return x
21
+
22
+
23
+ class DWConv(nn.Module):
24
+ def __init__(self, in_channel, out_channel, kernel, dilation, padding):
25
+ super(DWConv, self).__init__()
26
+ self.out_channel = out_channel
27
+ self.DWConv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel, padding=padding, groups=in_channel,
28
+ dilation=dilation, bias=False)
29
+ self.bn = nn.BatchNorm2d(out_channel)
30
+ self.selu = nn.SELU()
31
+
32
+ def forward(self, x):
33
+ x = self.DWConv(x)
34
+ out = self.selu(self.bn(x))
35
+
36
+ return out
37
+
38
+
39
+ class DWSConv(nn.Module):
40
+ def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
41
+ super(DWSConv, self).__init__()
42
+ self.out_channel = out_channel
43
+ self.DWConv = nn.Conv2d(in_channel, in_channel * kernels_per_layer, kernel_size=kernel, padding=padding,
44
+ groups=in_channel, bias=False)
45
+ self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
46
+ self.selu = nn.SELU()
47
+ self.PWConv = nn.Conv2d(in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False)
48
+ self.bn2 = nn.BatchNorm2d(out_channel)
49
+
50
+ def forward(self, x):
51
+ x = self.DWConv(x)
52
+ x = self.selu(self.bn(x))
53
+ out = self.PWConv(x)
54
+ out = self.selu(self.bn2(out))
55
+
56
+ return out
object/custom_dataset/110000026240767.png ADDED
object/custom_dataset/200630_colekt_pack0937__la_chambre__bottle_50ml__final__16x9-copy-scaled.png ADDED
object/custom_dataset/43a71f50-b839-4ab8-8be8-5be346ffe8be.png ADDED
object/custom_dataset/TOMBUL.png ADDED
object/custom_dataset/images (1).png ADDED
object/custom_dataset/images.png ADDED
object/custom_dataset/indir (1).png ADDED
object/custom_dataset/indir (2).png ADDED
object/custom_dataset/indir.png ADDED
object/custom_dataset/kabe-a-1912087.png ADDED
object/custom_dataset/photo-1541643600914-78b084683601.png ADDED
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.0.3
2
+ certifi==2022.12.7
3
+ colorama==0.4.4
4
+ cycler==0.10.0
5
+ imageio==2.9.0
6
+ joblib==1.2.0
7
+ kiwisolver==1.3.2
8
+ matplotlib==3.4.3
9
+ networkx==2.6.2
10
+ opencv-python-headless==4.5.3.56
11
+ Pillow
12
+ pyparsing==2.4.7
13
+ python-dateutil==2.8.2
14
+ PyWavelets==1.1.1
15
+ PyYAML==5.4.1
16
+ scikit-image==0.18.3
17
+ scikit-learn==0.24.2
18
+ scipy==1.7.1
19
+ six==1.16.0
20
+ sklearn==0.0
21
+ threadpoolctl==2.2.0
22
+ tifffile==2021.8.30
23
+ torch
24
+ torchvision
25
+ tqdm
26
+ wincertstore==0.2
27
+ transformers
28
+ streamlit
trainer.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ author: Min Seok Lee and Wooseok Shin
3
+ """
4
+ import os
5
+ import cv2
6
+ import time
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from tqdm import tqdm
12
+ from dataloader import get_train_augmentation, get_test_augmentation, get_loader, gt_to_tensor
13
+ from util.utils import AvgMeter
14
+ from util.metrics import Evaluation_metrics
15
+ from util.losses import Optimizer, Scheduler, Criterion
16
+ from model.TRACER import TRACER
17
+
18
+
19
+ class Trainer():
20
+ def __init__(self, args, save_path):
21
+ super(Trainer, self).__init__()
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.size = args.img_size
24
+
25
+ self.tr_img_folder = os.path.join(args.data_path, args.dataset, 'Train/images/')
26
+ self.tr_gt_folder = os.path.join(args.data_path, args.dataset, 'Train/masks/')
27
+ self.tr_edge_folder = os.path.join(args.data_path, args.dataset, 'Train/edges/')
28
+
29
+ self.train_transform = get_train_augmentation(img_size=args.img_size, ver=args.aug_ver)
30
+ self.test_transform = get_test_augmentation(img_size=args.img_size)
31
+
32
+ self.train_loader = get_loader(self.tr_img_folder, self.tr_gt_folder, self.tr_edge_folder, phase='train',
33
+ batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
34
+ transform=self.train_transform, seed=args.seed)
35
+ self.val_loader = get_loader(self.tr_img_folder, self.tr_gt_folder, self.tr_edge_folder, phase='val',
36
+ batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
37
+ transform=self.test_transform, seed=args.seed)
38
+
39
+ # Network
40
+ self.model = TRACER(args).to(self.device)
41
+
42
+ if args.multi_gpu:
43
+ self.model = nn.DataParallel(self.model).to(self.device)
44
+
45
+ # Loss and Optimizer
46
+ self.criterion = Criterion(args)
47
+ self.optimizer = Optimizer(args, self.model)
48
+ self.scheduler = Scheduler(args, self.optimizer)
49
+
50
+ # Train / Validate
51
+ min_loss = 1000
52
+ early_stopping = 0
53
+ t = time.time()
54
+ for epoch in range(1, args.epochs + 1):
55
+ self.epoch = epoch
56
+ train_loss, train_mae = self.training(args)
57
+ val_loss, val_mae = self.validate()
58
+
59
+ if args.scheduler == 'Reduce':
60
+ self.scheduler.step(val_loss)
61
+ else:
62
+ self.scheduler.step()
63
+
64
+ # Save models
65
+ if val_loss < min_loss:
66
+ early_stopping = 0
67
+ best_epoch = epoch
68
+ best_mae = val_mae
69
+ min_loss = val_loss
70
+ torch.save(self.model.state_dict(), os.path.join(save_path, 'best_model.pth'))
71
+ print(f'-----------------SAVE:{best_epoch}epoch----------------')
72
+ else:
73
+ early_stopping += 1
74
+
75
+ if early_stopping == args.patience + 5:
76
+ break
77
+
78
+ print(f'\nBest Val Epoch:{best_epoch} | Val Loss:{min_loss:.3f} | Val MAE:{best_mae:.3f} '
79
+ f'time: {(time.time() - t) / 60:.3f}M')
80
+
81
+ # Test time
82
+ datasets = ['DUTS', 'DUT-O', 'HKU-IS', 'ECSSD', 'PASCAL-S']
83
+ for dataset in datasets:
84
+ args.dataset = dataset
85
+ test_loss, test_mae, test_maxf, test_avgf, test_s_m = self.test(args, os.path.join(save_path))
86
+
87
+ print(
88
+ f'Test Loss:{test_loss:.3f} | MAX_F:{test_maxf:.3f} | AVG_F:{test_avgf:.3f} | MAE:{test_mae:.3f} '
89
+ f'| S_Measure:{test_s_m:.3f}, time: {time.time() - t:.3f}s')
90
+
91
+ end = time.time()
92
+ print(f'Total Process time:{(end - t) / 60:.3f}Minute')
93
+
94
+ def training(self, args):
95
+ self.model.train()
96
+ train_loss = AvgMeter()
97
+ train_mae = AvgMeter()
98
+
99
+ for images, masks, edges in tqdm(self.train_loader):
100
+ images = torch.tensor(images, device=self.device, dtype=torch.float32)
101
+ masks = torch.tensor(masks, device=self.device, dtype=torch.float32)
102
+ edges = torch.tensor(edges, device=self.device, dtype=torch.float32)
103
+
104
+ self.optimizer.zero_grad()
105
+ outputs, edge_mask, ds_map = self.model(images)
106
+ loss1 = self.criterion(outputs, masks)
107
+ loss2 = self.criterion(ds_map[0], masks)
108
+ loss3 = self.criterion(ds_map[1], masks)
109
+ loss4 = self.criterion(ds_map[2], masks)
110
+
111
+ loss_mask = self.criterion(edge_mask, edges)
112
+ loss = loss1 + loss2 + loss3 + loss4 + loss_mask
113
+
114
+ loss.backward()
115
+ nn.utils.clip_grad_norm_(self.model.parameters(), args.clipping)
116
+ self.optimizer.step()
117
+
118
+ # Metric
119
+ mae = torch.mean(torch.abs(outputs - masks))
120
+
121
+ # log
122
+ train_loss.update(loss.item(), n=images.size(0))
123
+ train_mae.update(mae.item(), n=images.size(0))
124
+
125
+ print(f'Epoch:[{self.epoch:03d}/{args.epochs:03d}]')
126
+ print(f'Train Loss:{train_loss.avg:.3f} | MAE:{train_mae.avg:.3f}')
127
+
128
+ return train_loss.avg, train_mae.avg
129
+
130
+ def validate(self):
131
+ self.model.eval()
132
+ val_loss = AvgMeter()
133
+ val_mae = AvgMeter()
134
+
135
+ with torch.no_grad():
136
+ for images, masks, edges in tqdm(self.val_loader):
137
+ images = torch.tensor(images, device=self.device, dtype=torch.float32)
138
+ masks = torch.tensor(masks, device=self.device, dtype=torch.float32)
139
+ edges = torch.tensor(edges, device=self.device, dtype=torch.float32)
140
+
141
+ outputs, edge_mask, ds_map = self.model(images)
142
+ loss1 = self.criterion(outputs, masks)
143
+ loss2 = self.criterion(ds_map[0], masks)
144
+ loss3 = self.criterion(ds_map[1], masks)
145
+ loss4 = self.criterion(ds_map[2], masks)
146
+
147
+ loss_mask = self.criterion(edge_mask, edges)
148
+ loss = loss1 + loss2 + loss3 + loss4 + loss_mask
149
+
150
+ # Metric
151
+ mae = torch.mean(torch.abs(outputs - masks))
152
+
153
+ # log
154
+ val_loss.update(loss.item(), n=images.size(0))
155
+ val_mae.update(mae.item(), n=images.size(0))
156
+
157
+ print(f'Valid Loss:{val_loss.avg:.3f} | MAE:{val_mae.avg:.3f}')
158
+ return val_loss.avg, val_mae.avg
159
+
160
+ def test(self, args, save_path):
161
+ path = os.path.join(save_path, 'best_model.pth')
162
+ self.model.load_state_dict(torch.load(path))
163
+ print('###### pre-trained Model restored #####')
164
+
165
+ te_img_folder = os.path.join(args.data_path, args.dataset, 'Test/images/')
166
+ te_gt_folder = os.path.join(args.data_path, args.dataset, 'Test/masks/')
167
+ test_loader = get_loader(te_img_folder, te_gt_folder, edge_folder=None, phase='test',
168
+ batch_size=args.batch_size, shuffle=False,
169
+ num_workers=args.num_workers, transform=self.test_transform)
170
+
171
+ self.model.eval()
172
+ test_loss = AvgMeter()
173
+ test_mae = AvgMeter()
174
+ test_maxf = AvgMeter()
175
+ test_avgf = AvgMeter()
176
+ test_s_m = AvgMeter()
177
+
178
+ Eval_tool = Evaluation_metrics(args.dataset, self.device)
179
+
180
+ with torch.no_grad():
181
+ for i, (images, masks, original_size, image_name) in enumerate(tqdm(test_loader)):
182
+ images = torch.tensor(images, device=self.device, dtype=torch.float32)
183
+
184
+ outputs, edge_mask, ds_map = self.model(images)
185
+ H, W = original_size
186
+
187
+ for i in range(images.size(0)):
188
+ mask = gt_to_tensor(masks[i])
189
+
190
+ h, w = H[i].item(), W[i].item()
191
+
192
+ output = F.interpolate(outputs[i].unsqueeze(0), size=(h, w), mode='bilinear')
193
+
194
+ loss = self.criterion(output, mask)
195
+
196
+ # Metric
197
+ mae, max_f, avg_f, s_score = Eval_tool.cal_total_metrics(output, mask)
198
+
199
+ # log
200
+ test_loss.update(loss.item(), n=1)
201
+ test_mae.update(mae, n=1)
202
+ test_maxf.update(max_f, n=1)
203
+ test_avgf.update(avg_f, n=1)
204
+ test_s_m.update(s_score, n=1)
205
+
206
+ test_loss = test_loss.avg
207
+ test_mae = test_mae.avg
208
+ test_maxf = test_maxf.avg
209
+ test_avgf = test_avgf.avg
210
+ test_s_m = test_s_m.avg
211
+
212
+ return test_loss, test_mae, test_maxf, test_avgf, test_s_m
213
+
214
+
215
+ class Tester():
216
+ def __init__(self, args, save_path):
217
+ super(Tester, self).__init__()
218
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
219
+ self.test_transform = get_test_augmentation(img_size=args.img_size)
220
+ self.args = args
221
+ self.save_path = save_path
222
+
223
+ # Network
224
+ self.model = TRACER(args).to(self.device)
225
+ if args.multi_gpu:
226
+ self.model = nn.DataParallel(self.model).to(self.device)
227
+
228
+ path = os.path.join(save_path, 'best_model.pth')
229
+ self.model.load_state_dict(torch.load(path))
230
+ print('###### pre-trained Model restored #####')
231
+
232
+ self.criterion = Criterion(args)
233
+
234
+ te_img_folder = os.path.join(args.data_path, args.dataset, 'Test/images/')
235
+ te_gt_folder = os.path.join(args.data_path, args.dataset, 'Test/masks/')
236
+
237
+ self.test_loader = get_loader(te_img_folder, te_gt_folder, edge_folder=None, phase='test',
238
+ batch_size=args.batch_size, shuffle=False,
239
+ num_workers=args.num_workers, transform=self.test_transform)
240
+
241
+ if args.save_map is not None:
242
+ os.makedirs(os.path.join('mask', 'exp'+str(self.args.exp_num), self.args.dataset), exist_ok=True)
243
+
244
+ def test(self):
245
+ self.model.eval()
246
+ test_loss = AvgMeter()
247
+ test_mae = AvgMeter()
248
+ test_maxf = AvgMeter()
249
+ test_avgf = AvgMeter()
250
+ test_s_m = AvgMeter()
251
+ t = time.time()
252
+
253
+ Eval_tool = Evaluation_metrics(self.args.dataset, self.device)
254
+
255
+ with torch.no_grad():
256
+ for i, (images, masks, original_size, image_name) in enumerate(tqdm(self.test_loader)):
257
+ images = torch.tensor(images, device=self.device, dtype=torch.float32)
258
+
259
+ outputs, edge_mask, ds_map = self.model(images)
260
+ H, W = original_size
261
+
262
+ for i in range(images.size(0)):
263
+ mask = gt_to_tensor(masks[i])
264
+ h, w = H[i].item(), W[i].item()
265
+
266
+ output = F.interpolate(outputs[i].unsqueeze(0), size=(h, w), mode='bilinear')
267
+ loss = self.criterion(output, mask)
268
+
269
+ # Metric
270
+ mae, max_f, avg_f, s_score = Eval_tool.cal_total_metrics(output, mask)
271
+
272
+ # Save prediction map
273
+ if self.args.save_map is not None:
274
+ output = (output.squeeze().detach().cpu().numpy()*255.0).astype(np.uint8) # convert uint8 type
275
+ cv2.imwrite(os.path.join('mask', 'exp'+str(self.args.exp_num), self.args.dataset, image_name[i]+'.png'), output)
276
+
277
+ # log
278
+ test_loss.update(loss.item(), n=1)
279
+ test_mae.update(mae, n=1)
280
+ test_maxf.update(max_f, n=1)
281
+ test_avgf.update(avg_f, n=1)
282
+ test_s_m.update(s_score, n=1)
283
+
284
+ test_loss = test_loss.avg
285
+ test_mae = test_mae.avg
286
+ test_maxf = test_maxf.avg
287
+ test_avgf = test_avgf.avg
288
+ test_s_m = test_s_m.avg
289
+
290
+ print(f'Test Loss:{test_loss:.4f} | MAX_F:{test_maxf:.4f} | MAE:{test_mae:.4f} '
291
+ f'| S_Measure:{test_s_m:.4f}, time: {time.time() - t:.3f}s')
292
+
293
+ return test_loss, test_mae, test_maxf, test_avgf, test_s_m