Spaces:
Sleeping
Sleeping
GilbertKrantz
commited on
Commit
·
61c2d3f
0
Parent(s):
Initial Commit
Browse files- .gitattributes +1 -0
- .gitignore +182 -0
- .gradio/certificate.pem +31 -0
- .python-version +1 -0
- DockerFile +28 -0
- LICENSE +201 -0
- README.md +123 -0
- gradio-inference.py +221 -0
- main.py +342 -0
- pyproject.toml +36 -0
- requirements.txt +11 -0
- training.ipynb +1 -0
- utils/Callback.py +21 -0
- utils/Comparator.py +179 -0
- utils/DatasetHandler.py +93 -0
- utils/Evaluator.py +364 -0
- utils/ModelCreator.py +180 -0
- utils/Trainer.py +285 -0
- uv.lock +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
weights/efficientvit.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# Ruff stuff:
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# PyPI configuration file
|
| 174 |
+
.pypirc
|
| 175 |
+
|
| 176 |
+
# Data
|
| 177 |
+
Data/
|
| 178 |
+
|
| 179 |
+
EDA/
|
| 180 |
+
|
| 181 |
+
# Model Outputs
|
| 182 |
+
model_outputs/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12.9
|
DockerFile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
git \
|
| 8 |
+
libgl1-mesa-glx \
|
| 9 |
+
libglib2.0-0 \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements first to leverage Docker cache
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy project files
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Make the gradio inference script executable
|
| 22 |
+
RUN chmod +x gradio_inference.py
|
| 23 |
+
|
| 24 |
+
# Expose port for Gradio
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Set the entrypoint command
|
| 28 |
+
CMD ["python", "gradio_inference.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eye Disease Detection
|
| 2 |
+
|
| 3 |
+
This repository contains a Gradio web application for eye disease detection using deep learning models. The application allows users to upload fundus images and get predictions for common eye conditions.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Easy-to-use web interface** for eye disease detection
|
| 8 |
+
- Support for **multiple model architectures** (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
|
| 9 |
+
- **Custom model loading** from saved model checkpoints
|
| 10 |
+
- **Visualization** of prediction probabilities
|
| 11 |
+
- **Dockerized deployment** option
|
| 12 |
+
|
| 13 |
+
## Supported Eye Conditions
|
| 14 |
+
|
| 15 |
+
The system can detect the following eye conditions:
|
| 16 |
+
- Central Serous Chorioretinopathy
|
| 17 |
+
- Diabetic Retinopathy
|
| 18 |
+
- Disc Edema
|
| 19 |
+
- Glaucoma
|
| 20 |
+
- Healthy (normal eye)
|
| 21 |
+
- Macular Scar
|
| 22 |
+
- Myopia
|
| 23 |
+
- Retinal Detachment
|
| 24 |
+
- Retinitis Pigmentosa
|
| 25 |
+
|
| 26 |
+
## Installation
|
| 27 |
+
|
| 28 |
+
### Prerequisites
|
| 29 |
+
|
| 30 |
+
- Python 3.12+
|
| 31 |
+
- PyTorch 2.7.0+
|
| 32 |
+
- CUDA-compatible GPU (optional, but recommended for faster inference)
|
| 33 |
+
|
| 34 |
+
### Option 1: Local Installation
|
| 35 |
+
|
| 36 |
+
1. Clone this repository:
|
| 37 |
+
```bash
|
| 38 |
+
git clone https://github.com/GilbertKrantz/eye-disease-detection.git
|
| 39 |
+
cd eye-disease-detection
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
2. Install the required packages:
|
| 43 |
+
```bash
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
3. Run the application:
|
| 48 |
+
```bash
|
| 49 |
+
python gradio_inference.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
4. Open your browser and go to http://localhost:7860
|
| 53 |
+
|
| 54 |
+
### Option 2: Docker Installation
|
| 55 |
+
|
| 56 |
+
1. Build the Docker image:
|
| 57 |
+
```bash
|
| 58 |
+
docker build -t eye-disease-detection .
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
2. Run the container:
|
| 62 |
+
```bash
|
| 63 |
+
docker run -p 7860:7860 eye-disease-detection
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
3. Open your browser and go to http://localhost:7860
|
| 67 |
+
|
| 68 |
+
## Usage
|
| 69 |
+
|
| 70 |
+
1. Upload a fundus image of the eye
|
| 71 |
+
2. (Optional) Specify the path to your trained model file (.pth)
|
| 72 |
+
3. Select the model architecture (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
|
| 73 |
+
4. Click "Analyze Image" to get the prediction
|
| 74 |
+
5. View the results and probability distribution
|
| 75 |
+
|
| 76 |
+
## Model Training
|
| 77 |
+
|
| 78 |
+
This repository focuses on inference. For training your own models, refer to the main training script and follow these steps:
|
| 79 |
+
|
| 80 |
+
1. Prepare your dataset in the required directory structure
|
| 81 |
+
2. Train a model using the main.py script:
|
| 82 |
+
```bash
|
| 83 |
+
python main.py --train-dir "/path/to/training/data" --eval-dir "/path/to/eval/data" --model mobilenetv4 --epochs 20 --save-model "my_model.pth"
|
| 84 |
+
```
|
| 85 |
+
3. Use the saved model with the inference application
|
| 86 |
+
|
| 87 |
+
## Project Structure
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
.
|
| 91 |
+
├── gradio_inference.py # Main Gradio application
|
| 92 |
+
├── requirements.txt # Python dependencies
|
| 93 |
+
├── Dockerfile # Docker configuration
|
| 94 |
+
├── README.md # This documentation
|
| 95 |
+
├── utils/ # Utility modules
|
| 96 |
+
│ ├── ModelCreator.py # Model architecture definitions
|
| 97 |
+
│ ├── Evaluator.py # Model evaluation utilities
|
| 98 |
+
│ ├── DatasetHandler.py # Dataset handling utilities
|
| 99 |
+
│ ├── Trainer.py # Model training utilities
|
| 100 |
+
│ └── Callback.py # Training callbacks
|
| 101 |
+
└── main.py # Main training script
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Performance
|
| 105 |
+
|
| 106 |
+
The performance of the models depends on the quality of training data and the specific architecture used. In general, these models can achieve accuracy rates of 85-95% on standard eye disease datasets.
|
| 107 |
+
|
| 108 |
+
## Customization
|
| 109 |
+
|
| 110 |
+
You can customize the application in several ways:
|
| 111 |
+
- Add example images in the Gradio interface
|
| 112 |
+
- Extend the list of supported classes by modifying the CLASSES variable in gradio_inference.py
|
| 113 |
+
- Add support for additional model architectures in ModelCreator.py
|
| 114 |
+
|
| 115 |
+
## License
|
| 116 |
+
|
| 117 |
+
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
|
| 118 |
+
|
| 119 |
+
## Acknowledgments
|
| 120 |
+
|
| 121 |
+
- The models are built using PyTorch and the TIMM library
|
| 122 |
+
- The web interface is built using Gradio
|
| 123 |
+
- Special thanks to the open-source community for making this project possible
|
gradio-inference.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Eye Disease Detection - Gradio Inference App
|
| 3 |
+
# Date: May 11, 2025
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
|
| 16 |
+
# Import custom modules
|
| 17 |
+
sys.path.append("./utils")
|
| 18 |
+
from ModelCreator import EyeDetectionModels
|
| 19 |
+
|
| 20 |
+
# Set device
|
| 21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
print(f"Using device: {device}")
|
| 23 |
+
|
| 24 |
+
# Define class names (make sure these match your model's classes)
|
| 25 |
+
CLASSES = [
|
| 26 |
+
"Central Serous Chorioretinopathy",
|
| 27 |
+
"Diabetic Retinopathy",
|
| 28 |
+
"Disc Edema",
|
| 29 |
+
"Glaucoma",
|
| 30 |
+
"Healthy",
|
| 31 |
+
"Macular Scar",
|
| 32 |
+
"Myopia",
|
| 33 |
+
"Retinal Detachment",
|
| 34 |
+
"Retinitis Pigmentosa",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_transform():
|
| 39 |
+
"""Get the standard transformation pipeline for inference."""
|
| 40 |
+
return transforms.Compose(
|
| 41 |
+
[
|
| 42 |
+
transforms.Resize(256),
|
| 43 |
+
transforms.CenterCrop(224),
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_model(model_path, model_type="efficientvit"):
|
| 51 |
+
"""
|
| 52 |
+
Load a pretrained model for inference.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_path: Path to the saved model state dict
|
| 56 |
+
model_type: Type of model to load (mobilenetv4, levit, efficientvit, gernet, regnetx)
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Loaded model ready for inference
|
| 60 |
+
"""
|
| 61 |
+
# Initialize model creator
|
| 62 |
+
logging.info("Initializing model creator...")
|
| 63 |
+
model_creator = EyeDetectionModels(
|
| 64 |
+
num_classes=len(CLASSES), freeze_layers=False # Not relevant for inference
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Check if model type exists
|
| 68 |
+
if model_type not in model_creator.models:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"Model type '{model_type}' not found. Available models: {list(model_creator.models.keys())}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Create model of specified type
|
| 74 |
+
logging.info(f"Creating model of type: {model_type}")
|
| 75 |
+
model = model_creator.models[model_type]()
|
| 76 |
+
|
| 77 |
+
# Load state dict if provided
|
| 78 |
+
if model_path and not os.path.exists(model_path):
|
| 79 |
+
raise FileNotFoundError(f"Model path '{model_path}' does not exist.")
|
| 80 |
+
elif model_path is None:
|
| 81 |
+
# Use default model path if it exists
|
| 82 |
+
if os.path.exists(f"./weights/{model_type}.pth"):
|
| 83 |
+
model_path = f"./weights/{model_type}.pth"
|
| 84 |
+
else:
|
| 85 |
+
model_path = None
|
| 86 |
+
logging.warning(
|
| 87 |
+
f"Default model path '{model_path}' not found. Using untrained model."
|
| 88 |
+
)
|
| 89 |
+
# Set model to evaluation mode
|
| 90 |
+
model.eval()
|
| 91 |
+
return model
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def predict_image(image, model_path, model_type):
|
| 95 |
+
"""
|
| 96 |
+
Predict eye disease from an uploaded image.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
image: Input image from Gradio
|
| 100 |
+
model_path: Path to the model state dict
|
| 101 |
+
model_type: Type of model architecture
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Dictionary of class probabilities
|
| 105 |
+
"""
|
| 106 |
+
try:
|
| 107 |
+
|
| 108 |
+
logging.info("Starting prediction...")
|
| 109 |
+
# Load model
|
| 110 |
+
model = load_model(model_path, model_type)
|
| 111 |
+
|
| 112 |
+
# Preprocess image
|
| 113 |
+
logging.info("Preprocessing image...")
|
| 114 |
+
if image is None:
|
| 115 |
+
logging.warning("No image provided.")
|
| 116 |
+
return {cls: 0.0 for cls in CLASSES}
|
| 117 |
+
transform = get_transform()
|
| 118 |
+
if image is None:
|
| 119 |
+
return {cls: 0.0 for cls in CLASSES}
|
| 120 |
+
|
| 121 |
+
# Convert numpy array to PIL Image
|
| 122 |
+
img = Image.fromarray(image).convert("RGB")
|
| 123 |
+
img_tensor = transform(img).unsqueeze(0).to(device)
|
| 124 |
+
logging.info("Image preprocessed successfully.")
|
| 125 |
+
|
| 126 |
+
# Make prediction
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
outputs = model(img_tensor)
|
| 129 |
+
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()
|
| 130 |
+
|
| 131 |
+
# Return probabilities for each class
|
| 132 |
+
return {cls: float(prob) for cls, prob in zip(CLASSES, probabilities)}
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
import traceback
|
| 136 |
+
|
| 137 |
+
traceback.print_exc()
|
| 138 |
+
return {cls: 0.0 for cls in CLASSES}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main():
|
| 142 |
+
"""Main function to run the Gradio interface."""
|
| 143 |
+
# Define available models
|
| 144 |
+
model_types = ["mobilenetv4", "levit", "efficientvit", "gernet", "regnetx"]
|
| 145 |
+
|
| 146 |
+
# Create the Gradio interface
|
| 147 |
+
with gr.Blocks(title="Eye Disease Detection") as demo:
|
| 148 |
+
gr.Markdown("# Eye Disease Detection System")
|
| 149 |
+
gr.Markdown(
|
| 150 |
+
"""This application uses deep learning to detect eye diseases from fundus images.
|
| 151 |
+
Currently supports detection of: Cataract, Diabetic Retinopathy, Glaucoma, and Normal eyes."""
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column():
|
| 156 |
+
input_image = gr.Image(label="Upload Fundus Image", type="numpy")
|
| 157 |
+
model_path = gr.Textbox(
|
| 158 |
+
label="Model Path (leave empty to use default)",
|
| 159 |
+
placeholder="Path to model .pth file",
|
| 160 |
+
value="",
|
| 161 |
+
)
|
| 162 |
+
model_type = gr.Dropdown(
|
| 163 |
+
label="Model Architecture", choices=model_types, value="mobilenetv4"
|
| 164 |
+
)
|
| 165 |
+
submit_btn = gr.Button("Analyze Image", variant="primary")
|
| 166 |
+
|
| 167 |
+
with gr.Column():
|
| 168 |
+
output_chart = gr.Label(label="Prediction")
|
| 169 |
+
|
| 170 |
+
# Process the image when the button is clicked
|
| 171 |
+
submit_btn.click(
|
| 172 |
+
fn=predict_image,
|
| 173 |
+
inputs=[input_image, model_path, model_type],
|
| 174 |
+
outputs=output_chart,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Examples section
|
| 178 |
+
gr.Markdown("### Examples (Please add your own example images)")
|
| 179 |
+
gr.Examples(
|
| 180 |
+
examples=[], # Add example paths here
|
| 181 |
+
inputs=input_image,
|
| 182 |
+
outputs=[output_chart],
|
| 183 |
+
fn=predict_image,
|
| 184 |
+
cache_examples=True,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Usage instructions
|
| 188 |
+
with gr.Accordion("Usage Instructions", open=False):
|
| 189 |
+
gr.Markdown(
|
| 190 |
+
"""
|
| 191 |
+
## How to use this application:
|
| 192 |
+
|
| 193 |
+
1. **Upload an image**: Click the upload button to select a fundus image from your computer
|
| 194 |
+
2. **Specify model** (Optional):
|
| 195 |
+
- Enter the path to your trained model file (.pth)
|
| 196 |
+
- Select the model architecture that was used for training
|
| 197 |
+
3. **Analyze**: Click the "Analyze Image" button to get results
|
| 198 |
+
4. **Interpret results**: The system will show the detected condition and probability distribution
|
| 199 |
+
|
| 200 |
+
## Model Information:
|
| 201 |
+
|
| 202 |
+
This system supports multiple model architectures:
|
| 203 |
+
- **MobileNetV4**: Lightweight and efficient model
|
| 204 |
+
- **LeViT**: Vision Transformer designed for efficiency
|
| 205 |
+
- **EfficientViT**: Hybrid CNN-Transformer architecture
|
| 206 |
+
- **GENet**: General and Efficient Network
|
| 207 |
+
- **RegNetX**: Systematically designed CNN architecture
|
| 208 |
+
|
| 209 |
+
For best results, ensure you're using a high-quality fundus image and the correct model type.
|
| 210 |
+
"""
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Launch the app
|
| 214 |
+
demo.launch(
|
| 215 |
+
share=True,
|
| 216 |
+
pwa=True,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Eye Disease Detection - Main Application
|
| 3 |
+
# Date: May 11, 2025
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import argparse
|
| 8 |
+
import random
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from torchvision import transforms, datasets
|
| 14 |
+
from torch.utils.data import DataLoader, random_split
|
| 15 |
+
|
| 16 |
+
# Import custom modules
|
| 17 |
+
sys.path.append("./utils")
|
| 18 |
+
from ModelCreator import EyeDetectionModels
|
| 19 |
+
from DatasetHandler import FilteredImageDataset
|
| 20 |
+
from Evaluator import ClassificationEvaluator
|
| 21 |
+
from Comparator import compare_models
|
| 22 |
+
from Trainer import model_train
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Set random seeds for reproducibility
|
| 26 |
+
def set_seed(seed=42):
|
| 27 |
+
"""Set seeds for reproducibility."""
|
| 28 |
+
random.seed(seed)
|
| 29 |
+
np.random.seed(seed)
|
| 30 |
+
torch.manual_seed(seed)
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
torch.cuda.manual_seed(seed)
|
| 33 |
+
torch.cuda.manual_seed_all(seed)
|
| 34 |
+
torch.backends.cudnn.deterministic = True
|
| 35 |
+
torch.backends.cudnn.benchmark = False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_transform():
|
| 39 |
+
"""
|
| 40 |
+
Get standard data transform for both training and validation/testing.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
transform: Standard transform for all datasets
|
| 44 |
+
"""
|
| 45 |
+
# Standard transform as specified
|
| 46 |
+
transform = transforms.Compose(
|
| 47 |
+
[
|
| 48 |
+
transforms.Resize(256),
|
| 49 |
+
transforms.CenterCrop(224),
|
| 50 |
+
transforms.ToTensor(),
|
| 51 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
return transform
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_data(args):
|
| 59 |
+
"""
|
| 60 |
+
Load and prepare datasets from separate directories for training and evaluation.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
args: Command line arguments
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
train_loader: DataLoader for training
|
| 67 |
+
val_loader: DataLoader for validation
|
| 68 |
+
test_loader: DataLoader for testing
|
| 69 |
+
dataset_ref: Reference to the evaluation dataset for class information
|
| 70 |
+
"""
|
| 71 |
+
print(f"Loading training dataset from: {args.train_dir}")
|
| 72 |
+
print(f"Loading evaluation dataset from: {args.eval_dir}")
|
| 73 |
+
|
| 74 |
+
# Get standard transform
|
| 75 |
+
transform = get_transform()
|
| 76 |
+
|
| 77 |
+
# Load training dataset
|
| 78 |
+
train_dataset = datasets.ImageFolder(args.train_dir, transform=transform)
|
| 79 |
+
print(f"Training dataset classes: {train_dataset.classes}")
|
| 80 |
+
print(f"Training dataset size: {len(train_dataset)}")
|
| 81 |
+
|
| 82 |
+
# Load evaluation dataset
|
| 83 |
+
eval_dataset = datasets.ImageFolder(args.eval_dir, transform=transform)
|
| 84 |
+
print(f"Evaluation dataset classes: {eval_dataset.classes}")
|
| 85 |
+
|
| 86 |
+
# Apply class filtering if requested
|
| 87 |
+
excluded_classes = args.exclude_classes.split(",") if args.exclude_classes else None
|
| 88 |
+
if excluded_classes and any(excluded_classes):
|
| 89 |
+
train_dataset = FilteredImageDataset(train_dataset, excluded_classes)
|
| 90 |
+
eval_dataset = FilteredImageDataset(eval_dataset, excluded_classes)
|
| 91 |
+
print(f"After filtering - Classes: {eval_dataset.classes}")
|
| 92 |
+
|
| 93 |
+
print(f"After filtering - Train size: {len(train_dataset)}")
|
| 94 |
+
print(f"After filtering - Eval size: {len(eval_dataset)}")
|
| 95 |
+
|
| 96 |
+
# Split evaluation dataset into validation and test sets
|
| 97 |
+
val_size = int(
|
| 98 |
+
len(eval_dataset) * (args.val_split / (args.val_split + args.test_split))
|
| 99 |
+
)
|
| 100 |
+
test_size = len(eval_dataset) - val_size
|
| 101 |
+
|
| 102 |
+
val_dataset, test_dataset = random_split(eval_dataset, [val_size, test_size])
|
| 103 |
+
|
| 104 |
+
print(
|
| 105 |
+
f"Split sizes - Train: {len(train_dataset)}, "
|
| 106 |
+
f"Validation: {len(val_dataset)}, Test: {len(test_dataset)}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Create data loaders
|
| 110 |
+
train_loader = DataLoader(
|
| 111 |
+
train_dataset,
|
| 112 |
+
batch_size=args.batch_size,
|
| 113 |
+
shuffle=True,
|
| 114 |
+
num_workers=args.num_workers,
|
| 115 |
+
pin_memory=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
val_loader = DataLoader(
|
| 119 |
+
val_dataset,
|
| 120 |
+
batch_size=args.batch_size,
|
| 121 |
+
shuffle=False,
|
| 122 |
+
num_workers=args.num_workers,
|
| 123 |
+
pin_memory=True,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
test_loader = DataLoader(
|
| 127 |
+
test_dataset,
|
| 128 |
+
batch_size=args.batch_size,
|
| 129 |
+
shuffle=False,
|
| 130 |
+
num_workers=args.num_workers,
|
| 131 |
+
pin_memory=True,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Use eval_dataset as the reference for class information
|
| 135 |
+
return train_loader, val_loader, test_loader, eval_dataset
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def train_single_model(args, train_loader, val_loader, test_loader, dataset):
|
| 139 |
+
"""Train a single model specified by the arguments."""
|
| 140 |
+
|
| 141 |
+
print(f"Creating {args.model} model...")
|
| 142 |
+
|
| 143 |
+
# Initialize model creator
|
| 144 |
+
model_creator = EyeDetectionModels(
|
| 145 |
+
num_classes=len(dataset.classes), freeze_layers=(not args.unfreeze_all)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Get model
|
| 149 |
+
if args.model in model_creator.models:
|
| 150 |
+
model = model_creator.models[args.model]()
|
| 151 |
+
else:
|
| 152 |
+
available_models = list(model_creator.models.keys())
|
| 153 |
+
print(
|
| 154 |
+
f"Error: Model '{args.model}' not found. Available models: {available_models}"
|
| 155 |
+
)
|
| 156 |
+
sys.exit(1)
|
| 157 |
+
|
| 158 |
+
# Train and evaluate model
|
| 159 |
+
results = model_train(model, train_loader, val_loader, dataset, epochs=args.epochs)
|
| 160 |
+
|
| 161 |
+
# Test the model
|
| 162 |
+
if results["accuracy"] is not None:
|
| 163 |
+
print("\nEvaluating on test set...")
|
| 164 |
+
evaluator = ClassificationEvaluator(class_names=dataset.classes)
|
| 165 |
+
test_results = evaluator.evaluate_model(model, test_loader)
|
| 166 |
+
print(f"Test accuracy: {test_results['accuracy']:.4f}")
|
| 167 |
+
|
| 168 |
+
# Save model if requested
|
| 169 |
+
if args.save_model:
|
| 170 |
+
save_path = args.save_model
|
| 171 |
+
try:
|
| 172 |
+
torch.save(model.state_dict(), save_path)
|
| 173 |
+
print(f"Model saved to {save_path}")
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"Error saving model: {e}")
|
| 176 |
+
else:
|
| 177 |
+
print("Training failed. Cannot evaluate on test set.")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def compare_multiple_models(args, train_loader, val_loader, test_loader, dataset):
|
| 181 |
+
"""Compare multiple models."""
|
| 182 |
+
|
| 183 |
+
print("Preparing to compare multiple models...")
|
| 184 |
+
|
| 185 |
+
# Initialize model creator
|
| 186 |
+
model_creator = EyeDetectionModels(
|
| 187 |
+
num_classes=len(dataset.classes), freeze_layers=(not args.unfreeze_all)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Get list of models to compare
|
| 191 |
+
model_names = args.compare_models.split(",")
|
| 192 |
+
models = []
|
| 193 |
+
names = []
|
| 194 |
+
|
| 195 |
+
for model_name in model_names:
|
| 196 |
+
model_name = model_name.strip()
|
| 197 |
+
if model_name in model_creator.models:
|
| 198 |
+
print(f"Adding {model_name} to comparison...")
|
| 199 |
+
models.append(model_creator.models[model_name]())
|
| 200 |
+
names.append(model_name)
|
| 201 |
+
else:
|
| 202 |
+
print(f"Warning: Model '{model_name}' not found, skipping.")
|
| 203 |
+
|
| 204 |
+
if not models:
|
| 205 |
+
print("No valid models to compare. Exiting.")
|
| 206 |
+
return
|
| 207 |
+
|
| 208 |
+
# Run comparison
|
| 209 |
+
compare_models(
|
| 210 |
+
models,
|
| 211 |
+
train_loader,
|
| 212 |
+
val_loader,
|
| 213 |
+
test_loader,
|
| 214 |
+
dataset,
|
| 215 |
+
epochs=args.epochs,
|
| 216 |
+
names=names,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def main():
|
| 221 |
+
"""Main function to run the eye disease detection application."""
|
| 222 |
+
|
| 223 |
+
# Set up argument parser with example usage
|
| 224 |
+
parser = argparse.ArgumentParser(
|
| 225 |
+
description="Eye Disease Detection using Deep Learning",
|
| 226 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 227 |
+
epilog="""
|
| 228 |
+
Examples:
|
| 229 |
+
# Train a single model
|
| 230 |
+
python main.py --train-dir "/path/to/augmented_dataset" --eval-dir "/path/to/original_dataset" --model mobilenetv4 --epochs 20 --save-model best_model.pth
|
| 231 |
+
|
| 232 |
+
# Compare multiple models
|
| 233 |
+
python main.py --train-dir "/path/to/augmented_dataset" --eval-dir "/path/to/original_dataset" --compare-models mobilenetv4,levit,efficientvit --epochs 15
|
| 234 |
+
""",
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Dataset and data loading arguments
|
| 238 |
+
data_group = parser.add_argument_group("Data Options")
|
| 239 |
+
data_group.add_argument(
|
| 240 |
+
"--train-dir",
|
| 241 |
+
type=str,
|
| 242 |
+
required=True,
|
| 243 |
+
help="Path to the training dataset directory (Augmented Dataset)",
|
| 244 |
+
)
|
| 245 |
+
data_group.add_argument(
|
| 246 |
+
"--eval-dir",
|
| 247 |
+
type=str,
|
| 248 |
+
required=True,
|
| 249 |
+
help="Path to the evaluation dataset directory (Original Dataset)",
|
| 250 |
+
)
|
| 251 |
+
data_group.add_argument(
|
| 252 |
+
"--batch-size",
|
| 253 |
+
type=int,
|
| 254 |
+
default=32,
|
| 255 |
+
help="Batch size for training and evaluation",
|
| 256 |
+
)
|
| 257 |
+
data_group.add_argument(
|
| 258 |
+
"--val-split",
|
| 259 |
+
type=float,
|
| 260 |
+
default=0.5,
|
| 261 |
+
help="Validation split ratio within evaluation set",
|
| 262 |
+
)
|
| 263 |
+
data_group.add_argument(
|
| 264 |
+
"--test-split",
|
| 265 |
+
type=float,
|
| 266 |
+
default=0.5,
|
| 267 |
+
help="Test split ratio within evaluation set",
|
| 268 |
+
)
|
| 269 |
+
data_group.add_argument(
|
| 270 |
+
"--num-workers",
|
| 271 |
+
type=int,
|
| 272 |
+
default=4,
|
| 273 |
+
help="Number of worker processes for data loading",
|
| 274 |
+
)
|
| 275 |
+
data_group.add_argument(
|
| 276 |
+
"--exclude-classes",
|
| 277 |
+
type=str,
|
| 278 |
+
default=None,
|
| 279 |
+
help="Comma-separated list of class names to exclude",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Model arguments
|
| 283 |
+
model_group = parser.add_argument_group("Model Options")
|
| 284 |
+
model_group.add_argument(
|
| 285 |
+
"--model",
|
| 286 |
+
type=str,
|
| 287 |
+
default="mobilenetv4",
|
| 288 |
+
help="Model architecture to use. Options: mobilenetv4, levit, efficientvit, gernet, regnetx",
|
| 289 |
+
)
|
| 290 |
+
model_group.add_argument(
|
| 291 |
+
"--unfreeze-all", action="store_true", help="Unfreeze all layers for training"
|
| 292 |
+
)
|
| 293 |
+
model_group.add_argument(
|
| 294 |
+
"--compare-models",
|
| 295 |
+
type=str,
|
| 296 |
+
default=None,
|
| 297 |
+
help="Comma-separated list of models to compare",
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Training arguments
|
| 301 |
+
train_group = parser.add_argument_group("Training Options")
|
| 302 |
+
train_group.add_argument(
|
| 303 |
+
"--epochs", type=int, default=20, help="Number of training epochs"
|
| 304 |
+
)
|
| 305 |
+
train_group.add_argument(
|
| 306 |
+
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
| 307 |
+
)
|
| 308 |
+
train_group.add_argument(
|
| 309 |
+
"--save-model", type=str, default=None, help="Path to save the trained model"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Parse arguments
|
| 313 |
+
args = parser.parse_args()
|
| 314 |
+
|
| 315 |
+
# Set random seed for reproducibility
|
| 316 |
+
set_seed(args.seed)
|
| 317 |
+
|
| 318 |
+
# Display GPU information
|
| 319 |
+
if torch.cuda.is_available():
|
| 320 |
+
device_count = torch.cuda.device_count()
|
| 321 |
+
print(f"Using {device_count} GPU{'s' if device_count > 1 else ''}")
|
| 322 |
+
for i in range(device_count):
|
| 323 |
+
print(f" Device {i}: {torch.cuda.get_device_name(i)}")
|
| 324 |
+
else:
|
| 325 |
+
print("No GPU available, using CPU")
|
| 326 |
+
|
| 327 |
+
# Load data
|
| 328 |
+
train_loader, val_loader, test_loader, dataset = load_data(args)
|
| 329 |
+
|
| 330 |
+
# Check if comparing multiple models
|
| 331 |
+
if args.compare_models:
|
| 332 |
+
compare_multiple_models(args, train_loader, val_loader, test_loader, dataset)
|
| 333 |
+
else:
|
| 334 |
+
train_single_model(args, train_loader, val_loader, test_loader, dataset)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
# Example usage for direct execution:
|
| 339 |
+
# python main.py --train-dir "/kaggle/input/eye-disease-image-dataset/Augmented Dataset/Augmented Dataset" \
|
| 340 |
+
# --eval-dir "/kaggle/input/eye-disease-image-dataset/Original Dataset/Original Dataset" \
|
| 341 |
+
# --model mobilenetv4 --epochs 10
|
| 342 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "eyediseasedetection"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12.9"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"gradio>=5.29.0",
|
| 9 |
+
"matplotlib>=3.10.3",
|
| 10 |
+
"pandas>=2.2.3",
|
| 11 |
+
"scikit-learn>=1.6.1",
|
| 12 |
+
"seaborn>=0.13.2",
|
| 13 |
+
"timm>=1.0.15",
|
| 14 |
+
"torch>=2.7.0",
|
| 15 |
+
"torchaudio>=2.7.0",
|
| 16 |
+
"torchvision>=0.22.0",
|
| 17 |
+
"tqdm>=4.67.1",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[tool.uv.sources]
|
| 21 |
+
torch = [
|
| 22 |
+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 23 |
+
]
|
| 24 |
+
torchvision = [
|
| 25 |
+
{ index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[[tool.uv.index]]
|
| 29 |
+
name = "pytorch-cpu"
|
| 30 |
+
url = "https://download.pytorch.org/whl/cpu"
|
| 31 |
+
explicit = true
|
| 32 |
+
|
| 33 |
+
[[tool.uv.index]]
|
| 34 |
+
name = "pytorch-cu128"
|
| 35 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 36 |
+
explicit = true
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.7.0
|
| 2 |
+
torchvision>=0.22.0
|
| 3 |
+
torchaudio>=2.7.0
|
| 4 |
+
numpy>=1.26.0
|
| 5 |
+
Pillow>=10.0.0
|
| 6 |
+
gradio>=4.11.0
|
| 7 |
+
matplotlib>=3.8.0
|
| 8 |
+
seaborn>=0.13.0
|
| 9 |
+
scikit-learn>=1.4.0
|
| 10 |
+
tqdm>=4.66.0
|
| 11 |
+
timm>=1.0.0
|
training.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.11.11","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":10951558,"sourceType":"datasetVersion","datasetId":6812365}],"dockerImageVersionId":31011,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nimport time\nimport random\nimport copy\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchvision.transforms as transforms\nfrom torchvision import transforms, datasets\nimport torchvision.models as models\nimport timm\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nfrom PIL import Image\nfrom torch.utils.data import DataLoader, random_split, Subset, Dataset\nfrom sklearn.metrics import (\n accuracy_score, confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve, average_precision_score\n)\nfrom sklearn.preprocessing import label_binarize\nfrom tqdm import tqdm\nimport gc\n\n\n# Set device\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Using device: {device}\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:11.680667Z","iopub.execute_input":"2025-05-07T15:22:11.680991Z","iopub.status.idle":"2025-05-07T15:22:16.455808Z","shell.execute_reply.started":"2025-05-07T15:22:11.680969Z","shell.execute_reply":"2025-05-07T15:22:16.455137Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Modified FilteredImageDataset class with Pterygium filtering\nclass FilteredImageDataset(Dataset):\n def __init__(self, dataset, excluded_classes=None):\n \"\"\"\n Create a filtered dataset that excludes specific classes.\n \n Args:\n dataset: Original dataset (ImageFolder or similar)\n excluded_classes: List of class names to exclude (e.g., [\"Pterygium\"])\n \"\"\"\n self.dataset = dataset\n self.excluded_classes = excluded_classes or []\n \n # Get original class information\n self.orig_classes = dataset.classes\n self.orig_class_to_idx = dataset.class_to_idx\n \n # Create indices of samples to keep (excluding specified classes)\n self.indices = []\n for idx, (_, target) in enumerate(dataset.samples):\n class_name = self.orig_classes[target]\n if class_name not in self.excluded_classes:\n self.indices.append(idx)\n \n # Create new class mapping without excluded classes\n remaining_classes = [c for c in self.orig_classes if c not in self.excluded_classes]\n self.classes = remaining_classes\n self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)}\n self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}\n \n # Create a mapping from old indices to new indices\n self.target_mapping = {}\n for old_class, old_idx in self.orig_class_to_idx.items():\n if old_class in self.class_to_idx:\n self.target_mapping[old_idx] = self.class_to_idx[old_class]\n \n print(f\"Filtered out classes: {self.excluded_classes}\")\n print(f\"Remaining classes: {self.classes}\")\n print(f\"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}\")\n\n def __getitem__(self, index):\n \"\"\"Get item from the filtered dataset with remapped class labels.\"\"\"\n orig_idx = self.indices[index]\n img, old_target = self.dataset[orig_idx]\n \n # Remap target to new class index\n new_target = self.target_mapping[old_target]\n \n return img, new_target\n\n def __len__(self):\n \"\"\"Return the number of samples in the filtered dataset.\"\"\"\n return len(self.indices)\n \n # Allow transform to be updated\n def set_transform(self, transform):\n \"\"\"Update the transform for the dataset.\"\"\"\n self.dataset.transform = transform","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.457202Z","iopub.execute_input":"2025-05-07T15:22:16.457579Z","iopub.status.idle":"2025-05-07T15:22:16.465134Z","shell.execute_reply.started":"2025-05-07T15:22:16.457560Z","shell.execute_reply":"2025-05-07T15:22:16.464502Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Model Definition","metadata":{}},{"cell_type":"code","source":"# Early stopping class\nclass EarlyStopping:\n def __init__(self, patience=5, delta=0):\n self.patience = patience\n self.delta = delta\n self.counter = 0\n self.best_score = None\n self.early_stop = False\n \n def __call__(self, val_loss):\n score = -val_loss\n \n if self.best_score is None:\n self.best_score = score\n elif score < self.best_score + self.delta:\n self.counter += 1\n if self.counter >= self.patience:\n self.early_stop = True\n else:\n self.best_score = score\n self.counter = 0\n\n# Model architecture functions\ndef _get_feature_blocks(model):\n \"\"\"\n Utility: locate the main feature blocks container in a timm model.\n Returns a list-like module of blocks.\n \"\"\"\n for attr in ('features', 'blocks', 'layers', 'stem'): # common container names\n if hasattr(model, attr):\n return getattr(model, attr)\n # fallback: collect all children except classifier/head\n return list(model.children())[:-1]\n\ndef _freeze_except_last_n(blocks, n):\n total = len(blocks)\n for idx, block in enumerate(blocks):\n requires = (idx >= total - n)\n for p in block.parameters():\n p.requires_grad = requires\n\ndef get_model_mobilenetv4(num_classes, freeze_layers=True, device='cuda'):\n model = timm.create_model('mobilenetv4_conv_medium.e500_r256_in1k', pretrained=True)\n if freeze_layers:\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n # replace classifier\n in_features = model.classifier.in_features\n model.classifier = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n\ndef get_model_levit(num_classes, freeze_layers=True, device='cuda'):\n model = timm.create_model('levit_128s.fb_dist_in1k', pretrained=True)\n if freeze_layers:\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n # Attempt to extract in_features from model.head or classifier\n head = getattr(model, 'head_dist', None) or getattr(model, 'classifier', None)\n linear = getattr(head, 'linear')\n in_features = 384\n model.head = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n model.head_dist = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n\ndef get_model_efficientvit(num_classes, freeze_layers=True, device='cuda'):\n model = timm.create_model('efficientvit_m1.r224_in1k', pretrained=True)\n if freeze_layers:\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n # handle different head naming\n head = getattr(model, 'head', None)\n print(head)\n linear = getattr(head, 'linear')\n in_features = 192\n model.head.linear = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n \ndef get_model_gernet(num_classes, freeze_layers=True, device='cuda'):\n \"\"\"\n Load and configure a GENet (General and Efficient Network) model with customizable classifier.\n \n Args:\n num_classes: Number of output classes\n freeze_layers: If True, freeze all but the last 2 blocks\n device: Device to load the model on ('cuda' or 'cpu')\n \n Returns:\n Configured GENet model\n \"\"\"\n model = timm.create_model('gernet_s.idstcv_in1k', pretrained=True)\n \n if freeze_layers:\n # For GENet, we need to specifically handle its structure\n # It typically has a 'stem' and 'stages' structure\n if hasattr(model, 'stem') and hasattr(model, 'stages'):\n # Freeze stem completely\n for param in model.stem.parameters():\n param.requires_grad = False\n \n # Freeze all stages except the last two\n stages = list(model.stages.children())\n total_stages = len(stages)\n for i, stage in enumerate(stages):\n requires_grad = (i >= total_stages - 2)\n for param in stage.parameters():\n param.requires_grad = requires_grad\n else:\n # Fallback to generic approach\n blocks = _get_feature_blocks(model)\n _freeze_except_last_n(blocks, 2)\n \n # Replace classifier\n in_features = model.head.fc.in_features\n model.head.fc = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)\n\ndef get_model_regnetx(num_classes, freeze_layers=True, device='cuda'):\n \"\"\"\n Load and configure a RegNetX model with customizable classifier.\n \n Args:\n num_classes: Number of output classes\n freeze_layers: If True, freeze all but the last 2 blocks\n device: Device to load the model on ('cuda' or 'cpu')\n \n Returns:\n Configured RegNetX model\n \"\"\"\n model = timm.create_model('regnetx_008.tv2_in1k', pretrained=True)\n \n if freeze_layers:\n # Looking at the error, we need to inspect the model structure carefully\n # Print the model structure to understand it better in real use\n # print(model)\n \n # Direct approach: check the model structure and freeze components individually\n # First, freeze all parameters\n for param in model.parameters():\n param.requires_grad = False\n \n # Then unfreeze the last few layers manually based on RegNetX structure\n # RegNetX typically has 'stem' + 'trunk' structure in timm\n if hasattr(model, 'trunk'):\n # Unfreeze final stages of the trunk\n trunk_blocks = list(model.trunk.children())\n # Unfreeze approximately last 25% of trunk blocks\n unfreeze_from = max(0, int(len(trunk_blocks) * 0.75))\n for i in range(unfreeze_from, len(trunk_blocks)):\n for param in trunk_blocks[i].parameters():\n param.requires_grad = True\n \n # Always unfreeze the classifier/head for fine-tuning\n for param in model.head.parameters():\n param.requires_grad = True\n \n # Replace classifier\n in_features = model.head.fc.in_features\n model.head.fc = nn.Sequential(\n nn.Linear(in_features, 512),\n nn.ReLU(inplace=True),\n nn.Dropout(0.4),\n nn.Linear(512, num_classes)\n )\n return model.to(device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.465842Z","iopub.execute_input":"2025-05-07T15:22:16.466084Z","iopub.status.idle":"2025-05-07T15:22:16.492719Z","shell.execute_reply.started":"2025-05-07T15:22:16.466067Z","shell.execute_reply":"2025-05-07T15:22:16.491782Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Training function","metadata":{}},{"cell_type":"code","source":"def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, early_stopping, epochs=15, use_ddp=False):\n \"\"\"\n Train the model and perform validation using multiple GPUs.\n Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.\n \n Args:\n model: Model to train\n criterion: Loss function\n optimizer: Optimizer for training\n scheduler: Learning rate scheduler\n train_loader: DataLoader for training data\n val_loader: DataLoader for validation data\n early_stopping: Early stopping handler\n epochs: Maximum number of epochs to train\n use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False)\n \"\"\"\n # Check available GPUs\n num_gpus = torch.cuda.device_count()\n if num_gpus < 2:\n print(f\"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources.\")\n else:\n print(f\"Using {num_gpus} GPUs for training\")\n \n # Setup device and model\n if num_gpus >= 2:\n if use_ddp:\n # For DistributedDataParallel\n import torch.distributed as dist\n from torch.nn.parallel import DistributedDataParallel as DDP\n \n # Initialize process group\n dist.init_process_group(backend='nccl')\n local_rank = dist.get_rank()\n torch.cuda.set_device(local_rank)\n device = torch.device(f\"cuda:{local_rank}\")\n \n model = model.to(device)\n model = DDP(model, device_ids=[local_rank])\n else:\n # For DataParallel (simpler to use)\n device = torch.device(\"cuda:0\")\n model = model.to(device)\n model = torch.nn.DataParallel(model)\n else:\n # Single GPU\n device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n model = model.to(device)\n \n train_losses = []\n val_losses = []\n train_accs = []\n val_accs = []\n \n # Store validation predictions and labels for final evaluation\n all_val_labels = []\n all_val_preds = []\n all_val_scores = []\n \n for epoch in range(epochs):\n print(f\"Epoch {epoch+1}/{epochs}\")\n \n # Training phase\n model.train()\n running_loss = 0.0\n correct = 0\n total = 0\n \n for inputs, labels in tqdm(train_loader, desc=\"Training\"):\n inputs, labels = inputs.to(device), labels.to(device)\n \n optimizer.zero_grad()\n outputs = model(inputs)\n loss = criterion(outputs, labels)\n loss.backward()\n optimizer.step()\n \n running_loss += loss.item() * inputs.size(0)\n _, predicted = torch.max(outputs, 1)\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n \n epoch_train_loss = running_loss / len(train_loader.dataset)\n epoch_train_acc = correct / total\n train_losses.append(epoch_train_loss)\n train_accs.append(epoch_train_acc)\n \n # Validation phase\n model.eval()\n running_loss = 0.0\n correct = 0\n total = 0\n \n all_labels = []\n all_preds = []\n all_scores = []\n \n with torch.no_grad():\n for inputs, labels in tqdm(val_loader, desc=\"Validation\"):\n inputs, labels = inputs.to(device), labels.to(device)\n outputs = model(inputs)\n loss = criterion(outputs, labels)\n \n running_loss += loss.item() * inputs.size(0)\n probs = F.softmax(outputs, dim=1)\n _, predicted = torch.max(outputs, 1)\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n \n all_labels.extend(labels.cpu().numpy().tolist())\n all_preds.extend(predicted.cpu().numpy().tolist())\n all_scores.append(probs.cpu().numpy())\n \n epoch_val_loss = running_loss / len(val_loader.dataset)\n epoch_val_acc = correct / total\n val_losses.append(epoch_val_loss)\n val_accs.append(epoch_val_acc)\n \n all_scores = np.vstack(all_scores) if all_scores else np.array([])\n \n # Store validation results for the final epoch\n all_val_labels = all_labels\n all_val_preds = all_preds\n all_val_scores = all_scores\n \n # Update learning rate scheduler\n scheduler.step(epoch_val_loss)\n \n print(f\"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}\")\n print(f\"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}\")\n print(f\"Learning rate: {optimizer.param_groups[0]['lr']:.6f}\")\n \n # Check early stopping\n early_stopping(epoch_val_loss)\n if early_stopping.early_stop:\n print(\"Early stopping triggered!\")\n break\n \n # Free up memory\n del all_labels, all_preds, all_scores\n gc.collect()\n torch.cuda.empty_cache()\n \n # Clean up DDP if used\n if num_gpus >= 2 and use_ddp:\n dist.destroy_process_group()\n \n return model, train_losses, val_losses, train_accs, val_accs, all_val_labels, all_val_preds, all_val_scores\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.494736Z","iopub.execute_input":"2025-05-07T15:22:16.495264Z","iopub.status.idle":"2025-05-07T15:22:16.517964Z","shell.execute_reply.started":"2025-05-07T15:22:16.495245Z","shell.execute_reply":"2025-05-07T15:22:16.517204Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Evaluation plotting functions","metadata":{}},{"cell_type":"code","source":"def plot_roc_curves(y_true, y_scores, class_names):\n \"\"\"\n Plot ROC curves for multi-class classification.\n \n Parameters:\n - y_true: true labels\n - y_scores: predicted probability scores from model\n - class_names: list of class names\n \"\"\"\n # Ensure inputs are numpy arrays\n if torch.is_tensor(y_true):\n y_true = y_true.cpu().numpy()\n if torch.is_tensor(y_scores):\n y_scores = y_scores.cpu().numpy()\n \n n_classes = len(class_names)\n \n # Binarize the labels for one-vs-rest ROC calculation\n y_true_bin = label_binarize(y_true, classes=range(n_classes))\n \n # Compute ROC curve and ROC area for each class\n fpr = {}\n tpr = {}\n roc_auc = {}\n \n plt.figure(figsize=(12, 8))\n \n for i in range(n_classes):\n fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])\n roc_auc[i] = auc(fpr[i], tpr[i])\n \n plt.plot(fpr[i], tpr[i], lw=2,\n label=f'{class_names[i]} (area = {roc_auc[i]:.2f})')\n \n # Plot the diagonal (random classifier)\n plt.plot([0, 1], [0, 1], 'k--', lw=2)\n \n # Calculate and plot micro-average ROC curve\n fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(y_true_bin.ravel(), y_scores.ravel())\n roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n plt.plot(fpr[\"micro\"], tpr[\"micro\"], \n label=f'Micro-average (area = {roc_auc[\"micro\"]:.2f})', \n lw=2, linestyle=':', color='deeppink')\n \n plt.xlim([0.0, 1.0])\n plt.ylim([0.0, 1.05])\n plt.xlabel('False Positive Rate')\n plt.ylabel('True Positive Rate')\n plt.title('ROC Curves')\n plt.legend(loc=\"lower right\")\n plt.grid(True, alpha=0.3)\n plt.tight_layout()\n plt.show()\n \n # Return the AUC values for reporting\n return roc_auc\n\ndef plot_pr_curves(y_true, y_scores, class_names):\n \"\"\"\n Plot Precision-Recall curves for multi-class classification.\n \n Parameters:\n - y_true: true labels\n - y_scores: predicted probability scores from model\n - class_names: list of class names\n \"\"\"\n # Ensure inputs are numpy arrays\n if torch.is_tensor(y_true):\n y_true = y_true.cpu().numpy()\n if torch.is_tensor(y_scores):\n y_scores = y_scores.cpu().numpy()\n \n n_classes = len(class_names)\n \n # Binarize the labels\n y_true_bin = label_binarize(y_true, classes=range(n_classes))\n \n # Compute PR curve and average precision for each class\n precision = {}\n recall = {}\n avg_precision = {}\n \n plt.figure(figsize=(12, 8))\n \n for i in range(n_classes):\n precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_scores[:, i])\n avg_precision[i] = average_precision_score(y_true_bin[:, i], y_scores[:, i])\n \n plt.plot(recall[i], precision[i], lw=2,\n label=f'{class_names[i]} (AP = {avg_precision[i]:.2f})')\n \n # Calculate and plot micro-average PR curve\n precision[\"micro\"], recall[\"micro\"], _ = precision_recall_curve(\n y_true_bin.ravel(), y_scores.ravel())\n avg_precision[\"micro\"] = average_precision_score(y_true_bin.ravel(), y_scores.ravel())\n \n plt.plot(recall[\"micro\"], precision[\"micro\"],\n label=f'Micro-average (AP = {avg_precision[\"micro\"]:.2f})',\n lw=2, linestyle=':', color='deeppink')\n \n plt.xlim([0.0, 1.0])\n plt.ylim([0.0, 1.05])\n plt.xlabel('Recall')\n plt.ylabel('Precision')\n plt.title('Precision-Recall Curves')\n plt.legend(loc=\"best\")\n plt.grid(True, alpha=0.3)\n plt.tight_layout()\n plt.show()\n \n # Return the average precision values for reporting\n return avg_precision\n\ndef plot_accuracy_and_loss(train_losses, val_losses, train_accs, val_accs):\n plt.figure(figsize=(12, 5))\n # Accuracy curve\n plt.subplot(1, 2, 1)\n plt.plot(train_accs, label=\"Train Accuracy\")\n plt.plot(val_accs, label=\"Validation Accuracy\")\n plt.xlabel(\"Epochs\")\n plt.ylabel(\"Accuracy\")\n plt.title(\"Accuracy Curve\")\n plt.legend()\n plt.grid(True)\n \n # Loss curve\n plt.subplot(1, 2, 2)\n plt.plot(train_losses, label=\"Train Loss\")\n plt.plot(val_losses, label=\"Validation Loss\")\n plt.xlabel(\"Epochs\")\n plt.ylabel(\"Loss\")\n plt.title(\"Loss Curve\")\n plt.legend()\n plt.grid(True)\n \n plt.tight_layout()\n plt.show()\n\ndef plot_confusion_matrix(y_true, y_pred, class_names):\n # Ensure we're working with numpy arrays\n y_true = np.array(y_true)\n y_pred = np.array(y_pred)\n \n # Get unique values in both arrays\n unique_values = np.unique(np.concatenate([y_true, y_pred]))\n print(f\"Unique values in confusion matrix data: {unique_values}\")\n \n # Create the confusion matrix with explicit labels\n cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))\n \n plt.figure(figsize=(10, 8))\n sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\",\n xticklabels=class_names, yticklabels=class_names)\n plt.title(\"Confusion Matrix\")\n plt.xlabel(\"Predicted\")\n plt.ylabel(\"True\")\n plt.tight_layout()\n plt.show()\n\ndef plot_per_class_accuracy(y_true, y_pred, class_names):\n # Convert to numpy arrays\n y_true = np.array(y_true)\n y_pred = np.array(y_pred)\n \n # Get number of expected classes\n num_classes = len(class_names)\n \n # Create the confusion matrix with explicit labels\n cm = confusion_matrix(y_true, y_pred, labels=range(num_classes))\n \n # Calculate per-class accuracy\n per_class_accuracy = np.zeros(num_classes)\n for i in range(num_classes):\n if i < cm.shape[0] and np.sum(cm[i, :]) > 0:\n per_class_accuracy[i] = cm[i, i] / np.sum(cm[i, :])\n \n # Create the bar plot\n plt.figure(figsize=(14, 7))\n plt.bar(range(num_classes), per_class_accuracy, color=\"skyblue\")\n plt.xticks(range(num_classes), class_names, rotation=45, ha='right')\n plt.xlabel(\"Classes\")\n plt.ylabel(\"Accuracy\")\n plt.title(\"Per-Class Accuracy\")\n plt.tight_layout()\n plt.show()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.518997Z","iopub.execute_input":"2025-05-07T15:22:16.519284Z","iopub.status.idle":"2025-05-07T15:22:16.546032Z","shell.execute_reply.started":"2025-05-07T15:22:16.519262Z","shell.execute_reply":"2025-05-07T15:22:16.545330Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from sklearn.metrics import cohen_kappa_score\n\ndef compute_classification_metrics(y_true, y_pred, y_scores, num_classes, class_names, model_name=\"\"):\n \"\"\"\n Compute comprehensive classification metrics including ROC AUC, PR AUC, and Cohen's Kappa.\n \n Parameters:\n - y_true: true labels\n - y_pred: predicted labels\n - y_scores: predicted probability scores from model\n - num_classes: number of classes\n - class_names: list of class names\n - model_name: name of the model (for display purposes)\n \n Returns:\n - accuracy: overall accuracy score\n - report_dict: classification report as dictionary\n - roc_auc_dict: ROC AUC scores by class\n - pr_auc_dict: PR AUC scores by class\n - kappa: Cohen's Kappa score\n \"\"\"\n # Calculate accuracy\n accuracy = accuracy_score(y_true, y_pred)\n print(f\"Overall Accuracy: {accuracy:.4f}\")\n \n # Calculate and display Cohen's Kappa\n kappa = cohen_kappa_score(y_true, y_pred)\n print(f\"Cohen's Kappa Score: {kappa:.4f}\")\n \n # Generate classification report\n report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)\n \n # Print formatted classification report\n print(\"\\nClassification Report:\")\n print(classification_report(y_true, y_pred, target_names=class_names))\n \n # Calculate ROC curves and AUC for each class\n print(\"\\nCalculating ROC curves...\")\n roc_auc_dict = plot_roc_curves(y_true, y_scores, class_names)\n \n # Calculate PR curves and AUC for each class\n print(\"\\nCalculating PR curves...\")\n pr_auc_dict = plot_pr_curves(y_true, y_scores, class_names)\n \n # Return metrics for comparison\n return accuracy, report, roc_auc_dict, pr_auc_dict, kappa\n\n# Also update evaluate_on_test_set to include kappa\ndef evaluate_on_test_set(model, test_loader, dataset):\n \"\"\"Evaluate a trained model on test dataset\"\"\"\n class_names = dataset.classes\n num_classes = len(class_names)\n \n model.eval()\n device = next(model.parameters()).device\n \n all_labels = []\n all_preds = []\n all_scores = []\n \n with torch.no_grad():\n for inputs, labels in test_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n outputs = model(inputs)\n _, preds = torch.max(outputs, 1)\n \n all_labels.extend(labels.cpu().numpy())\n all_preds.extend(preds.cpu().numpy())\n all_scores.append(torch.nn.functional.softmax(outputs, dim=1).cpu().numpy())\n \n all_scores = np.vstack(all_scores)\n \n # Compute metrics including kappa\n accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = compute_classification_metrics(\n all_labels, all_preds, all_scores, num_classes, class_names)\n \n # Build results dictionary with kappa\n results = {\n 'accuracy': accuracy,\n 'report': report_dict,\n 'roc_auc': roc_auc_dict,\n 'pr_auc': pr_auc_dict,\n 'kappa': kappa\n }\n \n return results","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.547383Z","iopub.execute_input":"2025-05-07T15:22:16.547662Z","iopub.status.idle":"2025-05-07T15:22:16.571142Z","shell.execute_reply.started":"2025-05-07T15:22:16.547638Z","shell.execute_reply":"2025-05-07T15:22:16.570633Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Main training function","metadata":{}},{"cell_type":"code","source":"# Update the model_train function to include kappa in the results\ndef model_train(model, train_loader, val_loader, dataset, epochs=20):\n model_name = type(model).__name__\n if hasattr(model, 'pretrained_cfg') and 'name' in model.pretrained_cfg:\n model_name = model.pretrained_cfg['name']\n \n print(f\"\\n{'='*20} Training {model_name} {'='*20}\\n\")\n \n class_names = dataset.classes\n num_classes = len(class_names)\n learning_rate = 0.001\n \n try:\n optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)\n early_stopping = EarlyStopping(patience=5)\n \n model, train_losses, val_losses, train_accs, val_accs, val_labels, val_preds, val_scores = train_model(\n model, nn.CrossEntropyLoss(), optimizer, scheduler,\n train_loader, val_loader, early_stopping, epochs=epochs, use_ddp=False\n )\n \n print(f\"\\n{'='*20} Evaluation for {model_name} {'='*20}\\n\")\n \n # Plot training curves\n plot_accuracy_and_loss(train_losses, val_losses, train_accs, val_accs)\n \n # Process validation predictions and labels\n try:\n plot_confusion_matrix(val_labels, val_preds, class_names)\n plot_per_class_accuracy(val_labels, val_preds, class_names)\n \n # Get metrics from the updated function including kappa\n accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = compute_classification_metrics(\n val_labels, val_preds, val_scores, num_classes, class_names, model_name)\n \n # Build a results dictionary including kappa\n results = {\n 'accuracy': accuracy,\n 'report': report_dict,\n 'roc_auc': roc_auc_dict,\n 'pr_auc': pr_auc_dict,\n 'kappa': kappa\n }\n \n return results\n except Exception as viz_error:\n print(f\"Error in visualization: {viz_error}\")\n import traceback\n traceback.print_exc()\n return {'accuracy': None}\n \n except Exception as e:\n print(f'Error occurred when training {model_name}: {e}')\n import traceback\n traceback.print_exc()\n return {'accuracy': None}\n finally:\n # Clean up memory\n if 'optimizer' in locals():\n del optimizer\n if 'scheduler' in locals():\n del scheduler\n if 'early_stopping' in locals():\n del early_stopping\n if 'train_losses' in locals():\n del train_losses\n del val_losses\n del train_accs\n del val_accs\n del val_labels\n del val_preds\n del val_scores\n \n gc.collect()\n torch.cuda.empty_cache()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.571806Z","iopub.execute_input":"2025-05-07T15:22:16.572176Z","iopub.status.idle":"2025-05-07T15:22:16.594545Z","shell.execute_reply.started":"2025-05-07T15:22:16.572146Z","shell.execute_reply":"2025-05-07T15:22:16.593850Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Comparison Function","metadata":{}},{"cell_type":"code","source":"def compare_models(models, train_loader, val_loader, test_loader, dataset, epochs=20, names=None):\n if names is None:\n names = [f\"Model {i+1}\" for i in range(len(models))]\n \n val_results = {}\n test_results = {}\n best_model_obj = None\n best_accuracy = -1\n best_model_name = \"\"\n \n # Summary dictionaries for metrics\n val_roc_auc_summary = {}\n test_roc_auc_summary = {}\n val_pr_auc_summary = {}\n test_pr_auc_summary = {}\n val_kappa_summary = {}\n test_kappa_summary = {}\n \n for i, (model, name) in enumerate(zip(models, names)):\n print(f\"\\n\\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\\n\")\n model_results = model_train(model, train_loader, val_loader, dataset, epochs)\n \n # Extract accuracy from results\n accuracy = model_results.get('accuracy')\n val_results[name] = accuracy\n \n # Extract and store metrics\n if 'roc_auc' in model_results and 'micro' in model_results['roc_auc']:\n val_roc_auc_summary[name] = model_results['roc_auc']['micro']\n else:\n val_roc_auc_summary[name] = None\n \n if 'pr_auc' in model_results and 'micro' in model_results['pr_auc']:\n val_pr_auc_summary[name] = model_results['pr_auc']['micro']\n else:\n val_pr_auc_summary[name] = None\n \n # Store kappa score\n if 'kappa' in model_results:\n val_kappa_summary[name] = model_results['kappa']\n else:\n val_kappa_summary[name] = None\n \n # Evaluate on test set\n if accuracy is not None:\n print(f\"\\n{'='*20} Testing {name} on Test Set {'='*20}\\n\")\n test_model_results = evaluate_on_test_set(model, test_loader, dataset)\n \n # Extract accuracy from test results\n test_accuracy = test_model_results.get('accuracy')\n test_results[name] = test_accuracy\n \n # Extract and store test metrics\n if 'roc_auc' in test_model_results and 'micro' in test_model_results['roc_auc']:\n test_roc_auc_summary[name] = test_model_results['roc_auc']['micro']\n else:\n test_roc_auc_summary[name] = None\n \n if 'pr_auc' in test_model_results and 'micro' in test_model_results['pr_auc']:\n test_pr_auc_summary[name] = test_model_results['pr_auc']['micro']\n else:\n test_pr_auc_summary[name] = None\n \n # Store test kappa score\n if 'kappa' in test_model_results:\n test_kappa_summary[name] = test_model_results['kappa']\n else:\n test_kappa_summary[name] = None\n \n # Track best model\n if test_accuracy > best_accuracy:\n best_accuracy = test_accuracy\n best_model_obj = copy.deepcopy(model)\n best_model_name = name\n \n # Print comprehensive comparison\n print(\"\\n\\n\" + \"=\"*100)\n print(\"COMPREHENSIVE MODEL COMPARISON\")\n print(\"=\"*100)\n print(f\"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}\")\n print(\"-\"*100)\n \n for name in val_results.keys():\n val_acc = val_results[name]\n test_acc = test_results.get(name, None)\n val_roc = val_roc_auc_summary.get(name, None)\n test_roc = test_roc_auc_summary.get(name, None)\n val_pr = val_pr_auc_summary.get(name, None)\n test_pr = test_pr_auc_summary.get(name, None)\n val_kappa = val_kappa_summary.get(name, None)\n test_kappa = test_kappa_summary.get(name, None)\n \n # Format values for display\n val_acc_str = f\"{val_acc:.4f}\" if val_acc is not None else \"Failed\"\n test_acc_str = f\"{test_acc:.4f}\" if test_acc is not None else \"N/A\"\n val_roc_str = f\"{val_roc:.4f}\" if val_roc is not None else \"N/A\"\n test_roc_str = f\"{test_roc:.4f}\" if test_roc is not None else \"N/A\"\n val_pr_str = f\"{val_pr:.4f}\" if val_pr is not None else \"N/A\"\n test_pr_str = f\"{test_pr:.4f}\" if test_pr is not None else \"N/A\"\n val_kappa_str = f\"{val_kappa:.4f}\" if val_kappa is not None else \"N/A\"\n test_kappa_str = f\"{test_kappa:.4f}\" if test_kappa is not None else \"N/A\"\n \n print(f\"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}\")\n \n # Identify best model based on test metrics\n if test_results:\n # Best model by accuracy\n best_acc_model = max(test_results.items(), key=lambda x: x[1] if x[1] is not None else -1)\n print(f\"\\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})\")\n \n # Best model by ROC AUC (if available)\n if any(v is not None for v in test_roc_auc_summary.values()):\n best_roc_model = max(\n [(k, v) for k, v in test_roc_auc_summary.items() if v is not None], \n key=lambda x: x[1] if x[1] is not None else -1\n )\n print(f\"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})\")\n \n # Best model by PR AUC (if available)\n if any(v is not None for v in test_pr_auc_summary.values()):\n best_pr_model = max(\n [(k, v) for k, v in test_pr_auc_summary.items() if v is not None], \n key=lambda x: x[1] if x[1] is not None else -1\n )\n print(f\"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})\")\n \n # Best model by Kappa (if available)\n if any(v is not None for v in test_kappa_summary.values()):\n best_kappa_model = max(\n [(k, v) for k, v in test_kappa_summary.items() if v is not None], \n key=lambda x: x[1] if x[1] is not None else -1\n )\n print(f\"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})\")\n \n # Save the best model (by accuracy)\n if best_model_obj is not None:\n try:\n model_save_path = f\"best_model_{best_model_name.lower().replace(' ', '_')}.pth\"\n torch.save(best_model_obj.state_dict(), model_save_path)\n print(f\"Best model saved to {model_save_path}\")\n except Exception as save_error:\n print(f\"Error saving best model: {save_error}\")\n else:\n print(\"\\nNo models successfully completed testing.\")\n \n print(\"=\"*100)\n \n # Visualize comparison\n try:\n # Create bar charts comparing different metrics\n plot_model_comparison(val_results, test_results, val_roc_auc_summary, \n test_roc_auc_summary, val_pr_auc_summary, test_pr_auc_summary,\n val_kappa_summary, test_kappa_summary)\n except Exception as viz_error:\n print(f\"Error in comparison visualization: {viz_error}\")\n import traceback\n traceback.print_exc()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.595427Z","iopub.execute_input":"2025-05-07T15:22:16.595661Z","iopub.status.idle":"2025-05-07T15:22:16.620099Z","shell.execute_reply.started":"2025-05-07T15:22:16.595635Z","shell.execute_reply":"2025-05-07T15:22:16.619560Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def plot_model_comparison(val_acc, test_acc, val_roc, test_roc, val_pr, test_pr, val_kappa, test_kappa):\n \"\"\"\n Create visualizations to compare model performance across different metrics including Cohen's Kappa.\n \"\"\"\n # Get the list of model names (should be the same across all dictionaries)\n models = list(val_acc.keys())\n \n # Create a figure with 4 subplots for Accuracy, ROC AUC, PR AUC, and Kappa\n fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(12, 24))\n \n # Plot Accuracy\n x = np.arange(len(models))\n width = 0.35\n \n val_acc_values = [val_acc.get(model, None) for model in models]\n test_acc_values = [test_acc.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_acc_values = [v if v is not None else float('nan') for v in val_acc_values]\n test_acc_values = [v if v is not None else float('nan') for v in test_acc_values]\n \n ax1.bar(x - width/2, val_acc_values, width, label='Validation', color='skyblue')\n ax1.bar(x + width/2, test_acc_values, width, label='Test', color='salmon')\n ax1.set_ylabel('Accuracy')\n ax1.set_title('Model Accuracy Comparison')\n ax1.set_xticks(x)\n ax1.set_xticklabels(models, rotation=45, ha='right')\n ax1.legend()\n ax1.grid(True, alpha=0.3)\n \n # Plot ROC AUC\n val_roc_values = [val_roc.get(model, None) for model in models]\n test_roc_values = [test_roc.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_roc_values = [v if v is not None else float('nan') for v in val_roc_values]\n test_roc_values = [v if v is not None else float('nan') for v in test_roc_values]\n \n ax2.bar(x - width/2, val_roc_values, width, label='Validation', color='skyblue')\n ax2.bar(x + width/2, test_roc_values, width, label='Test', color='salmon')\n ax2.set_ylabel('ROC AUC')\n ax2.set_title('Model ROC AUC Comparison')\n ax2.set_xticks(x)\n ax2.set_xticklabels(models, rotation=45, ha='right')\n ax2.legend()\n ax2.grid(True, alpha=0.3)\n \n # Plot PR AUC\n val_pr_values = [val_pr.get(model, None) for model in models]\n test_pr_values = [test_pr.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_pr_values = [v if v is not None else float('nan') for v in val_pr_values]\n test_pr_values = [v if v is not None else float('nan') for v in test_pr_values]\n \n ax3.bar(x - width/2, val_pr_values, width, label='Validation', color='skyblue')\n ax3.bar(x + width/2, test_pr_values, width, label='Test', color='salmon')\n ax3.set_ylabel('PR AUC')\n ax3.set_title('Model PR AUC Comparison')\n ax3.set_xticks(x)\n ax3.set_xticklabels(models, rotation=45, ha='right')\n ax3.legend()\n ax3.grid(True, alpha=0.3)\n \n # Plot Kappa scores\n val_kappa_values = [val_kappa.get(model, None) for model in models]\n test_kappa_values = [test_kappa.get(model, None) for model in models]\n \n # Replace None with NaN for plotting\n val_kappa_values = [v if v is not None else float('nan') for v in val_kappa_values]\n test_kappa_values = [v if v is not None else float('nan') for v in test_kappa_values]\n \n ax4.bar(x - width/2, val_kappa_values, width, label='Validation', color='skyblue')\n ax4.bar(x + width/2, test_kappa_values, width, label='Test', color='salmon')\n ax4.set_ylabel(\"Cohen's Kappa\")\n ax4.set_title(\"Model Cohen's Kappa Comparison\")\n ax4.set_xticks(x)\n ax4.set_xticklabels(models, rotation=45, ha='right')\n ax4.legend()\n ax4.grid(True, alpha=0.3)\n \n plt.tight_layout()\n plt.show()\n \n # Create a comprehensive heatmap for all metrics\n try:\n plot_metrics_heatmap(models, val_acc_values, test_acc_values, \n val_roc_values, test_roc_values,\n val_pr_values, test_pr_values,\n val_kappa_values, test_kappa_values)\n except Exception as e:\n print(f\"Error creating metrics heatmap: {e}\")\n\ndef plot_metrics_heatmap(models, val_acc, test_acc, val_roc, test_roc, val_pr, test_pr, val_kappa, test_kappa):\n \"\"\"\n Create a heatmap visualization of all metrics for easy comparison across models.\n \"\"\"\n # Prepare data for heatmap\n metric_names = ['Val Acc', 'Test Acc', 'Val ROC', 'Test ROC', \n 'Val PR', 'Test PR', 'Val Kappa', 'Test Kappa']\n \n data = np.array([\n val_acc, test_acc, val_roc, test_roc, val_pr, test_pr, val_kappa, test_kappa\n ])\n \n # Create the heatmap\n plt.figure(figsize=(12, 8))\n ax = sns.heatmap(data, annot=True, fmt=\".4f\", cmap=\"YlGnBu\", \n xticklabels=models, yticklabels=metric_names)\n \n plt.title(\"Comprehensive Model Performance Metrics\")\n plt.tight_layout()\n plt.show()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:16.595427Z","iopub.execute_input":"2025-05-07T15:22:16.595661Z","iopub.status.idle":"2025-05-07T15:22:16.620099Z","shell.execute_reply.started":"2025-05-07T15:22:16.595635Z","shell.execute_reply":"2025-05-07T15:22:16.619560Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"# Main Function","metadata":{}},{"cell_type":"code","source":"train_dir = '/kaggle/input/eye-disease-image-dataset/Augmented Dataset/Augmented Dataset' # For training (pre-augmented data)\neval_dir = '/kaggle/input/eye-disease-image-dataset/Original Dataset/Original Dataset' # For val and test\nepochs = 15\nclasses_to_exclude = [\"Pterygium\"]\nbatch_size = 32\n\nseed = 42\nrandom.seed(seed)\nnp.random.seed(seed)\ntorch.manual_seed(seed)\nif torch.cuda.is_available():\n torch.cuda.manual_seed(seed)\n torch.backends.cudnn.deterministic = True\n\n# Define transformations\ntransform = transforms.Compose([\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n])\n\ntry:\n # Load datasets\n print(f\"Loading training dataset from {train_dir}...\")\n train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)\n \n print(f\"Loading evaluation dataset from {eval_dir}...\")\n eval_dataset = datasets.ImageFolder(root=eval_dir, transform=transform)\n \n # Print dataset information\n print(f\"Training dataset loaded with {len(train_dataset)} images across {len(train_dataset.classes)} classes.\")\n print(f\"Training classes: {train_dataset.classes}\")\n \n print(f\"Evaluation dataset loaded with {len(eval_dataset)} images across {len(eval_dataset.classes)} classes.\")\n print(f\"Evaluation classes: {eval_dataset.classes}\")\n \n # Filter datasets if needed\n excluded_classes = classes_to_exclude or []\n if excluded_classes:\n print(f\"Filtering out classes: {excluded_classes}\")\n filtered_train_dataset = FilteredImageDataset(train_dataset, excluded_classes=excluded_classes)\n filtered_eval_dataset = FilteredImageDataset(eval_dataset, excluded_classes=excluded_classes)\n else:\n filtered_train_dataset = train_dataset\n filtered_eval_dataset = eval_dataset\n \n # Check if the filtered classes match between training and evaluation datasets\n if set(filtered_train_dataset.classes) != set(filtered_eval_dataset.classes):\n print(\"Warning: Class mismatch between filtered training and evaluation datasets!\")\n print(f\"Filtered training classes: {filtered_train_dataset.classes}\")\n print(f\"Filtered evaluation classes: {filtered_eval_dataset.classes}\")\n \n # Find common classes\n common_classes = set(filtered_train_dataset.classes).intersection(set(filtered_eval_dataset.classes))\n print(f\"Common classes: {common_classes}\")\n \n # Create additional filtering based on common classes\n filtered_train_dataset = FilteredImageDataset(train_dataset, \n included_classes=common_classes)\n filtered_eval_dataset = FilteredImageDataset(eval_dataset, \n included_classes=common_classes)\n \n # Split evaluation dataset into validation and test sets\n eval_ratio = 0.7 # 70% validation, 30% test\n eval_size = len(filtered_eval_dataset)\n val_size = int(eval_ratio * eval_size)\n test_size = eval_size - val_size\n \n val_dataset, test_dataset = random_split(filtered_eval_dataset, [val_size, test_size])\n \n print(f\"Training set size: {len(filtered_train_dataset)}\")\n print(f\"Validation set size: {len(val_dataset)}\")\n print(f\"Test set size: {len(test_dataset)}\")\n \n # Create data loaders\n train_loader = DataLoader(filtered_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)\n val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)\n test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)\n \n # Get the number of classes (after filtering)\n num_classes = len(filtered_train_dataset.classes)\n print(f\"Number of classes after filtering: {num_classes}\")\n print(f\"Classes after filtering: {filtered_train_dataset.classes}\")\n \n # Initialize models\n print(\"Initializing models...\")\n \n # Initialize models with the updated number of classes\n all_models = {\n \"MobileNetV4\": get_model_mobilenetv4(num_classes, freeze_layers=True),\n \"LeViT\": get_model_levit(num_classes, freeze_layers=True),\n \"EfficientViT\": get_model_efficientvit(num_classes, freeze_layers=True),\n \"GENet\": get_model_gernet(num_classes, freeze_layers=True),\n \"RegNetX\": get_model_regnetx(num_classes, freeze_layers=True)\n }\n \n models = list(all_models.values())\n model_names = list(all_models.keys())\n \n # Train and compare models\n print(\"Starting model training and comparison...\")\n compare_models(models, train_loader, val_loader, test_loader, filtered_train_dataset, epochs=epochs, names=model_names)\n \nexcept Exception as e:\n print(f\"Error in eye disease classification pipeline: {e}\")\n import traceback\n traceback.print_exc()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-07T15:22:59.764465Z","iopub.execute_input":"2025-05-07T15:22:59.765257Z","execution_failed":"2025-05-07T15:22:59.047Z"}},"outputs":[],"execution_count":null}]}
|
utils/Callback.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Early stopping class
|
| 2 |
+
class EarlyStopping:
|
| 3 |
+
def __init__(self, patience=5, delta=0):
|
| 4 |
+
self.patience = patience
|
| 5 |
+
self.delta = delta
|
| 6 |
+
self.counter = 0
|
| 7 |
+
self.best_score = None
|
| 8 |
+
self.early_stop = False
|
| 9 |
+
|
| 10 |
+
def __call__(self, val_loss):
|
| 11 |
+
score = -val_loss
|
| 12 |
+
|
| 13 |
+
if self.best_score is None:
|
| 14 |
+
self.best_score = score
|
| 15 |
+
elif score < self.best_score + self.delta:
|
| 16 |
+
self.counter += 1
|
| 17 |
+
if self.counter >= self.patience:
|
| 18 |
+
self.early_stop = True
|
| 19 |
+
else:
|
| 20 |
+
self.best_score = score
|
| 21 |
+
self.counter = 0
|
utils/Comparator.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Trainer import model_train
|
| 2 |
+
from Evaluator import ClassificationEvaluator
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compare_models(
|
| 6 |
+
models, train_loader, val_loader, test_loader, dataset, epochs=20, names=None
|
| 7 |
+
):
|
| 8 |
+
if names is None:
|
| 9 |
+
names = [f"Model {i+1}" for i in range(len(models))]
|
| 10 |
+
|
| 11 |
+
val_results = {}
|
| 12 |
+
test_results = {}
|
| 13 |
+
best_model_obj = None
|
| 14 |
+
best_accuracy = -1
|
| 15 |
+
best_model_name = ""
|
| 16 |
+
|
| 17 |
+
# Summary dictionaries for metrics
|
| 18 |
+
val_roc_auc_summary = {}
|
| 19 |
+
test_roc_auc_summary = {}
|
| 20 |
+
val_pr_auc_summary = {}
|
| 21 |
+
test_pr_auc_summary = {}
|
| 22 |
+
val_kappa_summary = {}
|
| 23 |
+
test_kappa_summary = {}
|
| 24 |
+
|
| 25 |
+
for i, (model, name) in enumerate(zip(models, names)):
|
| 26 |
+
evaluator = ClassificationEvaluator(
|
| 27 |
+
num_classes=len(dataset.classes),
|
| 28 |
+
class_names=dataset.classes,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
print(f"\n\n{'#'*30} Training {name} ({i+1}/{len(models)}) {'#'*30}\n")
|
| 32 |
+
model_results = model_train(model, train_loader, val_loader, dataset, epochs)
|
| 33 |
+
|
| 34 |
+
# Extract accuracy from results
|
| 35 |
+
accuracy = model_results.get("accuracy")
|
| 36 |
+
val_results[name] = accuracy
|
| 37 |
+
|
| 38 |
+
# Extract and store metrics
|
| 39 |
+
if "roc_auc" in model_results and "micro" in model_results["roc_auc"]:
|
| 40 |
+
val_roc_auc_summary[name] = model_results["roc_auc"]["micro"]
|
| 41 |
+
else:
|
| 42 |
+
val_roc_auc_summary[name] = None
|
| 43 |
+
|
| 44 |
+
if "pr_auc" in model_results and "micro" in model_results["pr_auc"]:
|
| 45 |
+
val_pr_auc_summary[name] = model_results["pr_auc"]["micro"]
|
| 46 |
+
else:
|
| 47 |
+
val_pr_auc_summary[name] = None
|
| 48 |
+
|
| 49 |
+
# Store kappa score
|
| 50 |
+
if "kappa" in model_results:
|
| 51 |
+
val_kappa_summary[name] = model_results["kappa"]
|
| 52 |
+
else:
|
| 53 |
+
val_kappa_summary[name] = None
|
| 54 |
+
|
| 55 |
+
# Evaluate on test set
|
| 56 |
+
if accuracy is not None:
|
| 57 |
+
print(f"\n{'='*20} Testing {name} on Test Set {'='*20}\n")
|
| 58 |
+
test_model_results = evaluator.evaluate_model(model, test_loader)
|
| 59 |
+
|
| 60 |
+
# Extract accuracy from test results
|
| 61 |
+
test_accuracy = test_model_results.get("accuracy")
|
| 62 |
+
test_results[name] = test_accuracy
|
| 63 |
+
|
| 64 |
+
# Extract and store test metrics
|
| 65 |
+
if (
|
| 66 |
+
"roc_auc" in test_model_results
|
| 67 |
+
and "micro" in test_model_results["roc_auc"]
|
| 68 |
+
):
|
| 69 |
+
test_roc_auc_summary[name] = test_model_results["roc_auc"]["micro"]
|
| 70 |
+
else:
|
| 71 |
+
test_roc_auc_summary[name] = None
|
| 72 |
+
|
| 73 |
+
if (
|
| 74 |
+
"pr_auc" in test_model_results
|
| 75 |
+
and "micro" in test_model_results["pr_auc"]
|
| 76 |
+
):
|
| 77 |
+
test_pr_auc_summary[name] = test_model_results["pr_auc"]["micro"]
|
| 78 |
+
else:
|
| 79 |
+
test_pr_auc_summary[name] = None
|
| 80 |
+
|
| 81 |
+
# Store test kappa score
|
| 82 |
+
if "kappa" in test_model_results:
|
| 83 |
+
test_kappa_summary[name] = test_model_results["kappa"]
|
| 84 |
+
else:
|
| 85 |
+
test_kappa_summary[name] = None
|
| 86 |
+
|
| 87 |
+
# Track best model
|
| 88 |
+
if test_accuracy > best_accuracy:
|
| 89 |
+
best_accuracy = test_accuracy
|
| 90 |
+
best_model_obj = copy.deepcopy(model)
|
| 91 |
+
best_model_name = name
|
| 92 |
+
|
| 93 |
+
# Print comprehensive comparison
|
| 94 |
+
print("\n\n" + "=" * 100)
|
| 95 |
+
print("COMPREHENSIVE MODEL COMPARISON")
|
| 96 |
+
print("=" * 100)
|
| 97 |
+
print(
|
| 98 |
+
f"{'Model':<20}{'Val Acc':<10}{'Test Acc':<10}{'Val ROC AUC':<14}{'Test ROC AUC':<14}{'Val PR AUC':<14}{'Test PR AUC':<14}{'Val Kappa':<14}{'Test Kappa':<14}"
|
| 99 |
+
)
|
| 100 |
+
print("-" * 100)
|
| 101 |
+
|
| 102 |
+
for name in val_results.keys():
|
| 103 |
+
val_acc = val_results[name]
|
| 104 |
+
test_acc = test_results.get(name, None)
|
| 105 |
+
val_roc = val_roc_auc_summary.get(name, None)
|
| 106 |
+
test_roc = test_roc_auc_summary.get(name, None)
|
| 107 |
+
val_pr = val_pr_auc_summary.get(name, None)
|
| 108 |
+
test_pr = test_pr_auc_summary.get(name, None)
|
| 109 |
+
val_kappa = val_kappa_summary.get(name, None)
|
| 110 |
+
test_kappa = test_kappa_summary.get(name, None)
|
| 111 |
+
|
| 112 |
+
# Format values for display
|
| 113 |
+
val_acc_str = f"{val_acc:.4f}" if val_acc is not None else "Failed"
|
| 114 |
+
test_acc_str = f"{test_acc:.4f}" if test_acc is not None else "N/A"
|
| 115 |
+
val_roc_str = f"{val_roc:.4f}" if val_roc is not None else "N/A"
|
| 116 |
+
test_roc_str = f"{test_roc:.4f}" if test_roc is not None else "N/A"
|
| 117 |
+
val_pr_str = f"{val_pr:.4f}" if val_pr is not None else "N/A"
|
| 118 |
+
test_pr_str = f"{test_pr:.4f}" if test_pr is not None else "N/A"
|
| 119 |
+
val_kappa_str = f"{val_kappa:.4f}" if val_kappa is not None else "N/A"
|
| 120 |
+
test_kappa_str = f"{test_kappa:.4f}" if test_kappa is not None else "N/A"
|
| 121 |
+
|
| 122 |
+
print(
|
| 123 |
+
f"{name:<20}{val_acc_str:<10}{test_acc_str:<10}{val_roc_str:<14}{test_roc_str:<14}{val_pr_str:<14}{test_pr_str:<14}{val_kappa_str:<14}{test_kappa_str:<14}"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Identify best model based on test metrics
|
| 127 |
+
if test_results:
|
| 128 |
+
# Best model by accuracy
|
| 129 |
+
best_acc_model = max(
|
| 130 |
+
test_results.items(), key=lambda x: x[1] if x[1] is not None else -1
|
| 131 |
+
)
|
| 132 |
+
print(
|
| 133 |
+
f"\nBest model by accuracy: {best_acc_model[0]} (Test Accuracy: {best_acc_model[1]:.4f})"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Best model by ROC AUC (if available)
|
| 137 |
+
if any(v is not None for v in test_roc_auc_summary.values()):
|
| 138 |
+
best_roc_model = max(
|
| 139 |
+
[(k, v) for k, v in test_roc_auc_summary.items() if v is not None],
|
| 140 |
+
key=lambda x: x[1] if x[1] is not None else -1,
|
| 141 |
+
)
|
| 142 |
+
print(
|
| 143 |
+
f"Best model by ROC AUC: {best_roc_model[0]} (Test ROC AUC: {best_roc_model[1]:.4f})"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Best model by PR AUC (if available)
|
| 147 |
+
if any(v is not None for v in test_pr_auc_summary.values()):
|
| 148 |
+
best_pr_model = max(
|
| 149 |
+
[(k, v) for k, v in test_pr_auc_summary.items() if v is not None],
|
| 150 |
+
key=lambda x: x[1] if x[1] is not None else -1,
|
| 151 |
+
)
|
| 152 |
+
print(
|
| 153 |
+
f"Best model by PR AUC: {best_pr_model[0]} (Test PR AUC: {best_pr_model[1]:.4f})"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Best model by Kappa (if available)
|
| 157 |
+
if any(v is not None for v in test_kappa_summary.values()):
|
| 158 |
+
best_kappa_model = max(
|
| 159 |
+
[(k, v) for k, v in test_kappa_summary.items() if v is not None],
|
| 160 |
+
key=lambda x: x[1] if x[1] is not None else -1,
|
| 161 |
+
)
|
| 162 |
+
print(
|
| 163 |
+
f"Best model by Cohen's Kappa: {best_kappa_model[0]} (Test Kappa: {best_kappa_model[1]:.4f})"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Save the best model (by accuracy)
|
| 167 |
+
if best_model_obj is not None:
|
| 168 |
+
try:
|
| 169 |
+
model_save_path = (
|
| 170 |
+
f"best_model_{best_model_name.lower().replace(' ', '_')}.pth"
|
| 171 |
+
)
|
| 172 |
+
torch.save(best_model_obj.state_dict(), model_save_path)
|
| 173 |
+
print(f"Best model saved to {model_save_path}")
|
| 174 |
+
except Exception as save_error:
|
| 175 |
+
print(f"Error saving best model: {save_error}")
|
| 176 |
+
else:
|
| 177 |
+
print("\nNo models successfully completed testing.")
|
| 178 |
+
|
| 179 |
+
print("=" * 100)
|
utils/DatasetHandler.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import random
|
| 4 |
+
import copy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.optim as optim
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
from torchvision import transforms, datasets
|
| 12 |
+
import torchvision.models as models
|
| 13 |
+
import timm
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import seaborn as sns
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from torch.utils.data import DataLoader, random_split, Subset, Dataset
|
| 18 |
+
from sklearn.metrics import (
|
| 19 |
+
accuracy_score,
|
| 20 |
+
confusion_matrix,
|
| 21 |
+
classification_report,
|
| 22 |
+
roc_curve,
|
| 23 |
+
auc,
|
| 24 |
+
precision_recall_curve,
|
| 25 |
+
average_precision_score,
|
| 26 |
+
)
|
| 27 |
+
from sklearn.preprocessing import label_binarize
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
import gc
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Modified FilteredImageDataset class with Pterygium filtering
|
| 33 |
+
class FilteredImageDataset(Dataset):
|
| 34 |
+
def __init__(self, dataset, excluded_classes=None):
|
| 35 |
+
"""
|
| 36 |
+
Create a filtered dataset that excludes specific classes.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
dataset: Original dataset (ImageFolder or similar)
|
| 40 |
+
excluded_classes: List of class names to exclude (e.g., ["Pterygium"])
|
| 41 |
+
"""
|
| 42 |
+
self.dataset = dataset
|
| 43 |
+
self.excluded_classes = excluded_classes or []
|
| 44 |
+
|
| 45 |
+
# Get original class information
|
| 46 |
+
self.orig_classes = dataset.classes
|
| 47 |
+
self.orig_class_to_idx = dataset.class_to_idx
|
| 48 |
+
|
| 49 |
+
# Create indices of samples to keep (excluding specified classes)
|
| 50 |
+
self.indices = []
|
| 51 |
+
for idx, (_, target) in enumerate(dataset.samples):
|
| 52 |
+
class_name = self.orig_classes[target]
|
| 53 |
+
if class_name not in self.excluded_classes:
|
| 54 |
+
self.indices.append(idx)
|
| 55 |
+
|
| 56 |
+
# Create new class mapping without excluded classes
|
| 57 |
+
remaining_classes = [
|
| 58 |
+
c for c in self.orig_classes if c not in self.excluded_classes
|
| 59 |
+
]
|
| 60 |
+
self.classes = remaining_classes
|
| 61 |
+
self.class_to_idx = {cls: idx for idx, cls in enumerate(remaining_classes)}
|
| 62 |
+
self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
|
| 63 |
+
|
| 64 |
+
# Create a mapping from old indices to new indices
|
| 65 |
+
self.target_mapping = {}
|
| 66 |
+
for old_class, old_idx in self.orig_class_to_idx.items():
|
| 67 |
+
if old_class in self.class_to_idx:
|
| 68 |
+
self.target_mapping[old_idx] = self.class_to_idx[old_class]
|
| 69 |
+
|
| 70 |
+
print(f"Filtered out classes: {self.excluded_classes}")
|
| 71 |
+
print(f"Remaining classes: {self.classes}")
|
| 72 |
+
print(
|
| 73 |
+
f"Original dataset size: {len(dataset)}, Filtered dataset size: {len(self.indices)}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, index):
|
| 77 |
+
"""Get item from the filtered dataset with remapped class labels."""
|
| 78 |
+
orig_idx = self.indices[index]
|
| 79 |
+
img, old_target = self.dataset[orig_idx]
|
| 80 |
+
|
| 81 |
+
# Remap target to new class index
|
| 82 |
+
new_target = self.target_mapping[old_target]
|
| 83 |
+
|
| 84 |
+
return img, new_target
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
"""Return the number of samples in the filtered dataset."""
|
| 88 |
+
return len(self.indices)
|
| 89 |
+
|
| 90 |
+
# Allow transform to be updated
|
| 91 |
+
def set_transform(self, transform):
|
| 92 |
+
"""Update the transform for the dataset."""
|
| 93 |
+
self.dataset.transform = transform
|
utils/Evaluator.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
from sklearn.metrics import (
|
| 6 |
+
accuracy_score,
|
| 7 |
+
classification_report,
|
| 8 |
+
confusion_matrix,
|
| 9 |
+
roc_curve,
|
| 10 |
+
precision_recall_curve,
|
| 11 |
+
auc,
|
| 12 |
+
average_precision_score,
|
| 13 |
+
cohen_kappa_score,
|
| 14 |
+
)
|
| 15 |
+
from sklearn.preprocessing import label_binarize
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ClassificationEvaluator:
|
| 19 |
+
"""
|
| 20 |
+
A class to evaluate and visualize classification model performance.
|
| 21 |
+
|
| 22 |
+
This class provides methods to compute various classification metrics
|
| 23 |
+
and generate visualizations for model evaluation.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, class_names):
|
| 27 |
+
"""
|
| 28 |
+
Initialize the evaluator with class names.
|
| 29 |
+
|
| 30 |
+
Parameters:
|
| 31 |
+
- class_names: list of class names
|
| 32 |
+
"""
|
| 33 |
+
self.class_names = class_names
|
| 34 |
+
self.num_classes = len(class_names)
|
| 35 |
+
|
| 36 |
+
def _ensure_numpy(self, data):
|
| 37 |
+
"""Convert tensor to numpy if needed."""
|
| 38 |
+
if torch.is_tensor(data):
|
| 39 |
+
return data.cpu().numpy()
|
| 40 |
+
return np.array(data)
|
| 41 |
+
|
| 42 |
+
def evaluate_model(self, model, test_loader):
|
| 43 |
+
"""
|
| 44 |
+
Evaluate a trained model on test dataset.
|
| 45 |
+
|
| 46 |
+
Parameters:
|
| 47 |
+
- model: PyTorch model to evaluate
|
| 48 |
+
- test_loader: DataLoader containing test data
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
- results: Dictionary containing evaluation metrics
|
| 52 |
+
"""
|
| 53 |
+
model.eval()
|
| 54 |
+
device = next(model.parameters()).device
|
| 55 |
+
|
| 56 |
+
all_labels = []
|
| 57 |
+
all_preds = []
|
| 58 |
+
all_scores = []
|
| 59 |
+
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
for inputs, labels in test_loader:
|
| 62 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 63 |
+
outputs = model(inputs)
|
| 64 |
+
_, preds = torch.max(outputs, 1)
|
| 65 |
+
|
| 66 |
+
all_labels.extend(labels.cpu().numpy())
|
| 67 |
+
all_preds.extend(preds.cpu().numpy())
|
| 68 |
+
all_scores.append(
|
| 69 |
+
torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
all_scores = np.vstack(all_scores)
|
| 73 |
+
|
| 74 |
+
# Compute metrics
|
| 75 |
+
results = self.compute_metrics(all_labels, all_preds, all_scores)
|
| 76 |
+
return results
|
| 77 |
+
|
| 78 |
+
def compute_metrics(self, y_true, y_pred, y_scores, model_name=""):
|
| 79 |
+
"""
|
| 80 |
+
Compute comprehensive classification metrics.
|
| 81 |
+
|
| 82 |
+
Parameters:
|
| 83 |
+
- y_true: true labels
|
| 84 |
+
- y_pred: predicted labels
|
| 85 |
+
- y_scores: predicted probability scores
|
| 86 |
+
- model_name: name of the model (optional)
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
- Dictionary containing all metrics
|
| 90 |
+
"""
|
| 91 |
+
# Ensure numpy arrays
|
| 92 |
+
y_true = self._ensure_numpy(y_true)
|
| 93 |
+
y_pred = self._ensure_numpy(y_pred)
|
| 94 |
+
y_scores = self._ensure_numpy(y_scores)
|
| 95 |
+
|
| 96 |
+
# Calculate accuracy
|
| 97 |
+
accuracy = accuracy_score(y_true, y_pred)
|
| 98 |
+
print(f"Overall Accuracy: {accuracy:.4f}")
|
| 99 |
+
|
| 100 |
+
# Calculate and display Cohen's Kappa
|
| 101 |
+
kappa = cohen_kappa_score(y_true, y_pred)
|
| 102 |
+
print(f"Cohen's Kappa Score: {kappa:.4f}")
|
| 103 |
+
|
| 104 |
+
# Generate classification report
|
| 105 |
+
report = classification_report(
|
| 106 |
+
y_true, y_pred, target_names=self.class_names, output_dict=True
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Print formatted classification report
|
| 110 |
+
print("\nClassification Report:")
|
| 111 |
+
print(classification_report(y_true, y_pred, target_names=self.class_names))
|
| 112 |
+
|
| 113 |
+
# Calculate ROC curves and AUC for each class
|
| 114 |
+
print("\nCalculating ROC curves...")
|
| 115 |
+
roc_auc_dict = self.plot_roc_curves(y_true, y_scores)
|
| 116 |
+
|
| 117 |
+
# Calculate PR curves and AUC for each class
|
| 118 |
+
print("\nCalculating PR curves...")
|
| 119 |
+
pr_auc_dict = self.plot_pr_curves(y_true, y_scores)
|
| 120 |
+
|
| 121 |
+
# Plot confusion matrix
|
| 122 |
+
print("\nGenerating confusion matrix...")
|
| 123 |
+
self.plot_confusion_matrix(y_true, y_pred)
|
| 124 |
+
|
| 125 |
+
# Plot per-class accuracy
|
| 126 |
+
print("\nCalculating per-class accuracy...")
|
| 127 |
+
self.plot_per_class_accuracy(y_true, y_pred)
|
| 128 |
+
|
| 129 |
+
# Return metrics dictionary
|
| 130 |
+
return {
|
| 131 |
+
"accuracy": accuracy,
|
| 132 |
+
"report": report,
|
| 133 |
+
"roc_auc": roc_auc_dict,
|
| 134 |
+
"pr_auc": pr_auc_dict,
|
| 135 |
+
"kappa": kappa,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def plot_roc_curves(self, y_true, y_scores):
|
| 139 |
+
"""
|
| 140 |
+
Plot ROC curves for multi-class classification.
|
| 141 |
+
|
| 142 |
+
Parameters:
|
| 143 |
+
- y_true: true labels
|
| 144 |
+
- y_scores: predicted probability scores
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
- Dictionary containing AUC values for each class
|
| 148 |
+
"""
|
| 149 |
+
y_true = self._ensure_numpy(y_true)
|
| 150 |
+
y_scores = self._ensure_numpy(y_scores)
|
| 151 |
+
|
| 152 |
+
# Binarize the labels for one-vs-rest ROC calculation
|
| 153 |
+
y_true_bin = label_binarize(y_true, classes=range(self.num_classes))
|
| 154 |
+
|
| 155 |
+
# Compute ROC curve and ROC area for each class
|
| 156 |
+
fpr = {}
|
| 157 |
+
tpr = {}
|
| 158 |
+
roc_auc = {}
|
| 159 |
+
|
| 160 |
+
plt.figure(figsize=(12, 8))
|
| 161 |
+
|
| 162 |
+
for i in range(self.num_classes):
|
| 163 |
+
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_scores[:, i])
|
| 164 |
+
roc_auc[i] = auc(fpr[i], tpr[i])
|
| 165 |
+
|
| 166 |
+
plt.plot(
|
| 167 |
+
fpr[i],
|
| 168 |
+
tpr[i],
|
| 169 |
+
lw=2,
|
| 170 |
+
label=f"{self.class_names[i]} (area = {roc_auc[i]:.2f})",
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Plot the diagonal (random classifier)
|
| 174 |
+
plt.plot([0, 1], [0, 1], "k--", lw=2)
|
| 175 |
+
|
| 176 |
+
# Calculate and plot micro-average ROC curve
|
| 177 |
+
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_scores.ravel())
|
| 178 |
+
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
| 179 |
+
plt.plot(
|
| 180 |
+
fpr["micro"],
|
| 181 |
+
tpr["micro"],
|
| 182 |
+
label=f'Micro-average (area = {roc_auc["micro"]:.2f})',
|
| 183 |
+
lw=2,
|
| 184 |
+
linestyle=":",
|
| 185 |
+
color="deeppink",
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
plt.xlim([0.0, 1.0])
|
| 189 |
+
plt.ylim([0.0, 1.05])
|
| 190 |
+
plt.xlabel("False Positive Rate")
|
| 191 |
+
plt.ylabel("True Positive Rate")
|
| 192 |
+
plt.title("ROC Curves")
|
| 193 |
+
plt.legend(loc="lower right")
|
| 194 |
+
plt.grid(True, alpha=0.3)
|
| 195 |
+
plt.tight_layout()
|
| 196 |
+
plt.show()
|
| 197 |
+
|
| 198 |
+
return roc_auc
|
| 199 |
+
|
| 200 |
+
def plot_pr_curves(self, y_true, y_scores):
|
| 201 |
+
"""
|
| 202 |
+
Plot Precision-Recall curves for multi-class classification.
|
| 203 |
+
|
| 204 |
+
Parameters:
|
| 205 |
+
- y_true: true labels
|
| 206 |
+
- y_scores: predicted probability scores
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
- Dictionary containing average precision values for each class
|
| 210 |
+
"""
|
| 211 |
+
y_true = self._ensure_numpy(y_true)
|
| 212 |
+
y_scores = self._ensure_numpy(y_scores)
|
| 213 |
+
|
| 214 |
+
# Binarize the labels
|
| 215 |
+
y_true_bin = label_binarize(y_true, classes=range(self.num_classes))
|
| 216 |
+
|
| 217 |
+
# Compute PR curve and average precision for each class
|
| 218 |
+
precision = {}
|
| 219 |
+
recall = {}
|
| 220 |
+
avg_precision = {}
|
| 221 |
+
|
| 222 |
+
plt.figure(figsize=(12, 8))
|
| 223 |
+
|
| 224 |
+
for i in range(self.num_classes):
|
| 225 |
+
precision[i], recall[i], _ = precision_recall_curve(
|
| 226 |
+
y_true_bin[:, i], y_scores[:, i]
|
| 227 |
+
)
|
| 228 |
+
avg_precision[i] = average_precision_score(y_true_bin[:, i], y_scores[:, i])
|
| 229 |
+
|
| 230 |
+
plt.plot(
|
| 231 |
+
recall[i],
|
| 232 |
+
precision[i],
|
| 233 |
+
lw=2,
|
| 234 |
+
label=f"{self.class_names[i]} (AP = {avg_precision[i]:.2f})",
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Calculate and plot micro-average PR curve
|
| 238 |
+
precision["micro"], recall["micro"], _ = precision_recall_curve(
|
| 239 |
+
y_true_bin.ravel(), y_scores.ravel()
|
| 240 |
+
)
|
| 241 |
+
avg_precision["micro"] = average_precision_score(
|
| 242 |
+
y_true_bin.ravel(), y_scores.ravel()
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
plt.plot(
|
| 246 |
+
recall["micro"],
|
| 247 |
+
precision["micro"],
|
| 248 |
+
label=f'Micro-average (AP = {avg_precision["micro"]:.2f})',
|
| 249 |
+
lw=2,
|
| 250 |
+
linestyle=":",
|
| 251 |
+
color="deeppink",
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
plt.xlim([0.0, 1.0])
|
| 255 |
+
plt.ylim([0.0, 1.05])
|
| 256 |
+
plt.xlabel("Recall")
|
| 257 |
+
plt.ylabel("Precision")
|
| 258 |
+
plt.title("Precision-Recall Curves")
|
| 259 |
+
plt.legend(loc="best")
|
| 260 |
+
plt.grid(True, alpha=0.3)
|
| 261 |
+
plt.tight_layout()
|
| 262 |
+
plt.show()
|
| 263 |
+
|
| 264 |
+
return avg_precision
|
| 265 |
+
|
| 266 |
+
def plot_confusion_matrix(self, y_true, y_pred):
|
| 267 |
+
"""
|
| 268 |
+
Plot confusion matrix.
|
| 269 |
+
|
| 270 |
+
Parameters:
|
| 271 |
+
- y_true: true labels
|
| 272 |
+
- y_pred: predicted labels
|
| 273 |
+
"""
|
| 274 |
+
y_true = self._ensure_numpy(y_true)
|
| 275 |
+
y_pred = self._ensure_numpy(y_pred)
|
| 276 |
+
|
| 277 |
+
# Get unique values in both arrays
|
| 278 |
+
unique_values = np.unique(np.concatenate([y_true, y_pred]))
|
| 279 |
+
print(f"Unique values in confusion matrix data: {unique_values}")
|
| 280 |
+
|
| 281 |
+
# Create the confusion matrix with explicit labels
|
| 282 |
+
cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes))
|
| 283 |
+
|
| 284 |
+
plt.figure(figsize=(10, 8))
|
| 285 |
+
sns.heatmap(
|
| 286 |
+
cm,
|
| 287 |
+
annot=True,
|
| 288 |
+
fmt="d",
|
| 289 |
+
cmap="Blues",
|
| 290 |
+
xticklabels=self.class_names,
|
| 291 |
+
yticklabels=self.class_names,
|
| 292 |
+
)
|
| 293 |
+
plt.title("Confusion Matrix")
|
| 294 |
+
plt.xlabel("Predicted")
|
| 295 |
+
plt.ylabel("True")
|
| 296 |
+
plt.tight_layout()
|
| 297 |
+
plt.show()
|
| 298 |
+
|
| 299 |
+
def plot_per_class_accuracy(self, y_true, y_pred):
|
| 300 |
+
"""
|
| 301 |
+
Plot per-class accuracy.
|
| 302 |
+
|
| 303 |
+
Parameters:
|
| 304 |
+
- y_true: true labels
|
| 305 |
+
- y_pred: predicted labels
|
| 306 |
+
"""
|
| 307 |
+
y_true = self._ensure_numpy(y_true)
|
| 308 |
+
y_pred = self._ensure_numpy(y_pred)
|
| 309 |
+
|
| 310 |
+
# Create the confusion matrix with explicit labels
|
| 311 |
+
cm = confusion_matrix(y_true, y_pred, labels=range(self.num_classes))
|
| 312 |
+
|
| 313 |
+
# Calculate per-class accuracy
|
| 314 |
+
per_class_accuracy = np.zeros(self.num_classes)
|
| 315 |
+
for i in range(self.num_classes):
|
| 316 |
+
if i < cm.shape[0] and np.sum(cm[i, :]) > 0:
|
| 317 |
+
per_class_accuracy[i] = cm[i, i] / np.sum(cm[i, :])
|
| 318 |
+
|
| 319 |
+
# Create the bar plot
|
| 320 |
+
plt.figure(figsize=(14, 7))
|
| 321 |
+
plt.bar(range(self.num_classes), per_class_accuracy, color="skyblue")
|
| 322 |
+
plt.xticks(range(self.num_classes), self.class_names, rotation=45, ha="right")
|
| 323 |
+
plt.xlabel("Classes")
|
| 324 |
+
plt.ylabel("Accuracy")
|
| 325 |
+
plt.title("Per-Class Accuracy")
|
| 326 |
+
plt.tight_layout()
|
| 327 |
+
plt.show()
|
| 328 |
+
|
| 329 |
+
return per_class_accuracy
|
| 330 |
+
|
| 331 |
+
def plot_training_history(self, train_losses, val_losses, train_accs, val_accs):
|
| 332 |
+
"""
|
| 333 |
+
Plot accuracy and loss curves from training history.
|
| 334 |
+
|
| 335 |
+
Parameters:
|
| 336 |
+
- train_losses: list of training losses
|
| 337 |
+
- val_losses: list of validation losses
|
| 338 |
+
- train_accs: list of training accuracies
|
| 339 |
+
- val_accs: list of validation accuracies
|
| 340 |
+
"""
|
| 341 |
+
plt.figure(figsize=(12, 5))
|
| 342 |
+
|
| 343 |
+
# Accuracy curve
|
| 344 |
+
plt.subplot(1, 2, 1)
|
| 345 |
+
plt.plot(train_accs, label="Train Accuracy")
|
| 346 |
+
plt.plot(val_accs, label="Validation Accuracy")
|
| 347 |
+
plt.xlabel("Epochs")
|
| 348 |
+
plt.ylabel("Accuracy")
|
| 349 |
+
plt.title("Accuracy Curve")
|
| 350 |
+
plt.legend()
|
| 351 |
+
plt.grid(True)
|
| 352 |
+
|
| 353 |
+
# Loss curve
|
| 354 |
+
plt.subplot(1, 2, 2)
|
| 355 |
+
plt.plot(train_losses, label="Train Loss")
|
| 356 |
+
plt.plot(val_losses, label="Validation Loss")
|
| 357 |
+
plt.xlabel("Epochs")
|
| 358 |
+
plt.ylabel("Loss")
|
| 359 |
+
plt.title("Loss Curve")
|
| 360 |
+
plt.legend()
|
| 361 |
+
plt.grid(True)
|
| 362 |
+
|
| 363 |
+
plt.tight_layout()
|
| 364 |
+
plt.show()
|
utils/ModelCreator.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import timm
|
| 4 |
+
|
| 5 |
+
# Set device
|
| 6 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EyeDetectionModels:
|
| 10 |
+
def __init__(self, num_classes, freeze_layers=True, device=DEVICE):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the EyeDetectionModels class.
|
| 13 |
+
This class provides methods to create and configure various deep learning models for eye detection.
|
| 14 |
+
"""
|
| 15 |
+
# Initialize the model creator
|
| 16 |
+
self.num_classes = num_classes
|
| 17 |
+
self.freeze_layers = freeze_layers
|
| 18 |
+
self.device = device
|
| 19 |
+
self.models = {
|
| 20 |
+
"mobilenetv4": self.get_model_mobilenetv4,
|
| 21 |
+
"levit": self.get_model_levit,
|
| 22 |
+
"efficientvit": self.get_model_efficientvit,
|
| 23 |
+
"gernet": self.get_model_gernet,
|
| 24 |
+
"regnetx": self.get_model_regnetx,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# Model architecture functions
|
| 28 |
+
@staticmethod
|
| 29 |
+
def _get_feature_blocks(model):
|
| 30 |
+
"""
|
| 31 |
+
Utility: locate the main feature blocks container in a timm model.
|
| 32 |
+
Returns a list-like module of blocks.
|
| 33 |
+
"""
|
| 34 |
+
for attr in ("features", "blocks", "layers", "stem"): # common container names
|
| 35 |
+
if hasattr(model, attr):
|
| 36 |
+
return getattr(model, attr)
|
| 37 |
+
# fallback: collect all children except classifier/head
|
| 38 |
+
return list(model.children())[:-1]
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _freeze_except_last_n(blocks, n):
|
| 42 |
+
total = len(blocks)
|
| 43 |
+
for idx, block in enumerate(blocks):
|
| 44 |
+
requires = idx >= total - n
|
| 45 |
+
for p in block.parameters():
|
| 46 |
+
p.requires_grad = requires
|
| 47 |
+
|
| 48 |
+
def get_model_mobilenetv4(self):
|
| 49 |
+
model = timm.create_model(
|
| 50 |
+
"mobilenetv4_conv_medium.e500_r256_in1k", pretrained=True
|
| 51 |
+
)
|
| 52 |
+
if self.freeze_layers:
|
| 53 |
+
blocks = self._get_feature_blocks(model)
|
| 54 |
+
self._freeze_except_last_n(blocks, 2)
|
| 55 |
+
# replace classifier
|
| 56 |
+
in_features = model.classifier.in_features
|
| 57 |
+
model.classifier = nn.Sequential(
|
| 58 |
+
nn.Linear(in_features, 512),
|
| 59 |
+
nn.ReLU(inplace=True),
|
| 60 |
+
nn.Dropout(0.4),
|
| 61 |
+
nn.Linear(512, self.num_classes),
|
| 62 |
+
)
|
| 63 |
+
return model.to(self.device)
|
| 64 |
+
|
| 65 |
+
def get_model_levit(self):
|
| 66 |
+
model = timm.create_model("levit_128s.fb_dist_in1k", pretrained=True)
|
| 67 |
+
if self.freeze_layers:
|
| 68 |
+
blocks = self._get_feature_blocks(model)
|
| 69 |
+
self._freeze_except_last_n(blocks, 2)
|
| 70 |
+
# Attempt to extract in_features from model.head or classifier
|
| 71 |
+
head = getattr(model, "head_dist", None) or getattr(model, "classifier", None)
|
| 72 |
+
linear = getattr(head, "linear")
|
| 73 |
+
in_features = 384
|
| 74 |
+
model.head = nn.Sequential(
|
| 75 |
+
nn.Linear(in_features, 512),
|
| 76 |
+
nn.ReLU(inplace=True),
|
| 77 |
+
nn.Dropout(0.4),
|
| 78 |
+
nn.Linear(512, self.num_classes),
|
| 79 |
+
)
|
| 80 |
+
model.head_dist = nn.Sequential(
|
| 81 |
+
nn.Linear(in_features, 512),
|
| 82 |
+
nn.ReLU(inplace=True),
|
| 83 |
+
nn.Dropout(0.4),
|
| 84 |
+
nn.Linear(512, self.num_classes),
|
| 85 |
+
)
|
| 86 |
+
return model.to(self.device)
|
| 87 |
+
|
| 88 |
+
def get_model_efficientvit(self):
|
| 89 |
+
model = timm.create_model("efficientvit_m1.r224_in1k", pretrained=True)
|
| 90 |
+
if self.freeze_layers:
|
| 91 |
+
blocks = self._get_feature_blocks(model)
|
| 92 |
+
self._freeze_except_last_n(blocks, 2)
|
| 93 |
+
# handle different head naming
|
| 94 |
+
head = getattr(model, "head", None)
|
| 95 |
+
print(head)
|
| 96 |
+
linear = getattr(head, "linear")
|
| 97 |
+
in_features = 192
|
| 98 |
+
model.head.linear = nn.Sequential(
|
| 99 |
+
nn.Linear(in_features, 512),
|
| 100 |
+
nn.ReLU(inplace=True),
|
| 101 |
+
nn.Dropout(0.4),
|
| 102 |
+
nn.Linear(512, self.num_classes),
|
| 103 |
+
)
|
| 104 |
+
return model.to(self.device)
|
| 105 |
+
|
| 106 |
+
def get_model_gernet(self):
|
| 107 |
+
"""
|
| 108 |
+
Load and configure a GENet (General and Efficient Network) model with customizable classifier.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Configured GENet model
|
| 112 |
+
"""
|
| 113 |
+
model = timm.create_model("gernet_s.idstcv_in1k", pretrained=True)
|
| 114 |
+
|
| 115 |
+
if self.freeze_layers:
|
| 116 |
+
# For GENet, we need to specifically handle its structure
|
| 117 |
+
# It typically has a 'stem' and 'stages' structure
|
| 118 |
+
if hasattr(model, "stem") and hasattr(model, "stages"):
|
| 119 |
+
# Freeze stem completely
|
| 120 |
+
for param in model.stem.parameters():
|
| 121 |
+
param.requires_grad = False
|
| 122 |
+
|
| 123 |
+
# Freeze all stages except the last two
|
| 124 |
+
stages = list(model.stages.children())
|
| 125 |
+
total_stages = len(stages)
|
| 126 |
+
for i, stage in enumerate(stages):
|
| 127 |
+
requires_grad = i >= total_stages - 2
|
| 128 |
+
for param in stage.parameters():
|
| 129 |
+
param.requires_grad = requires_grad
|
| 130 |
+
else:
|
| 131 |
+
# Fallback to generic approach
|
| 132 |
+
blocks = self._get_feature_blocks(model)
|
| 133 |
+
self._freeze_except_last_n(blocks, 2)
|
| 134 |
+
|
| 135 |
+
# Replace classifier
|
| 136 |
+
in_features = model.head.fc.in_features
|
| 137 |
+
model.head.fc = nn.Sequential(
|
| 138 |
+
nn.Linear(in_features, 512),
|
| 139 |
+
nn.ReLU(inplace=True),
|
| 140 |
+
nn.Dropout(0.4),
|
| 141 |
+
nn.Linear(512, self.num_classes),
|
| 142 |
+
)
|
| 143 |
+
return model.to(self.device)
|
| 144 |
+
|
| 145 |
+
def get_model_regnetx(self):
|
| 146 |
+
"""
|
| 147 |
+
Load and configure a RegNetX model with customizable classifier.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Configured RegNetX model
|
| 151 |
+
"""
|
| 152 |
+
model = timm.create_model("regnetx_008.tv2_in1k", pretrained=True)
|
| 153 |
+
|
| 154 |
+
if self.freeze_layers:
|
| 155 |
+
for param in model.parameters():
|
| 156 |
+
param.requires_grad = False
|
| 157 |
+
|
| 158 |
+
# RegNetX typically has 'stem' + 'trunk' structure in timm
|
| 159 |
+
if hasattr(model, "trunk"):
|
| 160 |
+
# Unfreeze final stages of the trunk
|
| 161 |
+
trunk_blocks = list(model.trunk.children())
|
| 162 |
+
# Unfreeze approximately last 25% of trunk blocks
|
| 163 |
+
unfreeze_from = max(0, int(len(trunk_blocks) * 0.75))
|
| 164 |
+
for i in range(unfreeze_from, len(trunk_blocks)):
|
| 165 |
+
for param in trunk_blocks[i].parameters():
|
| 166 |
+
param.requires_grad = True
|
| 167 |
+
|
| 168 |
+
# Always unfreeze the classifier/head for fine-tuning
|
| 169 |
+
for param in model.head.parameters():
|
| 170 |
+
param.requires_grad = True
|
| 171 |
+
|
| 172 |
+
# Replace classifier
|
| 173 |
+
in_features = model.head.fc.in_features
|
| 174 |
+
model.head.fc = nn.Sequential(
|
| 175 |
+
nn.Linear(in_features, 512),
|
| 176 |
+
nn.ReLU(inplace=True),
|
| 177 |
+
nn.Dropout(0.4),
|
| 178 |
+
nn.Linear(512, self.num_classes),
|
| 179 |
+
)
|
| 180 |
+
return model.to(self.device)
|
utils/Trainer.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import gc
|
| 8 |
+
|
| 9 |
+
from Evaluator import ClassificationEvaluator
|
| 10 |
+
from Callback import EarlyStopping
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def train_model(
|
| 14 |
+
model,
|
| 15 |
+
criterion,
|
| 16 |
+
optimizer,
|
| 17 |
+
scheduler,
|
| 18 |
+
train_loader,
|
| 19 |
+
val_loader,
|
| 20 |
+
early_stopping,
|
| 21 |
+
epochs=15,
|
| 22 |
+
use_ddp=False,
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Train the model and perform validation using multiple GPUs.
|
| 26 |
+
Supports both DataParallel (DP) and DistributedDataParallel (DDP) modes.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model: Model to train
|
| 30 |
+
criterion: Loss function
|
| 31 |
+
optimizer: Optimizer for training
|
| 32 |
+
scheduler: Learning rate scheduler
|
| 33 |
+
train_loader: DataLoader for training data
|
| 34 |
+
val_loader: DataLoader for validation data
|
| 35 |
+
early_stopping: Early stopping handler
|
| 36 |
+
epochs: Maximum number of epochs to train
|
| 37 |
+
use_ddp: Whether to use DistributedDataParallel (True) or DataParallel (False)
|
| 38 |
+
"""
|
| 39 |
+
# Check available GPUs
|
| 40 |
+
num_gpus = torch.cuda.device_count()
|
| 41 |
+
if num_gpus < 2:
|
| 42 |
+
print(
|
| 43 |
+
f"Warning: Requested multi-GPU training but only {num_gpus} GPU(s) available. Continuing with available resources."
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
print(f"Using {num_gpus} GPUs for training")
|
| 47 |
+
|
| 48 |
+
# Setup device and model
|
| 49 |
+
if num_gpus >= 2:
|
| 50 |
+
if use_ddp:
|
| 51 |
+
# For DistributedDataParallel
|
| 52 |
+
import torch.distributed as dist
|
| 53 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 54 |
+
|
| 55 |
+
# Initialize process group
|
| 56 |
+
dist.init_process_group(backend="nccl")
|
| 57 |
+
local_rank = dist.get_rank()
|
| 58 |
+
torch.cuda.set_device(local_rank)
|
| 59 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 60 |
+
|
| 61 |
+
model = model.to(device)
|
| 62 |
+
model = DDP(model, device_ids=[local_rank])
|
| 63 |
+
else:
|
| 64 |
+
# For DataParallel (simpler to use)
|
| 65 |
+
device = torch.device("cuda:0")
|
| 66 |
+
model = model.to(device)
|
| 67 |
+
model = torch.nn.DataParallel(model)
|
| 68 |
+
else:
|
| 69 |
+
# Single GPU
|
| 70 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
model = model.to(device)
|
| 72 |
+
|
| 73 |
+
train_losses = []
|
| 74 |
+
val_losses = []
|
| 75 |
+
train_accs = []
|
| 76 |
+
val_accs = []
|
| 77 |
+
|
| 78 |
+
# Store validation predictions and labels for final evaluation
|
| 79 |
+
all_val_labels = []
|
| 80 |
+
all_val_preds = []
|
| 81 |
+
all_val_scores = []
|
| 82 |
+
|
| 83 |
+
for epoch in range(epochs):
|
| 84 |
+
print(f"Epoch {epoch+1}/{epochs}")
|
| 85 |
+
|
| 86 |
+
# Training phase
|
| 87 |
+
model.train()
|
| 88 |
+
running_loss = 0.0
|
| 89 |
+
correct = 0
|
| 90 |
+
total = 0
|
| 91 |
+
|
| 92 |
+
for inputs, labels in tqdm(train_loader, desc="Training"):
|
| 93 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 94 |
+
|
| 95 |
+
optimizer.zero_grad()
|
| 96 |
+
outputs = model(inputs)
|
| 97 |
+
loss = criterion(outputs, labels)
|
| 98 |
+
loss.backward()
|
| 99 |
+
optimizer.step()
|
| 100 |
+
|
| 101 |
+
running_loss += loss.item() * inputs.size(0)
|
| 102 |
+
_, predicted = torch.max(outputs, 1)
|
| 103 |
+
total += labels.size(0)
|
| 104 |
+
correct += (predicted == labels).sum().item()
|
| 105 |
+
|
| 106 |
+
epoch_train_loss = running_loss / len(train_loader.dataset)
|
| 107 |
+
epoch_train_acc = correct / total
|
| 108 |
+
train_losses.append(epoch_train_loss)
|
| 109 |
+
train_accs.append(epoch_train_acc)
|
| 110 |
+
|
| 111 |
+
# Validation phase
|
| 112 |
+
model.eval()
|
| 113 |
+
running_loss = 0.0
|
| 114 |
+
correct = 0
|
| 115 |
+
total = 0
|
| 116 |
+
|
| 117 |
+
all_labels = []
|
| 118 |
+
all_preds = []
|
| 119 |
+
all_scores = []
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for inputs, labels in tqdm(val_loader, desc="Validation"):
|
| 123 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 124 |
+
outputs = model(inputs)
|
| 125 |
+
loss = criterion(outputs, labels)
|
| 126 |
+
|
| 127 |
+
running_loss += loss.item() * inputs.size(0)
|
| 128 |
+
probs = F.softmax(outputs, dim=1)
|
| 129 |
+
_, predicted = torch.max(outputs, 1)
|
| 130 |
+
total += labels.size(0)
|
| 131 |
+
correct += (predicted == labels).sum().item()
|
| 132 |
+
|
| 133 |
+
all_labels.extend(labels.cpu().numpy().tolist())
|
| 134 |
+
all_preds.extend(predicted.cpu().numpy().tolist())
|
| 135 |
+
all_scores.append(probs.cpu().numpy())
|
| 136 |
+
|
| 137 |
+
epoch_val_loss = running_loss / len(val_loader.dataset)
|
| 138 |
+
epoch_val_acc = correct / total
|
| 139 |
+
val_losses.append(epoch_val_loss)
|
| 140 |
+
val_accs.append(epoch_val_acc)
|
| 141 |
+
|
| 142 |
+
all_scores = np.vstack(all_scores) if all_scores else np.array([])
|
| 143 |
+
|
| 144 |
+
# Store validation results for the final epoch
|
| 145 |
+
all_val_labels = all_labels
|
| 146 |
+
all_val_preds = all_preds
|
| 147 |
+
all_val_scores = all_scores
|
| 148 |
+
|
| 149 |
+
# Update learning rate scheduler
|
| 150 |
+
scheduler.step(epoch_val_loss)
|
| 151 |
+
|
| 152 |
+
print(f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}")
|
| 153 |
+
print(f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")
|
| 154 |
+
print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
|
| 155 |
+
|
| 156 |
+
# Check early stopping
|
| 157 |
+
early_stopping(epoch_val_loss)
|
| 158 |
+
if early_stopping.early_stop:
|
| 159 |
+
print("Early stopping triggered!")
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
# Free up memory
|
| 163 |
+
del all_labels, all_preds, all_scores
|
| 164 |
+
gc.collect()
|
| 165 |
+
torch.cuda.empty_cache()
|
| 166 |
+
|
| 167 |
+
# Clean up DDP if used
|
| 168 |
+
if num_gpus >= 2 and use_ddp:
|
| 169 |
+
dist.destroy_process_group()
|
| 170 |
+
|
| 171 |
+
return (
|
| 172 |
+
model,
|
| 173 |
+
train_losses,
|
| 174 |
+
val_losses,
|
| 175 |
+
train_accs,
|
| 176 |
+
val_accs,
|
| 177 |
+
all_val_labels,
|
| 178 |
+
all_val_preds,
|
| 179 |
+
all_val_scores,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def model_train(model, train_loader, val_loader, dataset, epochs=20):
|
| 184 |
+
model_name = type(model).__name__
|
| 185 |
+
if hasattr(model, "pretrained_cfg") and "name" in model.pretrained_cfg:
|
| 186 |
+
model_name = model.pretrained_cfg["name"]
|
| 187 |
+
|
| 188 |
+
print(f"\n{'='*20} Training {model_name} {'='*20}\n")
|
| 189 |
+
|
| 190 |
+
class_names = dataset.classes
|
| 191 |
+
num_classes = len(class_names)
|
| 192 |
+
learning_rate = 0.001
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 196 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 197 |
+
optimizer, mode="min", factor=0.1, patience=3
|
| 198 |
+
)
|
| 199 |
+
early_stopping = EarlyStopping(patience=5)
|
| 200 |
+
|
| 201 |
+
(
|
| 202 |
+
model,
|
| 203 |
+
train_losses,
|
| 204 |
+
val_losses,
|
| 205 |
+
train_accs,
|
| 206 |
+
val_accs,
|
| 207 |
+
val_labels,
|
| 208 |
+
val_preds,
|
| 209 |
+
val_scores,
|
| 210 |
+
) = train_model(
|
| 211 |
+
model,
|
| 212 |
+
nn.CrossEntropyLoss(),
|
| 213 |
+
optimizer,
|
| 214 |
+
scheduler,
|
| 215 |
+
train_loader,
|
| 216 |
+
val_loader,
|
| 217 |
+
early_stopping,
|
| 218 |
+
epochs=epochs,
|
| 219 |
+
use_ddp=False,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
print(f"\n{'='*20} Evaluation for {model_name} {'='*20}\n")
|
| 223 |
+
evaluator = ClassificationEvaluator(
|
| 224 |
+
num_classes=num_classes,
|
| 225 |
+
class_names=class_names,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
evaluator.plot_training_history(train_losses, val_losses, train_accs, val_accs)
|
| 229 |
+
# Process validation predictions and labels
|
| 230 |
+
try:
|
| 231 |
+
evaluator.plot_confusion_matrix(val_labels, val_preds)
|
| 232 |
+
evaluator.plot_per_class_accuracy(val_labels, val_preds)
|
| 233 |
+
|
| 234 |
+
# Get metrics from the updated function including kappa
|
| 235 |
+
accuracy, report_dict, roc_auc_dict, pr_auc_dict, kappa = (
|
| 236 |
+
evaluator.compute_metrics(
|
| 237 |
+
val_labels,
|
| 238 |
+
val_preds,
|
| 239 |
+
val_scores,
|
| 240 |
+
model_name,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Build a results dictionary including kappa
|
| 245 |
+
results = {
|
| 246 |
+
"accuracy": accuracy,
|
| 247 |
+
"report": report_dict,
|
| 248 |
+
"roc_auc": roc_auc_dict,
|
| 249 |
+
"pr_auc": pr_auc_dict,
|
| 250 |
+
"kappa": kappa,
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
return results
|
| 254 |
+
except Exception as viz_error:
|
| 255 |
+
print(f"Error in visualization: {viz_error}")
|
| 256 |
+
import traceback
|
| 257 |
+
|
| 258 |
+
traceback.print_exc()
|
| 259 |
+
return {"accuracy": None}
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error occurred when training {model_name}: {e}")
|
| 263 |
+
import traceback
|
| 264 |
+
|
| 265 |
+
traceback.print_exc()
|
| 266 |
+
return {"accuracy": None}
|
| 267 |
+
finally:
|
| 268 |
+
# Clean up memory
|
| 269 |
+
if "optimizer" in locals():
|
| 270 |
+
del optimizer
|
| 271 |
+
if "scheduler" in locals():
|
| 272 |
+
del scheduler
|
| 273 |
+
if "early_stopping" in locals():
|
| 274 |
+
del early_stopping
|
| 275 |
+
if "train_losses" in locals():
|
| 276 |
+
del train_losses
|
| 277 |
+
del val_losses
|
| 278 |
+
del train_accs
|
| 279 |
+
del val_accs
|
| 280 |
+
del val_labels
|
| 281 |
+
del val_preds
|
| 282 |
+
del val_scores
|
| 283 |
+
|
| 284 |
+
gc.collect()
|
| 285 |
+
torch.cuda.empty_cache()
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|