Spaces:
Sleeping
Sleeping
Commit ·
b8c9192
1
Parent(s): 1691c67
add src
Browse files- .gitignore +148 -0
- LICENCE.md +21 -0
- README.md +41 -1
- RESULTS.md +87 -0
- app.py +206 -0
- augmentations.py +85 -0
- checkpoints/advamd.pt +3 -0
- checkpoints/drus.pt +3 -0
- checkpoints/pig.pt +3 -0
- dataloader.py +55 -0
- model.py +61 -0
- requirements.txt +10 -0
- run_inference.py +99 -0
- test.py +368 -0
- train.py +366 -0
.gitignore
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
| 2 |
+
|
| 3 |
+
deep-learning-models
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
backup
|
| 6 |
+
examples-local
|
| 7 |
+
|
| 8 |
+
### Python template
|
| 9 |
+
# Byte-compiled / optimized / DLL files
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.py[cod]
|
| 12 |
+
*$py.class
|
| 13 |
+
|
| 14 |
+
# C extensions
|
| 15 |
+
*.so
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
.Python
|
| 19 |
+
env/
|
| 20 |
+
build/
|
| 21 |
+
develop-eggs/
|
| 22 |
+
dist/
|
| 23 |
+
downloads/
|
| 24 |
+
eggs/
|
| 25 |
+
.eggs/
|
| 26 |
+
lib/
|
| 27 |
+
lib64/
|
| 28 |
+
parts/
|
| 29 |
+
sdist/
|
| 30 |
+
var/
|
| 31 |
+
wheels/
|
| 32 |
+
*.egg-info/
|
| 33 |
+
.installed.cfg
|
| 34 |
+
*.egg
|
| 35 |
+
|
| 36 |
+
# PyInstaller
|
| 37 |
+
# Usually these files are written by a python script from a template
|
| 38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 39 |
+
*.manifest
|
| 40 |
+
*.spec
|
| 41 |
+
|
| 42 |
+
# Installer logs
|
| 43 |
+
pip-log.txt
|
| 44 |
+
pip-delete-this-directory.txt
|
| 45 |
+
|
| 46 |
+
# Unit test / coverage reports
|
| 47 |
+
htmlcov/
|
| 48 |
+
.tox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*,cover
|
| 55 |
+
.hypothesis/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
|
| 65 |
+
# Flask stuff:
|
| 66 |
+
instance/
|
| 67 |
+
.webassets-cache
|
| 68 |
+
|
| 69 |
+
# Scrapy stuff:
|
| 70 |
+
.scrapy
|
| 71 |
+
|
| 72 |
+
# Sphinx documentation
|
| 73 |
+
docs/_build/
|
| 74 |
+
|
| 75 |
+
# PyBuilder
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# pyenv
|
| 82 |
+
.python-version
|
| 83 |
+
|
| 84 |
+
# celery beat schedule file
|
| 85 |
+
celerybeat-schedule
|
| 86 |
+
|
| 87 |
+
# SageMath parsed files
|
| 88 |
+
*.sage.py
|
| 89 |
+
|
| 90 |
+
# dotenv
|
| 91 |
+
.env
|
| 92 |
+
|
| 93 |
+
# virtualenv
|
| 94 |
+
.venv
|
| 95 |
+
venv/
|
| 96 |
+
ENV/
|
| 97 |
+
|
| 98 |
+
# Spyder project settings
|
| 99 |
+
.spyderproject
|
| 100 |
+
|
| 101 |
+
# Rope project settings
|
| 102 |
+
.ropeproject
|
| 103 |
+
### JetBrains template
|
| 104 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
|
| 105 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
| 106 |
+
|
| 107 |
+
# User-specific stuff:
|
| 108 |
+
.idea
|
| 109 |
+
.idea/**/workspace.xml
|
| 110 |
+
.idea/**/tasks.xml
|
| 111 |
+
.idea/dictionaries
|
| 112 |
+
|
| 113 |
+
# Sensitive or high-churn files:
|
| 114 |
+
.idea/**/dataSources/
|
| 115 |
+
.idea/**/dataSources.ids
|
| 116 |
+
.idea/**/dataSources.xml
|
| 117 |
+
.idea/**/dataSources.local.xml
|
| 118 |
+
.idea/**/sqlDataSources.xml
|
| 119 |
+
.idea/**/dynamic.xml
|
| 120 |
+
.idea/**/uiDesigner.xml
|
| 121 |
+
|
| 122 |
+
# Gradle:
|
| 123 |
+
.idea/**/gradle.xml
|
| 124 |
+
.idea/**/libraries
|
| 125 |
+
|
| 126 |
+
# Mongo Explorer plugin:
|
| 127 |
+
.idea/**/mongoSettings.xml
|
| 128 |
+
|
| 129 |
+
## File-based project format:
|
| 130 |
+
*.iws
|
| 131 |
+
|
| 132 |
+
## Plugin-specific files:
|
| 133 |
+
|
| 134 |
+
# IntelliJ
|
| 135 |
+
/out/
|
| 136 |
+
|
| 137 |
+
# mpeltonen/sbt-idea plugin
|
| 138 |
+
.idea_modules/
|
| 139 |
+
|
| 140 |
+
# JIRA plugin
|
| 141 |
+
atlassian-ide-plugin.xml
|
| 142 |
+
|
| 143 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
| 144 |
+
com_crashlytics_export_strings.xml
|
| 145 |
+
crashlytics.properties
|
| 146 |
+
crashlytics-build.properties
|
| 147 |
+
fabric.properties
|
| 148 |
+
|
LICENCE.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 NIH/DIR
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -12,4 +12,44 @@ license: mit
|
|
| 12 |
short_description: Framework for Classifying patient-based AMD in CFP images
|
| 13 |
---
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
short_description: Framework for Classifying patient-based AMD in CFP images
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# DeepSeeNet PyTorch
|
| 16 |
+
|
| 17 |
+
This repository is a PyTorch reimplementation of the original DeepSeeNet model:
|
| 18 |
+
|
| 19 |
+
https://github.com/ncbi-nlp/DeepSeeNet
|
| 20 |
+
|
| 21 |
+
DeepSeeNet predicts patient-level AREDS Simplified Severity Scale scores for age-related macular degeneration (AMD) from bilateral color fundus photographs. The model follows the original DeepSeeNet design by first predicting eye-level AMD risk factors, then combining predictions from both eyes into a patient-level simplified severity score.
|
| 22 |
+
|
| 23 |
+
## Tasks
|
| 24 |
+
|
| 25 |
+
The implementation trains three image-level subnetworks:
|
| 26 |
+
|
| 27 |
+
| Task | Classes | Output |
|
| 28 |
+
|---|---:|---|
|
| 29 |
+
| `ADVAMD` | 2 | late AMD absent / present |
|
| 30 |
+
| `DRUS` | 3 | small/none, medium, large drusen |
|
| 31 |
+
| `PIG` | 2 | pigmentary abnormality absent / present |
|
| 32 |
+
|
| 33 |
+
The final AREDS simplified score is computed from bilateral predictions:
|
| 34 |
+
|
| 35 |
+
- score `5` if late AMD is predicted in either eye
|
| 36 |
+
- otherwise, score is based on large drusen and pigmentary abnormalities across both eyes
|
| 37 |
+
- bilateral medium drusen contributes one point
|
| 38 |
+
|
| 39 |
+
## Citation
|
| 40 |
+
|
| 41 |
+
If you use this repository, please cite the original DeepSeeNet paper:
|
| 42 |
+
|
| 43 |
+
```bibtex
|
| 44 |
+
@article{peng2019deepseenet,
|
| 45 |
+
title={DeepSeeNet: A Deep Learning Model for Automated Classification of Patient-based Age-related Macular Degeneration Severity from Color Fundus Photographs},
|
| 46 |
+
author={Peng, Yifan and Dharssi, Shazia and Chen, Qingyu and Keenan, Tiarnan D. and Agr\'{o}n, Elvira and Wong, Wai T. and Chew, Emily Y. and Lu, Zhiyong},
|
| 47 |
+
journal={Ophthalmology},
|
| 48 |
+
volume={126},
|
| 49 |
+
number={4},
|
| 50 |
+
pages={565--575},
|
| 51 |
+
year={2019},
|
| 52 |
+
publisher={Elsevier},
|
| 53 |
+
doi={10.1016/j.ophtha.2018.11.015}
|
| 54 |
+
}
|
| 55 |
+
```
|
RESULTS.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Checkpoint Results
|
| 2 |
+
|
| 3 |
+
## ADVAMD
|
| 4 |
+
|
| 5 |
+
```text
|
| 6 |
+
Task: ADVAMD | endpoint: late_amd | positive_class=1
|
| 7 |
+
|
| 8 |
+
Metrics
|
| 9 |
+
-------
|
| 10 |
+
overall_accuracy 0.9658 (0.9628-0.9689)
|
| 11 |
+
sensitivity 0.8417 (0.8255-0.8576)
|
| 12 |
+
specificity 0.9852 (0.9831-0.9874)
|
| 13 |
+
kappa 0.8498 (0.8367-0.8632)
|
| 14 |
+
auc 0.9811 (0.9777-0.9844)
|
| 15 |
+
|
| 16 |
+
Classifier metrics
|
| 17 |
+
------------------
|
| 18 |
+
loss 0.1119
|
| 19 |
+
exact_accuracy 0.9658 (0.9628-0.9689)
|
| 20 |
+
exact_kappa 0.8498 (0.8367-0.8632)
|
| 21 |
+
|
| 22 |
+
Confusion matrix (rows=true, cols=pred):
|
| 23 |
+
[[11217 168]
|
| 24 |
+
[ 282 1499]]
|
| 25 |
+
|
| 26 |
+
Binary confusion matrix (rows=true, cols=pred):
|
| 27 |
+
[[11217 168]
|
| 28 |
+
[ 282 1499]]
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## DRUS
|
| 32 |
+
|
| 33 |
+
```text
|
| 34 |
+
Task: DRUS | endpoint: large_drusen | positive_class=2
|
| 35 |
+
|
| 36 |
+
Metrics
|
| 37 |
+
-------
|
| 38 |
+
overall_accuracy 0.8816 (0.8763-0.8869)
|
| 39 |
+
sensitivity 0.7708 (0.7588-0.7832)
|
| 40 |
+
specificity 0.9368 (0.9319-0.9418)
|
| 41 |
+
kappa 0.7263 (0.7144-0.7386)
|
| 42 |
+
auc 0.9489 (0.9452-0.9524)
|
| 43 |
+
|
| 44 |
+
Classifier metrics
|
| 45 |
+
------------------
|
| 46 |
+
loss 0.5903
|
| 47 |
+
exact_accuracy 0.7471 (0.7400-0.7542)
|
| 48 |
+
exact_kappa 0.6170 (0.6066-0.6280)
|
| 49 |
+
macro_ovr_auc 0.8960 (0.8919-0.9001)
|
| 50 |
+
|
| 51 |
+
Confusion matrix (rows=true, cols=pred):
|
| 52 |
+
[[4205 820 115]
|
| 53 |
+
[ 951 2255 440]
|
| 54 |
+
[ 182 822 3376]]
|
| 55 |
+
|
| 56 |
+
Binary confusion matrix (rows=true, cols=pred):
|
| 57 |
+
[[8231 555]
|
| 58 |
+
[1004 3376]]
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## PIG
|
| 62 |
+
|
| 63 |
+
```text
|
| 64 |
+
Task: PIG | endpoint: pigmentary_abnormality | positive_class=1
|
| 65 |
+
|
| 66 |
+
Metrics
|
| 67 |
+
-------
|
| 68 |
+
overall_accuracy 0.8925 (0.8874-0.8976)
|
| 69 |
+
sensitivity 0.8606 (0.8502-0.8701)
|
| 70 |
+
specificity 0.9113 (0.9053-0.9171)
|
| 71 |
+
kappa 0.7702 (0.7594-0.7811)
|
| 72 |
+
auc 0.9498 (0.9460-0.9536)
|
| 73 |
+
|
| 74 |
+
Classifier metrics
|
| 75 |
+
------------------
|
| 76 |
+
loss 0.2734
|
| 77 |
+
exact_accuracy 0.8925 (0.8874-0.8976)
|
| 78 |
+
exact_kappa 0.7702 (0.7594-0.7811)
|
| 79 |
+
|
| 80 |
+
Confusion matrix (rows=true, cols=pred):
|
| 81 |
+
[[7541 734]
|
| 82 |
+
[ 682 4209]]
|
| 83 |
+
|
| 84 |
+
Binary confusion matrix (rows=true, cols=pred):
|
| 85 |
+
[[7541 734]
|
| 86 |
+
[ 682 4209]]
|
| 87 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from augmentations import get_val_transforms
|
| 10 |
+
from model import DeepSeeNet
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
N_CLASSES = {
|
| 14 |
+
"ADVAMD": 2,
|
| 15 |
+
"DRUS": 3,
|
| 16 |
+
"PIG": 2,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
LABELS = {
|
| 20 |
+
"ADVAMD": ["no_late_amd", "late_amd"],
|
| 21 |
+
"DRUS": ["small_none", "medium", "large"],
|
| 22 |
+
"PIG": ["no_pigment", "pigment"],
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AlbumentationsTransform:
|
| 27 |
+
def __init__(self, transform):
|
| 28 |
+
self.transform = transform
|
| 29 |
+
|
| 30 |
+
def __call__(self, image):
|
| 31 |
+
return self.transform(image=np.asarray(image))["image"]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_args():
|
| 35 |
+
parser = argparse.ArgumentParser()
|
| 36 |
+
parser.add_argument("--checkpoint-folder", default="./checkpoints")
|
| 37 |
+
parser.add_argument("--backbone", default="inception_v3")
|
| 38 |
+
parser.add_argument("--image-size", type=int, default=1024)
|
| 39 |
+
parser.add_argument("--server-name", default="127.0.0.1")
|
| 40 |
+
parser.add_argument("--server-port", type=int, default=7860)
|
| 41 |
+
parser.add_argument("--share", action="store_true")
|
| 42 |
+
return parser.parse_args()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_model(path, task, backbone, device):
|
| 46 |
+
checkpoint = torch.load(path, map_location=device)
|
| 47 |
+
checkpoint_args = checkpoint.get("args", {})
|
| 48 |
+
|
| 49 |
+
model = DeepSeeNet(
|
| 50 |
+
n_classes=N_CLASSES[task],
|
| 51 |
+
backbone=checkpoint_args.get("backbone", backbone),
|
| 52 |
+
pretrained=False,
|
| 53 |
+
).to(device)
|
| 54 |
+
|
| 55 |
+
model.load_state_dict(checkpoint["model"])
|
| 56 |
+
model.eval()
|
| 57 |
+
return model
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_image(image, transform, device):
|
| 61 |
+
if image is None:
|
| 62 |
+
raise ValueError("Please upload both left and right images.")
|
| 63 |
+
|
| 64 |
+
image = image.convert("RGB")
|
| 65 |
+
return transform(image).unsqueeze(0).to(device)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
def predict(model, image, task):
|
| 70 |
+
logits = model(image)[0].detach().cpu()
|
| 71 |
+
probs = F.softmax(logits, dim=0)
|
| 72 |
+
pred = int(torch.argmax(logits).item())
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"prediction": pred,
|
| 76 |
+
"label": LABELS[task][pred],
|
| 77 |
+
"confidence": float(probs[pred]),
|
| 78 |
+
"probabilities": {
|
| 79 |
+
LABELS[task][i]: float(probs[i])
|
| 80 |
+
for i in range(len(LABELS[task]))
|
| 81 |
+
},
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def simplified_score(scores):
|
| 86 |
+
if scores["ADVAMD"]["left"]["prediction"] == 1 or scores["ADVAMD"]["right"]["prediction"] == 1:
|
| 87 |
+
return 5
|
| 88 |
+
|
| 89 |
+
score = 0
|
| 90 |
+
score += scores["PIG"]["left"]["prediction"] == 1
|
| 91 |
+
score += scores["PIG"]["right"]["prediction"] == 1
|
| 92 |
+
score += scores["DRUS"]["left"]["prediction"] == 2
|
| 93 |
+
score += scores["DRUS"]["right"]["prediction"] == 2
|
| 94 |
+
score += (
|
| 95 |
+
scores["DRUS"]["left"]["prediction"] == 1
|
| 96 |
+
and scores["DRUS"]["right"]["prediction"] == 1
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return int(min(score, 5))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def format_probs(probabilities):
|
| 103 |
+
return " | ".join(
|
| 104 |
+
f"{label}: {prob:.3f}"
|
| 105 |
+
for label, prob in probabilities.items()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def model_info(args, device):
|
| 110 |
+
return f"""
|
| 111 |
+
# DeepSeeNet
|
| 112 |
+
|
| 113 |
+
<div style="display: grid; grid-template-columns: repeat(4, max-content); gap: 0.75rem 2rem; align-items: center;">
|
| 114 |
+
<div><b>Model</b><br><code>{args.backbone}</code></div>
|
| 115 |
+
<div><b>Input size</b><br><code>{args.image_size} × {args.image_size}</code></div>
|
| 116 |
+
<div><b>Device</b><br><code>{device.type}</code></div>
|
| 117 |
+
<div><b>Checkpoint folder</b><br><code>{args.checkpoint_folder}</code></div>
|
| 118 |
+
</div>
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def make_app(args):
|
| 123 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 124 |
+
checkpoint_folder = Path(args.checkpoint_folder)
|
| 125 |
+
transform = AlbumentationsTransform(get_val_transforms(args.image_size))
|
| 126 |
+
|
| 127 |
+
models = {
|
| 128 |
+
"ADVAMD": load_model(checkpoint_folder / "advamd.pt", "ADVAMD", args.backbone, device),
|
| 129 |
+
"DRUS": load_model(checkpoint_folder / "drus.pt", "DRUS", args.backbone, device),
|
| 130 |
+
"PIG": load_model(checkpoint_folder / "pig.pt", "PIG", args.backbone, device),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
def run(left_image, right_image):
|
| 134 |
+
left = load_image(left_image, transform, device)
|
| 135 |
+
right = load_image(right_image, transform, device)
|
| 136 |
+
|
| 137 |
+
scores = {}
|
| 138 |
+
for task, model in models.items():
|
| 139 |
+
scores[task] = {
|
| 140 |
+
"left": predict(model, left, task),
|
| 141 |
+
"right": predict(model, right, task),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
score = simplified_score(scores)
|
| 145 |
+
|
| 146 |
+
summary_rows = [
|
| 147 |
+
["AREDS simplified score", score],
|
| 148 |
+
["Left eye", f"{scores['DRUS']['left']['label']}, {scores['PIG']['left']['label']}, {scores['ADVAMD']['left']['label']}"],
|
| 149 |
+
["Right eye", f"{scores['DRUS']['right']['label']}, {scores['PIG']['right']['label']}, {scores['ADVAMD']['right']['label']}"],
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
detail_rows = []
|
| 153 |
+
for task in ["ADVAMD", "DRUS", "PIG"]:
|
| 154 |
+
for eye in ["left", "right"]:
|
| 155 |
+
result = scores[task][eye]
|
| 156 |
+
detail_rows.append(
|
| 157 |
+
[
|
| 158 |
+
task,
|
| 159 |
+
eye,
|
| 160 |
+
result["label"],
|
| 161 |
+
f"{result['confidence']:.3f}",
|
| 162 |
+
format_probs(result["probabilities"]),
|
| 163 |
+
]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return summary_rows, detail_rows
|
| 167 |
+
|
| 168 |
+
with gr.Blocks(title="DeepSeeNet") as demo:
|
| 169 |
+
gr.Markdown(model_info(args, device))
|
| 170 |
+
|
| 171 |
+
with gr.Row():
|
| 172 |
+
left_image = gr.Image(type="pil", label="Left image")
|
| 173 |
+
right_image = gr.Image(type="pil", label="Right image")
|
| 174 |
+
|
| 175 |
+
button = gr.Button("Run")
|
| 176 |
+
|
| 177 |
+
summary = gr.Dataframe(
|
| 178 |
+
headers=["Item", "Result"],
|
| 179 |
+
label="Summary",
|
| 180 |
+
)
|
| 181 |
+
details = gr.Dataframe(
|
| 182 |
+
headers=["Task", "Eye", "Prediction", "Confidence", "Probabilities"],
|
| 183 |
+
label="Model outputs",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
button.click(
|
| 187 |
+
run,
|
| 188 |
+
inputs=[left_image, right_image],
|
| 189 |
+
outputs=[summary, details],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return demo
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def main():
|
| 196 |
+
args = parse_args()
|
| 197 |
+
demo = make_app(args)
|
| 198 |
+
demo.launch(
|
| 199 |
+
server_name=args.server_name,
|
| 200 |
+
server_port=args.server_port,
|
| 201 |
+
share=args.share,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
main()
|
augmentations.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
augmentations.py
|
| 3 |
+
|
| 4 |
+
Simple camera-style augmentations for color fundus photography (CFP)
|
| 5 |
+
classification.
|
| 6 |
+
|
| 7 |
+
Expected input:
|
| 8 |
+
RGB NumPy image, shape (H, W, 3)
|
| 9 |
+
|
| 10 |
+
Dependencies:
|
| 11 |
+
pip install albumentations opencv-python
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import albumentations as A
|
| 16 |
+
from albumentations.pytorch import ToTensorV2
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 20 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_train_transforms(
|
| 24 |
+
image_size=1024,
|
| 25 |
+
mean=IMAGENET_MEAN,
|
| 26 |
+
std=IMAGENET_STD,
|
| 27 |
+
):
|
| 28 |
+
return A.Compose([
|
| 29 |
+
A.Resize(image_size, image_size),
|
| 30 |
+
|
| 31 |
+
# Geometry is safe
|
| 32 |
+
A.HorizontalFlip(p=0.5),
|
| 33 |
+
|
| 34 |
+
A.ShiftScaleRotate(
|
| 35 |
+
shift_limit=0.02,
|
| 36 |
+
scale_limit=0.03, # slightly reduced
|
| 37 |
+
rotate_limit=5, # slightly reduced
|
| 38 |
+
border_mode=0,
|
| 39 |
+
value=0,
|
| 40 |
+
p=0.3,
|
| 41 |
+
),
|
| 42 |
+
|
| 43 |
+
# MUCH weaker photometric changes
|
| 44 |
+
A.RandomBrightnessContrast(
|
| 45 |
+
brightness_limit=0.08, # ↓ from 0.15
|
| 46 |
+
contrast_limit=0.08,
|
| 47 |
+
p=0.3,
|
| 48 |
+
),
|
| 49 |
+
|
| 50 |
+
# Remove or reduce gamma
|
| 51 |
+
A.RandomGamma(
|
| 52 |
+
gamma_limit=(95, 105), # very mild
|
| 53 |
+
p=0.2,
|
| 54 |
+
),
|
| 55 |
+
|
| 56 |
+
# Remove hue shift entirely (important)
|
| 57 |
+
# Hue shifts are not realistic for fundus physiology
|
| 58 |
+
# -> comment this out or reduce heavily
|
| 59 |
+
# A.HueSaturationValue(...)
|
| 60 |
+
|
| 61 |
+
# Keep mild quality perturbation
|
| 62 |
+
A.OneOf([
|
| 63 |
+
A.GaussianBlur(blur_limit=(3, 5)),
|
| 64 |
+
A.Downscale(scale_min=0.85, scale_max=0.95, interpolation=cv2.INTER_LINEAR),
|
| 65 |
+
A.ImageCompression(quality_lower=80, quality_upper=100),
|
| 66 |
+
], p=0.15),
|
| 67 |
+
|
| 68 |
+
A.Normalize(mean=mean, std=std),
|
| 69 |
+
ToTensorV2(),
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_val_transforms(
|
| 74 |
+
image_size=1024,
|
| 75 |
+
mean=IMAGENET_MEAN,
|
| 76 |
+
std=IMAGENET_STD,
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Validation/test transforms.
|
| 80 |
+
"""
|
| 81 |
+
return A.Compose([
|
| 82 |
+
A.Resize(image_size, image_size),
|
| 83 |
+
A.Normalize(mean=mean, std=std),
|
| 84 |
+
ToTensorV2(),
|
| 85 |
+
])
|
checkpoints/advamd.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c600e3f70526da0c4d65d5e4d55a563f9a083d44da223385d7a8572e194e191e
|
| 3 |
+
size 89723328
|
checkpoints/drus.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:efde754bd4d6be4b9afc4d0c17d016f138f52239ab36c4ce74ba6b24cef16245
|
| 3 |
+
size 89697156
|
checkpoints/pig.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fcdafe1a939b5603b3fbf76276702a60c68b657b359a0702b4cf9ae74fea8ee
|
| 3 |
+
size 89696006
|
dataloader.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch datasets and dataloaders for AREDS fundus images."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Callable, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
TASKS = ("ADVAMD", "DRUS", "PIG")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DEFAULT_TRANSFORM = transforms.Compose(
|
| 18 |
+
[
|
| 19 |
+
transforms.Resize(224),
|
| 20 |
+
transforms.CenterCrop(224),
|
| 21 |
+
transforms.ToTensor(),
|
| 22 |
+
transforms.Normalize(
|
| 23 |
+
mean=(0.485, 0.456, 0.406),
|
| 24 |
+
std=(0.229, 0.224, 0.225),
|
| 25 |
+
),
|
| 26 |
+
]
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AREDSDataset(Dataset):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
csv_path: Union[str, Path],
|
| 34 |
+
image_root: Union[str, Path],
|
| 35 |
+
task: str,
|
| 36 |
+
transform: Optional[Callable[[Image.Image], Tensor]] = None,
|
| 37 |
+
) -> None:
|
| 38 |
+
task = task.upper()
|
| 39 |
+
if task not in TASKS:
|
| 40 |
+
raise ValueError(f"task must be one of {TASKS}")
|
| 41 |
+
self.image_root = Path(image_root)
|
| 42 |
+
self.task = task
|
| 43 |
+
self.transform = transform or DEFAULT_TRANSFORM
|
| 44 |
+
self.data = pd.read_csv(csv_path)
|
| 45 |
+
|
| 46 |
+
def __len__(self) -> int:
|
| 47 |
+
return len(self.data)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
|
| 50 |
+
row = self.data.iloc[index]
|
| 51 |
+
image_path = self.image_root / row.pathname
|
| 52 |
+
image = Image.open(image_path).convert("RGB")
|
| 53 |
+
image = self.transform(image)
|
| 54 |
+
label = torch.tensor(int(row[self.task]), dtype=torch.long)
|
| 55 |
+
return image, label
|
model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DeepSeeNet model definition."""
|
| 2 |
+
|
| 3 |
+
from torch import Tensor, nn
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import timm
|
| 7 |
+
except ImportError: # pragma: no cover - handled when timm is absent.
|
| 8 |
+
timm = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DeepSeeNet(nn.Module):
|
| 12 |
+
"""DeepSeeNet risk-factor classifier in PyTorch.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
n_classes: Number of output classes.
|
| 16 |
+
backbone: Any timm model name that supports ``num_classes=0``. The
|
| 17 |
+
default uses InceptionV3.
|
| 18 |
+
pretrained: Load ImageNet weights for the backbone.
|
| 19 |
+
dropout: Dropout probability used by the classifier head.
|
| 20 |
+
freeze_backbone: If true, keep the backbone frozen and train only the
|
| 21 |
+
classifier head.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
n_classes: int = 2,
|
| 27 |
+
backbone: str = "inception_v3",
|
| 28 |
+
pretrained: bool = True,
|
| 29 |
+
dropout: float = 0.5,
|
| 30 |
+
freeze_backbone: bool = False,
|
| 31 |
+
) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
if n_classes < 1:
|
| 34 |
+
raise ValueError("n_classes must be positive")
|
| 35 |
+
if timm is None:
|
| 36 |
+
raise ImportError("timm is required to build DeepSeeNet")
|
| 37 |
+
|
| 38 |
+
self.backbone_name = backbone
|
| 39 |
+
self.backbone = timm.create_model(
|
| 40 |
+
backbone,
|
| 41 |
+
pretrained=pretrained,
|
| 42 |
+
num_classes=0,
|
| 43 |
+
global_pool="avg",
|
| 44 |
+
)
|
| 45 |
+
in_features = self.backbone.num_features
|
| 46 |
+
self.classifier = nn.Sequential(
|
| 47 |
+
nn.Linear(in_features, 256),
|
| 48 |
+
nn.ReLU(inplace=True),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
nn.Linear(256, 128),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
nn.Dropout(dropout),
|
| 53 |
+
nn.Linear(128, n_classes),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if freeze_backbone:
|
| 57 |
+
self.backbone.requires_grad_(False)
|
| 58 |
+
|
| 59 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 60 |
+
features = self.backbone(x)
|
| 61 |
+
return self.classifier(features)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
timm
|
| 4 |
+
albumentations
|
| 5 |
+
numpy
|
| 6 |
+
pandas
|
| 7 |
+
scikit-learn
|
| 8 |
+
tqdm
|
| 9 |
+
gradio
|
| 10 |
+
pillow
|
run_inference.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run DeepSeeNet inference for AREDS simplified score."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from dataloader import DEFAULT_TRANSFORM
|
| 10 |
+
from model import DeepSeeNet
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
N_CLASSES = {
|
| 14 |
+
"ADVAMD": 2,
|
| 15 |
+
"DRUS": 3,
|
| 16 |
+
"PIG": 2,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_args() -> argparse.Namespace:
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument("--left-image", required=True)
|
| 23 |
+
parser.add_argument("--right-image", required=True)
|
| 24 |
+
parser.add_argument("--advamd-checkpoint", required=True)
|
| 25 |
+
parser.add_argument("--drus-checkpoint", required=True)
|
| 26 |
+
parser.add_argument("--pig-checkpoint", required=True)
|
| 27 |
+
parser.add_argument("--backbone", default="inception_v3")
|
| 28 |
+
return parser.parse_args()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_model(checkpoint_path: str, task: str, backbone: str, device) -> DeepSeeNet:
|
| 32 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 33 |
+
checkpoint_args = checkpoint.get("args", {})
|
| 34 |
+
model = DeepSeeNet(
|
| 35 |
+
n_classes=N_CLASSES[task],
|
| 36 |
+
backbone=checkpoint_args.get("backbone", backbone),
|
| 37 |
+
pretrained=False,
|
| 38 |
+
).to(device)
|
| 39 |
+
model.load_state_dict(checkpoint["model"])
|
| 40 |
+
model.eval()
|
| 41 |
+
return model
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_image(path: str, device) -> torch.Tensor:
|
| 45 |
+
image = Image.open(path).convert("RGB")
|
| 46 |
+
return DEFAULT_TRANSFORM(image).unsqueeze(0).to(device)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def predict(model: DeepSeeNet, image: torch.Tensor) -> int:
|
| 51 |
+
return int(model(image).argmax(dim=1).item())
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def simplified_score(scores: dict[str, tuple[int, int]]) -> int:
|
| 55 |
+
score = 0
|
| 56 |
+
if scores["ADVAMD"][0] or scores["ADVAMD"][1]:
|
| 57 |
+
return 5
|
| 58 |
+
score += scores["PIG"][0] == 1
|
| 59 |
+
score += scores["PIG"][1] == 1
|
| 60 |
+
score += scores["DRUS"][0] == 2
|
| 61 |
+
score += scores["DRUS"][1] == 2
|
| 62 |
+
score += scores["DRUS"][0] == 1 and scores["DRUS"][1] == 1
|
| 63 |
+
return min(score, 5)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main() -> None:
|
| 67 |
+
args = parse_args()
|
| 68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 69 |
+
images = {
|
| 70 |
+
"left": load_image(args.left_image, device),
|
| 71 |
+
"right": load_image(args.right_image, device),
|
| 72 |
+
}
|
| 73 |
+
checkpoints = {
|
| 74 |
+
"ADVAMD": args.advamd_checkpoint,
|
| 75 |
+
"DRUS": args.drus_checkpoint,
|
| 76 |
+
"PIG": args.pig_checkpoint,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
scores = {}
|
| 80 |
+
for task, checkpoint in checkpoints.items():
|
| 81 |
+
model = load_model(checkpoint, task, args.backbone, device)
|
| 82 |
+
scores[task] = (
|
| 83 |
+
predict(model, images["left"]),
|
| 84 |
+
predict(model, images["right"]),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
print(
|
| 88 |
+
json.dumps(
|
| 89 |
+
{
|
| 90 |
+
"simplified_score": simplified_score(scores),
|
| 91 |
+
"risk_factors": scores,
|
| 92 |
+
},
|
| 93 |
+
indent=2,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
main()
|
test.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Callable
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from sklearn.metrics import (
|
| 17 |
+
accuracy_score,
|
| 18 |
+
cohen_kappa_score,
|
| 19 |
+
confusion_matrix,
|
| 20 |
+
recall_score,
|
| 21 |
+
roc_auc_score,
|
| 22 |
+
)
|
| 23 |
+
except ImportError as exc:
|
| 24 |
+
raise ImportError(
|
| 25 |
+
"This evaluation script needs scikit-learn. Install with: pip install scikit-learn"
|
| 26 |
+
) from exc
|
| 27 |
+
|
| 28 |
+
from augmentations import get_val_transforms
|
| 29 |
+
from dataloader import AREDSDataset
|
| 30 |
+
from model import DeepSeeNet
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
N_CLASSES = {
|
| 34 |
+
"ADVAMD": 2,
|
| 35 |
+
"DRUS": 3,
|
| 36 |
+
"PIG": 2,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
DEFAULT_POSITIVE_CLASS = {
|
| 40 |
+
"ADVAMD": 1,
|
| 41 |
+
"DRUS": 2,
|
| 42 |
+
"PIG": 1,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
ENDPOINT_NAME = {
|
| 46 |
+
"ADVAMD": "late_amd",
|
| 47 |
+
"DRUS": "large_drusen",
|
| 48 |
+
"PIG": "pigmentary_abnormality",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def parse_args() -> argparse.Namespace:
|
| 53 |
+
parser = argparse.ArgumentParser()
|
| 54 |
+
parser.add_argument("--test-csv", required=True)
|
| 55 |
+
parser.add_argument("--image-root", required=True)
|
| 56 |
+
parser.add_argument("--checkpoint", required=True)
|
| 57 |
+
parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES)
|
| 58 |
+
parser.add_argument("--backbone", default="inception_v3")
|
| 59 |
+
parser.add_argument("--image-size", type=int, default=1024)
|
| 60 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 61 |
+
parser.add_argument("--num-workers", type=int, default=16)
|
| 62 |
+
|
| 63 |
+
parser.add_argument("--positive-class", type=int, default=None)
|
| 64 |
+
parser.add_argument("--bootstrap-iters", type=int, default=2000)
|
| 65 |
+
parser.add_argument("--seed", type=int, default=123)
|
| 66 |
+
parser.add_argument("--bootstrap-unit-column", default=None)
|
| 67 |
+
parser.add_argument("--output-dir", default=None)
|
| 68 |
+
return parser.parse_args()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class AlbumentationsTransform:
|
| 72 |
+
def __init__(self, transform) -> None:
|
| 73 |
+
self.transform = transform
|
| 74 |
+
|
| 75 |
+
def __call__(self, image):
|
| 76 |
+
return self.transform(image=np.asarray(image))["image"]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def collect_predictions(model: torch.nn.Module, loader: DataLoader, device: torch.device) -> dict[str, np.ndarray | float]:
|
| 81 |
+
model.eval()
|
| 82 |
+
total_loss = 0.0
|
| 83 |
+
total_samples = 0
|
| 84 |
+
all_labels: list[np.ndarray] = []
|
| 85 |
+
all_logits: list[np.ndarray] = []
|
| 86 |
+
|
| 87 |
+
for images, labels in tqdm(loader, desc="test"):
|
| 88 |
+
images = images.to(device)
|
| 89 |
+
labels = labels.to(device)
|
| 90 |
+
|
| 91 |
+
logits = model(images)
|
| 92 |
+
if isinstance(logits, (tuple, list)):
|
| 93 |
+
logits = logits[0]
|
| 94 |
+
|
| 95 |
+
loss = F.cross_entropy(logits, labels)
|
| 96 |
+
batch_size = labels.size(0)
|
| 97 |
+
total_loss += loss.item() * batch_size
|
| 98 |
+
total_samples += batch_size
|
| 99 |
+
|
| 100 |
+
all_labels.append(labels.detach().cpu().numpy())
|
| 101 |
+
all_logits.append(logits.detach().cpu().numpy())
|
| 102 |
+
|
| 103 |
+
labels_np = np.concatenate(all_labels).astype(int)
|
| 104 |
+
logits_np = np.concatenate(all_logits, axis=0)
|
| 105 |
+
probs_np = torch.softmax(torch.from_numpy(logits_np), dim=1).numpy()
|
| 106 |
+
preds_np = probs_np.argmax(axis=1).astype(int)
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"loss": float(total_loss / max(total_samples, 1)),
|
| 110 |
+
"labels": labels_np,
|
| 111 |
+
"logits": logits_np,
|
| 112 |
+
"probs": probs_np,
|
| 113 |
+
"preds": preds_np,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def specificity_score(y_true_bin: np.ndarray, y_pred_bin: np.ndarray) -> float:
|
| 118 |
+
tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
|
| 119 |
+
fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
|
| 120 |
+
denom = tn + fp
|
| 121 |
+
return float(tn / denom) if denom else float("nan")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def safe_auc(y_true_bin: np.ndarray, y_score: np.ndarray) -> float:
|
| 125 |
+
if len(np.unique(y_true_bin)) < 2:
|
| 126 |
+
return float("nan")
|
| 127 |
+
return float(roc_auc_score(y_true_bin, y_score))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def compute_metrics(
|
| 131 |
+
y_true: np.ndarray,
|
| 132 |
+
y_pred: np.ndarray,
|
| 133 |
+
probs: np.ndarray,
|
| 134 |
+
n_classes: int,
|
| 135 |
+
positive_class: int,
|
| 136 |
+
) -> dict[str, float]:
|
| 137 |
+
y_true_bin = (y_true == positive_class).astype(int)
|
| 138 |
+
y_pred_bin = (y_pred == positive_class).astype(int)
|
| 139 |
+
pos_score = probs[:, positive_class]
|
| 140 |
+
|
| 141 |
+
metrics = {
|
| 142 |
+
"loss": float("nan"),
|
| 143 |
+
"exact_accuracy": float(accuracy_score(y_true, y_pred)),
|
| 144 |
+
"exact_kappa": float(cohen_kappa_score(y_true, y_pred)),
|
| 145 |
+
"overall_accuracy": float(accuracy_score(y_true_bin, y_pred_bin)),
|
| 146 |
+
"sensitivity": float(recall_score(y_true_bin, y_pred_bin, pos_label=1, zero_division=0)),
|
| 147 |
+
"specificity": specificity_score(y_true_bin, y_pred_bin),
|
| 148 |
+
"kappa": float(cohen_kappa_score(y_true_bin, y_pred_bin)),
|
| 149 |
+
"auc": safe_auc(y_true_bin, pos_score),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
if n_classes > 2 and len(np.unique(y_true)) > 1:
|
| 153 |
+
try:
|
| 154 |
+
metrics["macro_ovr_auc"] = float(
|
| 155 |
+
roc_auc_score(y_true, probs, labels=list(range(n_classes)), multi_class="ovr", average="macro")
|
| 156 |
+
)
|
| 157 |
+
except ValueError:
|
| 158 |
+
metrics["macro_ovr_auc"] = float("nan")
|
| 159 |
+
|
| 160 |
+
return metrics
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def make_bootstrap_indices(
|
| 164 |
+
n: int,
|
| 165 |
+
n_iters: int,
|
| 166 |
+
rng: np.random.Generator,
|
| 167 |
+
units: np.ndarray | None = None,
|
| 168 |
+
) -> list[np.ndarray]:
|
| 169 |
+
if n_iters <= 0:
|
| 170 |
+
return []
|
| 171 |
+
|
| 172 |
+
if units is None:
|
| 173 |
+
return [rng.integers(0, n, size=n) for _ in range(n_iters)]
|
| 174 |
+
|
| 175 |
+
unique_units = np.array(pd.unique(units))
|
| 176 |
+
row_indices_by_unit = {unit: np.where(units == unit)[0] for unit in unique_units}
|
| 177 |
+
out = []
|
| 178 |
+
for _ in range(n_iters):
|
| 179 |
+
sampled_units = rng.choice(unique_units, size=len(unique_units), replace=True)
|
| 180 |
+
out.append(np.concatenate([row_indices_by_unit[u] for u in sampled_units]))
|
| 181 |
+
return out
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def bootstrap_ci(
|
| 185 |
+
metric_fn: Callable[[np.ndarray], dict[str, float]],
|
| 186 |
+
indices: list[np.ndarray],
|
| 187 |
+
) -> dict[str, dict[str, float]]:
|
| 188 |
+
if not indices:
|
| 189 |
+
return {}
|
| 190 |
+
|
| 191 |
+
values_by_metric: dict[str, list[float]] = {}
|
| 192 |
+
for idx in tqdm(indices, desc="bootstrap", leave=False):
|
| 193 |
+
vals = metric_fn(idx)
|
| 194 |
+
for key, value in vals.items():
|
| 195 |
+
values_by_metric.setdefault(key, []).append(value)
|
| 196 |
+
|
| 197 |
+
intervals: dict[str, dict[str, float]] = {}
|
| 198 |
+
for key, values in values_by_metric.items():
|
| 199 |
+
arr = np.asarray(values, dtype=float)
|
| 200 |
+
intervals[key] = {
|
| 201 |
+
"ci_low": float(np.nanpercentile(arr, 2.5)),
|
| 202 |
+
"ci_high": float(np.nanpercentile(arr, 97.5)),
|
| 203 |
+
}
|
| 204 |
+
return intervals
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def combine_with_ci(metrics: dict[str, float], ci: dict[str, dict[str, float]]) -> dict[str, Any]:
|
| 208 |
+
out: dict[str, Any] = {}
|
| 209 |
+
for key, value in metrics.items():
|
| 210 |
+
out[key] = {"value": float(value)}
|
| 211 |
+
if key in ci:
|
| 212 |
+
out[key].update(ci[key])
|
| 213 |
+
return out
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def print_metric_table(metrics_with_ci: dict[str, Any]) -> None:
|
| 217 |
+
print("\nMetrics")
|
| 218 |
+
print("-------")
|
| 219 |
+
for key in ["overall_accuracy", "sensitivity", "specificity", "kappa", "auc"]:
|
| 220 |
+
item = metrics_with_ci[key]
|
| 221 |
+
if "ci_low" in item:
|
| 222 |
+
print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})")
|
| 223 |
+
else:
|
| 224 |
+
print(f"{key:20s} {item['value']:.4f}")
|
| 225 |
+
|
| 226 |
+
print("\nClassifier metrics")
|
| 227 |
+
print("------------------")
|
| 228 |
+
for key in ["loss", "exact_accuracy", "exact_kappa", "macro_ovr_auc"]:
|
| 229 |
+
if key not in metrics_with_ci:
|
| 230 |
+
continue
|
| 231 |
+
item = metrics_with_ci[key]
|
| 232 |
+
if "ci_low" in item:
|
| 233 |
+
print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})")
|
| 234 |
+
else:
|
| 235 |
+
print(f"{key:20s} {item['value']:.4f}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main() -> None:
|
| 239 |
+
args = parse_args()
|
| 240 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 241 |
+
|
| 242 |
+
task = args.task.upper()
|
| 243 |
+
n_classes = N_CLASSES[task]
|
| 244 |
+
positive_class = DEFAULT_POSITIVE_CLASS[task] if args.positive_class is None else args.positive_class
|
| 245 |
+
if not 0 <= positive_class < n_classes:
|
| 246 |
+
raise ValueError(f"positive_class={positive_class} is invalid for task={task} with {n_classes} classes")
|
| 247 |
+
|
| 248 |
+
dataset = AREDSDataset(
|
| 249 |
+
args.test_csv,
|
| 250 |
+
args.image_root,
|
| 251 |
+
task,
|
| 252 |
+
transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
|
| 253 |
+
)
|
| 254 |
+
loader = DataLoader(
|
| 255 |
+
dataset,
|
| 256 |
+
batch_size=args.batch_size,
|
| 257 |
+
shuffle=False,
|
| 258 |
+
num_workers=args.num_workers,
|
| 259 |
+
pin_memory=device.type == "cuda",
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
model = DeepSeeNet(
|
| 263 |
+
n_classes=n_classes,
|
| 264 |
+
backbone=args.backbone,
|
| 265 |
+
pretrained=False,
|
| 266 |
+
).to(device)
|
| 267 |
+
checkpoint = torch.load(args.checkpoint, map_location=device)
|
| 268 |
+
model.load_state_dict(checkpoint["model"])
|
| 269 |
+
|
| 270 |
+
pred_dict = collect_predictions(model, loader, device)
|
| 271 |
+
y_true = pred_dict["labels"]
|
| 272 |
+
y_pred = pred_dict["preds"]
|
| 273 |
+
probs = pred_dict["probs"]
|
| 274 |
+
|
| 275 |
+
metrics = compute_metrics(y_true, y_pred, probs, n_classes=n_classes, positive_class=positive_class)
|
| 276 |
+
metrics["loss"] = float(pred_dict["loss"])
|
| 277 |
+
|
| 278 |
+
units = None
|
| 279 |
+
if args.bootstrap_unit_column:
|
| 280 |
+
df_for_units = pd.read_csv(args.test_csv)
|
| 281 |
+
if args.bootstrap_unit_column not in df_for_units.columns:
|
| 282 |
+
raise KeyError(
|
| 283 |
+
f"--bootstrap-unit-column {args.bootstrap_unit_column!r} not found in {args.test_csv}. "
|
| 284 |
+
f"Available columns: {list(df_for_units.columns)}"
|
| 285 |
+
)
|
| 286 |
+
if len(df_for_units) != len(y_true):
|
| 287 |
+
raise ValueError(
|
| 288 |
+
"CSV length does not match dataset length. "
|
| 289 |
+
f"CSV rows={len(df_for_units)}, dataset rows={len(y_true)}"
|
| 290 |
+
)
|
| 291 |
+
units = df_for_units[args.bootstrap_unit_column].to_numpy()
|
| 292 |
+
|
| 293 |
+
rng = np.random.default_rng(args.seed)
|
| 294 |
+
bs_indices = make_bootstrap_indices(
|
| 295 |
+
n=len(y_true),
|
| 296 |
+
n_iters=args.bootstrap_iters,
|
| 297 |
+
rng=rng,
|
| 298 |
+
units=units,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def metric_fn(idx: np.ndarray) -> dict[str, float]:
|
| 302 |
+
out = compute_metrics(
|
| 303 |
+
y_true[idx],
|
| 304 |
+
y_pred[idx],
|
| 305 |
+
probs[idx],
|
| 306 |
+
n_classes=n_classes,
|
| 307 |
+
positive_class=positive_class,
|
| 308 |
+
)
|
| 309 |
+
out.pop("loss", None)
|
| 310 |
+
return out
|
| 311 |
+
|
| 312 |
+
ci = bootstrap_ci(metric_fn, bs_indices)
|
| 313 |
+
metrics_with_ci = combine_with_ci(metrics, ci)
|
| 314 |
+
|
| 315 |
+
cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes)))
|
| 316 |
+
endpoint_cm = confusion_matrix(
|
| 317 |
+
(y_true == positive_class).astype(int),
|
| 318 |
+
(y_pred == positive_class).astype(int),
|
| 319 |
+
labels=[0, 1],
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
meta = {
|
| 323 |
+
"task": task,
|
| 324 |
+
"endpoint": ENDPOINT_NAME[task],
|
| 325 |
+
"positive_class": int(positive_class),
|
| 326 |
+
"n_classes": int(n_classes),
|
| 327 |
+
"n_samples": int(len(y_true)),
|
| 328 |
+
"bootstrap_iters": int(args.bootstrap_iters),
|
| 329 |
+
"bootstrap_unit_column": args.bootstrap_unit_column,
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
print(f"\nTask: {task} | endpoint: {ENDPOINT_NAME[task]} | positive_class={positive_class}")
|
| 333 |
+
print_metric_table(metrics_with_ci)
|
| 334 |
+
print("\nConfusion matrix (rows=true, cols=pred):")
|
| 335 |
+
print(cm)
|
| 336 |
+
print("\nBinary confusion matrix (rows=true, cols=pred):")
|
| 337 |
+
print(endpoint_cm)
|
| 338 |
+
|
| 339 |
+
if args.output_dir:
|
| 340 |
+
output_dir = Path(args.output_dir)
|
| 341 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 342 |
+
|
| 343 |
+
with (output_dir / "metrics.json").open("w") as f:
|
| 344 |
+
json.dump({"meta": meta, "metrics": metrics_with_ci}, f, indent=2)
|
| 345 |
+
|
| 346 |
+
pd.DataFrame(cm).to_csv(output_dir / "confusion_matrix.csv", index=False)
|
| 347 |
+
pd.DataFrame(endpoint_cm, index=["true_neg", "true_pos"], columns=["pred_neg", "pred_pos"]).to_csv(
|
| 348 |
+
output_dir / "endpoint_confusion_matrix.csv"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
pred_df = pd.read_csv(args.test_csv)
|
| 352 |
+
if len(pred_df) == len(y_true):
|
| 353 |
+
pred_df = pred_df.copy()
|
| 354 |
+
else:
|
| 355 |
+
pred_df = pd.DataFrame(index=np.arange(len(y_true)))
|
| 356 |
+
pred_df["y_true"] = y_true
|
| 357 |
+
pred_df["y_pred"] = y_pred
|
| 358 |
+
pred_df[f"y_true_{ENDPOINT_NAME[task]}"] = (y_true == positive_class).astype(int)
|
| 359 |
+
pred_df[f"y_pred_{ENDPOINT_NAME[task]}"] = (y_pred == positive_class).astype(int)
|
| 360 |
+
for c in range(n_classes):
|
| 361 |
+
pred_df[f"prob_class_{c}"] = probs[:, c]
|
| 362 |
+
pred_df.to_csv(output_dir / "predictions.csv", index=False)
|
| 363 |
+
|
| 364 |
+
print(f"\nSaved outputs to: {output_dir}")
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from augmentations import get_train_transforms, get_val_transforms
|
| 12 |
+
from dataloader import AREDSDataset
|
| 13 |
+
from model import DeepSeeNet
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
N_CLASSES = {
|
| 17 |
+
"ADVAMD": 2,
|
| 18 |
+
"DRUS": 3,
|
| 19 |
+
"PIG": 2,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AlbumentationsTransform:
|
| 24 |
+
def __init__(self, transform):
|
| 25 |
+
self.transform = transform
|
| 26 |
+
|
| 27 |
+
def __call__(self, image):
|
| 28 |
+
return self.transform(image=np.asarray(image))["image"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def set_seed(seed):
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
np.random.seed(seed)
|
| 34 |
+
torch.manual_seed(seed)
|
| 35 |
+
torch.cuda.manual_seed_all(seed)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_class_weights(dataset, task, device):
|
| 39 |
+
labels = torch.tensor(dataset.data[task].to_numpy(), dtype=torch.long)
|
| 40 |
+
counts = torch.bincount(labels, minlength=N_CLASSES[task]).clamp_min(1)
|
| 41 |
+
weights = counts.sum() / (len(counts) * counts)
|
| 42 |
+
return weights.to(device)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def build_scheduler(optimizer, args):
|
| 46 |
+
if args.scheduler == "cosine":
|
| 47 |
+
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 48 |
+
optimizer,
|
| 49 |
+
T_max=args.epochs,
|
| 50 |
+
eta_min=args.min_lr,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if args.scheduler == "step":
|
| 54 |
+
return torch.optim.lr_scheduler.StepLR(
|
| 55 |
+
optimizer,
|
| 56 |
+
step_size=args.step_size,
|
| 57 |
+
gamma=args.gamma,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def train_one_epoch(
|
| 64 |
+
model,
|
| 65 |
+
loader,
|
| 66 |
+
optimizer,
|
| 67 |
+
scaler,
|
| 68 |
+
criterion,
|
| 69 |
+
device,
|
| 70 |
+
use_amp=True,
|
| 71 |
+
grad_clip=0.0,
|
| 72 |
+
):
|
| 73 |
+
model.train()
|
| 74 |
+
|
| 75 |
+
running_loss = 0.0
|
| 76 |
+
running_correct = 0
|
| 77 |
+
running_samples = 0
|
| 78 |
+
|
| 79 |
+
pbar = tqdm(loader, desc="Train", leave=False)
|
| 80 |
+
|
| 81 |
+
for images, labels in pbar:
|
| 82 |
+
images = images.to(device, non_blocking=True)
|
| 83 |
+
labels = labels.to(device, non_blocking=True)
|
| 84 |
+
|
| 85 |
+
optimizer.zero_grad(set_to_none=True)
|
| 86 |
+
|
| 87 |
+
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
|
| 88 |
+
logits = model(images)
|
| 89 |
+
loss = criterion(logits, labels)
|
| 90 |
+
|
| 91 |
+
if scaler is not None:
|
| 92 |
+
scaler.scale(loss).backward()
|
| 93 |
+
|
| 94 |
+
if grad_clip > 0:
|
| 95 |
+
scaler.unscale_(optimizer)
|
| 96 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 97 |
+
|
| 98 |
+
scaler.step(optimizer)
|
| 99 |
+
scaler.update()
|
| 100 |
+
else:
|
| 101 |
+
loss.backward()
|
| 102 |
+
|
| 103 |
+
if grad_clip > 0:
|
| 104 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 105 |
+
|
| 106 |
+
optimizer.step()
|
| 107 |
+
|
| 108 |
+
batch_size = labels.size(0)
|
| 109 |
+
running_loss += loss.item() * batch_size
|
| 110 |
+
running_correct += (logits.argmax(dim=1) == labels).sum().item()
|
| 111 |
+
running_samples += batch_size
|
| 112 |
+
|
| 113 |
+
pbar.set_postfix(
|
| 114 |
+
loss=f"{running_loss / running_samples:.4f}",
|
| 115 |
+
acc=f"{running_correct / running_samples:.4f}",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return running_loss / running_samples, running_correct / running_samples
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def evaluate(model, loader, criterion, device, use_amp=True):
|
| 123 |
+
model.eval()
|
| 124 |
+
|
| 125 |
+
running_loss = 0.0
|
| 126 |
+
running_correct = 0
|
| 127 |
+
running_samples = 0
|
| 128 |
+
|
| 129 |
+
pbar = tqdm(loader, desc="Val", leave=False)
|
| 130 |
+
|
| 131 |
+
for images, labels in pbar:
|
| 132 |
+
images = images.to(device, non_blocking=True)
|
| 133 |
+
labels = labels.to(device, non_blocking=True)
|
| 134 |
+
|
| 135 |
+
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
|
| 136 |
+
logits = model(images)
|
| 137 |
+
loss = criterion(logits, labels)
|
| 138 |
+
|
| 139 |
+
batch_size = labels.size(0)
|
| 140 |
+
running_loss += loss.item() * batch_size
|
| 141 |
+
running_correct += (logits.argmax(dim=1) == labels).sum().item()
|
| 142 |
+
running_samples += batch_size
|
| 143 |
+
|
| 144 |
+
pbar.set_postfix(
|
| 145 |
+
loss=f"{running_loss / running_samples:.4f}",
|
| 146 |
+
acc=f"{running_correct / running_samples:.4f}",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
return running_loss / running_samples, running_correct / running_samples
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def save_checkpoint(path, model, optimizer, epoch, best_val_loss, args):
|
| 153 |
+
path = Path(path)
|
| 154 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 155 |
+
|
| 156 |
+
torch.save(
|
| 157 |
+
{
|
| 158 |
+
"epoch": epoch,
|
| 159 |
+
"model": model.state_dict(),
|
| 160 |
+
"optimizer": optimizer.state_dict(),
|
| 161 |
+
"best_val_loss": best_val_loss,
|
| 162 |
+
"args": vars(args),
|
| 163 |
+
},
|
| 164 |
+
path,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def main(args):
|
| 169 |
+
set_seed(args.seed)
|
| 170 |
+
|
| 171 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 172 |
+
use_amp = args.amp and device.type == "cuda"
|
| 173 |
+
|
| 174 |
+
train_dataset = AREDSDataset(
|
| 175 |
+
args.train_csv,
|
| 176 |
+
args.image_root,
|
| 177 |
+
args.task,
|
| 178 |
+
transform=AlbumentationsTransform(get_train_transforms(args.image_size)),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
val_dataset = AREDSDataset(
|
| 182 |
+
args.valid_csv,
|
| 183 |
+
args.image_root,
|
| 184 |
+
args.task,
|
| 185 |
+
transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
train_loader = DataLoader(
|
| 189 |
+
train_dataset,
|
| 190 |
+
batch_size=args.batch_size,
|
| 191 |
+
shuffle=True,
|
| 192 |
+
num_workers=args.num_workers,
|
| 193 |
+
pin_memory=device.type == "cuda",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
val_loader = DataLoader(
|
| 197 |
+
val_dataset,
|
| 198 |
+
batch_size=args.batch_size,
|
| 199 |
+
shuffle=False,
|
| 200 |
+
num_workers=args.num_workers,
|
| 201 |
+
pin_memory=device.type == "cuda",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
model = DeepSeeNet(
|
| 205 |
+
n_classes=N_CLASSES[args.task],
|
| 206 |
+
backbone=args.backbone,
|
| 207 |
+
pretrained=not args.no_pretrained,
|
| 208 |
+
freeze_backbone=args.freeze_backbone,
|
| 209 |
+
).to(device)
|
| 210 |
+
|
| 211 |
+
class_weights = None
|
| 212 |
+
if not args.no_class_weights:
|
| 213 |
+
class_weights = get_class_weights(train_dataset, args.task, device)
|
| 214 |
+
|
| 215 |
+
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
|
| 216 |
+
|
| 217 |
+
optimizer = torch.optim.AdamW(
|
| 218 |
+
model.parameters(),
|
| 219 |
+
lr=args.lr,
|
| 220 |
+
weight_decay=args.weight_decay,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
scheduler = build_scheduler(optimizer, args)
|
| 224 |
+
scaler = torch.amp.GradScaler("cuda") if use_amp else None
|
| 225 |
+
|
| 226 |
+
wandb = None
|
| 227 |
+
if args.wandb:
|
| 228 |
+
import wandb
|
| 229 |
+
|
| 230 |
+
wandb.init(project=args.wandb_project, config=vars(args))
|
| 231 |
+
|
| 232 |
+
output_dir = Path(args.output_dir)
|
| 233 |
+
best_val_loss = float("inf")
|
| 234 |
+
|
| 235 |
+
print(f"Device: {device}")
|
| 236 |
+
print(f"Task: {args.task}")
|
| 237 |
+
print(f"Train samples: {len(train_dataset)}")
|
| 238 |
+
print(f"Val samples: {len(val_dataset)}")
|
| 239 |
+
print(f"Image size: {args.image_size}")
|
| 240 |
+
print(f"Batch size: {args.batch_size}")
|
| 241 |
+
print(f"Pretrained: {not args.no_pretrained}")
|
| 242 |
+
if class_weights is not None:
|
| 243 |
+
print(f"Class weights: {class_weights.detach().cpu().tolist()}")
|
| 244 |
+
|
| 245 |
+
for epoch in range(1, args.epochs + 1):
|
| 246 |
+
print(f"\nEpoch [{epoch:03d}/{args.epochs}]")
|
| 247 |
+
|
| 248 |
+
train_loss, train_acc = train_one_epoch(
|
| 249 |
+
model=model,
|
| 250 |
+
loader=train_loader,
|
| 251 |
+
optimizer=optimizer,
|
| 252 |
+
scaler=scaler,
|
| 253 |
+
criterion=criterion,
|
| 254 |
+
device=device,
|
| 255 |
+
use_amp=args.amp,
|
| 256 |
+
grad_clip=args.grad_clip,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
val_loss, val_acc = evaluate(
|
| 260 |
+
model=model,
|
| 261 |
+
loader=val_loader,
|
| 262 |
+
criterion=torch.nn.CrossEntropyLoss(),
|
| 263 |
+
device=device,
|
| 264 |
+
use_amp=args.amp,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 268 |
+
|
| 269 |
+
print(
|
| 270 |
+
f"train_loss={train_loss:.4f} "
|
| 271 |
+
f"train_acc={train_acc:.4f} "
|
| 272 |
+
f"val_loss={val_loss:.4f} "
|
| 273 |
+
f"val_acc={val_acc:.4f} "
|
| 274 |
+
f"lr={lr:.2e}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if wandb is not None:
|
| 278 |
+
wandb.log(
|
| 279 |
+
{
|
| 280 |
+
"epoch": epoch,
|
| 281 |
+
"lr": lr,
|
| 282 |
+
"train_loss": train_loss,
|
| 283 |
+
"train_acc": train_acc,
|
| 284 |
+
"val_loss": val_loss,
|
| 285 |
+
"val_acc": val_acc,
|
| 286 |
+
}
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
if val_loss < best_val_loss:
|
| 290 |
+
best_val_loss = val_loss
|
| 291 |
+
save_checkpoint(
|
| 292 |
+
output_dir / "best.pt",
|
| 293 |
+
model,
|
| 294 |
+
optimizer,
|
| 295 |
+
epoch,
|
| 296 |
+
best_val_loss,
|
| 297 |
+
args,
|
| 298 |
+
)
|
| 299 |
+
print(f"Saved best checkpoint: val_loss={best_val_loss:.4f}")
|
| 300 |
+
|
| 301 |
+
if args.save_every > 0 and epoch % args.save_every == 0:
|
| 302 |
+
save_checkpoint(
|
| 303 |
+
output_dir / f"epoch_{epoch:03d}.pt",
|
| 304 |
+
model,
|
| 305 |
+
optimizer,
|
| 306 |
+
epoch,
|
| 307 |
+
best_val_loss,
|
| 308 |
+
args,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if scheduler is not None:
|
| 312 |
+
scheduler.step()
|
| 313 |
+
|
| 314 |
+
save_checkpoint(
|
| 315 |
+
output_dir / "last.pt",
|
| 316 |
+
model,
|
| 317 |
+
optimizer,
|
| 318 |
+
args.epochs,
|
| 319 |
+
best_val_loss,
|
| 320 |
+
args,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
print("Training complete.")
|
| 324 |
+
print(f"Best val loss: {best_val_loss:.4f}")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def parse_args():
|
| 328 |
+
parser = argparse.ArgumentParser(description="Train DeepSeeNet.")
|
| 329 |
+
|
| 330 |
+
parser.add_argument("--train-csv", required=True)
|
| 331 |
+
parser.add_argument("--valid-csv", required=True)
|
| 332 |
+
parser.add_argument("--image-root", required=True)
|
| 333 |
+
parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES)
|
| 334 |
+
parser.add_argument("--output-dir", default="checkpoints/deepseenet")
|
| 335 |
+
|
| 336 |
+
parser.add_argument("--backbone", default="inception_v3")
|
| 337 |
+
parser.add_argument("--image-size", type=int, default=1024)
|
| 338 |
+
parser.add_argument("--epochs", type=int, default=20)
|
| 339 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 340 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
| 341 |
+
|
| 342 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 343 |
+
parser.add_argument("--weight-decay", type=float, default=1e-4)
|
| 344 |
+
parser.add_argument("--no-pretrained", action="store_true")
|
| 345 |
+
parser.add_argument("--freeze-backbone", action="store_true")
|
| 346 |
+
parser.add_argument("--no-class-weights", action="store_true")
|
| 347 |
+
|
| 348 |
+
parser.add_argument("--scheduler", choices=("none", "cosine", "step"), default="cosine")
|
| 349 |
+
parser.add_argument("--min-lr", type=float, default=1e-6)
|
| 350 |
+
parser.add_argument("--step-size", type=int, default=5)
|
| 351 |
+
parser.add_argument("--gamma", type=float, default=0.5)
|
| 352 |
+
|
| 353 |
+
parser.add_argument("--amp", action="store_true")
|
| 354 |
+
parser.add_argument("--grad-clip", type=float, default=0.0)
|
| 355 |
+
parser.add_argument("--save-every", type=int, default=0)
|
| 356 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 357 |
+
|
| 358 |
+
parser.add_argument("--wandb", action="store_true")
|
| 359 |
+
parser.add_argument("--wandb-project", default="deepseenet")
|
| 360 |
+
|
| 361 |
+
return parser.parse_args()
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
args = parse_args()
|
| 366 |
+
main(args)
|