Spaces:
Sleeping
Sleeping
Commit
·
f0e5caa
0
Parent(s):
model upload
Browse files- .gitattributes +2 -0
- LICENSE +407 -0
- README.md +168 -0
- app.py +40 -0
- app_utils.py +906 -0
- config.yml +51 -0
- feature_extraction/extractor_mediapipe.py +340 -0
- feature_extraction/features_extractor.py +48 -0
- gradio_app.py +388 -0
- gradio_utils.py +300 -0
- pre_trained_models/ResNet18/left_eye.pt +3 -0
- pre_trained_models/ResNet18/right_eye.pt +3 -0
- pre_trained_models/ResNet50/left_eye.pt +3 -0
- pre_trained_models/ResNet50/right_eye.pt +3 -0
- preprocessing/dataset_creation.py +26 -0
- preprocessing/dataset_creation_utils.py +14 -0
- registrations/models.py +56 -0
- registry.py +82 -0
- registry_utils.py +79 -0
- requirements.txt +28 -0
- sample_videos/All Smiles Ahead.webm +3 -0
- sample_videos/And it was all Yellow.webm +3 -0
- sample_videos/Blink It Like Brian.webm +3 -0
- sample_videos/Focus Pocus.webm +3 -0
- sample_videos/Funny Talks.webm +3 -0
- sample_videos/I like to move it move it.webm +3 -0
- sample_videos/Infinite Blue.webm +3 -0
- sample_videos/Red Ross.webm +3 -0
- sample_videos/Smile, You’re on Camera!.webm +3 -0
- utils.py +11 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Attribution-NonCommercial 4.0 International
|
| 2 |
+
|
| 3 |
+
=======================================================================
|
| 4 |
+
|
| 5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
+
does not provide legal services or legal advice. Distribution of
|
| 7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
+
other relationship. Creative Commons makes its licenses and related
|
| 9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
+
warranties regarding its licenses, any material licensed under their
|
| 11 |
+
terms and conditions, or any related information. Creative Commons
|
| 12 |
+
disclaims all liability for damages resulting from their use to the
|
| 13 |
+
fullest extent possible.
|
| 14 |
+
|
| 15 |
+
Using Creative Commons Public Licenses
|
| 16 |
+
|
| 17 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
+
conditions that creators and other rights holders may use to share
|
| 19 |
+
original works of authorship and other material subject to copyright
|
| 20 |
+
and certain other rights specified in the public license below. The
|
| 21 |
+
following considerations are for informational purposes only, are not
|
| 22 |
+
exhaustive, and do not form part of our licenses.
|
| 23 |
+
|
| 24 |
+
Considerations for licensors: Our public licenses are
|
| 25 |
+
intended for use by those authorized to give the public
|
| 26 |
+
permission to use material in ways otherwise restricted by
|
| 27 |
+
copyright and certain other rights. Our licenses are
|
| 28 |
+
irrevocable. Licensors should read and understand the terms
|
| 29 |
+
and conditions of the license they choose before applying it.
|
| 30 |
+
Licensors should also secure all rights necessary before
|
| 31 |
+
applying our licenses so that the public can reuse the
|
| 32 |
+
material as expected. Licensors should clearly mark any
|
| 33 |
+
material not subject to the license. This includes other CC-
|
| 34 |
+
licensed material, or material used under an exception or
|
| 35 |
+
limitation to copyright. More considerations for licensors:
|
| 36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
+
|
| 38 |
+
Considerations for the public: By using one of our public
|
| 39 |
+
licenses, a licensor grants the public permission to use the
|
| 40 |
+
licensed material under specified terms and conditions. If
|
| 41 |
+
the licensor's permission is not necessary for any reason--for
|
| 42 |
+
example, because of any applicable exception or limitation to
|
| 43 |
+
copyright--then that use is not regulated by the license. Our
|
| 44 |
+
licenses grant only permissions under copyright and certain
|
| 45 |
+
other rights that a licensor has authority to grant. Use of
|
| 46 |
+
the licensed material may still be restricted for other
|
| 47 |
+
reasons, including because others have copyright or other
|
| 48 |
+
rights in the material. A licensor may make special requests,
|
| 49 |
+
such as asking that all changes be marked or described.
|
| 50 |
+
Although not required by our licenses, you are encouraged to
|
| 51 |
+
respect those requests where reasonable. More considerations
|
| 52 |
+
for the public:
|
| 53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
+
|
| 55 |
+
=======================================================================
|
| 56 |
+
|
| 57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 58 |
+
License
|
| 59 |
+
|
| 60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 63 |
+
License"). To the extent this Public License may be interpreted as a
|
| 64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
| 65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
| 66 |
+
such rights in consideration of benefits the Licensor receives from
|
| 67 |
+
making the Licensed Material available under these terms and
|
| 68 |
+
conditions.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Section 1 -- Definitions.
|
| 72 |
+
|
| 73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 74 |
+
Rights that is derived from or based upon the Licensed Material
|
| 75 |
+
and in which the Licensed Material is translated, altered,
|
| 76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 77 |
+
permission under the Copyright and Similar Rights held by the
|
| 78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 79 |
+
Material is a musical work, performance, or sound recording,
|
| 80 |
+
Adapted Material is always produced where the Licensed Material is
|
| 81 |
+
synched in timed relation with a moving image.
|
| 82 |
+
|
| 83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 85 |
+
accordance with the terms and conditions of this Public License.
|
| 86 |
+
|
| 87 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 88 |
+
closely related to copyright including, without limitation,
|
| 89 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 90 |
+
Rights, without regard to how the rights are labeled or
|
| 91 |
+
categorized. For purposes of this Public License, the rights
|
| 92 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 93 |
+
Rights.
|
| 94 |
+
d. Effective Technological Measures means those measures that, in the
|
| 95 |
+
absence of proper authority, may not be circumvented under laws
|
| 96 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 97 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 98 |
+
agreements.
|
| 99 |
+
|
| 100 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 101 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 102 |
+
that applies to Your use of the Licensed Material.
|
| 103 |
+
|
| 104 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 105 |
+
or other material to which the Licensor applied this Public
|
| 106 |
+
License.
|
| 107 |
+
|
| 108 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 109 |
+
terms and conditions of this Public License, which are limited to
|
| 110 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 111 |
+
Licensed Material and that the Licensor has authority to license.
|
| 112 |
+
|
| 113 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 114 |
+
under this Public License.
|
| 115 |
+
|
| 116 |
+
i. NonCommercial means not primarily intended for or directed towards
|
| 117 |
+
commercial advantage or monetary compensation. For purposes of
|
| 118 |
+
this Public License, the exchange of the Licensed Material for
|
| 119 |
+
other material subject to Copyright and Similar Rights by digital
|
| 120 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 121 |
+
no payment of monetary compensation in connection with the
|
| 122 |
+
exchange.
|
| 123 |
+
|
| 124 |
+
j. Share means to provide material to the public by any means or
|
| 125 |
+
process that requires permission under the Licensed Rights, such
|
| 126 |
+
as reproduction, public display, public performance, distribution,
|
| 127 |
+
dissemination, communication, or importation, and to make material
|
| 128 |
+
available to the public including in ways that members of the
|
| 129 |
+
public may access the material from a place and at a time
|
| 130 |
+
individually chosen by them.
|
| 131 |
+
|
| 132 |
+
k. Sui Generis Database Rights means rights other than copyright
|
| 133 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 134 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 135 |
+
as amended and/or succeeded, as well as other essentially
|
| 136 |
+
equivalent rights anywhere in the world.
|
| 137 |
+
|
| 138 |
+
l. You means the individual or entity exercising the Licensed Rights
|
| 139 |
+
under this Public License. Your has a corresponding meaning.
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
Section 2 -- Scope.
|
| 143 |
+
|
| 144 |
+
a. License grant.
|
| 145 |
+
|
| 146 |
+
1. Subject to the terms and conditions of this Public License,
|
| 147 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 148 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 149 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 150 |
+
|
| 151 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 152 |
+
in part, for NonCommercial purposes only; and
|
| 153 |
+
|
| 154 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 155 |
+
NonCommercial purposes only.
|
| 156 |
+
|
| 157 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 158 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 159 |
+
License does not apply, and You do not need to comply with
|
| 160 |
+
its terms and conditions.
|
| 161 |
+
|
| 162 |
+
3. Term. The term of this Public License is specified in Section
|
| 163 |
+
6(a).
|
| 164 |
+
|
| 165 |
+
4. Media and formats; technical modifications allowed. The
|
| 166 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 167 |
+
all media and formats whether now known or hereafter created,
|
| 168 |
+
and to make technical modifications necessary to do so. The
|
| 169 |
+
Licensor waives and/or agrees not to assert any right or
|
| 170 |
+
authority to forbid You from making technical modifications
|
| 171 |
+
necessary to exercise the Licensed Rights, including
|
| 172 |
+
technical modifications necessary to circumvent Effective
|
| 173 |
+
Technological Measures. For purposes of this Public License,
|
| 174 |
+
simply making modifications authorized by this Section 2(a)
|
| 175 |
+
(4) never produces Adapted Material.
|
| 176 |
+
|
| 177 |
+
5. Downstream recipients.
|
| 178 |
+
|
| 179 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 180 |
+
recipient of the Licensed Material automatically
|
| 181 |
+
receives an offer from the Licensor to exercise the
|
| 182 |
+
Licensed Rights under the terms and conditions of this
|
| 183 |
+
Public License.
|
| 184 |
+
|
| 185 |
+
b. No downstream restrictions. You may not offer or impose
|
| 186 |
+
any additional or different terms or conditions on, or
|
| 187 |
+
apply any Effective Technological Measures to, the
|
| 188 |
+
Licensed Material if doing so restricts exercise of the
|
| 189 |
+
Licensed Rights by any recipient of the Licensed
|
| 190 |
+
Material.
|
| 191 |
+
|
| 192 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 193 |
+
may be construed as permission to assert or imply that You
|
| 194 |
+
are, or that Your use of the Licensed Material is, connected
|
| 195 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 196 |
+
the Licensor or others designated to receive attribution as
|
| 197 |
+
provided in Section 3(a)(1)(A)(i).
|
| 198 |
+
|
| 199 |
+
b. Other rights.
|
| 200 |
+
|
| 201 |
+
1. Moral rights, such as the right of integrity, are not
|
| 202 |
+
licensed under this Public License, nor are publicity,
|
| 203 |
+
privacy, and/or other similar personality rights; however, to
|
| 204 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 205 |
+
assert any such rights held by the Licensor to the limited
|
| 206 |
+
extent necessary to allow You to exercise the Licensed
|
| 207 |
+
Rights, but not otherwise.
|
| 208 |
+
|
| 209 |
+
2. Patent and trademark rights are not licensed under this
|
| 210 |
+
Public License.
|
| 211 |
+
|
| 212 |
+
3. To the extent possible, the Licensor waives any right to
|
| 213 |
+
collect royalties from You for the exercise of the Licensed
|
| 214 |
+
Rights, whether directly or through a collecting society
|
| 215 |
+
under any voluntary or waivable statutory or compulsory
|
| 216 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 217 |
+
reserves any right to collect such royalties, including when
|
| 218 |
+
the Licensed Material is used other than for NonCommercial
|
| 219 |
+
purposes.
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
Section 3 -- License Conditions.
|
| 223 |
+
|
| 224 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 225 |
+
following conditions.
|
| 226 |
+
|
| 227 |
+
a. Attribution.
|
| 228 |
+
|
| 229 |
+
1. If You Share the Licensed Material (including in modified
|
| 230 |
+
form), You must:
|
| 231 |
+
|
| 232 |
+
a. retain the following if it is supplied by the Licensor
|
| 233 |
+
with the Licensed Material:
|
| 234 |
+
|
| 235 |
+
i. identification of the creator(s) of the Licensed
|
| 236 |
+
Material and any others designated to receive
|
| 237 |
+
attribution, in any reasonable manner requested by
|
| 238 |
+
the Licensor (including by pseudonym if
|
| 239 |
+
designated);
|
| 240 |
+
|
| 241 |
+
ii. a copyright notice;
|
| 242 |
+
|
| 243 |
+
iii. a notice that refers to this Public License;
|
| 244 |
+
|
| 245 |
+
iv. a notice that refers to the disclaimer of
|
| 246 |
+
warranties;
|
| 247 |
+
|
| 248 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 249 |
+
extent reasonably practicable;
|
| 250 |
+
|
| 251 |
+
b. indicate if You modified the Licensed Material and
|
| 252 |
+
retain an indication of any previous modifications; and
|
| 253 |
+
|
| 254 |
+
c. indicate the Licensed Material is licensed under this
|
| 255 |
+
Public License, and include the text of, or the URI or
|
| 256 |
+
hyperlink to, this Public License.
|
| 257 |
+
|
| 258 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 259 |
+
reasonable manner based on the medium, means, and context in
|
| 260 |
+
which You Share the Licensed Material. For example, it may be
|
| 261 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 262 |
+
hyperlink to a resource that includes the required
|
| 263 |
+
information.
|
| 264 |
+
|
| 265 |
+
3. If requested by the Licensor, You must remove any of the
|
| 266 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 267 |
+
reasonably practicable.
|
| 268 |
+
|
| 269 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 270 |
+
License You apply must not prevent recipients of the Adapted
|
| 271 |
+
Material from complying with this Public License.
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
Section 4 -- Sui Generis Database Rights.
|
| 275 |
+
|
| 276 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 277 |
+
apply to Your use of the Licensed Material:
|
| 278 |
+
|
| 279 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 280 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 281 |
+
portion of the contents of the database for NonCommercial purposes
|
| 282 |
+
only;
|
| 283 |
+
|
| 284 |
+
b. if You include all or a substantial portion of the database
|
| 285 |
+
contents in a database in which You have Sui Generis Database
|
| 286 |
+
Rights, then the database in which You have Sui Generis Database
|
| 287 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 288 |
+
|
| 289 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 290 |
+
all or a substantial portion of the contents of the database.
|
| 291 |
+
|
| 292 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 293 |
+
replace Your obligations under this Public License where the Licensed
|
| 294 |
+
Rights include other Copyright and Similar Rights.
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 298 |
+
|
| 299 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 300 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 301 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 302 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 303 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 304 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 305 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 306 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 307 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 308 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 309 |
+
|
| 310 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 311 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 312 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 313 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 314 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 315 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 316 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 317 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 318 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 319 |
+
|
| 320 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 321 |
+
above shall be interpreted in a manner that, to the extent
|
| 322 |
+
possible, most closely approximates an absolute disclaimer and
|
| 323 |
+
waiver of all liability.
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
Section 6 -- Term and Termination.
|
| 327 |
+
|
| 328 |
+
a. This Public License applies for the term of the Copyright and
|
| 329 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 330 |
+
this Public License, then Your rights under this Public License
|
| 331 |
+
terminate automatically.
|
| 332 |
+
|
| 333 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 334 |
+
Section 6(a), it reinstates:
|
| 335 |
+
|
| 336 |
+
1. automatically as of the date the violation is cured, provided
|
| 337 |
+
it is cured within 30 days of Your discovery of the
|
| 338 |
+
violation; or
|
| 339 |
+
|
| 340 |
+
2. upon express reinstatement by the Licensor.
|
| 341 |
+
|
| 342 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 343 |
+
right the Licensor may have to seek remedies for Your violations
|
| 344 |
+
of this Public License.
|
| 345 |
+
|
| 346 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 347 |
+
Licensed Material under separate terms or conditions or stop
|
| 348 |
+
distributing the Licensed Material at any time; however, doing so
|
| 349 |
+
will not terminate this Public License.
|
| 350 |
+
|
| 351 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 352 |
+
License.
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
Section 7 -- Other Terms and Conditions.
|
| 356 |
+
|
| 357 |
+
a. The Licensor shall not be bound by any additional or different
|
| 358 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 359 |
+
|
| 360 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 361 |
+
Licensed Material not stated herein are separate from and
|
| 362 |
+
independent of the terms and conditions of this Public License.
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
Section 8 -- Interpretation.
|
| 366 |
+
|
| 367 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 368 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 369 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 370 |
+
be made without permission under this Public License.
|
| 371 |
+
|
| 372 |
+
b. To the extent possible, if any provision of this Public License is
|
| 373 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 374 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 375 |
+
cannot be reformed, it shall be severed from this Public License
|
| 376 |
+
without affecting the enforceability of the remaining terms and
|
| 377 |
+
conditions.
|
| 378 |
+
|
| 379 |
+
c. No term or condition of this Public License will be waived and no
|
| 380 |
+
failure to comply consented to unless expressly agreed to by the
|
| 381 |
+
Licensor.
|
| 382 |
+
|
| 383 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 384 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 385 |
+
that apply to the Licensor or You, including from the legal
|
| 386 |
+
processes of any jurisdiction or authority.
|
| 387 |
+
|
| 388 |
+
=======================================================================
|
| 389 |
+
|
| 390 |
+
Creative Commons is not a party to its public
|
| 391 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 392 |
+
its public licenses to material it publishes and in those instances
|
| 393 |
+
will be considered the “Licensor.†The text of the Creative Commons
|
| 394 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 395 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 396 |
+
material is shared under a Creative Commons public license or as
|
| 397 |
+
otherwise permitted by the Creative Commons policies published at
|
| 398 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 399 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 400 |
+
of Creative Commons without its prior written consent including,
|
| 401 |
+
without limitation, in connection with any unauthorized modifications
|
| 402 |
+
to any of its public licenses or any other arrangements,
|
| 403 |
+
understandings, or agreements concerning use of licensed material. For
|
| 404 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 405 |
+
public licenses.
|
| 406 |
+
|
| 407 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PupilSense
|
| 3 |
+
emoji: 👁️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.36.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# 👁️ PupilSense 👁️🕵️♂️
|
| 13 |
+
|
| 14 |
+
PupilSense is a deep learning-powered application for estimating pupil diameter from images and videos. It uses trained ResNet models with Class Activation Mapping (CAM) for interpretable predictions.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- **Image Processing**: Upload images to get instant pupil diameter estimates
|
| 19 |
+
- **Video Processing**: Analyze videos frame-by-frame for temporal pupil diameter analysis
|
| 20 |
+
- **Model Selection**: Choose between ResNet18 and ResNet50 architectures
|
| 21 |
+
- **Pupil Selection**: Analyze left pupil, right pupil, or both
|
| 22 |
+
- **Blink Detection**: Automatically detect and handle blinks in the analysis
|
| 23 |
+
- **CAM Visualization**: See which parts of the eye the model focuses on for predictions
|
| 24 |
+
- **API Access**: Full Gradio API support for programmatic access
|
| 25 |
+
|
| 26 |
+
## Usage
|
| 27 |
+
|
| 28 |
+
### Web Interface
|
| 29 |
+
Simply upload an image or video file and configure your analysis parameters:
|
| 30 |
+
- Select pupil(s) to analyze (left, right, or both)
|
| 31 |
+
- Choose the model architecture (ResNet18 or ResNet50)
|
| 32 |
+
- Enable/disable blink detection
|
| 33 |
+
- Click process to get results
|
| 34 |
+
|
| 35 |
+
### API Access
|
| 36 |
+
The Gradio interface provides automatic API endpoints. You can access the API documentation at `/docs` when the app is running.
|
| 37 |
+
|
| 38 |
+
Example API usage:
|
| 39 |
+
```python
|
| 40 |
+
import requests
|
| 41 |
+
import json
|
| 42 |
+
|
| 43 |
+
# For image processing
|
| 44 |
+
files = {"image_input": open("your_image.jpg", "rb")}
|
| 45 |
+
data = {
|
| 46 |
+
"pupil_selection": "both",
|
| 47 |
+
"tv_model": "ResNet18",
|
| 48 |
+
"blink_detection": True
|
| 49 |
+
}
|
| 50 |
+
response = requests.post("https://your-space-url/api/predict", files=files, data=data)
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Model Information
|
| 54 |
+
|
| 55 |
+
The application uses pre-trained ResNet models specifically trained for pupil diameter estimation:
|
| 56 |
+
- **ResNet18**: Faster inference, good accuracy
|
| 57 |
+
- **ResNet50**: Higher accuracy, slower inference
|
| 58 |
+
|
| 59 |
+
Both models support:
|
| 60 |
+
- Input resolution: 32x64 pixels (eye region)
|
| 61 |
+
- Output: Pupil diameter in millimeters
|
| 62 |
+
- CAM visualization for model interpretability
|
| 63 |
+
|
| 64 |
+
## Technical Details
|
| 65 |
+
|
| 66 |
+
- **Face Detection**: MediaPipe for robust face and eye detection
|
| 67 |
+
- **Preprocessing**: Automatic eye region extraction and normalization
|
| 68 |
+
- **Deep Learning**: PyTorch-based ResNet models
|
| 69 |
+
- **Visualization**: Matplotlib for result plotting and CAM overlays
|
| 70 |
+
- **Video Support**: Frame-by-frame analysis with temporal plotting
|
| 71 |
+
|
| 72 |
+
## Installation & Setup
|
| 73 |
+
|
| 74 |
+
### Local Development
|
| 75 |
+
|
| 76 |
+
1. **Clone the repository**
|
| 77 |
+
```bash
|
| 78 |
+
git clone <repository-url>
|
| 79 |
+
cd pupilsense
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
2. **Create virtual environment**
|
| 83 |
+
```bash
|
| 84 |
+
python3 -m venv venv
|
| 85 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
3. **Install dependencies**
|
| 89 |
+
```bash
|
| 90 |
+
pip install -r requirements.txt
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
4. **Run the application**
|
| 94 |
+
```bash
|
| 95 |
+
python app.py
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
The app will be available at `http://localhost:7860`
|
| 99 |
+
|
| 100 |
+
### Hugging Face Spaces Deployment
|
| 101 |
+
|
| 102 |
+
1. **Create a new Space** on Hugging Face with Gradio SDK
|
| 103 |
+
2. **Upload all files** from the pupilsense directory
|
| 104 |
+
3. **Ensure the following files are present:**
|
| 105 |
+
- `app.py` (main application file)
|
| 106 |
+
- `gradio_app.py` (Gradio interface)
|
| 107 |
+
- `gradio_utils.py` (utility functions)
|
| 108 |
+
- `requirements.txt` (dependencies)
|
| 109 |
+
- `README.md` (this file with proper YAML header)
|
| 110 |
+
- `pre_trained_models/` (model files)
|
| 111 |
+
- All other supporting files
|
| 112 |
+
|
| 113 |
+
## Known Issues & Troubleshooting
|
| 114 |
+
|
| 115 |
+
### MediaPipe Issues
|
| 116 |
+
- **Issue**: Segmentation fault or MediaPipe errors in headless environments
|
| 117 |
+
- **Solution**: The app includes error handling for MediaPipe failures. In production environments, ensure proper GPU/display drivers are available.
|
| 118 |
+
|
| 119 |
+
### Model Loading
|
| 120 |
+
- **Issue**: Model files not found
|
| 121 |
+
- **Solution**: Ensure `pre_trained_models/` directory contains the required `.pt` files for both ResNet18 and ResNet50 models.
|
| 122 |
+
|
| 123 |
+
### Memory Usage
|
| 124 |
+
- **Issue**: High memory usage with large videos
|
| 125 |
+
- **Solution**: The app automatically resizes frames to 640x480 to manage memory usage.
|
| 126 |
+
|
| 127 |
+
## File Structure
|
| 128 |
+
|
| 129 |
+
```
|
| 130 |
+
pupilsense/
|
| 131 |
+
├── app.py # Main application entry point
|
| 132 |
+
├── gradio_app.py # Gradio interface definition
|
| 133 |
+
├── gradio_utils.py # Utility functions (MediaPipe-free)
|
| 134 |
+
├── app_utils.py # Original Streamlit utilities (legacy)
|
| 135 |
+
├── requirements.txt # Python dependencies
|
| 136 |
+
├── README.md # This file
|
| 137 |
+
├── config.yml # Configuration file
|
| 138 |
+
├── registry.py # Model registry
|
| 139 |
+
├── registry_utils.py # Registry utilities
|
| 140 |
+
├── utils.py # General utilities
|
| 141 |
+
├── pre_trained_models/ # Trained model files
|
| 142 |
+
│ ├── ResNet18/
|
| 143 |
+
│ │ ├── left_eye.pt
|
| 144 |
+
│ │ └── right_eye.pt
|
| 145 |
+
│ └── ResNet50/
|
| 146 |
+
│ ├── left_eye.pt
|
| 147 |
+
│ └── right_eye.pt
|
| 148 |
+
├── preprocessing/ # Data preprocessing modules
|
| 149 |
+
├── feature_extraction/ # Feature extraction modules
|
| 150 |
+
├── registrations/ # Model registration modules
|
| 151 |
+
└── sample_videos/ # Sample video files
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## Contributing
|
| 155 |
+
|
| 156 |
+
1. Fork the repository
|
| 157 |
+
2. Create a feature branch
|
| 158 |
+
3. Make your changes
|
| 159 |
+
4. Test thoroughly
|
| 160 |
+
5. Submit a pull request
|
| 161 |
+
|
| 162 |
+
## License
|
| 163 |
+
|
| 164 |
+
See LICENSE file for details.
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path as osp
|
| 3 |
+
|
| 4 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
| 5 |
+
sys.path.append(root_path)
|
| 6 |
+
|
| 7 |
+
from gradio_app import create_gradio_interface
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
"""Main function to launch the Gradio interface."""
|
| 11 |
+
try:
|
| 12 |
+
demo = create_gradio_interface()
|
| 13 |
+
|
| 14 |
+
# For Hugging Face Spaces deployment
|
| 15 |
+
import os
|
| 16 |
+
if os.getenv("SPACE_ID") or os.getenv("SYSTEM") == "spaces":
|
| 17 |
+
# Running on Hugging Face Spaces
|
| 18 |
+
demo.launch(share=True)
|
| 19 |
+
else:
|
| 20 |
+
# Running locally
|
| 21 |
+
try:
|
| 22 |
+
demo.launch(
|
| 23 |
+
server_name="0.0.0.0",
|
| 24 |
+
server_port=7860,
|
| 25 |
+
share=False
|
| 26 |
+
)
|
| 27 |
+
except ValueError as e:
|
| 28 |
+
if "shareable link must be created" in str(e):
|
| 29 |
+
print("Localhost not accessible, creating shareable link...")
|
| 30 |
+
demo.launch(share=True)
|
| 31 |
+
else:
|
| 32 |
+
raise e
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error launching app: {e}")
|
| 35 |
+
import traceback
|
| 36 |
+
traceback.print_exc()
|
| 37 |
+
raise e
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
app_utils.py
ADDED
|
@@ -0,0 +1,906 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import cv2
|
| 7 |
+
from matplotlib import pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import streamlit as st
|
| 11 |
+
import torch
|
| 12 |
+
import tempfile
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from torchvision.transforms.functional import to_pil_image
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
from PIL import ImageOps
|
| 17 |
+
import altair as alt
|
| 18 |
+
import streamlit.components.v1 as components
|
| 19 |
+
|
| 20 |
+
from torchcam.methods import CAM
|
| 21 |
+
from torchcam import methods as torchcam_methods
|
| 22 |
+
from torchcam.utils import overlay_mask
|
| 23 |
+
import os.path as osp
|
| 24 |
+
|
| 25 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
| 26 |
+
sys.path.append(root_path)
|
| 27 |
+
|
| 28 |
+
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
| 29 |
+
from utils import get_model
|
| 30 |
+
|
| 31 |
+
CAM_METHODS = ["CAM"]
|
| 32 |
+
# colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
|
| 33 |
+
colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def load_model(model_configs, device="cpu"):
|
| 38 |
+
"""Loads the pre-trained model."""
|
| 39 |
+
model_path = os.path.join(root_path, model_configs["model_path"])
|
| 40 |
+
model_dict = torch.load(model_path, map_location=device)
|
| 41 |
+
model = get_model(model_configs=model_configs)
|
| 42 |
+
model.load_state_dict(model_dict)
|
| 43 |
+
model = model.to(device).eval()
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def extract_frames(video_path):
|
| 48 |
+
"""Extracts frames from a video file."""
|
| 49 |
+
vidcap = cv2.VideoCapture(video_path)
|
| 50 |
+
frames = []
|
| 51 |
+
success, image = vidcap.read()
|
| 52 |
+
while success:
|
| 53 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 54 |
+
frames.append(image_rgb)
|
| 55 |
+
success, image = vidcap.read()
|
| 56 |
+
vidcap.release()
|
| 57 |
+
return frames
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def resize_frame(image, max_width=640, max_height=480):
|
| 61 |
+
if not isinstance(image, Image.Image):
|
| 62 |
+
image = Image.fromarray(image)
|
| 63 |
+
original_size = image.size
|
| 64 |
+
|
| 65 |
+
# Resize the frame similarly to the image resizing logic
|
| 66 |
+
if original_size[0] == original_size[1] and original_size[0] >= 256:
|
| 67 |
+
max_size = (256, 256)
|
| 68 |
+
else:
|
| 69 |
+
max_size = list(original_size)
|
| 70 |
+
if original_size[0] >= max_width:
|
| 71 |
+
max_size[0] = max_width
|
| 72 |
+
elif original_size[0] < 64:
|
| 73 |
+
max_size[0] = 64
|
| 74 |
+
if original_size[1] >= max_height:
|
| 75 |
+
max_size[1] = max_height
|
| 76 |
+
elif original_size[1] < 32:
|
| 77 |
+
max_size[1] = 32
|
| 78 |
+
|
| 79 |
+
image.thumbnail(max_size)
|
| 80 |
+
# image = image.resize(max_size)
|
| 81 |
+
return image
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def is_image(file_extension):
|
| 85 |
+
"""Checks if the file is an image."""
|
| 86 |
+
return file_extension.lower() in ["png", "jpeg", "jpg"]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def is_video(file_extension):
|
| 90 |
+
"""Checks if the file is a video."""
|
| 91 |
+
return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_codec_and_extension(file_format):
|
| 95 |
+
"""Return codec and file extension based on the format."""
|
| 96 |
+
if file_format == "mp4":
|
| 97 |
+
return "H264", ".mp4"
|
| 98 |
+
elif file_format == "avi":
|
| 99 |
+
return "MJPG", ".avi"
|
| 100 |
+
elif file_format == "webm":
|
| 101 |
+
return "VP80", ".webm"
|
| 102 |
+
else:
|
| 103 |
+
return "MJPG", ".avi"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def display_results(input_image, cam_frame, pupil_diameter, cols):
|
| 107 |
+
"""Displays the input image and overlayed CAM result."""
|
| 108 |
+
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
|
| 109 |
+
axs[0].imshow(input_image)
|
| 110 |
+
axs[0].axis("off")
|
| 111 |
+
axs[0].set_title("Input Image")
|
| 112 |
+
axs[1].imshow(cam_frame)
|
| 113 |
+
axs[1].axis("off")
|
| 114 |
+
axs[1].set_title("Overlayed CAM")
|
| 115 |
+
cols[-1].pyplot(fig)
|
| 116 |
+
cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def preprocess_image(input_img, max_size=(256, 256)):
|
| 120 |
+
"""Resizes and preprocesses an image."""
|
| 121 |
+
input_img.thumbnail(max_size)
|
| 122 |
+
preprocess_steps = [
|
| 123 |
+
transforms.ToTensor(),
|
| 124 |
+
transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
|
| 125 |
+
]
|
| 126 |
+
return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def overlay_text_on_frame(frame, text, position=(16, 20)):
|
| 130 |
+
"""Write text on the image frame using OpenCV."""
|
| 131 |
+
return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_configs(blink_detection=False):
|
| 135 |
+
upscale = "-"
|
| 136 |
+
upscale_method_or_model = "-"
|
| 137 |
+
if upscale == "-":
|
| 138 |
+
sr_configs = None
|
| 139 |
+
else:
|
| 140 |
+
sr_configs = {
|
| 141 |
+
"method": upscale_method_or_model,
|
| 142 |
+
"params": {"upscale": upscale},
|
| 143 |
+
}
|
| 144 |
+
config_file = {
|
| 145 |
+
"sr_configs": sr_configs,
|
| 146 |
+
"feature_extraction_configs": {
|
| 147 |
+
"blink_detection": blink_detection,
|
| 148 |
+
"upscale": upscale,
|
| 149 |
+
"extraction_library": "mediapipe",
|
| 150 |
+
},
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
return config_file
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def setup(cols, pupil_selection, tv_model, output_path):
|
| 157 |
+
|
| 158 |
+
left_pupil_model = None
|
| 159 |
+
left_pupil_cam_extractor = None
|
| 160 |
+
right_pupil_model = None
|
| 161 |
+
right_pupil_cam_extractor = None
|
| 162 |
+
output_frames = {}
|
| 163 |
+
input_frames = {}
|
| 164 |
+
predicted_diameters = {}
|
| 165 |
+
pred_diameters_frames = {}
|
| 166 |
+
|
| 167 |
+
if pupil_selection == "both":
|
| 168 |
+
selected_eyes = ["left_eye", "right_eye"]
|
| 169 |
+
|
| 170 |
+
elif pupil_selection == "left_pupil":
|
| 171 |
+
selected_eyes = ["left_eye"]
|
| 172 |
+
|
| 173 |
+
elif pupil_selection == "right_pupil":
|
| 174 |
+
selected_eyes = ["right_eye"]
|
| 175 |
+
|
| 176 |
+
for i, eye_type in enumerate(selected_eyes):
|
| 177 |
+
model_configs = {
|
| 178 |
+
"model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
|
| 179 |
+
"registered_model_name": tv_model,
|
| 180 |
+
"num_classes": 1,
|
| 181 |
+
}
|
| 182 |
+
if eye_type == "left_eye":
|
| 183 |
+
left_pupil_model = load_model(model_configs)
|
| 184 |
+
left_pupil_cam_extractor = None
|
| 185 |
+
output_frames[eye_type] = []
|
| 186 |
+
input_frames[eye_type] = []
|
| 187 |
+
predicted_diameters[eye_type] = []
|
| 188 |
+
pred_diameters_frames[eye_type] = []
|
| 189 |
+
else:
|
| 190 |
+
right_pupil_model = load_model(model_configs)
|
| 191 |
+
right_pupil_cam_extractor = None
|
| 192 |
+
output_frames[eye_type] = []
|
| 193 |
+
input_frames[eye_type] = []
|
| 194 |
+
predicted_diameters[eye_type] = []
|
| 195 |
+
pred_diameters_frames[eye_type] = []
|
| 196 |
+
|
| 197 |
+
video_placeholders = {}
|
| 198 |
+
|
| 199 |
+
if output_path:
|
| 200 |
+
video_cols = cols[1].columns(len(input_frames.keys()))
|
| 201 |
+
|
| 202 |
+
for i, eye_type in enumerate(list(input_frames.keys())):
|
| 203 |
+
video_placeholders[eye_type] = video_cols[i].empty()
|
| 204 |
+
|
| 205 |
+
return (
|
| 206 |
+
selected_eyes,
|
| 207 |
+
input_frames,
|
| 208 |
+
output_frames,
|
| 209 |
+
predicted_diameters,
|
| 210 |
+
pred_diameters_frames,
|
| 211 |
+
video_placeholders,
|
| 212 |
+
left_pupil_model,
|
| 213 |
+
left_pupil_cam_extractor,
|
| 214 |
+
right_pupil_model,
|
| 215 |
+
right_pupil_cam_extractor,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def process_frames(
|
| 220 |
+
cols, input_imgs, tv_model, pupil_selection, cam_method, output_path=None, codec=None, blink_detection=False
|
| 221 |
+
):
|
| 222 |
+
|
| 223 |
+
config_file = get_configs(blink_detection)
|
| 224 |
+
|
| 225 |
+
face_frames = []
|
| 226 |
+
|
| 227 |
+
(
|
| 228 |
+
selected_eyes,
|
| 229 |
+
input_frames,
|
| 230 |
+
output_frames,
|
| 231 |
+
predicted_diameters,
|
| 232 |
+
pred_diameters_frames,
|
| 233 |
+
video_placeholders,
|
| 234 |
+
left_pupil_model,
|
| 235 |
+
left_pupil_cam_extractor,
|
| 236 |
+
right_pupil_model,
|
| 237 |
+
right_pupil_cam_extractor,
|
| 238 |
+
) = setup(cols, pupil_selection, tv_model, output_path)
|
| 239 |
+
|
| 240 |
+
ds_creation = EyeDentityDatasetCreation(
|
| 241 |
+
feature_extraction_configs=config_file["feature_extraction_configs"],
|
| 242 |
+
sr_configs=config_file["sr_configs"],
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
preprocess_steps = [
|
| 246 |
+
transforms.Resize(
|
| 247 |
+
[32, 64],
|
| 248 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 249 |
+
antialias=True,
|
| 250 |
+
),
|
| 251 |
+
transforms.ToTensor(),
|
| 252 |
+
]
|
| 253 |
+
preprocess_function = transforms.Compose(preprocess_steps)
|
| 254 |
+
|
| 255 |
+
eyes_ratios = []
|
| 256 |
+
|
| 257 |
+
for idx, input_img in enumerate(input_imgs):
|
| 258 |
+
|
| 259 |
+
img = np.array(input_img)
|
| 260 |
+
ds_results = ds_creation(img)
|
| 261 |
+
|
| 262 |
+
left_eye = None
|
| 263 |
+
right_eye = None
|
| 264 |
+
blinked = False
|
| 265 |
+
eyes_ratio = None
|
| 266 |
+
|
| 267 |
+
if ds_results is not None and "face" in ds_results:
|
| 268 |
+
face_img = to_pil_image(ds_results["face"])
|
| 269 |
+
has_face = True
|
| 270 |
+
else:
|
| 271 |
+
face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8))
|
| 272 |
+
has_face = False
|
| 273 |
+
face_frames.append({"has_face": has_face, "img": face_img})
|
| 274 |
+
|
| 275 |
+
if ds_results is not None and "eyes" in ds_results.keys():
|
| 276 |
+
blinked = ds_results["eyes"]["blinked"]
|
| 277 |
+
eyes_ratio = ds_results["eyes"]["eyes_ratio"]
|
| 278 |
+
if eyes_ratio is not None:
|
| 279 |
+
eyes_ratios.append(eyes_ratio)
|
| 280 |
+
if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None:
|
| 281 |
+
left_eye = ds_results["eyes"]["left_eye"]
|
| 282 |
+
left_eye = to_pil_image(left_eye).convert("RGB")
|
| 283 |
+
left_eye = preprocess_function(left_eye)
|
| 284 |
+
left_eye = left_eye.unsqueeze(0)
|
| 285 |
+
if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None:
|
| 286 |
+
right_eye = ds_results["eyes"]["right_eye"]
|
| 287 |
+
right_eye = to_pil_image(right_eye).convert("RGB")
|
| 288 |
+
right_eye = preprocess_function(right_eye)
|
| 289 |
+
right_eye = right_eye.unsqueeze(0)
|
| 290 |
+
else:
|
| 291 |
+
input_img = preprocess_function(input_img)
|
| 292 |
+
input_img = input_img.unsqueeze(0)
|
| 293 |
+
if pupil_selection == "left_pupil":
|
| 294 |
+
left_eye = input_img
|
| 295 |
+
elif pupil_selection == "right_pupil":
|
| 296 |
+
right_eye = input_img
|
| 297 |
+
else:
|
| 298 |
+
left_eye = input_img
|
| 299 |
+
right_eye = input_img
|
| 300 |
+
|
| 301 |
+
for i, eye_type in enumerate(selected_eyes):
|
| 302 |
+
|
| 303 |
+
if blinked:
|
| 304 |
+
if left_eye is not None and eye_type == "left_eye":
|
| 305 |
+
_, height, width = left_eye.squeeze(0).shape
|
| 306 |
+
input_image_pil = to_pil_image(left_eye.squeeze(0))
|
| 307 |
+
elif right_eye is not None and eye_type == "right_eye":
|
| 308 |
+
_, height, width = right_eye.squeeze(0).shape
|
| 309 |
+
input_image_pil = to_pil_image(right_eye.squeeze(0))
|
| 310 |
+
|
| 311 |
+
input_img_np = np.array(input_image_pil)
|
| 312 |
+
zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8))
|
| 313 |
+
output_img_np = overlay_text_on_frame(np.array(zeros_img), "blink")
|
| 314 |
+
predicted_diameter = "blink"
|
| 315 |
+
else:
|
| 316 |
+
if left_eye is not None and eye_type == "left_eye":
|
| 317 |
+
if left_pupil_cam_extractor is None:
|
| 318 |
+
if tv_model == "ResNet18":
|
| 319 |
+
target_layer = left_pupil_model.resnet.layer4[-1].conv2
|
| 320 |
+
elif tv_model == "ResNet50":
|
| 321 |
+
target_layer = left_pupil_model.resnet.layer4[-1].conv3
|
| 322 |
+
else:
|
| 323 |
+
raise Exception(f"No target layer available for selected model: {tv_model}")
|
| 324 |
+
left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method](
|
| 325 |
+
left_pupil_model,
|
| 326 |
+
target_layer=target_layer,
|
| 327 |
+
fc_layer=left_pupil_model.resnet.fc,
|
| 328 |
+
input_shape=left_eye.shape,
|
| 329 |
+
)
|
| 330 |
+
output = left_pupil_model(left_eye)
|
| 331 |
+
predicted_diameter = output[0].item()
|
| 332 |
+
act_maps = left_pupil_cam_extractor(0, output)
|
| 333 |
+
activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps)
|
| 334 |
+
input_image_pil = to_pil_image(left_eye.squeeze(0))
|
| 335 |
+
elif right_eye is not None and eye_type == "right_eye":
|
| 336 |
+
if right_pupil_cam_extractor is None:
|
| 337 |
+
if tv_model == "ResNet18":
|
| 338 |
+
target_layer = right_pupil_model.resnet.layer4[-1].conv2
|
| 339 |
+
elif tv_model == "ResNet50":
|
| 340 |
+
target_layer = right_pupil_model.resnet.layer4[-1].conv3
|
| 341 |
+
else:
|
| 342 |
+
raise Exception(f"No target layer available for selected model: {tv_model}")
|
| 343 |
+
right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method](
|
| 344 |
+
right_pupil_model,
|
| 345 |
+
target_layer=target_layer,
|
| 346 |
+
fc_layer=right_pupil_model.resnet.fc,
|
| 347 |
+
input_shape=right_eye.shape,
|
| 348 |
+
)
|
| 349 |
+
output = right_pupil_model(right_eye)
|
| 350 |
+
predicted_diameter = output[0].item()
|
| 351 |
+
act_maps = right_pupil_cam_extractor(0, output)
|
| 352 |
+
activation_map = (
|
| 353 |
+
act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps)
|
| 354 |
+
)
|
| 355 |
+
input_image_pil = to_pil_image(right_eye.squeeze(0))
|
| 356 |
+
|
| 357 |
+
# Create CAM overlay
|
| 358 |
+
activation_map_pil = to_pil_image(activation_map, mode="F")
|
| 359 |
+
result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5)
|
| 360 |
+
input_img_np = np.array(input_image_pil)
|
| 361 |
+
output_img_np = np.array(result)
|
| 362 |
+
|
| 363 |
+
# Add frame and predicted diameter to lists
|
| 364 |
+
input_frames[eye_type].append(input_img_np)
|
| 365 |
+
output_frames[eye_type].append(output_img_np)
|
| 366 |
+
predicted_diameters[eye_type].append(predicted_diameter)
|
| 367 |
+
|
| 368 |
+
if output_path:
|
| 369 |
+
height, width, _ = output_img_np.shape
|
| 370 |
+
frame = np.zeros((height, width, 3), dtype=np.uint8)
|
| 371 |
+
if not isinstance(predicted_diameter, str):
|
| 372 |
+
text = f"{predicted_diameter:.2f}"
|
| 373 |
+
else:
|
| 374 |
+
text = predicted_diameter
|
| 375 |
+
frame = overlay_text_on_frame(frame, text)
|
| 376 |
+
pred_diameters_frames[eye_type].append(frame)
|
| 377 |
+
|
| 378 |
+
combined_frame = np.vstack((input_img_np, output_img_np, frame))
|
| 379 |
+
|
| 380 |
+
img_base64 = pil_image_to_base64(Image.fromarray(combined_frame))
|
| 381 |
+
image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 382 |
+
video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True)
|
| 383 |
+
|
| 384 |
+
# video_placeholders[eye_type].image(combined_frame, use_column_width=True)
|
| 385 |
+
|
| 386 |
+
st.session_state.current_frame = idx + 1
|
| 387 |
+
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
| 388 |
+
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
|
| 389 |
+
|
| 390 |
+
if output_path:
|
| 391 |
+
combine_and_show_frames(
|
| 392 |
+
input_frames, output_frames, pred_diameters_frames, output_path, codec, video_placeholders
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# Function to display video with autoplay and loop
|
| 399 |
+
def display_video_with_autoplay(video_col, video_path, width):
|
| 400 |
+
video_html = f"""
|
| 401 |
+
<video width="{str(width)}%" height="auto" autoplay loop muted>
|
| 402 |
+
<source src="data:video/mp4;base64,{video_path}" type="video/mp4">
|
| 403 |
+
</video>
|
| 404 |
+
"""
|
| 405 |
+
video_col.markdown(video_html, unsafe_allow_html=True)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method, blink_detection=False):
|
| 409 |
+
|
| 410 |
+
resized_frames = []
|
| 411 |
+
for i, frame in enumerate(video_frames):
|
| 412 |
+
input_img = resize_frame(frame, max_width=640, max_height=480)
|
| 413 |
+
resized_frames.append(input_img)
|
| 414 |
+
|
| 415 |
+
file_format = output_path.split(".")[-1]
|
| 416 |
+
codec, extension = get_codec_and_extension(file_format)
|
| 417 |
+
|
| 418 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
|
| 419 |
+
cols, resized_frames, tv_model, pupil_selection, cam_method, output_path, codec, blink_detection
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# Function to convert string values to float or None
|
| 426 |
+
def convert_diameter(value):
|
| 427 |
+
try:
|
| 428 |
+
return float(value)
|
| 429 |
+
except (ValueError, TypeError):
|
| 430 |
+
return None # Return None if conversion fails
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, output_path, codec, video_cols):
|
| 434 |
+
# Assuming all frames have the same keys (eye types)
|
| 435 |
+
eye_types = input_frames.keys()
|
| 436 |
+
|
| 437 |
+
for i, eye_type in enumerate(eye_types):
|
| 438 |
+
in_frames = input_frames[eye_type]
|
| 439 |
+
cam_out_frames = cam_frames[eye_type]
|
| 440 |
+
pred_diameters_text_frames = pred_diameters_frames[eye_type]
|
| 441 |
+
|
| 442 |
+
# Get frame properties (assuming all frames have the same dimensions)
|
| 443 |
+
height, width, _ = in_frames[0].shape
|
| 444 |
+
fourcc = cv2.VideoWriter_fourcc(*codec)
|
| 445 |
+
fps = 10.0
|
| 446 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height * 3)) # Width is tripled for concatenation
|
| 447 |
+
|
| 448 |
+
# Loop through each set of frames and concatenate them
|
| 449 |
+
for j in range(len(in_frames)):
|
| 450 |
+
input_frame = in_frames[j]
|
| 451 |
+
cam_frame = cam_out_frames[j]
|
| 452 |
+
pred_frame = pred_diameters_text_frames[j]
|
| 453 |
+
|
| 454 |
+
# Convert frames to BGR if necessary
|
| 455 |
+
input_frame_bgr = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
|
| 456 |
+
cam_frame_bgr = cv2.cvtColor(cam_frame, cv2.COLOR_RGB2BGR)
|
| 457 |
+
pred_frame_bgr = cv2.cvtColor(pred_frame, cv2.COLOR_RGB2BGR)
|
| 458 |
+
|
| 459 |
+
# Concatenate frames horizontally (input, cam, pred)
|
| 460 |
+
combined_frame = np.vstack((input_frame_bgr, cam_frame_bgr, pred_frame_bgr))
|
| 461 |
+
|
| 462 |
+
# Write the combined frame to the video
|
| 463 |
+
out.write(combined_frame)
|
| 464 |
+
|
| 465 |
+
# Release the video writer
|
| 466 |
+
out.release()
|
| 467 |
+
|
| 468 |
+
# Read the video and encode it in base64 for displaying
|
| 469 |
+
with open(output_path, "rb") as video_file:
|
| 470 |
+
video_bytes = video_file.read()
|
| 471 |
+
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
| 472 |
+
|
| 473 |
+
# Display the combined video
|
| 474 |
+
display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50)
|
| 475 |
+
|
| 476 |
+
# Clean up
|
| 477 |
+
os.remove(output_path)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def set_input_image_on_ui(uploaded_file, cols):
|
| 481 |
+
input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
|
| 482 |
+
# NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
|
| 483 |
+
input_img = ImageOps.exif_transpose(input_img)
|
| 484 |
+
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
| 485 |
+
input_img = resize_frame(input_img, max_width=640, max_height=480)
|
| 486 |
+
cols[0].image(input_img, use_column_width=True)
|
| 487 |
+
st.session_state.total_frames = 1
|
| 488 |
+
return input_img
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def set_input_video_on_ui(uploaded_file, cols):
|
| 492 |
+
tfile = tempfile.NamedTemporaryFile(delete=False)
|
| 493 |
+
try:
|
| 494 |
+
tfile.write(uploaded_file.read())
|
| 495 |
+
except Exception:
|
| 496 |
+
tfile.write(uploaded_file)
|
| 497 |
+
video_path = tfile.name
|
| 498 |
+
video_frames = extract_frames(video_path)
|
| 499 |
+
cols[0].video(video_path)
|
| 500 |
+
st.session_state.total_frames = len(video_frames)
|
| 501 |
+
return video_frames, video_path
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def set_frames_processed_count_placeholder(cols):
|
| 505 |
+
st.session_state.current_frame = 0
|
| 506 |
+
st.session_state.frame_placeholder = cols[0].empty()
|
| 507 |
+
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
|
| 508 |
+
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def video_to_bytes(video_path):
|
| 512 |
+
# Open the video file in binary mode and return the bytes
|
| 513 |
+
with open(video_path, "rb") as video_file:
|
| 514 |
+
return video_file.read()
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def display_video_library(video_folder="./sample_videos"):
|
| 518 |
+
# Get all video files from the folder
|
| 519 |
+
video_files = [f for f in os.listdir(video_folder) if f.endswith(".webm")]
|
| 520 |
+
|
| 521 |
+
# Store the selected video path
|
| 522 |
+
selected_video_path = None
|
| 523 |
+
|
| 524 |
+
# Calculate number of columns (adjust based on your layout preferences)
|
| 525 |
+
num_columns = 3 # For a grid of 3 videos per row
|
| 526 |
+
|
| 527 |
+
# Display videos in a grid layout with 'Select' button for each video
|
| 528 |
+
for i in range(0, len(video_files), num_columns):
|
| 529 |
+
cols = st.columns(num_columns)
|
| 530 |
+
for idx, video_file in enumerate(video_files[i : i + num_columns]):
|
| 531 |
+
with cols[idx]:
|
| 532 |
+
st.subheader(video_file.split(".")[0]) # Use the file name as the title
|
| 533 |
+
video_path = os.path.join(video_folder, video_file)
|
| 534 |
+
st.video(video_path) # Show the video
|
| 535 |
+
if st.button(f"Select {video_file.split('.')[0]}", key=video_file, type="primary"):
|
| 536 |
+
st.session_state.clear()
|
| 537 |
+
st.toast("Scroll Down to see the input and predictions", icon="⏬")
|
| 538 |
+
selected_video_path = video_path # Store the path of the selected video
|
| 539 |
+
|
| 540 |
+
return selected_video_path
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def set_page_info_and_sidebar_info():
|
| 544 |
+
|
| 545 |
+
st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
|
| 546 |
+
st.title("👁️ PupilSense 👁️🕵️♂️")
|
| 547 |
+
# st.markdown("Upload your own images or video **OR** select from our sample library below")
|
| 548 |
+
st.markdown(
|
| 549 |
+
"<p style='font-size: 30px;'>"
|
| 550 |
+
"Upload your own image 🖼️ or video 🎞️ <strong>OR</strong> select from our sample videos 📚"
|
| 551 |
+
"</p>",
|
| 552 |
+
unsafe_allow_html=True,
|
| 553 |
+
)
|
| 554 |
+
# video_path = display_video_library()
|
| 555 |
+
show_demo_videos = st.sidebar.checkbox("Show Sample Videos", value=False)
|
| 556 |
+
if show_demo_videos:
|
| 557 |
+
video_path = display_video_library()
|
| 558 |
+
else:
|
| 559 |
+
video_path = None
|
| 560 |
+
|
| 561 |
+
st.markdown("<hr id='target_element' style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True)
|
| 562 |
+
cols = st.columns((1, 1))
|
| 563 |
+
cols[0].header("Input")
|
| 564 |
+
cols[-1].header("Prediction")
|
| 565 |
+
st.markdown("<hr style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True)
|
| 566 |
+
|
| 567 |
+
LABEL_MAP = ["left_pupil", "right_pupil"]
|
| 568 |
+
TV_MODELS = ["ResNet18", "ResNet50"]
|
| 569 |
+
|
| 570 |
+
if "uploader_key" not in st.session_state:
|
| 571 |
+
st.session_state["uploader_key"] = 1
|
| 572 |
+
|
| 573 |
+
st.sidebar.title("Upload Face 👨🦱 or Eye 👁️")
|
| 574 |
+
uploaded_file = st.sidebar.file_uploader(
|
| 575 |
+
"Upload Image or Video",
|
| 576 |
+
type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"],
|
| 577 |
+
key=st.session_state["uploader_key"],
|
| 578 |
+
)
|
| 579 |
+
if uploaded_file is not None:
|
| 580 |
+
st.session_state["uploaded_file"] = uploaded_file
|
| 581 |
+
|
| 582 |
+
st.sidebar.title("Setup")
|
| 583 |
+
pupil_selection = st.sidebar.selectbox(
|
| 584 |
+
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
|
| 585 |
+
)
|
| 586 |
+
tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models")
|
| 587 |
+
|
| 588 |
+
blink_detection = st.sidebar.checkbox("Detect Blinks", value=True)
|
| 589 |
+
|
| 590 |
+
st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
|
| 591 |
+
|
| 592 |
+
if "uploaded_file" not in st.session_state:
|
| 593 |
+
st.session_state["uploaded_file"] = None
|
| 594 |
+
|
| 595 |
+
if "og_video_path" not in st.session_state:
|
| 596 |
+
st.session_state["og_video_path"] = None
|
| 597 |
+
|
| 598 |
+
if uploaded_file is None and video_path is not None:
|
| 599 |
+
video_bytes = video_to_bytes(video_path)
|
| 600 |
+
uploaded_file = video_bytes
|
| 601 |
+
st.session_state["uploaded_file"] = uploaded_file
|
| 602 |
+
st.session_state["og_video_path"] = video_path
|
| 603 |
+
st.session_state["uploader_key"] = 0
|
| 604 |
+
|
| 605 |
+
return (
|
| 606 |
+
cols,
|
| 607 |
+
st.session_state["og_video_path"],
|
| 608 |
+
st.session_state["uploaded_file"],
|
| 609 |
+
pupil_selection,
|
| 610 |
+
tv_model,
|
| 611 |
+
blink_detection,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def pil_image_to_base64(img):
|
| 616 |
+
"""Convert a PIL Image to a base64 encoded string."""
|
| 617 |
+
buffered = io.BytesIO()
|
| 618 |
+
img.save(buffered, format="PNG")
|
| 619 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 620 |
+
return img_str
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection):
|
| 624 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
|
| 625 |
+
cols,
|
| 626 |
+
[input_img],
|
| 627 |
+
tv_model,
|
| 628 |
+
pupil_selection,
|
| 629 |
+
cam_method=CAM_METHODS[-1],
|
| 630 |
+
blink_detection=blink_detection,
|
| 631 |
+
)
|
| 632 |
+
# for ff in face_frames:
|
| 633 |
+
# if ff["has_face"]:
|
| 634 |
+
# cols[1].image(face_frames[0]["img"], use_column_width=True)
|
| 635 |
+
|
| 636 |
+
input_frames_keys = input_frames.keys()
|
| 637 |
+
video_cols = cols[1].columns(len(input_frames_keys))
|
| 638 |
+
|
| 639 |
+
for i, eye_type in enumerate(input_frames_keys):
|
| 640 |
+
# Check the pupil_selection and set the width accordingly
|
| 641 |
+
if pupil_selection == "both":
|
| 642 |
+
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
|
| 643 |
+
else:
|
| 644 |
+
img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1]))
|
| 645 |
+
image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 646 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
| 647 |
+
|
| 648 |
+
output_frames_keys = output_frames.keys()
|
| 649 |
+
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
|
| 650 |
+
for i, eye_type in enumerate(output_frames_keys):
|
| 651 |
+
height, width, c = output_frames[eye_type][0].shape
|
| 652 |
+
frame = np.zeros((height, width, c), dtype=np.uint8)
|
| 653 |
+
text = f"{predicted_diameters[eye_type][0]:.2f}"
|
| 654 |
+
frame = overlay_text_on_frame(frame, text)
|
| 655 |
+
|
| 656 |
+
if pupil_selection == "both":
|
| 657 |
+
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
|
| 658 |
+
video_cols[i].image(frame, use_column_width=True)
|
| 659 |
+
else:
|
| 660 |
+
img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1]))
|
| 661 |
+
image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 662 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
| 663 |
+
img_base64 = pil_image_to_base64(Image.fromarray(frame))
|
| 664 |
+
image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
|
| 665 |
+
video_cols[i].markdown(image_html, unsafe_allow_html=True)
|
| 666 |
+
|
| 667 |
+
return None
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def plot_ears(eyes_ratios, eyes_df):
|
| 671 |
+
eyes_df["EAR"] = eyes_ratios
|
| 672 |
+
df = pd.DataFrame(eyes_ratios, columns=["EAR"])
|
| 673 |
+
df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
|
| 674 |
+
|
| 675 |
+
# Create an Altair chart for eyes_ratios
|
| 676 |
+
line_chart = (
|
| 677 |
+
alt.Chart(df)
|
| 678 |
+
.mark_line(color=colors[-1]) # Set color of the line
|
| 679 |
+
.encode(
|
| 680 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
| 681 |
+
y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
|
| 682 |
+
tooltip=["Frame", "EAR"],
|
| 683 |
+
)
|
| 684 |
+
# .properties(title="Eyes Aspect Ratios (EARs)")
|
| 685 |
+
# .configure_axis(grid=True)
|
| 686 |
+
)
|
| 687 |
+
points_chart = line_chart.mark_point(color=colors[-1], filled=True)
|
| 688 |
+
|
| 689 |
+
# Create a horizontal rule at y=0.22
|
| 690 |
+
line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
|
| 691 |
+
|
| 692 |
+
line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
|
| 693 |
+
|
| 694 |
+
# Add text annotations for the lines
|
| 695 |
+
text1 = (
|
| 696 |
+
alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
|
| 697 |
+
.mark_text(align="left", dx=100, dy=9, color="red", size=16)
|
| 698 |
+
.encode(y="y:Q", text="label:N")
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
text2 = (
|
| 702 |
+
alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
|
| 703 |
+
.mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
|
| 704 |
+
.encode(y="y:Q", text="label:N")
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
# Add gray area text for the region between red and green lines
|
| 708 |
+
gray_area_text = (
|
| 709 |
+
alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
|
| 710 |
+
.mark_text(align="left", dx=0, dy=0, color="gray", size=16)
|
| 711 |
+
.encode(y="y:Q", text="label:N")
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Combine all elements: line chart, points, rules, and text annotations
|
| 715 |
+
final_chart = (
|
| 716 |
+
line_chart.properties(title="Eyes Aspect Ratios (EARs)")
|
| 717 |
+
+ points_chart
|
| 718 |
+
+ line1
|
| 719 |
+
+ line2
|
| 720 |
+
+ text1
|
| 721 |
+
+ text2
|
| 722 |
+
+ gray_area_text
|
| 723 |
+
).interactive()
|
| 724 |
+
|
| 725 |
+
# Configure axis properties at the chart level
|
| 726 |
+
final_chart = final_chart.configure_axis(grid=True)
|
| 727 |
+
|
| 728 |
+
# Display the Altair chart
|
| 729 |
+
# st.subheader("Eyes Aspect Ratios (EARs)")
|
| 730 |
+
st.altair_chart(final_chart, use_container_width=True)
|
| 731 |
+
return eyes_df
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def plot_individual_charts(predicted_diameters, cols):
|
| 735 |
+
# Iterate through categories and assign charts to columns
|
| 736 |
+
for i, (category, values) in enumerate(predicted_diameters.items()):
|
| 737 |
+
with cols[i]: # Directly use the column index
|
| 738 |
+
# st.subheader(category) # Add a subheader for the category
|
| 739 |
+
if "left" in category:
|
| 740 |
+
selected_color = colors[0]
|
| 741 |
+
elif "right" in category:
|
| 742 |
+
selected_color = colors[1]
|
| 743 |
+
else:
|
| 744 |
+
selected_color = colors[i]
|
| 745 |
+
|
| 746 |
+
# Convert values to numeric, replacing non-numeric values with None
|
| 747 |
+
values = [convert_diameter(value) for value in values]
|
| 748 |
+
|
| 749 |
+
if "left" in category:
|
| 750 |
+
category_name = "Left Pupil Diameter"
|
| 751 |
+
else:
|
| 752 |
+
category_name = "Right Pupil Diameter"
|
| 753 |
+
|
| 754 |
+
# Create a DataFrame from the values for Altair
|
| 755 |
+
df = pd.DataFrame(
|
| 756 |
+
{
|
| 757 |
+
"Frame": range(1, len(values) + 1),
|
| 758 |
+
category_name: values,
|
| 759 |
+
}
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
# Get the min and max values for y-axis limits, ignoring None
|
| 763 |
+
min_value = min(filter(lambda x: x is not None, values), default=None)
|
| 764 |
+
max_value = max(filter(lambda x: x is not None, values), default=None)
|
| 765 |
+
|
| 766 |
+
# Create an Altair chart with y-axis limits
|
| 767 |
+
line_chart = (
|
| 768 |
+
alt.Chart(df)
|
| 769 |
+
.mark_line(color=selected_color)
|
| 770 |
+
.encode(
|
| 771 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
| 772 |
+
y=alt.Y(
|
| 773 |
+
f"{category_name}:Q",
|
| 774 |
+
title="Diameter",
|
| 775 |
+
scale=alt.Scale(domain=[min_value, max_value]),
|
| 776 |
+
),
|
| 777 |
+
tooltip=[
|
| 778 |
+
"Frame",
|
| 779 |
+
alt.Tooltip(f"{category_name}:Q", title="Diameter"),
|
| 780 |
+
],
|
| 781 |
+
)
|
| 782 |
+
# .properties(title=f"{category} - Predicted Diameters")
|
| 783 |
+
# .configure_axis(grid=True)
|
| 784 |
+
)
|
| 785 |
+
points_chart = line_chart.mark_point(color=selected_color, filled=True)
|
| 786 |
+
|
| 787 |
+
final_chart = (
|
| 788 |
+
line_chart.properties(
|
| 789 |
+
title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters"
|
| 790 |
+
)
|
| 791 |
+
+ points_chart
|
| 792 |
+
).interactive()
|
| 793 |
+
|
| 794 |
+
final_chart = final_chart.configure_axis(grid=True)
|
| 795 |
+
|
| 796 |
+
# Display the Altair chart
|
| 797 |
+
st.altair_chart(final_chart, use_container_width=True)
|
| 798 |
+
return df
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def plot_combined_charts(predicted_diameters):
|
| 802 |
+
all_min_values = []
|
| 803 |
+
all_max_values = []
|
| 804 |
+
|
| 805 |
+
# Create an empty DataFrame to store combined data for plotting
|
| 806 |
+
combined_df = pd.DataFrame()
|
| 807 |
+
|
| 808 |
+
# Iterate through categories and collect data
|
| 809 |
+
for category, values in predicted_diameters.items():
|
| 810 |
+
# Convert values to numeric, replacing non-numeric values with None
|
| 811 |
+
values = [convert_diameter(value) for value in values]
|
| 812 |
+
|
| 813 |
+
# Get the min and max values for y-axis limits, ignoring None
|
| 814 |
+
min_value = min(filter(lambda x: x is not None, values), default=None)
|
| 815 |
+
max_value = max(filter(lambda x: x is not None, values), default=None)
|
| 816 |
+
|
| 817 |
+
all_min_values.append(min_value)
|
| 818 |
+
all_max_values.append(max_value)
|
| 819 |
+
|
| 820 |
+
category = "left_pupil" if "left" in category else "right_pupil"
|
| 821 |
+
|
| 822 |
+
# Create a DataFrame from the values
|
| 823 |
+
df = pd.DataFrame(
|
| 824 |
+
{
|
| 825 |
+
"Diameter": values,
|
| 826 |
+
"Frame": range(1, len(values) + 1), # Create a frame column starting from 1
|
| 827 |
+
"Category": category, # Add a column to specify the category
|
| 828 |
+
}
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# Append to combined DataFrame
|
| 832 |
+
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
| 833 |
+
|
| 834 |
+
combined_chart = (
|
| 835 |
+
alt.Chart(combined_df)
|
| 836 |
+
.mark_line()
|
| 837 |
+
.encode(
|
| 838 |
+
x=alt.X("Frame:Q", title="Frame Number"),
|
| 839 |
+
y=alt.Y(
|
| 840 |
+
"Diameter:Q",
|
| 841 |
+
title="Diameter",
|
| 842 |
+
scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]),
|
| 843 |
+
),
|
| 844 |
+
color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"),
|
| 845 |
+
tooltip=["Frame", "Diameter:Q", "Category:N"],
|
| 846 |
+
)
|
| 847 |
+
)
|
| 848 |
+
points_chart = combined_chart.mark_point(filled=True)
|
| 849 |
+
|
| 850 |
+
final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive()
|
| 851 |
+
|
| 852 |
+
final_chart = final_chart.configure_axis(grid=True)
|
| 853 |
+
|
| 854 |
+
# Display the combined chart
|
| 855 |
+
st.altair_chart(final_chart, use_container_width=True)
|
| 856 |
+
|
| 857 |
+
# --------------------------------------------
|
| 858 |
+
# Convert to a DataFrame
|
| 859 |
+
left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]]
|
| 860 |
+
right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]]
|
| 861 |
+
|
| 862 |
+
df = pd.DataFrame(
|
| 863 |
+
{
|
| 864 |
+
"Frame": range(1, len(left_pupil_values) + 1),
|
| 865 |
+
"Left Pupil Diameter": left_pupil_values,
|
| 866 |
+
"Right Pupil Diameter": right_pupil_values,
|
| 867 |
+
}
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Calculate the difference between left and right pupil diameters
|
| 871 |
+
df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"]
|
| 872 |
+
|
| 873 |
+
# Determine the status of the difference
|
| 874 |
+
df["Difference Status"] = df.apply(
|
| 875 |
+
lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R",
|
| 876 |
+
axis=1,
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
return df
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path):
|
| 883 |
+
output_video_path = f"{root_path}/tmp.webm"
|
| 884 |
+
input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
|
| 885 |
+
cols,
|
| 886 |
+
video_frames,
|
| 887 |
+
tv_model,
|
| 888 |
+
pupil_selection,
|
| 889 |
+
output_video_path,
|
| 890 |
+
cam_method=CAM_METHODS[-1],
|
| 891 |
+
blink_detection=blink_detection,
|
| 892 |
+
)
|
| 893 |
+
os.remove(video_path)
|
| 894 |
+
|
| 895 |
+
num_columns = len(predicted_diameters)
|
| 896 |
+
cols = st.columns(num_columns)
|
| 897 |
+
|
| 898 |
+
if num_columns == 2:
|
| 899 |
+
df = plot_combined_charts(predicted_diameters)
|
| 900 |
+
else:
|
| 901 |
+
df = plot_individual_charts(predicted_diameters, cols)
|
| 902 |
+
|
| 903 |
+
if eyes_ratios is not None and len(eyes_ratios) > 0:
|
| 904 |
+
df = plot_ears(eyes_ratios, df)
|
| 905 |
+
|
| 906 |
+
st.dataframe(df, hide_index=True, use_container_width=True)
|
config.yml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 42
|
| 2 |
+
|
| 3 |
+
feature_extraction_configs:
|
| 4 |
+
blink_detection: true
|
| 5 |
+
upscale: 1
|
| 6 |
+
extraction_library: "mediapipe"
|
| 7 |
+
show_features: ['faces', 'eyes', 'blinks']
|
| 8 |
+
|
| 9 |
+
model_configs:
|
| 10 |
+
models_path: "pre_trained_models"
|
| 11 |
+
registered_model_names: ["ResNet18", "ResNet50"]
|
| 12 |
+
labels: ["left_eye", "right_eye"]
|
| 13 |
+
targets: ["left_pupil", "right_pupil"]
|
| 14 |
+
num_classes: 1
|
| 15 |
+
|
| 16 |
+
xai_configs:
|
| 17 |
+
attribution_methods: [
|
| 18 |
+
"IntegratedGradients",
|
| 19 |
+
"Saliency",
|
| 20 |
+
"InputXGradient",
|
| 21 |
+
"GuidedBackprop",
|
| 22 |
+
"Deconvolution",
|
| 23 |
+
# "GuidedGradCam",
|
| 24 |
+
# "LayerGradCam",
|
| 25 |
+
# "LayerGradientXActivation",
|
| 26 |
+
]
|
| 27 |
+
cam_methods: [
|
| 28 |
+
"CAM",
|
| 29 |
+
"GradCAM",
|
| 30 |
+
"GradCAMpp",
|
| 31 |
+
"SmoothGradCAMpp",
|
| 32 |
+
"ScoreCAM",
|
| 33 |
+
"SSCAM",
|
| 34 |
+
"ISCAM",
|
| 35 |
+
"XGradCAM",
|
| 36 |
+
"LayerCAM",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
use_sr: false
|
| 40 |
+
|
| 41 |
+
upscale_configs:
|
| 42 |
+
upscale: [1, 2, 3, 4]
|
| 43 |
+
upscale_method_configs:
|
| 44 |
+
size: [16, 32]
|
| 45 |
+
antialias: true
|
| 46 |
+
interpolation: ["bicubic"]
|
| 47 |
+
|
| 48 |
+
sr_methods: ["GFPGAN", "RealESRGAN", "SRResNet", "CodeFormer", "HAT"]
|
| 49 |
+
sr_method_configs:
|
| 50 |
+
bg_upsampler_name: "realesrgan"
|
| 51 |
+
prefered_net_in_upsampler: "RRDBNet"
|
feature_extraction/extractor_mediapipe.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
import warnings
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from math import sqrt
|
| 7 |
+
import mediapipe as mp
|
| 8 |
+
from transformers import pipeline
|
| 9 |
+
|
| 10 |
+
warnings.filterwarnings("ignore")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ExtractorMediaPipe:
|
| 14 |
+
|
| 15 |
+
def __init__(self, upscale=1):
|
| 16 |
+
|
| 17 |
+
self.upscale = int(upscale)
|
| 18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
# ========== Face Extraction ==========
|
| 21 |
+
self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
|
| 22 |
+
self.face_mesh = mp.solutions.face_mesh.FaceMesh(
|
| 23 |
+
max_num_faces=1,
|
| 24 |
+
static_image_mode=True,
|
| 25 |
+
refine_landmarks=True,
|
| 26 |
+
min_detection_confidence=0.5,
|
| 27 |
+
min_tracking_confidence=0.5,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ========== Eyes Extraction ==========
|
| 31 |
+
self.RIGHT_EYE = [
|
| 32 |
+
362,
|
| 33 |
+
382,
|
| 34 |
+
381,
|
| 35 |
+
380,
|
| 36 |
+
374,
|
| 37 |
+
373,
|
| 38 |
+
390,
|
| 39 |
+
249,
|
| 40 |
+
263,
|
| 41 |
+
466,
|
| 42 |
+
388,
|
| 43 |
+
387,
|
| 44 |
+
386,
|
| 45 |
+
385,
|
| 46 |
+
384,
|
| 47 |
+
398,
|
| 48 |
+
]
|
| 49 |
+
self.LEFT_EYE = [
|
| 50 |
+
33,
|
| 51 |
+
7,
|
| 52 |
+
163,
|
| 53 |
+
144,
|
| 54 |
+
145,
|
| 55 |
+
153,
|
| 56 |
+
154,
|
| 57 |
+
155,
|
| 58 |
+
133,
|
| 59 |
+
173,
|
| 60 |
+
157,
|
| 61 |
+
158,
|
| 62 |
+
159,
|
| 63 |
+
160,
|
| 64 |
+
161,
|
| 65 |
+
246,
|
| 66 |
+
]
|
| 67 |
+
# https://huggingface.co/dima806/closed_eyes_image_detection
|
| 68 |
+
# https://www.kaggle.com/code/dima806/closed-eye-image-detection-vit
|
| 69 |
+
self.pipe = pipeline(
|
| 70 |
+
"image-classification",
|
| 71 |
+
model="dima806/closed_eyes_image_detection",
|
| 72 |
+
device=self.device,
|
| 73 |
+
)
|
| 74 |
+
self.blink_lower_thresh = 0.22
|
| 75 |
+
self.blink_upper_thresh = 0.25
|
| 76 |
+
self.blink_confidence = 0.50
|
| 77 |
+
|
| 78 |
+
# ========== Iris Extraction ==========
|
| 79 |
+
self.RIGHT_IRIS = [474, 475, 476, 477]
|
| 80 |
+
self.LEFT_IRIS = [469, 470, 471, 472]
|
| 81 |
+
|
| 82 |
+
def extract_face(self, image):
|
| 83 |
+
|
| 84 |
+
tmp_image = image.copy()
|
| 85 |
+
results = self.face_detector.process(tmp_image)
|
| 86 |
+
|
| 87 |
+
if not results.detections:
|
| 88 |
+
# print("No face detected")
|
| 89 |
+
return None
|
| 90 |
+
else:
|
| 91 |
+
bboxC = results.detections[0].location_data.relative_bounding_box
|
| 92 |
+
ih, iw, _ = image.shape
|
| 93 |
+
|
| 94 |
+
# Get bounding box coordinates
|
| 95 |
+
x, y, w, h = (
|
| 96 |
+
int(bboxC.xmin * iw),
|
| 97 |
+
int(bboxC.ymin * ih),
|
| 98 |
+
int(bboxC.width * iw),
|
| 99 |
+
int(bboxC.height * ih),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Calculate the center of the bounding box
|
| 103 |
+
center_x = x + w // 2
|
| 104 |
+
center_y = y + h // 2
|
| 105 |
+
|
| 106 |
+
# Calculate new bounds ensuring they fit within the image dimensions
|
| 107 |
+
half_size = 128 * self.upscale
|
| 108 |
+
x1 = max(center_x - half_size, 0)
|
| 109 |
+
y1 = max(center_y - half_size, 0)
|
| 110 |
+
x2 = min(center_x + half_size, iw)
|
| 111 |
+
y2 = min(center_y + half_size, ih)
|
| 112 |
+
|
| 113 |
+
# Adjust x1, x2, y1, and y2 to ensure the cropped region is exactly (256 * self.upscale) x (256 * self.upscale)
|
| 114 |
+
if x2 - x1 < (256 * self.upscale):
|
| 115 |
+
if x1 == 0:
|
| 116 |
+
x2 = min((256 * self.upscale), iw)
|
| 117 |
+
elif x2 == iw:
|
| 118 |
+
x1 = max(iw - (256 * self.upscale), 0)
|
| 119 |
+
|
| 120 |
+
if y2 - y1 < (256 * self.upscale):
|
| 121 |
+
if y1 == 0:
|
| 122 |
+
y2 = min((256 * self.upscale), ih)
|
| 123 |
+
elif y2 == ih:
|
| 124 |
+
y1 = max(ih - (256 * self.upscale), 0)
|
| 125 |
+
|
| 126 |
+
cropped_face = image[y1:y2, x1:x2]
|
| 127 |
+
|
| 128 |
+
# bicubic upsampling
|
| 129 |
+
# if self.upscale != 1:
|
| 130 |
+
# cropped_face = cv2.resize(
|
| 131 |
+
# cropped_face,
|
| 132 |
+
# (256 * self.upscale, 256 * self.upscale),
|
| 133 |
+
# interpolation=cv2.INTER_CUBIC,
|
| 134 |
+
# )
|
| 135 |
+
|
| 136 |
+
return cropped_face
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def landmarksDetection(image, results, draw=False):
|
| 140 |
+
image_height, image_width = image.shape[:2]
|
| 141 |
+
mesh_coordinates = [
|
| 142 |
+
(int(point.x * image_width), int(point.y * image_height))
|
| 143 |
+
for point in results.multi_face_landmarks[0].landmark
|
| 144 |
+
]
|
| 145 |
+
if draw:
|
| 146 |
+
[cv2.circle(image, i, 2, (0, 255, 0), -1) for i in mesh_coordinates]
|
| 147 |
+
return mesh_coordinates
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def euclideanDistance(point, point1):
|
| 151 |
+
x, y = point
|
| 152 |
+
x1, y1 = point1
|
| 153 |
+
distance = sqrt((x1 - x) ** 2 + (y1 - y) ** 2)
|
| 154 |
+
return distance
|
| 155 |
+
|
| 156 |
+
def blinkRatio(self, landmarks, right_indices, left_indices):
|
| 157 |
+
|
| 158 |
+
right_eye_landmark1 = landmarks[right_indices[0]]
|
| 159 |
+
right_eye_landmark2 = landmarks[right_indices[8]]
|
| 160 |
+
|
| 161 |
+
right_eye_landmark3 = landmarks[right_indices[12]]
|
| 162 |
+
right_eye_landmark4 = landmarks[right_indices[4]]
|
| 163 |
+
|
| 164 |
+
left_eye_landmark1 = landmarks[left_indices[0]]
|
| 165 |
+
left_eye_landmark2 = landmarks[left_indices[8]]
|
| 166 |
+
|
| 167 |
+
left_eye_landmark3 = landmarks[left_indices[12]]
|
| 168 |
+
left_eye_landmark4 = landmarks[left_indices[4]]
|
| 169 |
+
|
| 170 |
+
right_eye_horizontal_distance = self.euclideanDistance(right_eye_landmark1, right_eye_landmark2)
|
| 171 |
+
right_eye_vertical_distance = self.euclideanDistance(right_eye_landmark3, right_eye_landmark4)
|
| 172 |
+
|
| 173 |
+
left_eye_vertical_distance = self.euclideanDistance(left_eye_landmark3, left_eye_landmark4)
|
| 174 |
+
left_eye_horizontal_distance = self.euclideanDistance(left_eye_landmark1, left_eye_landmark2)
|
| 175 |
+
|
| 176 |
+
right_eye_ratio = right_eye_vertical_distance / right_eye_horizontal_distance
|
| 177 |
+
left_eye_ratio = left_eye_vertical_distance / left_eye_horizontal_distance
|
| 178 |
+
|
| 179 |
+
eyes_ratio = (right_eye_ratio + left_eye_ratio) / 2
|
| 180 |
+
|
| 181 |
+
return eyes_ratio
|
| 182 |
+
|
| 183 |
+
def extract_eyes_regions(self, image, landmarks, eye_indices):
|
| 184 |
+
h, w, _ = image.shape
|
| 185 |
+
points = [(int(landmarks[idx].x * w), int(landmarks[idx].y * h)) for idx in eye_indices]
|
| 186 |
+
|
| 187 |
+
x_min = min([p[0] for p in points])
|
| 188 |
+
x_max = max([p[0] for p in points])
|
| 189 |
+
y_min = min([p[1] for p in points])
|
| 190 |
+
y_max = max([p[1] for p in points])
|
| 191 |
+
|
| 192 |
+
center_x = (x_min + x_max) // 2
|
| 193 |
+
center_y = (y_min + y_max) // 2
|
| 194 |
+
|
| 195 |
+
target_width = 32 * self.upscale
|
| 196 |
+
target_height = 16 * self.upscale
|
| 197 |
+
|
| 198 |
+
x1 = max(center_x - target_width // 2, 0)
|
| 199 |
+
y1 = max(center_y - target_height // 2, 0)
|
| 200 |
+
x2 = x1 + target_width
|
| 201 |
+
y2 = y1 + target_height
|
| 202 |
+
|
| 203 |
+
if x2 > w:
|
| 204 |
+
x1 = w - target_width
|
| 205 |
+
x2 = w
|
| 206 |
+
if y2 > h:
|
| 207 |
+
y1 = h - target_height
|
| 208 |
+
y2 = h
|
| 209 |
+
|
| 210 |
+
return image[y1:y2, x1:x2]
|
| 211 |
+
|
| 212 |
+
def blink_detection_model(self, left_eye, right_eye):
|
| 213 |
+
|
| 214 |
+
left_eye = cv2.cvtColor(left_eye, cv2.COLOR_RGB2GRAY)
|
| 215 |
+
left_eye = Image.fromarray(left_eye)
|
| 216 |
+
preds_left = self.pipe(left_eye)
|
| 217 |
+
if preds_left[0]["label"] == "closeEye":
|
| 218 |
+
closed_left = preds_left[0]["score"] >= self.blink_confidence
|
| 219 |
+
else:
|
| 220 |
+
closed_left = preds_left[1]["score"] >= self.blink_confidence
|
| 221 |
+
|
| 222 |
+
right_eye = cv2.cvtColor(right_eye, cv2.COLOR_RGB2GRAY)
|
| 223 |
+
right_eye = Image.fromarray(right_eye)
|
| 224 |
+
preds_right = self.pipe(right_eye)
|
| 225 |
+
if preds_right[0]["label"] == "closeEye":
|
| 226 |
+
closed_right = preds_right[0]["score"] >= self.blink_confidence
|
| 227 |
+
else:
|
| 228 |
+
closed_right = preds_right[1]["score"] >= self.blink_confidence
|
| 229 |
+
|
| 230 |
+
# print("preds_left = ", preds_left)
|
| 231 |
+
# print("preds_right = ", preds_right)
|
| 232 |
+
|
| 233 |
+
return closed_left or closed_right
|
| 234 |
+
|
| 235 |
+
def extract_eyes(self, image, blink_detection=False):
|
| 236 |
+
|
| 237 |
+
tmp_face = image.copy()
|
| 238 |
+
results = self.face_mesh.process(tmp_face)
|
| 239 |
+
|
| 240 |
+
if results.multi_face_landmarks is None:
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
face_landmarks = results.multi_face_landmarks[0].landmark
|
| 244 |
+
|
| 245 |
+
left_eye = self.extract_eyes_regions(image, face_landmarks, self.LEFT_EYE)
|
| 246 |
+
right_eye = self.extract_eyes_regions(image, face_landmarks, self.RIGHT_EYE)
|
| 247 |
+
blinked = False
|
| 248 |
+
eyes_ratio = None
|
| 249 |
+
|
| 250 |
+
if blink_detection:
|
| 251 |
+
mesh_coordinates = self.landmarksDetection(image, results, False)
|
| 252 |
+
eyes_ratio = self.blinkRatio(mesh_coordinates, self.RIGHT_EYE, self.LEFT_EYE)
|
| 253 |
+
if eyes_ratio > self.blink_lower_thresh and eyes_ratio <= self.blink_upper_thresh:
|
| 254 |
+
# print(
|
| 255 |
+
# "I think person blinked. eyes_ratio = ",
|
| 256 |
+
# eyes_ratio,
|
| 257 |
+
# "Confirming with ViT model...",
|
| 258 |
+
# )
|
| 259 |
+
blinked = self.blink_detection_model(left_eye=left_eye, right_eye=right_eye)
|
| 260 |
+
# if blinked:
|
| 261 |
+
# print("Yes, person blinked. Confirmed by model")
|
| 262 |
+
# else:
|
| 263 |
+
# print("No, person didn't blinked. False Alarm")
|
| 264 |
+
elif eyes_ratio <= self.blink_lower_thresh:
|
| 265 |
+
blinked = True
|
| 266 |
+
# print("Surely person blinked. eyes_ratio = ", eyes_ratio)
|
| 267 |
+
else:
|
| 268 |
+
blinked = False
|
| 269 |
+
|
| 270 |
+
return {"left_eye": left_eye, "right_eye": right_eye, "blinked": blinked, "eyes_ratio": eyes_ratio}
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def segment_iris(iris_img):
|
| 274 |
+
|
| 275 |
+
# Convert RGB image to grayscale
|
| 276 |
+
iris_img_gray = cv2.cvtColor(iris_img, cv2.COLOR_RGB2GRAY)
|
| 277 |
+
|
| 278 |
+
# Apply Gaussian blur for denoising
|
| 279 |
+
iris_img_blur = cv2.GaussianBlur(iris_img_gray, (5, 5), 0)
|
| 280 |
+
|
| 281 |
+
# Perform adaptive thresholding
|
| 282 |
+
_, iris_img_mask = cv2.threshold(iris_img_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 283 |
+
|
| 284 |
+
# Invert the mask
|
| 285 |
+
segmented_mask = cv2.bitwise_not(iris_img_mask)
|
| 286 |
+
segmented_mask = cv2.cvtColor(segmented_mask, cv2.COLOR_GRAY2RGB)
|
| 287 |
+
segmented_iris = cv2.bitwise_and(iris_img, segmented_mask)
|
| 288 |
+
|
| 289 |
+
return {
|
| 290 |
+
"segmented_iris": segmented_iris,
|
| 291 |
+
"segmented_mask": segmented_mask,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
def extract_iris(self, image):
|
| 295 |
+
|
| 296 |
+
ih, iw, _ = image.shape
|
| 297 |
+
tmp_face = image.copy()
|
| 298 |
+
results = self.face_mesh.process(tmp_face)
|
| 299 |
+
|
| 300 |
+
if results.multi_face_landmarks is None:
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
mesh_coordinates = self.landmarksDetection(image, results, False)
|
| 304 |
+
mesh_points = np.array(mesh_coordinates)
|
| 305 |
+
|
| 306 |
+
(l_cx, l_cy), l_radius = cv2.minEnclosingCircle(mesh_points[self.LEFT_IRIS])
|
| 307 |
+
(r_cx, r_cy), r_radius = cv2.minEnclosingCircle(mesh_points[self.RIGHT_IRIS])
|
| 308 |
+
|
| 309 |
+
# Crop the left iris to be exactly 16*upscaled x 16*upscaled
|
| 310 |
+
l_x1 = max(int(l_cx) - (8 * self.upscale), 0)
|
| 311 |
+
l_y1 = max(int(l_cy) - (8 * self.upscale), 0)
|
| 312 |
+
l_x2 = min(int(l_cx) + (8 * self.upscale), iw)
|
| 313 |
+
l_y2 = min(int(l_cy) + (8 * self.upscale), ih)
|
| 314 |
+
|
| 315 |
+
cropped_left_iris = image[l_y1:l_y2, l_x1:l_x2]
|
| 316 |
+
|
| 317 |
+
left_iris_segmented_data = self.segment_iris(cv2.cvtColor(cropped_left_iris, cv2.COLOR_BGR2RGB))
|
| 318 |
+
|
| 319 |
+
# Crop the right iris to be exactly 16*upscaled x 16*upscaled
|
| 320 |
+
r_x1 = max(int(r_cx) - (8 * self.upscale), 0)
|
| 321 |
+
r_y1 = max(int(r_cy) - (8 * self.upscale), 0)
|
| 322 |
+
r_x2 = min(int(r_cx) + (8 * self.upscale), iw)
|
| 323 |
+
r_y2 = min(int(r_cy) + (8 * self.upscale), ih)
|
| 324 |
+
|
| 325 |
+
cropped_right_iris = image[r_y1:r_y2, r_x1:r_x2]
|
| 326 |
+
|
| 327 |
+
right_iris_segmented_data = self.segment_iris(cv2.cvtColor(cropped_right_iris, cv2.COLOR_BGR2RGB))
|
| 328 |
+
|
| 329 |
+
return {
|
| 330 |
+
"left_iris": {
|
| 331 |
+
"img": cropped_left_iris,
|
| 332 |
+
"segmented_iris": left_iris_segmented_data["segmented_iris"],
|
| 333 |
+
"segmented_mask": left_iris_segmented_data["segmented_mask"],
|
| 334 |
+
},
|
| 335 |
+
"right_iris": {
|
| 336 |
+
"img": cropped_right_iris,
|
| 337 |
+
"segmented_iris": right_iris_segmented_data["segmented_iris"],
|
| 338 |
+
"segmented_mask": right_iris_segmented_data["segmented_mask"],
|
| 339 |
+
},
|
| 340 |
+
}
|
feature_extraction/features_extractor.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import warnings
|
| 4 |
+
import os.path as osp
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 7 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 8 |
+
sys.path.append(root_path)
|
| 9 |
+
|
| 10 |
+
from feature_extraction.extractor_mediapipe import ExtractorMediaPipe
|
| 11 |
+
|
| 12 |
+
warnings.filterwarnings("ignore")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FeaturesExtractor:
|
| 16 |
+
|
| 17 |
+
def __init__(self, extraction_library="mediapipe", blink_detection=False, upscale=1):
|
| 18 |
+
self.upscale = upscale
|
| 19 |
+
self.blink_detection = blink_detection
|
| 20 |
+
self.extraction_library = extraction_library
|
| 21 |
+
self.feature_extractor = ExtractorMediaPipe(self.upscale)
|
| 22 |
+
|
| 23 |
+
def __call__(self, image):
|
| 24 |
+
results = {}
|
| 25 |
+
face = self.feature_extractor.extract_face(image)
|
| 26 |
+
if face is None:
|
| 27 |
+
# print("No face found. Skipped feature extraction!")
|
| 28 |
+
return None
|
| 29 |
+
else:
|
| 30 |
+
results["img"] = image
|
| 31 |
+
results["face"] = face
|
| 32 |
+
eyes_data = self.feature_extractor.extract_eyes(image, self.blink_detection)
|
| 33 |
+
if eyes_data is None:
|
| 34 |
+
# print("No eyes found. Skipped feature extraction!")
|
| 35 |
+
return results
|
| 36 |
+
else:
|
| 37 |
+
results["eyes"] = eyes_data
|
| 38 |
+
if eyes_data["blinked"]:
|
| 39 |
+
# print("Found blinked eyes!")
|
| 40 |
+
return results
|
| 41 |
+
else:
|
| 42 |
+
iris_data = self.feature_extractor.extract_iris(image)
|
| 43 |
+
if iris_data is None:
|
| 44 |
+
# print("No iris found. Skipped feature extraction!")
|
| 45 |
+
return results
|
| 46 |
+
else:
|
| 47 |
+
results["iris"] = iris_data
|
| 48 |
+
return results
|
gradio_app.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import numpy as np
|
| 6 |
+
import tempfile
|
| 7 |
+
from PIL import Image, ImageOps
|
| 8 |
+
import cv2
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import io
|
| 11 |
+
import base64
|
| 12 |
+
|
| 13 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
| 14 |
+
sys.path.append(root_path)
|
| 15 |
+
|
| 16 |
+
from registry_utils import import_registered_modules
|
| 17 |
+
from gradio_utils import (
|
| 18 |
+
is_image,
|
| 19 |
+
is_video,
|
| 20 |
+
extract_frames,
|
| 21 |
+
resize_frame,
|
| 22 |
+
CAM_METHODS,
|
| 23 |
+
process_frames_gradio,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
import_registered_modules()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def process_image_gradio(image, pupil_selection, tv_model, blink_detection):
|
| 30 |
+
"""
|
| 31 |
+
Process a single image and return results for Gradio interface.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
image: PIL Image or numpy array
|
| 35 |
+
pupil_selection: str - "left_pupil", "right_pupil", or "both"
|
| 36 |
+
tv_model: str - "ResNet18" or "ResNet50"
|
| 37 |
+
blink_detection: bool - whether to detect blinks
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
tuple: (input_image, cam_overlay, diameter_text, results_plot)
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
# Convert to PIL Image if needed
|
| 44 |
+
if isinstance(image, np.ndarray):
|
| 45 |
+
image = Image.fromarray(image)
|
| 46 |
+
|
| 47 |
+
# Handle EXIF rotation
|
| 48 |
+
image = ImageOps.exif_transpose(image)
|
| 49 |
+
|
| 50 |
+
# Resize image
|
| 51 |
+
image = resize_frame(image, max_width=640, max_height=480)
|
| 52 |
+
|
| 53 |
+
# Process the image using Gradio-compatible function
|
| 54 |
+
input_frames, output_frames, predicted_diameters = process_frames_gradio(
|
| 55 |
+
input_imgs=[image],
|
| 56 |
+
tv_model=tv_model,
|
| 57 |
+
pupil_selection=pupil_selection,
|
| 58 |
+
blink_detection=blink_detection,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Check if processing failed (empty results)
|
| 62 |
+
if not input_frames or not output_frames or not predicted_diameters:
|
| 63 |
+
error_msg = "Could not detect face/eyes in the image. Please try with a clearer image showing eyes."
|
| 64 |
+
error_img = Image.new('RGB', (400, 200), 'white')
|
| 65 |
+
return error_img, error_msg
|
| 66 |
+
|
| 67 |
+
# Create visualization
|
| 68 |
+
results = []
|
| 69 |
+
diameter_results = []
|
| 70 |
+
|
| 71 |
+
for eye_type in input_frames.keys():
|
| 72 |
+
input_img = input_frames[eye_type][-1]
|
| 73 |
+
output_img = output_frames[eye_type][-1]
|
| 74 |
+
diameter = predicted_diameters[eye_type][0]
|
| 75 |
+
|
| 76 |
+
# Create side-by-side comparison
|
| 77 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
|
| 78 |
+
|
| 79 |
+
ax1.imshow(input_img)
|
| 80 |
+
ax1.set_title(f"Input - {eye_type.replace('_', ' ').title()}")
|
| 81 |
+
ax1.axis('off')
|
| 82 |
+
|
| 83 |
+
ax2.imshow(output_img)
|
| 84 |
+
ax2.set_title(f"CAM Overlay - {eye_type.replace('_', ' ').title()}")
|
| 85 |
+
ax2.axis('off')
|
| 86 |
+
|
| 87 |
+
plt.tight_layout()
|
| 88 |
+
|
| 89 |
+
# Convert plot to image
|
| 90 |
+
buf = io.BytesIO()
|
| 91 |
+
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
|
| 92 |
+
buf.seek(0)
|
| 93 |
+
plot_img = Image.open(buf)
|
| 94 |
+
plt.close()
|
| 95 |
+
|
| 96 |
+
results.append(plot_img)
|
| 97 |
+
|
| 98 |
+
# Format diameter result
|
| 99 |
+
if isinstance(diameter, str):
|
| 100 |
+
diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter}")
|
| 101 |
+
else:
|
| 102 |
+
diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter:.2f} mm")
|
| 103 |
+
|
| 104 |
+
# Combine results if multiple eyes
|
| 105 |
+
if len(results) == 1:
|
| 106 |
+
final_image = results[0]
|
| 107 |
+
else:
|
| 108 |
+
# Combine multiple eye results
|
| 109 |
+
total_width = sum(img.width for img in results)
|
| 110 |
+
max_height = max(img.height for img in results)
|
| 111 |
+
final_image = Image.new('RGB', (total_width, max_height), 'white')
|
| 112 |
+
x_offset = 0
|
| 113 |
+
for img in results:
|
| 114 |
+
final_image.paste(img, (x_offset, 0))
|
| 115 |
+
x_offset += img.width
|
| 116 |
+
|
| 117 |
+
diameter_text = "\n".join(diameter_results)
|
| 118 |
+
|
| 119 |
+
return final_image, diameter_text
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
error_msg = f"Error processing image: {str(e)}"
|
| 123 |
+
# Create error image
|
| 124 |
+
error_img = Image.new('RGB', (400, 200), 'white')
|
| 125 |
+
return error_img, error_msg
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def process_video_gradio(video_file, pupil_selection, tv_model, blink_detection):
|
| 129 |
+
"""
|
| 130 |
+
Process a video file and return results for Gradio interface.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
video_file: file path or file object
|
| 134 |
+
pupil_selection: str - "left_pupil", "right_pupil", or "both"
|
| 135 |
+
tv_model: str - "ResNet18" or "ResNet50"
|
| 136 |
+
blink_detection: bool - whether to detect blinks
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
tuple: (results_plot, diameter_data, summary_text)
|
| 140 |
+
"""
|
| 141 |
+
try:
|
| 142 |
+
# Handle video file
|
| 143 |
+
if hasattr(video_file, 'name'):
|
| 144 |
+
video_path = video_file.name
|
| 145 |
+
else:
|
| 146 |
+
video_path = video_file
|
| 147 |
+
|
| 148 |
+
# Extract frames
|
| 149 |
+
video_frames = extract_frames(video_path)
|
| 150 |
+
|
| 151 |
+
if not video_frames:
|
| 152 |
+
return None, "No frames extracted from video", "Error: Could not process video"
|
| 153 |
+
|
| 154 |
+
# Resize frames
|
| 155 |
+
resized_frames = []
|
| 156 |
+
for frame in video_frames:
|
| 157 |
+
if isinstance(frame, np.ndarray):
|
| 158 |
+
frame = Image.fromarray(frame)
|
| 159 |
+
input_img = resize_frame(frame, max_width=640, max_height=480)
|
| 160 |
+
resized_frames.append(input_img)
|
| 161 |
+
|
| 162 |
+
# Process video frames using Gradio-compatible function
|
| 163 |
+
input_frames, output_frames, predicted_diameters = process_frames_gradio(
|
| 164 |
+
input_imgs=resized_frames,
|
| 165 |
+
tv_model=tv_model,
|
| 166 |
+
pupil_selection=pupil_selection,
|
| 167 |
+
blink_detection=blink_detection,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Check if processing failed (empty results)
|
| 171 |
+
if not input_frames or not output_frames or not predicted_diameters:
|
| 172 |
+
error_msg = "Could not process video. MediaPipe may have issues in this environment."
|
| 173 |
+
error_img = Image.new('RGB', (400, 200), 'white')
|
| 174 |
+
return error_img, "", error_msg
|
| 175 |
+
|
| 176 |
+
# Create results visualization
|
| 177 |
+
fig, axes = plt.subplots(len(predicted_diameters), 1, figsize=(12, 6 * len(predicted_diameters)))
|
| 178 |
+
if len(predicted_diameters) == 1:
|
| 179 |
+
axes = [axes]
|
| 180 |
+
|
| 181 |
+
summary_stats = []
|
| 182 |
+
|
| 183 |
+
for idx, (eye_type, diameters) in enumerate(predicted_diameters.items()):
|
| 184 |
+
# Filter out non-numeric values (like "blink")
|
| 185 |
+
numeric_diameters = [d for d in diameters if isinstance(d, (int, float))]
|
| 186 |
+
frame_numbers = list(range(len(diameters)))
|
| 187 |
+
|
| 188 |
+
# Plot diameter over time
|
| 189 |
+
axes[idx].plot(frame_numbers, diameters, marker='o', markersize=2)
|
| 190 |
+
axes[idx].set_title(f"Pupil Diameter Over Time - {eye_type.replace('_', ' ').title()}")
|
| 191 |
+
axes[idx].set_xlabel("Frame Number")
|
| 192 |
+
axes[idx].set_ylabel("Diameter (mm)")
|
| 193 |
+
axes[idx].grid(True, alpha=0.3)
|
| 194 |
+
|
| 195 |
+
# Calculate statistics
|
| 196 |
+
if numeric_diameters:
|
| 197 |
+
mean_diameter = np.mean(numeric_diameters)
|
| 198 |
+
std_diameter = np.std(numeric_diameters)
|
| 199 |
+
min_diameter = np.min(numeric_diameters)
|
| 200 |
+
max_diameter = np.max(numeric_diameters)
|
| 201 |
+
|
| 202 |
+
summary_stats.append(f"{eye_type.replace('_', ' ').title()}:")
|
| 203 |
+
summary_stats.append(f" Mean: {mean_diameter:.2f} mm")
|
| 204 |
+
summary_stats.append(f" Std: {std_diameter:.2f} mm")
|
| 205 |
+
summary_stats.append(f" Min: {min_diameter:.2f} mm")
|
| 206 |
+
summary_stats.append(f" Max: {max_diameter:.2f} mm")
|
| 207 |
+
summary_stats.append("")
|
| 208 |
+
|
| 209 |
+
plt.tight_layout()
|
| 210 |
+
|
| 211 |
+
# Convert plot to image
|
| 212 |
+
buf = io.BytesIO()
|
| 213 |
+
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
|
| 214 |
+
buf.seek(0)
|
| 215 |
+
plot_img = Image.open(buf)
|
| 216 |
+
plt.close()
|
| 217 |
+
|
| 218 |
+
# Create summary text
|
| 219 |
+
summary_text = f"Processed {len(video_frames)} frames\n\n" + "\n".join(summary_stats)
|
| 220 |
+
|
| 221 |
+
# Create CSV data for download
|
| 222 |
+
csv_data = "Frame,Eye_Type,Diameter_mm\n"
|
| 223 |
+
for eye_type, diameters in predicted_diameters.items():
|
| 224 |
+
for frame_idx, diameter in enumerate(diameters):
|
| 225 |
+
csv_data += f"{frame_idx},{eye_type},{diameter}\n"
|
| 226 |
+
|
| 227 |
+
# Clean up temporary files if they exist
|
| 228 |
+
# (output_video_path not used in this implementation)
|
| 229 |
+
|
| 230 |
+
return plot_img, csv_data, summary_text
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
error_msg = f"Error processing video: {str(e)}"
|
| 234 |
+
error_img = Image.new('RGB', (400, 200), 'white')
|
| 235 |
+
return error_img, "", error_msg
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def process_media_unified(media_input, pupil_selection, tv_model, blink_detection):
|
| 239 |
+
"""
|
| 240 |
+
Unified processing function that handles both images and videos.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
media_input: Either an image (PIL) or video file path
|
| 244 |
+
pupil_selection: str - "left_pupil", "right_pupil", or "both"
|
| 245 |
+
tv_model: str - "ResNet18" or "ResNet50"
|
| 246 |
+
blink_detection: bool - whether to detect blinks
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
tuple: (result_image, result_text)
|
| 250 |
+
"""
|
| 251 |
+
try:
|
| 252 |
+
# Check if input is an image or video
|
| 253 |
+
if hasattr(media_input, 'name'):
|
| 254 |
+
# It's a file object (video)
|
| 255 |
+
file_path = media_input.name
|
| 256 |
+
if is_video(file_path):
|
| 257 |
+
plot_img, csv_data, summary_text = process_video_gradio(media_input, pupil_selection, tv_model, blink_detection)
|
| 258 |
+
combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
|
| 259 |
+
return plot_img, combined_output
|
| 260 |
+
elif is_image(file_path):
|
| 261 |
+
# Convert file to PIL Image
|
| 262 |
+
from PIL import Image
|
| 263 |
+
image = Image.open(file_path)
|
| 264 |
+
return process_image_gradio(image, pupil_selection, tv_model, blink_detection)
|
| 265 |
+
else:
|
| 266 |
+
# It's a PIL Image
|
| 267 |
+
return process_image_gradio(media_input, pupil_selection, tv_model, blink_detection)
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
error_msg = f"Error processing media: {str(e)}"
|
| 271 |
+
from PIL import Image
|
| 272 |
+
error_img = Image.new('RGB', (400, 200), 'white')
|
| 273 |
+
return error_img, error_msg
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def create_gradio_interface():
|
| 277 |
+
"""Create and configure the Gradio interface with proper API support."""
|
| 278 |
+
|
| 279 |
+
# Create a unified interface that can handle both images and videos
|
| 280 |
+
with gr.Blocks(title="👁️ PupilSense 👁️🕵️♂️") as demo:
|
| 281 |
+
gr.Markdown("# 👁️ PupilSense - Pupil Diameter Analysis")
|
| 282 |
+
gr.Markdown("Upload an image or video to estimate pupil diameter using deep learning models.")
|
| 283 |
+
|
| 284 |
+
with gr.Tab("Image Processing"):
|
| 285 |
+
with gr.Row():
|
| 286 |
+
with gr.Column():
|
| 287 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
| 288 |
+
image_pupil_selection = gr.Dropdown(
|
| 289 |
+
["left_pupil", "right_pupil", "both"],
|
| 290 |
+
value="both",
|
| 291 |
+
label="Pupil Selection"
|
| 292 |
+
)
|
| 293 |
+
image_model = gr.Dropdown(
|
| 294 |
+
["ResNet18", "ResNet50"],
|
| 295 |
+
value="ResNet18",
|
| 296 |
+
label="Model"
|
| 297 |
+
)
|
| 298 |
+
image_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
|
| 299 |
+
image_submit = gr.Button("Process Image", variant="primary")
|
| 300 |
+
|
| 301 |
+
with gr.Column():
|
| 302 |
+
image_output = gr.Image(label="Results")
|
| 303 |
+
image_text_output = gr.Textbox(label="Pupil Diameter Results", lines=5)
|
| 304 |
+
|
| 305 |
+
image_submit.click(
|
| 306 |
+
fn=process_image_simple,
|
| 307 |
+
inputs=[image_input, image_pupil_selection, image_model, image_blink_detection],
|
| 308 |
+
outputs=[image_output, image_text_output]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
with gr.Tab("Video Processing"):
|
| 312 |
+
with gr.Row():
|
| 313 |
+
with gr.Column():
|
| 314 |
+
video_input = gr.Video(label="Upload Video")
|
| 315 |
+
video_pupil_selection = gr.Dropdown(
|
| 316 |
+
["left_pupil", "right_pupil", "both"],
|
| 317 |
+
value="both",
|
| 318 |
+
label="Pupil Selection"
|
| 319 |
+
)
|
| 320 |
+
video_model = gr.Dropdown(
|
| 321 |
+
["ResNet18", "ResNet50"],
|
| 322 |
+
value="ResNet18",
|
| 323 |
+
label="Model"
|
| 324 |
+
)
|
| 325 |
+
video_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
|
| 326 |
+
video_submit = gr.Button("Process Video", variant="primary")
|
| 327 |
+
|
| 328 |
+
with gr.Column():
|
| 329 |
+
video_output = gr.Image(label="Diameter Analysis")
|
| 330 |
+
video_text_output = gr.Textbox(label="Summary Statistics", lines=10)
|
| 331 |
+
|
| 332 |
+
video_submit.click(
|
| 333 |
+
fn=process_video_simple,
|
| 334 |
+
inputs=[video_input, video_pupil_selection, video_model, video_blink_detection],
|
| 335 |
+
outputs=[video_output, video_text_output]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Add a unified API endpoint that can handle both images and videos
|
| 339 |
+
with gr.Tab("API Testing"):
|
| 340 |
+
gr.Markdown("### API Endpoint for External Access")
|
| 341 |
+
gr.Markdown("This endpoint can process both images and videos programmatically.")
|
| 342 |
+
|
| 343 |
+
with gr.Row():
|
| 344 |
+
with gr.Column():
|
| 345 |
+
api_media_input = gr.File(label="Upload Image or Video File")
|
| 346 |
+
api_pupil_selection = gr.Dropdown(
|
| 347 |
+
["left_pupil", "right_pupil", "both"],
|
| 348 |
+
value="both",
|
| 349 |
+
label="Pupil Selection"
|
| 350 |
+
)
|
| 351 |
+
api_model = gr.Dropdown(
|
| 352 |
+
["ResNet18", "ResNet50"],
|
| 353 |
+
value="ResNet18",
|
| 354 |
+
label="Model"
|
| 355 |
+
)
|
| 356 |
+
api_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
|
| 357 |
+
api_submit = gr.Button("Process Media", variant="primary")
|
| 358 |
+
|
| 359 |
+
with gr.Column():
|
| 360 |
+
api_output = gr.Image(label="Results")
|
| 361 |
+
api_text_output = gr.Textbox(label="Analysis Results", lines=10)
|
| 362 |
+
|
| 363 |
+
api_submit.click(
|
| 364 |
+
fn=process_media_unified,
|
| 365 |
+
inputs=[api_media_input, api_pupil_selection, api_model, api_blink_detection],
|
| 366 |
+
outputs=[api_output, api_text_output]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return demo
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def process_image_simple(image, pupil_selection, tv_model, blink_detection):
|
| 373 |
+
"""Simplified image processing function for gr.Interface."""
|
| 374 |
+
result_image, result_text = process_image_gradio(image, pupil_selection, tv_model, blink_detection)
|
| 375 |
+
return result_image, result_text
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def process_video_simple(video_file, pupil_selection, tv_model, blink_detection):
|
| 379 |
+
"""Simplified video processing function for gr.Interface."""
|
| 380 |
+
plot_img, csv_data, summary_text = process_video_gradio(video_file, pupil_selection, tv_model, blink_detection)
|
| 381 |
+
# Combine summary and CSV data for single output
|
| 382 |
+
combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
|
| 383 |
+
return plot_img, combined_output
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
if __name__ == "__main__":
|
| 387 |
+
demo = create_gradio_interface()
|
| 388 |
+
demo.launch()
|
gradio_utils.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import cv2
|
| 7 |
+
from matplotlib import pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import tempfile
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torchvision.transforms.functional import to_pil_image
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from PIL import ImageOps
|
| 15 |
+
import os.path as osp
|
| 16 |
+
|
| 17 |
+
from torchcam.methods import CAM
|
| 18 |
+
from torchcam import methods as torchcam_methods
|
| 19 |
+
from torchcam.utils import overlay_mask
|
| 20 |
+
|
| 21 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
| 22 |
+
sys.path.append(root_path)
|
| 23 |
+
|
| 24 |
+
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
| 25 |
+
from utils import get_model
|
| 26 |
+
|
| 27 |
+
CAM_METHODS = ["CAM"]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def load_model(model_configs, device="cpu"):
|
| 32 |
+
"""Loads the pre-trained model."""
|
| 33 |
+
model_path = os.path.join(root_path, model_configs["model_path"])
|
| 34 |
+
model_dict = torch.load(model_path, map_location=device)
|
| 35 |
+
model = get_model(model_configs=model_configs)
|
| 36 |
+
model.load_state_dict(model_dict)
|
| 37 |
+
model = model.to(device).eval()
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def extract_frames(video_path):
|
| 42 |
+
"""Extracts frames from a video file."""
|
| 43 |
+
vidcap = cv2.VideoCapture(video_path)
|
| 44 |
+
frames = []
|
| 45 |
+
success, image = vidcap.read()
|
| 46 |
+
while success:
|
| 47 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 48 |
+
frames.append(image_rgb)
|
| 49 |
+
success, image = vidcap.read()
|
| 50 |
+
vidcap.release()
|
| 51 |
+
return frames
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def resize_frame(frame, max_width=640, max_height=480):
|
| 55 |
+
"""Resizes a frame while maintaining aspect ratio."""
|
| 56 |
+
if isinstance(frame, np.ndarray):
|
| 57 |
+
frame = Image.fromarray(frame)
|
| 58 |
+
|
| 59 |
+
# Calculate the scaling factor
|
| 60 |
+
width, height = frame.size
|
| 61 |
+
scale_w = max_width / width
|
| 62 |
+
scale_h = max_height / height
|
| 63 |
+
scale = min(scale_w, scale_h)
|
| 64 |
+
|
| 65 |
+
# Resize the frame
|
| 66 |
+
new_width = int(width * scale)
|
| 67 |
+
new_height = int(height * scale)
|
| 68 |
+
return frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_image(file_extension):
|
| 72 |
+
"""Check if file extension is an image format."""
|
| 73 |
+
return file_extension.lower() in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def is_video(file_extension):
|
| 77 |
+
"""Check if file extension is a video format."""
|
| 78 |
+
return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm", "flv", "wmv"]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_configs(blink_detection=False):
|
| 82 |
+
"""Get configuration for feature extraction."""
|
| 83 |
+
upscale = "-"
|
| 84 |
+
upscale_method_or_model = "-"
|
| 85 |
+
if upscale == "-":
|
| 86 |
+
sr_configs = None
|
| 87 |
+
else:
|
| 88 |
+
sr_configs = {
|
| 89 |
+
"method": upscale_method_or_model,
|
| 90 |
+
"params": {"upscale": upscale},
|
| 91 |
+
}
|
| 92 |
+
config_file = {
|
| 93 |
+
"sr_configs": sr_configs,
|
| 94 |
+
"feature_extraction_configs": {
|
| 95 |
+
"blink_detection": blink_detection,
|
| 96 |
+
"upscale": upscale,
|
| 97 |
+
"extraction_library": "mediapipe",
|
| 98 |
+
},
|
| 99 |
+
}
|
| 100 |
+
return config_file
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def setup_gradio(pupil_selection, tv_model):
|
| 104 |
+
"""Setup models and data structures for Gradio processing."""
|
| 105 |
+
left_pupil_model = None
|
| 106 |
+
left_pupil_cam_extractor = None
|
| 107 |
+
right_pupil_model = None
|
| 108 |
+
right_pupil_cam_extractor = None
|
| 109 |
+
output_frames = {}
|
| 110 |
+
input_frames = {}
|
| 111 |
+
predicted_diameters = {}
|
| 112 |
+
|
| 113 |
+
if pupil_selection == "both":
|
| 114 |
+
selected_eyes = ["left_eye", "right_eye"]
|
| 115 |
+
elif pupil_selection == "left_pupil":
|
| 116 |
+
selected_eyes = ["left_eye"]
|
| 117 |
+
elif pupil_selection == "right_pupil":
|
| 118 |
+
selected_eyes = ["right_eye"]
|
| 119 |
+
|
| 120 |
+
for eye_type in selected_eyes:
|
| 121 |
+
model_configs = {
|
| 122 |
+
"model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
|
| 123 |
+
"registered_model_name": tv_model,
|
| 124 |
+
"num_classes": 1,
|
| 125 |
+
}
|
| 126 |
+
if eye_type == "left_eye":
|
| 127 |
+
left_pupil_model = load_model(model_configs)
|
| 128 |
+
left_pupil_cam_extractor = None
|
| 129 |
+
else:
|
| 130 |
+
right_pupil_model = load_model(model_configs)
|
| 131 |
+
right_pupil_cam_extractor = None
|
| 132 |
+
|
| 133 |
+
output_frames[eye_type] = []
|
| 134 |
+
input_frames[eye_type] = []
|
| 135 |
+
predicted_diameters[eye_type] = []
|
| 136 |
+
|
| 137 |
+
return (
|
| 138 |
+
selected_eyes,
|
| 139 |
+
input_frames,
|
| 140 |
+
output_frames,
|
| 141 |
+
predicted_diameters,
|
| 142 |
+
left_pupil_model,
|
| 143 |
+
left_pupil_cam_extractor,
|
| 144 |
+
right_pupil_model,
|
| 145 |
+
right_pupil_cam_extractor,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def process_frames_gradio(input_imgs, tv_model, pupil_selection, blink_detection=False):
|
| 150 |
+
"""
|
| 151 |
+
Process frames without Streamlit dependencies.
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
config_file = get_configs(blink_detection)
|
| 155 |
+
|
| 156 |
+
(
|
| 157 |
+
selected_eyes,
|
| 158 |
+
input_frames,
|
| 159 |
+
output_frames,
|
| 160 |
+
predicted_diameters,
|
| 161 |
+
left_pupil_model,
|
| 162 |
+
left_pupil_cam_extractor,
|
| 163 |
+
right_pupil_model,
|
| 164 |
+
right_pupil_cam_extractor,
|
| 165 |
+
) = setup_gradio(pupil_selection, tv_model)
|
| 166 |
+
|
| 167 |
+
ds_creation = EyeDentityDatasetCreation(
|
| 168 |
+
feature_extraction_configs=config_file["feature_extraction_configs"],
|
| 169 |
+
sr_configs=config_file["sr_configs"],
|
| 170 |
+
)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
print(f"Error in setup: {e}")
|
| 173 |
+
# Return empty results if setup fails
|
| 174 |
+
return {}, {}, {}
|
| 175 |
+
|
| 176 |
+
preprocess_steps = [
|
| 177 |
+
transforms.Resize(
|
| 178 |
+
[32, 64],
|
| 179 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 180 |
+
antialias=True,
|
| 181 |
+
),
|
| 182 |
+
transforms.ToTensor(),
|
| 183 |
+
]
|
| 184 |
+
preprocess_function = transforms.Compose(preprocess_steps)
|
| 185 |
+
|
| 186 |
+
for idx, input_img in enumerate(input_imgs):
|
| 187 |
+
try:
|
| 188 |
+
img = np.array(input_img)
|
| 189 |
+
ds_results = ds_creation(img)
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Error in MediaPipe processing for frame {idx}: {e}")
|
| 192 |
+
ds_results = None
|
| 193 |
+
|
| 194 |
+
left_eye = None
|
| 195 |
+
right_eye = None
|
| 196 |
+
blinked = False
|
| 197 |
+
|
| 198 |
+
if ds_results is not None and "face" in ds_results:
|
| 199 |
+
has_face = True
|
| 200 |
+
else:
|
| 201 |
+
has_face = False
|
| 202 |
+
|
| 203 |
+
if has_face and ds_results is not None:
|
| 204 |
+
if blink_detection and "blinks" in ds_results:
|
| 205 |
+
blinked = ds_results["blinks"]["blinked"]
|
| 206 |
+
|
| 207 |
+
if not blinked and "eyes" in ds_results:
|
| 208 |
+
if "left_eye" in ds_results["eyes"] and ds_results["eyes"]["left_eye"] is not None:
|
| 209 |
+
left_eye_img = to_pil_image(ds_results["eyes"]["left_eye"])
|
| 210 |
+
input_img_tensor = preprocess_function(left_eye_img)
|
| 211 |
+
input_img_tensor = input_img_tensor.unsqueeze(0)
|
| 212 |
+
if pupil_selection in ["left_pupil", "both"]:
|
| 213 |
+
left_eye = input_img_tensor
|
| 214 |
+
|
| 215 |
+
if "right_eye" in ds_results["eyes"] and ds_results["eyes"]["right_eye"] is not None:
|
| 216 |
+
right_eye_img = to_pil_image(ds_results["eyes"]["right_eye"])
|
| 217 |
+
input_img_tensor = preprocess_function(right_eye_img)
|
| 218 |
+
input_img_tensor = input_img_tensor.unsqueeze(0)
|
| 219 |
+
if pupil_selection in ["right_pupil", "both"]:
|
| 220 |
+
right_eye = input_img_tensor
|
| 221 |
+
|
| 222 |
+
for eye_type in selected_eyes:
|
| 223 |
+
if blinked:
|
| 224 |
+
if left_eye is not None and eye_type == "left_eye":
|
| 225 |
+
_, height, width = left_eye.squeeze(0).shape
|
| 226 |
+
input_image_pil = to_pil_image(left_eye.squeeze(0))
|
| 227 |
+
elif right_eye is not None and eye_type == "right_eye":
|
| 228 |
+
_, height, width = right_eye.squeeze(0).shape
|
| 229 |
+
input_image_pil = to_pil_image(right_eye.squeeze(0))
|
| 230 |
+
else:
|
| 231 |
+
# Create a default black image if no eye detected
|
| 232 |
+
input_image_pil = Image.new('RGB', (64, 32), 'black')
|
| 233 |
+
height, width = 32, 64
|
| 234 |
+
|
| 235 |
+
input_img_np = np.array(input_image_pil)
|
| 236 |
+
zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8))
|
| 237 |
+
output_img_np = np.array(zeros_img)
|
| 238 |
+
predicted_diameter = "blink"
|
| 239 |
+
else:
|
| 240 |
+
if left_eye is not None and eye_type == "left_eye":
|
| 241 |
+
if left_pupil_cam_extractor is None:
|
| 242 |
+
if tv_model == "ResNet18":
|
| 243 |
+
target_layer = left_pupil_model.resnet.layer4[-1].conv2
|
| 244 |
+
elif tv_model == "ResNet50":
|
| 245 |
+
target_layer = left_pupil_model.resnet.layer4[-1].conv3
|
| 246 |
+
else:
|
| 247 |
+
raise Exception(f"No target layer available for selected model: {tv_model}")
|
| 248 |
+
left_pupil_cam_extractor = torchcam_methods.__dict__["CAM"](
|
| 249 |
+
left_pupil_model,
|
| 250 |
+
target_layer=target_layer,
|
| 251 |
+
fc_layer=left_pupil_model.resnet.fc,
|
| 252 |
+
input_shape=left_eye.shape,
|
| 253 |
+
)
|
| 254 |
+
output = left_pupil_model(left_eye)
|
| 255 |
+
predicted_diameter = output[0].item()
|
| 256 |
+
act_maps = left_pupil_cam_extractor(0, output)
|
| 257 |
+
activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps)
|
| 258 |
+
input_image_pil = to_pil_image(left_eye.squeeze(0))
|
| 259 |
+
elif right_eye is not None and eye_type == "right_eye":
|
| 260 |
+
if right_pupil_cam_extractor is None:
|
| 261 |
+
if tv_model == "ResNet18":
|
| 262 |
+
target_layer = right_pupil_model.resnet.layer4[-1].conv2
|
| 263 |
+
elif tv_model == "ResNet50":
|
| 264 |
+
target_layer = right_pupil_model.resnet.layer4[-1].conv3
|
| 265 |
+
else:
|
| 266 |
+
raise Exception(f"No target layer available for selected model: {tv_model}")
|
| 267 |
+
right_pupil_cam_extractor = torchcam_methods.__dict__["CAM"](
|
| 268 |
+
right_pupil_model,
|
| 269 |
+
target_layer=target_layer,
|
| 270 |
+
fc_layer=right_pupil_model.resnet.fc,
|
| 271 |
+
input_shape=right_eye.shape,
|
| 272 |
+
)
|
| 273 |
+
output = right_pupil_model(right_eye)
|
| 274 |
+
predicted_diameter = output[0].item()
|
| 275 |
+
act_maps = right_pupil_cam_extractor(0, output)
|
| 276 |
+
activation_map = (
|
| 277 |
+
act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps)
|
| 278 |
+
)
|
| 279 |
+
input_image_pil = to_pil_image(right_eye.squeeze(0))
|
| 280 |
+
else:
|
| 281 |
+
# No eye detected, create default values
|
| 282 |
+
input_image_pil = Image.new('RGB', (64, 32), 'black')
|
| 283 |
+
predicted_diameter = "no_eye_detected"
|
| 284 |
+
output_img_np = np.array(input_image_pil)
|
| 285 |
+
input_frames[eye_type].append(np.array(input_image_pil))
|
| 286 |
+
output_frames[eye_type].append(output_img_np)
|
| 287 |
+
predicted_diameters[eye_type].append(predicted_diameter)
|
| 288 |
+
continue
|
| 289 |
+
|
| 290 |
+
# Create CAM overlay
|
| 291 |
+
activation_map_pil = to_pil_image(activation_map, mode="F")
|
| 292 |
+
result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5)
|
| 293 |
+
input_img_np = np.array(input_image_pil)
|
| 294 |
+
output_img_np = np.array(result)
|
| 295 |
+
|
| 296 |
+
input_frames[eye_type].append(input_img_np)
|
| 297 |
+
output_frames[eye_type].append(output_img_np)
|
| 298 |
+
predicted_diameters[eye_type].append(predicted_diameter)
|
| 299 |
+
|
| 300 |
+
return input_frames, output_frames, predicted_diameters
|
pre_trained_models/ResNet18/left_eye.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98fb2c7880165c59ff975e02cd9e614fcf3a5859455f8d85695f57497dd894e6
|
| 3 |
+
size 46843194
|
pre_trained_models/ResNet18/right_eye.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68e2928f13900580bcb9b7c1a1f6d4bba863cfcfee2def944b49ef0c09337668
|
| 3 |
+
size 46843194
|
pre_trained_models/ResNet50/left_eye.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5bd4bac728b71dae9e759b86188206a4f38fbc83b9507dd08f2a6abe1568d995
|
| 3 |
+
size 102554624
|
pre_trained_models/ResNet50/right_eye.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5179f569ea1886c9ad63ca9d047fdf721a9b59a63313cd9da3f2e3fae25de73
|
| 3 |
+
size 102554624
|
preprocessing/dataset_creation.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import cv2
|
| 3 |
+
import os.path as osp
|
| 4 |
+
|
| 5 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 6 |
+
sys.path.append(root_path)
|
| 7 |
+
|
| 8 |
+
from feature_extraction.features_extractor import FeaturesExtractor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EyeDentityDatasetCreation:
|
| 12 |
+
|
| 13 |
+
def __init__(self, feature_extraction_configs, sr_configs=None):
|
| 14 |
+
self.extraction_library = feature_extraction_configs["extraction_library"]
|
| 15 |
+
self.upscale = 1
|
| 16 |
+
|
| 17 |
+
self.blink_detection = feature_extraction_configs["blink_detection"]
|
| 18 |
+
self.features_extractor = FeaturesExtractor(
|
| 19 |
+
extraction_library=self.extraction_library,
|
| 20 |
+
blink_detection=self.blink_detection,
|
| 21 |
+
upscale=self.upscale,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def __call__(self, img):
|
| 25 |
+
result_dict = self.features_extractor(img)
|
| 26 |
+
return result_dict
|
preprocessing/dataset_creation_utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def seed_everything(seed=42):
|
| 8 |
+
random.seed(seed)
|
| 9 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 10 |
+
np.random.seed(seed)
|
| 11 |
+
torch.manual_seed(seed)
|
| 12 |
+
torch.cuda.manual_seed(seed)
|
| 13 |
+
torch.backends.cudnn.benchmark = True
|
| 14 |
+
torch.backends.cudnn.deterministic = True
|
registrations/models.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os.path as osp
|
| 4 |
+
from torchvision import models
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from registry import MODEL_REGISTRY
|
| 7 |
+
|
| 8 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
| 9 |
+
sys.path.append(root_path)
|
| 10 |
+
|
| 11 |
+
# ============================= ResNets =============================
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@MODEL_REGISTRY.register()
|
| 15 |
+
class ResNet18(nn.Module):
|
| 16 |
+
def __init__(self, model_args):
|
| 17 |
+
super(ResNet18, self).__init__()
|
| 18 |
+
self.num_classes = model_args.get("num_classes", 1)
|
| 19 |
+
self.resnet = models.resnet18(weights=None)
|
| 20 |
+
self.regression_head = nn.Linear(1000, self.num_classes)
|
| 21 |
+
|
| 22 |
+
def forward(self, x, masks=None):
|
| 23 |
+
# Calculate the padding dynamically based on the input size
|
| 24 |
+
height, width = x.shape[2], x.shape[3]
|
| 25 |
+
pad_height = max(0, (224 - height) // 2)
|
| 26 |
+
pad_width = max(0, (224 - width) // 2)
|
| 27 |
+
|
| 28 |
+
# Apply padding
|
| 29 |
+
x = F.pad(x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0)
|
| 30 |
+
x = self.resnet(x)
|
| 31 |
+
x = self.regression_head(x)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@MODEL_REGISTRY.register()
|
| 36 |
+
class ResNet50(nn.Module):
|
| 37 |
+
def __init__(self, model_args):
|
| 38 |
+
super(ResNet50, self).__init__()
|
| 39 |
+
self.num_classes = model_args.get("num_classes", 1)
|
| 40 |
+
self.resnet = models.resnet50(weights=None)
|
| 41 |
+
self.regression_head = nn.Linear(1000, self.num_classes)
|
| 42 |
+
|
| 43 |
+
def forward(self, x, masks=None):
|
| 44 |
+
# Calculate the padding dynamically based on the input size
|
| 45 |
+
height, width = x.shape[2], x.shape[3]
|
| 46 |
+
pad_height = max(0, (224 - height) // 2)
|
| 47 |
+
pad_width = max(0, (224 - width) // 2)
|
| 48 |
+
|
| 49 |
+
# Apply padding
|
| 50 |
+
x = F.pad(x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0)
|
| 51 |
+
x = self.resnet(x)
|
| 52 |
+
x = self.regression_head(x)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
|
registry.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Registry:
|
| 5 |
+
"""
|
| 6 |
+
The registry that provides name -> object mapping, to support third-party
|
| 7 |
+
users' custom modules.
|
| 8 |
+
|
| 9 |
+
To create a registry (e.g. a backbone registry):
|
| 10 |
+
|
| 11 |
+
.. code-block:: python
|
| 12 |
+
|
| 13 |
+
BACKBONE_REGISTRY = Registry('BACKBONE')
|
| 14 |
+
|
| 15 |
+
To register an object:
|
| 16 |
+
|
| 17 |
+
.. code-block:: python
|
| 18 |
+
|
| 19 |
+
@BACKBONE_REGISTRY.register()
|
| 20 |
+
class MyBackbone():
|
| 21 |
+
...
|
| 22 |
+
|
| 23 |
+
Or:
|
| 24 |
+
|
| 25 |
+
.. code-block:: python
|
| 26 |
+
|
| 27 |
+
BACKBONE_REGISTRY.register(MyBackbone)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, name):
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
name (str): the name of this registry
|
| 34 |
+
"""
|
| 35 |
+
self._name = name
|
| 36 |
+
self._obj_map = {}
|
| 37 |
+
|
| 38 |
+
def _do_register(self, name, obj):
|
| 39 |
+
assert name not in self._obj_map, (
|
| 40 |
+
f"An object named '{name}' was already registered "
|
| 41 |
+
f"in '{self._name}' registry!"
|
| 42 |
+
)
|
| 43 |
+
self._obj_map[name] = obj
|
| 44 |
+
|
| 45 |
+
def register(self, obj=None):
|
| 46 |
+
"""
|
| 47 |
+
Register the given object under the the name `obj.__name__`.
|
| 48 |
+
Can be used as either a decorator or not.
|
| 49 |
+
See docstring of this class for usage.
|
| 50 |
+
"""
|
| 51 |
+
if obj is None:
|
| 52 |
+
# used as a decorator
|
| 53 |
+
def deco(func_or_class):
|
| 54 |
+
name = func_or_class.__name__
|
| 55 |
+
self._do_register(name, func_or_class)
|
| 56 |
+
return func_or_class
|
| 57 |
+
|
| 58 |
+
return deco
|
| 59 |
+
|
| 60 |
+
# used as a function call
|
| 61 |
+
name = obj.__name__
|
| 62 |
+
self._do_register(name, obj)
|
| 63 |
+
|
| 64 |
+
def get(self, name):
|
| 65 |
+
ret = self._obj_map.get(name)
|
| 66 |
+
if ret is None:
|
| 67 |
+
raise KeyError(
|
| 68 |
+
f"No object named '{name}' found in '{self._name}' registry!"
|
| 69 |
+
)
|
| 70 |
+
return ret
|
| 71 |
+
|
| 72 |
+
def __contains__(self, name):
|
| 73 |
+
return name in self._obj_map
|
| 74 |
+
|
| 75 |
+
def __iter__(self):
|
| 76 |
+
return iter(self._obj_map.items())
|
| 77 |
+
|
| 78 |
+
def keys(self):
|
| 79 |
+
return self._obj_map.keys()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
MODEL_REGISTRY = Registry("model")
|
registry_utils.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import importlib
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
| 7 |
+
"""Scan a directory to find the interested files.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
dir_path (str): Path of the directory.
|
| 11 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
| 12 |
+
interested in. Default: None.
|
| 13 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 14 |
+
directory. Default: False.
|
| 15 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
| 16 |
+
Default: False.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
A generator for all the interested files with relative paths.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 23 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 24 |
+
|
| 25 |
+
root = dir_path
|
| 26 |
+
|
| 27 |
+
def _scandir(dir_path, suffix, recursive):
|
| 28 |
+
for entry in os.scandir(dir_path):
|
| 29 |
+
if not entry.name.startswith(".") and entry.is_file():
|
| 30 |
+
if full_path:
|
| 31 |
+
return_path = entry.path
|
| 32 |
+
else:
|
| 33 |
+
return_path = osp.relpath(entry.path, root)
|
| 34 |
+
|
| 35 |
+
if suffix is None:
|
| 36 |
+
yield return_path
|
| 37 |
+
elif return_path.endswith(suffix):
|
| 38 |
+
yield return_path
|
| 39 |
+
else:
|
| 40 |
+
if recursive:
|
| 41 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
| 42 |
+
else:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def import_registered_modules(registration_folder="registrations"):
|
| 49 |
+
"""
|
| 50 |
+
Import all registered modules from the specified folder.
|
| 51 |
+
|
| 52 |
+
This function automatically scans all the files under the specified folder and imports all the required modules for registry.
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
registration_folder (str, optional): Path to the folder containing registration modules. Default is "registrations".
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
list: List of imported modules.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# print("\n")
|
| 62 |
+
|
| 63 |
+
registration_modules_folder = (
|
| 64 |
+
osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
|
| 65 |
+
)
|
| 66 |
+
# print("registration_modules_folder = ", registration_modules_folder)
|
| 67 |
+
|
| 68 |
+
registration_modules_file_names = [
|
| 69 |
+
osp.splitext(osp.basename(v))[0]
|
| 70 |
+
for v in scandir(dir_path=registration_modules_folder)
|
| 71 |
+
]
|
| 72 |
+
# print("registration_modules_file_names = ", registration_modules_file_names)
|
| 73 |
+
|
| 74 |
+
imported_modules = [
|
| 75 |
+
importlib.import_module(f"{registration_folder}.{file_name}")
|
| 76 |
+
for file_name in registration_modules_file_names
|
| 77 |
+
]
|
| 78 |
+
# print("imported_modules = ", imported_modules)
|
| 79 |
+
# print("\n")
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://huggingface.co/docs/hub/en/spaces-dependencies
|
| 2 |
+
tqdm
|
| 3 |
+
PyYAML
|
| 4 |
+
numpy
|
| 5 |
+
pandas
|
| 6 |
+
matplotlib
|
| 7 |
+
seaborn
|
| 8 |
+
mlflow
|
| 9 |
+
pillow
|
| 10 |
+
scikit_learn
|
| 11 |
+
torch
|
| 12 |
+
# captum
|
| 13 |
+
evaluate
|
| 14 |
+
# basicsr
|
| 15 |
+
facexlib
|
| 16 |
+
# realesrgan
|
| 17 |
+
opencv_python
|
| 18 |
+
cmake
|
| 19 |
+
# dlib
|
| 20 |
+
einops
|
| 21 |
+
transformers
|
| 22 |
+
# gfpgan
|
| 23 |
+
gradio==4.36.1
|
| 24 |
+
mediapipe
|
| 25 |
+
imutils
|
| 26 |
+
scipy
|
| 27 |
+
torchvision
|
| 28 |
+
torchcam
|
sample_videos/All Smiles Ahead.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bcb016898fac06a517f067ccbe1e6a32366e984aa9b07a7920a8bc9fdd780d17
|
| 3 |
+
size 951586
|
sample_videos/And it was all Yellow.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfd4d681a4a289aa24fc2193fbcbf855e69c6ce19b9619d8d88b7d782c6047dc
|
| 3 |
+
size 956742
|
sample_videos/Blink It Like Brian.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:590a74b0f13415bb7817c7e28f3f3348bb38aa1f517b8e311f8041c978b4c38b
|
| 3 |
+
size 961928
|
sample_videos/Focus Pocus.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf88c7cff8ef627c03c589ca5feb9c83e95db7e7b55294cbcbafefdfe31cdcf6
|
| 3 |
+
size 971924
|
sample_videos/Funny Talks.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb22539b023294865bc0df2c6c6622ed94d3d5d27df6d0a22ae8fb193c2d6910
|
| 3 |
+
size 963970
|
sample_videos/I like to move it move it.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a4d66abf1707607f826b4b65e20f920d90281e6ab0df5757b57da7216f424b3
|
| 3 |
+
size 958302
|
sample_videos/Infinite Blue.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6047ecda02535212c55714b727c32777a4891be91f642847c586bb24bd3d00d
|
| 3 |
+
size 960117
|
sample_videos/Red Ross.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1078b3fda618c4f5adefe05d389c8360d2a95af1080be7f699a0d72ca454bb3f
|
| 3 |
+
size 960661
|
sample_videos/Smile, You’re on Camera!.webm
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8757fa6781889323f9d577862dcc98f05abd295be6de8e8843a3eb1cd406fdd
|
| 3 |
+
size 965710
|
utils.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from registry import MODEL_REGISTRY
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_model(model_configs):
|
| 5 |
+
registered_model = MODEL_REGISTRY.get(model_configs["registered_model_name"])
|
| 6 |
+
model_configs.pop("registered_model_name")
|
| 7 |
+
if len(model_configs) > 0:
|
| 8 |
+
model = registered_model(model_configs)
|
| 9 |
+
else:
|
| 10 |
+
model = registered_model()
|
| 11 |
+
return model
|