GilbertKrantz commited on
Commit
61c2d3f
·
0 Parent(s):

Initial Commit

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ weights/efficientvit.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # Data
177
+ Data/
178
+
179
+ EDA/
180
+
181
+ # Model Outputs
182
+ model_outputs/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12.9
DockerFile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ libgl1-mesa-glx \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements first to leverage Docker cache
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy project files
19
+ COPY . .
20
+
21
+ # Make the gradio inference script executable
22
+ RUN chmod +x gradio_inference.py
23
+
24
+ # Expose port for Gradio
25
+ EXPOSE 7860
26
+
27
+ # Set the entrypoint command
28
+ CMD ["python", "gradio_inference.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eye Disease Detection
2
+
3
+ This repository contains a Gradio web application for eye disease detection using deep learning models. The application allows users to upload fundus images and get predictions for common eye conditions.
4
+
5
+ ## Features
6
+
7
+ - **Easy-to-use web interface** for eye disease detection
8
+ - Support for **multiple model architectures** (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
9
+ - **Custom model loading** from saved model checkpoints
10
+ - **Visualization** of prediction probabilities
11
+ - **Dockerized deployment** option
12
+
13
+ ## Supported Eye Conditions
14
+
15
+ The system can detect the following eye conditions:
16
+ - Central Serous Chorioretinopathy
17
+ - Diabetic Retinopathy
18
+ - Disc Edema
19
+ - Glaucoma
20
+ - Healthy (normal eye)
21
+ - Macular Scar
22
+ - Myopia
23
+ - Retinal Detachment
24
+ - Retinitis Pigmentosa
25
+
26
+ ## Installation
27
+
28
+ ### Prerequisites
29
+
30
+ - Python 3.12+
31
+ - PyTorch 2.7.0+
32
+ - CUDA-compatible GPU (optional, but recommended for faster inference)
33
+
34
+ ### Option 1: Local Installation
35
+
36
+ 1. Clone this repository:
37
+ ```bash
38
+ git clone https://github.com/GilbertKrantz/eye-disease-detection.git
39
+ cd eye-disease-detection
40
+ ```
41
+
42
+ 2. Install the required packages:
43
+ ```bash
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ 3. Run the application:
48
+ ```bash
49
+ python gradio_inference.py
50
+ ```
51
+
52
+ 4. Open your browser and go to http://localhost:7860
53
+
54
+ ### Option 2: Docker Installation
55
+
56
+ 1. Build the Docker image:
57
+ ```bash
58
+ docker build -t eye-disease-detection .
59
+ ```
60
+
61
+ 2. Run the container:
62
+ ```bash
63
+ docker run -p 7860:7860 eye-disease-detection
64
+ ```
65
+
66
+ 3. Open your browser and go to http://localhost:7860
67
+
68
+ ## Usage
69
+
70
+ 1. Upload a fundus image of the eye
71
+ 2. (Optional) Specify the path to your trained model file (.pth)
72
+ 3. Select the model architecture (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
73
+ 4. Click "Analyze Image" to get the prediction
74
+ 5. View the results and probability distribution
75
+
76
+ ## Model Training
77
+
78
+ This repository focuses on inference. For training your own models, refer to the main training script and follow these steps:
79
+
80
+ 1. Prepare your dataset in the required directory structure
81
+ 2. Train a model using the main.py script:
82
+ ```bash
83
+ python main.py --train-dir "/path/to/training/data" --eval-dir "/path/to/eval/data" --model mobilenetv4 --epochs 20 --save-model "my_model.pth"
84
+ ```
85
+ 3. Use the saved model with the inference application
86
+
87
+ ## Project Structure
88
+
89
+ ```
90
+ .
91
+ ├── gradio_inference.py # Main Gradio application
92
+ ├── requirements.txt # Python dependencies
93
+ ├── Dockerfile # Docker configuration
94
+ ├── README.md # This documentation
95
+ ├── utils/ # Utility modules
96
+ │ ├── ModelCreator.py # Model architecture definitions
97
+ │ ├── Evaluator.py # Model evaluation utilities
98
+ │ ├── DatasetHandler.py # Dataset handling utilities
99
+ │ ├── Trainer.py # Model training utilities
100
+ │ └── Callback.py # Training callbacks
101
+ └── main.py # Main training script
102
+ ```
103
+
104
+ ## Performance
105
+
106
+ The performance of the models depends on the quality of training data and the specific architecture used. In general, these models can achieve accuracy rates of 85-95% on standard eye disease datasets.
107
+
108
+ ## Customization
109
+
110
+ You can customize the application in several ways:
111
+ - Add example images in the Gradio interface
112
+ - Extend the list of supported classes by modifying the CLASSES variable in gradio_inference.py
113
+ - Add support for additional model architectures in ModelCreator.py
114
+
115
+ ## License
116
+
117
+ This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
118
+
119
+ ## Acknowledgments
120
+
121
+ - The models are built using PyTorch and the TIMM library
122
+ - The web interface is built using Gradio
123
+ - Special thanks to the open-source community for making this project possible
gradio-inference.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Eye Disease Detection - Gradio Inference App
3
+ # Date: May 11, 2025
4
+
5
+ import os
6
+ import sys
7
+ import torch
8
+ import numpy as np
9
+ import gradio as gr
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+ import logging
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ # Import custom modules
17
+ sys.path.append("./utils")
18
+ from ModelCreator import EyeDetectionModels
19
+
20
+ # Set device
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f"Using device: {device}")
23
+
24
+ # Define class names (make sure these match your model's classes)
25
+ CLASSES = [
26
+ "Central Serous Chorioretinopathy",
27
+ "Diabetic Retinopathy",
28
+ "Disc Edema",
29
+ "Glaucoma",
30
+ "Healthy",
31
+ "Macular Scar",
32
+ "Myopia",
33
+ "Retinal Detachment",
34
+ "Retinitis Pigmentosa",
35
+ ]
36
+
37
+
38
+ def get_transform():
39
+ """Get the standard transformation pipeline for inference."""
40
+ return transforms.Compose(
41
+ [
42
+ transforms.Resize(256),
43
+ transforms.CenterCrop(224),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
46
+ ]
47
+ )
48
+
49
+
50
+ def load_model(model_path, model_type="efficientvit"):
51
+ """
52
+ Load a pretrained model for inference.
53
+
54
+ Args:
55
+ model_path: Path to the saved model state dict
56
+ model_type: Type of model to load (mobilenetv4, levit, efficientvit, gernet, regnetx)
57
+
58
+ Returns:
59
+ Loaded model ready for inference
60
+ """
61
+ # Initialize model creator
62
+ logging.info("Initializing model creator...")
63
+ model_creator = EyeDetectionModels(
64
+ num_classes=len(CLASSES), freeze_layers=False # Not relevant for inference
65
+ )
66
+
67
+ # Check if model type exists
68
+ if model_type not in model_creator.models:
69
+ raise ValueError(
70
+ f"Model type '{model_type}' not found. Available models: {list(model_creator.models.keys())}"
71
+ )
72
+
73
+ # Create model of specified type
74
+ logging.info(f"Creating model of type: {model_type}")
75
+ model = model_creator.models[model_type]()
76
+
77
+ # Load state dict if provided
78
+ if model_path and not os.path.exists(model_path):
79
+ raise FileNotFoundError(f"Model path '{model_path}' does not exist.")
80
+ elif model_path is None:
81
+ # Use default model path if it exists
82
+ if os.path.exists(f"./weights/{model_type}.pth"):
83
+ model_path = f"./weights/{model_type}.pth"
84
+ else:
85
+ model_path = None
86
+ logging.warning(
87
+ f"Default model path '{model_path}' not found. Using untrained model."
88
+ )
89
+ # Set model to evaluation mode
90
+ model.eval()
91
+ return model
92
+
93
+
94
+ def predict_image(image, model_path, model_type):
95
+ """
96
+ Predict eye disease from an uploaded image.
97
+
98
+ Args:
99
+ image: Input image from Gradio
100
+ model_path: Path to the model state dict
101
+ model_type: Type of model architecture
102
+
103
+ Returns:
104
+ Dictionary of class probabilities
105
+ """
106
+ try:
107
+
108
+ logging.info("Starting prediction...")
109
+ # Load model
110
+ model = load_model(model_path, model_type)
111
+
112
+ # Preprocess image
113
+ logging.info("Preprocessing image...")
114
+ if image is None:
115
+ logging.warning("No image provided.")
116
+ return {cls: 0.0 for cls in CLASSES}
117
+ transform = get_transform()
118
+ if image is None:
119
+ return {cls: 0.0 for cls in CLASSES}
120
+
121
+ # Convert numpy array to PIL Image
122
+ img = Image.fromarray(image).convert("RGB")
123
+ img_tensor = transform(img).unsqueeze(0).to(device)
124
+ logging.info("Image preprocessed successfully.")
125
+
126
+ # Make prediction
127
+ with torch.no_grad():
128
+ outputs = model(img_tensor)
129
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()
130
+
131
+ # Return probabilities for each class
132
+ return {cls: float(prob) for cls, prob in zip(CLASSES, probabilities)}
133
+
134
+ except Exception as e:
135
+ import traceback
136
+
137
+ traceback.print_exc()
138
+ return {cls: 0.0 for cls in CLASSES}
139
+
140
+
141
+ def main():
142
+ """Main function to run the Gradio interface."""
143
+ # Define available models
144
+ model_types = ["mobilenetv4", "levit", "efficientvit", "gernet", "regnetx"]
145
+
146
+ # Create the Gradio interface
147
+ with gr.Blocks(title="Eye Disease Detection") as demo:
148
+ gr.Markdown("# Eye Disease Detection System")
149
+ gr.Markdown(
150
+ """This application uses deep learning to detect eye diseases from fundus images.
151
+ Currently supports detection of: Cataract, Diabetic Retinopathy, Glaucoma, and Normal eyes."""
152
+ )
153
+
154
+ with gr.Row():
155
+ with gr.Column():
156
+ input_image = gr.Image(label="Upload Fundus Image", type="numpy")
157
+ model_path = gr.Textbox(
158
+ label="Model Path (leave empty to use default)",
159
+ placeholder="Path to model .pth file",
160
+ value="",
161
+ )
162
+ model_type = gr.Dropdown(
163
+ label="Model Architecture", choices=model_types, value="mobilenetv4"
164
+ )
165
+ submit_btn = gr.Button("Analyze Image", variant="primary")
166
+
167
+ with gr.Column():
168
+ output_chart = gr.Label(label="Prediction")
169
+
170
+ # Process the image when the button is clicked
171
+ submit_btn.click(
172
+ fn=predict_image,
173
+ inputs=[input_image, model_path, model_type],
174
+ outputs=output_chart,
175
+ )
176
+
177
+ # Examples section
178
+ gr.Markdown("### Examples (Please add your own example images)")
179
+ gr.Examples(
180
+ examples=[], # Add example paths here
181
+ inputs=input_image,
182
+ outputs=[output_chart],
183
+ fn=predict_image,
184
+ cache_examples=True,
185
+ )
186
+
187
+ # Usage instructions
188
+ with gr.Accordion("Usage Instructions", open=False):
189
+ gr.Markdown(
190
+ """
191
+ ## How to use this application:
192
+
193
+ 1. **Upload an image**: Click the upload button to select a fundus image from your computer
194
+ 2. **Specify model** (Optional):
195
+ - Enter the path to your trained model file (.pth)
196
+ - Select the model architecture that was used for training
197
+ 3. **Analyze**: Click the "Analyze Image" button to get results
198
+ 4. **Interpret results**: The system will show the detected condition and probability distribution
199
+
200
+ ## Model Information:
201
+
202
+ This system supports multiple model architectures:
203
+ - **MobileNetV4**: Lightweight and efficient model
204
+ - **LeViT**: Vision Transformer designed for efficiency
205
+ - **EfficientViT**: Hybrid CNN-Transformer architecture
206
+ - **GENet**: General and Efficient Network
207
+ - **RegNetX**: Systematically designed CNN architecture
208
+
209
+ For best results, ensure you're using a high-quality fundus image and the correct model type.
210
+ """
211
+ )
212
+
213
+ # Launch the app
214
+ demo.launch(
215
+ share=True,
216
+ pwa=True,
217
+ )
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main()
main.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Eye Disease Detection - Main Application
3
+ # Date: May 11, 2025
4
+
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import random
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import matplotlib.pyplot as plt
13
+ from torchvision import transforms, datasets
14
+ from torch.utils.data import DataLoader, random_split
15
+
16
+ # Import custom modules
17
+ sys.path.append("./utils")
18
+ from ModelCreator import EyeDetectionModels
19
+ from DatasetHandler import FilteredImageDataset
20
+ from Evaluator import ClassificationEvaluator
21
+ from Comparator import compare_models
22
+ from Trainer import model_train
23
+
24
+
25
+ # Set random seeds for reproducibility
26
+ def set_seed(seed=42):
27
+ """Set seeds for reproducibility."""
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ if torch.cuda.is_available():
32
+ torch.cuda.manual_seed(seed)
33
+ torch.cuda.manual_seed_all(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+ torch.backends.cudnn.benchmark = False
36
+
37
+
38
+ def get_transform():
39
+ """
40
+ Get standard data transform for both training and validation/testing.
41
+
42
+ Returns:
43
+ transform: Standard transform for all datasets
44
+ """
45
+ # Standard transform as specified
46
+ transform = transforms.Compose(
47
+ [
48
+ transforms.Resize(256),
49
+ transforms.CenterCrop(224),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
52
+ ]
53
+ )
54
+
55
+ return transform
56
+
57
+
58
+ def load_data(args):
59
+ """
60
+ Load and prepare datasets from separate directories for training and evaluation.
61
+
62
+ Args:
63
+ args: Command line arguments
64
+
65
+ Returns:
66
+ train_loader: DataLoader for training
67
+ val_loader: DataLoader for validation
68
+ test_loader: DataLoader for testing
69
+ dataset_ref: Reference to the evaluation dataset for class information
70
+ """
71
+ print(f"Loading training dataset from: {args.train_dir}")
72
+ print(f"Loading evaluation dataset from: {args.eval_dir}")
73
+
74
+ # Get standard transform
75
+ transform = get_transform()
76
+
77
+ # Load training dataset
78
+ train_dataset = datasets.ImageFolder(args.train_dir, transform=transform)
79
+ print(f"Training dataset classes: {train_dataset.classes}")
80
+ print(f"Training dataset size: {len(train_dataset)}")
81
+
82
+ # Load evaluation dataset
83
+ eval_dataset = datasets.ImageFolder(args.eval_dir, transform=transform)
84
+ print(f"Evaluation dataset classes: {eval_dataset.classes}")
85
+
86
+ # Apply class filtering if requested
87
+ excluded_classes = args.exclude_classes.split(",") if args.exclude_classes else None
88
+ if excluded_classes and any(excluded_classes):
89
+ train_dataset = FilteredImageDataset(train_dataset, excluded_classes)
90
+ eval_dataset = FilteredImageDataset(eval_dataset, excluded_classes)
91
+ print(f"After filtering - Classes: {eval_dataset.classes}")
92
+
93
+ print(f"After filtering - Train size: {len(train_dataset)}")
94
+ print(f"After filtering - Eval size: {len(eval_dataset)}")
95
+
96
+ # Split evaluation dataset into validation and test sets
97
+ val_size = int(
98
+ len(eval_dataset) * (args.val_split / (args.val_split + args.test_split))
99
+ )
100
+ test_size = len(eval_dataset) - val_size
101
+
102
+ val_dataset, test_dataset = random_split(eval_dataset, [val_size, test_size])
103
+
104
+ print(
105
+ f"Split sizes - Train: {len(train_dataset)}, "
106
+ f"Validation: {len(val_dataset)}, Test: {len(test_dataset)}"
107
+ )
108
+
109
+ # Create data loaders
110
+ train_loader = DataLoader(
111
+ train_dataset,
112
+ batch_size=args.batch_size,
113
+ shuffle=True,
114
+ num_workers=args.num_workers,
115
+ pin_memory=True,
116
+ )
117
+
118
+ val_loader = DataLoader(
119
+ val_dataset,
120
+ batch_size=args.batch_size,
121
+ shuffle=False,
122
+ num_workers=args.num_workers,
123
+ pin_memory=True,
124
+ )
125
+
126
+ test_loader = DataLoader(
127
+ test_dataset,
128
+ batch_size=args.batch_size,
129
+ shuffle=False,
130
+ num_workers=args.num_workers,
131
+ pin_memory=True,
132
+ )
133
+
134
+ # Use eval_dataset as the reference for class information
135
+ return train_loader, val_loader, test_loader, eval_dataset
136
+
137
+
138
+ def train_single_model(args, train_loader, val_loader, test_loader, dataset):
139
+ """Train a single model specified by the arguments."""
140
+
141
+ print(f"Creating {args.model} model...")
142
+
143
+ # Initialize model creator
144
+ model_creator = EyeDetectionModels(
145
+ num_classes=len(dataset.classes), freeze_layers=(not args.unfreeze_all)
146
+ )
147
+
148
+ # Get model
149
+ if args.model in model_creator.models:
150
+ model = model_creator.models[args.model]()
151
+ else:
152
+ available_models = list(model_creator.models.keys())
153
+ print(
154
+ f"Error: Model '{args.model}' not found. Available models: {available_models}"
155
+ )
156
+ sys.exit(1)
157
+
158
+ # Train and evaluate model
159
+ results = model_train(model, train_loader, val_loader, dataset, epochs=args.epochs)
160
+
161
+ # Test the model
162
+ if results["accuracy"] is not None:
163
+ print("\nEvaluating on test set...")
164
+ evaluator = ClassificationEvaluator(class_names=dataset.classes)
165
+ test_results = evaluator.evaluate_model(model, test_loader)
166
+ print(f"Test accuracy: {test_results['accuracy']:.4f}")
167
+
168
+ # Save model if requested
169
+ if args.save_model:
170
+ save_path = args.save_model
171
+ try:
172
+ torch.save(model.state_dict(), save_path)
173
+ print(f"Model saved to {save_path}")
174
+ except Exception as e:
175
+ print(f"Error saving model: {e}")
176
+ else:
177
+ print("Training failed. Cannot evaluate on test set.")
178
+
179
+
180
+ def compare_multiple_models(args, train_loader, val_loader, test_loader, dataset):
181
+ """Compare multiple models."""
182
+
183
+ print("Preparing to compare multiple models...")
184
+
185
+ # Initialize model creator
186
+ model_creator = EyeDetectionModels(
187
+ num_classes=len(dataset.classes), freeze_layers=(not args.unfreeze_all)
188
+ )
189
+
190
+ # Get list of models to compare
191
+ model_names = args.compare_models.split(",")
192
+ models = []
193
+ names = []
194
+
195
+ for model_name in model_names:
196
+ model_name = model_name.strip()
197
+ if model_name in model_creator.models:
198
+ print(f"Adding {model_name} to comparison...")
199
+ models.append(model_creator.models[model_name]())
200
+ names.append(model_name)
201
+ else:
202
+ print(f"Warning: Model '{model_name}' not found, skipping.")
203
+
204
+ if not models:
205
+ print("No valid models to compare. Exiting.")
206
+ return
207
+
208
+ # Run comparison
209
+ compare_models(
210
+ models,
211
+ train_loader,
212
+ val_loader,
213
+ test_loader,
214
+ dataset,
215
+ epochs=args.epochs,
216
+ names=names,
217
+ )
218
+
219
+
220
+ def main():
221
+ """Main function to run the eye disease detection application."""
222
+
223
+ # Set up argument parser with example usage
224
+ parser = argparse.ArgumentParser(
225
+ description="Eye Disease Detection using Deep Learning",
226
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
227
+ epilog="""
228
+ Examples:
229
+ # Train a single model
230
+ python main.py --train-dir "/path/to/augmented_dataset" --eval-dir "/path/to/original_dataset" --model mobilenetv4 --epochs 20 --save-model best_model.pth
231
+
232
+ # Compare multiple models
233
+ python main.py --train-dir "/path/to/augmented_dataset" --eval-dir "/path/to/original_dataset" --compare-models mobilenetv4,levit,efficientvit --epochs 15
234
+ """,
235
+ )
236
+
237
+ # Dataset and data loading arguments
238
+ data_group = parser.add_argument_group("Data Options")
239
+ data_group.add_argument(
240
+ "--train-dir",
241
+ type=str,
242
+ required=True,
243
+ help="Path to the training dataset directory (Augmented Dataset)",
244
+ )
245
+ data_group.add_argument(
246
+ "--eval-dir",
247
+ type=str,
248
+ required=True,
249
+ help="Path to the evaluation dataset directory (Original Dataset)",
250
+ )
251
+ data_group.add_argument(
252
+ "--batch-size",
253
+ type=int,
254
+ default=32,
255
+ help="Batch size for training and evaluation",
256
+ )
257
+ data_group.add_argument(
258
+ "--val-split",
259
+ type=float,
260
+ default=0.5,
261
+ help="Validation split ratio within evaluation set",
262
+ )
263
+ data_group.add_argument(
264
+ "--test-split",
265
+ type=float,
266
+ default=0.5,
267
+ help="Test split ratio within evaluation set",
268
+ )
269
+ data_group.add_argument(
270
+ "--num-workers",
271
+ type=int,
272
+ default=4,
273
+ help="Number of worker processes for data loading",
274
+ )
275
+ data_group.add_argument(
276
+ "--exclude-classes",
277
+ type=str,
278
+ default=None,
279
+ help="Comma-separated list of class names to exclude",
280
+ )
281
+
282
+ # Model arguments
283
+ model_group = parser.add_argument_group("Model Options")
284
+ model_group.add_argument(
285
+ "--model",
286
+ type=str,
287
+ default="mobilenetv4",
288
+ help="Model architecture to use. Options: mobilenetv4, levit, efficientvit, gernet, regnetx",
289
+ )
290
+ model_group.add_argument(
291
+ "--unfreeze-all", action="store_true", help="Unfreeze all layers for training"
292
+ )
293
+ model_group.add_argument(
294
+ "--compare-models",
295
+ type=str,
296
+ default=None,
297
+ help="Comma-separated list of models to compare",
298
+ )
299
+
300
+ # Training arguments
301
+ train_group = parser.add_argument_group("Training Options")
302
+ train_group.add_argument(
303
+ "--epochs", type=int, default=20, help="Number of training epochs"
304
+ )
305
+ train_group.add_argument(
306
+ "--seed", type=int, default=42, help="Random seed for reproducibility"
307
+ )
308
+ train_group.add_argument(
309
+ "--save-model", type=str, default=None, help="Path to save the trained model"
310
+ )
311
+
312
+ # Parse arguments
313
+ args = parser.parse_args()
314
+
315
+ # Set random seed for reproducibility
316
+ set_seed(args.seed)
317
+
318
+ # Display GPU information
319
+ if torch.cuda.is_available():
320
+ device_count = torch.cuda.device_count()
321
+ print(f"Using {device_count} GPU{'s' if device_count > 1 else ''}")
322
+ for i in range(device_count):
323
+ print(f" Device {i}: {torch.cuda.get_device_name(i)}")
324
+ else:
325
+ print("No GPU available, using CPU")
326
+
327
+ # Load data
328
+ train_loader, val_loader, test_loader, dataset = load_data(args)
329
+
330
+ # Check if comparing multiple models
331
+ if args.compare_models:
332
+ compare_multiple_models(args, train_loader, val_loader, test_loader, dataset)
333
+ else:
334
+ train_single_model(args, train_loader, val_loader, test_loader, dataset)
335
+
336
+
337
+ if __name__ == "__main__":
338
+ # Example usage for direct execution:
339
+ # python main.py --train-dir "/kaggle/input/eye-disease-image-dataset/Augmented Dataset/Augmented Dataset" \
340
+ # --eval-dir "/kaggle/input/eye-disease-image-dataset/Original Dataset/Original Dataset" \
341
+ # --model mobilenetv4 --epochs 10
342
+ main()
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "eyediseasedetection"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12.9"
7
+ dependencies = [
8
+ "gradio>=5.29.0",
9
+ "matplotlib>=3.10.3",
10
+ "pandas>=2.2.3",
11
+ "scikit-learn>=1.6.1",
12
+ "seaborn>=0.13.2",
13
+ "timm>=1.0.15",
14
+ "torch>=2.7.0",
15
+ "torchaudio>=2.7.0",
16
+ "torchvision>=0.22.0",
17
+ "tqdm>=4.67.1",
18
+ ]
19
+
20
+ [tool.uv.sources]
21
+ torch = [
22
+ { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
23
+ ]
24
+ torchvision = [
25
+ { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
26
+ ]
27
+
28
+ [[tool.uv.index]]
29
+ name = "pytorch-cpu"
30
+ url = "https://download.pytorch.org/whl/cpu"
31
+ explicit = true
32
+
33
+ [[tool.uv.index]]
34
+ name = "pytorch-cu128"
35
+ url = "https://download.pytorch.org/whl/cu128"
36
+ explicit = true
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.7.0
2
+ torchvision>=0.22.0
3
+ torchaudio>=2.7.0
4
+ numpy>=1.26.0
5
+ Pillow>=10.0.0
6
+ gradio>=4.11.0
7
+ matplotlib>=3.8.0
8
+ seaborn>=0.13.0
9
+ scikit-learn>=1.4.0
10
+ tqdm>=4.66.0
11
+ timm>=1.0.0
training.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.11.11","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":10951558,"sourceType":"datasetVersion","datasetId":6812365}],"dockerImageVersionId":31011,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nimport time\nimport random\nimport copy\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchvision.transforms as transforms\nfrom torchvision import transforms, datasets\nimport torchvision.models as models\nimport timm\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nfrom PIL import Image\nfrom torch.utils.data import DataLoader, random_split, Subset, Dataset\nfrom sklearn.metrics import (\n accuracy_score, confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve, average_precision_score\n)\nfrom sklearn.preprocessing import label_binarize\nfrom tqdm import tqdm\nimport gc\n\n\n# Set device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Using device: {device}\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:11.680667Z","iopub.execute_input":"2025-05-07T15:22:11.680991Z","iopub.status.idle":"2025-05-07T15:22:16.455808Z","shell.execute_reply.started":"2025-05-07T15:22:11.680969Z","shell.execute_reply":"2025-05-07T15:22:16.455137Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Modified FilteredImageDataset class with Pterygium filtering\nclass FilteredImageDataset(Dataset):\n def __init__(self, dataset, excluded_classes=None):\n \"\"\"\n Create a filtered dataset that excludes specific classes.\n \n Args:\n dataset: Original dataset (ImageFolder or similar)\n excluded_classes: List of class names to exclude (e.g., [\"Pterygium\"])\n \"\"\"\n self.dataset = dataset\n self.excluded_classes = excluded_classes or []\n \n # Get original class information\n self.orig_classes = dataset.classes\n self.orig_class_to_idx = dataset.class_to_idx\n \n # Create indices of samples to keep (excluding specified classes)\n self.indices = []\n for idx, (_, target) in enumerate(dataset.samples):\n class_name = self.orig_classes[target]\n if class_name not in self.excluded_classes:\n self.indices.append(idx)\n \n # Create new class mapping without excluded classes\n remaining_classes = [c for c in self.orig_classes if c not in self.excluded_classes]\n self.classes = remaining_classes\n self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)}\n self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}\n \n # Create a mapping from old indices to new indices\n self.target_mapping = {}\n for old_class, old_idx in self.orig_class_to_idx.items():\n if old_class in self.class_to_idx:\n self.target_mapping[old_idx] = self.class_to_idx[old_class]\n \n print(f\"Filtered out classes: {self.excluded_classes}\")\n print(f\"Remaining classes: {self.classes}\")\n print(f\"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}\")\n\n def __getitem__(self, index):\n \"\"\"Get item from the filtered dataset with remapped class labels.\"\"\"\n orig_idx = self.indices[index]\n img, old_target = self.dataset[orig_idx]\n \n # Remap target to new class index\n new_target = self.target_mapping[old_target]\n \n return img, new_target\n\n def __len__(self):\n \"\"\"Return the number of samples in the filtered dataset.\"\"\"\n return len(self.indices)\n \n # Allow transform to be updated\n def set_transform(self, transform):\n \"\"\"Update the transform for the dataset.\"\"\"\n self.dataset.transform = transform","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.457202Z","iopub.execute_input":"2025-05-07T15:22:16.457579Z","iopub.status.idle":"2025-05-07T15:22:16.465134Z","shell.execute_reply.started":"2025-05-07T15:22:16.457560Z","shell.execute_reply":"2025-05-07T15:22:16.464502Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Model Definition","metadata":{}},{"cell_type":"code","source":"# Early stopping class\nclass EarlyStopping:\n def __init__(self, patience=5, delta=0):\n self.patience = patience\n self.delta = delta\n self.counter = 0\n self.best_score = None\n self.early_stop = False\n \n def __call__(self, val_loss):\n score = -val_loss\n \n if self.best_score is None:\n self.best_score = score\n elif score < self.best_score + self.delta:\n self.counter += 1\n if self.counter >= self.patience:\n self.early_stop = True\n else:\n self.best_score = score\n self.counter = 0\n\n# Model architecture functions\ndef _get_feature_blocks(model):\n \"\"\"\n Utility: locate the main feature blocks container in a timm model.\n Returns a list-like module of blocks.\n \"\"\"\n for attr in ('features', 'blocks', 'layers', 'stem'): # common container names\n if hasattr(model, attr):\n return getattr(model, attr)\n # fallback: collect all children except classifier/head\n return list(model.children())[:-1]\n\ndef _freeze_except_last_n(blocks, n):\n total = len(blocks)\n for idx, block in enumerate(blocks):\n requires = (idx >= total - n)\n for p in block.parameters():\n p.requires_grad = requires\n\ndef get_model_mobilenetv4(num_classes, freeze_layers=True, device='cuda'):\n model = timm.create_model('mobilenetv4_conv_medium.e500_r256_in1k', pretrained=True)\n if freeze_layers:\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n # replace classifier\n in_features = model.classifier.in_features\n model.classifier = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n\ndef get_model_levit(num_classes, freeze_layers=True, device='cuda'):\n model = timm.create_model('levit_128s.fb_dist_in1k', pretrained=True)\n if freeze_layers:\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n # Attempt to extract in_features from model.head or classifier\n head = getattr(model, 'head_dist', None) or getattr(model, 'classifier', None)\n linear = getattr(head, 'linear')\n in_features = 384\n model.head = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n model.head_dist = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n\ndef get_model_efficientvit(num_classes, freeze_layers=True, device='cuda'):\n model = timm.create_model('efficientvit_m1.r224_in1k', pretrained=True)\n if freeze_layers:\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n # handle different head naming\n head = getattr(model, 'head', None)\n print(head)\n linear = getattr(head, 'linear')\n in_features = 192\n model.head.linear = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n \ndef get_model_gernet(num_classes, freeze_layers=True, device='cuda'):\n \"\"\"\n Load and configure a GENet (General and Efficient Network) model with customizable classifier.\n \n Args:\n num_classes: Number of output classes\n freeze_layers: If True, freeze all but the last 2 blocks\n device: Device to load the model on ('cuda' or 'cpu')\n \n Returns:\n Configured GENet model\n \"\"\"\n model = timm.create_model('gernet_s.idstcv_in1k', pretrained=True)\n \n if freeze_layers:\n # For GENet, we need to specifically handle its structure\n # It typically has a 'stem' and 'stages' structure\n if hasattr(model, 'stem') and hasattr(model, 'stages'):\n # Freeze stem completely\n for param in model.stem.parameters():\n param.requires_grad = False\n \n # Freeze all stages except the last two\n stages = list(model.stages.children())\n total_stages = len(stages)\n for i, stage in enumerate(stages):\n requires_grad = (i >= total_stages - 2)\n for param in stage.parameters():\n param.requires_grad = requires_grad\n else:\n # Fallback to generic approach\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n \n # Replace classifier\n in_features = model.head.fc.in_features\n model.head.fc = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n\ndef get_model_regnetx(num_classes, freeze_layers=True, device='cuda'):\n \"\"\"\n Load and configure a RegNetX model with customizable classifier.\n \n Args:\n num_classes: Number of output classes\n freeze_layers: If True, freeze all but the last 2 blocks\n device: Device to load the model on ('cuda' or 'cpu')\n \n Returns:\n Configured RegNetX model\n \"\"\"\n model = timm.create_model('regnetx_008.tv2_in1k', pretrained=True)\n \n if freeze_layers:\n # Looking at the error, we need to inspect the model structure carefully\n # Print the model structure to understand it better in real use\n # print(model)\n \n # Direct approach: check the model structure and freeze components individually\n # First, freeze all parameters\n for param in model.parameters():\n param.requires_grad = False\n \n # Then unfreeze the last few layers manually based on RegNetX structure\n # RegNetX typically has 'stem' + 'trunk' structure in timm\n if hasattr(model, 'trunk'):\n # Unfreeze final stages of the trunk\n trunk_blocks = list(model.trunk.children())\n # Unfreeze approximately last 25% of trunk blocks\n unfreeze_from = max(0, int(len(trunk_blocks) * 0.75))\n for i in range(unfreeze_from, len(trunk_blocks)):\n for param in trunk_blocks[i].parameters():\n param.requires_grad = True\n \n # Always unfreeze the classifier/head for fine-tuning\n for param in model.head.parameters():\n param.requires_grad = True\n \n # Replace classifier\n in_features = model.head.fc.in_features\n model.head.fc = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.465842Z","iopub.execute_input":"2025-05-07T15:22:16.466084Z","iopub.status.idle":"2025-05-07T15:22:16.492719Z","shell.execute_reply.started":"2025-05-07T15:22:16.466067Z","shell.execute_reply":"2025-05-07T15:22:16.491782Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Training function","metadata":{}},{"cell_type":"code","source":"def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, early_stopping, epochs=15, use_ddp=False):\n \"\"\"\n Train the model and perform validation using multiple GPUs.\n Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.\n \n Args:\n model: Model to train\n criterion: Loss function\n optimizer: Optimizer for training\n scheduler: Learning rate scheduler\n train_loader: DataLoader for training data\n val_loader: DataLoader for validation data\n early_stopping: Early stopping handler\n epochs: Maximum number of epochs to train\n use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False)\n \"\"\"\n # Check available GPUs\n num_gpus = torch.cuda.device_count()\n if num_gpus < 2:\n print(f\"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources.\")\n else:\n print(f\"Using {num_gpus} GPUs for training\")\n \n # Setup device and model\n if num_gpus >= 2:\n if use_ddp:\n # For DistributedDataParallel\n import torch.distributed as dist\n from torch.nn.parallel import DistributedDataParallel as DDP\n \n # Initialize process group\n dist.init_process_group(backend='nccl')\n local_rank = dist.get_rank()\n torch.cuda.set_device(local_rank)\n device = torch.device(f\"cuda:{local_rank}\")\n \n model = model.to(device)\n model = DDP(model, device_ids=[local_rank])\n else:\n # For DataParallel (simpler to use)\n device = torch.device(\"cuda:0\")\n model = model.to(device)\n model = torch.nn.DataParallel(model)\n else:\n # Single GPU\n device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n model = model.to(device)\n \n train_losses = []\n val_losses = []\n train_accs = []\n val_accs = []\n \n # Store validation predictions and labels for final evaluation\n all_val_labels = []\n all_val_preds = []\n all_val_scores = []\n \n for epoch in range(epochs):\n print(f\"Epoch {epoch+1}/{epochs}\")\n \n # Training phase\n model.train()\n running_loss = 0.0\n correct = 0\n total = 0\n \n for inputs, labels in tqdm(train_loader, desc=\"Training\"):\n inputs, labels = inputs.to(device), labels.to(device)\n \n optimizer.zero_grad()\n outputs = model(inputs)\n loss = criterion(outputs, labels)\n loss.backward()\n optimizer.step()\n \n running_loss += loss.item() * inputs.size(0)\n _, predicted = torch.max(outputs, 1)\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n \n epoch_train_loss = running_loss / len(train_loader.dataset)\n epoch_train_acc = correct / total\n train_losses.append(epoch_train_loss)\n train_accs.append(epoch_train_acc)\n \n # Validation phase\n model.eval()\n running_loss = 0.0\n correct = 0\n total = 0\n \n all_labels = []\n all_preds = []\n all_scores = []\n \n with torch.no_grad():\n for inputs, labels in tqdm(val_loader, desc=\"Validation\"):\n inputs, labels = inputs.to(device), labels.to(device)\n outputs = model(inputs)\n loss = criterion(outputs, labels)\n \n running_loss += loss.item() * inputs.size(0)\n probs = F.softmax(outputs, dim=1)\n _, predicted = torch.max(outputs, 1)\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n \n all_labels.extend(labels.cpu().numpy().tolist())\n all_preds.extend(predicted.cpu().numpy().tolist())\n all_scores.append(probs.cpu().numpy())\n \n epoch_val_loss = running_loss / len(val_loader.dataset)\n epoch_val_acc = correct / total\n val_losses.append(epoch_val_loss)\n val_accs.append(epoch_val_acc)\n \n all_scores = np.vstack(all_scores) if all_scores else np.array([])\n \n # Store validation results for the final epoch\n all_val_labels = all_labels\n all_val_preds = all_preds\n all_val_scores = all_scores\n \n # Update learning rate scheduler\n scheduler.step(epoch_val_loss)\n \n print(f\"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}\")\n print(f\"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}\")\n print(f\"Learning rate: {optimizer.param_groups[0]['lr']:.6f}\")\n \n # Check early stopping\n early_stopping(epoch_val_loss)\n if early_stopping.early_stop:\n print(\"Early stopping triggered!\")\n break\n \n # Free up memory\n del all_labels, all_preds, all_scores\n gc.collect()\n torch.cuda.empty_cache()\n \n # Clean up DDP if used\n if num_gpus >= 2 and use_ddp:\n dist.destroy_process_group()\n \n return model, train_losses, val_losses, train_accs, val_accs, all_val_labels, all_val_preds, all_val_scores\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.494736Z","iopub.execute_input":"2025-05-07T15:22:16.495264Z","iopub.status.idle":"2025-05-07T15:22:16.517964Z","shell.execute_reply.started":"2025-05-07T15:22:16.495245Z","shell.execute_reply":"2025-05-07T15:22:16.517204Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Evaluation plotting functions","metadata":{}},{"cell_type":"code","source":"def plot_roc_curves(y_true, y_scores, class_names):\n \"\"\"\n Plot ROC curves for multi-class classification.\n \n Parameters:\n - y_true: true labels\n - y_scores: predicted probability scores from model\n - class_names: list of class names\n \"\"\"\n # Ensure inputs are numpy arrays\n if torch.is_tensor(y_true):\n y_true = y_true.cpu().numpy()\n if torch.is_tensor(y_scores):\n y_scores = y_scores.cpu().numpy()\n \n n_classes = len(class_names)\n \n # Binarize the labels for one-vs-rest ROC calculation\n y_true_bin = label_binarize(y_true, classes=range(n_classes))\n \n # Compute ROC curve and ROC area for each class\n fpr = {}\n tpr = {}\n roc_auc = {}\n \n plt.figure(figsize=(12, 8))\n \n for i in range(n_classes):\n fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n roc_auc[i] = auc(fpr[i], tpr[i])\n \n plt.plot(fpr[i], tpr[i], lw=2,\n label=f'{class_names[i]} (area = {roc_auc[i]:.2f})')\n \n # Plot the diagonal (random classifier)\n plt.plot([0, 1], [0, 1], 'k--', lw=2)\n \n # Calculate and plot micro-average ROC curve\n fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(y_true_bin.ravel(), y_scores.ravel())\n roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n plt.plot(fpr[\"micro\"], tpr[\"micro\"], \n label=f'Micro-average (area = {roc_auc[\"micro\"]:.2f})', \n lw=2, linestyle=':', color='deeppink')\n \n plt.xlim([0.0, 1.0])\n plt.ylim([0.0, 1.05])\n plt.xlabel('False Positive Rate')\n plt.ylabel('True Positive Rate')\n plt.title('ROC Curves')\n plt.legend(loc=\"lower right\")\n plt.grid(True, alpha=0.3)\n plt.tight_layout()\n plt.show()\n \n # Return the AUC values for reporting\n return roc_auc\n\ndef plot_pr_curves(y_true, y_scores, class_names):\n \"\"\"\n Plot Precision-Recall curves for multi-class classification.\n \n Parameters:\n - y_true: true labels\n - y_scores: predicted probability scores from model\n - class_names: list of class names\n \"\"\"\n # Ensure inputs are numpy arrays\n if torch.is_tensor(y_true):\n y_true = y_true.cpu().numpy()\n if torch.is_tensor(y_scores):\n y_scores = y_scores.cpu().numpy()\n \n n_classes = len(class_names)\n \n # Binarize the labels\n y_true_bin = label_binarize(y_true, classes=range(n_classes))\n \n # Compute PR curve and average precision for each class\n precision = {}\n recall = {}\n avg_precision = {}\n \n plt.figure(figsize=(12, 8))\n \n for i in range(n_classes):\n precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_scores[:, i])\n avg_precision[i] = average_precision_score(y_true_bin[:, i], y_scores[:, i])\n \n plt.plot(recall[i], precision[i], lw=2,\n label=f'{class_names[i]} (AP = {avg_precision[i]:.2f})')\n \n # Calculate and plot micro-average PR curve\n precision[\"micro\"], recall[\"micro\"], _ = precision_recall_curve(\n y_true_bin.ravel(), y_scores.ravel())\n avg_precision[\"micro\"] = average_precision_score(y_true_bin.ravel(), y_scores.ravel())\n \n plt.plot(recall[\"micro\"], precision[\"micro\"],\n label=f'Micro-average (AP = {avg_precision[\"micro\"]:.2f})',\n lw=2, linestyle=':', color='deeppink')\n \n plt.xlim([0.0, 1.0])\n plt.ylim([0.0, 1.05])\n plt.xlabel('Recall')\n plt.ylabel('Precision')\n plt.title('Precision-Recall Curves')\n plt.legend(loc=\"best\")\n plt.grid(True, alpha=0.3)\n plt.tight_layout()\n plt.show()\n \n # Return the average precision values for reporting\n return avg_precision\n\ndef plot_accuracy_and_loss(train_losses, val_losses, train_accs, val_accs):\n plt.figure(figsize=(12, 5))\n # Accuracy curve\n plt.subplot(1, 2, 1)\n plt.plot(train_accs, label=\"Train Accuracy\")\n plt.plot(val_accs, label=\"Validation Accuracy\")\n plt.xlabel(\"Epochs\")\n plt.ylabel(\"Accuracy\")\n plt.title(\"Accuracy Curve\")\n plt.legend()\n plt.grid(True)\n \n # Loss curve\n plt.subplot(1, 2, 2)\n plt.plot(train_losses, label=\"Train Loss\")\n plt.plot(val_losses, label=\"Validation Loss\")\n plt.xlabel(\"Epochs\")\n plt.ylabel(\"Loss\")\n plt.title(\"Loss Curve\")\n plt.legend()\n plt.grid(True)\n \n plt.tight_layout()\n plt.show()\n\ndef plot_confusion_matrix(y_true, y_pred, class_names):\n # Ensure we're working with numpy arrays\n y_true = np.array(y_true)\n y_pred = np.array(y_pred)\n \n # Get unique values in both arrays\n unique_values = np.unique(np.concatenate([y_true, y_pred]))\n print(f\"Unique values in confusion matrix data: {unique_values}\")\n \n # Create the confusion matrix with explicit labels\n cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))\n \n plt.figure(figsize=(10, 8))\n sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\",\n xticklabels=class_names, yticklabels=class_names)\n plt.title(\"Confusion Matrix\")\n plt.xlabel(\"Predicted\")\n plt.ylabel(\"True\")\n plt.tight_layout()\n plt.show()\n\ndef plot_per_class_accuracy(y_true, y_pred, class_names):\n # Convert to numpy arrays\n y_true = np.array(y_true)\n y_pred = np.array(y_pred)\n \n # Get number of expected classes\n num_classes = len(class_names)\n \n # Create the confusion matrix with explicit labels\n cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))\n \n # Calculate per-class accuracy\n per_class_accuracy = np.zeros(num_classes)\n for i in range(num_classes):\n if i < cm.shape[0] and np.sum(cm[i, :]) > 0:\n per_class_accuracy[i] = cm[i, i] / np.sum(cm[i, :])\n \n # Create the bar plot\n plt.figure(figsize=(14, 7))\n plt.bar(range(num_classes), per_class_accuracy, color=\"skyblue\")\n plt.xticks(range(num_classes), class_names, rotation=45, ha='right')\n plt.xlabel(\"Classes\")\n plt.ylabel(\"Accuracy\")\n plt.title(\"Per-Class Accuracy\")\n plt.tight_layout()\n plt.show()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.518997Z","iopub.execute_input":"2025-05-07T15:22:16.519284Z","iopub.status.idle":"2025-05-07T15:22:16.546032Z","shell.execute_reply.started":"2025-05-07T15:22:16.519262Z","shell.execute_reply":"2025-05-07T15:22:16.545330Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from sklearn.metrics import cohen_kappa_score\n\ndef compute_classification_metrics(y_true, y_pred, y_scores, num_classes, class_names, model_name=\"\"):\n \"\"\"\n Compute comprehensive classification metrics including ROC AUC, PR AUC, and Cohen's Kappa.\n \n Parameters:\n - y_true: true labels\n - y_pred: predicted labels\n - y_scores: predicted probability scores from model\n - num_classes: number of classes\n - class_names: list of class names\n - model_name: name of the model (for display purposes)\n \n Returns:\n - accuracy: overall accuracy score\n - report_dict: classification report as dictionary\n - roc_auc_dict: ROC AUC scores by class\n - pr_auc_dict: PR AUC scores by class\n - kappa: Cohen's Kappa score\n \"\"\"\n # Calculate accuracy\n accuracy = accuracy_score(y_true, y_pred)\n print(f\"Overall Accuracy: {accuracy:.4f}\")\n \n # Calculate and display Cohen's Kappa\n kappa = cohen_kappa_score(y_true, y_pred)\n print(f\"Cohen's Kappa Score: {kappa:.4f}\")\n \n # Generate classification report\n report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)\n \n # Print formatted classification report\n print(\"\\nClassification Report:\")\n print(classification_report(y_true, y_pred, target_names=class_names))\n \n # Calculate ROC curves and AUC for each class\n print(\"\\nCalculating ROC curves...\")\n roc_auc_dict = plot_roc_curves(y_true, y_scores, class_names)\n \n # Calculate PR curves and AUC for each class\n print(\"\\nCalculating PR curves...\")\n pr_auc_dict = plot_pr_curves(y_true, y_scores, class_names)\n \n # Return metrics for comparison\n return accuracy, report, roc_auc_dict, pr_auc_dict, kappa\n\n# Also update evaluate_on_test_set to include kappa\ndef evaluate_on_test_set(model, test_loader, dataset):\n \"\"\"Evaluate a trained model on test dataset\"\"\"\n class_names = dataset.classes\n num_classes = len(class_names)\n \n model.eval()\n device = next(model.parameters()).device\n \n all_labels = []\n all_preds = []\n all_scores = []\n \n with torch.no_grad():\n for inputs, labels in test_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n outputs = model(inputs)\n _, preds = torch.max(outputs, 1)\n \n all_labels.extend(labels.cpu().numpy())\n all_preds.extend(preds.cpu().numpy())\n all_scores.append(torch.nn.functional.softmax(outputs, dim=1).cpu().numpy())\n \n all_scores = np.vstack(all_scores)\n \n # Compute metrics including kappa\n accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = compute_classification_metrics(\n all_labels, all_preds, all_scores, num_classes, class_names)\n \n # Build results dictionary with kappa\n results = {\n 'accuracy': accuracy,\n 'report': report_dict,\n 'roc_auc': roc_auc_dict,\n 'pr_auc': pr_auc_dict,\n 'kappa': kappa\n }\n \n return results","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.547383Z","iopub.execute_input":"2025-05-07T15:22:16.547662Z","iopub.status.idle":"2025-05-07T15:22:16.571142Z","shell.execute_reply.started":"2025-05-07T15:22:16.547638Z","shell.execute_reply":"2025-05-07T15:22:16.570633Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Main training function","metadata":{}},{"cell_type":"code","source":"# Update the model_train function to include kappa in the results\ndef model_train(model, train_loader, val_loader, dataset, epochs=20):\n model_name = type(model).__name__\n if hasattr(model, 'pretrained_cfg') and 'name' in model.pretrained_cfg:\n model_name = model.pretrained_cfg['name']\n \n print(f\"\\n{'='*20} Training {model_name} {'='*20}\\n\")\n \n class_names = dataset.classes\n num_classes = len(class_names)\n learning_rate = 0.001\n \n try:\n optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)\n early_stopping = EarlyStopping(patience=5)\n \n model, train_losses, val_losses, train_accs, val_accs, val_labels, val_preds, val_scores = train_model(\n model, nn.CrossEntropyLoss(), optimizer, scheduler,\n train_loader, val_loader, early_stopping, epochs=epochs, use_ddp=False\n )\n \n print(f\"\\n{'='*20} Evaluation for {model_name} {'='*20}\\n\")\n \n # Plot training curves\n plot_accuracy_and_loss(train_losses, val_losses, train_accs, val_accs)\n \n # Process validation predictions and labels\n try:\n plot_confusion_matrix(val_labels, val_preds, class_names)\n plot_per_class_accuracy(val_labels, val_preds, class_names)\n \n # Get metrics from the updated function including kappa\n accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = compute_classification_metrics(\n val_labels, val_preds, val_scores, num_classes, class_names, model_name)\n \n # Build a results dictionary including kappa\n results = {\n 'accuracy': accuracy,\n 'report': report_dict,\n 'roc_auc': roc_auc_dict,\n 'pr_auc': pr_auc_dict,\n 'kappa': kappa\n }\n \n return results\n except Exception as viz_error:\n print(f\"Error in visualization: {viz_error}\")\n import traceback\n traceback.print_exc()\n return {'accuracy': None}\n \n except Exception as e:\n print(f'Error occurred when training {model_name}: {e}')\n import traceback\n traceback.print_exc()\n return {'accuracy': None}\n finally:\n # Clean up memory\n if 'optimizer' in locals():\n del optimizer\n if 'scheduler' in locals():\n del scheduler\n if 'early_stopping' in locals():\n del early_stopping\n if 'train_losses' in locals():\n del train_losses\n del val_losses\n del train_accs\n del val_accs\n del val_labels\n del val_preds\n del val_scores\n \n gc.collect()\n torch.cuda.empty_cache()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.571806Z","iopub.execute_input":"2025-05-07T15:22:16.572176Z","iopub.status.idle":"2025-05-07T15:22:16.594545Z","shell.execute_reply.started":"2025-05-07T15:22:16.572146Z","shell.execute_reply":"2025-05-07T15:22:16.593850Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Comparison Function","metadata":{}},{"cell_type":"code","source":"def compare_models(models, train_loader, val_loader, test_loader, dataset, epochs=20, names=None):\n if names is None:\n names = [f\"Model {i+1}\" for i in range(len(models))]\n \n val_results = {}\n test_results = {}\n best_model_obj = None\n best_accuracy = -1\n best_model_name = \"\"\n \n # Summary dictionaries for metrics\n val_roc_auc_summary = {}\n test_roc_auc_summary = {}\n val_pr_auc_summary = {}\n test_pr_auc_summary = {}\n val_kappa_summary = {}\n test_kappa_summary = {}\n \n for i, (model, name) in enumerate(zip(models, names)):\n print(f\"\\n\\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\\n\")\n model_results = model_train(model, train_loader, val_loader, dataset, epochs)\n \n # Extract accuracy from results\n accuracy = model_results.get('accuracy')\n val_results[name] = accuracy\n \n # Extract and store metrics\n if 'roc_auc' in model_results and 'micro' in model_results['roc_auc']:\n val_roc_auc_summary[name] = model_results['roc_auc']['micro']\n else:\n val_roc_auc_summary[name] = None\n \n if 'pr_auc' in model_results and 'micro' in model_results['pr_auc']:\n val_pr_auc_summary[name] = model_results['pr_auc']['micro']\n else:\n val_pr_auc_summary[name] = None\n \n # Store kappa score\n if 'kappa' in model_results:\n val_kappa_summary[name] = model_results['kappa']\n else:\n val_kappa_summary[name] = None\n \n # Evaluate on test set\n if accuracy is not None:\n print(f\"\\n{'='*20} Testing {name} on Test Set {'='*20}\\n\")\n test_model_results = evaluate_on_test_set(model, test_loader, dataset)\n \n # Extract accuracy from test results\n test_accuracy = test_model_results.get('accuracy')\n test_results[name] = test_accuracy\n \n # Extract and store test metrics\n if 'roc_auc' in test_model_results and 'micro' in test_model_results['roc_auc']:\n test_roc_auc_summary[name] = test_model_results['roc_auc']['micro']\n else:\n test_roc_auc_summary[name] = None\n \n if 'pr_auc' in test_model_results and 'micro' in test_model_results['pr_auc']:\n test_pr_auc_summary[name] = test_model_results['pr_auc']['micro']\n else:\n test_pr_auc_summary[name] = None\n \n # Store test kappa score\n if 'kappa' in test_model_results:\n test_kappa_summary[name] = test_model_results['kappa']\n else:\n test_kappa_summary[name] = None\n \n # Track best model\n if test_accuracy > best_accuracy:\n best_accuracy = test_accuracy\n best_model_obj = copy.deepcopy(model)\n best_model_name = name\n \n # Print comprehensive comparison\n print(\"\\n\\n\" + \"=\"*100)\n print(\"COMPREHENSIVE MODEL COMPARISON\")\n print(\"=\"*100)\n print(f\"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}\")\n print(\"-\"*100)\n \n for name in val_results.keys():\n val_acc = val_results[name]\n test_acc = test_results.get(name, None)\n val_roc = val_roc_auc_summary.get(name, None)\n test_roc = test_roc_auc_summary.get(name, None)\n val_pr = val_pr_auc_summary.get(name, None)\n test_pr = test_pr_auc_summary.get(name, None)\n val_kappa = val_kappa_summary.get(name, None)\n test_kappa = test_kappa_summary.get(name, None)\n \n # Format values for display\n val_acc_str = f\"{val_acc:.4f}\" if val_acc is not None else \"Failed\"\n test_acc_str = f\"{test_acc:.4f}\" if test_acc is not None else \"N/A\"\n val_roc_str = f\"{val_roc:.4f}\" if val_roc is not None else \"N/A\"\n test_roc_str = f\"{test_roc:.4f}\" if test_roc is not None else \"N/A\"\n val_pr_str = f\"{val_pr:.4f}\" if val_pr is not None else \"N/A\"\n test_pr_str = f\"{test_pr:.4f}\" if test_pr is not None else \"N/A\"\n val_kappa_str = f\"{val_kappa:.4f}\" if val_kappa is not None else \"N/A\"\n test_kappa_str = f\"{test_kappa:.4f}\" if test_kappa is not None else \"N/A\"\n \n print(f\"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}\")\n \n # Identify best model based on test metrics\n if test_results:\n # Best model by accuracy\n best_acc_model = max(test_results.items(), key=lambda x: x[1] if x[1] is not None else -1)\n print(f\"\\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})\")\n \n # Best model by ROC AUC (if available)\n if any(v is not None for v in test_roc_auc_summary.values()):\n best_roc_model = max(\n [(k, v) for k, v in test_roc_auc_summary.items() if v is not None], \n key=lambda x: x[1] if x[1] is not None else -1\n )\n print(f\"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})\")\n \n # Best model by PR AUC (if available)\n if any(v is not None for v in test_pr_auc_summary.values()):\n best_pr_model = max(\n [(k, v) for k, v in test_pr_auc_summary.items() if v is not None], \n key=lambda x: x[1] if x[1] is not None else -1\n )\n print(f\"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})\")\n \n # Best model by Kappa (if available)\n if any(v is not None for v in test_kappa_summary.values()):\n best_kappa_model = max(\n [(k, v) for k, v in test_kappa_summary.items() if v is not None], \n key=lambda x: x[1] if x[1] is not None else -1\n )\n print(f\"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})\")\n \n # Save the best model (by accuracy)\n if best_model_obj is not None:\n try:\n model_save_path = f\"best_model_{best_model_name.lower().replace(' ', '_')}.pth\"\n torch.save(best_model_obj.state_dict(), model_save_path)\n print(f\"Best model saved to {model_save_path}\")\n except Exception as save_error:\n print(f\"Error saving best model: {save_error}\")\n else:\n print(\"\\nNo models successfully completed testing.\")\n \n print(\"=\"*100)\n \n # Visualize comparison\n try:\n # Create bar charts comparing different metrics\n plot_model_comparison(val_results, test_results, val_roc_auc_summary, \n test_roc_auc_summary, val_pr_auc_summary, test_pr_auc_summary,\n val_kappa_summary, test_kappa_summary)\n except Exception as viz_error:\n print(f\"Error in comparison visualization: {viz_error}\")\n import traceback\n traceback.print_exc()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.595427Z","iopub.execute_input":"2025-05-07T15:22:16.595661Z","iopub.status.idle":"2025-05-07T15:22:16.620099Z","shell.execute_reply.started":"2025-05-07T15:22:16.595635Z","shell.execute_reply":"2025-05-07T15:22:16.619560Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def plot_model_comparison(val_acc, test_acc, val_roc, test_roc, val_pr, test_pr, val_kappa, test_kappa):\n \"\"\"\n Create visualizations to compare model performance across different metrics including Cohen's Kappa.\n \"\"\"\n # Get the list of model names (should be the same across all dictionaries)\n models = list(val_acc.keys())\n \n # Create a figure with 4 subplots for Accuracy, ROC AUC, PR AUC, and Kappa\n fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12, 24))\n \n # Plot Accuracy\n x = np.arange(len(models))\n width = 0.35\n \n val_acc_values = [val_acc.get(model, None) for model in models]\n test_acc_values = [test_acc.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_acc_values = [v if v is not None else float('nan') for v in val_acc_values]\n test_acc_values = [v if v is not None else float('nan') for v in test_acc_values]\n \n ax1.bar(x - width/2, val_acc_values, width, label='Validation', color='skyblue')\n ax1.bar(x + width/2, test_acc_values, width, label='Test', color='salmon')\n ax1.set_ylabel('Accuracy')\n ax1.set_title('Model Accuracy Comparison')\n ax1.set_xticks(x)\n ax1.set_xticklabels(models, rotation=45, ha='right')\n ax1.legend()\n ax1.grid(True, alpha=0.3)\n \n # Plot ROC AUC\n val_roc_values = [val_roc.get(model, None) for model in models]\n test_roc_values = [test_roc.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_roc_values = [v if v is not None else float('nan') for v in val_roc_values]\n test_roc_values = [v if v is not None else float('nan') for v in test_roc_values]\n \n ax2.bar(x - width/2, val_roc_values, width, label='Validation', color='skyblue')\n ax2.bar(x + width/2, test_roc_values, width, label='Test', color='salmon')\n ax2.set_ylabel('ROC AUC')\n ax2.set_title('Model ROC AUC Comparison')\n ax2.set_xticks(x)\n ax2.set_xticklabels(models, rotation=45, ha='right')\n ax2.legend()\n ax2.grid(True, alpha=0.3)\n \n # Plot PR AUC\n val_pr_values = [val_pr.get(model, None) for model in models]\n test_pr_values = [test_pr.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_pr_values = [v if v is not None else float('nan') for v in val_pr_values]\n test_pr_values = [v if v is not None else float('nan') for v in test_pr_values]\n \n ax3.bar(x - width/2, val_pr_values, width, label='Validation', color='skyblue')\n ax3.bar(x + width/2, test_pr_values, width, label='Test', color='salmon')\n ax3.set_ylabel('PR AUC')\n ax3.set_title('Model PR AUC Comparison')\n ax3.set_xticks(x)\n ax3.set_xticklabels(models, rotation=45, ha='right')\n ax3.legend()\n ax3.grid(True, alpha=0.3)\n \n # Plot Kappa scores\n val_kappa_values = [val_kappa.get(model, None) for model in models]\n test_kappa_values = [test_kappa.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_kappa_values = [v if v is not None else float('nan') for v in val_kappa_values]\n test_kappa_values = [v if v is not None else float('nan') for v in test_kappa_values]\n \n ax4.bar(x - width/2, val_kappa_values, width, label='Validation', color='skyblue')\n ax4.bar(x + width/2, test_kappa_values, width, label='Test', color='salmon')\n ax4.set_ylabel(\"Cohen's Kappa\")\n ax4.set_title(\"Model Cohen's Kappa Comparison\")\n ax4.set_xticks(x)\n ax4.set_xticklabels(models, rotation=45, ha='right')\n ax4.legend()\n ax4.grid(True, alpha=0.3)\n \n plt.tight_layout()\n plt.show()\n \n # Create a comprehensive heatmap for all metrics\n try:\n plot_metrics_heatmap(models, val_acc_values, test_acc_values, \n val_roc_values, test_roc_values,\n val_pr_values, test_pr_values,\n val_kappa_values, test_kappa_values)\n except Exception as e:\n print(f\"Error creating metrics heatmap: {e}\")\n\ndef plot_metrics_heatmap(models, val_acc, test_acc, val_roc, test_roc, val_pr, test_pr, val_kappa, test_kappa):\n \"\"\"\n Create a heatmap visualization of all metrics for easy comparison across models.\n \"\"\"\n # Prepare data for heatmap\n metric_names = ['Val Acc', 'Test Acc', 'Val ROC', 'Test ROC', \n 'Val PR', 'Test PR', 'Val Kappa', 'Test Kappa']\n \n data = np.array([\n val_acc, test_acc, val_roc, test_roc, val_pr, test_pr, val_kappa, test_kappa\n ])\n \n # Create the heatmap\n plt.figure(figsize=(12, 8))\n ax = sns.heatmap(data, annot=True, fmt=\".4f\", cmap=\"YlGnBu\", \n xticklabels=models, yticklabels=metric_names)\n \n plt.title(\"Comprehensive Model Performance Metrics\")\n plt.tight_layout()\n plt.show()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.595427Z","iopub.execute_input":"2025-05-07T15:22:16.595661Z","iopub.status.idle":"2025-05-07T15:22:16.620099Z","shell.execute_reply.started":"2025-05-07T15:22:16.595635Z","shell.execute_reply":"2025-05-07T15:22:16.619560Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Main Function","metadata":{}},{"cell_type":"code","source":"train_dir = '/kaggle/input/eye-disease-image-dataset/Augmented Dataset/Augmented Dataset' # For training (pre-augmented data)\neval_dir = '/kaggle/input/eye-disease-image-dataset/Original Dataset/Original Dataset' # For val and test\nepochs = 15\nclasses_to_exclude = [\"Pterygium\"]\nbatch_size = 32\n\nseed = 42\nrandom.seed(seed)\nnp.random.seed(seed)\ntorch.manual_seed(seed)\nif torch.cuda.is_available():\n torch.cuda.manual_seed(seed)\n torch.backends.cudnn.deterministic = True\n\n# Define transformations\ntransform = transforms.Compose([\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n])\n\ntry:\n # Load datasets\n print(f\"Loading training dataset from {train_dir}...\")\n train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)\n \n print(f\"Loading evaluation dataset from {eval_dir}...\")\n eval_dataset = datasets.ImageFolder(root=eval_dir, transform=transform)\n \n # Print dataset information\n print(f\"Training dataset loaded with {len(train_dataset)} images across {len(train_dataset.classes)} classes.\")\n print(f\"Training classes: {train_dataset.classes}\")\n \n print(f\"Evaluation dataset loaded with {len(eval_dataset)} images across {len(eval_dataset.classes)} classes.\")\n print(f\"Evaluation classes: {eval_dataset.classes}\")\n \n # Filter datasets if needed\n excluded_classes = classes_to_exclude or []\n if excluded_classes:\n print(f\"Filtering out classes: {excluded_classes}\")\n filtered_train_dataset = FilteredImageDataset(train_dataset, excluded_classes=excluded_classes)\n filtered_eval_dataset = FilteredImageDataset(eval_dataset, excluded_classes=excluded_classes)\n else:\n filtered_train_dataset = train_dataset\n filtered_eval_dataset = eval_dataset\n \n # Check if the filtered classes match between training and evaluation datasets\n if set(filtered_train_dataset.classes) != set(filtered_eval_dataset.classes):\n print(\"Warning: Class mismatch between filtered training and evaluation datasets!\")\n print(f\"Filtered training classes: {filtered_train_dataset.classes}\")\n print(f\"Filtered evaluation classes: {filtered_eval_dataset.classes}\")\n \n # Find common classes\n common_classes = set(filtered_train_dataset.classes).intersection(set(filtered_eval_dataset.classes))\n print(f\"Common classes: {common_classes}\")\n \n # Create additional filtering based on common classes\n filtered_train_dataset = FilteredImageDataset(train_dataset, \n included_classes=common_classes)\n filtered_eval_dataset = FilteredImageDataset(eval_dataset, \n included_classes=common_classes)\n \n # Split evaluation dataset into validation and test sets\n eval_ratio = 0.7 # 70% validation, 30% test\n eval_size = len(filtered_eval_dataset)\n val_size = int(eval_ratio * eval_size)\n test_size = eval_size - val_size\n \n val_dataset, test_dataset = random_split(filtered_eval_dataset, [val_size, test_size])\n \n print(f\"Training set size: {len(filtered_train_dataset)}\")\n print(f\"Validation set size: {len(val_dataset)}\")\n print(f\"Test set size: {len(test_dataset)}\")\n \n # Create data loaders\n train_loader = DataLoader(filtered_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)\n val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)\n test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)\n \n # Get the number of classes (after filtering)\n num_classes = len(filtered_train_dataset.classes)\n print(f\"Number of classes after filtering: {num_classes}\")\n print(f\"Classes after filtering: {filtered_train_dataset.classes}\")\n \n # Initialize models\n print(\"Initializing models...\")\n \n # Initialize models with the updated number of classes\n all_models = {\n \"MobileNetV4\": get_model_mobilenetv4(num_classes, freeze_layers=True),\n \"LeViT\": get_model_levit(num_classes, freeze_layers=True),\n \"EfficientViT\": get_model_efficientvit(num_classes, freeze_layers=True),\n \"GENet\": get_model_gernet(num_classes, freeze_layers=True),\n \"RegNetX\": get_model_regnetx(num_classes, freeze_layers=True)\n }\n \n models = list(all_models.values())\n model_names = list(all_models.keys())\n \n # Train and compare models\n print(\"Starting model training and comparison...\")\n compare_models(models, train_loader, val_loader, test_loader, filtered_train_dataset, epochs=epochs, names=model_names)\n \nexcept Exception as e:\n print(f\"Error in eye disease classification pipeline: {e}\")\n import traceback\n traceback.print_exc()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:59.764465Z","iopub.execute_input":"2025-05-07T15:22:59.765257Z","execution_failed":"2025-05-07T15:22:59.047Z"}},"outputs":[],"execution_count":null}]}
utils/Callback.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Early stopping class
2
+ class EarlyStopping:
3
+ def __init__(self, patience=5, delta=0):
4
+ self.patience = patience
5
+ self.delta = delta
6
+ self.counter = 0
7
+ self.best_score = None
8
+ self.early_stop = False
9
+
10
+ def __call__(self, val_loss):
11
+ score = -val_loss
12
+
13
+ if self.best_score is None:
14
+ self.best_score = score
15
+ elif score < self.best_score + self.delta:
16
+ self.counter += 1
17
+ if self.counter >= self.patience:
18
+ self.early_stop = True
19
+ else:
20
+ self.best_score = score
21
+ self.counter = 0
utils/Comparator.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Trainer import model_train
2
+ from Evaluator import ClassificationEvaluator
3
+
4
+
5
+ def compare_models(
6
+ models, train_loader, val_loader, test_loader, dataset, epochs=20, names=None
7
+ ):
8
+ if names is None:
9
+ names = [f"Model {i+1}" for i in range(len(models))]
10
+
11
+ val_results = {}
12
+ test_results = {}
13
+ best_model_obj = None
14
+ best_accuracy = -1
15
+ best_model_name = ""
16
+
17
+ # Summary dictionaries for metrics
18
+ val_roc_auc_summary = {}
19
+ test_roc_auc_summary = {}
20
+ val_pr_auc_summary = {}
21
+ test_pr_auc_summary = {}
22
+ val_kappa_summary = {}
23
+ test_kappa_summary = {}
24
+
25
+ for i, (model, name) in enumerate(zip(models, names)):
26
+ evaluator = ClassificationEvaluator(
27
+ num_classes=len(dataset.classes),
28
+ class_names=dataset.classes,
29
+ )
30
+
31
+ print(f"\n\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\n")
32
+ model_results = model_train(model, train_loader, val_loader, dataset, epochs)
33
+
34
+ # Extract accuracy from results
35
+ accuracy = model_results.get("accuracy")
36
+ val_results[name] = accuracy
37
+
38
+ # Extract and store metrics
39
+ if "roc_auc" in model_results and "micro" in model_results["roc_auc"]:
40
+ val_roc_auc_summary[name] = model_results["roc_auc"]["micro"]
41
+ else:
42
+ val_roc_auc_summary[name] = None
43
+
44
+ if "pr_auc" in model_results and "micro" in model_results["pr_auc"]:
45
+ val_pr_auc_summary[name] = model_results["pr_auc"]["micro"]
46
+ else:
47
+ val_pr_auc_summary[name] = None
48
+
49
+ # Store kappa score
50
+ if "kappa" in model_results:
51
+ val_kappa_summary[name] = model_results["kappa"]
52
+ else:
53
+ val_kappa_summary[name] = None
54
+
55
+ # Evaluate on test set
56
+ if accuracy is not None:
57
+ print(f"\n{'='*20} Testing {name} on Test Set {'='*20}\n")
58
+ test_model_results = evaluator.evaluate_model(model, test_loader)
59
+
60
+ # Extract accuracy from test results
61
+ test_accuracy = test_model_results.get("accuracy")
62
+ test_results[name] = test_accuracy
63
+
64
+ # Extract and store test metrics
65
+ if (
66
+ "roc_auc" in test_model_results
67
+ and "micro" in test_model_results["roc_auc"]
68
+ ):
69
+ test_roc_auc_summary[name] = test_model_results["roc_auc"]["micro"]
70
+ else:
71
+ test_roc_auc_summary[name] = None
72
+
73
+ if (
74
+ "pr_auc" in test_model_results
75
+ and "micro" in test_model_results["pr_auc"]
76
+ ):
77
+ test_pr_auc_summary[name] = test_model_results["pr_auc"]["micro"]
78
+ else:
79
+ test_pr_auc_summary[name] = None
80
+
81
+ # Store test kappa score
82
+ if "kappa" in test_model_results:
83
+ test_kappa_summary[name] = test_model_results["kappa"]
84
+ else:
85
+ test_kappa_summary[name] = None
86
+
87
+ # Track best model
88
+ if test_accuracy > best_accuracy:
89
+ best_accuracy = test_accuracy
90
+ best_model_obj = copy.deepcopy(model)
91
+ best_model_name = name
92
+
93
+ # Print comprehensive comparison
94
+ print("\n\n" + "=" * 100)
95
+ print("COMPREHENSIVE MODEL COMPARISON")
96
+ print("=" * 100)
97
+ print(
98
+ f"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}"
99
+ )
100
+ print("-" * 100)
101
+
102
+ for name in val_results.keys():
103
+ val_acc = val_results[name]
104
+ test_acc = test_results.get(name, None)
105
+ val_roc = val_roc_auc_summary.get(name, None)
106
+ test_roc = test_roc_auc_summary.get(name, None)
107
+ val_pr = val_pr_auc_summary.get(name, None)
108
+ test_pr = test_pr_auc_summary.get(name, None)
109
+ val_kappa = val_kappa_summary.get(name, None)
110
+ test_kappa = test_kappa_summary.get(name, None)
111
+
112
+ # Format values for display
113
+ val_acc_str = f"{val_acc:.4f}" if val_acc is not None else "Failed"
114
+ test_acc_str = f"{test_acc:.4f}" if test_acc is not None else "N/A"
115
+ val_roc_str = f"{val_roc:.4f}" if val_roc is not None else "N/A"
116
+ test_roc_str = f"{test_roc:.4f}" if test_roc is not None else "N/A"
117
+ val_pr_str = f"{val_pr:.4f}" if val_pr is not None else "N/A"
118
+ test_pr_str = f"{test_pr:.4f}" if test_pr is not None else "N/A"
119
+ val_kappa_str = f"{val_kappa:.4f}" if val_kappa is not None else "N/A"
120
+ test_kappa_str = f"{test_kappa:.4f}" if test_kappa is not None else "N/A"
121
+
122
+ print(
123
+ f"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}"
124
+ )
125
+
126
+ # Identify best model based on test metrics
127
+ if test_results:
128
+ # Best model by accuracy
129
+ best_acc_model = max(
130
+ test_results.items(), key=lambda x: x[1] if x[1] is not None else -1
131
+ )
132
+ print(
133
+ f"\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})"
134
+ )
135
+
136
+ # Best model by ROC AUC (if available)
137
+ if any(v is not None for v in test_roc_auc_summary.values()):
138
+ best_roc_model = max(
139
+ [(k, v) for k, v in test_roc_auc_summary.items() if v is not None],
140
+ key=lambda x: x[1] if x[1] is not None else -1,
141
+ )
142
+ print(
143
+ f"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})"
144
+ )
145
+
146
+ # Best model by PR AUC (if available)
147
+ if any(v is not None for v in test_pr_auc_summary.values()):
148
+ best_pr_model = max(
149
+ [(k, v) for k, v in test_pr_auc_summary.items() if v is not None],
150
+ key=lambda x: x[1] if x[1] is not None else -1,
151
+ )
152
+ print(
153
+ f"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})"
154
+ )
155
+
156
+ # Best model by Kappa (if available)
157
+ if any(v is not None for v in test_kappa_summary.values()):
158
+ best_kappa_model = max(
159
+ [(k, v) for k, v in test_kappa_summary.items() if v is not None],
160
+ key=lambda x: x[1] if x[1] is not None else -1,
161
+ )
162
+ print(
163
+ f"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})"
164
+ )
165
+
166
+ # Save the best model (by accuracy)
167
+ if best_model_obj is not None:
168
+ try:
169
+ model_save_path = (
170
+ f"best_model_{best_model_name.lower().replace(' ', '_')}.pth"
171
+ )
172
+ torch.save(best_model_obj.state_dict(), model_save_path)
173
+ print(f"Best model saved to {model_save_path}")
174
+ except Exception as save_error:
175
+ print(f"Error saving best model: {save_error}")
176
+ else:
177
+ print("\nNo models successfully completed testing.")
178
+
179
+ print("=" * 100)
utils/DatasetHandler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import copy
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ import torchvision.transforms as transforms
11
+ from torchvision import transforms, datasets
12
+ import torchvision.models as models
13
+ import timm
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from PIL import Image
17
+ from torch.utils.data import DataLoader, random_split, Subset, Dataset
18
+ from sklearn.metrics import (
19
+ accuracy_score,
20
+ confusion_matrix,
21
+ classification_report,
22
+ roc_curve,
23
+ auc,
24
+ precision_recall_curve,
25
+ average_precision_score,
26
+ )
27
+ from sklearn.preprocessing import label_binarize
28
+ from tqdm import tqdm
29
+ import gc
30
+
31
+
32
+ # Modified FilteredImageDataset class with Pterygium filtering
33
+ class FilteredImageDataset(Dataset):
34
+ def __init__(self, dataset, excluded_classes=None):
35
+ """
36
+ Create a filtered dataset that excludes specific classes.
37
+
38
+ Args:
39
+ dataset: Original dataset (ImageFolder or similar)
40
+ excluded_classes: List of class names to exclude (e.g., ["Pterygium"])
41
+ """
42
+ self.dataset = dataset
43
+ self.excluded_classes = excluded_classes or []
44
+
45
+ # Get original class information
46
+ self.orig_classes = dataset.classes
47
+ self.orig_class_to_idx = dataset.class_to_idx
48
+
49
+ # Create indices of samples to keep (excluding specified classes)
50
+ self.indices = []
51
+ for idx, (_, target) in enumerate(dataset.samples):
52
+ class_name = self.orig_classes[target]
53
+ if class_name not in self.excluded_classes:
54
+ self.indices.append(idx)
55
+
56
+ # Create new class mapping without excluded classes
57
+ remaining_classes = [
58
+ c for c in self.orig_classes if c not in self.excluded_classes
59
+ ]
60
+ self.classes = remaining_classes
61
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)}
62
+ self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
63
+
64
+ # Create a mapping from old indices to new indices
65
+ self.target_mapping = {}
66
+ for old_class, old_idx in self.orig_class_to_idx.items():
67
+ if old_class in self.class_to_idx:
68
+ self.target_mapping[old_idx] = self.class_to_idx[old_class]
69
+
70
+ print(f"Filtered out classes: {self.excluded_classes}")
71
+ print(f"Remaining classes: {self.classes}")
72
+ print(
73
+ f"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}"
74
+ )
75
+
76
+ def __getitem__(self, index):
77
+ """Get item from the filtered dataset with remapped class labels."""
78
+ orig_idx = self.indices[index]
79
+ img, old_target = self.dataset[orig_idx]
80
+
81
+ # Remap target to new class index
82
+ new_target = self.target_mapping[old_target]
83
+
84
+ return img, new_target
85
+
86
+ def __len__(self):
87
+ """Return the number of samples in the filtered dataset."""
88
+ return len(self.indices)
89
+
90
+ # Allow transform to be updated
91
+ def set_transform(self, transform):
92
+ """Update the transform for the dataset."""
93
+ self.dataset.transform = transform
utils/Evaluator.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from sklearn.metrics import (
6
+ accuracy_score,
7
+ classification_report,
8
+ confusion_matrix,
9
+ roc_curve,
10
+ precision_recall_curve,
11
+ auc,
12
+ average_precision_score,
13
+ cohen_kappa_score,
14
+ )
15
+ from sklearn.preprocessing import label_binarize
16
+
17
+
18
+ class ClassificationEvaluator:
19
+ """
20
+ A class to evaluate and visualize classification model performance.
21
+
22
+ This class provides methods to compute various classification metrics
23
+ and generate visualizations for model evaluation.
24
+ """
25
+
26
+ def __init__(self, class_names):
27
+ """
28
+ Initialize the evaluator with class names.
29
+
30
+ Parameters:
31
+ - class_names: list of class names
32
+ """
33
+ self.class_names = class_names
34
+ self.num_classes = len(class_names)
35
+
36
+ def _ensure_numpy(self, data):
37
+ """Convert tensor to numpy if needed."""
38
+ if torch.is_tensor(data):
39
+ return data.cpu().numpy()
40
+ return np.array(data)
41
+
42
+ def evaluate_model(self, model, test_loader):
43
+ """
44
+ Evaluate a trained model on test dataset.
45
+
46
+ Parameters:
47
+ - model: PyTorch model to evaluate
48
+ - test_loader: DataLoader containing test data
49
+
50
+ Returns:
51
+ - results: Dictionary containing evaluation metrics
52
+ """
53
+ model.eval()
54
+ device = next(model.parameters()).device
55
+
56
+ all_labels = []
57
+ all_preds = []
58
+ all_scores = []
59
+
60
+ with torch.no_grad():
61
+ for inputs, labels in test_loader:
62
+ inputs, labels = inputs.to(device), labels.to(device)
63
+ outputs = model(inputs)
64
+ _, preds = torch.max(outputs, 1)
65
+
66
+ all_labels.extend(labels.cpu().numpy())
67
+ all_preds.extend(preds.cpu().numpy())
68
+ all_scores.append(
69
+ torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
70
+ )
71
+
72
+ all_scores = np.vstack(all_scores)
73
+
74
+ # Compute metrics
75
+ results = self.compute_metrics(all_labels, all_preds, all_scores)
76
+ return results
77
+
78
+ def compute_metrics(self, y_true, y_pred, y_scores, model_name=""):
79
+ """
80
+ Compute comprehensive classification metrics.
81
+
82
+ Parameters:
83
+ - y_true: true labels
84
+ - y_pred: predicted labels
85
+ - y_scores: predicted probability scores
86
+ - model_name: name of the model (optional)
87
+
88
+ Returns:
89
+ - Dictionary containing all metrics
90
+ """
91
+ # Ensure numpy arrays
92
+ y_true = self._ensure_numpy(y_true)
93
+ y_pred = self._ensure_numpy(y_pred)
94
+ y_scores = self._ensure_numpy(y_scores)
95
+
96
+ # Calculate accuracy
97
+ accuracy = accuracy_score(y_true, y_pred)
98
+ print(f"Overall Accuracy: {accuracy:.4f}")
99
+
100
+ # Calculate and display Cohen's Kappa
101
+ kappa = cohen_kappa_score(y_true, y_pred)
102
+ print(f"Cohen's Kappa Score: {kappa:.4f}")
103
+
104
+ # Generate classification report
105
+ report = classification_report(
106
+ y_true, y_pred, target_names=self.class_names, output_dict=True
107
+ )
108
+
109
+ # Print formatted classification report
110
+ print("\nClassification Report:")
111
+ print(classification_report(y_true, y_pred, target_names=self.class_names))
112
+
113
+ # Calculate ROC curves and AUC for each class
114
+ print("\nCalculating ROC curves...")
115
+ roc_auc_dict = self.plot_roc_curves(y_true, y_scores)
116
+
117
+ # Calculate PR curves and AUC for each class
118
+ print("\nCalculating PR curves...")
119
+ pr_auc_dict = self.plot_pr_curves(y_true, y_scores)
120
+
121
+ # Plot confusion matrix
122
+ print("\nGenerating confusion matrix...")
123
+ self.plot_confusion_matrix(y_true, y_pred)
124
+
125
+ # Plot per-class accuracy
126
+ print("\nCalculating per-class accuracy...")
127
+ self.plot_per_class_accuracy(y_true, y_pred)
128
+
129
+ # Return metrics dictionary
130
+ return {
131
+ "accuracy": accuracy,
132
+ "report": report,
133
+ "roc_auc": roc_auc_dict,
134
+ "pr_auc": pr_auc_dict,
135
+ "kappa": kappa,
136
+ }
137
+
138
+ def plot_roc_curves(self, y_true, y_scores):
139
+ """
140
+ Plot ROC curves for multi-class classification.
141
+
142
+ Parameters:
143
+ - y_true: true labels
144
+ - y_scores: predicted probability scores
145
+
146
+ Returns:
147
+ - Dictionary containing AUC values for each class
148
+ """
149
+ y_true = self._ensure_numpy(y_true)
150
+ y_scores = self._ensure_numpy(y_scores)
151
+
152
+ # Binarize the labels for one-vs-rest ROC calculation
153
+ y_true_bin = label_binarize(y_true, classes=range(self.num_classes))
154
+
155
+ # Compute ROC curve and ROC area for each class
156
+ fpr = {}
157
+ tpr = {}
158
+ roc_auc = {}
159
+
160
+ plt.figure(figsize=(12, 8))
161
+
162
+ for i in range(self.num_classes):
163
+ fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])
164
+ roc_auc[i] = auc(fpr[i], tpr[i])
165
+
166
+ plt.plot(
167
+ fpr[i],
168
+ tpr[i],
169
+ lw=2,
170
+ label=f"{self.class_names[i]} (area = {roc_auc[i]:.2f})",
171
+ )
172
+
173
+ # Plot the diagonal (random classifier)
174
+ plt.plot([0, 1], [0, 1], "k--", lw=2)
175
+
176
+ # Calculate and plot micro-average ROC curve
177
+ fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_scores.ravel())
178
+ roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
179
+ plt.plot(
180
+ fpr["micro"],
181
+ tpr["micro"],
182
+ label=f'Micro-average (area = {roc_auc["micro"]:.2f})',
183
+ lw=2,
184
+ linestyle=":",
185
+ color="deeppink",
186
+ )
187
+
188
+ plt.xlim([0.0, 1.0])
189
+ plt.ylim([0.0, 1.05])
190
+ plt.xlabel("False Positive Rate")
191
+ plt.ylabel("True Positive Rate")
192
+ plt.title("ROC Curves")
193
+ plt.legend(loc="lower right")
194
+ plt.grid(True, alpha=0.3)
195
+ plt.tight_layout()
196
+ plt.show()
197
+
198
+ return roc_auc
199
+
200
+ def plot_pr_curves(self, y_true, y_scores):
201
+ """
202
+ Plot Precision-Recall curves for multi-class classification.
203
+
204
+ Parameters:
205
+ - y_true: true labels
206
+ - y_scores: predicted probability scores
207
+
208
+ Returns:
209
+ - Dictionary containing average precision values for each class
210
+ """
211
+ y_true = self._ensure_numpy(y_true)
212
+ y_scores = self._ensure_numpy(y_scores)
213
+
214
+ # Binarize the labels
215
+ y_true_bin = label_binarize(y_true, classes=range(self.num_classes))
216
+
217
+ # Compute PR curve and average precision for each class
218
+ precision = {}
219
+ recall = {}
220
+ avg_precision = {}
221
+
222
+ plt.figure(figsize=(12, 8))
223
+
224
+ for i in range(self.num_classes):
225
+ precision[i], recall[i], _ = precision_recall_curve(
226
+ y_true_bin[:, i], y_scores[:, i]
227
+ )
228
+ avg_precision[i] = average_precision_score(y_true_bin[:, i], y_scores[:, i])
229
+
230
+ plt.plot(
231
+ recall[i],
232
+ precision[i],
233
+ lw=2,
234
+ label=f"{self.class_names[i]} (AP = {avg_precision[i]:.2f})",
235
+ )
236
+
237
+ # Calculate and plot micro-average PR curve
238
+ precision["micro"], recall["micro"], _ = precision_recall_curve(
239
+ y_true_bin.ravel(), y_scores.ravel()
240
+ )
241
+ avg_precision["micro"] = average_precision_score(
242
+ y_true_bin.ravel(), y_scores.ravel()
243
+ )
244
+
245
+ plt.plot(
246
+ recall["micro"],
247
+ precision["micro"],
248
+ label=f'Micro-average (AP = {avg_precision["micro"]:.2f})',
249
+ lw=2,
250
+ linestyle=":",
251
+ color="deeppink",
252
+ )
253
+
254
+ plt.xlim([0.0, 1.0])
255
+ plt.ylim([0.0, 1.05])
256
+ plt.xlabel("Recall")
257
+ plt.ylabel("Precision")
258
+ plt.title("Precision-Recall Curves")
259
+ plt.legend(loc="best")
260
+ plt.grid(True, alpha=0.3)
261
+ plt.tight_layout()
262
+ plt.show()
263
+
264
+ return avg_precision
265
+
266
+ def plot_confusion_matrix(self, y_true, y_pred):
267
+ """
268
+ Plot confusion matrix.
269
+
270
+ Parameters:
271
+ - y_true: true labels
272
+ - y_pred: predicted labels
273
+ """
274
+ y_true = self._ensure_numpy(y_true)
275
+ y_pred = self._ensure_numpy(y_pred)
276
+
277
+ # Get unique values in both arrays
278
+ unique_values = np.unique(np.concatenate([y_true, y_pred]))
279
+ print(f"Unique values in confusion matrix data: {unique_values}")
280
+
281
+ # Create the confusion matrix with explicit labels
282
+ cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes))
283
+
284
+ plt.figure(figsize=(10, 8))
285
+ sns.heatmap(
286
+ cm,
287
+ annot=True,
288
+ fmt="d",
289
+ cmap="Blues",
290
+ xticklabels=self.class_names,
291
+ yticklabels=self.class_names,
292
+ )
293
+ plt.title("Confusion Matrix")
294
+ plt.xlabel("Predicted")
295
+ plt.ylabel("True")
296
+ plt.tight_layout()
297
+ plt.show()
298
+
299
+ def plot_per_class_accuracy(self, y_true, y_pred):
300
+ """
301
+ Plot per-class accuracy.
302
+
303
+ Parameters:
304
+ - y_true: true labels
305
+ - y_pred: predicted labels
306
+ """
307
+ y_true = self._ensure_numpy(y_true)
308
+ y_pred = self._ensure_numpy(y_pred)
309
+
310
+ # Create the confusion matrix with explicit labels
311
+ cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes))
312
+
313
+ # Calculate per-class accuracy
314
+ per_class_accuracy = np.zeros(self.num_classes)
315
+ for i in range(self.num_classes):
316
+ if i < cm.shape[0] and np.sum(cm[i, :]) > 0:
317
+ per_class_accuracy[i] = cm[i, i] / np.sum(cm[i, :])
318
+
319
+ # Create the bar plot
320
+ plt.figure(figsize=(14, 7))
321
+ plt.bar(range(self.num_classes), per_class_accuracy, color="skyblue")
322
+ plt.xticks(range(self.num_classes), self.class_names, rotation=45, ha="right")
323
+ plt.xlabel("Classes")
324
+ plt.ylabel("Accuracy")
325
+ plt.title("Per-Class Accuracy")
326
+ plt.tight_layout()
327
+ plt.show()
328
+
329
+ return per_class_accuracy
330
+
331
+ def plot_training_history(self, train_losses, val_losses, train_accs, val_accs):
332
+ """
333
+ Plot accuracy and loss curves from training history.
334
+
335
+ Parameters:
336
+ - train_losses: list of training losses
337
+ - val_losses: list of validation losses
338
+ - train_accs: list of training accuracies
339
+ - val_accs: list of validation accuracies
340
+ """
341
+ plt.figure(figsize=(12, 5))
342
+
343
+ # Accuracy curve
344
+ plt.subplot(1, 2, 1)
345
+ plt.plot(train_accs, label="Train Accuracy")
346
+ plt.plot(val_accs, label="Validation Accuracy")
347
+ plt.xlabel("Epochs")
348
+ plt.ylabel("Accuracy")
349
+ plt.title("Accuracy Curve")
350
+ plt.legend()
351
+ plt.grid(True)
352
+
353
+ # Loss curve
354
+ plt.subplot(1, 2, 2)
355
+ plt.plot(train_losses, label="Train Loss")
356
+ plt.plot(val_losses, label="Validation Loss")
357
+ plt.xlabel("Epochs")
358
+ plt.ylabel("Loss")
359
+ plt.title("Loss Curve")
360
+ plt.legend()
361
+ plt.grid(True)
362
+
363
+ plt.tight_layout()
364
+ plt.show()
utils/ModelCreator.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+
5
+ # Set device
6
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+
9
+ class EyeDetectionModels:
10
+ def __init__(self, num_classes, freeze_layers=True, device=DEVICE):
11
+ """
12
+ Initialize the EyeDetectionModels class.
13
+ This class provides methods to create and configure various deep learning models for eye detection.
14
+ """
15
+ # Initialize the model creator
16
+ self.num_classes = num_classes
17
+ self.freeze_layers = freeze_layers
18
+ self.device = device
19
+ self.models = {
20
+ "mobilenetv4": self.get_model_mobilenetv4,
21
+ "levit": self.get_model_levit,
22
+ "efficientvit": self.get_model_efficientvit,
23
+ "gernet": self.get_model_gernet,
24
+ "regnetx": self.get_model_regnetx,
25
+ }
26
+
27
+ # Model architecture functions
28
+ @staticmethod
29
+ def _get_feature_blocks(model):
30
+ """
31
+ Utility: locate the main feature blocks container in a timm model.
32
+ Returns a list-like module of blocks.
33
+ """
34
+ for attr in ("features", "blocks", "layers", "stem"): # common container names
35
+ if hasattr(model, attr):
36
+ return getattr(model, attr)
37
+ # fallback: collect all children except classifier/head
38
+ return list(model.children())[:-1]
39
+
40
+ @staticmethod
41
+ def _freeze_except_last_n(blocks, n):
42
+ total = len(blocks)
43
+ for idx, block in enumerate(blocks):
44
+ requires = idx >= total - n
45
+ for p in block.parameters():
46
+ p.requires_grad = requires
47
+
48
+ def get_model_mobilenetv4(self):
49
+ model = timm.create_model(
50
+ "mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
51
+ )
52
+ if self.freeze_layers:
53
+ blocks = self._get_feature_blocks(model)
54
+ self._freeze_except_last_n(blocks, 2)
55
+ # replace classifier
56
+ in_features = model.classifier.in_features
57
+ model.classifier = nn.Sequential(
58
+ nn.Linear(in_features, 512),
59
+ nn.ReLU(inplace=True),
60
+ nn.Dropout(0.4),
61
+ nn.Linear(512, self.num_classes),
62
+ )
63
+ return model.to(self.device)
64
+
65
+ def get_model_levit(self):
66
+ model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
67
+ if self.freeze_layers:
68
+ blocks = self._get_feature_blocks(model)
69
+ self._freeze_except_last_n(blocks, 2)
70
+ # Attempt to extract in_features from model.head or classifier
71
+ head = getattr(model, "head_dist", None) or getattr(model, "classifier", None)
72
+ linear = getattr(head, "linear")
73
+ in_features = 384
74
+ model.head = nn.Sequential(
75
+ nn.Linear(in_features, 512),
76
+ nn.ReLU(inplace=True),
77
+ nn.Dropout(0.4),
78
+ nn.Linear(512, self.num_classes),
79
+ )
80
+ model.head_dist = nn.Sequential(
81
+ nn.Linear(in_features, 512),
82
+ nn.ReLU(inplace=True),
83
+ nn.Dropout(0.4),
84
+ nn.Linear(512, self.num_classes),
85
+ )
86
+ return model.to(self.device)
87
+
88
+ def get_model_efficientvit(self):
89
+ model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
90
+ if self.freeze_layers:
91
+ blocks = self._get_feature_blocks(model)
92
+ self._freeze_except_last_n(blocks, 2)
93
+ # handle different head naming
94
+ head = getattr(model, "head", None)
95
+ print(head)
96
+ linear = getattr(head, "linear")
97
+ in_features = 192
98
+ model.head.linear = nn.Sequential(
99
+ nn.Linear(in_features, 512),
100
+ nn.ReLU(inplace=True),
101
+ nn.Dropout(0.4),
102
+ nn.Linear(512, self.num_classes),
103
+ )
104
+ return model.to(self.device)
105
+
106
+ def get_model_gernet(self):
107
+ """
108
+ Load and configure a GENet (General and Efficient Network) model with customizable classifier.
109
+
110
+ Returns:
111
+ Configured GENet model
112
+ """
113
+ model = timm.create_model("gernet_s.idstcv_in1k", pretrained=True)
114
+
115
+ if self.freeze_layers:
116
+ # For GENet, we need to specifically handle its structure
117
+ # It typically has a 'stem' and 'stages' structure
118
+ if hasattr(model, "stem") and hasattr(model, "stages"):
119
+ # Freeze stem completely
120
+ for param in model.stem.parameters():
121
+ param.requires_grad = False
122
+
123
+ # Freeze all stages except the last two
124
+ stages = list(model.stages.children())
125
+ total_stages = len(stages)
126
+ for i, stage in enumerate(stages):
127
+ requires_grad = i >= total_stages - 2
128
+ for param in stage.parameters():
129
+ param.requires_grad = requires_grad
130
+ else:
131
+ # Fallback to generic approach
132
+ blocks = self._get_feature_blocks(model)
133
+ self._freeze_except_last_n(blocks, 2)
134
+
135
+ # Replace classifier
136
+ in_features = model.head.fc.in_features
137
+ model.head.fc = nn.Sequential(
138
+ nn.Linear(in_features, 512),
139
+ nn.ReLU(inplace=True),
140
+ nn.Dropout(0.4),
141
+ nn.Linear(512, self.num_classes),
142
+ )
143
+ return model.to(self.device)
144
+
145
+ def get_model_regnetx(self):
146
+ """
147
+ Load and configure a RegNetX model with customizable classifier.
148
+
149
+ Returns:
150
+ Configured RegNetX model
151
+ """
152
+ model = timm.create_model("regnetx_008.tv2_in1k", pretrained=True)
153
+
154
+ if self.freeze_layers:
155
+ for param in model.parameters():
156
+ param.requires_grad = False
157
+
158
+ # RegNetX typically has 'stem' + 'trunk' structure in timm
159
+ if hasattr(model, "trunk"):
160
+ # Unfreeze final stages of the trunk
161
+ trunk_blocks = list(model.trunk.children())
162
+ # Unfreeze approximately last 25% of trunk blocks
163
+ unfreeze_from = max(0, int(len(trunk_blocks) * 0.75))
164
+ for i in range(unfreeze_from, len(trunk_blocks)):
165
+ for param in trunk_blocks[i].parameters():
166
+ param.requires_grad = True
167
+
168
+ # Always unfreeze the classifier/head for fine-tuning
169
+ for param in model.head.parameters():
170
+ param.requires_grad = True
171
+
172
+ # Replace classifier
173
+ in_features = model.head.fc.in_features
174
+ model.head.fc = nn.Sequential(
175
+ nn.Linear(in_features, 512),
176
+ nn.ReLU(inplace=True),
177
+ nn.Dropout(0.4),
178
+ nn.Linear(512, self.num_classes),
179
+ )
180
+ return model.to(self.device)
utils/Trainer.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import torch.nn as nn
6
+ from tqdm import tqdm
7
+ import gc
8
+
9
+ from Evaluator import ClassificationEvaluator
10
+ from Callback import EarlyStopping
11
+
12
+
13
+ def train_model(
14
+ model,
15
+ criterion,
16
+ optimizer,
17
+ scheduler,
18
+ train_loader,
19
+ val_loader,
20
+ early_stopping,
21
+ epochs=15,
22
+ use_ddp=False,
23
+ ):
24
+ """
25
+ Train the model and perform validation using multiple GPUs.
26
+ Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
27
+
28
+ Args:
29
+ model: Model to train
30
+ criterion: Loss function
31
+ optimizer: Optimizer for training
32
+ scheduler: Learning rate scheduler
33
+ train_loader: DataLoader for training data
34
+ val_loader: DataLoader for validation data
35
+ early_stopping: Early stopping handler
36
+ epochs: Maximum number of epochs to train
37
+ use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False)
38
+ """
39
+ # Check available GPUs
40
+ num_gpus = torch.cuda.device_count()
41
+ if num_gpus < 2:
42
+ print(
43
+ f"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources."
44
+ )
45
+ else:
46
+ print(f"Using {num_gpus} GPUs for training")
47
+
48
+ # Setup device and model
49
+ if num_gpus >= 2:
50
+ if use_ddp:
51
+ # For DistributedDataParallel
52
+ import torch.distributed as dist
53
+ from torch.nn.parallel import DistributedDataParallel as DDP
54
+
55
+ # Initialize process group
56
+ dist.init_process_group(backend="nccl")
57
+ local_rank = dist.get_rank()
58
+ torch.cuda.set_device(local_rank)
59
+ device = torch.device(f"cuda:{local_rank}")
60
+
61
+ model = model.to(device)
62
+ model = DDP(model, device_ids=[local_rank])
63
+ else:
64
+ # For DataParallel (simpler to use)
65
+ device = torch.device("cuda:0")
66
+ model = model.to(device)
67
+ model = torch.nn.DataParallel(model)
68
+ else:
69
+ # Single GPU
70
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71
+ model = model.to(device)
72
+
73
+ train_losses = []
74
+ val_losses = []
75
+ train_accs = []
76
+ val_accs = []
77
+
78
+ # Store validation predictions and labels for final evaluation
79
+ all_val_labels = []
80
+ all_val_preds = []
81
+ all_val_scores = []
82
+
83
+ for epoch in range(epochs):
84
+ print(f"Epoch {epoch+1}/{epochs}")
85
+
86
+ # Training phase
87
+ model.train()
88
+ running_loss = 0.0
89
+ correct = 0
90
+ total = 0
91
+
92
+ for inputs, labels in tqdm(train_loader, desc="Training"):
93
+ inputs, labels = inputs.to(device), labels.to(device)
94
+
95
+ optimizer.zero_grad()
96
+ outputs = model(inputs)
97
+ loss = criterion(outputs, labels)
98
+ loss.backward()
99
+ optimizer.step()
100
+
101
+ running_loss += loss.item() * inputs.size(0)
102
+ _, predicted = torch.max(outputs, 1)
103
+ total += labels.size(0)
104
+ correct += (predicted == labels).sum().item()
105
+
106
+ epoch_train_loss = running_loss / len(train_loader.dataset)
107
+ epoch_train_acc = correct / total
108
+ train_losses.append(epoch_train_loss)
109
+ train_accs.append(epoch_train_acc)
110
+
111
+ # Validation phase
112
+ model.eval()
113
+ running_loss = 0.0
114
+ correct = 0
115
+ total = 0
116
+
117
+ all_labels = []
118
+ all_preds = []
119
+ all_scores = []
120
+
121
+ with torch.no_grad():
122
+ for inputs, labels in tqdm(val_loader, desc="Validation"):
123
+ inputs, labels = inputs.to(device), labels.to(device)
124
+ outputs = model(inputs)
125
+ loss = criterion(outputs, labels)
126
+
127
+ running_loss += loss.item() * inputs.size(0)
128
+ probs = F.softmax(outputs, dim=1)
129
+ _, predicted = torch.max(outputs, 1)
130
+ total += labels.size(0)
131
+ correct += (predicted == labels).sum().item()
132
+
133
+ all_labels.extend(labels.cpu().numpy().tolist())
134
+ all_preds.extend(predicted.cpu().numpy().tolist())
135
+ all_scores.append(probs.cpu().numpy())
136
+
137
+ epoch_val_loss = running_loss / len(val_loader.dataset)
138
+ epoch_val_acc = correct / total
139
+ val_losses.append(epoch_val_loss)
140
+ val_accs.append(epoch_val_acc)
141
+
142
+ all_scores = np.vstack(all_scores) if all_scores else np.array([])
143
+
144
+ # Store validation results for the final epoch
145
+ all_val_labels = all_labels
146
+ all_val_preds = all_preds
147
+ all_val_scores = all_scores
148
+
149
+ # Update learning rate scheduler
150
+ scheduler.step(epoch_val_loss)
151
+
152
+ print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
153
+ print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")
154
+ print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
155
+
156
+ # Check early stopping
157
+ early_stopping(epoch_val_loss)
158
+ if early_stopping.early_stop:
159
+ print("Early stopping triggered!")
160
+ break
161
+
162
+ # Free up memory
163
+ del all_labels, all_preds, all_scores
164
+ gc.collect()
165
+ torch.cuda.empty_cache()
166
+
167
+ # Clean up DDP if used
168
+ if num_gpus >= 2 and use_ddp:
169
+ dist.destroy_process_group()
170
+
171
+ return (
172
+ model,
173
+ train_losses,
174
+ val_losses,
175
+ train_accs,
176
+ val_accs,
177
+ all_val_labels,
178
+ all_val_preds,
179
+ all_val_scores,
180
+ )
181
+
182
+
183
+ def model_train(model, train_loader, val_loader, dataset, epochs=20):
184
+ model_name = type(model).__name__
185
+ if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
186
+ model_name = model.pretrained_cfg["name"]
187
+
188
+ print(f"\n{'='*20} Training {model_name} {'='*20}\n")
189
+
190
+ class_names = dataset.classes
191
+ num_classes = len(class_names)
192
+ learning_rate = 0.001
193
+
194
+ try:
195
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
196
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
197
+ optimizer, mode="min", factor=0.1, patience=3
198
+ )
199
+ early_stopping = EarlyStopping(patience=5)
200
+
201
+ (
202
+ model,
203
+ train_losses,
204
+ val_losses,
205
+ train_accs,
206
+ val_accs,
207
+ val_labels,
208
+ val_preds,
209
+ val_scores,
210
+ ) = train_model(
211
+ model,
212
+ nn.CrossEntropyLoss(),
213
+ optimizer,
214
+ scheduler,
215
+ train_loader,
216
+ val_loader,
217
+ early_stopping,
218
+ epochs=epochs,
219
+ use_ddp=False,
220
+ )
221
+
222
+ print(f"\n{'='*20} Evaluation for {model_name} {'='*20}\n")
223
+ evaluator = ClassificationEvaluator(
224
+ num_classes=num_classes,
225
+ class_names=class_names,
226
+ )
227
+
228
+ evaluator.plot_training_history(train_losses, val_losses, train_accs, val_accs)
229
+ # Process validation predictions and labels
230
+ try:
231
+ evaluator.plot_confusion_matrix(val_labels, val_preds)
232
+ evaluator.plot_per_class_accuracy(val_labels, val_preds)
233
+
234
+ # Get metrics from the updated function including kappa
235
+ accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = (
236
+ evaluator.compute_metrics(
237
+ val_labels,
238
+ val_preds,
239
+ val_scores,
240
+ model_name,
241
+ )
242
+ )
243
+
244
+ # Build a results dictionary including kappa
245
+ results = {
246
+ "accuracy": accuracy,
247
+ "report": report_dict,
248
+ "roc_auc": roc_auc_dict,
249
+ "pr_auc": pr_auc_dict,
250
+ "kappa": kappa,
251
+ }
252
+
253
+ return results
254
+ except Exception as viz_error:
255
+ print(f"Error in visualization: {viz_error}")
256
+ import traceback
257
+
258
+ traceback.print_exc()
259
+ return {"accuracy": None}
260
+
261
+ except Exception as e:
262
+ print(f"Error occurred when training {model_name}: {e}")
263
+ import traceback
264
+
265
+ traceback.print_exc()
266
+ return {"accuracy": None}
267
+ finally:
268
+ # Clean up memory
269
+ if "optimizer" in locals():
270
+ del optimizer
271
+ if "scheduler" in locals():
272
+ del scheduler
273
+ if "early_stopping" in locals():
274
+ del early_stopping
275
+ if "train_losses" in locals():
276
+ del train_losses
277
+ del val_losses
278
+ del train_accs
279
+ del val_accs
280
+ del val_labels
281
+ del val_preds
282
+ del val_scores
283
+
284
+ gc.collect()
285
+ torch.cuda.empty_cache()
uv.lock ADDED
The diff for this file is too large to render. See raw diff