dxue321 commited on
Commit
c3a7f7f
·
1 Parent(s): 15268d6

inital upload

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ model-40 filter=lfs diff=lfs merge=lfs -text
WB_sRGB/LICENSE.md ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
438
+
WB_sRGB/classes/WBsRGB.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## White-balance model class
2
+ #
3
+ # Copyright (c) 2018-present, Mahmoud Afifi
4
+ # York University, Canada
5
+ # mafifi@eecs.yorku.ca | m.3afifi@gmail.com
6
+ #
7
+ # This source code is licensed under the license found in the
8
+ # LICENSE file in the root directory of this source tree.
9
+ # All rights reserved.
10
+ #
11
+ # Please cite the following work if this program is used:
12
+ # Mahmoud Afifi, Brian Price, Scott Cohen, and Michael S. Brown,
13
+ # "When color constancy goes wrong: Correcting improperly white-balanced
14
+ # images", CVPR 2019.
15
+ #
16
+ ##########################################################################
17
+
18
+
19
+ import numpy as np
20
+ import numpy.matlib
21
+ import cv2
22
+
23
+
24
+ class WBsRGB:
25
+ def __init__(self, gamut_mapping=2, upgraded=0):
26
+ if upgraded == 1:
27
+ self.features = np.load('WB_sRGB/models/features+.npy') # encoded features
28
+ self.mappingFuncs = np.load('WB_sRGB/models/mappingFuncs+.npy') # correct funcs
29
+ self.encoderWeights = np.load('WB_sRGB/models/encoderWeights+.npy') # PCA matrix
30
+ self.encoderBias = np.load('WB_sRGB/models/encoderBias+.npy') # PCA bias
31
+ self.K = 75 # K value for NN searching
32
+ else:
33
+ self.features = np.load('WB_sRGB/models/features.npy') # encoded features
34
+ self.mappingFuncs = np.load('WB_sRGB/models/mappingFuncs.npy') # correction funcs
35
+ self.encoderWeights = np.load('WB_sRGB/models/encoderWeights.npy') # PCA matrix
36
+ self.encoderBias = np.load('WB_sRGB/models/encoderBias.npy') # PCA bias
37
+ self.K = 25 # K value for nearest neighbor searching
38
+
39
+ self.sigma = 0.25 # fall-off factor for KNN blending
40
+ self.h = 60 # histogram bin width
41
+ # our results reported with gamut_mapping=2, however gamut_mapping=1
42
+ # gives more compelling results with over-saturated examples.
43
+ self.gamut_mapping = gamut_mapping # options: 1 scaling, 2 clipping
44
+
45
+ def encode(self, hist):
46
+ """ Generates a compacted feature of a given RGB-uv histogram tensor."""
47
+ histR_reshaped = np.reshape(np.transpose(hist[:, :, 0]),
48
+ (1, int(hist.size / 3)), order="F")
49
+ histG_reshaped = np.reshape(np.transpose(hist[:, :, 1]),
50
+ (1, int(hist.size / 3)), order="F")
51
+ histB_reshaped = np.reshape(np.transpose(hist[:, :, 2]),
52
+ (1, int(hist.size / 3)), order="F")
53
+ hist_reshaped = np.append(histR_reshaped,
54
+ [histG_reshaped, histB_reshaped])
55
+ feature = np.dot(hist_reshaped - self.encoderBias.transpose(),
56
+ self.encoderWeights)
57
+ return feature
58
+
59
+ def rgb_uv_hist(self, I):
60
+ """ Computes an RGB-uv histogram tensor. """
61
+ sz = np.shape(I) # get size of current image
62
+ if sz[0] * sz[1] > 202500: # resize if it is larger than 450*450
63
+ factor = np.sqrt(202500 / (sz[0] * sz[1])) # rescale factor
64
+ newH = int(np.floor(sz[0] * factor))
65
+ newW = int(np.floor(sz[1] * factor))
66
+ I = cv2.resize(I, (newW, newH), interpolation=cv2.INTER_NEAREST)
67
+ I_reshaped = I[(I > 0).all(axis=2)]
68
+ eps = 6.4 / self.h
69
+ hist = np.zeros((self.h, self.h, 3)) # histogram will be stored here
70
+ Iy = np.linalg.norm(I_reshaped, axis=1) # intensity vector
71
+ for i in range(3): # for each histogram layer, do
72
+ r = [] # excluded channels will be stored here
73
+ for j in range(3): # for each color channel do
74
+ if j != i:
75
+ r.append(j)
76
+ Iu = np.log(I_reshaped[:, i] / I_reshaped[:, r[1]])
77
+ Iv = np.log(I_reshaped[:, i] / I_reshaped[:, r[0]])
78
+ hist[:, :, i], _, _ = np.histogram2d(
79
+ Iu, Iv, bins=self.h, range=((-3.2 - eps / 2, 3.2 - eps / 2),) * 2, weights=Iy)
80
+ norm_ = hist[:, :, i].sum()
81
+ hist[:, :, i] = np.sqrt(hist[:, :, i] / norm_) # (hist/norm)^(1/2)
82
+ return hist
83
+
84
+ def correctImage(self, I):
85
+ """ White balance a given image I. """
86
+ # I = I[..., ::-1] # convert from BGR to RGB #donna
87
+ I = im2double(I) # convert to double
88
+ # Convert I to float32 may speed up the process.
89
+ feature = self.encode(self.rgb_uv_hist(I))
90
+ # Do
91
+ # ```python
92
+ # feature_diff = self.features - feature
93
+ # D_sq = np.einsum('ij,ij->i', feature_diff, feature_diff)[:, None]
94
+ # ```
95
+ D_sq = np.einsum(
96
+ 'ij, ij ->i', self.features, self.features)[:, None] + np.einsum(
97
+ 'ij, ij ->i', feature, feature) - 2 * self.features.dot(feature.T)
98
+
99
+ # get smallest K distances
100
+ idH = D_sq.argpartition(self.K, axis=0)[:self.K]
101
+ mappingFuncs = np.squeeze(self.mappingFuncs[idH, :])
102
+ dH = np.sqrt(
103
+ np.take_along_axis(D_sq, idH, axis=0))
104
+ weightsH = np.exp(-(np.power(dH, 2)) /
105
+ (2 * np.power(self.sigma, 2))) # compute weights
106
+ weightsH = weightsH / sum(weightsH) # normalize blending weights
107
+ mf = sum(np.matlib.repmat(weightsH, 1, 33) *
108
+ mappingFuncs, 0) # compute the mapping function
109
+ mf = mf.reshape(11, 3, order="F") # reshape it to be 9 * 3
110
+ I_corr = self.colorCorrection(I, mf) # apply it!
111
+ return I_corr
112
+
113
+ def colorCorrection(self, input, m):
114
+ """ Applies a mapping function m to a given input image. """
115
+ sz = np.shape(input) # get size of input image
116
+ I_reshaped = np.reshape(input, (int(input.size / 3), 3), order="F")
117
+ kernel_out = kernelP(I_reshaped)
118
+ out = np.dot(kernel_out, m)
119
+ if self.gamut_mapping == 1:
120
+ # scaling based on input image energy
121
+ out = normScaling(I_reshaped, out)
122
+ elif self.gamut_mapping == 2:
123
+ # clip out-of-gamut pixels
124
+ out = outOfGamutClipping(out)
125
+ else:
126
+ raise Exception('Wrong gamut_mapping value')
127
+ # reshape output image back to the original image shape
128
+ out = out.reshape(sz[0], sz[1], sz[2], order="F")
129
+ out = out.astype('float32') #donna
130
+ # out = out.astype('float32')[..., ::-1] # convert from BGR to RGB
131
+ return out
132
+
133
+
134
+ def normScaling(I, I_corr):
135
+ """ Scales each pixel based on original image energy. """
136
+ norm_I_corr = np.sqrt(np.sum(np.power(I_corr, 2), 1))
137
+ inds = norm_I_corr != 0
138
+ norm_I_corr = norm_I_corr[inds]
139
+ norm_I = np.sqrt(np.sum(np.power(I[inds, :], 2), 1))
140
+ I_corr[inds, :] = I_corr[inds, :] / np.tile(
141
+ norm_I_corr[:, np.newaxis], 3) * np.tile(norm_I[:, np.newaxis], 3)
142
+ return I_corr
143
+
144
+
145
+ def kernelP(rgb):
146
+ """ Kernel function: kernel(r, g, b) -> (r,g,b,rg,rb,gb,r^2,g^2,b^2,rgb,1)
147
+ Ref: Hong, et al., "A study of digital camera colorimetric
148
+ characterization based on polynomial modeling." Color Research &
149
+ Application, 2001. """
150
+ r, g, b = np.split(rgb, 3, axis=1)
151
+ return np.concatenate(
152
+ [rgb, r * g, r * b, g * b, rgb ** 2, r * g * b, np.ones_like(r)], axis=1)
153
+
154
+
155
+ def outOfGamutClipping(I):
156
+ """ Clips out-of-gamut pixels. """
157
+ I[I > 1] = 1 # any pixel is higher than 1, clip it to 1
158
+ I[I < 0] = 0 # any pixel is below 0, clip it to 0
159
+ return I
160
+
161
+
162
+ def im2double(im):
163
+ """ Returns a double image [0,1] of the uint8 im [0,255]. """
164
+ return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)
WB_sRGB/models/encoderBias.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8d77ec0baab6e45a5602cf9f394e7fde3720b8ca76f59c45d7807e0c49b7070
3
+ size 43328
WB_sRGB/models/encoderWeights.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff8e684dd18f0e82fb5b1f398b26eb356f7ba9acd8c81dc43e791f86f105659a
3
+ size 2376128
WB_sRGB/models/features.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76ec7f16aaf1cda53dfb1b3812fae88cb56e0157e3fe7e84138227ac1108a9d6
3
+ size 13757828
WB_sRGB/models/mappingFuncs.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d677ecbe0e958ca1f6d56deb9a3ae9ce66cfbb606e891c8e06e60d807de68b6b
3
+ size 8254748
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from multi_image_process import compute_inp_palette, recolor_single_image, recolor_group_images
5
+
6
+
7
+
8
+ example_images = ["./examples/flower/001.jpg",
9
+ "./examples/flower/002.jpg",
10
+ "./examples/flower/003.jpg",
11
+ "./examples/flower/004.jpg",
12
+ "./examples/flower/005.jpg",
13
+ ]
14
+
15
+
16
+
17
+ def swap_to_gallery(images):
18
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
19
+
20
+ def remove_back_to_files():
21
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
22
+
23
+ def show_palette(*colors):
24
+ # create HTML,each color is a block with a border
25
+ color_blocks = ""
26
+ for color in colors:
27
+ if color: # if choose a color
28
+ color_blocks += f'<div style="width:40px;height:40px;background:{color};display:inline-block;margin:5px;border:1px solid #000;"></div>'
29
+ return color_blocks
30
+
31
+
32
+ def load_example_images():
33
+ images = [Image.open(p) for p in example_images]
34
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=example_images, visible=False)
35
+
36
+
37
+
38
+ if __name__ == "__main__":
39
+
40
+ with gr.Blocks() as demo:
41
+ gr.Markdown("# Image recoloring with palette")
42
+
43
+
44
+ # multiple image recoloring for color consistency
45
+ with gr.Row():
46
+ with gr.Column():
47
+ gr.Markdown("### Inputs")
48
+ image_input = gr.File(
49
+ label="Drag (Select) more than one photos",
50
+ file_types=["image"],
51
+ file_count="multiple"
52
+ )
53
+ uploaded_files = gr.Gallery(label="Input images", visible=False, columns=7, rows=1, height=200)
54
+
55
+
56
+ with gr.Column(visible=False) as clear_button:
57
+ remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=image_input, size="sm")
58
+
59
+
60
+ image_input.upload(fn=swap_to_gallery, inputs=image_input, outputs=[uploaded_files, clear_button, image_input])
61
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, image_input])
62
+
63
+ gr.Markdown("### Select the parameters for recoloring")
64
+ with gr.Row():
65
+ with gr.Group():
66
+ gr.Markdown("### Recoloring without other techiques")
67
+ num_center_grp = gr.Dropdown(choices=[1, 2, 3, 4, 5], value=3, label="Number of group palettes")
68
+ with gr.Group():
69
+ gr.Markdown("### Recoloring with other techiques")
70
+ with gr.Row():
71
+ with gr.Group():
72
+ # gr.Markdown("### white balance")
73
+ checkbox_input_wb = gr.Checkbox(value=False, label="Apply white balance correction")
74
+ with gr.Group():
75
+ # gr.Markdown("### saliency detection")
76
+ checkbox_input_sal = gr.Checkbox(value=False, label="Apply saliency")
77
+ checkbox_input_recolor_sal = gr.Checkbox(value=False, label="Recolor salient part only")
78
+ checkbox_input_recolor_nonsal = gr.Checkbox(value=False, label="Recolor non-salient part only")
79
+ num_center_sal = gr.Dropdown(choices=[1, 2, 3], value=1, label="Number of salient palettes")
80
+ num_center_nonsal = gr.Dropdown(choices=[1, 2, 3], value=1, label="Number of non-salient palettes")
81
+ with gr.Group():
82
+ # gr.Markdown("### color naming")
83
+ checkbox_input_cn = gr.Checkbox(value=False, label="Apply color naming")
84
+ naming_thres = gr.Textbox(value=0.8, label="Threshold of color naming", placeholder=0.8)
85
+
86
+ with gr.Column():
87
+ gr.Markdown("### Outputs")
88
+ output_gallery_palette_in = gr.Gallery(label="Input image palettes", columns=7, rows=1, height=100)
89
+ output_gallery_palette_group= gr.Gallery(label="Group palette", columns=2, rows=1, height=100)
90
+ output_gallery_palette_out = gr.Gallery(label="Output image palettes", columns=7, rows=1, height=100)
91
+ output_gallery_recolor = gr.Gallery(label="Recolored images", columns=7, rows=1, height=300)
92
+
93
+ with gr.Row():
94
+ example_btn = gr.Button("Load Example Images")
95
+ example_btn.click(fn=load_example_images, outputs=[uploaded_files, clear_button, image_input])
96
+
97
+ palette_btn = gr.Button("Compute palette").click(
98
+ compute_inp_palette,
99
+ inputs=[image_input],
100
+ outputs=[output_gallery_palette_in]
101
+ )
102
+
103
+ # recoloring_multi_btn = gr.Button("Recoloring multiple images").click(
104
+ # multi_img_color_consist,
105
+ # inputs=[image_input, num_center_grp, checkbox_input_wb, checkbox_input_sal, checkbox_input_cn, naming_thres],
106
+ # outputs=[output_gallery_recolor, output_gallery_palette_in, output_gallery_palette_out, output_gallery_palette_group]
107
+ # )
108
+
109
+ recoloring_multi_btn = gr.Button("Recoloring multiple images").click(
110
+ recolor_group_images,
111
+ inputs=[image_input, num_center_grp, num_center_sal, num_center_nonsal,
112
+ checkbox_input_wb,
113
+ checkbox_input_sal, checkbox_input_recolor_sal, checkbox_input_recolor_nonsal,
114
+ checkbox_input_cn, naming_thres],
115
+ outputs=[output_gallery_recolor, output_gallery_palette_in, output_gallery_palette_out, output_gallery_palette_group]
116
+ )
117
+
118
+
119
+
120
+ # # single image recoloring with user-defined palette
121
+ # with gr.Row():
122
+ # with gr.Column():
123
+ # gr.Markdown("### Image color categorization")
124
+ # with gr.Row():
125
+ # color1 = gr.ColorPicker(label="Color 1", value="#ff0000")
126
+ # color2 = gr.ColorPicker(label="Color 2", value="#00ff00")
127
+ # color3 = gr.ColorPicker(label="Color 3", value="#0000ff")
128
+ # color4 = gr.ColorPicker(label="Color 4", value=None)
129
+ # color5 = gr.ColorPicker(label="Color 5", value=None)
130
+ # palette = gr.HTML()
131
+ # with gr.Row():
132
+ # btn = gr.Button("Show Picked Palette")
133
+ # btn.click(fn=show_palette, inputs=[color1, color2, color3, color4, color5], outputs=palette)
134
+
135
+ # with gr.Column():
136
+ # gr.Markdown("### Output image")
137
+ # output_gallery_recolor_single = gr.Gallery(label="Recolored image", columns=1, rows=1, height=100)
138
+ # recoloring_single_btn = gr.Button("Recoloring single images").click(
139
+ # recolor_single_image,
140
+ # inputs=[image_input, color1, color2, color3, color4, color5],
141
+ # outputs=[output_gallery_recolor_single]
142
+ # )
143
+
144
+
145
+ demo.launch()
color_naming/colornaming.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import cv2
3
+ # from collections import Counter
4
+ from skimage.color import rgb2lab, lab2rgb
5
+
6
+
7
+ COLOR_NAME = ['black', 'brown', 'blue', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
8
+
9
+ def im2c(img_lab, w2c):
10
+ """
11
+ Convert an image to color name representation using a color-name matrix.
12
+
13
+ Returns:
14
+ numpy.ndarray: Processed image based on color parameter.
15
+ """
16
+ img_lab = np.expand_dims(img_lab, axis=0)
17
+ im = lab2rgb(img_lab)*255
18
+
19
+ # Define color name mappings
20
+ color_values = np.array([[ 0, 0, 0],
21
+ [165, 81, 43],
22
+ [ 0, 0, 255],
23
+ [127, 127, 127],
24
+ [ 0, 255, 0],
25
+ [255, 127, 0],
26
+ [255, 165, 216],
27
+ [191, 0, 191],
28
+ [255, 0, 0],
29
+ [255, 255, 255],
30
+ [255, 255, 0]], dtype=np.uint8)
31
+
32
+
33
+ # Extract RGB channels
34
+ # RR, GG, BB = im[:, :, 0].flatten(), im[:, :, 1].flatten(), im[:, :, 2].flatten()
35
+
36
+ # Compute index for w2c lookup
37
+ index_im = ((im[:, :, 0].flatten() // 8) + 32 * (im[:, :, 1].flatten()// 8) + 32 * 32 * (im[:, :, 2].flatten() // 8)).astype(np.int32)
38
+
39
+ # w2cM = np.argmax(w2c, axis=1)
40
+ # name_idx_img = w2cM[index_im].reshape(im.shape[:2])
41
+
42
+ # max_prob = w2c[np.arange(w2c.shape[0]), w2cM]
43
+ # max_prob_map = max_prob[index_im].reshape(im.shape[:2])
44
+
45
+ prob_map = w2c[index_im, :].reshape((im.shape[0], im.shape[1], w2c.shape[1]))
46
+ # max_prob_map = np.max(prob_map, axis=2)
47
+ name_idx_img = np.argmax(prob_map, axis=2)
48
+
49
+ color_img = np.zeros_like(im).astype(np.uint8)
50
+ color_nam =[0 for i in range(np.size(im, 1))]
51
+
52
+ for jj in range(im.shape[0]):
53
+ for ii in range(im.shape[1]):
54
+ color_img[jj, ii, :] = np.array(color_values[name_idx_img[jj, ii]])
55
+ color_nam[ii] = COLOR_NAME[name_idx_img[jj, ii]]
56
+
57
+ # return prob_map, max_prob_map, name_idx_img, color_img
58
+ return name_idx_img, color_nam, color_img, prob_map
59
+
60
+
61
+
62
+
63
+ def compare_color_name(img_src, img_tgt, w2c, threshold=0.9):
64
+ color_label_org, color_nam_org, _, prob_map_org = im2c(img_src, w2c)
65
+ color_label_new, color_nam_new, _, prob_map_new = im2c(img_tgt, w2c)
66
+
67
+ if threshold==0:
68
+ is_same_color = (color_label_org==color_label_new)
69
+ else:
70
+ diff = np.zeros_like(color_label_org).astype(np.float64)
71
+ for jj in range(np.size(color_label_org, 0)):
72
+ for ii in range(np.size(color_label_org, 1)):
73
+ # difference also can be the l1 , l2 distance, kl divergence between two probablity distribution
74
+ # diff[jj, ii] = prob_map_org[jj, ii, color_label_org[jj, ii]] - prob_map_new[jj, ii, color_label_org[jj, ii]]
75
+ diff[jj, ii] = np.linalg.norm(prob_map_org[jj, ii, :]-prob_map_new[jj, ii, :])
76
+ is_same_color = (np.abs(diff) < threshold)
77
+
78
+ print(is_same_color)
79
+ return is_same_color
80
+
81
+ # if __name__ == "__main__":
82
+ # w2c = np.load('w2c11_joost_c.npy').astype(np.float16)
83
+
84
+
85
+ # image_path = './test.jpg'
86
+ # img = cv2.imread(image_path).astype(np.float32)
87
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
88
+
89
+ # prob_map, max_prob_img, name_idx_img, color_img = im2c(img, w2c)
90
+
91
+
92
+ # filtered_counts = Counter(name_idx_img[name_idx_img <= 10])
93
+ # sorted_counts = sorted(filtered_counts.items(), key=lambda x: x[1], reverse=True)
94
+ # top_3_values = [num for num, count in sorted_counts[:3]]
95
+ # top_3_colors = [COLOR_NAME[i] for i in top_3_values]
96
+
97
+ # print("Top 3 colors:", top_3_colors)
98
+
99
+ # cv2.imwrite('./colormap_joost.jpg', cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB))
100
+
color_naming/w2c11_joost_c.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc5ede48fed749d552edf1fd5d058266c713f20af3d4a484310103be91766856
3
+ size 2883712
examples/flower/001.jpg ADDED

Git LFS Details

  • SHA256: 92b662971220120496c4ebcc7136b0f0dc6a045ef2ed7be58322b09c46df0131
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
examples/flower/002.jpg ADDED

Git LFS Details

  • SHA256: 299f685f4b50ef9cf67b670a17695865707be5dc00bd9e89b889863fc15c4ef2
  • Pointer size: 130 Bytes
  • Size of remote file: 88.5 kB
examples/flower/003.jpg ADDED

Git LFS Details

  • SHA256: cae11d1127737a21574c77403eb84419bde3d6e76a24a86b9ff6df6b2cb7911d
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
examples/flower/004.jpg ADDED

Git LFS Details

  • SHA256: d0a690494b6b22e4f2222777333406a6dd38a1dfebd5b803d11bc0173596d716
  • Pointer size: 130 Bytes
  • Size of remote file: 92.6 kB
examples/flower/005.jpg ADDED

Git LFS Details

  • SHA256: 05a2e76fe802b74ed20b9922020af71a4af17e3c8e526a0fa0c5ffad3281236b
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
examples/landmark/01.jpg ADDED

Git LFS Details

  • SHA256: 5cf61c53c134a43394210da33d621bb09fe2698978641e6babea83cce4ac4cd0
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
examples/landmark/02.jpg ADDED

Git LFS Details

  • SHA256: e342a4bf188dadbc6031f57a22031fd9a956dc8d871e9cbd49af4f558648c838
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
examples/landmark/03.jpg ADDED

Git LFS Details

  • SHA256: fde2ee9aa26c9641fa19c56dd33520e22f1faa11bd43a2bee1b00c5e92faa9da
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
examples/landmark/04.jpg ADDED

Git LFS Details

  • SHA256: bf4b014c7ed2e06b8e076c2e30f25444521dd2f08d23249873f64084f35ad6bf
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
examples/landmark/05.jpg ADDED

Git LFS Details

  • SHA256: 14ba883f33ea8937459636d250e3cc10728abd0fbe180f0816845a61de943b6c
  • Pointer size: 131 Bytes
  • Size of remote file: 199 kB
examples/landmark/06.jpg ADDED

Git LFS Details

  • SHA256: b21bc8d57ea549a717764bbbbf3dbcb113358afcf516abee9892d17f1b831e3c
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/landmark/07.jpg ADDED

Git LFS Details

  • SHA256: 12f80822ff53c70930f9503bcc2bbf489f1f147a16475ed6220bad24999b83d6
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
examples/portrait/image-00000.png ADDED

Git LFS Details

  • SHA256: ae769840f058bb07e409969bba1380eae8540e529752542fd9668f8920a5f411
  • Pointer size: 131 Bytes
  • Size of remote file: 312 kB
examples/portrait/image-00002.png ADDED

Git LFS Details

  • SHA256: b2d68a2fff9bee3f4da9fbd451ea169d2f4517e6689287fc26c2d3c58126cb46
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
examples/portrait/image-00004.png ADDED

Git LFS Details

  • SHA256: 54b7fb4732eb36aaffe58fbf27b2201b3029cce074c1f1298341ea98137e21da
  • Pointer size: 131 Bytes
  • Size of remote file: 770 kB
examples/portrait/image-00006.png ADDED

Git LFS Details

  • SHA256: 4741c26129164cd0a3a003125113e7860d03431fd4956684dd48ac0d921ac8c3
  • Pointer size: 131 Bytes
  • Size of remote file: 378 kB
examples/portrait/image-00014.png ADDED

Git LFS Details

  • SHA256: a35b8688c77a04710e552d4846da6c0d51320113de95a67218183afad0a9f70b
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
extract_palette.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import pandas as pd
3
+
4
+ from sklearn.cluster import KMeans
5
+ from sklearn.metrics import pairwise_distances
6
+
7
+ from collections import Counter
8
+
9
+
10
+
11
+ def histogram(img_lab, bin, mode=2, mask=None):
12
+ # img_lab = rgb2lab(img_rgb)
13
+ # img_lab = img_lab.astype(int)
14
+ if mask is None:
15
+ mask = np.ones_like(img_lab[:,:,0])
16
+
17
+ if img_lab.ndim != 2:
18
+ img_lab = img_lab.reshape(-1, 3)
19
+
20
+ mask = mask.flatten()
21
+ img_lab_masked = img_lab[mask==1]
22
+
23
+ if mode == 3:
24
+ hist, edges = np.histogramdd(img_lab_masked, bins=bin)
25
+ xpos, ypos, zpos = np.meshgrid(edges[0][:-1], edges[1][:-1], edges[2][:-1], indexing="ij")
26
+ hist_samples = np.concatenate((xpos.reshape((bin*bin*bin,1)), ypos.reshape((bin*bin*bin,1)), zpos.reshape((bin*bin*bin,1))), axis=1)
27
+ hist_counts = hist.reshape(bin*bin*bin)
28
+
29
+ elif mode == 2:
30
+ hist, xedges, yedges = np.histogram2d(img_lab_masked[:,1], img_lab_masked[:,2], bins=bin, range=None)
31
+ xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing="ij")
32
+ hist_samples = np.concatenate((xpos.reshape((bin*bin,1)), ypos.reshape((bin*bin,1))), axis=1)
33
+ hist_counts = hist.reshape(bin*bin)
34
+
35
+ # hist_counts = hist_counts/np.sum(hist_counts)
36
+
37
+ return hist_samples, hist_counts
38
+
39
+
40
+
41
+
42
+ def palette_extraction(img_lab, hist_samples, hist_counts, mode=2, threshold=0.93, num_clusters=5, mask=None):
43
+
44
+ if mask is None:
45
+ mask = np.ones_like(img_lab[:,:,0])
46
+
47
+ if img_lab.ndim != 2:
48
+ img_lab = img_lab.reshape(-1, 3)
49
+
50
+ mask = mask.flatten()
51
+ # img_lab = img_lab[mask==1]
52
+
53
+
54
+ hist_densities = hist_counts /np.sum(hist_counts)
55
+ ########################### palette extraction ###########################
56
+ # inital cluster center
57
+ index = np.argwhere(hist_densities!=0)
58
+ index = np.squeeze(index, axis=(1,))
59
+ num_nonzero = np.size(index)
60
+
61
+ # ## directly clustering
62
+ # num_clusters_opt = num_clusters
63
+ # kmeans_f = KMeans(n_clusters=num_clusters_opt, init='k-means++', random_state=0).fit(
64
+ # hist_samples[index, :], y=None, sample_weight=hist_densities[index])
65
+
66
+
67
+ ## clustering method from matlab code
68
+ inits_all = []
69
+ Cold = np.mean(hist_samples[index, :], 0)
70
+ distortion=np.zeros((num_clusters,1))
71
+
72
+
73
+ dist = pairwise_distances(hist_samples[index, :], np.expand_dims(Cold, axis=0), metric='euclidean')
74
+ distortion[0] = np.sum(hist_densities[index] * np.squeeze(dist**2, axis=1), 0)
75
+
76
+ inits_all.append(Cold)
77
+
78
+
79
+
80
+ for k in range(1, num_clusters):
81
+ # Initialize the cluster centers
82
+ k = k+1
83
+ cinits = np.zeros((k, mode))
84
+ cw = hist_densities[index]
85
+ for i in range(k):
86
+ id = np.argmax(cw)
87
+ cinits[i,:] = hist_samples[index, :][id,:]
88
+ d2 = cinits[i,:]* np.ones((num_nonzero, 1)) - hist_samples[index, :]
89
+ d2 = np.sum(np.square(d2), axis=1)
90
+ d2 = d2/np.max(d2)
91
+ cw = cw * (d2**2)
92
+
93
+ inits_all.append(cinits)
94
+ kmeans = KMeans(n_clusters=k, init=cinits, n_init=1).fit(
95
+ hist_samples[index, :], y=None, sample_weight=hist_densities[index])
96
+
97
+ dist_point = pairwise_distances(hist_samples[index, :], kmeans.cluster_centers_, metric='euclidean')
98
+ distortion[k-1] = np.sum(hist_densities[index] * np.min(dist_point, axis=1)**2)
99
+
100
+ variance = distortion[:-1] - distortion[1:]
101
+ distortion_percent = np.cumsum(variance)/(distortion[0]-distortion[-1])
102
+
103
+ r=np.argwhere(distortion_percent > threshold)
104
+ num_clusters_opt = np.min(r)+2
105
+
106
+ kmeans_f = KMeans(n_clusters=num_clusters_opt, init=inits_all[num_clusters_opt-1], n_init=1).fit(
107
+ hist_samples[index, :], y=None, sample_weight=hist_densities[index])
108
+ cluster_centers = kmeans_f.cluster_centers_
109
+
110
+
111
+ # print(cluster_centers.shape)
112
+
113
+ if mode ==3:
114
+ img_labels = kmeans_f.predict(img_lab)
115
+ elif mode == 2:
116
+ img_labels = kmeans_f.predict(img_lab[:, 1:3])
117
+
118
+ hist_labels = kmeans_f.predict(hist_samples)
119
+
120
+
121
+ # print(cluster_centers.shape)
122
+
123
+ # # lab to rgb
124
+ # cluster_cen_rgb = lab2rgb(np.expand_dims(cluster_centers, axis=0))
125
+ # cluster_cen_rgb = np.squeeze(cluster_cen_rgb, axis=0)
126
+
127
+ img_labels[mask==0] = 255
128
+ c_densities = np.zeros(num_clusters_opt)
129
+
130
+ dict=Counter(img_labels)
131
+ for key in np.unique(img_labels):
132
+ if key == 255:
133
+ continue
134
+ c_densities[key] = dict.get(key)
135
+
136
+ c_densities = c_densities / np.sum(c_densities)
137
+
138
+
139
+ return cluster_centers, c_densities, img_labels, hist_labels
image.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import cv2
2
+ import os
3
+ import numpy as np
4
+
5
+ from PIL import Image
6
+ from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
7
+ from WB_sRGB.classes import WBsRGB as wb_srgb
8
+ from extract_palette import histogram, palette_extraction
9
+ from saliency.LDF.infer import Saliency_LDF
10
+ from saliency.fast_saliency import get_saliency_ft, get_saliency_mbd
11
+ from utils import color_difference
12
+
13
+
14
+ class BaseImage:
15
+ def __init__(self, filepath):
16
+ self.filename = os.path.basename(filepath.name)
17
+ self.image = Image.open(filepath)
18
+ self.img_rgb = np.asarray(self.image).astype(dtype=np.uint8)
19
+ self.img_lab = rgb2lab(self.img_rgb)
20
+
21
+ self.bin_size = 16
22
+ self.mode = 2
23
+ self.hist_harmonization = False
24
+ self.template = 'L'
25
+ self.distortion_threshold = 0.93
26
+ self.num_center_ind = 7
27
+ self.lightness = 70
28
+ # self.if_correct_wb = if_correct_wb
29
+ # self.if_saliency = if_saliency
30
+ # self.saliency_threshold = sal_thres
31
+ self.cdiff_threshold = 30
32
+ self.sal_threshold = 0.9
33
+ self.applied_wb = False
34
+ # self.valid_class = [0,1]
35
+
36
+
37
+ # self.hist_value, self.hist_count, \
38
+ # self.c_center, self.c_density, \
39
+ # self.c_img_label = self.extract_palette(if_wb=self.if_correct_wb,
40
+ # if_saliency=self.if_saliency,
41
+ # sal_thres=self.saliency_threshold)
42
+
43
+ # self.inital_info(self.if_correct_wb,
44
+ # self.if_saliency,
45
+ # self.saliency_threshold)
46
+
47
+
48
+ # self.hist_value, self.hist_count, self.c_center, self.c_density, self.c_img_label = self.extract_palette(if_wb=self.if_correct_wb, if_saliency=False)
49
+ # self.hist_value_sal, self.hist_count_sal, self.c_center_sal, self.c_density_sal, self.c_img_label_sal = self.extract_palette(if_wb=self.if_correct_wb, if_saliency=True, sal_thres=self.saliency_threshold)
50
+
51
+ def inital_info(self, if_correct_wb, if_saliency, wb_thres, sal_thres, valid_class):
52
+ self.hist_value, self.hist_count, \
53
+ self.c_center, self.c_density, \
54
+ self.c_img_label, self.sal_links = self.extract_salient_palette(if_wb=if_correct_wb,
55
+ if_saliency=if_saliency,
56
+ wb_thres=wb_thres,
57
+ sal_thres=sal_thres,
58
+ valid_class=valid_class)
59
+
60
+ self.label_colored = self.cal_color_segment()
61
+
62
+ def get_rgb_image(self):
63
+ return self.img_rgb
64
+
65
+ def get_lab_image(self):
66
+ return self.img_lab
67
+
68
+ def get_wb_image(self):
69
+ self.img_wb = self.white_balance_correction()
70
+ return self.img_wb
71
+
72
+ def get_saliency(self):
73
+ self.sal_map = self.saliency_detection(self.img_rgb)
74
+ return self.sal_map
75
+
76
+ def get_color_segment(self):
77
+ return self.label_colored
78
+
79
+ def get_label(self):
80
+ # print(self.links)
81
+ # label_mapped = np.zeros_like(self.colorlabel)
82
+ # for id, label in enumerate(self.links):
83
+ # label_mapped[self.colorlabel==id] = label
84
+ # self.colorlabel = label_mapped
85
+ return self.colorlabel
86
+
87
+
88
+ def cal_color_segment(self):
89
+ label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
90
+ for id_color in range(np.size(self.center, 0)):
91
+ label_colored[self.colorlabel == id_color] = self.center[id_color, :]
92
+ label_colored = lab2rgb(label_colored)
93
+ label_colored = np.round(label_colored*255).astype(np.uint8)
94
+ return label_colored
95
+
96
+
97
+ # def cal_salient_segment(self, palettelabel):
98
+ # label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
99
+ # valid_label = np.argwhere(palettelabel==1).flatten()
100
+ # for id_color in valid_label:
101
+ # label_colored[self.colorlabel == id_color] = self.center[id_color, :]
102
+ # label_colored = lab2rgb(label_colored)
103
+ # label_colored = np.round(label_colored*255).astype(np.uint8)
104
+ # return label_colored
105
+
106
+
107
+ def white_balance_correction(self):
108
+ # print('Correcting the white balance...')
109
+ # use upgraded_model = 1 to load our new model that is upgraded with new
110
+ # training examples.
111
+ upgraded_model = 2
112
+ # use gamut_mapping = 1 for scaling, 2 for clipping (our paper's results
113
+ # reported using clipping). If the image is over-saturated, scaling is
114
+ # recommended.
115
+ gamut_mapping = 2
116
+ # processing
117
+ # create an instance of the WB model
118
+ wbModel = wb_srgb.WBsRGB(gamut_mapping=gamut_mapping,
119
+ upgraded=upgraded_model)
120
+ img_wb = wbModel.correctImage(self.img_rgb) # white balance it
121
+ image_wb = (img_wb*255).astype(np.uint8)
122
+ # img_wb = cv2.cvtColor(img_wb, cv2.COLOR_BGR2RGB)
123
+ return image_wb
124
+
125
+
126
+ def saliency_detection(self, img_rgb, method='LDF'):
127
+ if method == 'LDF':
128
+ get_saliency_LDF = Saliency_LDF()
129
+ sal_map = get_saliency_LDF.inference(img_rgb)
130
+ elif method == 'ft':
131
+ sal_map = get_saliency_ft(img_rgb)
132
+ elif method == 'rbd':
133
+ sal_map = get_saliency_mbd(img_rgb)
134
+ return sal_map
135
+
136
+
137
+ def solve_ind_palette(self, img_rgb, mask_binary=None):
138
+ w, h, c = img_rgb.shape
139
+ img_lab = rgb2lab(img_rgb) # lab transfer by function
140
+
141
+ hist_value, hist_count = histogram(img_lab, self.bin_size, mode=self.mode, mask=mask_binary) ## with numpy histogram
142
+
143
+ ## extract palette
144
+ # mask_binary = np.ones_like(self.img_rgb[:,:,0])
145
+ c_center, c_density, c_img_label, histlabel = palette_extraction(img_lab, hist_value, hist_count,
146
+ threshold=self.distortion_threshold,
147
+ num_clusters=self.num_center_ind,
148
+ mode=self.mode,
149
+ mask=mask_binary)
150
+
151
+ if self.mode == 2:
152
+ c_center = np.insert(c_center, 0, values=self.lightness, axis=1)
153
+
154
+ c_img_label = np.reshape(c_img_label, (w,h))
155
+ # density = np.tile(hist_counts, (self.mode, 1))
156
+
157
+ return hist_value, hist_count, c_center, c_density, c_img_label, histlabel
158
+
159
+
160
+
161
+
162
+ def extract_salient_palette(self, if_wb=False, if_saliency=False, wb_thres=5, sal_thres=0.9, valid_class=[0,1]):
163
+
164
+ img_rgb = self.img_rgb.copy()
165
+ if if_wb:
166
+ self.img_wb = self.white_balance_correction()
167
+ img_wb = self.img_wb
168
+ dE = color_difference(img_rgb, img_wb)
169
+ print(dE)
170
+ if dE > wb_thres:
171
+ self.applied_wb = True
172
+ img_rgb = img_wb
173
+ print('use white balance correction on {}'.format(self.filename.split('/')[-1]))
174
+
175
+ hist_value, hist_count, center, density, colorlabel, histlabel = self.solve_ind_palette(img_rgb, mask_binary=None)
176
+ self.center = center
177
+ self.colorlabel = colorlabel
178
+
179
+ sal_links = [i for i in range(np.size(center, axis=0))]
180
+
181
+ if not if_saliency:
182
+ return hist_value, hist_count, center, density, colorlabel, sal_links
183
+
184
+ else:
185
+ self.sal_map = self.saliency_detection(self.img_rgb)
186
+ label_sem = np.zeros_like(self.img_rgb[:,:,0])
187
+ # print(label_sem.shape, self.sal_map.shape)
188
+ label_sem[self.sal_map > sal_thres]=1
189
+
190
+ p_feq = np.zeros((len(valid_class), np.size(center, axis=0)))
191
+
192
+
193
+ for id_cls, cls in enumerate(valid_class):
194
+ label_binary = np.zeros_like(label_sem)
195
+ label_binary[label_sem==cls] = 1
196
+ colorlabel_cls = colorlabel[label_binary==1]
197
+ value, count = np.unique(colorlabel_cls, return_counts=True)
198
+ p_feq[id_cls, value] = count/count.sum()
199
+
200
+ palettelabel = np.argmax(p_feq, axis=0)
201
+
202
+ class_num = len(valid_class)
203
+ c_center = [np.array([]) for i in range(class_num)]
204
+ c_density = [np.array([]) for i in range(class_num)]
205
+ c_img_label = [np.array([]) for i in range(class_num)]
206
+ hist_samples = [np.array([]) for i in range(class_num)]
207
+ hist_counts = [np.array([]) for i in range(class_num)]
208
+ mapping = [np.array([]) for i in range(class_num)]
209
+
210
+
211
+ for id_cls, cls in enumerate(valid_class):
212
+ mapping[id_cls] = np.argwhere(palettelabel==id_cls).flatten()
213
+ c_center[id_cls]= center[mapping[id_cls],:]
214
+ c_density[id_cls] = density[mapping[id_cls]]
215
+ hist_samples[id_cls] = hist_value.copy()
216
+ hist_counts[id_cls] = hist_count.copy()
217
+ hist_counts[id_cls][histlabel!=id_cls] = 0
218
+
219
+
220
+ for idx, label in enumerate(mapping[id_cls]):
221
+ labels = np.zeros_like(colorlabel)
222
+ labels[colorlabel==label] = idx
223
+ c_img_label[id_cls] = labels
224
+
225
+ # if id_cls ==1:
226
+ # label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
227
+ # for id_color in mapping[id_cls]:
228
+ # label_colored[colorlabel == id_color] = center[id_color, :]
229
+ # label_colored = lab2rgb(label_colored)
230
+ # label_colored = np.round(label_colored*255).astype(np.uint8)
231
+
232
+ # print(colorlabel.shape, c_img_label[id_cls].shape)
233
+ # print(density.shape, c_density[id_cls].shape)
234
+ # print(center.shape, c_center[id_cls].shape)
235
+
236
+ sal_links = np.hstack((mapping[1], mapping[0]))
237
+
238
+ # print(links)
239
+
240
+ return hist_samples, hist_counts, c_center, c_density, c_img_label, sal_links
241
+
242
+
multi_image_process.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from skimage.color import rgb2lab
6
+
7
+ from solve_group_palette import solve_group_palette
8
+ from recolor import lab_transfer
9
+ from color_naming.colornaming import compare_color_name
10
+ from utils import visualize_palette
11
+ from image import BaseImage
12
+
13
+
14
+ # recolor single image with given palette TO DO
15
+ def recolor_single_image(image, save_dir='./results/testing'):
16
+ image = image
17
+
18
+ return image
19
+
20
+
21
+ # compute and save the palette for each image
22
+ def compute_inp_palette(images, save_dir='./results/testing'):
23
+ palette_all = [i for i in range(len(images))]
24
+
25
+ for img_id, img in enumerate(images):
26
+
27
+ img_name = os.path.basename(img.name)
28
+ print('processing image {}: {}...'.format(img_id, img_name))
29
+ image = BaseImage(img)
30
+ _, _, c_center, _, _, _ = image.solve_ind_palette(image.img_rgb, mask_binary=None)
31
+ # print(c_center)
32
+
33
+ if not os.path.exists(save_dir):
34
+ os.makedirs(save_dir)
35
+ imwrite_path = os.path.join(save_dir, 'palette_'+ img_name[:-4]+'.png')
36
+ img_palette = visualize_palette(c_center, patch_size=20)
37
+ img_palette = np.round(img_palette*255).astype(np.uint8)
38
+ cv2.imwrite(imwrite_path, cv2.cvtColor(img_palette, cv2.COLOR_RGB2BGR))
39
+
40
+ palette_all[img_id] = img_palette.copy()
41
+ return palette_all
42
+
43
+
44
+ # match the palette with the group palette
45
+ def match_palette(palette_ind, palette_grp, L_idx, if_cn, naming_thres):
46
+ # print(L_idx)
47
+ # print(palette_ind)
48
+ palette_mapped = palette_ind.copy()
49
+ valid = (L_idx > 0)
50
+ # print('L_idx:', L_idx)
51
+ if L_idx.size == 0:
52
+ return palette_mapped, L_idx
53
+ valid = valid.flatten()
54
+ idx = np.argwhere(L_idx.flatten() > 0).flatten()
55
+
56
+ palette_mapped[idx, :] = palette_grp[L_idx[valid].flatten()-1, :]
57
+
58
+ if if_cn:
59
+ # print("Check the matching colors with color naming")
60
+ # print(L_idx, palette_mapped)
61
+ w2c = np.load('./color_naming/w2c11_joost_c.npy').astype(np.float16)
62
+ is_the_same = compare_color_name(palette_ind, palette_mapped, w2c, threshold=naming_thres)
63
+ mask = (is_the_same == True).T
64
+ idx = np.argwhere(is_the_same == False).flatten()
65
+ L_idx = L_idx * mask
66
+ palette_mapped[idx, :] = palette_ind[idx, :]
67
+ # print(L_idx, palette_mapped)
68
+ return palette_mapped, L_idx
69
+
70
+
71
+ # solve the group palette without saliency
72
+ def solve_grp_palette(images, mode, bin_size,
73
+ if_wb=False, wb_thres=20, num_center=5,
74
+ lightness=70., eta=1e10, gamma=0, iteration=10,
75
+ if_cn=False, naming_thres=0.5):
76
+
77
+ num_img = len(images)
78
+ c_center = [np.array([]) for i in range(num_img)]
79
+ c_density = [np.array([]) for i in range(num_img)]
80
+ palette_map = [np.array([]) for i in range(num_img)]
81
+ L_idx = [np.array([]) for i in range(num_img)]
82
+
83
+ if mode == 3:
84
+ hist_samples_all = np.zeros((bin_size**3, 3))
85
+ hist_counts_all = np.zeros(bin_size**3)
86
+
87
+ elif mode ==2:
88
+ hist_samples_all = np.zeros((bin_size**2, 2))
89
+ hist_counts_all = np.zeros(bin_size**2)
90
+
91
+ for img_id, image in enumerate(images):
92
+ image.inital_info(if_wb,
93
+ False,
94
+ wb_thres,
95
+ 0.9,
96
+ [0, 1])
97
+
98
+ density = np.tile(image.hist_count, (mode, 1))
99
+ hist_counts_all = hist_counts_all + image.hist_count
100
+ hist_samples_all = hist_samples_all + density.T * image.hist_value
101
+
102
+ index = np.argwhere(hist_counts_all != 0)
103
+ index = np.squeeze(index, axis=(1,))
104
+
105
+ hist_counts_all = hist_counts_all[index]
106
+ hist_samples_all = hist_samples_all[index, :] / np.expand_dims(hist_counts_all, axis=1)
107
+
108
+
109
+ print('Solving the group palette...')
110
+ ## take the number of palettes of each input image as the reference
111
+ reference = np.zeros((num_img, 1))
112
+ if np.sum(reference) == 0:
113
+ reference = reference + 1
114
+
115
+ num_palettes = 0
116
+ for i in range(num_img):
117
+ if reference[i]:
118
+ num_palettes = num_palettes + np.size(images[i].c_center, 1)
119
+
120
+
121
+ ##### calculate inital group center with kmeans
122
+ # print(palette_size_default.get().dtype())
123
+ m = np.minimum(int(num_center), num_palettes)
124
+
125
+ c_center = [image.c_center for image in images]
126
+ c_density = [image.c_density for image in images]
127
+
128
+
129
+ M, matching = solve_group_palette(hist_samples_all, hist_counts_all,
130
+ c_center, c_density, reference, m,
131
+ lightness=lightness, eta=eta,
132
+ gamma=gamma, iteration=iteration)
133
+
134
+ for img_id in range(num_img):
135
+ palette_map[img_id], L_idx[img_id] = match_palette(images[img_id].c_center, M, matching[img_id], if_cn, naming_thres)
136
+ # print('c_center: ', c_center)
137
+
138
+ return c_center, M, palette_map, L_idx
139
+
140
+
141
+
142
+ # solve the group palette with saliency
143
+ def solve_grp_palette_wsal(images, mode, bin_size,
144
+ lightness=70., eta=1e10, gamma=0, iteration=10,
145
+ if_wb=False, wb_thres=30,
146
+ if_saliency=True, sal_thres=0.9, valid_class=[0,1],
147
+ sal_center=1, nonsal_center=1,
148
+ recolor_nonsal_only=False, recolor_sal_only=False,
149
+ if_cn=False, naming_thres=0.5):
150
+
151
+ class_num = len(valid_class)
152
+ num_img = len(images)
153
+
154
+ c_center = [np.array([]) for i in range(num_img)]
155
+ c_density = [np.array([]) for i in range(num_img)]
156
+
157
+ M = [np.array([]) for i in range(class_num)]
158
+ matching = [np.array([]) for i in range(class_num)]
159
+
160
+ palette_map = [[np.array([]) for i in range(num_img)] for i in range(class_num)]
161
+ L_idx = [[np.array([]) for i in range(num_img)] for i in range(class_num)]
162
+ p_src_sal_size = [np.array([]) for i in range(num_img)]
163
+
164
+ if mode == 3:
165
+ hist_samples_all = [np.zeros((bin_size**3, 3)) for i in range(class_num)]
166
+ hist_counts_all = [np.zeros(bin_size**3) for i in range(class_num)]
167
+
168
+ elif mode ==2:
169
+ hist_samples_all = [np.zeros((bin_size**2, 2)) for i in range(class_num)]
170
+ hist_counts_all = [np.zeros(bin_size**2) for i in range(class_num)]
171
+
172
+
173
+ for img_id, image in enumerate(images):
174
+
175
+ image.inital_info(if_wb,
176
+ if_saliency,
177
+ wb_thres,
178
+ sal_thres,
179
+ valid_class)
180
+
181
+ for id_cls, cls in enumerate(valid_class):
182
+ density = np.tile(image.hist_count[id_cls], (mode, 1))
183
+ hist_counts_all[id_cls] = hist_counts_all[id_cls] + image.hist_count[id_cls]
184
+ hist_samples_all[id_cls] = hist_samples_all[id_cls] + density.T * image.hist_value[id_cls]
185
+
186
+
187
+ for id_cls, cls in enumerate(valid_class):
188
+ index = np.argwhere(hist_counts_all[id_cls] != 0)
189
+ index = np.squeeze(index, axis=(1,))
190
+
191
+ hist_counts_all[id_cls] = hist_counts_all[id_cls][index]
192
+ hist_samples_all[id_cls] = hist_samples_all[id_cls] [index, :] / np.expand_dims(hist_counts_all[id_cls], axis=1)
193
+
194
+
195
+ print('Solving the group palette...')
196
+ ## take the number of palettes of each input image as the reference
197
+ reference = np.zeros((num_img, 1))
198
+ if np.sum(reference) == 0:
199
+ reference = reference + 1
200
+
201
+ num_palettes = 0
202
+ for i in range(num_img):
203
+ if reference[i] and images[i].c_center[id_cls].size != 0:
204
+ num_palettes = num_palettes + np.size(images[i].c_center[id_cls] , 1)
205
+
206
+
207
+ ##### calculate inital group center with kmeans
208
+ # print(palette_size_default.get().dtype())
209
+ m=[1,1]
210
+ m[0] = np.minimum(int(sal_center), num_palettes)
211
+ m[1] = np.minimum(int(nonsal_center), num_palettes)
212
+
213
+
214
+ c_center = [image.c_center[id_cls] for image in images]
215
+ c_density = [image.c_density[id_cls] for image in images]
216
+ # print(c_center[id_cls])
217
+
218
+
219
+ M[id_cls], matching[id_cls] = solve_group_palette(hist_samples_all[id_cls], hist_counts_all[id_cls],
220
+ c_center, c_density, reference, m[id_cls],
221
+ lightness=lightness, eta=eta,
222
+ gamma=gamma, iteration=iteration)
223
+ # print(M[id_cls], matching[id_cls])
224
+
225
+ for img_id in range(num_img):
226
+ palette_map[id_cls][img_id], L_idx[id_cls][img_id] = match_palette(images[img_id].c_center[id_cls], M[id_cls], matching[id_cls][img_id], if_cn, naming_thres)
227
+ if id_cls == 1:
228
+ p_src_sal_size[img_id] = np.size(palette_map[1][img_id], axis=0)
229
+
230
+
231
+
232
+ # if only_nonsal_var.get():
233
+ # palette_map[1][img_id] = c_center[1][img_id]
234
+
235
+ p_grp_sal_size = np.size(M[1], axis=0)
236
+ p_nsal_size = np.size(M[0], axis=0)
237
+
238
+ cls_keep = None
239
+ if recolor_nonsal_only:
240
+ print('only recolor non-salient region')
241
+ cls_keep = 0
242
+ elif recolor_sal_only:
243
+ print('only recolor salient region')
244
+ cls_keep = 1
245
+
246
+ if cls_keep is not None:
247
+ for img_id in range(num_img):
248
+ # print(L_idx[1][img_id])
249
+ # print(palette_map[1][img_id])
250
+ if L_idx[cls_keep][img_id].size == 0:
251
+ palette_map[cls_keep][img_id] = images[img_id].c_center[cls_keep]
252
+ L_idx[cls_keep][img_id] = 0
253
+ else:
254
+ # print(L_idx[cls_keep][img_id])
255
+ for idx_nsal in range(len(L_idx[cls_keep][img_id])):
256
+ palette_map[cls_keep][img_id][idx_nsal,:] = images[img_id].c_center[cls_keep][idx_nsal,:]
257
+ L_idx[cls_keep][img_id][idx_nsal] = 0
258
+
259
+
260
+ L_idx_img = [0 for i in range(num_img)]
261
+ palette_map_img = [0 for i in range(num_img)]
262
+ center_img = [0 for i in range(num_img)]
263
+
264
+
265
+ for img_id in range(num_img):
266
+ # print(L_idx[0][img_id], L_idx[1][img_id])
267
+ center_img[img_id] = np.vstack([images[img_id].c_center[1], images[img_id].c_center[0]])
268
+ palette_map_img[img_id] = np.vstack([palette_map[1][img_id], palette_map[0][img_id]])
269
+
270
+ # print(img_id, L_idx[1][img_id], L_idx[0][img_id], p_grp_sal_size)
271
+ if L_idx[1][img_id].size == 0 and L_idx[0][img_id].size == 0:
272
+ L_idx_img[img_id] = np.array([])
273
+ elif L_idx[1][img_id].size == 0:
274
+ L_idx_img[img_id] = L_idx[0][img_id]+p_grp_sal_size
275
+ elif L_idx[0][img_id].size == 0:
276
+ L_idx_img[img_id] = L_idx[1][img_id]
277
+ else:
278
+ L_idx_img[img_id] = np.vstack([L_idx[1][img_id], L_idx[0][img_id]+p_grp_sal_size])
279
+ # print(L_idx_img[img_id])
280
+
281
+ M = np.vstack([M[1], M[0]])
282
+ return center_img, M, palette_map_img, L_idx_img, p_src_sal_size, p_grp_sal_size
283
+
284
+
285
+
286
+
287
+ # recolor multiple images for color consistency
288
+ def recolor_group_images(inp_images,
289
+ num_center_grp, num_center_sal, num_center_nonsal,
290
+ if_wb,
291
+ if_sal, recolor_nonsal_only, recolor_sal_only,
292
+ if_cn, naming_thres,
293
+ save_dir='./results/testing'):
294
+
295
+ mode = 2
296
+ bin_size = 16
297
+
298
+ num_img = len(inp_images)
299
+ images = [None for i in range(num_img)]
300
+ recolored_images = [None for i in range(num_img)]
301
+ img_p_src = [None for i in range(num_img)]
302
+ img_p_tgt = [None for i in range(num_img)]
303
+
304
+ links = []
305
+
306
+ for img_id, image in enumerate(inp_images):
307
+ img_name = os.path.basename(image.name)
308
+ print('processing image {}: {}...'.format(img_id, img_name))
309
+ images[img_id] = BaseImage(image)
310
+
311
+ if if_sal:
312
+ palette_src, palette_grp, palette_tgt, links, src_sal_size, grp_sal_size = solve_grp_palette_wsal(images, mode, bin_size,
313
+ lightness=70., eta=1e10, gamma=0, iteration=10,
314
+ if_wb=if_wb, wb_thres=30,
315
+ if_saliency=True, sal_thres=0.5, valid_class=[0,1],
316
+ sal_center=num_center_sal, nonsal_center=num_center_nonsal,
317
+ recolor_nonsal_only=recolor_nonsal_only, recolor_sal_only=recolor_sal_only,
318
+ if_cn=if_cn, naming_thres=float(naming_thres))
319
+ else:
320
+ src_sal_size = [0 for i in range(num_img)]
321
+ grp_sal_size = 0
322
+ palette_src, palette_grp, palette_tgt, links = solve_grp_palette(images, mode, bin_size,
323
+ if_wb=if_wb, wb_thres=30, num_center=num_center_grp,
324
+ lightness=70., eta=1e10, gamma=0, iteration=10,
325
+ if_cn=if_cn, naming_thres=float(naming_thres))
326
+
327
+ for img_id, image in enumerate(images):
328
+ # print(palette_grp[img_id].shape, palette_tgt[img_id].shape)
329
+ if if_wb:
330
+ img_wb = image.get_wb_image()
331
+ img_lab = rgb2lab(img_wb)
332
+ else:
333
+ img_lab = image.get_lab_image()
334
+ img_rgb_out, _ = lab_transfer(img_lab, palette_src[img_id], palette_tgt[img_id], mask=None, mode=2)
335
+ img_rgb_out = np.round(img_rgb_out*255).astype(np.uint8)
336
+
337
+ out_img_path = os.path.join(save_dir, 'recolor_'+ image.filename)
338
+ img_bgr_out = cv2.cvtColor(img_rgb_out, cv2.COLOR_RGB2BGR)
339
+ cv2.imwrite(out_img_path, img_bgr_out)
340
+
341
+ recolored_images[img_id] = img_rgb_out
342
+
343
+ # self.wb_images[img_id] = image.get_wb_image()
344
+ # self.saliency_images[img_id] = image.get_saliency()
345
+
346
+ link = (links[img_id]-1).flatten()
347
+ link = link.tolist()
348
+ link = [None if item == -1 else item for item in link]
349
+ links.append(link)
350
+
351
+ img_p_grp = [visualize_palette(palette_grp, patch_size=20)]
352
+
353
+ # print(palette_grp, palette_src, palette_tgt)
354
+
355
+ for img_id in range(num_img):
356
+ img_p_src[img_id] = visualize_palette(palette_src[img_id], patch_size=20)
357
+ img_p_tgt[img_id] = visualize_palette(palette_tgt[img_id], patch_size=20)
358
+
359
+ return recolored_images, img_p_src, img_p_tgt, img_p_grp
360
+
361
+
362
+
363
+
364
+
365
+
recolor.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from skimage.color import rgb2lab, lab2rgb
3
+
4
+
5
+ def rgb_transfer(Iin, C_src, C_tgt, mask=None):
6
+
7
+ Iin = np.array(Iin/255.).astype(np.float32)
8
+
9
+ C_src = lab2rgb(np.expand_dims(C_src, axis=0))
10
+ C_tgt = lab2rgb(np.expand_dims(C_tgt, axis=0))
11
+
12
+ C_src = np.squeeze(C_src, axis=0)
13
+ C_tgt = np.squeeze(C_tgt, axis=0)
14
+
15
+ if mask is None:
16
+ mask = np.ones_like(Iin[:,:,0])
17
+
18
+ m, n, b = Iin.shape
19
+
20
+ Iin = np.reshape(Iin, (m*n, b))
21
+ Iout = Iin.copy()
22
+ mask = mask.flatten()
23
+
24
+ Iout = ab_transfer(Iin, C_src, C_tgt, mask=mask)
25
+ Iout = np.reshape(Iout, (m,n,b))
26
+
27
+ return Iout
28
+
29
+
30
+
31
+ def lab_transfer(Iin, C_src, C_tgt, mask=None, mode=2):
32
+ # Convert RGB to Lab
33
+ # print(C_src)
34
+ # print(C_tgt)
35
+ if mask is None:
36
+ mask = np.ones_like(Iin[:,:,0])
37
+
38
+ Pout = C_tgt.copy()
39
+ m, n, b = Iin.shape
40
+
41
+ # Iin = rgb2lab(Iin)
42
+ Iin = np.reshape(Iin, (m*n, b))
43
+ Iout = Iin.copy()
44
+ mask = mask.flatten()
45
+
46
+ # C_src = rgb2lab(np.expand_dims(C_src, axis=0))
47
+ # C_tgt = rgb2lab(np.expand_dims(C_tgt, axis=0))
48
+
49
+ # C_src = np.squeeze(C_src, axis=0)
50
+ # C_tgt = np.squeeze(C_tgt, axis=0)
51
+
52
+ if mode == 2:
53
+ Iout[:, 1:] = ab_transfer(Iin[:, 1:], C_src[:, 1:], C_tgt[:, 1:], mask=mask)
54
+ else:
55
+ Iout[:, 0:] = ab_transfer(Iin[:, 0:], C_src[:, 0:], C_tgt[:, 0:], mask=mask)
56
+
57
+ # Convert Lab to RGB
58
+ Iout = np.reshape(Iout, (m,n,b))
59
+ Iout = lab2rgb(Iout)
60
+
61
+ return Iout, Pout
62
+
63
+
64
+
65
+ def ab_transfer(I_src, C_src, C_tgt, mask=None):
66
+ if mask is None:
67
+ mask = np.ones_like(I_src[:,0])
68
+
69
+ I_tgt = np.zeros_like(I_src)
70
+ [m, b] = I_src.shape
71
+
72
+ # remove close color
73
+ k = np.size(C_src, 0)
74
+ eps = 0.0001
75
+ W = np.zeros((m, k))
76
+ for i in range(k):
77
+ D = np.zeros(m)
78
+ for j in range(b):
79
+ D = D + (I_src[:, j] - C_src[i,j])**2
80
+ W[:, i] = 1./(D + eps)
81
+
82
+ # print(k,b)
83
+
84
+ sumW= np.sum(W, 1)
85
+ for j in range(k):
86
+ W[:, j]= W[:, j] / sumW
87
+
88
+ for i in range(k):
89
+ for j in range(b):
90
+ I_tgt[:, j] = I_tgt[:, j] + W[:, i] * (I_src[:, j] + C_tgt[i, j] - C_src[i, j])
91
+
92
+ idx = np.argwhere(mask == 0)
93
+
94
+ I_tgt[idx, :] = I_src[idx, :]
95
+
96
+ return I_tgt
97
+
98
+
99
+
100
+
101
+ def lab_transfer_cls(Iin, C_src, C_tgt, mask=None, valid_class=None):
102
+ # Convert RGB to Lab
103
+ if mask is None:
104
+ mask = np.ones_like(Iin[:,:,0])
105
+
106
+ Pout = C_tgt.copy()
107
+ m, n, b = Iin.shape
108
+
109
+ Iin = rgb2lab(Iin)
110
+ Iin = np.reshape(Iin, (m*n, b))
111
+ Iout = Iin.copy()
112
+ Iout_cls = np.zeros_like(Iin)
113
+
114
+ mask = mask.flatten()
115
+ mask_bin = np.zeros_like(mask)
116
+
117
+ # C_src = rgb2lab(np.expand_dims(C_src, axis=0))
118
+ # C_tgt = rgb2lab(np.expand_dims(C_tgt, axis=0))
119
+
120
+ # C_src = np.squeeze(C_src, axis=0)
121
+ # C_tgt = np.squeeze(C_tgt, axis=0)
122
+ for id_cls, cls in enumerate(valid_class):
123
+ if C_src[id_cls].size == 0:
124
+ continue
125
+ mask_bin[mask==cls] = 1
126
+ Iout_cls[:, 1:] = ab_transfer(Iin[:, 1:], C_src[id_cls][:, 1:], C_tgt[id_cls][:, 1:], mask=mask_bin)
127
+ idx = np.argwhere(mask_bin == 1)
128
+ Iout[idx, 1:] = Iout_cls[idx, 1:].copy()
129
+
130
+ # Convert Lab to RGB
131
+ Iout = np.reshape(np.round(Iout), (m,n,b))
132
+ Iout = lab2rgb(Iout)
133
+
134
+ return Iout, Pout
135
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ torch
4
+ scikit-learn
5
+ scikit-image
6
+ Pillow
7
+ SciPy
8
+ networkx
9
+ libsvm
saliency/LDF/dataset.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ #coding=utf-8
3
+
4
+ import os
5
+ import cv2
6
+ import torch
7
+ import numpy as np
8
+ from torch.utils.data import Dataset
9
+
10
+ ########################### Data Augmentation ###########################
11
+ class Normalize(object):
12
+ def __init__(self, mean, std):
13
+ self.mean = mean
14
+ self.std = std
15
+
16
+ def __call__(self, image, mask=None, body=None, detail=None):
17
+ image = (image - self.mean)/self.std
18
+ if mask is None:
19
+ return image
20
+ return image, mask/255, body/255, detail/255
21
+
22
+ class RandomCrop(object):
23
+ def __call__(self, image, mask=None, body=None, detail=None):
24
+ H,W,_ = image.shape
25
+ randw = np.random.randint(W/8)
26
+ randh = np.random.randint(H/8)
27
+ offseth = 0 if randh == 0 else np.random.randint(randh)
28
+ offsetw = 0 if randw == 0 else np.random.randint(randw)
29
+ p0, p1, p2, p3 = offseth, H+offseth-randh, offsetw, W+offsetw-randw
30
+ if mask is None:
31
+ return image[p0:p1,p2:p3, :]
32
+ return image[p0:p1,p2:p3, :], mask[p0:p1,p2:p3], body[p0:p1,p2:p3], detail[p0:p1,p2:p3]
33
+
34
+ class RandomFlip(object):
35
+ def __call__(self, image, mask=None, body=None, detail=None):
36
+ if np.random.randint(2)==0:
37
+ if mask is None:
38
+ return image[:,::-1,:].copy()
39
+ return image[:,::-1,:].copy(), mask[:, ::-1].copy(), body[:, ::-1].copy(), detail[:, ::-1].copy()
40
+ else:
41
+ if mask is None:
42
+ return image
43
+ return image, mask, body, detail
44
+
45
+ class Resize(object):
46
+ def __init__(self, H, W):
47
+ self.H = H
48
+ self.W = W
49
+
50
+ def __call__(self, image, mask=None, body=None, detail=None):
51
+ image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
52
+ if mask is None:
53
+ return image
54
+ mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
55
+ body = cv2.resize( body, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
56
+ detail= cv2.resize( detail, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
57
+ return image, mask, body, detail
58
+
59
+ class ToTensor(object):
60
+ def __call__(self, image, mask=None, body=None, detail=None):
61
+ image = torch.from_numpy(image)
62
+ image = image.permute(2, 0, 1)
63
+ if mask is None:
64
+ return image
65
+ mask = torch.from_numpy(mask)
66
+ body = torch.from_numpy(body)
67
+ detail= torch.from_numpy(detail)
68
+ return image, mask, body, detail
69
+
70
+
71
+ ########################### Config File ###########################
72
+ class Config(object):
73
+ def __init__(self, **kwargs):
74
+ self.kwargs = kwargs
75
+ self.mean = np.array([[[124.55, 118.90, 102.94]]])
76
+ self.std = np.array([[[ 56.77, 55.97, 57.50]]])
77
+ # print('\nParameters...')
78
+ # for k, v in self.kwargs.items():
79
+ # print('%-10s: %s'%(k, v))
80
+
81
+ def __getattr__(self, name):
82
+ if name in self.kwargs:
83
+ return self.kwargs[name]
84
+ else:
85
+ return None
86
+
87
+
88
+ ########################### Dataset Class ###########################
89
+ class Data(Dataset):
90
+ def __init__(self, cfg):
91
+ self.cfg = cfg
92
+ self.normalize = Normalize(mean=cfg.mean, std=cfg.std)
93
+ self.randomcrop = RandomCrop()
94
+ self.randomflip = RandomFlip()
95
+ self.resize = Resize(352, 352)
96
+ self.totensor = ToTensor()
97
+
98
+ with open(cfg.datapath+'/'+cfg.mode+'.txt', 'r') as lines:
99
+ self.samples = []
100
+ for line in lines:
101
+ self.samples.append(line.strip())
102
+
103
+ def __getitem__(self, idx):
104
+ name = self.samples[idx]
105
+ image = cv2.imread(self.cfg.datapath+'/image/'+name+'.jpg')[:,:,::-1].astype(np.float32)
106
+
107
+ if self.cfg.mode=='train':
108
+ mask = cv2.imread(self.cfg.datapath+'/mask/' +name+'.png', 0).astype(np.float32)
109
+ body = cv2.imread(self.cfg.datapath+'/body/' +name+'.png', 0).astype(np.float32)
110
+ detail= cv2.imread(self.cfg.datapath+'/detail/' +name+'.png', 0).astype(np.float32)
111
+ image, mask, body, detail = self.normalize(image, mask, body, detail)
112
+ image, mask, body, detail = self.randomcrop(image, mask, body, detail)
113
+ image, mask, body, detail = self.randomflip(image, mask, body, detail)
114
+ return image, mask, body, detail
115
+ else:
116
+ shape = image.shape[:2]
117
+ image = self.normalize(image)
118
+ image = self.resize(image)
119
+ image = self.totensor(image)
120
+ return image, shape, name
121
+
122
+ def __len__(self):
123
+ return len(self.samples)
124
+
125
+ def collate(self, batch):
126
+ size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
127
+ image, mask, body, detail = [list(item) for item in zip(*batch)]
128
+ for i in range(len(batch)):
129
+ image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
130
+ mask[i] = cv2.resize(mask[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
131
+ body[i] = cv2.resize(body[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
132
+ detail[i]= cv2.resize(detail[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
133
+ image = torch.from_numpy(np.stack(image, axis=0)).permute(0,3,1,2)
134
+ mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1)
135
+ body = torch.from_numpy(np.stack(body, axis=0)).unsqueeze(1)
136
+ detail = torch.from_numpy(np.stack(detail, axis=0)).unsqueeze(1)
137
+ return image, mask, body, detail
saliency/LDF/infer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import os
4
+ import sys
5
+ import numpy as np
6
+ sys.dont_write_bytecode = True
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+
11
+ import saliency.LDF.dataset as dataset
12
+ from saliency.LDF.net import LDF
13
+
14
+
15
+ ## Implementation of saliency detection with LDF model.
16
+ class Saliency_LDF:
17
+ def __init__(self, pretrained_model='./saliency/LDF/model-40'):
18
+ self.cfg = dataset.Config(snapshot=pretrained_model, mode='test')
19
+ self.normalize = dataset.Normalize(mean=self.cfg.mean, std=self.cfg.std)
20
+ self.resize = dataset.Resize(352, 352)
21
+ self.totensor = dataset.ToTensor()
22
+ ## network
23
+ self.net = LDF(self.cfg)
24
+ self.net.train(False)
25
+ self.net.cuda()
26
+
27
+ def inference(self, img_rgb):
28
+ shape = img_rgb.shape[:2]
29
+ image = self.normalize(img_rgb)
30
+ image = self.resize(image)
31
+ image = self.totensor(image)
32
+
33
+ with torch.no_grad():
34
+ image = image.unsqueeze(0)
35
+ image = image.cuda().float()
36
+ outb1, outd1, out1, outb2, outd2, out2 = self.net(image, shape)
37
+ out = out2
38
+ pred = torch.sigmoid(out[0,0]).cpu().numpy() #[0,1]
39
+
40
+ return pred
saliency/LDF/model-40 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:972a77654afdad80f99ce29968e8c02d5f14cad0aa52e21b46b1c47e1468d68e
3
+ size 100920708
saliency/LDF/net.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ #coding=utf-8
3
+
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ def weight_init(module):
11
+ for n, m in module.named_children():
12
+ print('initialize: '+n)
13
+ if isinstance(m, nn.Conv2d):
14
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
15
+ if m.bias is not None:
16
+ nn.init.zeros_(m.bias)
17
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
18
+ nn.init.ones_(m.weight)
19
+ if m.bias is not None:
20
+ nn.init.zeros_(m.bias)
21
+ elif isinstance(m, nn.Linear):
22
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
23
+ if m.bias is not None:
24
+ nn.init.zeros_(m.bias)
25
+ elif isinstance(m, nn.Sequential):
26
+ weight_init(m)
27
+ elif isinstance(m, nn.ReLU):
28
+ pass
29
+ else:
30
+ m.initialize()
31
+
32
+ class Bottleneck(nn.Module):
33
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
34
+ super(Bottleneck, self).__init__()
35
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
36
+ self.bn1 = nn.BatchNorm2d(planes)
37
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation)
38
+ self.bn2 = nn.BatchNorm2d(planes)
39
+ self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
40
+ self.bn3 = nn.BatchNorm2d(planes*4)
41
+ self.downsample = downsample
42
+
43
+ def forward(self, x):
44
+ out = F.relu(self.bn1(self.conv1(x)), inplace=True)
45
+ out = F.relu(self.bn2(self.conv2(out)), inplace=True)
46
+ out = self.bn3(self.conv3(out))
47
+ if self.downsample is not None:
48
+ x = self.downsample(x)
49
+ return F.relu(out+x, inplace=True)
50
+
51
+ class ResNet(nn.Module):
52
+ def __init__(self):
53
+ super(ResNet, self).__init__()
54
+ # self.cfg = cfg
55
+ self.inplanes = 64
56
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
57
+ self.bn1 = nn.BatchNorm2d(64)
58
+ self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1)
59
+ self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)
60
+ self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)
61
+ self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)
62
+ # if self.training:
63
+ # self.initialize()
64
+
65
+ def make_layer(self, planes, blocks, stride, dilation):
66
+ downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4))
67
+ layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)]
68
+ self.inplanes = planes*4
69
+ for _ in range(1, blocks):
70
+ layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))
71
+ return nn.Sequential(*layers)
72
+
73
+ def forward(self, x):
74
+ out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
75
+ out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
76
+ out2 = self.layer1(out1)
77
+ out3 = self.layer2(out2)
78
+ out4 = self.layer3(out3)
79
+ out5 = self.layer4(out4)
80
+ return out1, out2, out3, out4, out5
81
+
82
+ def initialize(self):
83
+ self.load_state_dict(torch.load('../res/resnet50-19c8e357.pth'), strict=False)
84
+
85
+ class Decoder(nn.Module):
86
+ def __init__(self):
87
+ super(Decoder, self).__init__()
88
+ self.conv0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
89
+ self.bn0 = nn.BatchNorm2d(64)
90
+ self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
91
+ self.bn1 = nn.BatchNorm2d(64)
92
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
93
+ self.bn2 = nn.BatchNorm2d(64)
94
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
95
+ self.bn3 = nn.BatchNorm2d(64)
96
+
97
+ def forward(self, input1, input2=[0,0,0,0]):
98
+ out0 = F.relu(self.bn0(self.conv0(input1[0]+input2[0])), inplace=True)
99
+ out0 = F.interpolate(out0, size=input1[1].size()[2:], mode='bilinear')
100
+ out1 = F.relu(self.bn1(self.conv1(input1[1]+input2[1]+out0)), inplace=True)
101
+ out1 = F.interpolate(out1, size=input1[2].size()[2:], mode='bilinear')
102
+ out2 = F.relu(self.bn2(self.conv2(input1[2]+input2[2]+out1)), inplace=True)
103
+ out2 = F.interpolate(out2, size=input1[3].size()[2:], mode='bilinear')
104
+ out3 = F.relu(self.bn3(self.conv3(input1[3]+input2[3]+out2)), inplace=True)
105
+ return out3
106
+
107
+ def initialize(self):
108
+ weight_init(self)
109
+
110
+
111
+ class Encoder(nn.Module):
112
+ def __init__(self):
113
+ super(Encoder, self).__init__()
114
+ self.conv1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
115
+ self.bn1 = nn.BatchNorm2d(64)
116
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
117
+ self.bn2 = nn.BatchNorm2d(64)
118
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
119
+ self.bn3 = nn.BatchNorm2d(64)
120
+ self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
121
+ self.bn4 = nn.BatchNorm2d(64)
122
+
123
+ self.conv1b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
124
+ self.bn1b = nn.BatchNorm2d(64)
125
+ self.conv2b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
126
+ self.bn2b = nn.BatchNorm2d(64)
127
+ self.conv3b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
128
+ self.bn3b = nn.BatchNorm2d(64)
129
+ self.conv4b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
130
+ self.bn4b = nn.BatchNorm2d(64)
131
+
132
+ self.conv1d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
133
+ self.bn1d = nn.BatchNorm2d(64)
134
+ self.conv2d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
135
+ self.bn2d = nn.BatchNorm2d(64)
136
+ self.conv3d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
137
+ self.bn3d = nn.BatchNorm2d(64)
138
+ self.conv4d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
139
+ self.bn4d = nn.BatchNorm2d(64)
140
+
141
+ def forward(self, out1):
142
+ out1 = F.relu(self.bn1(self.conv1(out1)), inplace=True)
143
+ out2 = F.max_pool2d(out1, kernel_size=2, stride=2)
144
+ out2 = F.relu(self.bn2(self.conv2(out2)), inplace=True)
145
+ out3 = F.max_pool2d(out2, kernel_size=2, stride=2)
146
+ out3 = F.relu(self.bn3(self.conv3(out3)), inplace=True)
147
+ out4 = F.max_pool2d(out3, kernel_size=2, stride=2)
148
+ out4 = F.relu(self.bn4(self.conv4(out4)), inplace=True)
149
+
150
+ out1b = F.relu(self.bn1b(self.conv1b(out1)), inplace=True)
151
+ out2b = F.relu(self.bn2b(self.conv2b(out2)), inplace=True)
152
+ out3b = F.relu(self.bn3b(self.conv3b(out3)), inplace=True)
153
+ out4b = F.relu(self.bn4b(self.conv4b(out4)), inplace=True)
154
+
155
+ out1d = F.relu(self.bn1d(self.conv1d(out1)), inplace=True)
156
+ out2d = F.relu(self.bn2d(self.conv2d(out2)), inplace=True)
157
+ out3d = F.relu(self.bn3d(self.conv3d(out3)), inplace=True)
158
+ out4d = F.relu(self.bn4d(self.conv4d(out4)), inplace=True)
159
+ return (out4b, out3b, out2b, out1b), (out4d, out3d, out2d, out1d)
160
+
161
+ def initialize(self):
162
+ weight_init(self)
163
+
164
+
165
+ class LDF(nn.Module):
166
+ def __init__(self, cfg):
167
+ super(LDF, self).__init__()
168
+ self.cfg = cfg
169
+ self.bkbone = ResNet()
170
+ self.conv5b = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
171
+ self.conv4b = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
172
+ self.conv3b = nn.Sequential(nn.Conv2d( 512, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
173
+ self.conv2b = nn.Sequential(nn.Conv2d( 256, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
174
+
175
+ self.conv5d = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
176
+ self.conv4d = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
177
+ self.conv3d = nn.Sequential(nn.Conv2d( 512, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
178
+ self.conv2d = nn.Sequential(nn.Conv2d( 256, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
179
+
180
+ self.encoder = Encoder()
181
+ self.decoderb = Decoder()
182
+ self.decoderd = Decoder()
183
+ self.linearb = nn.Conv2d(64, 1, kernel_size=3, padding=1)
184
+ self.lineard = nn.Conv2d(64, 1, kernel_size=3, padding=1)
185
+ self.linear = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 1, kernel_size=3, padding=1))
186
+ self.initialize()
187
+
188
+ def forward(self, x, shape=None):
189
+ out1, out2, out3, out4, out5 = self.bkbone(x)
190
+ out2b, out3b, out4b, out5b = self.conv2b(out2), self.conv3b(out3), self.conv4b(out4), self.conv5b(out5)
191
+ out2d, out3d, out4d, out5d = self.conv2d(out2), self.conv3d(out3), self.conv4d(out4), self.conv5d(out5)
192
+
193
+ outb1 = self.decoderb([out5b, out4b, out3b, out2b])
194
+ outd1 = self.decoderd([out5d, out4d, out3d, out2d])
195
+ out1 = torch.cat([outb1, outd1], dim=1)
196
+ outb2, outd2 = self.encoder(out1)
197
+ outb2 = self.decoderb([out5b, out4b, out3b, out2b], outb2)
198
+ outd2 = self.decoderd([out5d, out4d, out3d, out2d], outd2)
199
+ out2 = torch.cat([outb2, outd2], dim=1)
200
+
201
+ if shape is None:
202
+ shape = x.size()[2:]
203
+ out1 = F.interpolate(self.linear(out1), size=shape, mode='bilinear')
204
+ outb1 = F.interpolate(self.linearb(outb1), size=shape, mode='bilinear')
205
+ outd1 = F.interpolate(self.lineard(outd1), size=shape, mode='bilinear')
206
+
207
+ out2 = F.interpolate(self.linear(out2), size=shape, mode='bilinear')
208
+ outb2 = F.interpolate(self.linearb(outb2), size=shape, mode='bilinear')
209
+ outd2 = F.interpolate(self.lineard(outd2), size=shape, mode='bilinear')
210
+ return outb1, outd1, out1, outb2, outd2, out2
211
+
212
+ def initialize(self):
213
+ if self.cfg.snapshot:
214
+ self.load_state_dict(torch.load(self.cfg.snapshot))
215
+ else:
216
+ weight_init(self)
saliency/fast_saliency.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cv2
3
+ import numpy as np
4
+ import scipy.spatial.distance
5
+ import scipy.signal
6
+ import skimage
7
+ import skimage.io
8
+ import time
9
+ from skimage.segmentation import slic
10
+ from skimage.util import img_as_float
11
+ import networkx as nx
12
+ #import matplotlib.pyplot as plt
13
+
14
+
15
+ def S(x1,x2,geodesic,sigma_clr=10):
16
+ return math.exp(-pow(geodesic[x1,x2],2)/(2*sigma_clr*sigma_clr))
17
+
18
+ def compute_saliency_cost(smoothness,w_bg,wCtr):
19
+ n = len(w_bg)
20
+ A = np.zeros((n,n))
21
+ b = np.zeros((n))
22
+
23
+ for x in range(0,n):
24
+ A[x,x] = 2 * w_bg[x] + 2 * (wCtr[x])
25
+ b[x] = 2 * wCtr[x]
26
+ for y in range(0,n):
27
+ A[x,x] += 2 * smoothness[x,y]
28
+ A[x,y] -= 2 * smoothness[x,y]
29
+
30
+ x = np.linalg.solve(A, b)
31
+
32
+ return x
33
+
34
+ def path_length(path,G):
35
+ dist = 0.0
36
+ for i in range(1,len(path)):
37
+ dist += G[path[i - 1]][path[i]]['weight']
38
+ return dist
39
+
40
+ def make_graph(grid):
41
+ # get unique labels
42
+ vertices = np.unique(grid)
43
+
44
+ # map unique labels to [1,...,num_labels]
45
+ reverse_dict = dict(zip(vertices,np.arange(len(vertices))))
46
+ grid = np.array([reverse_dict[x] for x in grid.flat]).reshape(grid.shape)
47
+
48
+ # create edges
49
+ down = np.c_[grid[:-1, :].ravel(), grid[1:, :].ravel()]
50
+ right = np.c_[grid[:, :-1].ravel(), grid[:, 1:].ravel()]
51
+ all_edges = np.vstack([right, down])
52
+ all_edges = all_edges[all_edges[:, 0] != all_edges[:, 1], :]
53
+ all_edges = np.sort(all_edges,axis=1)
54
+ num_vertices = len(vertices)
55
+ edge_hash = all_edges[:,0] + num_vertices * all_edges[:, 1]
56
+ # find unique connections
57
+ edges = np.unique(edge_hash)
58
+ # undo hashing
59
+ edges = [[vertices[x%num_vertices],
60
+ vertices[x//num_vertices]] for x in edges]
61
+
62
+ return vertices, edges
63
+
64
+
65
+ # Saliency map calculation based on:
66
+ # Wangjiang Zhu, Shuang Liang, Yichen Wei and Jian Sun,
67
+ # Saliency Optimization from Robust Background Detection, (CVPR), 2014
68
+
69
+ # based on the asumption that saliency region has smaller 'boundary connectivity'
70
+ # object regions (salient regions) are much less connected to image boundaries than background ones
71
+ # superpixel
72
+
73
+ def get_saliency_rbd(img):
74
+
75
+ img_lab = img_as_float(skimage.color.rgb2lab(img))
76
+
77
+ img_rgb = img_as_float(img)
78
+
79
+ img_gray = img_as_float(skimage.color.rgb2gray(img))
80
+
81
+ segments_slic = slic(img_rgb, n_segments=250, compactness=10, sigma=1, enforce_connectivity=False)
82
+
83
+ # num_segments = len(np.unique(segments_slic))
84
+
85
+ nrows, ncols = segments_slic.shape
86
+ max_dist = math.sqrt(nrows*nrows + ncols*ncols)
87
+
88
+ grid = segments_slic
89
+
90
+ (vertices,edges) = make_graph(grid)
91
+
92
+ gridx, gridy = np.mgrid[:grid.shape[0], :grid.shape[1]]
93
+
94
+ centers = dict()
95
+ colors = dict()
96
+ # distances = dict()
97
+ boundary = dict()
98
+
99
+ for v in vertices:
100
+ centers[v] = [gridy[grid == v].mean(), gridx[grid == v].mean()]
101
+ colors[v] = np.mean(img_lab[grid==v],axis=0)
102
+
103
+ x_pix = gridx[grid == v]
104
+ y_pix = gridy[grid == v]
105
+
106
+ if np.any(x_pix == 0) or np.any(y_pix == 0) or np.any(x_pix == nrows - 1) or np.any(y_pix == ncols - 1):
107
+ boundary[v] = 1
108
+ else:
109
+ boundary[v] = 0
110
+
111
+ G = nx.Graph()
112
+
113
+ #buid the graph
114
+ for edge in edges:
115
+ pt1 = edge[0]
116
+ pt2 = edge[1]
117
+ color_distance = scipy.spatial.distance.euclidean(colors[pt1],colors[pt2])
118
+ G.add_edge(pt1, pt2, weight=color_distance )
119
+
120
+ #add a new edge in graph if edges are both on boundary
121
+ for v1 in vertices:
122
+ if boundary[v1] == 1:
123
+ for v2 in vertices:
124
+ if boundary[v2] == 1:
125
+ color_distance = scipy.spatial.distance.euclidean(colors[v1],colors[v2])
126
+ G.add_edge(v1,v2,weight=color_distance)
127
+
128
+ geodesic = np.zeros((len(vertices),len(vertices)),dtype=float)
129
+ spatial = np.zeros((len(vertices),len(vertices)),dtype=float)
130
+ smoothness = np.zeros((len(vertices),len(vertices)),dtype=float)
131
+ adjacency = np.zeros((len(vertices),len(vertices)),dtype=float)
132
+
133
+ sigma_clr = 10.0
134
+ sigma_bndcon = 1.0
135
+ sigma_spa = 0.25
136
+ mu = 0.1
137
+
138
+ all_shortest_paths_color = nx.shortest_path(G,source=None,target=None,weight='weight')
139
+
140
+ for v1 in vertices:
141
+ for v2 in vertices:
142
+ if v1 == v2:
143
+ geodesic[v1,v2] = 0
144
+ spatial[v1,v2] = 0
145
+ smoothness[v1,v2] = 0
146
+ else:
147
+ geodesic[v1,v2] = path_length(all_shortest_paths_color[v1][v2],G)
148
+ spatial[v1,v2] = scipy.spatial.distance.euclidean(centers[v1],centers[v2]) / max_dist
149
+ smoothness[v1,v2] = math.exp( - (geodesic[v1,v2] * geodesic[v1,v2])/(2.0*sigma_clr*sigma_clr)) + mu
150
+
151
+ for edge in edges:
152
+ pt1 = edge[0]
153
+ pt2 = edge[1]
154
+ adjacency[pt1,pt2] = 1
155
+ adjacency[pt2,pt1] = 1
156
+
157
+ for v1 in vertices:
158
+ for v2 in vertices:
159
+ smoothness[v1,v2] = adjacency[v1,v2] * smoothness[v1,v2]
160
+
161
+ area = dict()
162
+ len_bnd = dict()
163
+ bnd_con = dict()
164
+ w_bg = dict()
165
+ ctr = dict()
166
+ wCtr = dict()
167
+
168
+ for v1 in vertices:
169
+ area[v1] = 0
170
+ len_bnd[v1] = 0
171
+ ctr[v1] = 0
172
+ for v2 in vertices:
173
+ d_app = geodesic[v1,v2]
174
+ d_spa = spatial[v1,v2]
175
+ w_spa = math.exp(- ((d_spa)*(d_spa))/(2.0*sigma_spa*sigma_spa))
176
+ area_i = S(v1,v2,geodesic)
177
+ area[v1] += area_i
178
+ len_bnd[v1] += area_i * boundary[v2]
179
+ ctr[v1] += d_app * w_spa
180
+ bnd_con[v1] = len_bnd[v1] / math.sqrt(area[v1])
181
+ w_bg[v1] = 1.0 - math.exp(- (bnd_con[v1]*bnd_con[v1])/(2*sigma_bndcon*sigma_bndcon))
182
+
183
+ for v1 in vertices:
184
+ wCtr[v1] = 0
185
+ for v2 in vertices:
186
+ d_app = geodesic[v1,v2]
187
+ d_spa = spatial[v1,v2]
188
+ w_spa = math.exp(- (d_spa*d_spa)/(2.0*sigma_spa*sigma_spa))
189
+ wCtr[v1] += d_app * w_spa * w_bg[v2]
190
+
191
+ # normalise value for wCtr
192
+
193
+ min_value = min(wCtr.values())
194
+ max_value = max(wCtr.values())
195
+
196
+ minVal = [key for key, value in wCtr.items() if value == min_value]
197
+ maxVal = [key for key, value in wCtr.items() if value == max_value]
198
+
199
+ for v in vertices:
200
+ wCtr[v] = (wCtr[v] - min_value)/(max_value - min_value)
201
+
202
+ img_disp1 = img_gray.copy()
203
+ # img_disp2 = img_gray.copy()
204
+
205
+ x = compute_saliency_cost(smoothness,w_bg,wCtr)
206
+
207
+ for v in vertices:
208
+ img_disp1[grid == v] = x[v]
209
+
210
+ img_disp2 = img_disp1.copy()
211
+ sal = np.zeros((img_disp1.shape[0],img_disp1.shape[1],3))
212
+
213
+ sal = img_disp2
214
+ sal_max = np.max(sal)
215
+ sal_min = np.min(sal)
216
+ sal = ((sal - sal_min) / (sal_max - sal_min))
217
+
218
+ return sal
219
+
220
+
221
+
222
+ # Saliency map calculation based on:
223
+ # R. Achanta, S. Hemami, F. Estrada and S. Süsstrunk,
224
+ # Frequency-tuned Salient Region Detection, (CVPR 2009), pp. 1597 - 1604, 2009
225
+ # a frequency-tuned approach to estimate center-surround contrast using color and luminance features
226
+ # combine several different filters to remove unwanted high-frequency components
227
+
228
+ def get_saliency_ft(img):
229
+
230
+ img_rgb = img_as_float(img)
231
+
232
+ img_lab = skimage.color.rgb2lab(img_rgb)
233
+
234
+ mean_val = np.mean(img_rgb,axis=(0,1))
235
+
236
+ kernel_h = (1.0/16.0) * np.array([[1,4,6,4,1]])
237
+ kernel_w = kernel_h.transpose()
238
+
239
+ blurred_l = scipy.signal.convolve2d(img_lab[:,:,0],kernel_h,mode='same')
240
+ blurred_a = scipy.signal.convolve2d(img_lab[:,:,1],kernel_h,mode='same')
241
+ blurred_b = scipy.signal.convolve2d(img_lab[:,:,2],kernel_h,mode='same')
242
+
243
+ blurred_l = scipy.signal.convolve2d(blurred_l,kernel_w,mode='same')
244
+ blurred_a = scipy.signal.convolve2d(blurred_a,kernel_w,mode='same')
245
+ blurred_b = scipy.signal.convolve2d(blurred_b,kernel_w,mode='same')
246
+
247
+ im_blurred = np.dstack([blurred_l,blurred_a,blurred_b])
248
+
249
+ sal = np.linalg.norm(mean_val - im_blurred,axis = 2)
250
+ sal_max = np.max(sal)
251
+ sal_min = np.min(sal)
252
+ sal = ((sal - sal_min) / (sal_max - sal_min))
253
+ return sal
254
+
255
+
256
+
257
+ def raster_scan(img,L,U,D):
258
+ n_rows = len(img)
259
+ n_cols = len(img[0])
260
+
261
+ for x in range(1, n_rows - 1):
262
+ for y in range(1, n_cols - 1):
263
+ ix = img[x][y]
264
+ d = D[x][y]
265
+
266
+ u1 = U[x-1][y]
267
+ l1 = L[x-1][y]
268
+
269
+ u2 = U[x][y-1]
270
+ l2 = L[x][y-1]
271
+
272
+ b1 = max(u1,ix) - min(l1,ix)
273
+ b2 = max(u2,ix) - min(l2,ix)
274
+
275
+ if d <= b1 and d <= b2:
276
+ continue
277
+ elif b1 < d and b1 <= b2:
278
+ D[x][y] = b1
279
+ U[x][y] = max(u1, ix)
280
+ L[x][y] = min(l1, ix)
281
+ else:
282
+ D[x][y] = b2
283
+ U[x][y] = max(u2, ix)
284
+ L[x][y] = min(l2, ix)
285
+ return True
286
+
287
+
288
+ def raster_scan(img, L, U, D):
289
+ n_rows = len(img)
290
+ n_cols = len(img[0])
291
+
292
+ for x in range(1, n_rows - 1):
293
+ for y in range(1, n_cols - 1):
294
+ ix = img[x][y]
295
+ d = D[x][y]
296
+
297
+ u1 = U[x-1][y]
298
+ l1 = L[x-1][y]
299
+
300
+ u2 = U[x][y-1]
301
+ l2 = L[x][y-1]
302
+
303
+ b1 = max(u1,ix) - min(l1,ix)
304
+ b2 = max(u2,ix) - min(l2,ix)
305
+
306
+ if d <= b1 and d <= b2:
307
+ continue
308
+ elif b1 < d and b1 <= b2:
309
+ D[x][y] = b1
310
+ U[x][y] = max(u1, ix)
311
+ L[x][y] = min(l1, ix)
312
+ else:
313
+ D[x][y] = b2
314
+ U[x][y] = max(u2, ix)
315
+ L[x][y] = min(l2, ix)
316
+ return True
317
+
318
+
319
+
320
+
321
+
322
+ def raster_scan_inv(img,L,U,D):
323
+ n_rows = len(img)
324
+ n_cols = len(img[0])
325
+
326
+ for x in range(n_rows - 2, 1, -1):
327
+ for y in range(n_cols - 2, 1, -1):
328
+
329
+ ix = img[x][y]
330
+ d = D[x][y]
331
+
332
+ u1 = U[x+1][y]
333
+ l1 = L[x+1][y]
334
+
335
+ u2 = U[x][y+1]
336
+ l2 = L[x][y+1]
337
+
338
+ b1 = max(u1,ix) - min(l1,ix)
339
+ b2 = max(u2,ix) - min(l2,ix)
340
+
341
+ if d <= b1 and d <= b2:
342
+ continue
343
+ elif b1 < d and b1 <= b2:
344
+ D[x][y] = b1
345
+ U[x][y] = max(u1,ix)
346
+ L[x][y] = min(l1,ix)
347
+ else:
348
+ D[x][y] = b2
349
+ U[x][y] = max(u2,ix)
350
+ L[x][y] = min(l2,ix)
351
+ return True
352
+
353
+
354
+ def mbd(img, num_iters):
355
+ if len(img.shape) != 2:
356
+ print('did not get 2d np array to fast mbd')
357
+ return None
358
+ if (img.shape[0] <= 3 or img.shape[1] <= 3):
359
+ print('image is too small')
360
+ return None
361
+
362
+ L = np.copy(img)
363
+ U = np.copy(img)
364
+ D = float('Inf') * np.ones(img.shape)
365
+ D[0,:] = 0
366
+ D[-1,:] = 0
367
+ D[:,0] = 0
368
+ D[:,-1] = 0
369
+
370
+ # unfortunately, iterating over numpy arrays is very slow
371
+ img_list = img.tolist()
372
+ L_list = L.tolist()
373
+ U_list = U.tolist()
374
+ D_list = D.tolist()
375
+
376
+ # start_time = time.time()
377
+ for x in range(0, num_iters):
378
+ if x%2 == 1:
379
+ raster_scan(img_list, L_list, U_list, D_list)
380
+ else:
381
+ raster_scan_inv(img_list, L_list, U_list, D_list)
382
+
383
+ # end_time = time.time()
384
+ # print('mbd function: ', end_time-start_time)
385
+ return np.array(D_list)
386
+
387
+
388
+ # Saliency map calculation based on:
389
+ # Minimum Barrier Salient Object Detection at 80 FPS
390
+ # based on the Image Boundary Connectivity Cue, which assumes that
391
+ # background regions are usually connected to the image borders
392
+ # cons:
393
+ # doesn't consider spatial info, distant objects are detected as the same object
394
+ # fail when the contrast between foreground and background are small
395
+
396
+ def get_saliency_mbd(img):
397
+
398
+ img_mean = np.mean(img, axis=(2))
399
+ # start_time = time.time()
400
+ sal = mbd(img_mean,3)
401
+ # end_time = time.time()
402
+ # print('mbd function: ', end_time-start_time)
403
+
404
+ # get the background map
405
+ # paper uses 30px for an image of size 300px, so we use 10%
406
+ (n_rows, n_cols, n_channels) = img.shape
407
+ img_size = math.sqrt(n_rows * n_cols)
408
+ border_thickness = int(math.floor(0.1 * img_size))
409
+
410
+ img_lab = img_as_float(skimage.color.rgb2lab(img))
411
+
412
+ px_left = img_lab[0:border_thickness,:,:]
413
+ px_right = img_lab[n_rows - border_thickness-1:-1,:,:]
414
+
415
+ px_top = img_lab[:,0:border_thickness,:]
416
+ px_bottom = img_lab[:,n_cols - border_thickness-1:-1,:]
417
+
418
+ px_mean_left = np.mean(px_left,axis=(0,1))
419
+ px_mean_right = np.mean(px_right,axis=(0,1))
420
+ px_mean_top = np.mean(px_top,axis=(0,1))
421
+ px_mean_bottom = np.mean(px_bottom,axis=(0,1))
422
+
423
+
424
+ px_left = px_left.reshape((n_cols*border_thickness,3))
425
+ px_right = px_right.reshape((n_cols*border_thickness,3))
426
+
427
+ px_top = px_top.reshape((n_rows*border_thickness,3))
428
+ px_bottom = px_bottom.reshape((n_rows*border_thickness,3))
429
+
430
+ cov_left = np.cov(px_left.T)
431
+ cov_right = np.cov(px_right.T)
432
+
433
+ cov_top = np.cov(px_top.T)
434
+ cov_bottom = np.cov(px_bottom.T)
435
+
436
+ cov_left = np.linalg.inv(cov_left)
437
+ cov_right = np.linalg.inv(cov_right)
438
+
439
+ cov_top = np.linalg.inv(cov_top)
440
+ cov_bottom = np.linalg.inv(cov_bottom)
441
+
442
+
443
+ u_left = np.zeros(sal.shape)
444
+ u_right = np.zeros(sal.shape)
445
+ u_top = np.zeros(sal.shape)
446
+ u_bottom = np.zeros(sal.shape)
447
+
448
+ u_final = np.zeros(sal.shape)
449
+ img_lab_unrolled = img_lab.reshape(img_lab.shape[0]*img_lab.shape[1],3)
450
+
451
+ px_mean_left_2 = np.zeros((1,3))
452
+ px_mean_left_2[0,:] = px_mean_left
453
+
454
+ u_left = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_left_2,'mahalanobis', VI=cov_left)
455
+ u_left = u_left.reshape((img_lab.shape[0],img_lab.shape[1]))
456
+
457
+ px_mean_right_2 = np.zeros((1,3))
458
+ px_mean_right_2[0,:] = px_mean_right
459
+
460
+ u_right = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_right_2,'mahalanobis', VI=cov_right)
461
+ u_right = u_right.reshape((img_lab.shape[0],img_lab.shape[1]))
462
+
463
+ px_mean_top_2 = np.zeros((1,3))
464
+ px_mean_top_2[0,:] = px_mean_top
465
+
466
+ u_top = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_top_2,'mahalanobis', VI=cov_top)
467
+ u_top = u_top.reshape((img_lab.shape[0],img_lab.shape[1]))
468
+
469
+ px_mean_bottom_2 = np.zeros((1,3))
470
+ px_mean_bottom_2[0,:] = px_mean_bottom
471
+
472
+ u_bottom = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_bottom_2,'mahalanobis', VI=cov_bottom)
473
+ u_bottom = u_bottom.reshape((img_lab.shape[0],img_lab.shape[1]))
474
+
475
+ max_u_left = np.max(u_left)
476
+ max_u_right = np.max(u_right)
477
+ max_u_top = np.max(u_top)
478
+ max_u_bottom = np.max(u_bottom)
479
+
480
+ u_left = u_left / max_u_left
481
+ u_right = u_right / max_u_right
482
+ u_top = u_top / max_u_top
483
+ u_bottom = u_bottom / max_u_bottom
484
+
485
+ u_max = np.maximum(np.maximum(np.maximum(u_left,u_right),u_top),u_bottom)
486
+
487
+ u_final = (u_left + u_right + u_top + u_bottom) - u_max
488
+
489
+ u_max_final = np.max(u_final)
490
+ sal_max = np.max(sal)
491
+ sal = sal / sal_max + u_final / u_max_final
492
+
493
+ #postprocessing
494
+ # apply centredness map
495
+ sal = sal / np.max(sal)
496
+
497
+ s = np.mean(sal)
498
+ alpha = 50.0
499
+ delta = alpha * math.sqrt(s)
500
+
501
+ xv,yv = np.meshgrid(np.arange(sal.shape[1]),np.arange(sal.shape[0]))
502
+ (w,h) = sal.shape
503
+ w2 = w/2.0
504
+ h2 = h/2.0
505
+
506
+ C = 1 - np.sqrt(np.power(xv - h2,2) + np.power(yv - w2,2)) / math.sqrt(np.power(w2,2) + np.power(h2,2))
507
+ sal = sal * C
508
+
509
+ #increase bg/fg contrast
510
+ def f(x):
511
+ b = 10.0
512
+ return 1.0 / (1.0 + math.exp(-b*(x - 0.5)))
513
+
514
+ fv = np.vectorize(f)
515
+ sal = sal / np.max(sal)
516
+ sal = fv(sal)
517
+ return sal
518
+
519
+
520
+ def binarise_saliency_map(saliency_map, method='adaptive',threshold=0.5):
521
+
522
+ # check if input is a numpy array
523
+ if type(saliency_map).__module__ != np.__name__:
524
+ print('Expected numpy array')
525
+ return None
526
+
527
+ #check if input is 2D
528
+ if len(saliency_map.shape) != 2:
529
+ print('Saliency map must be 2D')
530
+ return None
531
+
532
+ if method == 'fixed':
533
+ return (saliency_map > threshold)
534
+
535
+ elif method == 'adaptive':
536
+ adaptive_threshold = 2.0 * saliency_map.mean()
537
+ return (saliency_map > adaptive_threshold)
538
+
539
+ elif method == 'clustering':
540
+ print('Not yet implemented')
541
+ return None
542
+
543
+ else:
544
+ print("Method must be one of fixed, adaptive or clustering")
545
+ return None
546
+
547
+
548
+ if __name__ == '__main__':
549
+ # path to the image
550
+ filename = '../images/flower/19569518092_2db12519fd_c.jpg'
551
+ # filename = './images/landmark_04/12004354405_dc546d53ce_c.jpg'
552
+
553
+ img = skimage.io.imread(filename)
554
+
555
+ if len(img.shape) != 3: # got a grayscale image
556
+ img = skimage.color.gray2rgb(img)
557
+
558
+ # get the saliency maps using the 3 implemented methods
559
+ start_time = time.time()
560
+ rbd = get_saliency_rbd(img)
561
+ end_time = time.time()
562
+ print('rbd:', end_time-start_time)
563
+
564
+ start_time = time.time()
565
+ ft = get_saliency_ft(img)
566
+ end_time = time.time()
567
+ print('ft:', end_time-start_time)
568
+
569
+ start_time = time.time()
570
+ mbd_img = get_saliency_mbd(img)
571
+ end_time = time.time()
572
+ print('mbd:', end_time-start_time)
573
+
574
+ # often, it is desirable to have a binary saliency map
575
+ binary_sal = binarise_saliency_map(mbd_img, method='adaptive')
576
+
577
+ img = cv2.imread(filename)
578
+
579
+ # print(mbd.max())
580
+
581
+ cv2.imshow('img', img)
582
+ cv2.imshow('rbd', rbd)
583
+ cv2.imshow('ft', ft)
584
+ cv2.imshow('mbd', mbd_img)
585
+
586
+ #openCV cannot display numpy type 0, so convert to uint8 and scale
587
+ cv2.imshow('binary', 255 * binary_sal.astype('uint8'))
588
+
589
+
590
+ cv2.waitKey(0)
solve_group_palette.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from sklearn.cluster import KMeans
4
+ from sklearn.metrics import pairwise_distances
5
+
6
+ from utils import stack_list
7
+
8
+
9
+
10
+ def solve_group_palette(hist_sample_all, hist_counts_all, center, density, reference, m, lightness=70, eta=0, gamma=1e10, iteration=10):
11
+ num_img = len(center)
12
+ Lout = [i for i in range(num_img)]
13
+
14
+ if num_img > 1:
15
+ gamma = gamma / ((num_img-1)/2)
16
+ lbd = gamma / 50
17
+
18
+ old_val = 0
19
+ init_min=np.Inf
20
+
21
+ if m == 1:
22
+ M = np.mean(hist_sample_all, axis=0)
23
+ else:
24
+ cinits = np.zeros((m, np.size(hist_sample_all, 1)))
25
+ cw = hist_counts_all
26
+ for i in range(m):
27
+ id = np.argmax(cw)
28
+ cinits[i,:] = hist_sample_all[id,:]
29
+ d2 = cinits[i,:]* np.ones((np.size(hist_sample_all, 0), 1)) - hist_sample_all
30
+ d2 = np.sum(np.square(d2), axis=1)
31
+ d2 = d2/np.max(d2)
32
+ cw = cw * (d2**2)
33
+
34
+ kmeans_grp = KMeans(n_clusters=m, init=cinits, n_init=1).fit(
35
+ hist_sample_all, y=None, sample_weight=hist_counts_all)
36
+ M = kmeans_grp.cluster_centers_
37
+
38
+ if np.size(hist_sample_all, 1) == 2:
39
+ # print(M.shape)
40
+ if M.ndim == 1:
41
+ M = np.expand_dims(M, axis=0)
42
+ M = np.insert(M, 0, values=lightness, axis=1)
43
+ # print(M.shape)
44
+
45
+ ## choose the nearest cluster center from all the individual palettes as the inital group palette
46
+
47
+ # center_r0 = delete_num(center, 0)
48
+ # density_r0 = delete_num(density, 0)
49
+ # center_all = center_r0[0]
50
+ # for i in range(len(center_r0)-1):
51
+ # center_all = np.vstack([center_all, center_r0[i+1]])
52
+
53
+ center_r0 = delete_num(center)
54
+ density_r0 = delete_num(density)
55
+ center_all = stack_list(center)
56
+
57
+
58
+ if M.ndim == 1:
59
+ M=M.reshape(1, -1)
60
+
61
+ D = pairwise_distances(M, center_all, metric='euclidean')
62
+ idx = np.argmin(D, 1)
63
+
64
+ # center_all
65
+
66
+ M = center_all[idx,:]
67
+
68
+ ## solve the group palette according to the requirement (gamma and eta)
69
+ for t in range(iteration):
70
+ sum_val = 0
71
+ # solve for the assignment (matching)
72
+ for i in range(num_img):
73
+ # print(center[i])
74
+ if center[i].size != 0:
75
+ # if center[i] is not 0:
76
+ # print(center[i])
77
+ Lout[i], val = solve_optimal_ind_palette(center[i], density[i], M, lbd, gamma, eta, init_min)
78
+ sum_val = sum_val + val
79
+ else:
80
+ Lout[i] = np.array([])
81
+ sum_val = 0
82
+
83
+ # re-compute the group color theme (mean colors)
84
+ Lout_r0 = delete_num(Lout)
85
+ idx = detect_num(Lout)
86
+ reference_r0 = reference[idx]
87
+ M = solve_mean(center_r0, reference_r0, density_r0, Lout_r0, m, len(reference_r0), lbd)
88
+ print('Iteration {}, val: {}'.format(t, sum_val))
89
+ if np.abs(old_val - sum_val) < 10:
90
+ break
91
+ else:
92
+ old_val = sum_val
93
+
94
+ return M, Lout
95
+
96
+
97
+ def delete_num(list_org):
98
+ list_new=[]
99
+ for i in list_org:
100
+ if i.size != 0:
101
+ list_new.append(i)
102
+ return list_new
103
+
104
+ def detect_num(list_org):
105
+ list_idx=[]
106
+ for i in range(len(list_org)):
107
+ if list_org[i].size != 0:
108
+ list_idx.append(i)
109
+ return list_idx
110
+
111
+
112
+ def solve_optimal_ind_palette(center, density, center_mean, lambd, gamma, eta, init_min):
113
+
114
+ n = np.size(center, 0)
115
+ m = np.size(center_mean, 0)
116
+ # print(eta)
117
+
118
+ # brute-force all possible cases
119
+ min_obj_func = init_min
120
+ D1 = pairwise_distances(center, center, metric='euclidean')
121
+ D2 = psub2(center, center_mean)
122
+
123
+ dist = pairwise_distances(center, center_mean, metric='euclidean')
124
+ dist = np.tile(np.expand_dims(density, axis=1), (1,m)) * (dist**2)
125
+
126
+ D3 = np.insert(dist, 0, values=np.zeros(n), axis=1)
127
+
128
+
129
+ num_of_pairs = (m+1)**n
130
+ label = np.zeros((n,1)).astype(np.int32)
131
+ min_label_com = label.copy()
132
+ label[n-1, :] = -1
133
+ for idx in range(num_of_pairs):
134
+ label[n-1] = label[n-1] + 1
135
+ curId = n-1
136
+ while label[curId] > m:
137
+ label[curId] = 0
138
+ curId = curId - 1
139
+ label[curId] = label[curId] + 1
140
+ term4 = np.sum((label==0) * np.expand_dims(density, axis=1))
141
+ val = eta * term4
142
+ if val >= min_obj_func:
143
+ continue
144
+ for i in range(n):
145
+ val = val + D3[i, label[i]] #the first term
146
+ # cut down the unsolution
147
+ if val >= min_obj_func:
148
+ continue
149
+ term2 = 0
150
+ term3 = 0
151
+ for ii in range(n-1):
152
+ for jj in range(ii+1, n):
153
+ term2 = term2 + D2[ii, jj, label[ii], label[jj]]
154
+ if label[ii] == label[jj] and label[ii] > 0:
155
+ term3 = term3 + D1[ii, jj]
156
+
157
+ val = val + lambd * term2 + gamma * term3
158
+ if val < min_obj_func:
159
+ min_obj_func = val.copy()
160
+ min_label_com = label.copy()
161
+
162
+ return min_label_com, min_obj_func
163
+
164
+
165
+
166
+
167
+ def solve_mean(center, reference, density, L, m, n, lambd):
168
+ A = np.zeros((m, m))
169
+ B = np.zeros((m, 3))
170
+ M = np.zeros((m, 3))
171
+
172
+ for i in range(n):
173
+ if reference[i] == 0:
174
+ continue
175
+ Pi = center[i]
176
+ Wi = density[i]
177
+ Li = L[i]
178
+ ni = np.size(Pi, 0)
179
+ # first term
180
+ for j in range(ni):
181
+ if Li[j] > 0:
182
+ A[Li[j]-1, Li[j]-1] = A[Li[j]-1, Li[j]-1] + Wi[j]
183
+ B[Li[j]-1, :] = B[Li[j]-1, :] + np.expand_dims(Wi[j] * Pi[j, :], axis=0)
184
+
185
+ # second term
186
+ for j1 in range(ni-1):
187
+ for j2 in range(j1+1, ni):
188
+ if Li[j1] > 0 and Li[j2] > 0:
189
+ A[Li[j1]-1, Li[j1]-1] = A[Li[j1]-1, Li[j1]-1] + lambd
190
+ A[Li[j1]-1, Li[j2]-1] = A[Li[j1]-1, Li[j2]-1] - lambd
191
+ A[Li[j2]-1, Li[j2]-1] = A[Li[j2]-1, Li[j2]-1] + lambd
192
+ A[Li[j2]-1, Li[j1]-1] = A[Li[j2]-1, Li[j1]-1] - lambd
193
+ B[Li[j1]-1, :] = B[Li[j1]-1, :] + lambd * (Pi[j1,:]-Pi[j2,:])
194
+ B[Li[j2]-1, :] = B[Li[j2]-1, :] + lambd * (Pi[j2,:]-Pi[j1,:])
195
+
196
+ # solve least squares
197
+ # print(A)
198
+ M = np.dot(np.linalg.pinv(A), B)
199
+
200
+ # M[np.isnan[M]] = 0
201
+
202
+ # k-median, take the nearest from the individual palettes
203
+ M = choose_mediod(center, M)
204
+
205
+ return M
206
+
207
+
208
+
209
+ def concat_list(input, axis=0):
210
+ list_cat = input[0]
211
+ for i in range(len(input)-1):
212
+ list_cat = np.concatenate((list_cat, input[i+1]), axis=axis)
213
+ # list_cat = np.vstack([list_cat, input[i+1]])
214
+ return list_cat
215
+
216
+
217
+ def choose_mediod(Pin, M):
218
+ P = concat_list(Pin)
219
+ D = pairwise_distances(M, P)
220
+ idx = np.argmin(D, axis=1)
221
+ M = P[idx,:]
222
+ return M
223
+
224
+
225
+
226
+ def psub2(P, M):
227
+ p = np.size(P, 0)
228
+ m = np.size(M, 0)
229
+ D = np.zeros((p, p, m + 1, m + 1))
230
+ for i1 in range(p-1):
231
+ for i2 in range(i1+1, p):
232
+ for i3 in range(m):
233
+ for i4 in range(m):
234
+ D[i1,i2,i3+1,i4+1] = np.sum((P[i1,:] - P[i2,:] - M[i3,:] + M[i4,:])**2)
235
+ return D
236
+
237
+
238
+
239
+
240
+
utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import cv2
2
+ import numpy as np
3
+ # import matplotlib.pyplot as plt
4
+ from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
5
+ from PIL import Image
6
+
7
+
8
+ def stack_list(x):
9
+ stack = []
10
+ for id, val in enumerate(x):
11
+ if np.size(val) != 0:
12
+ if stack == []:
13
+ stack = val
14
+ else:
15
+ stack = np.vstack([stack, val])
16
+ return stack
17
+
18
+ def rgb_to_hex(r, g, b):
19
+ return '#{:02x}{:02x}{:02x}'.format(r, g, b)
20
+
21
+ def hex_to_rgb(hex):
22
+ # print(hex)
23
+ rgb = []
24
+ for i in (1, 3, 5):
25
+ decimal = int(hex[i:i+2], 16)
26
+ rgb.append(decimal)
27
+ return tuple(rgb)
28
+
29
+
30
+ def image_resize(img, c_w, c_h):
31
+ # img : PIL Image
32
+ if type(img) is np.ndarray:
33
+ img = Image.fromarray(img)
34
+ h, w = img.size
35
+ h_factor = c_h / w
36
+ w_factor = c_w / h
37
+ # factor = h_factor
38
+ factor = np.minimum(h_factor, w_factor)
39
+ # print(w*factor, h*factor)
40
+ img = img.resize((np.round(h*factor).astype(np.int64),
41
+ np.round(w*factor).astype(np.int64)),
42
+ Image.BILINEAR)
43
+ return img
44
+
45
+ def get_palette(num_cls):
46
+ """ Returns the color map for visualizing the segmentation mask.
47
+ Args:
48
+ num_cls: Number of classes
49
+ Returns:
50
+ The color map
51
+ """
52
+ n = num_cls
53
+ palette = [0] * (n * 3)
54
+ for j in range(0, n):
55
+ lab = j
56
+ palette[j * 3 + 0] = 0
57
+ palette[j * 3 + 1] = 0
58
+ palette[j * 3 + 2] = 0
59
+ i = 0
60
+ while lab:
61
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
62
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
63
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
64
+ i += 1
65
+ lab >>= 3
66
+ return palette
67
+
68
+
69
+ def visualize_palette(palette_lab, patch_size=20):
70
+ # print(palette_lab)
71
+ if palette_lab is None:
72
+ return np.ones((patch_size, patch_size, 3)) * [1.,1.,1.]
73
+ palette_lab = np.expand_dims(palette_lab, axis=0)
74
+ # palette_lab = np.sort(palette_lab, axis=1)
75
+
76
+ # # lab transfer by lookuptable
77
+ # # cluster_cen_rgb = lab2rgb_lut(cluster_cen_lab)
78
+ palette_rgb = lab2rgb(palette_lab)
79
+ palette_rgb = np.squeeze(palette_rgb, axis=0)
80
+
81
+ for id in range(np.size(palette_rgb, 0)):
82
+ rgb = np.expand_dims(palette_rgb[id,:], axis=(0, 1))
83
+ if id==0:
84
+ img_palette = np.ones((patch_size, patch_size, 3)) * rgb
85
+ else:
86
+ img_palette = np.append(img_palette, np.ones((patch_size, patch_size, 3)) * rgb, axis=1)
87
+
88
+ return img_palette
89
+
90
+
91
+ def visualize_palette_rgb(palette_rgb, patch_size=20):
92
+ # print(palette_lab)
93
+ if palette_rgb == 0:
94
+ return np.ones((patch_size, patch_size, 3)) * [1.,1.,1.]
95
+
96
+
97
+ for id in range(np.size(palette_rgb, 0)):
98
+ rgb = np.expand_dims(palette_rgb[id,:], axis=(0, 1))
99
+ if id==0:
100
+ img_palette = np.ones((patch_size, patch_size, 3)) * rgb
101
+ else:
102
+ img_palette = np.append(img_palette, np.ones((patch_size, patch_size, 3)) * rgb, axis=1)
103
+
104
+ return img_palette
105
+
106
+
107
+
108
+ # def vis_consistency(img_rgb_all, img_rgb_out_all, label_colored_all, c_center, L_idx):
109
+
110
+
111
+
112
+ def color_difference(img1, img2):
113
+ h, w, c = img1.shape
114
+ img1_lab = rgb2lab(img1)
115
+ img2_lab = rgb2lab(img2)
116
+
117
+ diff=img1_lab-img2_lab
118
+
119
+ dE = np.sqrt(diff[:,:,0]**2 + diff[:,:,1]**2 + diff[:,:,2]**2)
120
+ # dE = np.sqrt(diff[:,:,0]**2 + diff[:,:,0]**2)
121
+ dE = np.sum(dE)/(h*w)
122
+
123
+ return dE
124
+