Spaces:
Sleeping
Sleeping
inital upload
Browse files- .gitattributes +4 -0
- WB_sRGB/LICENSE.md +438 -0
- WB_sRGB/classes/WBsRGB.py +164 -0
- WB_sRGB/models/encoderBias.npy +3 -0
- WB_sRGB/models/encoderWeights.npy +3 -0
- WB_sRGB/models/features.npy +3 -0
- WB_sRGB/models/mappingFuncs.npy +3 -0
- app.py +145 -0
- color_naming/colornaming.py +100 -0
- color_naming/w2c11_joost_c.npy +3 -0
- examples/flower/001.jpg +3 -0
- examples/flower/002.jpg +3 -0
- examples/flower/003.jpg +3 -0
- examples/flower/004.jpg +3 -0
- examples/flower/005.jpg +3 -0
- examples/landmark/01.jpg +3 -0
- examples/landmark/02.jpg +3 -0
- examples/landmark/03.jpg +3 -0
- examples/landmark/04.jpg +3 -0
- examples/landmark/05.jpg +3 -0
- examples/landmark/06.jpg +3 -0
- examples/landmark/07.jpg +3 -0
- examples/portrait/image-00000.png +3 -0
- examples/portrait/image-00002.png +3 -0
- examples/portrait/image-00004.png +3 -0
- examples/portrait/image-00006.png +3 -0
- examples/portrait/image-00014.png +3 -0
- extract_palette.py +139 -0
- image.py +242 -0
- multi_image_process.py +365 -0
- recolor.py +135 -0
- requirements.txt +9 -0
- saliency/LDF/dataset.py +137 -0
- saliency/LDF/infer.py +40 -0
- saliency/LDF/model-40 +3 -0
- saliency/LDF/net.py +216 -0
- saliency/fast_saliency.py +590 -0
- solve_group_palette.py +240 -0
- utils.py +124 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
model-40 filter=lfs diff=lfs merge=lfs -text
|
WB_sRGB/LICENSE.md
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
| 58 |
+
Public License
|
| 59 |
+
|
| 60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
| 63 |
+
("Public License"). To the extent this Public License may be
|
| 64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
| 65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
| 66 |
+
Licensor grants You such rights in consideration of benefits the
|
| 67 |
+
Licensor receives from making the Licensed Material available under
|
| 68 |
+
these terms and conditions.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Section 1 -- Definitions.
|
| 72 |
+
|
| 73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 74 |
+
Rights that is derived from or based upon the Licensed Material
|
| 75 |
+
and in which the Licensed Material is translated, altered,
|
| 76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 77 |
+
permission under the Copyright and Similar Rights held by the
|
| 78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 79 |
+
Material is a musical work, performance, or sound recording,
|
| 80 |
+
Adapted Material is always produced where the Licensed Material is
|
| 81 |
+
synched in timed relation with a moving image.
|
| 82 |
+
|
| 83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 85 |
+
accordance with the terms and conditions of this Public License.
|
| 86 |
+
|
| 87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
| 88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
| 89 |
+
Commons as essentially the equivalent of this Public License.
|
| 90 |
+
|
| 91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
| 92 |
+
closely related to copyright including, without limitation,
|
| 93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 94 |
+
Rights, without regard to how the rights are labeled or
|
| 95 |
+
categorized. For purposes of this Public License, the rights
|
| 96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 97 |
+
Rights.
|
| 98 |
+
|
| 99 |
+
e. Effective Technological Measures means those measures that, in the
|
| 100 |
+
absence of proper authority, may not be circumvented under laws
|
| 101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 103 |
+
agreements.
|
| 104 |
+
|
| 105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 106 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 107 |
+
that applies to Your use of the Licensed Material.
|
| 108 |
+
|
| 109 |
+
g. License Elements means the license attributes listed in the name
|
| 110 |
+
of a Creative Commons Public License. The License Elements of this
|
| 111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
| 112 |
+
|
| 113 |
+
h. Licensed Material means the artistic or literary work, database,
|
| 114 |
+
or other material to which the Licensor applied this Public
|
| 115 |
+
License.
|
| 116 |
+
|
| 117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
| 118 |
+
terms and conditions of this Public License, which are limited to
|
| 119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 120 |
+
Licensed Material and that the Licensor has authority to license.
|
| 121 |
+
|
| 122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
| 123 |
+
under this Public License.
|
| 124 |
+
|
| 125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
| 126 |
+
commercial advantage or monetary compensation. For purposes of
|
| 127 |
+
this Public License, the exchange of the Licensed Material for
|
| 128 |
+
other material subject to Copyright and Similar Rights by digital
|
| 129 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 130 |
+
no payment of monetary compensation in connection with the
|
| 131 |
+
exchange.
|
| 132 |
+
|
| 133 |
+
l. Share means to provide material to the public by any means or
|
| 134 |
+
process that requires permission under the Licensed Rights, such
|
| 135 |
+
as reproduction, public display, public performance, distribution,
|
| 136 |
+
dissemination, communication, or importation, and to make material
|
| 137 |
+
available to the public including in ways that members of the
|
| 138 |
+
public may access the material from a place and at a time
|
| 139 |
+
individually chosen by them.
|
| 140 |
+
|
| 141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
| 142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 144 |
+
as amended and/or succeeded, as well as other essentially
|
| 145 |
+
equivalent rights anywhere in the world.
|
| 146 |
+
|
| 147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
| 148 |
+
under this Public License. Your has a corresponding meaning.
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
Section 2 -- Scope.
|
| 152 |
+
|
| 153 |
+
a. License grant.
|
| 154 |
+
|
| 155 |
+
1. Subject to the terms and conditions of this Public License,
|
| 156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 159 |
+
|
| 160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 161 |
+
in part, for NonCommercial purposes only; and
|
| 162 |
+
|
| 163 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 164 |
+
NonCommercial purposes only.
|
| 165 |
+
|
| 166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 167 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 168 |
+
License does not apply, and You do not need to comply with
|
| 169 |
+
its terms and conditions.
|
| 170 |
+
|
| 171 |
+
3. Term. The term of this Public License is specified in Section
|
| 172 |
+
6(a).
|
| 173 |
+
|
| 174 |
+
4. Media and formats; technical modifications allowed. The
|
| 175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 176 |
+
all media and formats whether now known or hereafter created,
|
| 177 |
+
and to make technical modifications necessary to do so. The
|
| 178 |
+
Licensor waives and/or agrees not to assert any right or
|
| 179 |
+
authority to forbid You from making technical modifications
|
| 180 |
+
necessary to exercise the Licensed Rights, including
|
| 181 |
+
technical modifications necessary to circumvent Effective
|
| 182 |
+
Technological Measures. For purposes of this Public License,
|
| 183 |
+
simply making modifications authorized by this Section 2(a)
|
| 184 |
+
(4) never produces Adapted Material.
|
| 185 |
+
|
| 186 |
+
5. Downstream recipients.
|
| 187 |
+
|
| 188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 189 |
+
recipient of the Licensed Material automatically
|
| 190 |
+
receives an offer from the Licensor to exercise the
|
| 191 |
+
Licensed Rights under the terms and conditions of this
|
| 192 |
+
Public License.
|
| 193 |
+
|
| 194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
| 195 |
+
Every recipient of Adapted Material from You
|
| 196 |
+
automatically receives an offer from the Licensor to
|
| 197 |
+
exercise the Licensed Rights in the Adapted Material
|
| 198 |
+
under the conditions of the Adapter's License You apply.
|
| 199 |
+
|
| 200 |
+
c. No downstream restrictions. You may not offer or impose
|
| 201 |
+
any additional or different terms or conditions on, or
|
| 202 |
+
apply any Effective Technological Measures to, the
|
| 203 |
+
Licensed Material if doing so restricts exercise of the
|
| 204 |
+
Licensed Rights by any recipient of the Licensed
|
| 205 |
+
Material.
|
| 206 |
+
|
| 207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 208 |
+
may be construed as permission to assert or imply that You
|
| 209 |
+
are, or that Your use of the Licensed Material is, connected
|
| 210 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 211 |
+
the Licensor or others designated to receive attribution as
|
| 212 |
+
provided in Section 3(a)(1)(A)(i).
|
| 213 |
+
|
| 214 |
+
b. Other rights.
|
| 215 |
+
|
| 216 |
+
1. Moral rights, such as the right of integrity, are not
|
| 217 |
+
licensed under this Public License, nor are publicity,
|
| 218 |
+
privacy, and/or other similar personality rights; however, to
|
| 219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 220 |
+
assert any such rights held by the Licensor to the limited
|
| 221 |
+
extent necessary to allow You to exercise the Licensed
|
| 222 |
+
Rights, but not otherwise.
|
| 223 |
+
|
| 224 |
+
2. Patent and trademark rights are not licensed under this
|
| 225 |
+
Public License.
|
| 226 |
+
|
| 227 |
+
3. To the extent possible, the Licensor waives any right to
|
| 228 |
+
collect royalties from You for the exercise of the Licensed
|
| 229 |
+
Rights, whether directly or through a collecting society
|
| 230 |
+
under any voluntary or waivable statutory or compulsory
|
| 231 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 232 |
+
reserves any right to collect such royalties, including when
|
| 233 |
+
the Licensed Material is used other than for NonCommercial
|
| 234 |
+
purposes.
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
Section 3 -- License Conditions.
|
| 238 |
+
|
| 239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 240 |
+
following conditions.
|
| 241 |
+
|
| 242 |
+
a. Attribution.
|
| 243 |
+
|
| 244 |
+
1. If You Share the Licensed Material (including in modified
|
| 245 |
+
form), You must:
|
| 246 |
+
|
| 247 |
+
a. retain the following if it is supplied by the Licensor
|
| 248 |
+
with the Licensed Material:
|
| 249 |
+
|
| 250 |
+
i. identification of the creator(s) of the Licensed
|
| 251 |
+
Material and any others designated to receive
|
| 252 |
+
attribution, in any reasonable manner requested by
|
| 253 |
+
the Licensor (including by pseudonym if
|
| 254 |
+
designated);
|
| 255 |
+
|
| 256 |
+
ii. a copyright notice;
|
| 257 |
+
|
| 258 |
+
iii. a notice that refers to this Public License;
|
| 259 |
+
|
| 260 |
+
iv. a notice that refers to the disclaimer of
|
| 261 |
+
warranties;
|
| 262 |
+
|
| 263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 264 |
+
extent reasonably practicable;
|
| 265 |
+
|
| 266 |
+
b. indicate if You modified the Licensed Material and
|
| 267 |
+
retain an indication of any previous modifications; and
|
| 268 |
+
|
| 269 |
+
c. indicate the Licensed Material is licensed under this
|
| 270 |
+
Public License, and include the text of, or the URI or
|
| 271 |
+
hyperlink to, this Public License.
|
| 272 |
+
|
| 273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 274 |
+
reasonable manner based on the medium, means, and context in
|
| 275 |
+
which You Share the Licensed Material. For example, it may be
|
| 276 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 277 |
+
hyperlink to a resource that includes the required
|
| 278 |
+
information.
|
| 279 |
+
3. If requested by the Licensor, You must remove any of the
|
| 280 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 281 |
+
reasonably practicable.
|
| 282 |
+
|
| 283 |
+
b. ShareAlike.
|
| 284 |
+
|
| 285 |
+
In addition to the conditions in Section 3(a), if You Share
|
| 286 |
+
Adapted Material You produce, the following conditions also apply.
|
| 287 |
+
|
| 288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
| 289 |
+
license with the same License Elements, this version or
|
| 290 |
+
later, or a BY-NC-SA Compatible License.
|
| 291 |
+
|
| 292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
| 293 |
+
Adapter's License You apply. You may satisfy this condition
|
| 294 |
+
in any reasonable manner based on the medium, means, and
|
| 295 |
+
context in which You Share Adapted Material.
|
| 296 |
+
|
| 297 |
+
3. You may not offer or impose any additional or different terms
|
| 298 |
+
or conditions on, or apply any Effective Technological
|
| 299 |
+
Measures to, Adapted Material that restrict exercise of the
|
| 300 |
+
rights granted under the Adapter's License You apply.
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
Section 4 -- Sui Generis Database Rights.
|
| 304 |
+
|
| 305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 306 |
+
apply to Your use of the Licensed Material:
|
| 307 |
+
|
| 308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 310 |
+
portion of the contents of the database for NonCommercial purposes
|
| 311 |
+
only;
|
| 312 |
+
|
| 313 |
+
b. if You include all or a substantial portion of the database
|
| 314 |
+
contents in a database in which You have Sui Generis Database
|
| 315 |
+
Rights, then the database in which You have Sui Generis Database
|
| 316 |
+
Rights (but not its individual contents) is Adapted Material,
|
| 317 |
+
including for purposes of Section 3(b); and
|
| 318 |
+
|
| 319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 320 |
+
all or a substantial portion of the contents of the database.
|
| 321 |
+
|
| 322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 323 |
+
replace Your obligations under this Public License where the Licensed
|
| 324 |
+
Rights include other Copyright and Similar Rights.
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 328 |
+
|
| 329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 339 |
+
|
| 340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 349 |
+
|
| 350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 351 |
+
above shall be interpreted in a manner that, to the extent
|
| 352 |
+
possible, most closely approximates an absolute disclaimer and
|
| 353 |
+
waiver of all liability.
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
Section 6 -- Term and Termination.
|
| 357 |
+
|
| 358 |
+
a. This Public License applies for the term of the Copyright and
|
| 359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 360 |
+
this Public License, then Your rights under this Public License
|
| 361 |
+
terminate automatically.
|
| 362 |
+
|
| 363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 364 |
+
Section 6(a), it reinstates:
|
| 365 |
+
|
| 366 |
+
1. automatically as of the date the violation is cured, provided
|
| 367 |
+
it is cured within 30 days of Your discovery of the
|
| 368 |
+
violation; or
|
| 369 |
+
|
| 370 |
+
2. upon express reinstatement by the Licensor.
|
| 371 |
+
|
| 372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 373 |
+
right the Licensor may have to seek remedies for Your violations
|
| 374 |
+
of this Public License.
|
| 375 |
+
|
| 376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 377 |
+
Licensed Material under separate terms or conditions or stop
|
| 378 |
+
distributing the Licensed Material at any time; however, doing so
|
| 379 |
+
will not terminate this Public License.
|
| 380 |
+
|
| 381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 382 |
+
License.
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
Section 7 -- Other Terms and Conditions.
|
| 386 |
+
|
| 387 |
+
a. The Licensor shall not be bound by any additional or different
|
| 388 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 389 |
+
|
| 390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 391 |
+
Licensed Material not stated herein are separate from and
|
| 392 |
+
independent of the terms and conditions of this Public License.
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
Section 8 -- Interpretation.
|
| 396 |
+
|
| 397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 399 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 400 |
+
be made without permission under this Public License.
|
| 401 |
+
|
| 402 |
+
b. To the extent possible, if any provision of this Public License is
|
| 403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 404 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 405 |
+
cannot be reformed, it shall be severed from this Public License
|
| 406 |
+
without affecting the enforceability of the remaining terms and
|
| 407 |
+
conditions.
|
| 408 |
+
|
| 409 |
+
c. No term or condition of this Public License will be waived and no
|
| 410 |
+
failure to comply consented to unless expressly agreed to by the
|
| 411 |
+
Licensor.
|
| 412 |
+
|
| 413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 415 |
+
that apply to the Licensor or You, including from the legal
|
| 416 |
+
processes of any jurisdiction or authority.
|
| 417 |
+
|
| 418 |
+
=======================================================================
|
| 419 |
+
|
| 420 |
+
Creative Commons is not a party to its public
|
| 421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 422 |
+
its public licenses to material it publishes and in those instances
|
| 423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 426 |
+
material is shared under a Creative Commons public license or as
|
| 427 |
+
otherwise permitted by the Creative Commons policies published at
|
| 428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 430 |
+
of Creative Commons without its prior written consent including,
|
| 431 |
+
without limitation, in connection with any unauthorized modifications
|
| 432 |
+
to any of its public licenses or any other arrangements,
|
| 433 |
+
understandings, or agreements concerning use of licensed material. For
|
| 434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 435 |
+
public licenses.
|
| 436 |
+
|
| 437 |
+
Creative Commons may be contacted at creativecommons.org.
|
| 438 |
+
|
WB_sRGB/classes/WBsRGB.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## White-balance model class
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2018-present, Mahmoud Afifi
|
| 4 |
+
# York University, Canada
|
| 5 |
+
# mafifi@eecs.yorku.ca | m.3afifi@gmail.com
|
| 6 |
+
#
|
| 7 |
+
# This source code is licensed under the license found in the
|
| 8 |
+
# LICENSE file in the root directory of this source tree.
|
| 9 |
+
# All rights reserved.
|
| 10 |
+
#
|
| 11 |
+
# Please cite the following work if this program is used:
|
| 12 |
+
# Mahmoud Afifi, Brian Price, Scott Cohen, and Michael S. Brown,
|
| 13 |
+
# "When color constancy goes wrong: Correcting improperly white-balanced
|
| 14 |
+
# images", CVPR 2019.
|
| 15 |
+
#
|
| 16 |
+
##########################################################################
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import numpy.matlib
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class WBsRGB:
|
| 25 |
+
def __init__(self, gamut_mapping=2, upgraded=0):
|
| 26 |
+
if upgraded == 1:
|
| 27 |
+
self.features = np.load('WB_sRGB/models/features+.npy') # encoded features
|
| 28 |
+
self.mappingFuncs = np.load('WB_sRGB/models/mappingFuncs+.npy') # correct funcs
|
| 29 |
+
self.encoderWeights = np.load('WB_sRGB/models/encoderWeights+.npy') # PCA matrix
|
| 30 |
+
self.encoderBias = np.load('WB_sRGB/models/encoderBias+.npy') # PCA bias
|
| 31 |
+
self.K = 75 # K value for NN searching
|
| 32 |
+
else:
|
| 33 |
+
self.features = np.load('WB_sRGB/models/features.npy') # encoded features
|
| 34 |
+
self.mappingFuncs = np.load('WB_sRGB/models/mappingFuncs.npy') # correction funcs
|
| 35 |
+
self.encoderWeights = np.load('WB_sRGB/models/encoderWeights.npy') # PCA matrix
|
| 36 |
+
self.encoderBias = np.load('WB_sRGB/models/encoderBias.npy') # PCA bias
|
| 37 |
+
self.K = 25 # K value for nearest neighbor searching
|
| 38 |
+
|
| 39 |
+
self.sigma = 0.25 # fall-off factor for KNN blending
|
| 40 |
+
self.h = 60 # histogram bin width
|
| 41 |
+
# our results reported with gamut_mapping=2, however gamut_mapping=1
|
| 42 |
+
# gives more compelling results with over-saturated examples.
|
| 43 |
+
self.gamut_mapping = gamut_mapping # options: 1 scaling, 2 clipping
|
| 44 |
+
|
| 45 |
+
def encode(self, hist):
|
| 46 |
+
""" Generates a compacted feature of a given RGB-uv histogram tensor."""
|
| 47 |
+
histR_reshaped = np.reshape(np.transpose(hist[:, :, 0]),
|
| 48 |
+
(1, int(hist.size / 3)), order="F")
|
| 49 |
+
histG_reshaped = np.reshape(np.transpose(hist[:, :, 1]),
|
| 50 |
+
(1, int(hist.size / 3)), order="F")
|
| 51 |
+
histB_reshaped = np.reshape(np.transpose(hist[:, :, 2]),
|
| 52 |
+
(1, int(hist.size / 3)), order="F")
|
| 53 |
+
hist_reshaped = np.append(histR_reshaped,
|
| 54 |
+
[histG_reshaped, histB_reshaped])
|
| 55 |
+
feature = np.dot(hist_reshaped - self.encoderBias.transpose(),
|
| 56 |
+
self.encoderWeights)
|
| 57 |
+
return feature
|
| 58 |
+
|
| 59 |
+
def rgb_uv_hist(self, I):
|
| 60 |
+
""" Computes an RGB-uv histogram tensor. """
|
| 61 |
+
sz = np.shape(I) # get size of current image
|
| 62 |
+
if sz[0] * sz[1] > 202500: # resize if it is larger than 450*450
|
| 63 |
+
factor = np.sqrt(202500 / (sz[0] * sz[1])) # rescale factor
|
| 64 |
+
newH = int(np.floor(sz[0] * factor))
|
| 65 |
+
newW = int(np.floor(sz[1] * factor))
|
| 66 |
+
I = cv2.resize(I, (newW, newH), interpolation=cv2.INTER_NEAREST)
|
| 67 |
+
I_reshaped = I[(I > 0).all(axis=2)]
|
| 68 |
+
eps = 6.4 / self.h
|
| 69 |
+
hist = np.zeros((self.h, self.h, 3)) # histogram will be stored here
|
| 70 |
+
Iy = np.linalg.norm(I_reshaped, axis=1) # intensity vector
|
| 71 |
+
for i in range(3): # for each histogram layer, do
|
| 72 |
+
r = [] # excluded channels will be stored here
|
| 73 |
+
for j in range(3): # for each color channel do
|
| 74 |
+
if j != i:
|
| 75 |
+
r.append(j)
|
| 76 |
+
Iu = np.log(I_reshaped[:, i] / I_reshaped[:, r[1]])
|
| 77 |
+
Iv = np.log(I_reshaped[:, i] / I_reshaped[:, r[0]])
|
| 78 |
+
hist[:, :, i], _, _ = np.histogram2d(
|
| 79 |
+
Iu, Iv, bins=self.h, range=((-3.2 - eps / 2, 3.2 - eps / 2),) * 2, weights=Iy)
|
| 80 |
+
norm_ = hist[:, :, i].sum()
|
| 81 |
+
hist[:, :, i] = np.sqrt(hist[:, :, i] / norm_) # (hist/norm)^(1/2)
|
| 82 |
+
return hist
|
| 83 |
+
|
| 84 |
+
def correctImage(self, I):
|
| 85 |
+
""" White balance a given image I. """
|
| 86 |
+
# I = I[..., ::-1] # convert from BGR to RGB #donna
|
| 87 |
+
I = im2double(I) # convert to double
|
| 88 |
+
# Convert I to float32 may speed up the process.
|
| 89 |
+
feature = self.encode(self.rgb_uv_hist(I))
|
| 90 |
+
# Do
|
| 91 |
+
# ```python
|
| 92 |
+
# feature_diff = self.features - feature
|
| 93 |
+
# D_sq = np.einsum('ij,ij->i', feature_diff, feature_diff)[:, None]
|
| 94 |
+
# ```
|
| 95 |
+
D_sq = np.einsum(
|
| 96 |
+
'ij, ij ->i', self.features, self.features)[:, None] + np.einsum(
|
| 97 |
+
'ij, ij ->i', feature, feature) - 2 * self.features.dot(feature.T)
|
| 98 |
+
|
| 99 |
+
# get smallest K distances
|
| 100 |
+
idH = D_sq.argpartition(self.K, axis=0)[:self.K]
|
| 101 |
+
mappingFuncs = np.squeeze(self.mappingFuncs[idH, :])
|
| 102 |
+
dH = np.sqrt(
|
| 103 |
+
np.take_along_axis(D_sq, idH, axis=0))
|
| 104 |
+
weightsH = np.exp(-(np.power(dH, 2)) /
|
| 105 |
+
(2 * np.power(self.sigma, 2))) # compute weights
|
| 106 |
+
weightsH = weightsH / sum(weightsH) # normalize blending weights
|
| 107 |
+
mf = sum(np.matlib.repmat(weightsH, 1, 33) *
|
| 108 |
+
mappingFuncs, 0) # compute the mapping function
|
| 109 |
+
mf = mf.reshape(11, 3, order="F") # reshape it to be 9 * 3
|
| 110 |
+
I_corr = self.colorCorrection(I, mf) # apply it!
|
| 111 |
+
return I_corr
|
| 112 |
+
|
| 113 |
+
def colorCorrection(self, input, m):
|
| 114 |
+
""" Applies a mapping function m to a given input image. """
|
| 115 |
+
sz = np.shape(input) # get size of input image
|
| 116 |
+
I_reshaped = np.reshape(input, (int(input.size / 3), 3), order="F")
|
| 117 |
+
kernel_out = kernelP(I_reshaped)
|
| 118 |
+
out = np.dot(kernel_out, m)
|
| 119 |
+
if self.gamut_mapping == 1:
|
| 120 |
+
# scaling based on input image energy
|
| 121 |
+
out = normScaling(I_reshaped, out)
|
| 122 |
+
elif self.gamut_mapping == 2:
|
| 123 |
+
# clip out-of-gamut pixels
|
| 124 |
+
out = outOfGamutClipping(out)
|
| 125 |
+
else:
|
| 126 |
+
raise Exception('Wrong gamut_mapping value')
|
| 127 |
+
# reshape output image back to the original image shape
|
| 128 |
+
out = out.reshape(sz[0], sz[1], sz[2], order="F")
|
| 129 |
+
out = out.astype('float32') #donna
|
| 130 |
+
# out = out.astype('float32')[..., ::-1] # convert from BGR to RGB
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def normScaling(I, I_corr):
|
| 135 |
+
""" Scales each pixel based on original image energy. """
|
| 136 |
+
norm_I_corr = np.sqrt(np.sum(np.power(I_corr, 2), 1))
|
| 137 |
+
inds = norm_I_corr != 0
|
| 138 |
+
norm_I_corr = norm_I_corr[inds]
|
| 139 |
+
norm_I = np.sqrt(np.sum(np.power(I[inds, :], 2), 1))
|
| 140 |
+
I_corr[inds, :] = I_corr[inds, :] / np.tile(
|
| 141 |
+
norm_I_corr[:, np.newaxis], 3) * np.tile(norm_I[:, np.newaxis], 3)
|
| 142 |
+
return I_corr
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def kernelP(rgb):
|
| 146 |
+
""" Kernel function: kernel(r, g, b) -> (r,g,b,rg,rb,gb,r^2,g^2,b^2,rgb,1)
|
| 147 |
+
Ref: Hong, et al., "A study of digital camera colorimetric
|
| 148 |
+
characterization based on polynomial modeling." Color Research &
|
| 149 |
+
Application, 2001. """
|
| 150 |
+
r, g, b = np.split(rgb, 3, axis=1)
|
| 151 |
+
return np.concatenate(
|
| 152 |
+
[rgb, r * g, r * b, g * b, rgb ** 2, r * g * b, np.ones_like(r)], axis=1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def outOfGamutClipping(I):
|
| 156 |
+
""" Clips out-of-gamut pixels. """
|
| 157 |
+
I[I > 1] = 1 # any pixel is higher than 1, clip it to 1
|
| 158 |
+
I[I < 0] = 0 # any pixel is below 0, clip it to 0
|
| 159 |
+
return I
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def im2double(im):
|
| 163 |
+
""" Returns a double image [0,1] of the uint8 im [0,255]. """
|
| 164 |
+
return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)
|
WB_sRGB/models/encoderBias.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8d77ec0baab6e45a5602cf9f394e7fde3720b8ca76f59c45d7807e0c49b7070
|
| 3 |
+
size 43328
|
WB_sRGB/models/encoderWeights.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff8e684dd18f0e82fb5b1f398b26eb356f7ba9acd8c81dc43e791f86f105659a
|
| 3 |
+
size 2376128
|
WB_sRGB/models/features.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76ec7f16aaf1cda53dfb1b3812fae88cb56e0157e3fe7e84138227ac1108a9d6
|
| 3 |
+
size 13757828
|
WB_sRGB/models/mappingFuncs.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d677ecbe0e958ca1f6d56deb9a3ae9ce66cfbb606e891c8e06e60d807de68b6b
|
| 3 |
+
size 8254748
|
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
from multi_image_process import compute_inp_palette, recolor_single_image, recolor_group_images
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
example_images = ["./examples/flower/001.jpg",
|
| 9 |
+
"./examples/flower/002.jpg",
|
| 10 |
+
"./examples/flower/003.jpg",
|
| 11 |
+
"./examples/flower/004.jpg",
|
| 12 |
+
"./examples/flower/005.jpg",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def swap_to_gallery(images):
|
| 18 |
+
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
|
| 19 |
+
|
| 20 |
+
def remove_back_to_files():
|
| 21 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
| 22 |
+
|
| 23 |
+
def show_palette(*colors):
|
| 24 |
+
# create HTML,each color is a block with a border
|
| 25 |
+
color_blocks = ""
|
| 26 |
+
for color in colors:
|
| 27 |
+
if color: # if choose a color
|
| 28 |
+
color_blocks += f'<div style="width:40px;height:40px;background:{color};display:inline-block;margin:5px;border:1px solid #000;"></div>'
|
| 29 |
+
return color_blocks
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_example_images():
|
| 33 |
+
images = [Image.open(p) for p in example_images]
|
| 34 |
+
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=example_images, visible=False)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
|
| 40 |
+
with gr.Blocks() as demo:
|
| 41 |
+
gr.Markdown("# Image recoloring with palette")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# multiple image recoloring for color consistency
|
| 45 |
+
with gr.Row():
|
| 46 |
+
with gr.Column():
|
| 47 |
+
gr.Markdown("### Inputs")
|
| 48 |
+
image_input = gr.File(
|
| 49 |
+
label="Drag (Select) more than one photos",
|
| 50 |
+
file_types=["image"],
|
| 51 |
+
file_count="multiple"
|
| 52 |
+
)
|
| 53 |
+
uploaded_files = gr.Gallery(label="Input images", visible=False, columns=7, rows=1, height=200)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
with gr.Column(visible=False) as clear_button:
|
| 57 |
+
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=image_input, size="sm")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
image_input.upload(fn=swap_to_gallery, inputs=image_input, outputs=[uploaded_files, clear_button, image_input])
|
| 61 |
+
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, image_input])
|
| 62 |
+
|
| 63 |
+
gr.Markdown("### Select the parameters for recoloring")
|
| 64 |
+
with gr.Row():
|
| 65 |
+
with gr.Group():
|
| 66 |
+
gr.Markdown("### Recoloring without other techiques")
|
| 67 |
+
num_center_grp = gr.Dropdown(choices=[1, 2, 3, 4, 5], value=3, label="Number of group palettes")
|
| 68 |
+
with gr.Group():
|
| 69 |
+
gr.Markdown("### Recoloring with other techiques")
|
| 70 |
+
with gr.Row():
|
| 71 |
+
with gr.Group():
|
| 72 |
+
# gr.Markdown("### white balance")
|
| 73 |
+
checkbox_input_wb = gr.Checkbox(value=False, label="Apply white balance correction")
|
| 74 |
+
with gr.Group():
|
| 75 |
+
# gr.Markdown("### saliency detection")
|
| 76 |
+
checkbox_input_sal = gr.Checkbox(value=False, label="Apply saliency")
|
| 77 |
+
checkbox_input_recolor_sal = gr.Checkbox(value=False, label="Recolor salient part only")
|
| 78 |
+
checkbox_input_recolor_nonsal = gr.Checkbox(value=False, label="Recolor non-salient part only")
|
| 79 |
+
num_center_sal = gr.Dropdown(choices=[1, 2, 3], value=1, label="Number of salient palettes")
|
| 80 |
+
num_center_nonsal = gr.Dropdown(choices=[1, 2, 3], value=1, label="Number of non-salient palettes")
|
| 81 |
+
with gr.Group():
|
| 82 |
+
# gr.Markdown("### color naming")
|
| 83 |
+
checkbox_input_cn = gr.Checkbox(value=False, label="Apply color naming")
|
| 84 |
+
naming_thres = gr.Textbox(value=0.8, label="Threshold of color naming", placeholder=0.8)
|
| 85 |
+
|
| 86 |
+
with gr.Column():
|
| 87 |
+
gr.Markdown("### Outputs")
|
| 88 |
+
output_gallery_palette_in = gr.Gallery(label="Input image palettes", columns=7, rows=1, height=100)
|
| 89 |
+
output_gallery_palette_group= gr.Gallery(label="Group palette", columns=2, rows=1, height=100)
|
| 90 |
+
output_gallery_palette_out = gr.Gallery(label="Output image palettes", columns=7, rows=1, height=100)
|
| 91 |
+
output_gallery_recolor = gr.Gallery(label="Recolored images", columns=7, rows=1, height=300)
|
| 92 |
+
|
| 93 |
+
with gr.Row():
|
| 94 |
+
example_btn = gr.Button("Load Example Images")
|
| 95 |
+
example_btn.click(fn=load_example_images, outputs=[uploaded_files, clear_button, image_input])
|
| 96 |
+
|
| 97 |
+
palette_btn = gr.Button("Compute palette").click(
|
| 98 |
+
compute_inp_palette,
|
| 99 |
+
inputs=[image_input],
|
| 100 |
+
outputs=[output_gallery_palette_in]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# recoloring_multi_btn = gr.Button("Recoloring multiple images").click(
|
| 104 |
+
# multi_img_color_consist,
|
| 105 |
+
# inputs=[image_input, num_center_grp, checkbox_input_wb, checkbox_input_sal, checkbox_input_cn, naming_thres],
|
| 106 |
+
# outputs=[output_gallery_recolor, output_gallery_palette_in, output_gallery_palette_out, output_gallery_palette_group]
|
| 107 |
+
# )
|
| 108 |
+
|
| 109 |
+
recoloring_multi_btn = gr.Button("Recoloring multiple images").click(
|
| 110 |
+
recolor_group_images,
|
| 111 |
+
inputs=[image_input, num_center_grp, num_center_sal, num_center_nonsal,
|
| 112 |
+
checkbox_input_wb,
|
| 113 |
+
checkbox_input_sal, checkbox_input_recolor_sal, checkbox_input_recolor_nonsal,
|
| 114 |
+
checkbox_input_cn, naming_thres],
|
| 115 |
+
outputs=[output_gallery_recolor, output_gallery_palette_in, output_gallery_palette_out, output_gallery_palette_group]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# # single image recoloring with user-defined palette
|
| 121 |
+
# with gr.Row():
|
| 122 |
+
# with gr.Column():
|
| 123 |
+
# gr.Markdown("### Image color categorization")
|
| 124 |
+
# with gr.Row():
|
| 125 |
+
# color1 = gr.ColorPicker(label="Color 1", value="#ff0000")
|
| 126 |
+
# color2 = gr.ColorPicker(label="Color 2", value="#00ff00")
|
| 127 |
+
# color3 = gr.ColorPicker(label="Color 3", value="#0000ff")
|
| 128 |
+
# color4 = gr.ColorPicker(label="Color 4", value=None)
|
| 129 |
+
# color5 = gr.ColorPicker(label="Color 5", value=None)
|
| 130 |
+
# palette = gr.HTML()
|
| 131 |
+
# with gr.Row():
|
| 132 |
+
# btn = gr.Button("Show Picked Palette")
|
| 133 |
+
# btn.click(fn=show_palette, inputs=[color1, color2, color3, color4, color5], outputs=palette)
|
| 134 |
+
|
| 135 |
+
# with gr.Column():
|
| 136 |
+
# gr.Markdown("### Output image")
|
| 137 |
+
# output_gallery_recolor_single = gr.Gallery(label="Recolored image", columns=1, rows=1, height=100)
|
| 138 |
+
# recoloring_single_btn = gr.Button("Recoloring single images").click(
|
| 139 |
+
# recolor_single_image,
|
| 140 |
+
# inputs=[image_input, color1, color2, color3, color4, color5],
|
| 141 |
+
# outputs=[output_gallery_recolor_single]
|
| 142 |
+
# )
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
demo.launch()
|
color_naming/colornaming.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
# import cv2
|
| 3 |
+
# from collections import Counter
|
| 4 |
+
from skimage.color import rgb2lab, lab2rgb
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
COLOR_NAME = ['black', 'brown', 'blue', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow']
|
| 8 |
+
|
| 9 |
+
def im2c(img_lab, w2c):
|
| 10 |
+
"""
|
| 11 |
+
Convert an image to color name representation using a color-name matrix.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
numpy.ndarray: Processed image based on color parameter.
|
| 15 |
+
"""
|
| 16 |
+
img_lab = np.expand_dims(img_lab, axis=0)
|
| 17 |
+
im = lab2rgb(img_lab)*255
|
| 18 |
+
|
| 19 |
+
# Define color name mappings
|
| 20 |
+
color_values = np.array([[ 0, 0, 0],
|
| 21 |
+
[165, 81, 43],
|
| 22 |
+
[ 0, 0, 255],
|
| 23 |
+
[127, 127, 127],
|
| 24 |
+
[ 0, 255, 0],
|
| 25 |
+
[255, 127, 0],
|
| 26 |
+
[255, 165, 216],
|
| 27 |
+
[191, 0, 191],
|
| 28 |
+
[255, 0, 0],
|
| 29 |
+
[255, 255, 255],
|
| 30 |
+
[255, 255, 0]], dtype=np.uint8)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Extract RGB channels
|
| 34 |
+
# RR, GG, BB = im[:, :, 0].flatten(), im[:, :, 1].flatten(), im[:, :, 2].flatten()
|
| 35 |
+
|
| 36 |
+
# Compute index for w2c lookup
|
| 37 |
+
index_im = ((im[:, :, 0].flatten() // 8) + 32 * (im[:, :, 1].flatten()// 8) + 32 * 32 * (im[:, :, 2].flatten() // 8)).astype(np.int32)
|
| 38 |
+
|
| 39 |
+
# w2cM = np.argmax(w2c, axis=1)
|
| 40 |
+
# name_idx_img = w2cM[index_im].reshape(im.shape[:2])
|
| 41 |
+
|
| 42 |
+
# max_prob = w2c[np.arange(w2c.shape[0]), w2cM]
|
| 43 |
+
# max_prob_map = max_prob[index_im].reshape(im.shape[:2])
|
| 44 |
+
|
| 45 |
+
prob_map = w2c[index_im, :].reshape((im.shape[0], im.shape[1], w2c.shape[1]))
|
| 46 |
+
# max_prob_map = np.max(prob_map, axis=2)
|
| 47 |
+
name_idx_img = np.argmax(prob_map, axis=2)
|
| 48 |
+
|
| 49 |
+
color_img = np.zeros_like(im).astype(np.uint8)
|
| 50 |
+
color_nam =[0 for i in range(np.size(im, 1))]
|
| 51 |
+
|
| 52 |
+
for jj in range(im.shape[0]):
|
| 53 |
+
for ii in range(im.shape[1]):
|
| 54 |
+
color_img[jj, ii, :] = np.array(color_values[name_idx_img[jj, ii]])
|
| 55 |
+
color_nam[ii] = COLOR_NAME[name_idx_img[jj, ii]]
|
| 56 |
+
|
| 57 |
+
# return prob_map, max_prob_map, name_idx_img, color_img
|
| 58 |
+
return name_idx_img, color_nam, color_img, prob_map
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def compare_color_name(img_src, img_tgt, w2c, threshold=0.9):
|
| 64 |
+
color_label_org, color_nam_org, _, prob_map_org = im2c(img_src, w2c)
|
| 65 |
+
color_label_new, color_nam_new, _, prob_map_new = im2c(img_tgt, w2c)
|
| 66 |
+
|
| 67 |
+
if threshold==0:
|
| 68 |
+
is_same_color = (color_label_org==color_label_new)
|
| 69 |
+
else:
|
| 70 |
+
diff = np.zeros_like(color_label_org).astype(np.float64)
|
| 71 |
+
for jj in range(np.size(color_label_org, 0)):
|
| 72 |
+
for ii in range(np.size(color_label_org, 1)):
|
| 73 |
+
# difference also can be the l1 , l2 distance, kl divergence between two probablity distribution
|
| 74 |
+
# diff[jj, ii] = prob_map_org[jj, ii, color_label_org[jj, ii]] - prob_map_new[jj, ii, color_label_org[jj, ii]]
|
| 75 |
+
diff[jj, ii] = np.linalg.norm(prob_map_org[jj, ii, :]-prob_map_new[jj, ii, :])
|
| 76 |
+
is_same_color = (np.abs(diff) < threshold)
|
| 77 |
+
|
| 78 |
+
print(is_same_color)
|
| 79 |
+
return is_same_color
|
| 80 |
+
|
| 81 |
+
# if __name__ == "__main__":
|
| 82 |
+
# w2c = np.load('w2c11_joost_c.npy').astype(np.float16)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# image_path = './test.jpg'
|
| 86 |
+
# img = cv2.imread(image_path).astype(np.float32)
|
| 87 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 88 |
+
|
| 89 |
+
# prob_map, max_prob_img, name_idx_img, color_img = im2c(img, w2c)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# filtered_counts = Counter(name_idx_img[name_idx_img <= 10])
|
| 93 |
+
# sorted_counts = sorted(filtered_counts.items(), key=lambda x: x[1], reverse=True)
|
| 94 |
+
# top_3_values = [num for num, count in sorted_counts[:3]]
|
| 95 |
+
# top_3_colors = [COLOR_NAME[i] for i in top_3_values]
|
| 96 |
+
|
| 97 |
+
# print("Top 3 colors:", top_3_colors)
|
| 98 |
+
|
| 99 |
+
# cv2.imwrite('./colormap_joost.jpg', cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB))
|
| 100 |
+
|
color_naming/w2c11_joost_c.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bc5ede48fed749d552edf1fd5d058266c713f20af3d4a484310103be91766856
|
| 3 |
+
size 2883712
|
examples/flower/001.jpg
ADDED
|
Git LFS Details
|
examples/flower/002.jpg
ADDED
|
Git LFS Details
|
examples/flower/003.jpg
ADDED
|
Git LFS Details
|
examples/flower/004.jpg
ADDED
|
Git LFS Details
|
examples/flower/005.jpg
ADDED
|
Git LFS Details
|
examples/landmark/01.jpg
ADDED
|
Git LFS Details
|
examples/landmark/02.jpg
ADDED
|
Git LFS Details
|
examples/landmark/03.jpg
ADDED
|
Git LFS Details
|
examples/landmark/04.jpg
ADDED
|
Git LFS Details
|
examples/landmark/05.jpg
ADDED
|
Git LFS Details
|
examples/landmark/06.jpg
ADDED
|
Git LFS Details
|
examples/landmark/07.jpg
ADDED
|
Git LFS Details
|
examples/portrait/image-00000.png
ADDED
|
Git LFS Details
|
examples/portrait/image-00002.png
ADDED
|
Git LFS Details
|
examples/portrait/image-00004.png
ADDED
|
Git LFS Details
|
examples/portrait/image-00006.png
ADDED
|
Git LFS Details
|
examples/portrait/image-00014.png
ADDED
|
Git LFS Details
|
extract_palette.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
# import pandas as pd
|
| 3 |
+
|
| 4 |
+
from sklearn.cluster import KMeans
|
| 5 |
+
from sklearn.metrics import pairwise_distances
|
| 6 |
+
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def histogram(img_lab, bin, mode=2, mask=None):
|
| 12 |
+
# img_lab = rgb2lab(img_rgb)
|
| 13 |
+
# img_lab = img_lab.astype(int)
|
| 14 |
+
if mask is None:
|
| 15 |
+
mask = np.ones_like(img_lab[:,:,0])
|
| 16 |
+
|
| 17 |
+
if img_lab.ndim != 2:
|
| 18 |
+
img_lab = img_lab.reshape(-1, 3)
|
| 19 |
+
|
| 20 |
+
mask = mask.flatten()
|
| 21 |
+
img_lab_masked = img_lab[mask==1]
|
| 22 |
+
|
| 23 |
+
if mode == 3:
|
| 24 |
+
hist, edges = np.histogramdd(img_lab_masked, bins=bin)
|
| 25 |
+
xpos, ypos, zpos = np.meshgrid(edges[0][:-1], edges[1][:-1], edges[2][:-1], indexing="ij")
|
| 26 |
+
hist_samples = np.concatenate((xpos.reshape((bin*bin*bin,1)), ypos.reshape((bin*bin*bin,1)), zpos.reshape((bin*bin*bin,1))), axis=1)
|
| 27 |
+
hist_counts = hist.reshape(bin*bin*bin)
|
| 28 |
+
|
| 29 |
+
elif mode == 2:
|
| 30 |
+
hist, xedges, yedges = np.histogram2d(img_lab_masked[:,1], img_lab_masked[:,2], bins=bin, range=None)
|
| 31 |
+
xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing="ij")
|
| 32 |
+
hist_samples = np.concatenate((xpos.reshape((bin*bin,1)), ypos.reshape((bin*bin,1))), axis=1)
|
| 33 |
+
hist_counts = hist.reshape(bin*bin)
|
| 34 |
+
|
| 35 |
+
# hist_counts = hist_counts/np.sum(hist_counts)
|
| 36 |
+
|
| 37 |
+
return hist_samples, hist_counts
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def palette_extraction(img_lab, hist_samples, hist_counts, mode=2, threshold=0.93, num_clusters=5, mask=None):
|
| 43 |
+
|
| 44 |
+
if mask is None:
|
| 45 |
+
mask = np.ones_like(img_lab[:,:,0])
|
| 46 |
+
|
| 47 |
+
if img_lab.ndim != 2:
|
| 48 |
+
img_lab = img_lab.reshape(-1, 3)
|
| 49 |
+
|
| 50 |
+
mask = mask.flatten()
|
| 51 |
+
# img_lab = img_lab[mask==1]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
hist_densities = hist_counts /np.sum(hist_counts)
|
| 55 |
+
########################### palette extraction ###########################
|
| 56 |
+
# inital cluster center
|
| 57 |
+
index = np.argwhere(hist_densities!=0)
|
| 58 |
+
index = np.squeeze(index, axis=(1,))
|
| 59 |
+
num_nonzero = np.size(index)
|
| 60 |
+
|
| 61 |
+
# ## directly clustering
|
| 62 |
+
# num_clusters_opt = num_clusters
|
| 63 |
+
# kmeans_f = KMeans(n_clusters=num_clusters_opt, init='k-means++', random_state=0).fit(
|
| 64 |
+
# hist_samples[index, :], y=None, sample_weight=hist_densities[index])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
## clustering method from matlab code
|
| 68 |
+
inits_all = []
|
| 69 |
+
Cold = np.mean(hist_samples[index, :], 0)
|
| 70 |
+
distortion=np.zeros((num_clusters,1))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
dist = pairwise_distances(hist_samples[index, :], np.expand_dims(Cold, axis=0), metric='euclidean')
|
| 74 |
+
distortion[0] = np.sum(hist_densities[index] * np.squeeze(dist**2, axis=1), 0)
|
| 75 |
+
|
| 76 |
+
inits_all.append(Cold)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
for k in range(1, num_clusters):
|
| 81 |
+
# Initialize the cluster centers
|
| 82 |
+
k = k+1
|
| 83 |
+
cinits = np.zeros((k, mode))
|
| 84 |
+
cw = hist_densities[index]
|
| 85 |
+
for i in range(k):
|
| 86 |
+
id = np.argmax(cw)
|
| 87 |
+
cinits[i,:] = hist_samples[index, :][id,:]
|
| 88 |
+
d2 = cinits[i,:]* np.ones((num_nonzero, 1)) - hist_samples[index, :]
|
| 89 |
+
d2 = np.sum(np.square(d2), axis=1)
|
| 90 |
+
d2 = d2/np.max(d2)
|
| 91 |
+
cw = cw * (d2**2)
|
| 92 |
+
|
| 93 |
+
inits_all.append(cinits)
|
| 94 |
+
kmeans = KMeans(n_clusters=k, init=cinits, n_init=1).fit(
|
| 95 |
+
hist_samples[index, :], y=None, sample_weight=hist_densities[index])
|
| 96 |
+
|
| 97 |
+
dist_point = pairwise_distances(hist_samples[index, :], kmeans.cluster_centers_, metric='euclidean')
|
| 98 |
+
distortion[k-1] = np.sum(hist_densities[index] * np.min(dist_point, axis=1)**2)
|
| 99 |
+
|
| 100 |
+
variance = distortion[:-1] - distortion[1:]
|
| 101 |
+
distortion_percent = np.cumsum(variance)/(distortion[0]-distortion[-1])
|
| 102 |
+
|
| 103 |
+
r=np.argwhere(distortion_percent > threshold)
|
| 104 |
+
num_clusters_opt = np.min(r)+2
|
| 105 |
+
|
| 106 |
+
kmeans_f = KMeans(n_clusters=num_clusters_opt, init=inits_all[num_clusters_opt-1], n_init=1).fit(
|
| 107 |
+
hist_samples[index, :], y=None, sample_weight=hist_densities[index])
|
| 108 |
+
cluster_centers = kmeans_f.cluster_centers_
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# print(cluster_centers.shape)
|
| 112 |
+
|
| 113 |
+
if mode ==3:
|
| 114 |
+
img_labels = kmeans_f.predict(img_lab)
|
| 115 |
+
elif mode == 2:
|
| 116 |
+
img_labels = kmeans_f.predict(img_lab[:, 1:3])
|
| 117 |
+
|
| 118 |
+
hist_labels = kmeans_f.predict(hist_samples)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# print(cluster_centers.shape)
|
| 122 |
+
|
| 123 |
+
# # lab to rgb
|
| 124 |
+
# cluster_cen_rgb = lab2rgb(np.expand_dims(cluster_centers, axis=0))
|
| 125 |
+
# cluster_cen_rgb = np.squeeze(cluster_cen_rgb, axis=0)
|
| 126 |
+
|
| 127 |
+
img_labels[mask==0] = 255
|
| 128 |
+
c_densities = np.zeros(num_clusters_opt)
|
| 129 |
+
|
| 130 |
+
dict=Counter(img_labels)
|
| 131 |
+
for key in np.unique(img_labels):
|
| 132 |
+
if key == 255:
|
| 133 |
+
continue
|
| 134 |
+
c_densities[key] = dict.get(key)
|
| 135 |
+
|
| 136 |
+
c_densities = c_densities / np.sum(c_densities)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
return cluster_centers, c_densities, img_labels, hist_labels
|
image.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import cv2
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
|
| 7 |
+
from WB_sRGB.classes import WBsRGB as wb_srgb
|
| 8 |
+
from extract_palette import histogram, palette_extraction
|
| 9 |
+
from saliency.LDF.infer import Saliency_LDF
|
| 10 |
+
from saliency.fast_saliency import get_saliency_ft, get_saliency_mbd
|
| 11 |
+
from utils import color_difference
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseImage:
|
| 15 |
+
def __init__(self, filepath):
|
| 16 |
+
self.filename = os.path.basename(filepath.name)
|
| 17 |
+
self.image = Image.open(filepath)
|
| 18 |
+
self.img_rgb = np.asarray(self.image).astype(dtype=np.uint8)
|
| 19 |
+
self.img_lab = rgb2lab(self.img_rgb)
|
| 20 |
+
|
| 21 |
+
self.bin_size = 16
|
| 22 |
+
self.mode = 2
|
| 23 |
+
self.hist_harmonization = False
|
| 24 |
+
self.template = 'L'
|
| 25 |
+
self.distortion_threshold = 0.93
|
| 26 |
+
self.num_center_ind = 7
|
| 27 |
+
self.lightness = 70
|
| 28 |
+
# self.if_correct_wb = if_correct_wb
|
| 29 |
+
# self.if_saliency = if_saliency
|
| 30 |
+
# self.saliency_threshold = sal_thres
|
| 31 |
+
self.cdiff_threshold = 30
|
| 32 |
+
self.sal_threshold = 0.9
|
| 33 |
+
self.applied_wb = False
|
| 34 |
+
# self.valid_class = [0,1]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# self.hist_value, self.hist_count, \
|
| 38 |
+
# self.c_center, self.c_density, \
|
| 39 |
+
# self.c_img_label = self.extract_palette(if_wb=self.if_correct_wb,
|
| 40 |
+
# if_saliency=self.if_saliency,
|
| 41 |
+
# sal_thres=self.saliency_threshold)
|
| 42 |
+
|
| 43 |
+
# self.inital_info(self.if_correct_wb,
|
| 44 |
+
# self.if_saliency,
|
| 45 |
+
# self.saliency_threshold)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# self.hist_value, self.hist_count, self.c_center, self.c_density, self.c_img_label = self.extract_palette(if_wb=self.if_correct_wb, if_saliency=False)
|
| 49 |
+
# self.hist_value_sal, self.hist_count_sal, self.c_center_sal, self.c_density_sal, self.c_img_label_sal = self.extract_palette(if_wb=self.if_correct_wb, if_saliency=True, sal_thres=self.saliency_threshold)
|
| 50 |
+
|
| 51 |
+
def inital_info(self, if_correct_wb, if_saliency, wb_thres, sal_thres, valid_class):
|
| 52 |
+
self.hist_value, self.hist_count, \
|
| 53 |
+
self.c_center, self.c_density, \
|
| 54 |
+
self.c_img_label, self.sal_links = self.extract_salient_palette(if_wb=if_correct_wb,
|
| 55 |
+
if_saliency=if_saliency,
|
| 56 |
+
wb_thres=wb_thres,
|
| 57 |
+
sal_thres=sal_thres,
|
| 58 |
+
valid_class=valid_class)
|
| 59 |
+
|
| 60 |
+
self.label_colored = self.cal_color_segment()
|
| 61 |
+
|
| 62 |
+
def get_rgb_image(self):
|
| 63 |
+
return self.img_rgb
|
| 64 |
+
|
| 65 |
+
def get_lab_image(self):
|
| 66 |
+
return self.img_lab
|
| 67 |
+
|
| 68 |
+
def get_wb_image(self):
|
| 69 |
+
self.img_wb = self.white_balance_correction()
|
| 70 |
+
return self.img_wb
|
| 71 |
+
|
| 72 |
+
def get_saliency(self):
|
| 73 |
+
self.sal_map = self.saliency_detection(self.img_rgb)
|
| 74 |
+
return self.sal_map
|
| 75 |
+
|
| 76 |
+
def get_color_segment(self):
|
| 77 |
+
return self.label_colored
|
| 78 |
+
|
| 79 |
+
def get_label(self):
|
| 80 |
+
# print(self.links)
|
| 81 |
+
# label_mapped = np.zeros_like(self.colorlabel)
|
| 82 |
+
# for id, label in enumerate(self.links):
|
| 83 |
+
# label_mapped[self.colorlabel==id] = label
|
| 84 |
+
# self.colorlabel = label_mapped
|
| 85 |
+
return self.colorlabel
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def cal_color_segment(self):
|
| 89 |
+
label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
|
| 90 |
+
for id_color in range(np.size(self.center, 0)):
|
| 91 |
+
label_colored[self.colorlabel == id_color] = self.center[id_color, :]
|
| 92 |
+
label_colored = lab2rgb(label_colored)
|
| 93 |
+
label_colored = np.round(label_colored*255).astype(np.uint8)
|
| 94 |
+
return label_colored
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# def cal_salient_segment(self, palettelabel):
|
| 98 |
+
# label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
|
| 99 |
+
# valid_label = np.argwhere(palettelabel==1).flatten()
|
| 100 |
+
# for id_color in valid_label:
|
| 101 |
+
# label_colored[self.colorlabel == id_color] = self.center[id_color, :]
|
| 102 |
+
# label_colored = lab2rgb(label_colored)
|
| 103 |
+
# label_colored = np.round(label_colored*255).astype(np.uint8)
|
| 104 |
+
# return label_colored
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def white_balance_correction(self):
|
| 108 |
+
# print('Correcting the white balance...')
|
| 109 |
+
# use upgraded_model = 1 to load our new model that is upgraded with new
|
| 110 |
+
# training examples.
|
| 111 |
+
upgraded_model = 2
|
| 112 |
+
# use gamut_mapping = 1 for scaling, 2 for clipping (our paper's results
|
| 113 |
+
# reported using clipping). If the image is over-saturated, scaling is
|
| 114 |
+
# recommended.
|
| 115 |
+
gamut_mapping = 2
|
| 116 |
+
# processing
|
| 117 |
+
# create an instance of the WB model
|
| 118 |
+
wbModel = wb_srgb.WBsRGB(gamut_mapping=gamut_mapping,
|
| 119 |
+
upgraded=upgraded_model)
|
| 120 |
+
img_wb = wbModel.correctImage(self.img_rgb) # white balance it
|
| 121 |
+
image_wb = (img_wb*255).astype(np.uint8)
|
| 122 |
+
# img_wb = cv2.cvtColor(img_wb, cv2.COLOR_BGR2RGB)
|
| 123 |
+
return image_wb
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def saliency_detection(self, img_rgb, method='LDF'):
|
| 127 |
+
if method == 'LDF':
|
| 128 |
+
get_saliency_LDF = Saliency_LDF()
|
| 129 |
+
sal_map = get_saliency_LDF.inference(img_rgb)
|
| 130 |
+
elif method == 'ft':
|
| 131 |
+
sal_map = get_saliency_ft(img_rgb)
|
| 132 |
+
elif method == 'rbd':
|
| 133 |
+
sal_map = get_saliency_mbd(img_rgb)
|
| 134 |
+
return sal_map
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def solve_ind_palette(self, img_rgb, mask_binary=None):
|
| 138 |
+
w, h, c = img_rgb.shape
|
| 139 |
+
img_lab = rgb2lab(img_rgb) # lab transfer by function
|
| 140 |
+
|
| 141 |
+
hist_value, hist_count = histogram(img_lab, self.bin_size, mode=self.mode, mask=mask_binary) ## with numpy histogram
|
| 142 |
+
|
| 143 |
+
## extract palette
|
| 144 |
+
# mask_binary = np.ones_like(self.img_rgb[:,:,0])
|
| 145 |
+
c_center, c_density, c_img_label, histlabel = palette_extraction(img_lab, hist_value, hist_count,
|
| 146 |
+
threshold=self.distortion_threshold,
|
| 147 |
+
num_clusters=self.num_center_ind,
|
| 148 |
+
mode=self.mode,
|
| 149 |
+
mask=mask_binary)
|
| 150 |
+
|
| 151 |
+
if self.mode == 2:
|
| 152 |
+
c_center = np.insert(c_center, 0, values=self.lightness, axis=1)
|
| 153 |
+
|
| 154 |
+
c_img_label = np.reshape(c_img_label, (w,h))
|
| 155 |
+
# density = np.tile(hist_counts, (self.mode, 1))
|
| 156 |
+
|
| 157 |
+
return hist_value, hist_count, c_center, c_density, c_img_label, histlabel
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def extract_salient_palette(self, if_wb=False, if_saliency=False, wb_thres=5, sal_thres=0.9, valid_class=[0,1]):
|
| 163 |
+
|
| 164 |
+
img_rgb = self.img_rgb.copy()
|
| 165 |
+
if if_wb:
|
| 166 |
+
self.img_wb = self.white_balance_correction()
|
| 167 |
+
img_wb = self.img_wb
|
| 168 |
+
dE = color_difference(img_rgb, img_wb)
|
| 169 |
+
print(dE)
|
| 170 |
+
if dE > wb_thres:
|
| 171 |
+
self.applied_wb = True
|
| 172 |
+
img_rgb = img_wb
|
| 173 |
+
print('use white balance correction on {}'.format(self.filename.split('/')[-1]))
|
| 174 |
+
|
| 175 |
+
hist_value, hist_count, center, density, colorlabel, histlabel = self.solve_ind_palette(img_rgb, mask_binary=None)
|
| 176 |
+
self.center = center
|
| 177 |
+
self.colorlabel = colorlabel
|
| 178 |
+
|
| 179 |
+
sal_links = [i for i in range(np.size(center, axis=0))]
|
| 180 |
+
|
| 181 |
+
if not if_saliency:
|
| 182 |
+
return hist_value, hist_count, center, density, colorlabel, sal_links
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
self.sal_map = self.saliency_detection(self.img_rgb)
|
| 186 |
+
label_sem = np.zeros_like(self.img_rgb[:,:,0])
|
| 187 |
+
# print(label_sem.shape, self.sal_map.shape)
|
| 188 |
+
label_sem[self.sal_map > sal_thres]=1
|
| 189 |
+
|
| 190 |
+
p_feq = np.zeros((len(valid_class), np.size(center, axis=0)))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
for id_cls, cls in enumerate(valid_class):
|
| 194 |
+
label_binary = np.zeros_like(label_sem)
|
| 195 |
+
label_binary[label_sem==cls] = 1
|
| 196 |
+
colorlabel_cls = colorlabel[label_binary==1]
|
| 197 |
+
value, count = np.unique(colorlabel_cls, return_counts=True)
|
| 198 |
+
p_feq[id_cls, value] = count/count.sum()
|
| 199 |
+
|
| 200 |
+
palettelabel = np.argmax(p_feq, axis=0)
|
| 201 |
+
|
| 202 |
+
class_num = len(valid_class)
|
| 203 |
+
c_center = [np.array([]) for i in range(class_num)]
|
| 204 |
+
c_density = [np.array([]) for i in range(class_num)]
|
| 205 |
+
c_img_label = [np.array([]) for i in range(class_num)]
|
| 206 |
+
hist_samples = [np.array([]) for i in range(class_num)]
|
| 207 |
+
hist_counts = [np.array([]) for i in range(class_num)]
|
| 208 |
+
mapping = [np.array([]) for i in range(class_num)]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
for id_cls, cls in enumerate(valid_class):
|
| 212 |
+
mapping[id_cls] = np.argwhere(palettelabel==id_cls).flatten()
|
| 213 |
+
c_center[id_cls]= center[mapping[id_cls],:]
|
| 214 |
+
c_density[id_cls] = density[mapping[id_cls]]
|
| 215 |
+
hist_samples[id_cls] = hist_value.copy()
|
| 216 |
+
hist_counts[id_cls] = hist_count.copy()
|
| 217 |
+
hist_counts[id_cls][histlabel!=id_cls] = 0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
for idx, label in enumerate(mapping[id_cls]):
|
| 221 |
+
labels = np.zeros_like(colorlabel)
|
| 222 |
+
labels[colorlabel==label] = idx
|
| 223 |
+
c_img_label[id_cls] = labels
|
| 224 |
+
|
| 225 |
+
# if id_cls ==1:
|
| 226 |
+
# label_colored = np.zeros_like(self.img_rgb, dtype=np.float64)
|
| 227 |
+
# for id_color in mapping[id_cls]:
|
| 228 |
+
# label_colored[colorlabel == id_color] = center[id_color, :]
|
| 229 |
+
# label_colored = lab2rgb(label_colored)
|
| 230 |
+
# label_colored = np.round(label_colored*255).astype(np.uint8)
|
| 231 |
+
|
| 232 |
+
# print(colorlabel.shape, c_img_label[id_cls].shape)
|
| 233 |
+
# print(density.shape, c_density[id_cls].shape)
|
| 234 |
+
# print(center.shape, c_center[id_cls].shape)
|
| 235 |
+
|
| 236 |
+
sal_links = np.hstack((mapping[1], mapping[0]))
|
| 237 |
+
|
| 238 |
+
# print(links)
|
| 239 |
+
|
| 240 |
+
return hist_samples, hist_counts, c_center, c_density, c_img_label, sal_links
|
| 241 |
+
|
| 242 |
+
|
multi_image_process.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from skimage.color import rgb2lab
|
| 6 |
+
|
| 7 |
+
from solve_group_palette import solve_group_palette
|
| 8 |
+
from recolor import lab_transfer
|
| 9 |
+
from color_naming.colornaming import compare_color_name
|
| 10 |
+
from utils import visualize_palette
|
| 11 |
+
from image import BaseImage
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# recolor single image with given palette TO DO
|
| 15 |
+
def recolor_single_image(image, save_dir='./results/testing'):
|
| 16 |
+
image = image
|
| 17 |
+
|
| 18 |
+
return image
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# compute and save the palette for each image
|
| 22 |
+
def compute_inp_palette(images, save_dir='./results/testing'):
|
| 23 |
+
palette_all = [i for i in range(len(images))]
|
| 24 |
+
|
| 25 |
+
for img_id, img in enumerate(images):
|
| 26 |
+
|
| 27 |
+
img_name = os.path.basename(img.name)
|
| 28 |
+
print('processing image {}: {}...'.format(img_id, img_name))
|
| 29 |
+
image = BaseImage(img)
|
| 30 |
+
_, _, c_center, _, _, _ = image.solve_ind_palette(image.img_rgb, mask_binary=None)
|
| 31 |
+
# print(c_center)
|
| 32 |
+
|
| 33 |
+
if not os.path.exists(save_dir):
|
| 34 |
+
os.makedirs(save_dir)
|
| 35 |
+
imwrite_path = os.path.join(save_dir, 'palette_'+ img_name[:-4]+'.png')
|
| 36 |
+
img_palette = visualize_palette(c_center, patch_size=20)
|
| 37 |
+
img_palette = np.round(img_palette*255).astype(np.uint8)
|
| 38 |
+
cv2.imwrite(imwrite_path, cv2.cvtColor(img_palette, cv2.COLOR_RGB2BGR))
|
| 39 |
+
|
| 40 |
+
palette_all[img_id] = img_palette.copy()
|
| 41 |
+
return palette_all
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# match the palette with the group palette
|
| 45 |
+
def match_palette(palette_ind, palette_grp, L_idx, if_cn, naming_thres):
|
| 46 |
+
# print(L_idx)
|
| 47 |
+
# print(palette_ind)
|
| 48 |
+
palette_mapped = palette_ind.copy()
|
| 49 |
+
valid = (L_idx > 0)
|
| 50 |
+
# print('L_idx:', L_idx)
|
| 51 |
+
if L_idx.size == 0:
|
| 52 |
+
return palette_mapped, L_idx
|
| 53 |
+
valid = valid.flatten()
|
| 54 |
+
idx = np.argwhere(L_idx.flatten() > 0).flatten()
|
| 55 |
+
|
| 56 |
+
palette_mapped[idx, :] = palette_grp[L_idx[valid].flatten()-1, :]
|
| 57 |
+
|
| 58 |
+
if if_cn:
|
| 59 |
+
# print("Check the matching colors with color naming")
|
| 60 |
+
# print(L_idx, palette_mapped)
|
| 61 |
+
w2c = np.load('./color_naming/w2c11_joost_c.npy').astype(np.float16)
|
| 62 |
+
is_the_same = compare_color_name(palette_ind, palette_mapped, w2c, threshold=naming_thres)
|
| 63 |
+
mask = (is_the_same == True).T
|
| 64 |
+
idx = np.argwhere(is_the_same == False).flatten()
|
| 65 |
+
L_idx = L_idx * mask
|
| 66 |
+
palette_mapped[idx, :] = palette_ind[idx, :]
|
| 67 |
+
# print(L_idx, palette_mapped)
|
| 68 |
+
return palette_mapped, L_idx
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# solve the group palette without saliency
|
| 72 |
+
def solve_grp_palette(images, mode, bin_size,
|
| 73 |
+
if_wb=False, wb_thres=20, num_center=5,
|
| 74 |
+
lightness=70., eta=1e10, gamma=0, iteration=10,
|
| 75 |
+
if_cn=False, naming_thres=0.5):
|
| 76 |
+
|
| 77 |
+
num_img = len(images)
|
| 78 |
+
c_center = [np.array([]) for i in range(num_img)]
|
| 79 |
+
c_density = [np.array([]) for i in range(num_img)]
|
| 80 |
+
palette_map = [np.array([]) for i in range(num_img)]
|
| 81 |
+
L_idx = [np.array([]) for i in range(num_img)]
|
| 82 |
+
|
| 83 |
+
if mode == 3:
|
| 84 |
+
hist_samples_all = np.zeros((bin_size**3, 3))
|
| 85 |
+
hist_counts_all = np.zeros(bin_size**3)
|
| 86 |
+
|
| 87 |
+
elif mode ==2:
|
| 88 |
+
hist_samples_all = np.zeros((bin_size**2, 2))
|
| 89 |
+
hist_counts_all = np.zeros(bin_size**2)
|
| 90 |
+
|
| 91 |
+
for img_id, image in enumerate(images):
|
| 92 |
+
image.inital_info(if_wb,
|
| 93 |
+
False,
|
| 94 |
+
wb_thres,
|
| 95 |
+
0.9,
|
| 96 |
+
[0, 1])
|
| 97 |
+
|
| 98 |
+
density = np.tile(image.hist_count, (mode, 1))
|
| 99 |
+
hist_counts_all = hist_counts_all + image.hist_count
|
| 100 |
+
hist_samples_all = hist_samples_all + density.T * image.hist_value
|
| 101 |
+
|
| 102 |
+
index = np.argwhere(hist_counts_all != 0)
|
| 103 |
+
index = np.squeeze(index, axis=(1,))
|
| 104 |
+
|
| 105 |
+
hist_counts_all = hist_counts_all[index]
|
| 106 |
+
hist_samples_all = hist_samples_all[index, :] / np.expand_dims(hist_counts_all, axis=1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
print('Solving the group palette...')
|
| 110 |
+
## take the number of palettes of each input image as the reference
|
| 111 |
+
reference = np.zeros((num_img, 1))
|
| 112 |
+
if np.sum(reference) == 0:
|
| 113 |
+
reference = reference + 1
|
| 114 |
+
|
| 115 |
+
num_palettes = 0
|
| 116 |
+
for i in range(num_img):
|
| 117 |
+
if reference[i]:
|
| 118 |
+
num_palettes = num_palettes + np.size(images[i].c_center, 1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
##### calculate inital group center with kmeans
|
| 122 |
+
# print(palette_size_default.get().dtype())
|
| 123 |
+
m = np.minimum(int(num_center), num_palettes)
|
| 124 |
+
|
| 125 |
+
c_center = [image.c_center for image in images]
|
| 126 |
+
c_density = [image.c_density for image in images]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
M, matching = solve_group_palette(hist_samples_all, hist_counts_all,
|
| 130 |
+
c_center, c_density, reference, m,
|
| 131 |
+
lightness=lightness, eta=eta,
|
| 132 |
+
gamma=gamma, iteration=iteration)
|
| 133 |
+
|
| 134 |
+
for img_id in range(num_img):
|
| 135 |
+
palette_map[img_id], L_idx[img_id] = match_palette(images[img_id].c_center, M, matching[img_id], if_cn, naming_thres)
|
| 136 |
+
# print('c_center: ', c_center)
|
| 137 |
+
|
| 138 |
+
return c_center, M, palette_map, L_idx
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# solve the group palette with saliency
|
| 143 |
+
def solve_grp_palette_wsal(images, mode, bin_size,
|
| 144 |
+
lightness=70., eta=1e10, gamma=0, iteration=10,
|
| 145 |
+
if_wb=False, wb_thres=30,
|
| 146 |
+
if_saliency=True, sal_thres=0.9, valid_class=[0,1],
|
| 147 |
+
sal_center=1, nonsal_center=1,
|
| 148 |
+
recolor_nonsal_only=False, recolor_sal_only=False,
|
| 149 |
+
if_cn=False, naming_thres=0.5):
|
| 150 |
+
|
| 151 |
+
class_num = len(valid_class)
|
| 152 |
+
num_img = len(images)
|
| 153 |
+
|
| 154 |
+
c_center = [np.array([]) for i in range(num_img)]
|
| 155 |
+
c_density = [np.array([]) for i in range(num_img)]
|
| 156 |
+
|
| 157 |
+
M = [np.array([]) for i in range(class_num)]
|
| 158 |
+
matching = [np.array([]) for i in range(class_num)]
|
| 159 |
+
|
| 160 |
+
palette_map = [[np.array([]) for i in range(num_img)] for i in range(class_num)]
|
| 161 |
+
L_idx = [[np.array([]) for i in range(num_img)] for i in range(class_num)]
|
| 162 |
+
p_src_sal_size = [np.array([]) for i in range(num_img)]
|
| 163 |
+
|
| 164 |
+
if mode == 3:
|
| 165 |
+
hist_samples_all = [np.zeros((bin_size**3, 3)) for i in range(class_num)]
|
| 166 |
+
hist_counts_all = [np.zeros(bin_size**3) for i in range(class_num)]
|
| 167 |
+
|
| 168 |
+
elif mode ==2:
|
| 169 |
+
hist_samples_all = [np.zeros((bin_size**2, 2)) for i in range(class_num)]
|
| 170 |
+
hist_counts_all = [np.zeros(bin_size**2) for i in range(class_num)]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
for img_id, image in enumerate(images):
|
| 174 |
+
|
| 175 |
+
image.inital_info(if_wb,
|
| 176 |
+
if_saliency,
|
| 177 |
+
wb_thres,
|
| 178 |
+
sal_thres,
|
| 179 |
+
valid_class)
|
| 180 |
+
|
| 181 |
+
for id_cls, cls in enumerate(valid_class):
|
| 182 |
+
density = np.tile(image.hist_count[id_cls], (mode, 1))
|
| 183 |
+
hist_counts_all[id_cls] = hist_counts_all[id_cls] + image.hist_count[id_cls]
|
| 184 |
+
hist_samples_all[id_cls] = hist_samples_all[id_cls] + density.T * image.hist_value[id_cls]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
for id_cls, cls in enumerate(valid_class):
|
| 188 |
+
index = np.argwhere(hist_counts_all[id_cls] != 0)
|
| 189 |
+
index = np.squeeze(index, axis=(1,))
|
| 190 |
+
|
| 191 |
+
hist_counts_all[id_cls] = hist_counts_all[id_cls][index]
|
| 192 |
+
hist_samples_all[id_cls] = hist_samples_all[id_cls] [index, :] / np.expand_dims(hist_counts_all[id_cls], axis=1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
print('Solving the group palette...')
|
| 196 |
+
## take the number of palettes of each input image as the reference
|
| 197 |
+
reference = np.zeros((num_img, 1))
|
| 198 |
+
if np.sum(reference) == 0:
|
| 199 |
+
reference = reference + 1
|
| 200 |
+
|
| 201 |
+
num_palettes = 0
|
| 202 |
+
for i in range(num_img):
|
| 203 |
+
if reference[i] and images[i].c_center[id_cls].size != 0:
|
| 204 |
+
num_palettes = num_palettes + np.size(images[i].c_center[id_cls] , 1)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
##### calculate inital group center with kmeans
|
| 208 |
+
# print(palette_size_default.get().dtype())
|
| 209 |
+
m=[1,1]
|
| 210 |
+
m[0] = np.minimum(int(sal_center), num_palettes)
|
| 211 |
+
m[1] = np.minimum(int(nonsal_center), num_palettes)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
c_center = [image.c_center[id_cls] for image in images]
|
| 215 |
+
c_density = [image.c_density[id_cls] for image in images]
|
| 216 |
+
# print(c_center[id_cls])
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
M[id_cls], matching[id_cls] = solve_group_palette(hist_samples_all[id_cls], hist_counts_all[id_cls],
|
| 220 |
+
c_center, c_density, reference, m[id_cls],
|
| 221 |
+
lightness=lightness, eta=eta,
|
| 222 |
+
gamma=gamma, iteration=iteration)
|
| 223 |
+
# print(M[id_cls], matching[id_cls])
|
| 224 |
+
|
| 225 |
+
for img_id in range(num_img):
|
| 226 |
+
palette_map[id_cls][img_id], L_idx[id_cls][img_id] = match_palette(images[img_id].c_center[id_cls], M[id_cls], matching[id_cls][img_id], if_cn, naming_thres)
|
| 227 |
+
if id_cls == 1:
|
| 228 |
+
p_src_sal_size[img_id] = np.size(palette_map[1][img_id], axis=0)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# if only_nonsal_var.get():
|
| 233 |
+
# palette_map[1][img_id] = c_center[1][img_id]
|
| 234 |
+
|
| 235 |
+
p_grp_sal_size = np.size(M[1], axis=0)
|
| 236 |
+
p_nsal_size = np.size(M[0], axis=0)
|
| 237 |
+
|
| 238 |
+
cls_keep = None
|
| 239 |
+
if recolor_nonsal_only:
|
| 240 |
+
print('only recolor non-salient region')
|
| 241 |
+
cls_keep = 0
|
| 242 |
+
elif recolor_sal_only:
|
| 243 |
+
print('only recolor salient region')
|
| 244 |
+
cls_keep = 1
|
| 245 |
+
|
| 246 |
+
if cls_keep is not None:
|
| 247 |
+
for img_id in range(num_img):
|
| 248 |
+
# print(L_idx[1][img_id])
|
| 249 |
+
# print(palette_map[1][img_id])
|
| 250 |
+
if L_idx[cls_keep][img_id].size == 0:
|
| 251 |
+
palette_map[cls_keep][img_id] = images[img_id].c_center[cls_keep]
|
| 252 |
+
L_idx[cls_keep][img_id] = 0
|
| 253 |
+
else:
|
| 254 |
+
# print(L_idx[cls_keep][img_id])
|
| 255 |
+
for idx_nsal in range(len(L_idx[cls_keep][img_id])):
|
| 256 |
+
palette_map[cls_keep][img_id][idx_nsal,:] = images[img_id].c_center[cls_keep][idx_nsal,:]
|
| 257 |
+
L_idx[cls_keep][img_id][idx_nsal] = 0
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
L_idx_img = [0 for i in range(num_img)]
|
| 261 |
+
palette_map_img = [0 for i in range(num_img)]
|
| 262 |
+
center_img = [0 for i in range(num_img)]
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
for img_id in range(num_img):
|
| 266 |
+
# print(L_idx[0][img_id], L_idx[1][img_id])
|
| 267 |
+
center_img[img_id] = np.vstack([images[img_id].c_center[1], images[img_id].c_center[0]])
|
| 268 |
+
palette_map_img[img_id] = np.vstack([palette_map[1][img_id], palette_map[0][img_id]])
|
| 269 |
+
|
| 270 |
+
# print(img_id, L_idx[1][img_id], L_idx[0][img_id], p_grp_sal_size)
|
| 271 |
+
if L_idx[1][img_id].size == 0 and L_idx[0][img_id].size == 0:
|
| 272 |
+
L_idx_img[img_id] = np.array([])
|
| 273 |
+
elif L_idx[1][img_id].size == 0:
|
| 274 |
+
L_idx_img[img_id] = L_idx[0][img_id]+p_grp_sal_size
|
| 275 |
+
elif L_idx[0][img_id].size == 0:
|
| 276 |
+
L_idx_img[img_id] = L_idx[1][img_id]
|
| 277 |
+
else:
|
| 278 |
+
L_idx_img[img_id] = np.vstack([L_idx[1][img_id], L_idx[0][img_id]+p_grp_sal_size])
|
| 279 |
+
# print(L_idx_img[img_id])
|
| 280 |
+
|
| 281 |
+
M = np.vstack([M[1], M[0]])
|
| 282 |
+
return center_img, M, palette_map_img, L_idx_img, p_src_sal_size, p_grp_sal_size
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# recolor multiple images for color consistency
|
| 288 |
+
def recolor_group_images(inp_images,
|
| 289 |
+
num_center_grp, num_center_sal, num_center_nonsal,
|
| 290 |
+
if_wb,
|
| 291 |
+
if_sal, recolor_nonsal_only, recolor_sal_only,
|
| 292 |
+
if_cn, naming_thres,
|
| 293 |
+
save_dir='./results/testing'):
|
| 294 |
+
|
| 295 |
+
mode = 2
|
| 296 |
+
bin_size = 16
|
| 297 |
+
|
| 298 |
+
num_img = len(inp_images)
|
| 299 |
+
images = [None for i in range(num_img)]
|
| 300 |
+
recolored_images = [None for i in range(num_img)]
|
| 301 |
+
img_p_src = [None for i in range(num_img)]
|
| 302 |
+
img_p_tgt = [None for i in range(num_img)]
|
| 303 |
+
|
| 304 |
+
links = []
|
| 305 |
+
|
| 306 |
+
for img_id, image in enumerate(inp_images):
|
| 307 |
+
img_name = os.path.basename(image.name)
|
| 308 |
+
print('processing image {}: {}...'.format(img_id, img_name))
|
| 309 |
+
images[img_id] = BaseImage(image)
|
| 310 |
+
|
| 311 |
+
if if_sal:
|
| 312 |
+
palette_src, palette_grp, palette_tgt, links, src_sal_size, grp_sal_size = solve_grp_palette_wsal(images, mode, bin_size,
|
| 313 |
+
lightness=70., eta=1e10, gamma=0, iteration=10,
|
| 314 |
+
if_wb=if_wb, wb_thres=30,
|
| 315 |
+
if_saliency=True, sal_thres=0.5, valid_class=[0,1],
|
| 316 |
+
sal_center=num_center_sal, nonsal_center=num_center_nonsal,
|
| 317 |
+
recolor_nonsal_only=recolor_nonsal_only, recolor_sal_only=recolor_sal_only,
|
| 318 |
+
if_cn=if_cn, naming_thres=float(naming_thres))
|
| 319 |
+
else:
|
| 320 |
+
src_sal_size = [0 for i in range(num_img)]
|
| 321 |
+
grp_sal_size = 0
|
| 322 |
+
palette_src, palette_grp, palette_tgt, links = solve_grp_palette(images, mode, bin_size,
|
| 323 |
+
if_wb=if_wb, wb_thres=30, num_center=num_center_grp,
|
| 324 |
+
lightness=70., eta=1e10, gamma=0, iteration=10,
|
| 325 |
+
if_cn=if_cn, naming_thres=float(naming_thres))
|
| 326 |
+
|
| 327 |
+
for img_id, image in enumerate(images):
|
| 328 |
+
# print(palette_grp[img_id].shape, palette_tgt[img_id].shape)
|
| 329 |
+
if if_wb:
|
| 330 |
+
img_wb = image.get_wb_image()
|
| 331 |
+
img_lab = rgb2lab(img_wb)
|
| 332 |
+
else:
|
| 333 |
+
img_lab = image.get_lab_image()
|
| 334 |
+
img_rgb_out, _ = lab_transfer(img_lab, palette_src[img_id], palette_tgt[img_id], mask=None, mode=2)
|
| 335 |
+
img_rgb_out = np.round(img_rgb_out*255).astype(np.uint8)
|
| 336 |
+
|
| 337 |
+
out_img_path = os.path.join(save_dir, 'recolor_'+ image.filename)
|
| 338 |
+
img_bgr_out = cv2.cvtColor(img_rgb_out, cv2.COLOR_RGB2BGR)
|
| 339 |
+
cv2.imwrite(out_img_path, img_bgr_out)
|
| 340 |
+
|
| 341 |
+
recolored_images[img_id] = img_rgb_out
|
| 342 |
+
|
| 343 |
+
# self.wb_images[img_id] = image.get_wb_image()
|
| 344 |
+
# self.saliency_images[img_id] = image.get_saliency()
|
| 345 |
+
|
| 346 |
+
link = (links[img_id]-1).flatten()
|
| 347 |
+
link = link.tolist()
|
| 348 |
+
link = [None if item == -1 else item for item in link]
|
| 349 |
+
links.append(link)
|
| 350 |
+
|
| 351 |
+
img_p_grp = [visualize_palette(palette_grp, patch_size=20)]
|
| 352 |
+
|
| 353 |
+
# print(palette_grp, palette_src, palette_tgt)
|
| 354 |
+
|
| 355 |
+
for img_id in range(num_img):
|
| 356 |
+
img_p_src[img_id] = visualize_palette(palette_src[img_id], patch_size=20)
|
| 357 |
+
img_p_tgt[img_id] = visualize_palette(palette_tgt[img_id], patch_size=20)
|
| 358 |
+
|
| 359 |
+
return recolored_images, img_p_src, img_p_tgt, img_p_grp
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
recolor.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from skimage.color import rgb2lab, lab2rgb
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def rgb_transfer(Iin, C_src, C_tgt, mask=None):
|
| 6 |
+
|
| 7 |
+
Iin = np.array(Iin/255.).astype(np.float32)
|
| 8 |
+
|
| 9 |
+
C_src = lab2rgb(np.expand_dims(C_src, axis=0))
|
| 10 |
+
C_tgt = lab2rgb(np.expand_dims(C_tgt, axis=0))
|
| 11 |
+
|
| 12 |
+
C_src = np.squeeze(C_src, axis=0)
|
| 13 |
+
C_tgt = np.squeeze(C_tgt, axis=0)
|
| 14 |
+
|
| 15 |
+
if mask is None:
|
| 16 |
+
mask = np.ones_like(Iin[:,:,0])
|
| 17 |
+
|
| 18 |
+
m, n, b = Iin.shape
|
| 19 |
+
|
| 20 |
+
Iin = np.reshape(Iin, (m*n, b))
|
| 21 |
+
Iout = Iin.copy()
|
| 22 |
+
mask = mask.flatten()
|
| 23 |
+
|
| 24 |
+
Iout = ab_transfer(Iin, C_src, C_tgt, mask=mask)
|
| 25 |
+
Iout = np.reshape(Iout, (m,n,b))
|
| 26 |
+
|
| 27 |
+
return Iout
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def lab_transfer(Iin, C_src, C_tgt, mask=None, mode=2):
|
| 32 |
+
# Convert RGB to Lab
|
| 33 |
+
# print(C_src)
|
| 34 |
+
# print(C_tgt)
|
| 35 |
+
if mask is None:
|
| 36 |
+
mask = np.ones_like(Iin[:,:,0])
|
| 37 |
+
|
| 38 |
+
Pout = C_tgt.copy()
|
| 39 |
+
m, n, b = Iin.shape
|
| 40 |
+
|
| 41 |
+
# Iin = rgb2lab(Iin)
|
| 42 |
+
Iin = np.reshape(Iin, (m*n, b))
|
| 43 |
+
Iout = Iin.copy()
|
| 44 |
+
mask = mask.flatten()
|
| 45 |
+
|
| 46 |
+
# C_src = rgb2lab(np.expand_dims(C_src, axis=0))
|
| 47 |
+
# C_tgt = rgb2lab(np.expand_dims(C_tgt, axis=0))
|
| 48 |
+
|
| 49 |
+
# C_src = np.squeeze(C_src, axis=0)
|
| 50 |
+
# C_tgt = np.squeeze(C_tgt, axis=0)
|
| 51 |
+
|
| 52 |
+
if mode == 2:
|
| 53 |
+
Iout[:, 1:] = ab_transfer(Iin[:, 1:], C_src[:, 1:], C_tgt[:, 1:], mask=mask)
|
| 54 |
+
else:
|
| 55 |
+
Iout[:, 0:] = ab_transfer(Iin[:, 0:], C_src[:, 0:], C_tgt[:, 0:], mask=mask)
|
| 56 |
+
|
| 57 |
+
# Convert Lab to RGB
|
| 58 |
+
Iout = np.reshape(Iout, (m,n,b))
|
| 59 |
+
Iout = lab2rgb(Iout)
|
| 60 |
+
|
| 61 |
+
return Iout, Pout
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def ab_transfer(I_src, C_src, C_tgt, mask=None):
|
| 66 |
+
if mask is None:
|
| 67 |
+
mask = np.ones_like(I_src[:,0])
|
| 68 |
+
|
| 69 |
+
I_tgt = np.zeros_like(I_src)
|
| 70 |
+
[m, b] = I_src.shape
|
| 71 |
+
|
| 72 |
+
# remove close color
|
| 73 |
+
k = np.size(C_src, 0)
|
| 74 |
+
eps = 0.0001
|
| 75 |
+
W = np.zeros((m, k))
|
| 76 |
+
for i in range(k):
|
| 77 |
+
D = np.zeros(m)
|
| 78 |
+
for j in range(b):
|
| 79 |
+
D = D + (I_src[:, j] - C_src[i,j])**2
|
| 80 |
+
W[:, i] = 1./(D + eps)
|
| 81 |
+
|
| 82 |
+
# print(k,b)
|
| 83 |
+
|
| 84 |
+
sumW= np.sum(W, 1)
|
| 85 |
+
for j in range(k):
|
| 86 |
+
W[:, j]= W[:, j] / sumW
|
| 87 |
+
|
| 88 |
+
for i in range(k):
|
| 89 |
+
for j in range(b):
|
| 90 |
+
I_tgt[:, j] = I_tgt[:, j] + W[:, i] * (I_src[:, j] + C_tgt[i, j] - C_src[i, j])
|
| 91 |
+
|
| 92 |
+
idx = np.argwhere(mask == 0)
|
| 93 |
+
|
| 94 |
+
I_tgt[idx, :] = I_src[idx, :]
|
| 95 |
+
|
| 96 |
+
return I_tgt
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def lab_transfer_cls(Iin, C_src, C_tgt, mask=None, valid_class=None):
|
| 102 |
+
# Convert RGB to Lab
|
| 103 |
+
if mask is None:
|
| 104 |
+
mask = np.ones_like(Iin[:,:,0])
|
| 105 |
+
|
| 106 |
+
Pout = C_tgt.copy()
|
| 107 |
+
m, n, b = Iin.shape
|
| 108 |
+
|
| 109 |
+
Iin = rgb2lab(Iin)
|
| 110 |
+
Iin = np.reshape(Iin, (m*n, b))
|
| 111 |
+
Iout = Iin.copy()
|
| 112 |
+
Iout_cls = np.zeros_like(Iin)
|
| 113 |
+
|
| 114 |
+
mask = mask.flatten()
|
| 115 |
+
mask_bin = np.zeros_like(mask)
|
| 116 |
+
|
| 117 |
+
# C_src = rgb2lab(np.expand_dims(C_src, axis=0))
|
| 118 |
+
# C_tgt = rgb2lab(np.expand_dims(C_tgt, axis=0))
|
| 119 |
+
|
| 120 |
+
# C_src = np.squeeze(C_src, axis=0)
|
| 121 |
+
# C_tgt = np.squeeze(C_tgt, axis=0)
|
| 122 |
+
for id_cls, cls in enumerate(valid_class):
|
| 123 |
+
if C_src[id_cls].size == 0:
|
| 124 |
+
continue
|
| 125 |
+
mask_bin[mask==cls] = 1
|
| 126 |
+
Iout_cls[:, 1:] = ab_transfer(Iin[:, 1:], C_src[id_cls][:, 1:], C_tgt[id_cls][:, 1:], mask=mask_bin)
|
| 127 |
+
idx = np.argwhere(mask_bin == 1)
|
| 128 |
+
Iout[idx, 1:] = Iout_cls[idx, 1:].copy()
|
| 129 |
+
|
| 130 |
+
# Convert Lab to RGB
|
| 131 |
+
Iout = np.reshape(np.round(Iout), (m,n,b))
|
| 132 |
+
Iout = lab2rgb(Iout)
|
| 133 |
+
|
| 134 |
+
return Iout, Pout
|
| 135 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
opencv-python
|
| 3 |
+
torch
|
| 4 |
+
scikit-learn
|
| 5 |
+
scikit-image
|
| 6 |
+
Pillow
|
| 7 |
+
SciPy
|
| 8 |
+
networkx
|
| 9 |
+
libsvm
|
saliency/LDF/dataset.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
#coding=utf-8
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
########################### Data Augmentation ###########################
|
| 11 |
+
class Normalize(object):
|
| 12 |
+
def __init__(self, mean, std):
|
| 13 |
+
self.mean = mean
|
| 14 |
+
self.std = std
|
| 15 |
+
|
| 16 |
+
def __call__(self, image, mask=None, body=None, detail=None):
|
| 17 |
+
image = (image - self.mean)/self.std
|
| 18 |
+
if mask is None:
|
| 19 |
+
return image
|
| 20 |
+
return image, mask/255, body/255, detail/255
|
| 21 |
+
|
| 22 |
+
class RandomCrop(object):
|
| 23 |
+
def __call__(self, image, mask=None, body=None, detail=None):
|
| 24 |
+
H,W,_ = image.shape
|
| 25 |
+
randw = np.random.randint(W/8)
|
| 26 |
+
randh = np.random.randint(H/8)
|
| 27 |
+
offseth = 0 if randh == 0 else np.random.randint(randh)
|
| 28 |
+
offsetw = 0 if randw == 0 else np.random.randint(randw)
|
| 29 |
+
p0, p1, p2, p3 = offseth, H+offseth-randh, offsetw, W+offsetw-randw
|
| 30 |
+
if mask is None:
|
| 31 |
+
return image[p0:p1,p2:p3, :]
|
| 32 |
+
return image[p0:p1,p2:p3, :], mask[p0:p1,p2:p3], body[p0:p1,p2:p3], detail[p0:p1,p2:p3]
|
| 33 |
+
|
| 34 |
+
class RandomFlip(object):
|
| 35 |
+
def __call__(self, image, mask=None, body=None, detail=None):
|
| 36 |
+
if np.random.randint(2)==0:
|
| 37 |
+
if mask is None:
|
| 38 |
+
return image[:,::-1,:].copy()
|
| 39 |
+
return image[:,::-1,:].copy(), mask[:, ::-1].copy(), body[:, ::-1].copy(), detail[:, ::-1].copy()
|
| 40 |
+
else:
|
| 41 |
+
if mask is None:
|
| 42 |
+
return image
|
| 43 |
+
return image, mask, body, detail
|
| 44 |
+
|
| 45 |
+
class Resize(object):
|
| 46 |
+
def __init__(self, H, W):
|
| 47 |
+
self.H = H
|
| 48 |
+
self.W = W
|
| 49 |
+
|
| 50 |
+
def __call__(self, image, mask=None, body=None, detail=None):
|
| 51 |
+
image = cv2.resize(image, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
| 52 |
+
if mask is None:
|
| 53 |
+
return image
|
| 54 |
+
mask = cv2.resize( mask, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
| 55 |
+
body = cv2.resize( body, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
| 56 |
+
detail= cv2.resize( detail, dsize=(self.W, self.H), interpolation=cv2.INTER_LINEAR)
|
| 57 |
+
return image, mask, body, detail
|
| 58 |
+
|
| 59 |
+
class ToTensor(object):
|
| 60 |
+
def __call__(self, image, mask=None, body=None, detail=None):
|
| 61 |
+
image = torch.from_numpy(image)
|
| 62 |
+
image = image.permute(2, 0, 1)
|
| 63 |
+
if mask is None:
|
| 64 |
+
return image
|
| 65 |
+
mask = torch.from_numpy(mask)
|
| 66 |
+
body = torch.from_numpy(body)
|
| 67 |
+
detail= torch.from_numpy(detail)
|
| 68 |
+
return image, mask, body, detail
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
########################### Config File ###########################
|
| 72 |
+
class Config(object):
|
| 73 |
+
def __init__(self, **kwargs):
|
| 74 |
+
self.kwargs = kwargs
|
| 75 |
+
self.mean = np.array([[[124.55, 118.90, 102.94]]])
|
| 76 |
+
self.std = np.array([[[ 56.77, 55.97, 57.50]]])
|
| 77 |
+
# print('\nParameters...')
|
| 78 |
+
# for k, v in self.kwargs.items():
|
| 79 |
+
# print('%-10s: %s'%(k, v))
|
| 80 |
+
|
| 81 |
+
def __getattr__(self, name):
|
| 82 |
+
if name in self.kwargs:
|
| 83 |
+
return self.kwargs[name]
|
| 84 |
+
else:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
########################### Dataset Class ###########################
|
| 89 |
+
class Data(Dataset):
|
| 90 |
+
def __init__(self, cfg):
|
| 91 |
+
self.cfg = cfg
|
| 92 |
+
self.normalize = Normalize(mean=cfg.mean, std=cfg.std)
|
| 93 |
+
self.randomcrop = RandomCrop()
|
| 94 |
+
self.randomflip = RandomFlip()
|
| 95 |
+
self.resize = Resize(352, 352)
|
| 96 |
+
self.totensor = ToTensor()
|
| 97 |
+
|
| 98 |
+
with open(cfg.datapath+'/'+cfg.mode+'.txt', 'r') as lines:
|
| 99 |
+
self.samples = []
|
| 100 |
+
for line in lines:
|
| 101 |
+
self.samples.append(line.strip())
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx):
|
| 104 |
+
name = self.samples[idx]
|
| 105 |
+
image = cv2.imread(self.cfg.datapath+'/image/'+name+'.jpg')[:,:,::-1].astype(np.float32)
|
| 106 |
+
|
| 107 |
+
if self.cfg.mode=='train':
|
| 108 |
+
mask = cv2.imread(self.cfg.datapath+'/mask/' +name+'.png', 0).astype(np.float32)
|
| 109 |
+
body = cv2.imread(self.cfg.datapath+'/body/' +name+'.png', 0).astype(np.float32)
|
| 110 |
+
detail= cv2.imread(self.cfg.datapath+'/detail/' +name+'.png', 0).astype(np.float32)
|
| 111 |
+
image, mask, body, detail = self.normalize(image, mask, body, detail)
|
| 112 |
+
image, mask, body, detail = self.randomcrop(image, mask, body, detail)
|
| 113 |
+
image, mask, body, detail = self.randomflip(image, mask, body, detail)
|
| 114 |
+
return image, mask, body, detail
|
| 115 |
+
else:
|
| 116 |
+
shape = image.shape[:2]
|
| 117 |
+
image = self.normalize(image)
|
| 118 |
+
image = self.resize(image)
|
| 119 |
+
image = self.totensor(image)
|
| 120 |
+
return image, shape, name
|
| 121 |
+
|
| 122 |
+
def __len__(self):
|
| 123 |
+
return len(self.samples)
|
| 124 |
+
|
| 125 |
+
def collate(self, batch):
|
| 126 |
+
size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
|
| 127 |
+
image, mask, body, detail = [list(item) for item in zip(*batch)]
|
| 128 |
+
for i in range(len(batch)):
|
| 129 |
+
image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
|
| 130 |
+
mask[i] = cv2.resize(mask[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
|
| 131 |
+
body[i] = cv2.resize(body[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
|
| 132 |
+
detail[i]= cv2.resize(detail[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
|
| 133 |
+
image = torch.from_numpy(np.stack(image, axis=0)).permute(0,3,1,2)
|
| 134 |
+
mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1)
|
| 135 |
+
body = torch.from_numpy(np.stack(body, axis=0)).unsqueeze(1)
|
| 136 |
+
detail = torch.from_numpy(np.stack(detail, axis=0)).unsqueeze(1)
|
| 137 |
+
return image, mask, body, detail
|
saliency/LDF/infer.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import numpy as np
|
| 6 |
+
sys.dont_write_bytecode = True
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
import saliency.LDF.dataset as dataset
|
| 12 |
+
from saliency.LDF.net import LDF
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Implementation of saliency detection with LDF model.
|
| 16 |
+
class Saliency_LDF:
|
| 17 |
+
def __init__(self, pretrained_model='./saliency/LDF/model-40'):
|
| 18 |
+
self.cfg = dataset.Config(snapshot=pretrained_model, mode='test')
|
| 19 |
+
self.normalize = dataset.Normalize(mean=self.cfg.mean, std=self.cfg.std)
|
| 20 |
+
self.resize = dataset.Resize(352, 352)
|
| 21 |
+
self.totensor = dataset.ToTensor()
|
| 22 |
+
## network
|
| 23 |
+
self.net = LDF(self.cfg)
|
| 24 |
+
self.net.train(False)
|
| 25 |
+
self.net.cuda()
|
| 26 |
+
|
| 27 |
+
def inference(self, img_rgb):
|
| 28 |
+
shape = img_rgb.shape[:2]
|
| 29 |
+
image = self.normalize(img_rgb)
|
| 30 |
+
image = self.resize(image)
|
| 31 |
+
image = self.totensor(image)
|
| 32 |
+
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
image = image.unsqueeze(0)
|
| 35 |
+
image = image.cuda().float()
|
| 36 |
+
outb1, outd1, out1, outb2, outd2, out2 = self.net(image, shape)
|
| 37 |
+
out = out2
|
| 38 |
+
pred = torch.sigmoid(out[0,0]).cpu().numpy() #[0,1]
|
| 39 |
+
|
| 40 |
+
return pred
|
saliency/LDF/model-40
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:972a77654afdad80f99ce29968e8c02d5f14cad0aa52e21b46b1c47e1468d68e
|
| 3 |
+
size 100920708
|
saliency/LDF/net.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
#coding=utf-8
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
def weight_init(module):
|
| 11 |
+
for n, m in module.named_children():
|
| 12 |
+
print('initialize: '+n)
|
| 13 |
+
if isinstance(m, nn.Conv2d):
|
| 14 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
| 15 |
+
if m.bias is not None:
|
| 16 |
+
nn.init.zeros_(m.bias)
|
| 17 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
|
| 18 |
+
nn.init.ones_(m.weight)
|
| 19 |
+
if m.bias is not None:
|
| 20 |
+
nn.init.zeros_(m.bias)
|
| 21 |
+
elif isinstance(m, nn.Linear):
|
| 22 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
|
| 23 |
+
if m.bias is not None:
|
| 24 |
+
nn.init.zeros_(m.bias)
|
| 25 |
+
elif isinstance(m, nn.Sequential):
|
| 26 |
+
weight_init(m)
|
| 27 |
+
elif isinstance(m, nn.ReLU):
|
| 28 |
+
pass
|
| 29 |
+
else:
|
| 30 |
+
m.initialize()
|
| 31 |
+
|
| 32 |
+
class Bottleneck(nn.Module):
|
| 33 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
|
| 34 |
+
super(Bottleneck, self).__init__()
|
| 35 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 36 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 37 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3*dilation-1)//2, bias=False, dilation=dilation)
|
| 38 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 39 |
+
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
|
| 40 |
+
self.bn3 = nn.BatchNorm2d(planes*4)
|
| 41 |
+
self.downsample = downsample
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
|
| 45 |
+
out = F.relu(self.bn2(self.conv2(out)), inplace=True)
|
| 46 |
+
out = self.bn3(self.conv3(out))
|
| 47 |
+
if self.downsample is not None:
|
| 48 |
+
x = self.downsample(x)
|
| 49 |
+
return F.relu(out+x, inplace=True)
|
| 50 |
+
|
| 51 |
+
class ResNet(nn.Module):
|
| 52 |
+
def __init__(self):
|
| 53 |
+
super(ResNet, self).__init__()
|
| 54 |
+
# self.cfg = cfg
|
| 55 |
+
self.inplanes = 64
|
| 56 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 57 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 58 |
+
self.layer1 = self.make_layer( 64, 3, stride=1, dilation=1)
|
| 59 |
+
self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)
|
| 60 |
+
self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)
|
| 61 |
+
self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)
|
| 62 |
+
# if self.training:
|
| 63 |
+
# self.initialize()
|
| 64 |
+
|
| 65 |
+
def make_layer(self, planes, blocks, stride, dilation):
|
| 66 |
+
downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes*4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes*4))
|
| 67 |
+
layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)]
|
| 68 |
+
self.inplanes = planes*4
|
| 69 |
+
for _ in range(1, blocks):
|
| 70 |
+
layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))
|
| 71 |
+
return nn.Sequential(*layers)
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
|
| 75 |
+
out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
|
| 76 |
+
out2 = self.layer1(out1)
|
| 77 |
+
out3 = self.layer2(out2)
|
| 78 |
+
out4 = self.layer3(out3)
|
| 79 |
+
out5 = self.layer4(out4)
|
| 80 |
+
return out1, out2, out3, out4, out5
|
| 81 |
+
|
| 82 |
+
def initialize(self):
|
| 83 |
+
self.load_state_dict(torch.load('../res/resnet50-19c8e357.pth'), strict=False)
|
| 84 |
+
|
| 85 |
+
class Decoder(nn.Module):
|
| 86 |
+
def __init__(self):
|
| 87 |
+
super(Decoder, self).__init__()
|
| 88 |
+
self.conv0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 89 |
+
self.bn0 = nn.BatchNorm2d(64)
|
| 90 |
+
self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 91 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 92 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 93 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 94 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 95 |
+
self.bn3 = nn.BatchNorm2d(64)
|
| 96 |
+
|
| 97 |
+
def forward(self, input1, input2=[0,0,0,0]):
|
| 98 |
+
out0 = F.relu(self.bn0(self.conv0(input1[0]+input2[0])), inplace=True)
|
| 99 |
+
out0 = F.interpolate(out0, size=input1[1].size()[2:], mode='bilinear')
|
| 100 |
+
out1 = F.relu(self.bn1(self.conv1(input1[1]+input2[1]+out0)), inplace=True)
|
| 101 |
+
out1 = F.interpolate(out1, size=input1[2].size()[2:], mode='bilinear')
|
| 102 |
+
out2 = F.relu(self.bn2(self.conv2(input1[2]+input2[2]+out1)), inplace=True)
|
| 103 |
+
out2 = F.interpolate(out2, size=input1[3].size()[2:], mode='bilinear')
|
| 104 |
+
out3 = F.relu(self.bn3(self.conv3(input1[3]+input2[3]+out2)), inplace=True)
|
| 105 |
+
return out3
|
| 106 |
+
|
| 107 |
+
def initialize(self):
|
| 108 |
+
weight_init(self)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Encoder(nn.Module):
|
| 112 |
+
def __init__(self):
|
| 113 |
+
super(Encoder, self).__init__()
|
| 114 |
+
self.conv1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
|
| 115 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 116 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 117 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 118 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 119 |
+
self.bn3 = nn.BatchNorm2d(64)
|
| 120 |
+
self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 121 |
+
self.bn4 = nn.BatchNorm2d(64)
|
| 122 |
+
|
| 123 |
+
self.conv1b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 124 |
+
self.bn1b = nn.BatchNorm2d(64)
|
| 125 |
+
self.conv2b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 126 |
+
self.bn2b = nn.BatchNorm2d(64)
|
| 127 |
+
self.conv3b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 128 |
+
self.bn3b = nn.BatchNorm2d(64)
|
| 129 |
+
self.conv4b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 130 |
+
self.bn4b = nn.BatchNorm2d(64)
|
| 131 |
+
|
| 132 |
+
self.conv1d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 133 |
+
self.bn1d = nn.BatchNorm2d(64)
|
| 134 |
+
self.conv2d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 135 |
+
self.bn2d = nn.BatchNorm2d(64)
|
| 136 |
+
self.conv3d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 137 |
+
self.bn3d = nn.BatchNorm2d(64)
|
| 138 |
+
self.conv4d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
| 139 |
+
self.bn4d = nn.BatchNorm2d(64)
|
| 140 |
+
|
| 141 |
+
def forward(self, out1):
|
| 142 |
+
out1 = F.relu(self.bn1(self.conv1(out1)), inplace=True)
|
| 143 |
+
out2 = F.max_pool2d(out1, kernel_size=2, stride=2)
|
| 144 |
+
out2 = F.relu(self.bn2(self.conv2(out2)), inplace=True)
|
| 145 |
+
out3 = F.max_pool2d(out2, kernel_size=2, stride=2)
|
| 146 |
+
out3 = F.relu(self.bn3(self.conv3(out3)), inplace=True)
|
| 147 |
+
out4 = F.max_pool2d(out3, kernel_size=2, stride=2)
|
| 148 |
+
out4 = F.relu(self.bn4(self.conv4(out4)), inplace=True)
|
| 149 |
+
|
| 150 |
+
out1b = F.relu(self.bn1b(self.conv1b(out1)), inplace=True)
|
| 151 |
+
out2b = F.relu(self.bn2b(self.conv2b(out2)), inplace=True)
|
| 152 |
+
out3b = F.relu(self.bn3b(self.conv3b(out3)), inplace=True)
|
| 153 |
+
out4b = F.relu(self.bn4b(self.conv4b(out4)), inplace=True)
|
| 154 |
+
|
| 155 |
+
out1d = F.relu(self.bn1d(self.conv1d(out1)), inplace=True)
|
| 156 |
+
out2d = F.relu(self.bn2d(self.conv2d(out2)), inplace=True)
|
| 157 |
+
out3d = F.relu(self.bn3d(self.conv3d(out3)), inplace=True)
|
| 158 |
+
out4d = F.relu(self.bn4d(self.conv4d(out4)), inplace=True)
|
| 159 |
+
return (out4b, out3b, out2b, out1b), (out4d, out3d, out2d, out1d)
|
| 160 |
+
|
| 161 |
+
def initialize(self):
|
| 162 |
+
weight_init(self)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class LDF(nn.Module):
|
| 166 |
+
def __init__(self, cfg):
|
| 167 |
+
super(LDF, self).__init__()
|
| 168 |
+
self.cfg = cfg
|
| 169 |
+
self.bkbone = ResNet()
|
| 170 |
+
self.conv5b = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 171 |
+
self.conv4b = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 172 |
+
self.conv3b = nn.Sequential(nn.Conv2d( 512, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 173 |
+
self.conv2b = nn.Sequential(nn.Conv2d( 256, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 174 |
+
|
| 175 |
+
self.conv5d = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 176 |
+
self.conv4d = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 177 |
+
self.conv3d = nn.Sequential(nn.Conv2d( 512, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 178 |
+
self.conv2d = nn.Sequential(nn.Conv2d( 256, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
|
| 179 |
+
|
| 180 |
+
self.encoder = Encoder()
|
| 181 |
+
self.decoderb = Decoder()
|
| 182 |
+
self.decoderd = Decoder()
|
| 183 |
+
self.linearb = nn.Conv2d(64, 1, kernel_size=3, padding=1)
|
| 184 |
+
self.lineard = nn.Conv2d(64, 1, kernel_size=3, padding=1)
|
| 185 |
+
self.linear = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 1, kernel_size=3, padding=1))
|
| 186 |
+
self.initialize()
|
| 187 |
+
|
| 188 |
+
def forward(self, x, shape=None):
|
| 189 |
+
out1, out2, out3, out4, out5 = self.bkbone(x)
|
| 190 |
+
out2b, out3b, out4b, out5b = self.conv2b(out2), self.conv3b(out3), self.conv4b(out4), self.conv5b(out5)
|
| 191 |
+
out2d, out3d, out4d, out5d = self.conv2d(out2), self.conv3d(out3), self.conv4d(out4), self.conv5d(out5)
|
| 192 |
+
|
| 193 |
+
outb1 = self.decoderb([out5b, out4b, out3b, out2b])
|
| 194 |
+
outd1 = self.decoderd([out5d, out4d, out3d, out2d])
|
| 195 |
+
out1 = torch.cat([outb1, outd1], dim=1)
|
| 196 |
+
outb2, outd2 = self.encoder(out1)
|
| 197 |
+
outb2 = self.decoderb([out5b, out4b, out3b, out2b], outb2)
|
| 198 |
+
outd2 = self.decoderd([out5d, out4d, out3d, out2d], outd2)
|
| 199 |
+
out2 = torch.cat([outb2, outd2], dim=1)
|
| 200 |
+
|
| 201 |
+
if shape is None:
|
| 202 |
+
shape = x.size()[2:]
|
| 203 |
+
out1 = F.interpolate(self.linear(out1), size=shape, mode='bilinear')
|
| 204 |
+
outb1 = F.interpolate(self.linearb(outb1), size=shape, mode='bilinear')
|
| 205 |
+
outd1 = F.interpolate(self.lineard(outd1), size=shape, mode='bilinear')
|
| 206 |
+
|
| 207 |
+
out2 = F.interpolate(self.linear(out2), size=shape, mode='bilinear')
|
| 208 |
+
outb2 = F.interpolate(self.linearb(outb2), size=shape, mode='bilinear')
|
| 209 |
+
outd2 = F.interpolate(self.lineard(outd2), size=shape, mode='bilinear')
|
| 210 |
+
return outb1, outd1, out1, outb2, outd2, out2
|
| 211 |
+
|
| 212 |
+
def initialize(self):
|
| 213 |
+
if self.cfg.snapshot:
|
| 214 |
+
self.load_state_dict(torch.load(self.cfg.snapshot))
|
| 215 |
+
else:
|
| 216 |
+
weight_init(self)
|
saliency/fast_saliency.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy.spatial.distance
|
| 5 |
+
import scipy.signal
|
| 6 |
+
import skimage
|
| 7 |
+
import skimage.io
|
| 8 |
+
import time
|
| 9 |
+
from skimage.segmentation import slic
|
| 10 |
+
from skimage.util import img_as_float
|
| 11 |
+
import networkx as nx
|
| 12 |
+
#import matplotlib.pyplot as plt
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def S(x1,x2,geodesic,sigma_clr=10):
|
| 16 |
+
return math.exp(-pow(geodesic[x1,x2],2)/(2*sigma_clr*sigma_clr))
|
| 17 |
+
|
| 18 |
+
def compute_saliency_cost(smoothness,w_bg,wCtr):
|
| 19 |
+
n = len(w_bg)
|
| 20 |
+
A = np.zeros((n,n))
|
| 21 |
+
b = np.zeros((n))
|
| 22 |
+
|
| 23 |
+
for x in range(0,n):
|
| 24 |
+
A[x,x] = 2 * w_bg[x] + 2 * (wCtr[x])
|
| 25 |
+
b[x] = 2 * wCtr[x]
|
| 26 |
+
for y in range(0,n):
|
| 27 |
+
A[x,x] += 2 * smoothness[x,y]
|
| 28 |
+
A[x,y] -= 2 * smoothness[x,y]
|
| 29 |
+
|
| 30 |
+
x = np.linalg.solve(A, b)
|
| 31 |
+
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
def path_length(path,G):
|
| 35 |
+
dist = 0.0
|
| 36 |
+
for i in range(1,len(path)):
|
| 37 |
+
dist += G[path[i - 1]][path[i]]['weight']
|
| 38 |
+
return dist
|
| 39 |
+
|
| 40 |
+
def make_graph(grid):
|
| 41 |
+
# get unique labels
|
| 42 |
+
vertices = np.unique(grid)
|
| 43 |
+
|
| 44 |
+
# map unique labels to [1,...,num_labels]
|
| 45 |
+
reverse_dict = dict(zip(vertices,np.arange(len(vertices))))
|
| 46 |
+
grid = np.array([reverse_dict[x] for x in grid.flat]).reshape(grid.shape)
|
| 47 |
+
|
| 48 |
+
# create edges
|
| 49 |
+
down = np.c_[grid[:-1, :].ravel(), grid[1:, :].ravel()]
|
| 50 |
+
right = np.c_[grid[:, :-1].ravel(), grid[:, 1:].ravel()]
|
| 51 |
+
all_edges = np.vstack([right, down])
|
| 52 |
+
all_edges = all_edges[all_edges[:, 0] != all_edges[:, 1], :]
|
| 53 |
+
all_edges = np.sort(all_edges,axis=1)
|
| 54 |
+
num_vertices = len(vertices)
|
| 55 |
+
edge_hash = all_edges[:,0] + num_vertices * all_edges[:, 1]
|
| 56 |
+
# find unique connections
|
| 57 |
+
edges = np.unique(edge_hash)
|
| 58 |
+
# undo hashing
|
| 59 |
+
edges = [[vertices[x%num_vertices],
|
| 60 |
+
vertices[x//num_vertices]] for x in edges]
|
| 61 |
+
|
| 62 |
+
return vertices, edges
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Saliency map calculation based on:
|
| 66 |
+
# Wangjiang Zhu, Shuang Liang, Yichen Wei and Jian Sun,
|
| 67 |
+
# Saliency Optimization from Robust Background Detection, (CVPR), 2014
|
| 68 |
+
|
| 69 |
+
# based on the asumption that saliency region has smaller 'boundary connectivity'
|
| 70 |
+
# object regions (salient regions) are much less connected to image boundaries than background ones
|
| 71 |
+
# superpixel
|
| 72 |
+
|
| 73 |
+
def get_saliency_rbd(img):
|
| 74 |
+
|
| 75 |
+
img_lab = img_as_float(skimage.color.rgb2lab(img))
|
| 76 |
+
|
| 77 |
+
img_rgb = img_as_float(img)
|
| 78 |
+
|
| 79 |
+
img_gray = img_as_float(skimage.color.rgb2gray(img))
|
| 80 |
+
|
| 81 |
+
segments_slic = slic(img_rgb, n_segments=250, compactness=10, sigma=1, enforce_connectivity=False)
|
| 82 |
+
|
| 83 |
+
# num_segments = len(np.unique(segments_slic))
|
| 84 |
+
|
| 85 |
+
nrows, ncols = segments_slic.shape
|
| 86 |
+
max_dist = math.sqrt(nrows*nrows + ncols*ncols)
|
| 87 |
+
|
| 88 |
+
grid = segments_slic
|
| 89 |
+
|
| 90 |
+
(vertices,edges) = make_graph(grid)
|
| 91 |
+
|
| 92 |
+
gridx, gridy = np.mgrid[:grid.shape[0], :grid.shape[1]]
|
| 93 |
+
|
| 94 |
+
centers = dict()
|
| 95 |
+
colors = dict()
|
| 96 |
+
# distances = dict()
|
| 97 |
+
boundary = dict()
|
| 98 |
+
|
| 99 |
+
for v in vertices:
|
| 100 |
+
centers[v] = [gridy[grid == v].mean(), gridx[grid == v].mean()]
|
| 101 |
+
colors[v] = np.mean(img_lab[grid==v],axis=0)
|
| 102 |
+
|
| 103 |
+
x_pix = gridx[grid == v]
|
| 104 |
+
y_pix = gridy[grid == v]
|
| 105 |
+
|
| 106 |
+
if np.any(x_pix == 0) or np.any(y_pix == 0) or np.any(x_pix == nrows - 1) or np.any(y_pix == ncols - 1):
|
| 107 |
+
boundary[v] = 1
|
| 108 |
+
else:
|
| 109 |
+
boundary[v] = 0
|
| 110 |
+
|
| 111 |
+
G = nx.Graph()
|
| 112 |
+
|
| 113 |
+
#buid the graph
|
| 114 |
+
for edge in edges:
|
| 115 |
+
pt1 = edge[0]
|
| 116 |
+
pt2 = edge[1]
|
| 117 |
+
color_distance = scipy.spatial.distance.euclidean(colors[pt1],colors[pt2])
|
| 118 |
+
G.add_edge(pt1, pt2, weight=color_distance )
|
| 119 |
+
|
| 120 |
+
#add a new edge in graph if edges are both on boundary
|
| 121 |
+
for v1 in vertices:
|
| 122 |
+
if boundary[v1] == 1:
|
| 123 |
+
for v2 in vertices:
|
| 124 |
+
if boundary[v2] == 1:
|
| 125 |
+
color_distance = scipy.spatial.distance.euclidean(colors[v1],colors[v2])
|
| 126 |
+
G.add_edge(v1,v2,weight=color_distance)
|
| 127 |
+
|
| 128 |
+
geodesic = np.zeros((len(vertices),len(vertices)),dtype=float)
|
| 129 |
+
spatial = np.zeros((len(vertices),len(vertices)),dtype=float)
|
| 130 |
+
smoothness = np.zeros((len(vertices),len(vertices)),dtype=float)
|
| 131 |
+
adjacency = np.zeros((len(vertices),len(vertices)),dtype=float)
|
| 132 |
+
|
| 133 |
+
sigma_clr = 10.0
|
| 134 |
+
sigma_bndcon = 1.0
|
| 135 |
+
sigma_spa = 0.25
|
| 136 |
+
mu = 0.1
|
| 137 |
+
|
| 138 |
+
all_shortest_paths_color = nx.shortest_path(G,source=None,target=None,weight='weight')
|
| 139 |
+
|
| 140 |
+
for v1 in vertices:
|
| 141 |
+
for v2 in vertices:
|
| 142 |
+
if v1 == v2:
|
| 143 |
+
geodesic[v1,v2] = 0
|
| 144 |
+
spatial[v1,v2] = 0
|
| 145 |
+
smoothness[v1,v2] = 0
|
| 146 |
+
else:
|
| 147 |
+
geodesic[v1,v2] = path_length(all_shortest_paths_color[v1][v2],G)
|
| 148 |
+
spatial[v1,v2] = scipy.spatial.distance.euclidean(centers[v1],centers[v2]) / max_dist
|
| 149 |
+
smoothness[v1,v2] = math.exp( - (geodesic[v1,v2] * geodesic[v1,v2])/(2.0*sigma_clr*sigma_clr)) + mu
|
| 150 |
+
|
| 151 |
+
for edge in edges:
|
| 152 |
+
pt1 = edge[0]
|
| 153 |
+
pt2 = edge[1]
|
| 154 |
+
adjacency[pt1,pt2] = 1
|
| 155 |
+
adjacency[pt2,pt1] = 1
|
| 156 |
+
|
| 157 |
+
for v1 in vertices:
|
| 158 |
+
for v2 in vertices:
|
| 159 |
+
smoothness[v1,v2] = adjacency[v1,v2] * smoothness[v1,v2]
|
| 160 |
+
|
| 161 |
+
area = dict()
|
| 162 |
+
len_bnd = dict()
|
| 163 |
+
bnd_con = dict()
|
| 164 |
+
w_bg = dict()
|
| 165 |
+
ctr = dict()
|
| 166 |
+
wCtr = dict()
|
| 167 |
+
|
| 168 |
+
for v1 in vertices:
|
| 169 |
+
area[v1] = 0
|
| 170 |
+
len_bnd[v1] = 0
|
| 171 |
+
ctr[v1] = 0
|
| 172 |
+
for v2 in vertices:
|
| 173 |
+
d_app = geodesic[v1,v2]
|
| 174 |
+
d_spa = spatial[v1,v2]
|
| 175 |
+
w_spa = math.exp(- ((d_spa)*(d_spa))/(2.0*sigma_spa*sigma_spa))
|
| 176 |
+
area_i = S(v1,v2,geodesic)
|
| 177 |
+
area[v1] += area_i
|
| 178 |
+
len_bnd[v1] += area_i * boundary[v2]
|
| 179 |
+
ctr[v1] += d_app * w_spa
|
| 180 |
+
bnd_con[v1] = len_bnd[v1] / math.sqrt(area[v1])
|
| 181 |
+
w_bg[v1] = 1.0 - math.exp(- (bnd_con[v1]*bnd_con[v1])/(2*sigma_bndcon*sigma_bndcon))
|
| 182 |
+
|
| 183 |
+
for v1 in vertices:
|
| 184 |
+
wCtr[v1] = 0
|
| 185 |
+
for v2 in vertices:
|
| 186 |
+
d_app = geodesic[v1,v2]
|
| 187 |
+
d_spa = spatial[v1,v2]
|
| 188 |
+
w_spa = math.exp(- (d_spa*d_spa)/(2.0*sigma_spa*sigma_spa))
|
| 189 |
+
wCtr[v1] += d_app * w_spa * w_bg[v2]
|
| 190 |
+
|
| 191 |
+
# normalise value for wCtr
|
| 192 |
+
|
| 193 |
+
min_value = min(wCtr.values())
|
| 194 |
+
max_value = max(wCtr.values())
|
| 195 |
+
|
| 196 |
+
minVal = [key for key, value in wCtr.items() if value == min_value]
|
| 197 |
+
maxVal = [key for key, value in wCtr.items() if value == max_value]
|
| 198 |
+
|
| 199 |
+
for v in vertices:
|
| 200 |
+
wCtr[v] = (wCtr[v] - min_value)/(max_value - min_value)
|
| 201 |
+
|
| 202 |
+
img_disp1 = img_gray.copy()
|
| 203 |
+
# img_disp2 = img_gray.copy()
|
| 204 |
+
|
| 205 |
+
x = compute_saliency_cost(smoothness,w_bg,wCtr)
|
| 206 |
+
|
| 207 |
+
for v in vertices:
|
| 208 |
+
img_disp1[grid == v] = x[v]
|
| 209 |
+
|
| 210 |
+
img_disp2 = img_disp1.copy()
|
| 211 |
+
sal = np.zeros((img_disp1.shape[0],img_disp1.shape[1],3))
|
| 212 |
+
|
| 213 |
+
sal = img_disp2
|
| 214 |
+
sal_max = np.max(sal)
|
| 215 |
+
sal_min = np.min(sal)
|
| 216 |
+
sal = ((sal - sal_min) / (sal_max - sal_min))
|
| 217 |
+
|
| 218 |
+
return sal
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Saliency map calculation based on:
|
| 223 |
+
# R. Achanta, S. Hemami, F. Estrada and S. Süsstrunk,
|
| 224 |
+
# Frequency-tuned Salient Region Detection, (CVPR 2009), pp. 1597 - 1604, 2009
|
| 225 |
+
# a frequency-tuned approach to estimate center-surround contrast using color and luminance features
|
| 226 |
+
# combine several different filters to remove unwanted high-frequency components
|
| 227 |
+
|
| 228 |
+
def get_saliency_ft(img):
|
| 229 |
+
|
| 230 |
+
img_rgb = img_as_float(img)
|
| 231 |
+
|
| 232 |
+
img_lab = skimage.color.rgb2lab(img_rgb)
|
| 233 |
+
|
| 234 |
+
mean_val = np.mean(img_rgb,axis=(0,1))
|
| 235 |
+
|
| 236 |
+
kernel_h = (1.0/16.0) * np.array([[1,4,6,4,1]])
|
| 237 |
+
kernel_w = kernel_h.transpose()
|
| 238 |
+
|
| 239 |
+
blurred_l = scipy.signal.convolve2d(img_lab[:,:,0],kernel_h,mode='same')
|
| 240 |
+
blurred_a = scipy.signal.convolve2d(img_lab[:,:,1],kernel_h,mode='same')
|
| 241 |
+
blurred_b = scipy.signal.convolve2d(img_lab[:,:,2],kernel_h,mode='same')
|
| 242 |
+
|
| 243 |
+
blurred_l = scipy.signal.convolve2d(blurred_l,kernel_w,mode='same')
|
| 244 |
+
blurred_a = scipy.signal.convolve2d(blurred_a,kernel_w,mode='same')
|
| 245 |
+
blurred_b = scipy.signal.convolve2d(blurred_b,kernel_w,mode='same')
|
| 246 |
+
|
| 247 |
+
im_blurred = np.dstack([blurred_l,blurred_a,blurred_b])
|
| 248 |
+
|
| 249 |
+
sal = np.linalg.norm(mean_val - im_blurred,axis = 2)
|
| 250 |
+
sal_max = np.max(sal)
|
| 251 |
+
sal_min = np.min(sal)
|
| 252 |
+
sal = ((sal - sal_min) / (sal_max - sal_min))
|
| 253 |
+
return sal
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def raster_scan(img,L,U,D):
|
| 258 |
+
n_rows = len(img)
|
| 259 |
+
n_cols = len(img[0])
|
| 260 |
+
|
| 261 |
+
for x in range(1, n_rows - 1):
|
| 262 |
+
for y in range(1, n_cols - 1):
|
| 263 |
+
ix = img[x][y]
|
| 264 |
+
d = D[x][y]
|
| 265 |
+
|
| 266 |
+
u1 = U[x-1][y]
|
| 267 |
+
l1 = L[x-1][y]
|
| 268 |
+
|
| 269 |
+
u2 = U[x][y-1]
|
| 270 |
+
l2 = L[x][y-1]
|
| 271 |
+
|
| 272 |
+
b1 = max(u1,ix) - min(l1,ix)
|
| 273 |
+
b2 = max(u2,ix) - min(l2,ix)
|
| 274 |
+
|
| 275 |
+
if d <= b1 and d <= b2:
|
| 276 |
+
continue
|
| 277 |
+
elif b1 < d and b1 <= b2:
|
| 278 |
+
D[x][y] = b1
|
| 279 |
+
U[x][y] = max(u1, ix)
|
| 280 |
+
L[x][y] = min(l1, ix)
|
| 281 |
+
else:
|
| 282 |
+
D[x][y] = b2
|
| 283 |
+
U[x][y] = max(u2, ix)
|
| 284 |
+
L[x][y] = min(l2, ix)
|
| 285 |
+
return True
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def raster_scan(img, L, U, D):
|
| 289 |
+
n_rows = len(img)
|
| 290 |
+
n_cols = len(img[0])
|
| 291 |
+
|
| 292 |
+
for x in range(1, n_rows - 1):
|
| 293 |
+
for y in range(1, n_cols - 1):
|
| 294 |
+
ix = img[x][y]
|
| 295 |
+
d = D[x][y]
|
| 296 |
+
|
| 297 |
+
u1 = U[x-1][y]
|
| 298 |
+
l1 = L[x-1][y]
|
| 299 |
+
|
| 300 |
+
u2 = U[x][y-1]
|
| 301 |
+
l2 = L[x][y-1]
|
| 302 |
+
|
| 303 |
+
b1 = max(u1,ix) - min(l1,ix)
|
| 304 |
+
b2 = max(u2,ix) - min(l2,ix)
|
| 305 |
+
|
| 306 |
+
if d <= b1 and d <= b2:
|
| 307 |
+
continue
|
| 308 |
+
elif b1 < d and b1 <= b2:
|
| 309 |
+
D[x][y] = b1
|
| 310 |
+
U[x][y] = max(u1, ix)
|
| 311 |
+
L[x][y] = min(l1, ix)
|
| 312 |
+
else:
|
| 313 |
+
D[x][y] = b2
|
| 314 |
+
U[x][y] = max(u2, ix)
|
| 315 |
+
L[x][y] = min(l2, ix)
|
| 316 |
+
return True
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def raster_scan_inv(img,L,U,D):
|
| 323 |
+
n_rows = len(img)
|
| 324 |
+
n_cols = len(img[0])
|
| 325 |
+
|
| 326 |
+
for x in range(n_rows - 2, 1, -1):
|
| 327 |
+
for y in range(n_cols - 2, 1, -1):
|
| 328 |
+
|
| 329 |
+
ix = img[x][y]
|
| 330 |
+
d = D[x][y]
|
| 331 |
+
|
| 332 |
+
u1 = U[x+1][y]
|
| 333 |
+
l1 = L[x+1][y]
|
| 334 |
+
|
| 335 |
+
u2 = U[x][y+1]
|
| 336 |
+
l2 = L[x][y+1]
|
| 337 |
+
|
| 338 |
+
b1 = max(u1,ix) - min(l1,ix)
|
| 339 |
+
b2 = max(u2,ix) - min(l2,ix)
|
| 340 |
+
|
| 341 |
+
if d <= b1 and d <= b2:
|
| 342 |
+
continue
|
| 343 |
+
elif b1 < d and b1 <= b2:
|
| 344 |
+
D[x][y] = b1
|
| 345 |
+
U[x][y] = max(u1,ix)
|
| 346 |
+
L[x][y] = min(l1,ix)
|
| 347 |
+
else:
|
| 348 |
+
D[x][y] = b2
|
| 349 |
+
U[x][y] = max(u2,ix)
|
| 350 |
+
L[x][y] = min(l2,ix)
|
| 351 |
+
return True
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def mbd(img, num_iters):
|
| 355 |
+
if len(img.shape) != 2:
|
| 356 |
+
print('did not get 2d np array to fast mbd')
|
| 357 |
+
return None
|
| 358 |
+
if (img.shape[0] <= 3 or img.shape[1] <= 3):
|
| 359 |
+
print('image is too small')
|
| 360 |
+
return None
|
| 361 |
+
|
| 362 |
+
L = np.copy(img)
|
| 363 |
+
U = np.copy(img)
|
| 364 |
+
D = float('Inf') * np.ones(img.shape)
|
| 365 |
+
D[0,:] = 0
|
| 366 |
+
D[-1,:] = 0
|
| 367 |
+
D[:,0] = 0
|
| 368 |
+
D[:,-1] = 0
|
| 369 |
+
|
| 370 |
+
# unfortunately, iterating over numpy arrays is very slow
|
| 371 |
+
img_list = img.tolist()
|
| 372 |
+
L_list = L.tolist()
|
| 373 |
+
U_list = U.tolist()
|
| 374 |
+
D_list = D.tolist()
|
| 375 |
+
|
| 376 |
+
# start_time = time.time()
|
| 377 |
+
for x in range(0, num_iters):
|
| 378 |
+
if x%2 == 1:
|
| 379 |
+
raster_scan(img_list, L_list, U_list, D_list)
|
| 380 |
+
else:
|
| 381 |
+
raster_scan_inv(img_list, L_list, U_list, D_list)
|
| 382 |
+
|
| 383 |
+
# end_time = time.time()
|
| 384 |
+
# print('mbd function: ', end_time-start_time)
|
| 385 |
+
return np.array(D_list)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# Saliency map calculation based on:
|
| 389 |
+
# Minimum Barrier Salient Object Detection at 80 FPS
|
| 390 |
+
# based on the Image Boundary Connectivity Cue, which assumes that
|
| 391 |
+
# background regions are usually connected to the image borders
|
| 392 |
+
# cons:
|
| 393 |
+
# doesn't consider spatial info, distant objects are detected as the same object
|
| 394 |
+
# fail when the contrast between foreground and background are small
|
| 395 |
+
|
| 396 |
+
def get_saliency_mbd(img):
|
| 397 |
+
|
| 398 |
+
img_mean = np.mean(img, axis=(2))
|
| 399 |
+
# start_time = time.time()
|
| 400 |
+
sal = mbd(img_mean,3)
|
| 401 |
+
# end_time = time.time()
|
| 402 |
+
# print('mbd function: ', end_time-start_time)
|
| 403 |
+
|
| 404 |
+
# get the background map
|
| 405 |
+
# paper uses 30px for an image of size 300px, so we use 10%
|
| 406 |
+
(n_rows, n_cols, n_channels) = img.shape
|
| 407 |
+
img_size = math.sqrt(n_rows * n_cols)
|
| 408 |
+
border_thickness = int(math.floor(0.1 * img_size))
|
| 409 |
+
|
| 410 |
+
img_lab = img_as_float(skimage.color.rgb2lab(img))
|
| 411 |
+
|
| 412 |
+
px_left = img_lab[0:border_thickness,:,:]
|
| 413 |
+
px_right = img_lab[n_rows - border_thickness-1:-1,:,:]
|
| 414 |
+
|
| 415 |
+
px_top = img_lab[:,0:border_thickness,:]
|
| 416 |
+
px_bottom = img_lab[:,n_cols - border_thickness-1:-1,:]
|
| 417 |
+
|
| 418 |
+
px_mean_left = np.mean(px_left,axis=(0,1))
|
| 419 |
+
px_mean_right = np.mean(px_right,axis=(0,1))
|
| 420 |
+
px_mean_top = np.mean(px_top,axis=(0,1))
|
| 421 |
+
px_mean_bottom = np.mean(px_bottom,axis=(0,1))
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
px_left = px_left.reshape((n_cols*border_thickness,3))
|
| 425 |
+
px_right = px_right.reshape((n_cols*border_thickness,3))
|
| 426 |
+
|
| 427 |
+
px_top = px_top.reshape((n_rows*border_thickness,3))
|
| 428 |
+
px_bottom = px_bottom.reshape((n_rows*border_thickness,3))
|
| 429 |
+
|
| 430 |
+
cov_left = np.cov(px_left.T)
|
| 431 |
+
cov_right = np.cov(px_right.T)
|
| 432 |
+
|
| 433 |
+
cov_top = np.cov(px_top.T)
|
| 434 |
+
cov_bottom = np.cov(px_bottom.T)
|
| 435 |
+
|
| 436 |
+
cov_left = np.linalg.inv(cov_left)
|
| 437 |
+
cov_right = np.linalg.inv(cov_right)
|
| 438 |
+
|
| 439 |
+
cov_top = np.linalg.inv(cov_top)
|
| 440 |
+
cov_bottom = np.linalg.inv(cov_bottom)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
u_left = np.zeros(sal.shape)
|
| 444 |
+
u_right = np.zeros(sal.shape)
|
| 445 |
+
u_top = np.zeros(sal.shape)
|
| 446 |
+
u_bottom = np.zeros(sal.shape)
|
| 447 |
+
|
| 448 |
+
u_final = np.zeros(sal.shape)
|
| 449 |
+
img_lab_unrolled = img_lab.reshape(img_lab.shape[0]*img_lab.shape[1],3)
|
| 450 |
+
|
| 451 |
+
px_mean_left_2 = np.zeros((1,3))
|
| 452 |
+
px_mean_left_2[0,:] = px_mean_left
|
| 453 |
+
|
| 454 |
+
u_left = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_left_2,'mahalanobis', VI=cov_left)
|
| 455 |
+
u_left = u_left.reshape((img_lab.shape[0],img_lab.shape[1]))
|
| 456 |
+
|
| 457 |
+
px_mean_right_2 = np.zeros((1,3))
|
| 458 |
+
px_mean_right_2[0,:] = px_mean_right
|
| 459 |
+
|
| 460 |
+
u_right = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_right_2,'mahalanobis', VI=cov_right)
|
| 461 |
+
u_right = u_right.reshape((img_lab.shape[0],img_lab.shape[1]))
|
| 462 |
+
|
| 463 |
+
px_mean_top_2 = np.zeros((1,3))
|
| 464 |
+
px_mean_top_2[0,:] = px_mean_top
|
| 465 |
+
|
| 466 |
+
u_top = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_top_2,'mahalanobis', VI=cov_top)
|
| 467 |
+
u_top = u_top.reshape((img_lab.shape[0],img_lab.shape[1]))
|
| 468 |
+
|
| 469 |
+
px_mean_bottom_2 = np.zeros((1,3))
|
| 470 |
+
px_mean_bottom_2[0,:] = px_mean_bottom
|
| 471 |
+
|
| 472 |
+
u_bottom = scipy.spatial.distance.cdist(img_lab_unrolled,px_mean_bottom_2,'mahalanobis', VI=cov_bottom)
|
| 473 |
+
u_bottom = u_bottom.reshape((img_lab.shape[0],img_lab.shape[1]))
|
| 474 |
+
|
| 475 |
+
max_u_left = np.max(u_left)
|
| 476 |
+
max_u_right = np.max(u_right)
|
| 477 |
+
max_u_top = np.max(u_top)
|
| 478 |
+
max_u_bottom = np.max(u_bottom)
|
| 479 |
+
|
| 480 |
+
u_left = u_left / max_u_left
|
| 481 |
+
u_right = u_right / max_u_right
|
| 482 |
+
u_top = u_top / max_u_top
|
| 483 |
+
u_bottom = u_bottom / max_u_bottom
|
| 484 |
+
|
| 485 |
+
u_max = np.maximum(np.maximum(np.maximum(u_left,u_right),u_top),u_bottom)
|
| 486 |
+
|
| 487 |
+
u_final = (u_left + u_right + u_top + u_bottom) - u_max
|
| 488 |
+
|
| 489 |
+
u_max_final = np.max(u_final)
|
| 490 |
+
sal_max = np.max(sal)
|
| 491 |
+
sal = sal / sal_max + u_final / u_max_final
|
| 492 |
+
|
| 493 |
+
#postprocessing
|
| 494 |
+
# apply centredness map
|
| 495 |
+
sal = sal / np.max(sal)
|
| 496 |
+
|
| 497 |
+
s = np.mean(sal)
|
| 498 |
+
alpha = 50.0
|
| 499 |
+
delta = alpha * math.sqrt(s)
|
| 500 |
+
|
| 501 |
+
xv,yv = np.meshgrid(np.arange(sal.shape[1]),np.arange(sal.shape[0]))
|
| 502 |
+
(w,h) = sal.shape
|
| 503 |
+
w2 = w/2.0
|
| 504 |
+
h2 = h/2.0
|
| 505 |
+
|
| 506 |
+
C = 1 - np.sqrt(np.power(xv - h2,2) + np.power(yv - w2,2)) / math.sqrt(np.power(w2,2) + np.power(h2,2))
|
| 507 |
+
sal = sal * C
|
| 508 |
+
|
| 509 |
+
#increase bg/fg contrast
|
| 510 |
+
def f(x):
|
| 511 |
+
b = 10.0
|
| 512 |
+
return 1.0 / (1.0 + math.exp(-b*(x - 0.5)))
|
| 513 |
+
|
| 514 |
+
fv = np.vectorize(f)
|
| 515 |
+
sal = sal / np.max(sal)
|
| 516 |
+
sal = fv(sal)
|
| 517 |
+
return sal
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def binarise_saliency_map(saliency_map, method='adaptive',threshold=0.5):
|
| 521 |
+
|
| 522 |
+
# check if input is a numpy array
|
| 523 |
+
if type(saliency_map).__module__ != np.__name__:
|
| 524 |
+
print('Expected numpy array')
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
#check if input is 2D
|
| 528 |
+
if len(saliency_map.shape) != 2:
|
| 529 |
+
print('Saliency map must be 2D')
|
| 530 |
+
return None
|
| 531 |
+
|
| 532 |
+
if method == 'fixed':
|
| 533 |
+
return (saliency_map > threshold)
|
| 534 |
+
|
| 535 |
+
elif method == 'adaptive':
|
| 536 |
+
adaptive_threshold = 2.0 * saliency_map.mean()
|
| 537 |
+
return (saliency_map > adaptive_threshold)
|
| 538 |
+
|
| 539 |
+
elif method == 'clustering':
|
| 540 |
+
print('Not yet implemented')
|
| 541 |
+
return None
|
| 542 |
+
|
| 543 |
+
else:
|
| 544 |
+
print("Method must be one of fixed, adaptive or clustering")
|
| 545 |
+
return None
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
if __name__ == '__main__':
|
| 549 |
+
# path to the image
|
| 550 |
+
filename = '../images/flower/19569518092_2db12519fd_c.jpg'
|
| 551 |
+
# filename = './images/landmark_04/12004354405_dc546d53ce_c.jpg'
|
| 552 |
+
|
| 553 |
+
img = skimage.io.imread(filename)
|
| 554 |
+
|
| 555 |
+
if len(img.shape) != 3: # got a grayscale image
|
| 556 |
+
img = skimage.color.gray2rgb(img)
|
| 557 |
+
|
| 558 |
+
# get the saliency maps using the 3 implemented methods
|
| 559 |
+
start_time = time.time()
|
| 560 |
+
rbd = get_saliency_rbd(img)
|
| 561 |
+
end_time = time.time()
|
| 562 |
+
print('rbd:', end_time-start_time)
|
| 563 |
+
|
| 564 |
+
start_time = time.time()
|
| 565 |
+
ft = get_saliency_ft(img)
|
| 566 |
+
end_time = time.time()
|
| 567 |
+
print('ft:', end_time-start_time)
|
| 568 |
+
|
| 569 |
+
start_time = time.time()
|
| 570 |
+
mbd_img = get_saliency_mbd(img)
|
| 571 |
+
end_time = time.time()
|
| 572 |
+
print('mbd:', end_time-start_time)
|
| 573 |
+
|
| 574 |
+
# often, it is desirable to have a binary saliency map
|
| 575 |
+
binary_sal = binarise_saliency_map(mbd_img, method='adaptive')
|
| 576 |
+
|
| 577 |
+
img = cv2.imread(filename)
|
| 578 |
+
|
| 579 |
+
# print(mbd.max())
|
| 580 |
+
|
| 581 |
+
cv2.imshow('img', img)
|
| 582 |
+
cv2.imshow('rbd', rbd)
|
| 583 |
+
cv2.imshow('ft', ft)
|
| 584 |
+
cv2.imshow('mbd', mbd_img)
|
| 585 |
+
|
| 586 |
+
#openCV cannot display numpy type 0, so convert to uint8 and scale
|
| 587 |
+
cv2.imshow('binary', 255 * binary_sal.astype('uint8'))
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
cv2.waitKey(0)
|
solve_group_palette.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.cluster import KMeans
|
| 4 |
+
from sklearn.metrics import pairwise_distances
|
| 5 |
+
|
| 6 |
+
from utils import stack_list
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def solve_group_palette(hist_sample_all, hist_counts_all, center, density, reference, m, lightness=70, eta=0, gamma=1e10, iteration=10):
|
| 11 |
+
num_img = len(center)
|
| 12 |
+
Lout = [i for i in range(num_img)]
|
| 13 |
+
|
| 14 |
+
if num_img > 1:
|
| 15 |
+
gamma = gamma / ((num_img-1)/2)
|
| 16 |
+
lbd = gamma / 50
|
| 17 |
+
|
| 18 |
+
old_val = 0
|
| 19 |
+
init_min=np.Inf
|
| 20 |
+
|
| 21 |
+
if m == 1:
|
| 22 |
+
M = np.mean(hist_sample_all, axis=0)
|
| 23 |
+
else:
|
| 24 |
+
cinits = np.zeros((m, np.size(hist_sample_all, 1)))
|
| 25 |
+
cw = hist_counts_all
|
| 26 |
+
for i in range(m):
|
| 27 |
+
id = np.argmax(cw)
|
| 28 |
+
cinits[i,:] = hist_sample_all[id,:]
|
| 29 |
+
d2 = cinits[i,:]* np.ones((np.size(hist_sample_all, 0), 1)) - hist_sample_all
|
| 30 |
+
d2 = np.sum(np.square(d2), axis=1)
|
| 31 |
+
d2 = d2/np.max(d2)
|
| 32 |
+
cw = cw * (d2**2)
|
| 33 |
+
|
| 34 |
+
kmeans_grp = KMeans(n_clusters=m, init=cinits, n_init=1).fit(
|
| 35 |
+
hist_sample_all, y=None, sample_weight=hist_counts_all)
|
| 36 |
+
M = kmeans_grp.cluster_centers_
|
| 37 |
+
|
| 38 |
+
if np.size(hist_sample_all, 1) == 2:
|
| 39 |
+
# print(M.shape)
|
| 40 |
+
if M.ndim == 1:
|
| 41 |
+
M = np.expand_dims(M, axis=0)
|
| 42 |
+
M = np.insert(M, 0, values=lightness, axis=1)
|
| 43 |
+
# print(M.shape)
|
| 44 |
+
|
| 45 |
+
## choose the nearest cluster center from all the individual palettes as the inital group palette
|
| 46 |
+
|
| 47 |
+
# center_r0 = delete_num(center, 0)
|
| 48 |
+
# density_r0 = delete_num(density, 0)
|
| 49 |
+
# center_all = center_r0[0]
|
| 50 |
+
# for i in range(len(center_r0)-1):
|
| 51 |
+
# center_all = np.vstack([center_all, center_r0[i+1]])
|
| 52 |
+
|
| 53 |
+
center_r0 = delete_num(center)
|
| 54 |
+
density_r0 = delete_num(density)
|
| 55 |
+
center_all = stack_list(center)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if M.ndim == 1:
|
| 59 |
+
M=M.reshape(1, -1)
|
| 60 |
+
|
| 61 |
+
D = pairwise_distances(M, center_all, metric='euclidean')
|
| 62 |
+
idx = np.argmin(D, 1)
|
| 63 |
+
|
| 64 |
+
# center_all
|
| 65 |
+
|
| 66 |
+
M = center_all[idx,:]
|
| 67 |
+
|
| 68 |
+
## solve the group palette according to the requirement (gamma and eta)
|
| 69 |
+
for t in range(iteration):
|
| 70 |
+
sum_val = 0
|
| 71 |
+
# solve for the assignment (matching)
|
| 72 |
+
for i in range(num_img):
|
| 73 |
+
# print(center[i])
|
| 74 |
+
if center[i].size != 0:
|
| 75 |
+
# if center[i] is not 0:
|
| 76 |
+
# print(center[i])
|
| 77 |
+
Lout[i], val = solve_optimal_ind_palette(center[i], density[i], M, lbd, gamma, eta, init_min)
|
| 78 |
+
sum_val = sum_val + val
|
| 79 |
+
else:
|
| 80 |
+
Lout[i] = np.array([])
|
| 81 |
+
sum_val = 0
|
| 82 |
+
|
| 83 |
+
# re-compute the group color theme (mean colors)
|
| 84 |
+
Lout_r0 = delete_num(Lout)
|
| 85 |
+
idx = detect_num(Lout)
|
| 86 |
+
reference_r0 = reference[idx]
|
| 87 |
+
M = solve_mean(center_r0, reference_r0, density_r0, Lout_r0, m, len(reference_r0), lbd)
|
| 88 |
+
print('Iteration {}, val: {}'.format(t, sum_val))
|
| 89 |
+
if np.abs(old_val - sum_val) < 10:
|
| 90 |
+
break
|
| 91 |
+
else:
|
| 92 |
+
old_val = sum_val
|
| 93 |
+
|
| 94 |
+
return M, Lout
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def delete_num(list_org):
|
| 98 |
+
list_new=[]
|
| 99 |
+
for i in list_org:
|
| 100 |
+
if i.size != 0:
|
| 101 |
+
list_new.append(i)
|
| 102 |
+
return list_new
|
| 103 |
+
|
| 104 |
+
def detect_num(list_org):
|
| 105 |
+
list_idx=[]
|
| 106 |
+
for i in range(len(list_org)):
|
| 107 |
+
if list_org[i].size != 0:
|
| 108 |
+
list_idx.append(i)
|
| 109 |
+
return list_idx
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def solve_optimal_ind_palette(center, density, center_mean, lambd, gamma, eta, init_min):
|
| 113 |
+
|
| 114 |
+
n = np.size(center, 0)
|
| 115 |
+
m = np.size(center_mean, 0)
|
| 116 |
+
# print(eta)
|
| 117 |
+
|
| 118 |
+
# brute-force all possible cases
|
| 119 |
+
min_obj_func = init_min
|
| 120 |
+
D1 = pairwise_distances(center, center, metric='euclidean')
|
| 121 |
+
D2 = psub2(center, center_mean)
|
| 122 |
+
|
| 123 |
+
dist = pairwise_distances(center, center_mean, metric='euclidean')
|
| 124 |
+
dist = np.tile(np.expand_dims(density, axis=1), (1,m)) * (dist**2)
|
| 125 |
+
|
| 126 |
+
D3 = np.insert(dist, 0, values=np.zeros(n), axis=1)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
num_of_pairs = (m+1)**n
|
| 130 |
+
label = np.zeros((n,1)).astype(np.int32)
|
| 131 |
+
min_label_com = label.copy()
|
| 132 |
+
label[n-1, :] = -1
|
| 133 |
+
for idx in range(num_of_pairs):
|
| 134 |
+
label[n-1] = label[n-1] + 1
|
| 135 |
+
curId = n-1
|
| 136 |
+
while label[curId] > m:
|
| 137 |
+
label[curId] = 0
|
| 138 |
+
curId = curId - 1
|
| 139 |
+
label[curId] = label[curId] + 1
|
| 140 |
+
term4 = np.sum((label==0) * np.expand_dims(density, axis=1))
|
| 141 |
+
val = eta * term4
|
| 142 |
+
if val >= min_obj_func:
|
| 143 |
+
continue
|
| 144 |
+
for i in range(n):
|
| 145 |
+
val = val + D3[i, label[i]] #the first term
|
| 146 |
+
# cut down the unsolution
|
| 147 |
+
if val >= min_obj_func:
|
| 148 |
+
continue
|
| 149 |
+
term2 = 0
|
| 150 |
+
term3 = 0
|
| 151 |
+
for ii in range(n-1):
|
| 152 |
+
for jj in range(ii+1, n):
|
| 153 |
+
term2 = term2 + D2[ii, jj, label[ii], label[jj]]
|
| 154 |
+
if label[ii] == label[jj] and label[ii] > 0:
|
| 155 |
+
term3 = term3 + D1[ii, jj]
|
| 156 |
+
|
| 157 |
+
val = val + lambd * term2 + gamma * term3
|
| 158 |
+
if val < min_obj_func:
|
| 159 |
+
min_obj_func = val.copy()
|
| 160 |
+
min_label_com = label.copy()
|
| 161 |
+
|
| 162 |
+
return min_label_com, min_obj_func
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def solve_mean(center, reference, density, L, m, n, lambd):
|
| 168 |
+
A = np.zeros((m, m))
|
| 169 |
+
B = np.zeros((m, 3))
|
| 170 |
+
M = np.zeros((m, 3))
|
| 171 |
+
|
| 172 |
+
for i in range(n):
|
| 173 |
+
if reference[i] == 0:
|
| 174 |
+
continue
|
| 175 |
+
Pi = center[i]
|
| 176 |
+
Wi = density[i]
|
| 177 |
+
Li = L[i]
|
| 178 |
+
ni = np.size(Pi, 0)
|
| 179 |
+
# first term
|
| 180 |
+
for j in range(ni):
|
| 181 |
+
if Li[j] > 0:
|
| 182 |
+
A[Li[j]-1, Li[j]-1] = A[Li[j]-1, Li[j]-1] + Wi[j]
|
| 183 |
+
B[Li[j]-1, :] = B[Li[j]-1, :] + np.expand_dims(Wi[j] * Pi[j, :], axis=0)
|
| 184 |
+
|
| 185 |
+
# second term
|
| 186 |
+
for j1 in range(ni-1):
|
| 187 |
+
for j2 in range(j1+1, ni):
|
| 188 |
+
if Li[j1] > 0 and Li[j2] > 0:
|
| 189 |
+
A[Li[j1]-1, Li[j1]-1] = A[Li[j1]-1, Li[j1]-1] + lambd
|
| 190 |
+
A[Li[j1]-1, Li[j2]-1] = A[Li[j1]-1, Li[j2]-1] - lambd
|
| 191 |
+
A[Li[j2]-1, Li[j2]-1] = A[Li[j2]-1, Li[j2]-1] + lambd
|
| 192 |
+
A[Li[j2]-1, Li[j1]-1] = A[Li[j2]-1, Li[j1]-1] - lambd
|
| 193 |
+
B[Li[j1]-1, :] = B[Li[j1]-1, :] + lambd * (Pi[j1,:]-Pi[j2,:])
|
| 194 |
+
B[Li[j2]-1, :] = B[Li[j2]-1, :] + lambd * (Pi[j2,:]-Pi[j1,:])
|
| 195 |
+
|
| 196 |
+
# solve least squares
|
| 197 |
+
# print(A)
|
| 198 |
+
M = np.dot(np.linalg.pinv(A), B)
|
| 199 |
+
|
| 200 |
+
# M[np.isnan[M]] = 0
|
| 201 |
+
|
| 202 |
+
# k-median, take the nearest from the individual palettes
|
| 203 |
+
M = choose_mediod(center, M)
|
| 204 |
+
|
| 205 |
+
return M
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def concat_list(input, axis=0):
|
| 210 |
+
list_cat = input[0]
|
| 211 |
+
for i in range(len(input)-1):
|
| 212 |
+
list_cat = np.concatenate((list_cat, input[i+1]), axis=axis)
|
| 213 |
+
# list_cat = np.vstack([list_cat, input[i+1]])
|
| 214 |
+
return list_cat
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def choose_mediod(Pin, M):
|
| 218 |
+
P = concat_list(Pin)
|
| 219 |
+
D = pairwise_distances(M, P)
|
| 220 |
+
idx = np.argmin(D, axis=1)
|
| 221 |
+
M = P[idx,:]
|
| 222 |
+
return M
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def psub2(P, M):
|
| 227 |
+
p = np.size(P, 0)
|
| 228 |
+
m = np.size(M, 0)
|
| 229 |
+
D = np.zeros((p, p, m + 1, m + 1))
|
| 230 |
+
for i1 in range(p-1):
|
| 231 |
+
for i2 in range(i1+1, p):
|
| 232 |
+
for i3 in range(m):
|
| 233 |
+
for i4 in range(m):
|
| 234 |
+
D[i1,i2,i3+1,i4+1] = np.sum((P[i1,:] - P[i2,:] - M[i3,:] + M[i4,:])**2)
|
| 235 |
+
return D
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
# import matplotlib.pyplot as plt
|
| 4 |
+
from skimage.color import rgb2lab, lab2rgb, rgb2hsv, hsv2rgb
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def stack_list(x):
|
| 9 |
+
stack = []
|
| 10 |
+
for id, val in enumerate(x):
|
| 11 |
+
if np.size(val) != 0:
|
| 12 |
+
if stack == []:
|
| 13 |
+
stack = val
|
| 14 |
+
else:
|
| 15 |
+
stack = np.vstack([stack, val])
|
| 16 |
+
return stack
|
| 17 |
+
|
| 18 |
+
def rgb_to_hex(r, g, b):
|
| 19 |
+
return '#{:02x}{:02x}{:02x}'.format(r, g, b)
|
| 20 |
+
|
| 21 |
+
def hex_to_rgb(hex):
|
| 22 |
+
# print(hex)
|
| 23 |
+
rgb = []
|
| 24 |
+
for i in (1, 3, 5):
|
| 25 |
+
decimal = int(hex[i:i+2], 16)
|
| 26 |
+
rgb.append(decimal)
|
| 27 |
+
return tuple(rgb)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def image_resize(img, c_w, c_h):
|
| 31 |
+
# img : PIL Image
|
| 32 |
+
if type(img) is np.ndarray:
|
| 33 |
+
img = Image.fromarray(img)
|
| 34 |
+
h, w = img.size
|
| 35 |
+
h_factor = c_h / w
|
| 36 |
+
w_factor = c_w / h
|
| 37 |
+
# factor = h_factor
|
| 38 |
+
factor = np.minimum(h_factor, w_factor)
|
| 39 |
+
# print(w*factor, h*factor)
|
| 40 |
+
img = img.resize((np.round(h*factor).astype(np.int64),
|
| 41 |
+
np.round(w*factor).astype(np.int64)),
|
| 42 |
+
Image.BILINEAR)
|
| 43 |
+
return img
|
| 44 |
+
|
| 45 |
+
def get_palette(num_cls):
|
| 46 |
+
""" Returns the color map for visualizing the segmentation mask.
|
| 47 |
+
Args:
|
| 48 |
+
num_cls: Number of classes
|
| 49 |
+
Returns:
|
| 50 |
+
The color map
|
| 51 |
+
"""
|
| 52 |
+
n = num_cls
|
| 53 |
+
palette = [0] * (n * 3)
|
| 54 |
+
for j in range(0, n):
|
| 55 |
+
lab = j
|
| 56 |
+
palette[j * 3 + 0] = 0
|
| 57 |
+
palette[j * 3 + 1] = 0
|
| 58 |
+
palette[j * 3 + 2] = 0
|
| 59 |
+
i = 0
|
| 60 |
+
while lab:
|
| 61 |
+
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
| 62 |
+
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
| 63 |
+
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
| 64 |
+
i += 1
|
| 65 |
+
lab >>= 3
|
| 66 |
+
return palette
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def visualize_palette(palette_lab, patch_size=20):
|
| 70 |
+
# print(palette_lab)
|
| 71 |
+
if palette_lab is None:
|
| 72 |
+
return np.ones((patch_size, patch_size, 3)) * [1.,1.,1.]
|
| 73 |
+
palette_lab = np.expand_dims(palette_lab, axis=0)
|
| 74 |
+
# palette_lab = np.sort(palette_lab, axis=1)
|
| 75 |
+
|
| 76 |
+
# # lab transfer by lookuptable
|
| 77 |
+
# # cluster_cen_rgb = lab2rgb_lut(cluster_cen_lab)
|
| 78 |
+
palette_rgb = lab2rgb(palette_lab)
|
| 79 |
+
palette_rgb = np.squeeze(palette_rgb, axis=0)
|
| 80 |
+
|
| 81 |
+
for id in range(np.size(palette_rgb, 0)):
|
| 82 |
+
rgb = np.expand_dims(palette_rgb[id,:], axis=(0, 1))
|
| 83 |
+
if id==0:
|
| 84 |
+
img_palette = np.ones((patch_size, patch_size, 3)) * rgb
|
| 85 |
+
else:
|
| 86 |
+
img_palette = np.append(img_palette, np.ones((patch_size, patch_size, 3)) * rgb, axis=1)
|
| 87 |
+
|
| 88 |
+
return img_palette
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def visualize_palette_rgb(palette_rgb, patch_size=20):
|
| 92 |
+
# print(palette_lab)
|
| 93 |
+
if palette_rgb == 0:
|
| 94 |
+
return np.ones((patch_size, patch_size, 3)) * [1.,1.,1.]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
for id in range(np.size(palette_rgb, 0)):
|
| 98 |
+
rgb = np.expand_dims(palette_rgb[id,:], axis=(0, 1))
|
| 99 |
+
if id==0:
|
| 100 |
+
img_palette = np.ones((patch_size, patch_size, 3)) * rgb
|
| 101 |
+
else:
|
| 102 |
+
img_palette = np.append(img_palette, np.ones((patch_size, patch_size, 3)) * rgb, axis=1)
|
| 103 |
+
|
| 104 |
+
return img_palette
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# def vis_consistency(img_rgb_all, img_rgb_out_all, label_colored_all, c_center, L_idx):
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def color_difference(img1, img2):
|
| 113 |
+
h, w, c = img1.shape
|
| 114 |
+
img1_lab = rgb2lab(img1)
|
| 115 |
+
img2_lab = rgb2lab(img2)
|
| 116 |
+
|
| 117 |
+
diff=img1_lab-img2_lab
|
| 118 |
+
|
| 119 |
+
dE = np.sqrt(diff[:,:,0]**2 + diff[:,:,1]**2 + diff[:,:,2]**2)
|
| 120 |
+
# dE = np.sqrt(diff[:,:,0]**2 + diff[:,:,0]**2)
|
| 121 |
+
dE = np.sum(dE)/(h*w)
|
| 122 |
+
|
| 123 |
+
return dE
|
| 124 |
+
|