txarst commited on
Commit
f0e5caa
·
0 Parent(s):

model upload

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.webm filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
README.md ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PupilSense
3
+ emoji: 👁️
4
+ colorFrom: red
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # 👁️ PupilSense 👁️🕵️‍♂️
13
+
14
+ PupilSense is a deep learning-powered application for estimating pupil diameter from images and videos. It uses trained ResNet models with Class Activation Mapping (CAM) for interpretable predictions.
15
+
16
+ ## Features
17
+
18
+ - **Image Processing**: Upload images to get instant pupil diameter estimates
19
+ - **Video Processing**: Analyze videos frame-by-frame for temporal pupil diameter analysis
20
+ - **Model Selection**: Choose between ResNet18 and ResNet50 architectures
21
+ - **Pupil Selection**: Analyze left pupil, right pupil, or both
22
+ - **Blink Detection**: Automatically detect and handle blinks in the analysis
23
+ - **CAM Visualization**: See which parts of the eye the model focuses on for predictions
24
+ - **API Access**: Full Gradio API support for programmatic access
25
+
26
+ ## Usage
27
+
28
+ ### Web Interface
29
+ Simply upload an image or video file and configure your analysis parameters:
30
+ - Select pupil(s) to analyze (left, right, or both)
31
+ - Choose the model architecture (ResNet18 or ResNet50)
32
+ - Enable/disable blink detection
33
+ - Click process to get results
34
+
35
+ ### API Access
36
+ The Gradio interface provides automatic API endpoints. You can access the API documentation at `/docs` when the app is running.
37
+
38
+ Example API usage:
39
+ ```python
40
+ import requests
41
+ import json
42
+
43
+ # For image processing
44
+ files = {"image_input": open("your_image.jpg", "rb")}
45
+ data = {
46
+ "pupil_selection": "both",
47
+ "tv_model": "ResNet18",
48
+ "blink_detection": True
49
+ }
50
+ response = requests.post("https://your-space-url/api/predict", files=files, data=data)
51
+ ```
52
+
53
+ ## Model Information
54
+
55
+ The application uses pre-trained ResNet models specifically trained for pupil diameter estimation:
56
+ - **ResNet18**: Faster inference, good accuracy
57
+ - **ResNet50**: Higher accuracy, slower inference
58
+
59
+ Both models support:
60
+ - Input resolution: 32x64 pixels (eye region)
61
+ - Output: Pupil diameter in millimeters
62
+ - CAM visualization for model interpretability
63
+
64
+ ## Technical Details
65
+
66
+ - **Face Detection**: MediaPipe for robust face and eye detection
67
+ - **Preprocessing**: Automatic eye region extraction and normalization
68
+ - **Deep Learning**: PyTorch-based ResNet models
69
+ - **Visualization**: Matplotlib for result plotting and CAM overlays
70
+ - **Video Support**: Frame-by-frame analysis with temporal plotting
71
+
72
+ ## Installation & Setup
73
+
74
+ ### Local Development
75
+
76
+ 1. **Clone the repository**
77
+ ```bash
78
+ git clone <repository-url>
79
+ cd pupilsense
80
+ ```
81
+
82
+ 2. **Create virtual environment**
83
+ ```bash
84
+ python3 -m venv venv
85
+ source venv/bin/activate # On Windows: venv\Scripts\activate
86
+ ```
87
+
88
+ 3. **Install dependencies**
89
+ ```bash
90
+ pip install -r requirements.txt
91
+ ```
92
+
93
+ 4. **Run the application**
94
+ ```bash
95
+ python app.py
96
+ ```
97
+
98
+ The app will be available at `http://localhost:7860`
99
+
100
+ ### Hugging Face Spaces Deployment
101
+
102
+ 1. **Create a new Space** on Hugging Face with Gradio SDK
103
+ 2. **Upload all files** from the pupilsense directory
104
+ 3. **Ensure the following files are present:**
105
+ - `app.py` (main application file)
106
+ - `gradio_app.py` (Gradio interface)
107
+ - `gradio_utils.py` (utility functions)
108
+ - `requirements.txt` (dependencies)
109
+ - `README.md` (this file with proper YAML header)
110
+ - `pre_trained_models/` (model files)
111
+ - All other supporting files
112
+
113
+ ## Known Issues & Troubleshooting
114
+
115
+ ### MediaPipe Issues
116
+ - **Issue**: Segmentation fault or MediaPipe errors in headless environments
117
+ - **Solution**: The app includes error handling for MediaPipe failures. In production environments, ensure proper GPU/display drivers are available.
118
+
119
+ ### Model Loading
120
+ - **Issue**: Model files not found
121
+ - **Solution**: Ensure `pre_trained_models/` directory contains the required `.pt` files for both ResNet18 and ResNet50 models.
122
+
123
+ ### Memory Usage
124
+ - **Issue**: High memory usage with large videos
125
+ - **Solution**: The app automatically resizes frames to 640x480 to manage memory usage.
126
+
127
+ ## File Structure
128
+
129
+ ```
130
+ pupilsense/
131
+ ├── app.py # Main application entry point
132
+ ├── gradio_app.py # Gradio interface definition
133
+ ├── gradio_utils.py # Utility functions (MediaPipe-free)
134
+ ├── app_utils.py # Original Streamlit utilities (legacy)
135
+ ├── requirements.txt # Python dependencies
136
+ ├── README.md # This file
137
+ ├── config.yml # Configuration file
138
+ ├── registry.py # Model registry
139
+ ├── registry_utils.py # Registry utilities
140
+ ├── utils.py # General utilities
141
+ ├── pre_trained_models/ # Trained model files
142
+ │ ├── ResNet18/
143
+ │ │ ├── left_eye.pt
144
+ │ │ └── right_eye.pt
145
+ │ └── ResNet50/
146
+ │ ├── left_eye.pt
147
+ │ └── right_eye.pt
148
+ ├── preprocessing/ # Data preprocessing modules
149
+ ├── feature_extraction/ # Feature extraction modules
150
+ ├── registrations/ # Model registration modules
151
+ └── sample_videos/ # Sample video files
152
+ ```
153
+
154
+ ## Contributing
155
+
156
+ 1. Fork the repository
157
+ 2. Create a feature branch
158
+ 3. Make your changes
159
+ 4. Test thoroughly
160
+ 5. Submit a pull request
161
+
162
+ ## License
163
+
164
+ See LICENSE file for details.
165
+
166
+ ---
167
+
168
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path as osp
3
+
4
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
5
+ sys.path.append(root_path)
6
+
7
+ from gradio_app import create_gradio_interface
8
+
9
+ def main():
10
+ """Main function to launch the Gradio interface."""
11
+ try:
12
+ demo = create_gradio_interface()
13
+
14
+ # For Hugging Face Spaces deployment
15
+ import os
16
+ if os.getenv("SPACE_ID") or os.getenv("SYSTEM") == "spaces":
17
+ # Running on Hugging Face Spaces
18
+ demo.launch(share=True)
19
+ else:
20
+ # Running locally
21
+ try:
22
+ demo.launch(
23
+ server_name="0.0.0.0",
24
+ server_port=7860,
25
+ share=False
26
+ )
27
+ except ValueError as e:
28
+ if "shareable link must be created" in str(e):
29
+ print("Localhost not accessible, creating shareable link...")
30
+ demo.launch(share=True)
31
+ else:
32
+ raise e
33
+ except Exception as e:
34
+ print(f"Error launching app: {e}")
35
+ import traceback
36
+ traceback.print_exc()
37
+ raise e
38
+
39
+ if __name__ == "__main__":
40
+ main()
app_utils.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ import io
4
+ import os
5
+ import sys
6
+ import cv2
7
+ from matplotlib import pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import streamlit as st
11
+ import torch
12
+ import tempfile
13
+ from PIL import Image
14
+ from torchvision.transforms.functional import to_pil_image
15
+ from torchvision import transforms
16
+ from PIL import ImageOps
17
+ import altair as alt
18
+ import streamlit.components.v1 as components
19
+
20
+ from torchcam.methods import CAM
21
+ from torchcam import methods as torchcam_methods
22
+ from torchcam.utils import overlay_mask
23
+ import os.path as osp
24
+
25
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
26
+ sys.path.append(root_path)
27
+
28
+ from preprocessing.dataset_creation import EyeDentityDatasetCreation
29
+ from utils import get_model
30
+
31
+ CAM_METHODS = ["CAM"]
32
+ # colors = ["#2ca02c", "#d62728", "#1f77b4", "#ff7f0e"] # Green, Red, Blue, Orange
33
+ colors = ["#1f77b4", "#ff7f0e", "#636363"] # Blue, Orange, Gray
34
+
35
+
36
+ @torch.no_grad()
37
+ def load_model(model_configs, device="cpu"):
38
+ """Loads the pre-trained model."""
39
+ model_path = os.path.join(root_path, model_configs["model_path"])
40
+ model_dict = torch.load(model_path, map_location=device)
41
+ model = get_model(model_configs=model_configs)
42
+ model.load_state_dict(model_dict)
43
+ model = model.to(device).eval()
44
+ return model
45
+
46
+
47
+ def extract_frames(video_path):
48
+ """Extracts frames from a video file."""
49
+ vidcap = cv2.VideoCapture(video_path)
50
+ frames = []
51
+ success, image = vidcap.read()
52
+ while success:
53
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
54
+ frames.append(image_rgb)
55
+ success, image = vidcap.read()
56
+ vidcap.release()
57
+ return frames
58
+
59
+
60
+ def resize_frame(image, max_width=640, max_height=480):
61
+ if not isinstance(image, Image.Image):
62
+ image = Image.fromarray(image)
63
+ original_size = image.size
64
+
65
+ # Resize the frame similarly to the image resizing logic
66
+ if original_size[0] == original_size[1] and original_size[0] >= 256:
67
+ max_size = (256, 256)
68
+ else:
69
+ max_size = list(original_size)
70
+ if original_size[0] >= max_width:
71
+ max_size[0] = max_width
72
+ elif original_size[0] < 64:
73
+ max_size[0] = 64
74
+ if original_size[1] >= max_height:
75
+ max_size[1] = max_height
76
+ elif original_size[1] < 32:
77
+ max_size[1] = 32
78
+
79
+ image.thumbnail(max_size)
80
+ # image = image.resize(max_size)
81
+ return image
82
+
83
+
84
+ def is_image(file_extension):
85
+ """Checks if the file is an image."""
86
+ return file_extension.lower() in ["png", "jpeg", "jpg"]
87
+
88
+
89
+ def is_video(file_extension):
90
+ """Checks if the file is a video."""
91
+ return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"]
92
+
93
+
94
+ def get_codec_and_extension(file_format):
95
+ """Return codec and file extension based on the format."""
96
+ if file_format == "mp4":
97
+ return "H264", ".mp4"
98
+ elif file_format == "avi":
99
+ return "MJPG", ".avi"
100
+ elif file_format == "webm":
101
+ return "VP80", ".webm"
102
+ else:
103
+ return "MJPG", ".avi"
104
+
105
+
106
+ def display_results(input_image, cam_frame, pupil_diameter, cols):
107
+ """Displays the input image and overlayed CAM result."""
108
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
109
+ axs[0].imshow(input_image)
110
+ axs[0].axis("off")
111
+ axs[0].set_title("Input Image")
112
+ axs[1].imshow(cam_frame)
113
+ axs[1].axis("off")
114
+ axs[1].set_title("Overlayed CAM")
115
+ cols[-1].pyplot(fig)
116
+ cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm")
117
+
118
+
119
+ def preprocess_image(input_img, max_size=(256, 256)):
120
+ """Resizes and preprocesses an image."""
121
+ input_img.thumbnail(max_size)
122
+ preprocess_steps = [
123
+ transforms.ToTensor(),
124
+ transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
125
+ ]
126
+ return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0)
127
+
128
+
129
+ def overlay_text_on_frame(frame, text, position=(16, 20)):
130
+ """Write text on the image frame using OpenCV."""
131
+ return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA)
132
+
133
+
134
+ def get_configs(blink_detection=False):
135
+ upscale = "-"
136
+ upscale_method_or_model = "-"
137
+ if upscale == "-":
138
+ sr_configs = None
139
+ else:
140
+ sr_configs = {
141
+ "method": upscale_method_or_model,
142
+ "params": {"upscale": upscale},
143
+ }
144
+ config_file = {
145
+ "sr_configs": sr_configs,
146
+ "feature_extraction_configs": {
147
+ "blink_detection": blink_detection,
148
+ "upscale": upscale,
149
+ "extraction_library": "mediapipe",
150
+ },
151
+ }
152
+
153
+ return config_file
154
+
155
+
156
+ def setup(cols, pupil_selection, tv_model, output_path):
157
+
158
+ left_pupil_model = None
159
+ left_pupil_cam_extractor = None
160
+ right_pupil_model = None
161
+ right_pupil_cam_extractor = None
162
+ output_frames = {}
163
+ input_frames = {}
164
+ predicted_diameters = {}
165
+ pred_diameters_frames = {}
166
+
167
+ if pupil_selection == "both":
168
+ selected_eyes = ["left_eye", "right_eye"]
169
+
170
+ elif pupil_selection == "left_pupil":
171
+ selected_eyes = ["left_eye"]
172
+
173
+ elif pupil_selection == "right_pupil":
174
+ selected_eyes = ["right_eye"]
175
+
176
+ for i, eye_type in enumerate(selected_eyes):
177
+ model_configs = {
178
+ "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
179
+ "registered_model_name": tv_model,
180
+ "num_classes": 1,
181
+ }
182
+ if eye_type == "left_eye":
183
+ left_pupil_model = load_model(model_configs)
184
+ left_pupil_cam_extractor = None
185
+ output_frames[eye_type] = []
186
+ input_frames[eye_type] = []
187
+ predicted_diameters[eye_type] = []
188
+ pred_diameters_frames[eye_type] = []
189
+ else:
190
+ right_pupil_model = load_model(model_configs)
191
+ right_pupil_cam_extractor = None
192
+ output_frames[eye_type] = []
193
+ input_frames[eye_type] = []
194
+ predicted_diameters[eye_type] = []
195
+ pred_diameters_frames[eye_type] = []
196
+
197
+ video_placeholders = {}
198
+
199
+ if output_path:
200
+ video_cols = cols[1].columns(len(input_frames.keys()))
201
+
202
+ for i, eye_type in enumerate(list(input_frames.keys())):
203
+ video_placeholders[eye_type] = video_cols[i].empty()
204
+
205
+ return (
206
+ selected_eyes,
207
+ input_frames,
208
+ output_frames,
209
+ predicted_diameters,
210
+ pred_diameters_frames,
211
+ video_placeholders,
212
+ left_pupil_model,
213
+ left_pupil_cam_extractor,
214
+ right_pupil_model,
215
+ right_pupil_cam_extractor,
216
+ )
217
+
218
+
219
+ def process_frames(
220
+ cols, input_imgs, tv_model, pupil_selection, cam_method, output_path=None, codec=None, blink_detection=False
221
+ ):
222
+
223
+ config_file = get_configs(blink_detection)
224
+
225
+ face_frames = []
226
+
227
+ (
228
+ selected_eyes,
229
+ input_frames,
230
+ output_frames,
231
+ predicted_diameters,
232
+ pred_diameters_frames,
233
+ video_placeholders,
234
+ left_pupil_model,
235
+ left_pupil_cam_extractor,
236
+ right_pupil_model,
237
+ right_pupil_cam_extractor,
238
+ ) = setup(cols, pupil_selection, tv_model, output_path)
239
+
240
+ ds_creation = EyeDentityDatasetCreation(
241
+ feature_extraction_configs=config_file["feature_extraction_configs"],
242
+ sr_configs=config_file["sr_configs"],
243
+ )
244
+
245
+ preprocess_steps = [
246
+ transforms.Resize(
247
+ [32, 64],
248
+ interpolation=transforms.InterpolationMode.BICUBIC,
249
+ antialias=True,
250
+ ),
251
+ transforms.ToTensor(),
252
+ ]
253
+ preprocess_function = transforms.Compose(preprocess_steps)
254
+
255
+ eyes_ratios = []
256
+
257
+ for idx, input_img in enumerate(input_imgs):
258
+
259
+ img = np.array(input_img)
260
+ ds_results = ds_creation(img)
261
+
262
+ left_eye = None
263
+ right_eye = None
264
+ blinked = False
265
+ eyes_ratio = None
266
+
267
+ if ds_results is not None and "face" in ds_results:
268
+ face_img = to_pil_image(ds_results["face"])
269
+ has_face = True
270
+ else:
271
+ face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8))
272
+ has_face = False
273
+ face_frames.append({"has_face": has_face, "img": face_img})
274
+
275
+ if ds_results is not None and "eyes" in ds_results.keys():
276
+ blinked = ds_results["eyes"]["blinked"]
277
+ eyes_ratio = ds_results["eyes"]["eyes_ratio"]
278
+ if eyes_ratio is not None:
279
+ eyes_ratios.append(eyes_ratio)
280
+ if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None:
281
+ left_eye = ds_results["eyes"]["left_eye"]
282
+ left_eye = to_pil_image(left_eye).convert("RGB")
283
+ left_eye = preprocess_function(left_eye)
284
+ left_eye = left_eye.unsqueeze(0)
285
+ if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None:
286
+ right_eye = ds_results["eyes"]["right_eye"]
287
+ right_eye = to_pil_image(right_eye).convert("RGB")
288
+ right_eye = preprocess_function(right_eye)
289
+ right_eye = right_eye.unsqueeze(0)
290
+ else:
291
+ input_img = preprocess_function(input_img)
292
+ input_img = input_img.unsqueeze(0)
293
+ if pupil_selection == "left_pupil":
294
+ left_eye = input_img
295
+ elif pupil_selection == "right_pupil":
296
+ right_eye = input_img
297
+ else:
298
+ left_eye = input_img
299
+ right_eye = input_img
300
+
301
+ for i, eye_type in enumerate(selected_eyes):
302
+
303
+ if blinked:
304
+ if left_eye is not None and eye_type == "left_eye":
305
+ _, height, width = left_eye.squeeze(0).shape
306
+ input_image_pil = to_pil_image(left_eye.squeeze(0))
307
+ elif right_eye is not None and eye_type == "right_eye":
308
+ _, height, width = right_eye.squeeze(0).shape
309
+ input_image_pil = to_pil_image(right_eye.squeeze(0))
310
+
311
+ input_img_np = np.array(input_image_pil)
312
+ zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8))
313
+ output_img_np = overlay_text_on_frame(np.array(zeros_img), "blink")
314
+ predicted_diameter = "blink"
315
+ else:
316
+ if left_eye is not None and eye_type == "left_eye":
317
+ if left_pupil_cam_extractor is None:
318
+ if tv_model == "ResNet18":
319
+ target_layer = left_pupil_model.resnet.layer4[-1].conv2
320
+ elif tv_model == "ResNet50":
321
+ target_layer = left_pupil_model.resnet.layer4[-1].conv3
322
+ else:
323
+ raise Exception(f"No target layer available for selected model: {tv_model}")
324
+ left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method](
325
+ left_pupil_model,
326
+ target_layer=target_layer,
327
+ fc_layer=left_pupil_model.resnet.fc,
328
+ input_shape=left_eye.shape,
329
+ )
330
+ output = left_pupil_model(left_eye)
331
+ predicted_diameter = output[0].item()
332
+ act_maps = left_pupil_cam_extractor(0, output)
333
+ activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps)
334
+ input_image_pil = to_pil_image(left_eye.squeeze(0))
335
+ elif right_eye is not None and eye_type == "right_eye":
336
+ if right_pupil_cam_extractor is None:
337
+ if tv_model == "ResNet18":
338
+ target_layer = right_pupil_model.resnet.layer4[-1].conv2
339
+ elif tv_model == "ResNet50":
340
+ target_layer = right_pupil_model.resnet.layer4[-1].conv3
341
+ else:
342
+ raise Exception(f"No target layer available for selected model: {tv_model}")
343
+ right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method](
344
+ right_pupil_model,
345
+ target_layer=target_layer,
346
+ fc_layer=right_pupil_model.resnet.fc,
347
+ input_shape=right_eye.shape,
348
+ )
349
+ output = right_pupil_model(right_eye)
350
+ predicted_diameter = output[0].item()
351
+ act_maps = right_pupil_cam_extractor(0, output)
352
+ activation_map = (
353
+ act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps)
354
+ )
355
+ input_image_pil = to_pil_image(right_eye.squeeze(0))
356
+
357
+ # Create CAM overlay
358
+ activation_map_pil = to_pil_image(activation_map, mode="F")
359
+ result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5)
360
+ input_img_np = np.array(input_image_pil)
361
+ output_img_np = np.array(result)
362
+
363
+ # Add frame and predicted diameter to lists
364
+ input_frames[eye_type].append(input_img_np)
365
+ output_frames[eye_type].append(output_img_np)
366
+ predicted_diameters[eye_type].append(predicted_diameter)
367
+
368
+ if output_path:
369
+ height, width, _ = output_img_np.shape
370
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
371
+ if not isinstance(predicted_diameter, str):
372
+ text = f"{predicted_diameter:.2f}"
373
+ else:
374
+ text = predicted_diameter
375
+ frame = overlay_text_on_frame(frame, text)
376
+ pred_diameters_frames[eye_type].append(frame)
377
+
378
+ combined_frame = np.vstack((input_img_np, output_img_np, frame))
379
+
380
+ img_base64 = pil_image_to_base64(Image.fromarray(combined_frame))
381
+ image_html = f'<div style="width: {str(50*len(selected_eyes))}%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
382
+ video_placeholders[eye_type].markdown(image_html, unsafe_allow_html=True)
383
+
384
+ # video_placeholders[eye_type].image(combined_frame, use_column_width=True)
385
+
386
+ st.session_state.current_frame = idx + 1
387
+ txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
388
+ st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
389
+
390
+ if output_path:
391
+ combine_and_show_frames(
392
+ input_frames, output_frames, pred_diameters_frames, output_path, codec, video_placeholders
393
+ )
394
+
395
+ return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios
396
+
397
+
398
+ # Function to display video with autoplay and loop
399
+ def display_video_with_autoplay(video_col, video_path, width):
400
+ video_html = f"""
401
+ <video width="{str(width)}%" height="auto" autoplay loop muted>
402
+ <source src="data:video/mp4;base64,{video_path}" type="video/mp4">
403
+ </video>
404
+ """
405
+ video_col.markdown(video_html, unsafe_allow_html=True)
406
+
407
+
408
+ def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method, blink_detection=False):
409
+
410
+ resized_frames = []
411
+ for i, frame in enumerate(video_frames):
412
+ input_img = resize_frame(frame, max_width=640, max_height=480)
413
+ resized_frames.append(input_img)
414
+
415
+ file_format = output_path.split(".")[-1]
416
+ codec, extension = get_codec_and_extension(file_format)
417
+
418
+ input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
419
+ cols, resized_frames, tv_model, pupil_selection, cam_method, output_path, codec, blink_detection
420
+ )
421
+
422
+ return input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios
423
+
424
+
425
+ # Function to convert string values to float or None
426
+ def convert_diameter(value):
427
+ try:
428
+ return float(value)
429
+ except (ValueError, TypeError):
430
+ return None # Return None if conversion fails
431
+
432
+
433
+ def combine_and_show_frames(input_frames, cam_frames, pred_diameters_frames, output_path, codec, video_cols):
434
+ # Assuming all frames have the same keys (eye types)
435
+ eye_types = input_frames.keys()
436
+
437
+ for i, eye_type in enumerate(eye_types):
438
+ in_frames = input_frames[eye_type]
439
+ cam_out_frames = cam_frames[eye_type]
440
+ pred_diameters_text_frames = pred_diameters_frames[eye_type]
441
+
442
+ # Get frame properties (assuming all frames have the same dimensions)
443
+ height, width, _ = in_frames[0].shape
444
+ fourcc = cv2.VideoWriter_fourcc(*codec)
445
+ fps = 10.0
446
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height * 3)) # Width is tripled for concatenation
447
+
448
+ # Loop through each set of frames and concatenate them
449
+ for j in range(len(in_frames)):
450
+ input_frame = in_frames[j]
451
+ cam_frame = cam_out_frames[j]
452
+ pred_frame = pred_diameters_text_frames[j]
453
+
454
+ # Convert frames to BGR if necessary
455
+ input_frame_bgr = cv2.cvtColor(input_frame, cv2.COLOR_RGB2BGR)
456
+ cam_frame_bgr = cv2.cvtColor(cam_frame, cv2.COLOR_RGB2BGR)
457
+ pred_frame_bgr = cv2.cvtColor(pred_frame, cv2.COLOR_RGB2BGR)
458
+
459
+ # Concatenate frames horizontally (input, cam, pred)
460
+ combined_frame = np.vstack((input_frame_bgr, cam_frame_bgr, pred_frame_bgr))
461
+
462
+ # Write the combined frame to the video
463
+ out.write(combined_frame)
464
+
465
+ # Release the video writer
466
+ out.release()
467
+
468
+ # Read the video and encode it in base64 for displaying
469
+ with open(output_path, "rb") as video_file:
470
+ video_bytes = video_file.read()
471
+ video_base64 = base64.b64encode(video_bytes).decode("utf-8")
472
+
473
+ # Display the combined video
474
+ display_video_with_autoplay(video_cols[eye_type], video_base64, width=len(video_cols) * 50)
475
+
476
+ # Clean up
477
+ os.remove(output_path)
478
+
479
+
480
+ def set_input_image_on_ui(uploaded_file, cols):
481
+ input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
482
+ # NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
483
+ input_img = ImageOps.exif_transpose(input_img)
484
+ input_img = resize_frame(input_img, max_width=640, max_height=480)
485
+ input_img = resize_frame(input_img, max_width=640, max_height=480)
486
+ cols[0].image(input_img, use_column_width=True)
487
+ st.session_state.total_frames = 1
488
+ return input_img
489
+
490
+
491
+ def set_input_video_on_ui(uploaded_file, cols):
492
+ tfile = tempfile.NamedTemporaryFile(delete=False)
493
+ try:
494
+ tfile.write(uploaded_file.read())
495
+ except Exception:
496
+ tfile.write(uploaded_file)
497
+ video_path = tfile.name
498
+ video_frames = extract_frames(video_path)
499
+ cols[0].video(video_path)
500
+ st.session_state.total_frames = len(video_frames)
501
+ return video_frames, video_path
502
+
503
+
504
+ def set_frames_processed_count_placeholder(cols):
505
+ st.session_state.current_frame = 0
506
+ st.session_state.frame_placeholder = cols[0].empty()
507
+ txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>"
508
+ st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True)
509
+
510
+
511
+ def video_to_bytes(video_path):
512
+ # Open the video file in binary mode and return the bytes
513
+ with open(video_path, "rb") as video_file:
514
+ return video_file.read()
515
+
516
+
517
+ def display_video_library(video_folder="./sample_videos"):
518
+ # Get all video files from the folder
519
+ video_files = [f for f in os.listdir(video_folder) if f.endswith(".webm")]
520
+
521
+ # Store the selected video path
522
+ selected_video_path = None
523
+
524
+ # Calculate number of columns (adjust based on your layout preferences)
525
+ num_columns = 3 # For a grid of 3 videos per row
526
+
527
+ # Display videos in a grid layout with 'Select' button for each video
528
+ for i in range(0, len(video_files), num_columns):
529
+ cols = st.columns(num_columns)
530
+ for idx, video_file in enumerate(video_files[i : i + num_columns]):
531
+ with cols[idx]:
532
+ st.subheader(video_file.split(".")[0]) # Use the file name as the title
533
+ video_path = os.path.join(video_folder, video_file)
534
+ st.video(video_path) # Show the video
535
+ if st.button(f"Select {video_file.split('.')[0]}", key=video_file, type="primary"):
536
+ st.session_state.clear()
537
+ st.toast("Scroll Down to see the input and predictions", icon="⏬")
538
+ selected_video_path = video_path # Store the path of the selected video
539
+
540
+ return selected_video_path
541
+
542
+
543
+ def set_page_info_and_sidebar_info():
544
+
545
+ st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
546
+ st.title("👁️ PupilSense 👁️🕵️‍♂️")
547
+ # st.markdown("Upload your own images or video **OR** select from our sample library below")
548
+ st.markdown(
549
+ "<p style='font-size: 30px;'>"
550
+ "Upload your own image 🖼️ or video 🎞️ <strong>OR</strong> select from our sample videos 📚"
551
+ "</p>",
552
+ unsafe_allow_html=True,
553
+ )
554
+ # video_path = display_video_library()
555
+ show_demo_videos = st.sidebar.checkbox("Show Sample Videos", value=False)
556
+ if show_demo_videos:
557
+ video_path = display_video_library()
558
+ else:
559
+ video_path = None
560
+
561
+ st.markdown("<hr id='target_element' style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True)
562
+ cols = st.columns((1, 1))
563
+ cols[0].header("Input")
564
+ cols[-1].header("Prediction")
565
+ st.markdown("<hr style='border: 1px solid #6d6d6d; margin: 20px 0;'>", unsafe_allow_html=True)
566
+
567
+ LABEL_MAP = ["left_pupil", "right_pupil"]
568
+ TV_MODELS = ["ResNet18", "ResNet50"]
569
+
570
+ if "uploader_key" not in st.session_state:
571
+ st.session_state["uploader_key"] = 1
572
+
573
+ st.sidebar.title("Upload Face 👨‍🦱 or Eye 👁️")
574
+ uploaded_file = st.sidebar.file_uploader(
575
+ "Upload Image or Video",
576
+ type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"],
577
+ key=st.session_state["uploader_key"],
578
+ )
579
+ if uploaded_file is not None:
580
+ st.session_state["uploaded_file"] = uploaded_file
581
+
582
+ st.sidebar.title("Setup")
583
+ pupil_selection = st.sidebar.selectbox(
584
+ "Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
585
+ )
586
+ tv_model = st.sidebar.selectbox("Classification model", TV_MODELS, help="Supported Models")
587
+
588
+ blink_detection = st.sidebar.checkbox("Detect Blinks", value=True)
589
+
590
+ st.markdown("<style>#vg-tooltip-element{z-index: 1000051}</style>", unsafe_allow_html=True)
591
+
592
+ if "uploaded_file" not in st.session_state:
593
+ st.session_state["uploaded_file"] = None
594
+
595
+ if "og_video_path" not in st.session_state:
596
+ st.session_state["og_video_path"] = None
597
+
598
+ if uploaded_file is None and video_path is not None:
599
+ video_bytes = video_to_bytes(video_path)
600
+ uploaded_file = video_bytes
601
+ st.session_state["uploaded_file"] = uploaded_file
602
+ st.session_state["og_video_path"] = video_path
603
+ st.session_state["uploader_key"] = 0
604
+
605
+ return (
606
+ cols,
607
+ st.session_state["og_video_path"],
608
+ st.session_state["uploaded_file"],
609
+ pupil_selection,
610
+ tv_model,
611
+ blink_detection,
612
+ )
613
+
614
+
615
+ def pil_image_to_base64(img):
616
+ """Convert a PIL Image to a base64 encoded string."""
617
+ buffered = io.BytesIO()
618
+ img.save(buffered, format="PNG")
619
+ img_str = base64.b64encode(buffered.getvalue()).decode()
620
+ return img_str
621
+
622
+
623
+ def process_image_and_vizualize_data(cols, input_img, tv_model, pupil_selection, blink_detection):
624
+ input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_frames(
625
+ cols,
626
+ [input_img],
627
+ tv_model,
628
+ pupil_selection,
629
+ cam_method=CAM_METHODS[-1],
630
+ blink_detection=blink_detection,
631
+ )
632
+ # for ff in face_frames:
633
+ # if ff["has_face"]:
634
+ # cols[1].image(face_frames[0]["img"], use_column_width=True)
635
+
636
+ input_frames_keys = input_frames.keys()
637
+ video_cols = cols[1].columns(len(input_frames_keys))
638
+
639
+ for i, eye_type in enumerate(input_frames_keys):
640
+ # Check the pupil_selection and set the width accordingly
641
+ if pupil_selection == "both":
642
+ video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
643
+ else:
644
+ img_base64 = pil_image_to_base64(Image.fromarray(input_frames[eye_type][-1]))
645
+ image_html = f'<div style="width: 50%; margin-bottom: 1.2%;"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
646
+ video_cols[i].markdown(image_html, unsafe_allow_html=True)
647
+
648
+ output_frames_keys = output_frames.keys()
649
+ fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
650
+ for i, eye_type in enumerate(output_frames_keys):
651
+ height, width, c = output_frames[eye_type][0].shape
652
+ frame = np.zeros((height, width, c), dtype=np.uint8)
653
+ text = f"{predicted_diameters[eye_type][0]:.2f}"
654
+ frame = overlay_text_on_frame(frame, text)
655
+
656
+ if pupil_selection == "both":
657
+ video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
658
+ video_cols[i].image(frame, use_column_width=True)
659
+ else:
660
+ img_base64 = pil_image_to_base64(Image.fromarray(output_frames[eye_type][-1]))
661
+ image_html = f'<div style="width: 50%; margin-top: 1.2%; margin-bottom: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
662
+ video_cols[i].markdown(image_html, unsafe_allow_html=True)
663
+ img_base64 = pil_image_to_base64(Image.fromarray(frame))
664
+ image_html = f'<div style="width: 50%; margin-top: 1.2%"><img src="data:image/png;base64,{img_base64}" style="width: 100%;"></div>'
665
+ video_cols[i].markdown(image_html, unsafe_allow_html=True)
666
+
667
+ return None
668
+
669
+
670
+ def plot_ears(eyes_ratios, eyes_df):
671
+ eyes_df["EAR"] = eyes_ratios
672
+ df = pd.DataFrame(eyes_ratios, columns=["EAR"])
673
+ df["Frame"] = range(1, len(eyes_ratios) + 1) # Create a frame column starting from 1
674
+
675
+ # Create an Altair chart for eyes_ratios
676
+ line_chart = (
677
+ alt.Chart(df)
678
+ .mark_line(color=colors[-1]) # Set color of the line
679
+ .encode(
680
+ x=alt.X("Frame:Q", title="Frame Number"),
681
+ y=alt.Y("EAR:Q", title="Eyes Aspect Ratio"),
682
+ tooltip=["Frame", "EAR"],
683
+ )
684
+ # .properties(title="Eyes Aspect Ratios (EARs)")
685
+ # .configure_axis(grid=True)
686
+ )
687
+ points_chart = line_chart.mark_point(color=colors[-1], filled=True)
688
+
689
+ # Create a horizontal rule at y=0.22
690
+ line1 = alt.Chart(pd.DataFrame({"y": [0.22]})).mark_rule(color="red").encode(y="y:Q")
691
+
692
+ line2 = alt.Chart(pd.DataFrame({"y": [0.25]})).mark_rule(color="green").encode(y="y:Q")
693
+
694
+ # Add text annotations for the lines
695
+ text1 = (
696
+ alt.Chart(pd.DataFrame({"y": [0.22], "label": ["Definite Blinks (<=0.22)"]}))
697
+ .mark_text(align="left", dx=100, dy=9, color="red", size=16)
698
+ .encode(y="y:Q", text="label:N")
699
+ )
700
+
701
+ text2 = (
702
+ alt.Chart(pd.DataFrame({"y": [0.25], "label": ["No Blinks (>=0.25)"]}))
703
+ .mark_text(align="left", dx=-150, dy=-9, color="green", size=16)
704
+ .encode(y="y:Q", text="label:N")
705
+ )
706
+
707
+ # Add gray area text for the region between red and green lines
708
+ gray_area_text = (
709
+ alt.Chart(pd.DataFrame({"y": [0.235], "label": ["Gray Area"]}))
710
+ .mark_text(align="left", dx=0, dy=0, color="gray", size=16)
711
+ .encode(y="y:Q", text="label:N")
712
+ )
713
+
714
+ # Combine all elements: line chart, points, rules, and text annotations
715
+ final_chart = (
716
+ line_chart.properties(title="Eyes Aspect Ratios (EARs)")
717
+ + points_chart
718
+ + line1
719
+ + line2
720
+ + text1
721
+ + text2
722
+ + gray_area_text
723
+ ).interactive()
724
+
725
+ # Configure axis properties at the chart level
726
+ final_chart = final_chart.configure_axis(grid=True)
727
+
728
+ # Display the Altair chart
729
+ # st.subheader("Eyes Aspect Ratios (EARs)")
730
+ st.altair_chart(final_chart, use_container_width=True)
731
+ return eyes_df
732
+
733
+
734
+ def plot_individual_charts(predicted_diameters, cols):
735
+ # Iterate through categories and assign charts to columns
736
+ for i, (category, values) in enumerate(predicted_diameters.items()):
737
+ with cols[i]: # Directly use the column index
738
+ # st.subheader(category) # Add a subheader for the category
739
+ if "left" in category:
740
+ selected_color = colors[0]
741
+ elif "right" in category:
742
+ selected_color = colors[1]
743
+ else:
744
+ selected_color = colors[i]
745
+
746
+ # Convert values to numeric, replacing non-numeric values with None
747
+ values = [convert_diameter(value) for value in values]
748
+
749
+ if "left" in category:
750
+ category_name = "Left Pupil Diameter"
751
+ else:
752
+ category_name = "Right Pupil Diameter"
753
+
754
+ # Create a DataFrame from the values for Altair
755
+ df = pd.DataFrame(
756
+ {
757
+ "Frame": range(1, len(values) + 1),
758
+ category_name: values,
759
+ }
760
+ )
761
+
762
+ # Get the min and max values for y-axis limits, ignoring None
763
+ min_value = min(filter(lambda x: x is not None, values), default=None)
764
+ max_value = max(filter(lambda x: x is not None, values), default=None)
765
+
766
+ # Create an Altair chart with y-axis limits
767
+ line_chart = (
768
+ alt.Chart(df)
769
+ .mark_line(color=selected_color)
770
+ .encode(
771
+ x=alt.X("Frame:Q", title="Frame Number"),
772
+ y=alt.Y(
773
+ f"{category_name}:Q",
774
+ title="Diameter",
775
+ scale=alt.Scale(domain=[min_value, max_value]),
776
+ ),
777
+ tooltip=[
778
+ "Frame",
779
+ alt.Tooltip(f"{category_name}:Q", title="Diameter"),
780
+ ],
781
+ )
782
+ # .properties(title=f"{category} - Predicted Diameters")
783
+ # .configure_axis(grid=True)
784
+ )
785
+ points_chart = line_chart.mark_point(color=selected_color, filled=True)
786
+
787
+ final_chart = (
788
+ line_chart.properties(
789
+ title=f"{'Left Pupil' if 'left' in category else 'Right Pupil'} - Predicted Diameters"
790
+ )
791
+ + points_chart
792
+ ).interactive()
793
+
794
+ final_chart = final_chart.configure_axis(grid=True)
795
+
796
+ # Display the Altair chart
797
+ st.altair_chart(final_chart, use_container_width=True)
798
+ return df
799
+
800
+
801
+ def plot_combined_charts(predicted_diameters):
802
+ all_min_values = []
803
+ all_max_values = []
804
+
805
+ # Create an empty DataFrame to store combined data for plotting
806
+ combined_df = pd.DataFrame()
807
+
808
+ # Iterate through categories and collect data
809
+ for category, values in predicted_diameters.items():
810
+ # Convert values to numeric, replacing non-numeric values with None
811
+ values = [convert_diameter(value) for value in values]
812
+
813
+ # Get the min and max values for y-axis limits, ignoring None
814
+ min_value = min(filter(lambda x: x is not None, values), default=None)
815
+ max_value = max(filter(lambda x: x is not None, values), default=None)
816
+
817
+ all_min_values.append(min_value)
818
+ all_max_values.append(max_value)
819
+
820
+ category = "left_pupil" if "left" in category else "right_pupil"
821
+
822
+ # Create a DataFrame from the values
823
+ df = pd.DataFrame(
824
+ {
825
+ "Diameter": values,
826
+ "Frame": range(1, len(values) + 1), # Create a frame column starting from 1
827
+ "Category": category, # Add a column to specify the category
828
+ }
829
+ )
830
+
831
+ # Append to combined DataFrame
832
+ combined_df = pd.concat([combined_df, df], ignore_index=True)
833
+
834
+ combined_chart = (
835
+ alt.Chart(combined_df)
836
+ .mark_line()
837
+ .encode(
838
+ x=alt.X("Frame:Q", title="Frame Number"),
839
+ y=alt.Y(
840
+ "Diameter:Q",
841
+ title="Diameter",
842
+ scale=alt.Scale(domain=[min(all_min_values), max(all_max_values)]),
843
+ ),
844
+ color=alt.Color("Category:N", scale=alt.Scale(range=colors), title="Pupil Type"),
845
+ tooltip=["Frame", "Diameter:Q", "Category:N"],
846
+ )
847
+ )
848
+ points_chart = combined_chart.mark_point(filled=True)
849
+
850
+ final_chart = (combined_chart.properties(title="Predicted Diameters") + points_chart).interactive()
851
+
852
+ final_chart = final_chart.configure_axis(grid=True)
853
+
854
+ # Display the combined chart
855
+ st.altair_chart(final_chart, use_container_width=True)
856
+
857
+ # --------------------------------------------
858
+ # Convert to a DataFrame
859
+ left_pupil_values = [convert_diameter(value) for value in predicted_diameters["left_eye"]]
860
+ right_pupil_values = [convert_diameter(value) for value in predicted_diameters["right_eye"]]
861
+
862
+ df = pd.DataFrame(
863
+ {
864
+ "Frame": range(1, len(left_pupil_values) + 1),
865
+ "Left Pupil Diameter": left_pupil_values,
866
+ "Right Pupil Diameter": right_pupil_values,
867
+ }
868
+ )
869
+
870
+ # Calculate the difference between left and right pupil diameters
871
+ df["Difference Value"] = df["Left Pupil Diameter"] - df["Right Pupil Diameter"]
872
+
873
+ # Determine the status of the difference
874
+ df["Difference Status"] = df.apply(
875
+ lambda row: "L>R" if row["Left Pupil Diameter"] > row["Right Pupil Diameter"] else "L<R",
876
+ axis=1,
877
+ )
878
+
879
+ return df
880
+
881
+
882
+ def process_video_and_visualize_data(cols, video_frames, tv_model, pupil_selection, blink_detection, video_path):
883
+ output_video_path = f"{root_path}/tmp.webm"
884
+ input_frames, output_frames, predicted_diameters, face_frames, eyes_ratios = process_video(
885
+ cols,
886
+ video_frames,
887
+ tv_model,
888
+ pupil_selection,
889
+ output_video_path,
890
+ cam_method=CAM_METHODS[-1],
891
+ blink_detection=blink_detection,
892
+ )
893
+ os.remove(video_path)
894
+
895
+ num_columns = len(predicted_diameters)
896
+ cols = st.columns(num_columns)
897
+
898
+ if num_columns == 2:
899
+ df = plot_combined_charts(predicted_diameters)
900
+ else:
901
+ df = plot_individual_charts(predicted_diameters, cols)
902
+
903
+ if eyes_ratios is not None and len(eyes_ratios) > 0:
904
+ df = plot_ears(eyes_ratios, df)
905
+
906
+ st.dataframe(df, hide_index=True, use_container_width=True)
config.yml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 42
2
+
3
+ feature_extraction_configs:
4
+ blink_detection: true
5
+ upscale: 1
6
+ extraction_library: "mediapipe"
7
+ show_features: ['faces', 'eyes', 'blinks']
8
+
9
+ model_configs:
10
+ models_path: "pre_trained_models"
11
+ registered_model_names: ["ResNet18", "ResNet50"]
12
+ labels: ["left_eye", "right_eye"]
13
+ targets: ["left_pupil", "right_pupil"]
14
+ num_classes: 1
15
+
16
+ xai_configs:
17
+ attribution_methods: [
18
+ "IntegratedGradients",
19
+ "Saliency",
20
+ "InputXGradient",
21
+ "GuidedBackprop",
22
+ "Deconvolution",
23
+ # "GuidedGradCam",
24
+ # "LayerGradCam",
25
+ # "LayerGradientXActivation",
26
+ ]
27
+ cam_methods: [
28
+ "CAM",
29
+ "GradCAM",
30
+ "GradCAMpp",
31
+ "SmoothGradCAMpp",
32
+ "ScoreCAM",
33
+ "SSCAM",
34
+ "ISCAM",
35
+ "XGradCAM",
36
+ "LayerCAM",
37
+ ]
38
+
39
+ use_sr: false
40
+
41
+ upscale_configs:
42
+ upscale: [1, 2, 3, 4]
43
+ upscale_method_configs:
44
+ size: [16, 32]
45
+ antialias: true
46
+ interpolation: ["bicubic"]
47
+
48
+ sr_methods: ["GFPGAN", "RealESRGAN", "SRResNet", "CodeFormer", "HAT"]
49
+ sr_method_configs:
50
+ bg_upsampler_name: "realesrgan"
51
+ prefered_net_in_upsampler: "RRDBNet"
feature_extraction/extractor_mediapipe.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import warnings
4
+ import numpy as np
5
+ from PIL import Image
6
+ from math import sqrt
7
+ import mediapipe as mp
8
+ from transformers import pipeline
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+
13
+ class ExtractorMediaPipe:
14
+
15
+ def __init__(self, upscale=1):
16
+
17
+ self.upscale = int(upscale)
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # ========== Face Extraction ==========
21
+ self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
22
+ self.face_mesh = mp.solutions.face_mesh.FaceMesh(
23
+ max_num_faces=1,
24
+ static_image_mode=True,
25
+ refine_landmarks=True,
26
+ min_detection_confidence=0.5,
27
+ min_tracking_confidence=0.5,
28
+ )
29
+
30
+ # ========== Eyes Extraction ==========
31
+ self.RIGHT_EYE = [
32
+ 362,
33
+ 382,
34
+ 381,
35
+ 380,
36
+ 374,
37
+ 373,
38
+ 390,
39
+ 249,
40
+ 263,
41
+ 466,
42
+ 388,
43
+ 387,
44
+ 386,
45
+ 385,
46
+ 384,
47
+ 398,
48
+ ]
49
+ self.LEFT_EYE = [
50
+ 33,
51
+ 7,
52
+ 163,
53
+ 144,
54
+ 145,
55
+ 153,
56
+ 154,
57
+ 155,
58
+ 133,
59
+ 173,
60
+ 157,
61
+ 158,
62
+ 159,
63
+ 160,
64
+ 161,
65
+ 246,
66
+ ]
67
+ # https://huggingface.co/dima806/closed_eyes_image_detection
68
+ # https://www.kaggle.com/code/dima806/closed-eye-image-detection-vit
69
+ self.pipe = pipeline(
70
+ "image-classification",
71
+ model="dima806/closed_eyes_image_detection",
72
+ device=self.device,
73
+ )
74
+ self.blink_lower_thresh = 0.22
75
+ self.blink_upper_thresh = 0.25
76
+ self.blink_confidence = 0.50
77
+
78
+ # ========== Iris Extraction ==========
79
+ self.RIGHT_IRIS = [474, 475, 476, 477]
80
+ self.LEFT_IRIS = [469, 470, 471, 472]
81
+
82
+ def extract_face(self, image):
83
+
84
+ tmp_image = image.copy()
85
+ results = self.face_detector.process(tmp_image)
86
+
87
+ if not results.detections:
88
+ # print("No face detected")
89
+ return None
90
+ else:
91
+ bboxC = results.detections[0].location_data.relative_bounding_box
92
+ ih, iw, _ = image.shape
93
+
94
+ # Get bounding box coordinates
95
+ x, y, w, h = (
96
+ int(bboxC.xmin * iw),
97
+ int(bboxC.ymin * ih),
98
+ int(bboxC.width * iw),
99
+ int(bboxC.height * ih),
100
+ )
101
+
102
+ # Calculate the center of the bounding box
103
+ center_x = x + w // 2
104
+ center_y = y + h // 2
105
+
106
+ # Calculate new bounds ensuring they fit within the image dimensions
107
+ half_size = 128 * self.upscale
108
+ x1 = max(center_x - half_size, 0)
109
+ y1 = max(center_y - half_size, 0)
110
+ x2 = min(center_x + half_size, iw)
111
+ y2 = min(center_y + half_size, ih)
112
+
113
+ # Adjust x1, x2, y1, and y2 to ensure the cropped region is exactly (256 * self.upscale) x (256 * self.upscale)
114
+ if x2 - x1 < (256 * self.upscale):
115
+ if x1 == 0:
116
+ x2 = min((256 * self.upscale), iw)
117
+ elif x2 == iw:
118
+ x1 = max(iw - (256 * self.upscale), 0)
119
+
120
+ if y2 - y1 < (256 * self.upscale):
121
+ if y1 == 0:
122
+ y2 = min((256 * self.upscale), ih)
123
+ elif y2 == ih:
124
+ y1 = max(ih - (256 * self.upscale), 0)
125
+
126
+ cropped_face = image[y1:y2, x1:x2]
127
+
128
+ # bicubic upsampling
129
+ # if self.upscale != 1:
130
+ # cropped_face = cv2.resize(
131
+ # cropped_face,
132
+ # (256 * self.upscale, 256 * self.upscale),
133
+ # interpolation=cv2.INTER_CUBIC,
134
+ # )
135
+
136
+ return cropped_face
137
+
138
+ @staticmethod
139
+ def landmarksDetection(image, results, draw=False):
140
+ image_height, image_width = image.shape[:2]
141
+ mesh_coordinates = [
142
+ (int(point.x * image_width), int(point.y * image_height))
143
+ for point in results.multi_face_landmarks[0].landmark
144
+ ]
145
+ if draw:
146
+ [cv2.circle(image, i, 2, (0, 255, 0), -1) for i in mesh_coordinates]
147
+ return mesh_coordinates
148
+
149
+ @staticmethod
150
+ def euclideanDistance(point, point1):
151
+ x, y = point
152
+ x1, y1 = point1
153
+ distance = sqrt((x1 - x) ** 2 + (y1 - y) ** 2)
154
+ return distance
155
+
156
+ def blinkRatio(self, landmarks, right_indices, left_indices):
157
+
158
+ right_eye_landmark1 = landmarks[right_indices[0]]
159
+ right_eye_landmark2 = landmarks[right_indices[8]]
160
+
161
+ right_eye_landmark3 = landmarks[right_indices[12]]
162
+ right_eye_landmark4 = landmarks[right_indices[4]]
163
+
164
+ left_eye_landmark1 = landmarks[left_indices[0]]
165
+ left_eye_landmark2 = landmarks[left_indices[8]]
166
+
167
+ left_eye_landmark3 = landmarks[left_indices[12]]
168
+ left_eye_landmark4 = landmarks[left_indices[4]]
169
+
170
+ right_eye_horizontal_distance = self.euclideanDistance(right_eye_landmark1, right_eye_landmark2)
171
+ right_eye_vertical_distance = self.euclideanDistance(right_eye_landmark3, right_eye_landmark4)
172
+
173
+ left_eye_vertical_distance = self.euclideanDistance(left_eye_landmark3, left_eye_landmark4)
174
+ left_eye_horizontal_distance = self.euclideanDistance(left_eye_landmark1, left_eye_landmark2)
175
+
176
+ right_eye_ratio = right_eye_vertical_distance / right_eye_horizontal_distance
177
+ left_eye_ratio = left_eye_vertical_distance / left_eye_horizontal_distance
178
+
179
+ eyes_ratio = (right_eye_ratio + left_eye_ratio) / 2
180
+
181
+ return eyes_ratio
182
+
183
+ def extract_eyes_regions(self, image, landmarks, eye_indices):
184
+ h, w, _ = image.shape
185
+ points = [(int(landmarks[idx].x * w), int(landmarks[idx].y * h)) for idx in eye_indices]
186
+
187
+ x_min = min([p[0] for p in points])
188
+ x_max = max([p[0] for p in points])
189
+ y_min = min([p[1] for p in points])
190
+ y_max = max([p[1] for p in points])
191
+
192
+ center_x = (x_min + x_max) // 2
193
+ center_y = (y_min + y_max) // 2
194
+
195
+ target_width = 32 * self.upscale
196
+ target_height = 16 * self.upscale
197
+
198
+ x1 = max(center_x - target_width // 2, 0)
199
+ y1 = max(center_y - target_height // 2, 0)
200
+ x2 = x1 + target_width
201
+ y2 = y1 + target_height
202
+
203
+ if x2 > w:
204
+ x1 = w - target_width
205
+ x2 = w
206
+ if y2 > h:
207
+ y1 = h - target_height
208
+ y2 = h
209
+
210
+ return image[y1:y2, x1:x2]
211
+
212
+ def blink_detection_model(self, left_eye, right_eye):
213
+
214
+ left_eye = cv2.cvtColor(left_eye, cv2.COLOR_RGB2GRAY)
215
+ left_eye = Image.fromarray(left_eye)
216
+ preds_left = self.pipe(left_eye)
217
+ if preds_left[0]["label"] == "closeEye":
218
+ closed_left = preds_left[0]["score"] >= self.blink_confidence
219
+ else:
220
+ closed_left = preds_left[1]["score"] >= self.blink_confidence
221
+
222
+ right_eye = cv2.cvtColor(right_eye, cv2.COLOR_RGB2GRAY)
223
+ right_eye = Image.fromarray(right_eye)
224
+ preds_right = self.pipe(right_eye)
225
+ if preds_right[0]["label"] == "closeEye":
226
+ closed_right = preds_right[0]["score"] >= self.blink_confidence
227
+ else:
228
+ closed_right = preds_right[1]["score"] >= self.blink_confidence
229
+
230
+ # print("preds_left = ", preds_left)
231
+ # print("preds_right = ", preds_right)
232
+
233
+ return closed_left or closed_right
234
+
235
+ def extract_eyes(self, image, blink_detection=False):
236
+
237
+ tmp_face = image.copy()
238
+ results = self.face_mesh.process(tmp_face)
239
+
240
+ if results.multi_face_landmarks is None:
241
+ return None
242
+
243
+ face_landmarks = results.multi_face_landmarks[0].landmark
244
+
245
+ left_eye = self.extract_eyes_regions(image, face_landmarks, self.LEFT_EYE)
246
+ right_eye = self.extract_eyes_regions(image, face_landmarks, self.RIGHT_EYE)
247
+ blinked = False
248
+ eyes_ratio = None
249
+
250
+ if blink_detection:
251
+ mesh_coordinates = self.landmarksDetection(image, results, False)
252
+ eyes_ratio = self.blinkRatio(mesh_coordinates, self.RIGHT_EYE, self.LEFT_EYE)
253
+ if eyes_ratio > self.blink_lower_thresh and eyes_ratio <= self.blink_upper_thresh:
254
+ # print(
255
+ # "I think person blinked. eyes_ratio = ",
256
+ # eyes_ratio,
257
+ # "Confirming with ViT model...",
258
+ # )
259
+ blinked = self.blink_detection_model(left_eye=left_eye, right_eye=right_eye)
260
+ # if blinked:
261
+ # print("Yes, person blinked. Confirmed by model")
262
+ # else:
263
+ # print("No, person didn't blinked. False Alarm")
264
+ elif eyes_ratio <= self.blink_lower_thresh:
265
+ blinked = True
266
+ # print("Surely person blinked. eyes_ratio = ", eyes_ratio)
267
+ else:
268
+ blinked = False
269
+
270
+ return {"left_eye": left_eye, "right_eye": right_eye, "blinked": blinked, "eyes_ratio": eyes_ratio}
271
+
272
+ @staticmethod
273
+ def segment_iris(iris_img):
274
+
275
+ # Convert RGB image to grayscale
276
+ iris_img_gray = cv2.cvtColor(iris_img, cv2.COLOR_RGB2GRAY)
277
+
278
+ # Apply Gaussian blur for denoising
279
+ iris_img_blur = cv2.GaussianBlur(iris_img_gray, (5, 5), 0)
280
+
281
+ # Perform adaptive thresholding
282
+ _, iris_img_mask = cv2.threshold(iris_img_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
283
+
284
+ # Invert the mask
285
+ segmented_mask = cv2.bitwise_not(iris_img_mask)
286
+ segmented_mask = cv2.cvtColor(segmented_mask, cv2.COLOR_GRAY2RGB)
287
+ segmented_iris = cv2.bitwise_and(iris_img, segmented_mask)
288
+
289
+ return {
290
+ "segmented_iris": segmented_iris,
291
+ "segmented_mask": segmented_mask,
292
+ }
293
+
294
+ def extract_iris(self, image):
295
+
296
+ ih, iw, _ = image.shape
297
+ tmp_face = image.copy()
298
+ results = self.face_mesh.process(tmp_face)
299
+
300
+ if results.multi_face_landmarks is None:
301
+ return None
302
+
303
+ mesh_coordinates = self.landmarksDetection(image, results, False)
304
+ mesh_points = np.array(mesh_coordinates)
305
+
306
+ (l_cx, l_cy), l_radius = cv2.minEnclosingCircle(mesh_points[self.LEFT_IRIS])
307
+ (r_cx, r_cy), r_radius = cv2.minEnclosingCircle(mesh_points[self.RIGHT_IRIS])
308
+
309
+ # Crop the left iris to be exactly 16*upscaled x 16*upscaled
310
+ l_x1 = max(int(l_cx) - (8 * self.upscale), 0)
311
+ l_y1 = max(int(l_cy) - (8 * self.upscale), 0)
312
+ l_x2 = min(int(l_cx) + (8 * self.upscale), iw)
313
+ l_y2 = min(int(l_cy) + (8 * self.upscale), ih)
314
+
315
+ cropped_left_iris = image[l_y1:l_y2, l_x1:l_x2]
316
+
317
+ left_iris_segmented_data = self.segment_iris(cv2.cvtColor(cropped_left_iris, cv2.COLOR_BGR2RGB))
318
+
319
+ # Crop the right iris to be exactly 16*upscaled x 16*upscaled
320
+ r_x1 = max(int(r_cx) - (8 * self.upscale), 0)
321
+ r_y1 = max(int(r_cy) - (8 * self.upscale), 0)
322
+ r_x2 = min(int(r_cx) + (8 * self.upscale), iw)
323
+ r_y2 = min(int(r_cy) + (8 * self.upscale), ih)
324
+
325
+ cropped_right_iris = image[r_y1:r_y2, r_x1:r_x2]
326
+
327
+ right_iris_segmented_data = self.segment_iris(cv2.cvtColor(cropped_right_iris, cv2.COLOR_BGR2RGB))
328
+
329
+ return {
330
+ "left_iris": {
331
+ "img": cropped_left_iris,
332
+ "segmented_iris": left_iris_segmented_data["segmented_iris"],
333
+ "segmented_mask": left_iris_segmented_data["segmented_mask"],
334
+ },
335
+ "right_iris": {
336
+ "img": cropped_right_iris,
337
+ "segmented_iris": right_iris_segmented_data["segmented_iris"],
338
+ "segmented_mask": right_iris_segmented_data["segmented_mask"],
339
+ },
340
+ }
feature_extraction/features_extractor.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ import os.path as osp
5
+
6
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
8
+ sys.path.append(root_path)
9
+
10
+ from feature_extraction.extractor_mediapipe import ExtractorMediaPipe
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+ class FeaturesExtractor:
16
+
17
+ def __init__(self, extraction_library="mediapipe", blink_detection=False, upscale=1):
18
+ self.upscale = upscale
19
+ self.blink_detection = blink_detection
20
+ self.extraction_library = extraction_library
21
+ self.feature_extractor = ExtractorMediaPipe(self.upscale)
22
+
23
+ def __call__(self, image):
24
+ results = {}
25
+ face = self.feature_extractor.extract_face(image)
26
+ if face is None:
27
+ # print("No face found. Skipped feature extraction!")
28
+ return None
29
+ else:
30
+ results["img"] = image
31
+ results["face"] = face
32
+ eyes_data = self.feature_extractor.extract_eyes(image, self.blink_detection)
33
+ if eyes_data is None:
34
+ # print("No eyes found. Skipped feature extraction!")
35
+ return results
36
+ else:
37
+ results["eyes"] = eyes_data
38
+ if eyes_data["blinked"]:
39
+ # print("Found blinked eyes!")
40
+ return results
41
+ else:
42
+ iris_data = self.feature_extractor.extract_iris(image)
43
+ if iris_data is None:
44
+ # print("No iris found. Skipped feature extraction!")
45
+ return results
46
+ else:
47
+ results["iris"] = iris_data
48
+ return results
gradio_app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import os.path as osp
4
+ import gradio as gr
5
+ import numpy as np
6
+ import tempfile
7
+ from PIL import Image, ImageOps
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ import base64
12
+
13
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
14
+ sys.path.append(root_path)
15
+
16
+ from registry_utils import import_registered_modules
17
+ from gradio_utils import (
18
+ is_image,
19
+ is_video,
20
+ extract_frames,
21
+ resize_frame,
22
+ CAM_METHODS,
23
+ process_frames_gradio,
24
+ )
25
+
26
+ import_registered_modules()
27
+
28
+
29
+ def process_image_gradio(image, pupil_selection, tv_model, blink_detection):
30
+ """
31
+ Process a single image and return results for Gradio interface.
32
+
33
+ Args:
34
+ image: PIL Image or numpy array
35
+ pupil_selection: str - "left_pupil", "right_pupil", or "both"
36
+ tv_model: str - "ResNet18" or "ResNet50"
37
+ blink_detection: bool - whether to detect blinks
38
+
39
+ Returns:
40
+ tuple: (input_image, cam_overlay, diameter_text, results_plot)
41
+ """
42
+ try:
43
+ # Convert to PIL Image if needed
44
+ if isinstance(image, np.ndarray):
45
+ image = Image.fromarray(image)
46
+
47
+ # Handle EXIF rotation
48
+ image = ImageOps.exif_transpose(image)
49
+
50
+ # Resize image
51
+ image = resize_frame(image, max_width=640, max_height=480)
52
+
53
+ # Process the image using Gradio-compatible function
54
+ input_frames, output_frames, predicted_diameters = process_frames_gradio(
55
+ input_imgs=[image],
56
+ tv_model=tv_model,
57
+ pupil_selection=pupil_selection,
58
+ blink_detection=blink_detection,
59
+ )
60
+
61
+ # Check if processing failed (empty results)
62
+ if not input_frames or not output_frames or not predicted_diameters:
63
+ error_msg = "Could not detect face/eyes in the image. Please try with a clearer image showing eyes."
64
+ error_img = Image.new('RGB', (400, 200), 'white')
65
+ return error_img, error_msg
66
+
67
+ # Create visualization
68
+ results = []
69
+ diameter_results = []
70
+
71
+ for eye_type in input_frames.keys():
72
+ input_img = input_frames[eye_type][-1]
73
+ output_img = output_frames[eye_type][-1]
74
+ diameter = predicted_diameters[eye_type][0]
75
+
76
+ # Create side-by-side comparison
77
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
78
+
79
+ ax1.imshow(input_img)
80
+ ax1.set_title(f"Input - {eye_type.replace('_', ' ').title()}")
81
+ ax1.axis('off')
82
+
83
+ ax2.imshow(output_img)
84
+ ax2.set_title(f"CAM Overlay - {eye_type.replace('_', ' ').title()}")
85
+ ax2.axis('off')
86
+
87
+ plt.tight_layout()
88
+
89
+ # Convert plot to image
90
+ buf = io.BytesIO()
91
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
92
+ buf.seek(0)
93
+ plot_img = Image.open(buf)
94
+ plt.close()
95
+
96
+ results.append(plot_img)
97
+
98
+ # Format diameter result
99
+ if isinstance(diameter, str):
100
+ diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter}")
101
+ else:
102
+ diameter_results.append(f"{eye_type.replace('_', ' ').title()}: {diameter:.2f} mm")
103
+
104
+ # Combine results if multiple eyes
105
+ if len(results) == 1:
106
+ final_image = results[0]
107
+ else:
108
+ # Combine multiple eye results
109
+ total_width = sum(img.width for img in results)
110
+ max_height = max(img.height for img in results)
111
+ final_image = Image.new('RGB', (total_width, max_height), 'white')
112
+ x_offset = 0
113
+ for img in results:
114
+ final_image.paste(img, (x_offset, 0))
115
+ x_offset += img.width
116
+
117
+ diameter_text = "\n".join(diameter_results)
118
+
119
+ return final_image, diameter_text
120
+
121
+ except Exception as e:
122
+ error_msg = f"Error processing image: {str(e)}"
123
+ # Create error image
124
+ error_img = Image.new('RGB', (400, 200), 'white')
125
+ return error_img, error_msg
126
+
127
+
128
+ def process_video_gradio(video_file, pupil_selection, tv_model, blink_detection):
129
+ """
130
+ Process a video file and return results for Gradio interface.
131
+
132
+ Args:
133
+ video_file: file path or file object
134
+ pupil_selection: str - "left_pupil", "right_pupil", or "both"
135
+ tv_model: str - "ResNet18" or "ResNet50"
136
+ blink_detection: bool - whether to detect blinks
137
+
138
+ Returns:
139
+ tuple: (results_plot, diameter_data, summary_text)
140
+ """
141
+ try:
142
+ # Handle video file
143
+ if hasattr(video_file, 'name'):
144
+ video_path = video_file.name
145
+ else:
146
+ video_path = video_file
147
+
148
+ # Extract frames
149
+ video_frames = extract_frames(video_path)
150
+
151
+ if not video_frames:
152
+ return None, "No frames extracted from video", "Error: Could not process video"
153
+
154
+ # Resize frames
155
+ resized_frames = []
156
+ for frame in video_frames:
157
+ if isinstance(frame, np.ndarray):
158
+ frame = Image.fromarray(frame)
159
+ input_img = resize_frame(frame, max_width=640, max_height=480)
160
+ resized_frames.append(input_img)
161
+
162
+ # Process video frames using Gradio-compatible function
163
+ input_frames, output_frames, predicted_diameters = process_frames_gradio(
164
+ input_imgs=resized_frames,
165
+ tv_model=tv_model,
166
+ pupil_selection=pupil_selection,
167
+ blink_detection=blink_detection,
168
+ )
169
+
170
+ # Check if processing failed (empty results)
171
+ if not input_frames or not output_frames or not predicted_diameters:
172
+ error_msg = "Could not process video. MediaPipe may have issues in this environment."
173
+ error_img = Image.new('RGB', (400, 200), 'white')
174
+ return error_img, "", error_msg
175
+
176
+ # Create results visualization
177
+ fig, axes = plt.subplots(len(predicted_diameters), 1, figsize=(12, 6 * len(predicted_diameters)))
178
+ if len(predicted_diameters) == 1:
179
+ axes = [axes]
180
+
181
+ summary_stats = []
182
+
183
+ for idx, (eye_type, diameters) in enumerate(predicted_diameters.items()):
184
+ # Filter out non-numeric values (like "blink")
185
+ numeric_diameters = [d for d in diameters if isinstance(d, (int, float))]
186
+ frame_numbers = list(range(len(diameters)))
187
+
188
+ # Plot diameter over time
189
+ axes[idx].plot(frame_numbers, diameters, marker='o', markersize=2)
190
+ axes[idx].set_title(f"Pupil Diameter Over Time - {eye_type.replace('_', ' ').title()}")
191
+ axes[idx].set_xlabel("Frame Number")
192
+ axes[idx].set_ylabel("Diameter (mm)")
193
+ axes[idx].grid(True, alpha=0.3)
194
+
195
+ # Calculate statistics
196
+ if numeric_diameters:
197
+ mean_diameter = np.mean(numeric_diameters)
198
+ std_diameter = np.std(numeric_diameters)
199
+ min_diameter = np.min(numeric_diameters)
200
+ max_diameter = np.max(numeric_diameters)
201
+
202
+ summary_stats.append(f"{eye_type.replace('_', ' ').title()}:")
203
+ summary_stats.append(f" Mean: {mean_diameter:.2f} mm")
204
+ summary_stats.append(f" Std: {std_diameter:.2f} mm")
205
+ summary_stats.append(f" Min: {min_diameter:.2f} mm")
206
+ summary_stats.append(f" Max: {max_diameter:.2f} mm")
207
+ summary_stats.append("")
208
+
209
+ plt.tight_layout()
210
+
211
+ # Convert plot to image
212
+ buf = io.BytesIO()
213
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
214
+ buf.seek(0)
215
+ plot_img = Image.open(buf)
216
+ plt.close()
217
+
218
+ # Create summary text
219
+ summary_text = f"Processed {len(video_frames)} frames\n\n" + "\n".join(summary_stats)
220
+
221
+ # Create CSV data for download
222
+ csv_data = "Frame,Eye_Type,Diameter_mm\n"
223
+ for eye_type, diameters in predicted_diameters.items():
224
+ for frame_idx, diameter in enumerate(diameters):
225
+ csv_data += f"{frame_idx},{eye_type},{diameter}\n"
226
+
227
+ # Clean up temporary files if they exist
228
+ # (output_video_path not used in this implementation)
229
+
230
+ return plot_img, csv_data, summary_text
231
+
232
+ except Exception as e:
233
+ error_msg = f"Error processing video: {str(e)}"
234
+ error_img = Image.new('RGB', (400, 200), 'white')
235
+ return error_img, "", error_msg
236
+
237
+
238
+ def process_media_unified(media_input, pupil_selection, tv_model, blink_detection):
239
+ """
240
+ Unified processing function that handles both images and videos.
241
+
242
+ Args:
243
+ media_input: Either an image (PIL) or video file path
244
+ pupil_selection: str - "left_pupil", "right_pupil", or "both"
245
+ tv_model: str - "ResNet18" or "ResNet50"
246
+ blink_detection: bool - whether to detect blinks
247
+
248
+ Returns:
249
+ tuple: (result_image, result_text)
250
+ """
251
+ try:
252
+ # Check if input is an image or video
253
+ if hasattr(media_input, 'name'):
254
+ # It's a file object (video)
255
+ file_path = media_input.name
256
+ if is_video(file_path):
257
+ plot_img, csv_data, summary_text = process_video_gradio(media_input, pupil_selection, tv_model, blink_detection)
258
+ combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
259
+ return plot_img, combined_output
260
+ elif is_image(file_path):
261
+ # Convert file to PIL Image
262
+ from PIL import Image
263
+ image = Image.open(file_path)
264
+ return process_image_gradio(image, pupil_selection, tv_model, blink_detection)
265
+ else:
266
+ # It's a PIL Image
267
+ return process_image_gradio(media_input, pupil_selection, tv_model, blink_detection)
268
+
269
+ except Exception as e:
270
+ error_msg = f"Error processing media: {str(e)}"
271
+ from PIL import Image
272
+ error_img = Image.new('RGB', (400, 200), 'white')
273
+ return error_img, error_msg
274
+
275
+
276
+ def create_gradio_interface():
277
+ """Create and configure the Gradio interface with proper API support."""
278
+
279
+ # Create a unified interface that can handle both images and videos
280
+ with gr.Blocks(title="👁️ PupilSense 👁️🕵️‍♂️") as demo:
281
+ gr.Markdown("# 👁️ PupilSense - Pupil Diameter Analysis")
282
+ gr.Markdown("Upload an image or video to estimate pupil diameter using deep learning models.")
283
+
284
+ with gr.Tab("Image Processing"):
285
+ with gr.Row():
286
+ with gr.Column():
287
+ image_input = gr.Image(type="pil", label="Upload Image")
288
+ image_pupil_selection = gr.Dropdown(
289
+ ["left_pupil", "right_pupil", "both"],
290
+ value="both",
291
+ label="Pupil Selection"
292
+ )
293
+ image_model = gr.Dropdown(
294
+ ["ResNet18", "ResNet50"],
295
+ value="ResNet18",
296
+ label="Model"
297
+ )
298
+ image_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
299
+ image_submit = gr.Button("Process Image", variant="primary")
300
+
301
+ with gr.Column():
302
+ image_output = gr.Image(label="Results")
303
+ image_text_output = gr.Textbox(label="Pupil Diameter Results", lines=5)
304
+
305
+ image_submit.click(
306
+ fn=process_image_simple,
307
+ inputs=[image_input, image_pupil_selection, image_model, image_blink_detection],
308
+ outputs=[image_output, image_text_output]
309
+ )
310
+
311
+ with gr.Tab("Video Processing"):
312
+ with gr.Row():
313
+ with gr.Column():
314
+ video_input = gr.Video(label="Upload Video")
315
+ video_pupil_selection = gr.Dropdown(
316
+ ["left_pupil", "right_pupil", "both"],
317
+ value="both",
318
+ label="Pupil Selection"
319
+ )
320
+ video_model = gr.Dropdown(
321
+ ["ResNet18", "ResNet50"],
322
+ value="ResNet18",
323
+ label="Model"
324
+ )
325
+ video_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
326
+ video_submit = gr.Button("Process Video", variant="primary")
327
+
328
+ with gr.Column():
329
+ video_output = gr.Image(label="Diameter Analysis")
330
+ video_text_output = gr.Textbox(label="Summary Statistics", lines=10)
331
+
332
+ video_submit.click(
333
+ fn=process_video_simple,
334
+ inputs=[video_input, video_pupil_selection, video_model, video_blink_detection],
335
+ outputs=[video_output, video_text_output]
336
+ )
337
+
338
+ # Add a unified API endpoint that can handle both images and videos
339
+ with gr.Tab("API Testing"):
340
+ gr.Markdown("### API Endpoint for External Access")
341
+ gr.Markdown("This endpoint can process both images and videos programmatically.")
342
+
343
+ with gr.Row():
344
+ with gr.Column():
345
+ api_media_input = gr.File(label="Upload Image or Video File")
346
+ api_pupil_selection = gr.Dropdown(
347
+ ["left_pupil", "right_pupil", "both"],
348
+ value="both",
349
+ label="Pupil Selection"
350
+ )
351
+ api_model = gr.Dropdown(
352
+ ["ResNet18", "ResNet50"],
353
+ value="ResNet18",
354
+ label="Model"
355
+ )
356
+ api_blink_detection = gr.Checkbox(value=True, label="Detect Blinks")
357
+ api_submit = gr.Button("Process Media", variant="primary")
358
+
359
+ with gr.Column():
360
+ api_output = gr.Image(label="Results")
361
+ api_text_output = gr.Textbox(label="Analysis Results", lines=10)
362
+
363
+ api_submit.click(
364
+ fn=process_media_unified,
365
+ inputs=[api_media_input, api_pupil_selection, api_model, api_blink_detection],
366
+ outputs=[api_output, api_text_output]
367
+ )
368
+
369
+ return demo
370
+
371
+
372
+ def process_image_simple(image, pupil_selection, tv_model, blink_detection):
373
+ """Simplified image processing function for gr.Interface."""
374
+ result_image, result_text = process_image_gradio(image, pupil_selection, tv_model, blink_detection)
375
+ return result_image, result_text
376
+
377
+
378
+ def process_video_simple(video_file, pupil_selection, tv_model, blink_detection):
379
+ """Simplified video processing function for gr.Interface."""
380
+ plot_img, csv_data, summary_text = process_video_gradio(video_file, pupil_selection, tv_model, blink_detection)
381
+ # Combine summary and CSV data for single output
382
+ combined_output = f"{summary_text}\n\n--- CSV Data ---\n{csv_data}"
383
+ return plot_img, combined_output
384
+
385
+
386
+ if __name__ == "__main__":
387
+ demo = create_gradio_interface()
388
+ demo.launch()
gradio_utils.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ import io
4
+ import os
5
+ import sys
6
+ import cv2
7
+ from matplotlib import pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import tempfile
11
+ from PIL import Image
12
+ from torchvision.transforms.functional import to_pil_image
13
+ from torchvision import transforms
14
+ from PIL import ImageOps
15
+ import os.path as osp
16
+
17
+ from torchcam.methods import CAM
18
+ from torchcam import methods as torchcam_methods
19
+ from torchcam.utils import overlay_mask
20
+
21
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
22
+ sys.path.append(root_path)
23
+
24
+ from preprocessing.dataset_creation import EyeDentityDatasetCreation
25
+ from utils import get_model
26
+
27
+ CAM_METHODS = ["CAM"]
28
+
29
+
30
+ @torch.no_grad()
31
+ def load_model(model_configs, device="cpu"):
32
+ """Loads the pre-trained model."""
33
+ model_path = os.path.join(root_path, model_configs["model_path"])
34
+ model_dict = torch.load(model_path, map_location=device)
35
+ model = get_model(model_configs=model_configs)
36
+ model.load_state_dict(model_dict)
37
+ model = model.to(device).eval()
38
+ return model
39
+
40
+
41
+ def extract_frames(video_path):
42
+ """Extracts frames from a video file."""
43
+ vidcap = cv2.VideoCapture(video_path)
44
+ frames = []
45
+ success, image = vidcap.read()
46
+ while success:
47
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
48
+ frames.append(image_rgb)
49
+ success, image = vidcap.read()
50
+ vidcap.release()
51
+ return frames
52
+
53
+
54
+ def resize_frame(frame, max_width=640, max_height=480):
55
+ """Resizes a frame while maintaining aspect ratio."""
56
+ if isinstance(frame, np.ndarray):
57
+ frame = Image.fromarray(frame)
58
+
59
+ # Calculate the scaling factor
60
+ width, height = frame.size
61
+ scale_w = max_width / width
62
+ scale_h = max_height / height
63
+ scale = min(scale_w, scale_h)
64
+
65
+ # Resize the frame
66
+ new_width = int(width * scale)
67
+ new_height = int(height * scale)
68
+ return frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
69
+
70
+
71
+ def is_image(file_extension):
72
+ """Check if file extension is an image format."""
73
+ return file_extension.lower() in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]
74
+
75
+
76
+ def is_video(file_extension):
77
+ """Check if file extension is a video format."""
78
+ return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm", "flv", "wmv"]
79
+
80
+
81
+ def get_configs(blink_detection=False):
82
+ """Get configuration for feature extraction."""
83
+ upscale = "-"
84
+ upscale_method_or_model = "-"
85
+ if upscale == "-":
86
+ sr_configs = None
87
+ else:
88
+ sr_configs = {
89
+ "method": upscale_method_or_model,
90
+ "params": {"upscale": upscale},
91
+ }
92
+ config_file = {
93
+ "sr_configs": sr_configs,
94
+ "feature_extraction_configs": {
95
+ "blink_detection": blink_detection,
96
+ "upscale": upscale,
97
+ "extraction_library": "mediapipe",
98
+ },
99
+ }
100
+ return config_file
101
+
102
+
103
+ def setup_gradio(pupil_selection, tv_model):
104
+ """Setup models and data structures for Gradio processing."""
105
+ left_pupil_model = None
106
+ left_pupil_cam_extractor = None
107
+ right_pupil_model = None
108
+ right_pupil_cam_extractor = None
109
+ output_frames = {}
110
+ input_frames = {}
111
+ predicted_diameters = {}
112
+
113
+ if pupil_selection == "both":
114
+ selected_eyes = ["left_eye", "right_eye"]
115
+ elif pupil_selection == "left_pupil":
116
+ selected_eyes = ["left_eye"]
117
+ elif pupil_selection == "right_pupil":
118
+ selected_eyes = ["right_eye"]
119
+
120
+ for eye_type in selected_eyes:
121
+ model_configs = {
122
+ "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
123
+ "registered_model_name": tv_model,
124
+ "num_classes": 1,
125
+ }
126
+ if eye_type == "left_eye":
127
+ left_pupil_model = load_model(model_configs)
128
+ left_pupil_cam_extractor = None
129
+ else:
130
+ right_pupil_model = load_model(model_configs)
131
+ right_pupil_cam_extractor = None
132
+
133
+ output_frames[eye_type] = []
134
+ input_frames[eye_type] = []
135
+ predicted_diameters[eye_type] = []
136
+
137
+ return (
138
+ selected_eyes,
139
+ input_frames,
140
+ output_frames,
141
+ predicted_diameters,
142
+ left_pupil_model,
143
+ left_pupil_cam_extractor,
144
+ right_pupil_model,
145
+ right_pupil_cam_extractor,
146
+ )
147
+
148
+
149
+ def process_frames_gradio(input_imgs, tv_model, pupil_selection, blink_detection=False):
150
+ """
151
+ Process frames without Streamlit dependencies.
152
+ """
153
+ try:
154
+ config_file = get_configs(blink_detection)
155
+
156
+ (
157
+ selected_eyes,
158
+ input_frames,
159
+ output_frames,
160
+ predicted_diameters,
161
+ left_pupil_model,
162
+ left_pupil_cam_extractor,
163
+ right_pupil_model,
164
+ right_pupil_cam_extractor,
165
+ ) = setup_gradio(pupil_selection, tv_model)
166
+
167
+ ds_creation = EyeDentityDatasetCreation(
168
+ feature_extraction_configs=config_file["feature_extraction_configs"],
169
+ sr_configs=config_file["sr_configs"],
170
+ )
171
+ except Exception as e:
172
+ print(f"Error in setup: {e}")
173
+ # Return empty results if setup fails
174
+ return {}, {}, {}
175
+
176
+ preprocess_steps = [
177
+ transforms.Resize(
178
+ [32, 64],
179
+ interpolation=transforms.InterpolationMode.BICUBIC,
180
+ antialias=True,
181
+ ),
182
+ transforms.ToTensor(),
183
+ ]
184
+ preprocess_function = transforms.Compose(preprocess_steps)
185
+
186
+ for idx, input_img in enumerate(input_imgs):
187
+ try:
188
+ img = np.array(input_img)
189
+ ds_results = ds_creation(img)
190
+ except Exception as e:
191
+ print(f"Error in MediaPipe processing for frame {idx}: {e}")
192
+ ds_results = None
193
+
194
+ left_eye = None
195
+ right_eye = None
196
+ blinked = False
197
+
198
+ if ds_results is not None and "face" in ds_results:
199
+ has_face = True
200
+ else:
201
+ has_face = False
202
+
203
+ if has_face and ds_results is not None:
204
+ if blink_detection and "blinks" in ds_results:
205
+ blinked = ds_results["blinks"]["blinked"]
206
+
207
+ if not blinked and "eyes" in ds_results:
208
+ if "left_eye" in ds_results["eyes"] and ds_results["eyes"]["left_eye"] is not None:
209
+ left_eye_img = to_pil_image(ds_results["eyes"]["left_eye"])
210
+ input_img_tensor = preprocess_function(left_eye_img)
211
+ input_img_tensor = input_img_tensor.unsqueeze(0)
212
+ if pupil_selection in ["left_pupil", "both"]:
213
+ left_eye = input_img_tensor
214
+
215
+ if "right_eye" in ds_results["eyes"] and ds_results["eyes"]["right_eye"] is not None:
216
+ right_eye_img = to_pil_image(ds_results["eyes"]["right_eye"])
217
+ input_img_tensor = preprocess_function(right_eye_img)
218
+ input_img_tensor = input_img_tensor.unsqueeze(0)
219
+ if pupil_selection in ["right_pupil", "both"]:
220
+ right_eye = input_img_tensor
221
+
222
+ for eye_type in selected_eyes:
223
+ if blinked:
224
+ if left_eye is not None and eye_type == "left_eye":
225
+ _, height, width = left_eye.squeeze(0).shape
226
+ input_image_pil = to_pil_image(left_eye.squeeze(0))
227
+ elif right_eye is not None and eye_type == "right_eye":
228
+ _, height, width = right_eye.squeeze(0).shape
229
+ input_image_pil = to_pil_image(right_eye.squeeze(0))
230
+ else:
231
+ # Create a default black image if no eye detected
232
+ input_image_pil = Image.new('RGB', (64, 32), 'black')
233
+ height, width = 32, 64
234
+
235
+ input_img_np = np.array(input_image_pil)
236
+ zeros_img = to_pil_image(np.zeros((height, width, 3), dtype=np.uint8))
237
+ output_img_np = np.array(zeros_img)
238
+ predicted_diameter = "blink"
239
+ else:
240
+ if left_eye is not None and eye_type == "left_eye":
241
+ if left_pupil_cam_extractor is None:
242
+ if tv_model == "ResNet18":
243
+ target_layer = left_pupil_model.resnet.layer4[-1].conv2
244
+ elif tv_model == "ResNet50":
245
+ target_layer = left_pupil_model.resnet.layer4[-1].conv3
246
+ else:
247
+ raise Exception(f"No target layer available for selected model: {tv_model}")
248
+ left_pupil_cam_extractor = torchcam_methods.__dict__["CAM"](
249
+ left_pupil_model,
250
+ target_layer=target_layer,
251
+ fc_layer=left_pupil_model.resnet.fc,
252
+ input_shape=left_eye.shape,
253
+ )
254
+ output = left_pupil_model(left_eye)
255
+ predicted_diameter = output[0].item()
256
+ act_maps = left_pupil_cam_extractor(0, output)
257
+ activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps)
258
+ input_image_pil = to_pil_image(left_eye.squeeze(0))
259
+ elif right_eye is not None and eye_type == "right_eye":
260
+ if right_pupil_cam_extractor is None:
261
+ if tv_model == "ResNet18":
262
+ target_layer = right_pupil_model.resnet.layer4[-1].conv2
263
+ elif tv_model == "ResNet50":
264
+ target_layer = right_pupil_model.resnet.layer4[-1].conv3
265
+ else:
266
+ raise Exception(f"No target layer available for selected model: {tv_model}")
267
+ right_pupil_cam_extractor = torchcam_methods.__dict__["CAM"](
268
+ right_pupil_model,
269
+ target_layer=target_layer,
270
+ fc_layer=right_pupil_model.resnet.fc,
271
+ input_shape=right_eye.shape,
272
+ )
273
+ output = right_pupil_model(right_eye)
274
+ predicted_diameter = output[0].item()
275
+ act_maps = right_pupil_cam_extractor(0, output)
276
+ activation_map = (
277
+ act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps)
278
+ )
279
+ input_image_pil = to_pil_image(right_eye.squeeze(0))
280
+ else:
281
+ # No eye detected, create default values
282
+ input_image_pil = Image.new('RGB', (64, 32), 'black')
283
+ predicted_diameter = "no_eye_detected"
284
+ output_img_np = np.array(input_image_pil)
285
+ input_frames[eye_type].append(np.array(input_image_pil))
286
+ output_frames[eye_type].append(output_img_np)
287
+ predicted_diameters[eye_type].append(predicted_diameter)
288
+ continue
289
+
290
+ # Create CAM overlay
291
+ activation_map_pil = to_pil_image(activation_map, mode="F")
292
+ result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5)
293
+ input_img_np = np.array(input_image_pil)
294
+ output_img_np = np.array(result)
295
+
296
+ input_frames[eye_type].append(input_img_np)
297
+ output_frames[eye_type].append(output_img_np)
298
+ predicted_diameters[eye_type].append(predicted_diameter)
299
+
300
+ return input_frames, output_frames, predicted_diameters
pre_trained_models/ResNet18/left_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98fb2c7880165c59ff975e02cd9e614fcf3a5859455f8d85695f57497dd894e6
3
+ size 46843194
pre_trained_models/ResNet18/right_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68e2928f13900580bcb9b7c1a1f6d4bba863cfcfee2def944b49ef0c09337668
3
+ size 46843194
pre_trained_models/ResNet50/left_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bd4bac728b71dae9e759b86188206a4f38fbc83b9507dd08f2a6abe1568d995
3
+ size 102554624
pre_trained_models/ResNet50/right_eye.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5179f569ea1886c9ad63ca9d047fdf721a9b59a63313cd9da3f2e3fae25de73
3
+ size 102554624
preprocessing/dataset_creation.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import os.path as osp
4
+
5
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
6
+ sys.path.append(root_path)
7
+
8
+ from feature_extraction.features_extractor import FeaturesExtractor
9
+
10
+
11
+ class EyeDentityDatasetCreation:
12
+
13
+ def __init__(self, feature_extraction_configs, sr_configs=None):
14
+ self.extraction_library = feature_extraction_configs["extraction_library"]
15
+ self.upscale = 1
16
+
17
+ self.blink_detection = feature_extraction_configs["blink_detection"]
18
+ self.features_extractor = FeaturesExtractor(
19
+ extraction_library=self.extraction_library,
20
+ blink_detection=self.blink_detection,
21
+ upscale=self.upscale,
22
+ )
23
+
24
+ def __call__(self, img):
25
+ result_dict = self.features_extractor(img)
26
+ return result_dict
preprocessing/dataset_creation_utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+
6
+
7
+ def seed_everything(seed=42):
8
+ random.seed(seed)
9
+ os.environ["PYTHONHASHSEED"] = str(seed)
10
+ np.random.seed(seed)
11
+ torch.manual_seed(seed)
12
+ torch.cuda.manual_seed(seed)
13
+ torch.backends.cudnn.benchmark = True
14
+ torch.backends.cudnn.deterministic = True
registrations/models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch.nn as nn
3
+ import os.path as osp
4
+ from torchvision import models
5
+ import torch.nn.functional as F
6
+ from registry import MODEL_REGISTRY
7
+
8
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
9
+ sys.path.append(root_path)
10
+
11
+ # ============================= ResNets =============================
12
+
13
+
14
+ @MODEL_REGISTRY.register()
15
+ class ResNet18(nn.Module):
16
+ def __init__(self, model_args):
17
+ super(ResNet18, self).__init__()
18
+ self.num_classes = model_args.get("num_classes", 1)
19
+ self.resnet = models.resnet18(weights=None)
20
+ self.regression_head = nn.Linear(1000, self.num_classes)
21
+
22
+ def forward(self, x, masks=None):
23
+ # Calculate the padding dynamically based on the input size
24
+ height, width = x.shape[2], x.shape[3]
25
+ pad_height = max(0, (224 - height) // 2)
26
+ pad_width = max(0, (224 - width) // 2)
27
+
28
+ # Apply padding
29
+ x = F.pad(x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0)
30
+ x = self.resnet(x)
31
+ x = self.regression_head(x)
32
+ return x
33
+
34
+
35
+ @MODEL_REGISTRY.register()
36
+ class ResNet50(nn.Module):
37
+ def __init__(self, model_args):
38
+ super(ResNet50, self).__init__()
39
+ self.num_classes = model_args.get("num_classes", 1)
40
+ self.resnet = models.resnet50(weights=None)
41
+ self.regression_head = nn.Linear(1000, self.num_classes)
42
+
43
+ def forward(self, x, masks=None):
44
+ # Calculate the padding dynamically based on the input size
45
+ height, width = x.shape[2], x.shape[3]
46
+ pad_height = max(0, (224 - height) // 2)
47
+ pad_width = max(0, (224 - width) // 2)
48
+
49
+ # Apply padding
50
+ x = F.pad(x, (pad_width, pad_width, pad_height, pad_height), mode="constant", value=0)
51
+ x = self.resnet(x)
52
+ x = self.regression_head(x)
53
+ return x
54
+
55
+
56
+ # print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
registry.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2
+
3
+
4
+ class Registry:
5
+ """
6
+ The registry that provides name -> object mapping, to support third-party
7
+ users' custom modules.
8
+
9
+ To create a registry (e.g. a backbone registry):
10
+
11
+ .. code-block:: python
12
+
13
+ BACKBONE_REGISTRY = Registry('BACKBONE')
14
+
15
+ To register an object:
16
+
17
+ .. code-block:: python
18
+
19
+ @BACKBONE_REGISTRY.register()
20
+ class MyBackbone():
21
+ ...
22
+
23
+ Or:
24
+
25
+ .. code-block:: python
26
+
27
+ BACKBONE_REGISTRY.register(MyBackbone)
28
+ """
29
+
30
+ def __init__(self, name):
31
+ """
32
+ Args:
33
+ name (str): the name of this registry
34
+ """
35
+ self._name = name
36
+ self._obj_map = {}
37
+
38
+ def _do_register(self, name, obj):
39
+ assert name not in self._obj_map, (
40
+ f"An object named '{name}' was already registered "
41
+ f"in '{self._name}' registry!"
42
+ )
43
+ self._obj_map[name] = obj
44
+
45
+ def register(self, obj=None):
46
+ """
47
+ Register the given object under the the name `obj.__name__`.
48
+ Can be used as either a decorator or not.
49
+ See docstring of this class for usage.
50
+ """
51
+ if obj is None:
52
+ # used as a decorator
53
+ def deco(func_or_class):
54
+ name = func_or_class.__name__
55
+ self._do_register(name, func_or_class)
56
+ return func_or_class
57
+
58
+ return deco
59
+
60
+ # used as a function call
61
+ name = obj.__name__
62
+ self._do_register(name, obj)
63
+
64
+ def get(self, name):
65
+ ret = self._obj_map.get(name)
66
+ if ret is None:
67
+ raise KeyError(
68
+ f"No object named '{name}' found in '{self._name}' registry!"
69
+ )
70
+ return ret
71
+
72
+ def __contains__(self, name):
73
+ return name in self._obj_map
74
+
75
+ def __iter__(self):
76
+ return iter(self._obj_map.items())
77
+
78
+ def keys(self):
79
+ return self._obj_map.keys()
80
+
81
+
82
+ MODEL_REGISTRY = Registry("model")
registry_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ from os import path as osp
4
+
5
+
6
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
7
+ """Scan a directory to find the interested files.
8
+
9
+ Args:
10
+ dir_path (str): Path of the directory.
11
+ suffix (str | tuple(str), optional): File suffix that we are
12
+ interested in. Default: None.
13
+ recursive (bool, optional): If set to True, recursively scan the
14
+ directory. Default: False.
15
+ full_path (bool, optional): If set to True, include the dir_path.
16
+ Default: False.
17
+
18
+ Returns:
19
+ A generator for all the interested files with relative paths.
20
+ """
21
+
22
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
23
+ raise TypeError('"suffix" must be a string or tuple of strings')
24
+
25
+ root = dir_path
26
+
27
+ def _scandir(dir_path, suffix, recursive):
28
+ for entry in os.scandir(dir_path):
29
+ if not entry.name.startswith(".") and entry.is_file():
30
+ if full_path:
31
+ return_path = entry.path
32
+ else:
33
+ return_path = osp.relpath(entry.path, root)
34
+
35
+ if suffix is None:
36
+ yield return_path
37
+ elif return_path.endswith(suffix):
38
+ yield return_path
39
+ else:
40
+ if recursive:
41
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
42
+ else:
43
+ continue
44
+
45
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
46
+
47
+
48
+ def import_registered_modules(registration_folder="registrations"):
49
+ """
50
+ Import all registered modules from the specified folder.
51
+
52
+ This function automatically scans all the files under the specified folder and imports all the required modules for registry.
53
+
54
+ Parameters:
55
+ registration_folder (str, optional): Path to the folder containing registration modules. Default is "registrations".
56
+
57
+ Returns:
58
+ list: List of imported modules.
59
+ """
60
+
61
+ # print("\n")
62
+
63
+ registration_modules_folder = (
64
+ osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
65
+ )
66
+ # print("registration_modules_folder = ", registration_modules_folder)
67
+
68
+ registration_modules_file_names = [
69
+ osp.splitext(osp.basename(v))[0]
70
+ for v in scandir(dir_path=registration_modules_folder)
71
+ ]
72
+ # print("registration_modules_file_names = ", registration_modules_file_names)
73
+
74
+ imported_modules = [
75
+ importlib.import_module(f"{registration_folder}.{file_name}")
76
+ for file_name in registration_modules_file_names
77
+ ]
78
+ # print("imported_modules = ", imported_modules)
79
+ # print("\n")
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/docs/hub/en/spaces-dependencies
2
+ tqdm
3
+ PyYAML
4
+ numpy
5
+ pandas
6
+ matplotlib
7
+ seaborn
8
+ mlflow
9
+ pillow
10
+ scikit_learn
11
+ torch
12
+ # captum
13
+ evaluate
14
+ # basicsr
15
+ facexlib
16
+ # realesrgan
17
+ opencv_python
18
+ cmake
19
+ # dlib
20
+ einops
21
+ transformers
22
+ # gfpgan
23
+ gradio==4.36.1
24
+ mediapipe
25
+ imutils
26
+ scipy
27
+ torchvision
28
+ torchcam
sample_videos/All Smiles Ahead.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcb016898fac06a517f067ccbe1e6a32366e984aa9b07a7920a8bc9fdd780d17
3
+ size 951586
sample_videos/And it was all Yellow.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfd4d681a4a289aa24fc2193fbcbf855e69c6ce19b9619d8d88b7d782c6047dc
3
+ size 956742
sample_videos/Blink It Like Brian.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:590a74b0f13415bb7817c7e28f3f3348bb38aa1f517b8e311f8041c978b4c38b
3
+ size 961928
sample_videos/Focus Pocus.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf88c7cff8ef627c03c589ca5feb9c83e95db7e7b55294cbcbafefdfe31cdcf6
3
+ size 971924
sample_videos/Funny Talks.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb22539b023294865bc0df2c6c6622ed94d3d5d27df6d0a22ae8fb193c2d6910
3
+ size 963970
sample_videos/I like to move it move it.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a4d66abf1707607f826b4b65e20f920d90281e6ab0df5757b57da7216f424b3
3
+ size 958302
sample_videos/Infinite Blue.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6047ecda02535212c55714b727c32777a4891be91f642847c586bb24bd3d00d
3
+ size 960117
sample_videos/Red Ross.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1078b3fda618c4f5adefe05d389c8360d2a95af1080be7f699a0d72ca454bb3f
3
+ size 960661
sample_videos/Smile, You’re on Camera!.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8757fa6781889323f9d577862dcc98f05abd295be6de8e8843a3eb1cd406fdd
3
+ size 965710
utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from registry import MODEL_REGISTRY
2
+
3
+
4
+ def get_model(model_configs):
5
+ registered_model = MODEL_REGISTRY.get(model_configs["registered_model_name"])
6
+ model_configs.pop("registered_model_name")
7
+ if len(model_configs) > 0:
8
+ model = registered_model(model_configs)
9
+ else:
10
+ model = registered_model()
11
+ return model