Spaces:
Build error
Build error
Setup instructions, docstrings
Browse files- README.md +45 -1
- app.py +35 -3
- environment.yml +247 -18
- labelmap.py +2 -0
- requirements.txt +74 -0
- train.py +69 -15
README.md
CHANGED
|
@@ -1 +1,45 @@
|
|
| 1 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diabetic Retinopathy Detection with AI
|
| 2 |
+
|
| 3 |
+
## Setup
|
| 4 |
+
|
| 5 |
+
### Gradio app environment
|
| 6 |
+
|
| 7 |
+
Install from pip requirements file:
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
conda create -n retinopathy_app python=3.10
|
| 11 |
+
conda activate retinopathy_app
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
python app.py
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Install manually:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
pip install pytorch --index-url https://download.pytorch.org/whl/cpu
|
| 20 |
+
pip install gradio
|
| 21 |
+
pip install transformers
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Training environment
|
| 25 |
+
|
| 26 |
+
Create conda environment from YAML:
|
| 27 |
+
```bash
|
| 28 |
+
mamba env create -n retinopathy_train -f environment.yml
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
Download the data from [Kaggle](https://www.kaggle.com/competitions/diabetic-retinopathy-detection/data) or use kaggle API:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
pip install kaggle
|
| 35 |
+
kaggle competitions download -c diabetic-retinopathy-detection
|
| 36 |
+
mkdir retinopathy_data/
|
| 37 |
+
unzip diabetic-retinopathy-detection.zip -d retinopathy_data/
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Launch training:
|
| 41 |
+
```bash
|
| 42 |
+
conda activate retinopathy_train
|
| 43 |
+
python train.py
|
| 44 |
+
```
|
| 45 |
+
The trained model will be put into `lightning_logs/`.
|
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
-
from typing import Tuple, Optional, Dict, List
|
| 6 |
import glob
|
| 7 |
from collections import defaultdict
|
| 8 |
|
|
@@ -13,7 +13,10 @@ from labelmap import DR_LABELMAP
|
|
| 13 |
|
| 14 |
|
| 15 |
class App:
|
|
|
|
|
|
|
| 16 |
def __init__(self) -> None:
|
|
|
|
| 17 |
|
| 18 |
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
|
| 19 |
|
|
@@ -41,7 +44,7 @@ class App:
|
|
| 41 |
output = gr.Label(num_top_classes=len(DR_LABELMAP),
|
| 42 |
label="Retinopathy level prediction")
|
| 43 |
with gr.Column(scale=4):
|
| 44 |
-
gr.Markdown(":
|
| 46 |
with gr.Column(scale=9, min_width=100):
|
| 47 |
image = gr.Image(label="Retina scan")
|
|
@@ -66,9 +69,19 @@ class App:
|
|
| 66 |
self.ui = ui
|
| 67 |
|
| 68 |
def launch(self) -> None:
|
|
|
|
| 69 |
self.ui.queue().launch(share=True)
|
| 70 |
|
| 71 |
-
def predict(self, image: Optional[np.ndarray]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
if image is None:
|
| 73 |
return dict()
|
| 74 |
cls_name, prob, probs = self._infer(image)
|
|
@@ -79,6 +92,19 @@ class App:
|
|
| 79 |
return probs_dict
|
| 80 |
|
| 81 |
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
assert isinstance(self.model, ResNetForImageClassification)
|
| 83 |
|
| 84 |
inputs = self.image_processor(image_chw, return_tensors="pt")
|
|
@@ -98,6 +124,11 @@ class App:
|
|
| 98 |
|
| 99 |
@staticmethod
|
| 100 |
def _load_example_lists() -> Dict[int, List[str]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
|
| 103 |
|
|
@@ -115,6 +146,7 @@ class App:
|
|
| 115 |
|
| 116 |
|
| 117 |
def main():
|
|
|
|
| 118 |
app = App()
|
| 119 |
app.launch()
|
| 120 |
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
from typing import Tuple, Optional, Dict, List, Dict
|
| 6 |
import glob
|
| 7 |
from collections import defaultdict
|
| 8 |
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class App:
|
| 16 |
+
""" Demonstration of the Diabetic Retinopathy model as a Gradio app. """
|
| 17 |
+
|
| 18 |
def __init__(self) -> None:
|
| 19 |
+
""" Constructor. """
|
| 20 |
|
| 21 |
ckpt_name = "2023-12-24_20-02-18_30345221_V100_x4_resnet34/"
|
| 22 |
|
|
|
|
| 44 |
output = gr.Label(num_top_classes=len(DR_LABELMAP),
|
| 45 |
label="Retinopathy level prediction")
|
| 46 |
with gr.Column(scale=4):
|
| 47 |
+
gr.Markdown("")
|
| 48 |
with gr.Row():
|
| 49 |
with gr.Column(scale=9, min_width=100):
|
| 50 |
image = gr.Image(label="Retina scan")
|
|
|
|
| 69 |
self.ui = ui
|
| 70 |
|
| 71 |
def launch(self) -> None:
|
| 72 |
+
""" Launch the application, blocking. """
|
| 73 |
self.ui.queue().launch(share=True)
|
| 74 |
|
| 75 |
+
def predict(self, image: Optional[np.ndarray]) -> Dict[str, float]:
|
| 76 |
+
""" Gradio callback for pricessing of an image.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
image (Optional[np.ndarray]): Provided image.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Dict[str, float]: Label-compatible dict.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
if image is None:
|
| 86 |
return dict()
|
| 87 |
cls_name, prob, probs = self._infer(image)
|
|
|
|
| 92 |
return probs_dict
|
| 93 |
|
| 94 |
def _infer(self, image_chw: np.ndarray) -> Tuple[str, float, np.ndarray]:
|
| 95 |
+
""" Low-level method to perform neural network inference.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
image_chw (np.ndarray): Provided image.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tuple[str, float, np.ndarray]:
|
| 102 |
+
- Most probable class name
|
| 103 |
+
- Probability of the most probable class name.
|
| 104 |
+
- Probablilities of all classes in the order of
|
| 105 |
+
being listed in the label map.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
assert isinstance(self.model, ResNetForImageClassification)
|
| 109 |
|
| 110 |
inputs = self.image_processor(image_chw, return_tensors="pt")
|
|
|
|
| 124 |
|
| 125 |
@staticmethod
|
| 126 |
def _load_example_lists() -> Dict[int, List[str]]:
|
| 127 |
+
""" Load example retina images from disk.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Dict[int, List[str]]: Dictionary of cls_id -> list of images paths.
|
| 131 |
+
"""
|
| 132 |
|
| 133 |
example_flat_list = glob.glob("demo_data/train/**/*.jpeg")
|
| 134 |
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
| 149 |
+
""" App entry point. """
|
| 150 |
app = App()
|
| 151 |
app.launch()
|
| 152 |
|
environment.yml
CHANGED
|
@@ -1,22 +1,251 @@
|
|
| 1 |
-
name:
|
| 2 |
-
|
| 3 |
channels:
|
| 4 |
-
-
|
| 5 |
-
- nvidia
|
| 6 |
- conda-forge
|
| 7 |
- defaults
|
| 8 |
-
|
| 9 |
dependencies:
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
-
|
| 17 |
-
-
|
| 18 |
-
-
|
| 19 |
-
-
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: retinopathy
|
|
|
|
| 2 |
channels:
|
| 3 |
+
- anaconda
|
|
|
|
| 4 |
- conda-forge
|
| 5 |
- defaults
|
|
|
|
| 6 |
dependencies:
|
| 7 |
+
- _libgcc_mutex=0.1=main
|
| 8 |
+
- _openmp_mutex=5.1=1_gnu
|
| 9 |
+
- aiofiles=22.1.0=py310h06a4308_0
|
| 10 |
+
- aiosqlite=0.18.0=py310h06a4308_0
|
| 11 |
+
- argon2-cffi=21.3.0=pyhd3eb1b0_0
|
| 12 |
+
- argon2-cffi-bindings=21.2.0=py310h7f8727e_0
|
| 13 |
+
- asttokens=2.0.5=pyhd3eb1b0_0
|
| 14 |
+
- attrs=23.1.0=py310h06a4308_0
|
| 15 |
+
- babel=2.11.0=py310h06a4308_0
|
| 16 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
| 17 |
+
- beautifulsoup4=4.12.2=py310h06a4308_0
|
| 18 |
+
- bleach=4.1.0=pyhd3eb1b0_0
|
| 19 |
+
- brotli-python=1.0.9=py310h6a678d5_7
|
| 20 |
+
- bzip2=1.0.8=h7b6447c_0
|
| 21 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
| 22 |
+
- certifi=2023.11.17=py310h06a4308_0
|
| 23 |
+
- cffi=1.16.0=py310h5eee18b_0
|
| 24 |
+
- comm=0.1.2=py310h06a4308_0
|
| 25 |
+
- cryptography=41.0.3=py310hdda0065_0
|
| 26 |
+
- debugpy=1.6.7=py310h6a678d5_0
|
| 27 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
| 28 |
+
- defusedxml=0.7.1=pyhd3eb1b0_0
|
| 29 |
+
- executing=0.8.3=pyhd3eb1b0_0
|
| 30 |
+
- ipykernel=6.25.0=py310h2f386ee_0
|
| 31 |
+
- ipython=8.15.0=py310h06a4308_0
|
| 32 |
+
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
| 33 |
+
- jedi=0.18.1=py310h06a4308_1
|
| 34 |
+
- jinja2=3.1.2=py310h06a4308_0
|
| 35 |
+
- json5=0.9.6=pyhd3eb1b0_0
|
| 36 |
+
- jsonschema=4.19.2=py310h06a4308_0
|
| 37 |
+
- jsonschema-specifications=2023.7.1=py310h06a4308_0
|
| 38 |
+
- jupyter_client=8.6.0=py310h06a4308_0
|
| 39 |
+
- jupyter_core=5.5.0=py310h06a4308_0
|
| 40 |
+
- jupyter_events=0.8.0=py310h06a4308_0
|
| 41 |
+
- jupyter_server=2.10.0=py310h06a4308_0
|
| 42 |
+
- jupyter_server_fileid=0.9.0=py310h06a4308_0
|
| 43 |
+
- jupyter_server_terminals=0.4.4=py310h06a4308_1
|
| 44 |
+
- jupyter_server_ydoc=0.8.0=py310h06a4308_1
|
| 45 |
+
- jupyter_ydoc=0.2.4=py310h06a4308_0
|
| 46 |
+
- jupyterlab=3.6.3=py310h06a4308_0
|
| 47 |
+
- jupyterlab_pygments=0.2.2=py310h06a4308_0
|
| 48 |
+
- jupyterlab_server=2.25.1=py310h06a4308_0
|
| 49 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 50 |
+
- libffi=3.4.4=h6a678d5_0
|
| 51 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 52 |
+
- libgomp=11.2.0=h1234567_1
|
| 53 |
+
- libsodium=1.0.18=h7b6447c_0
|
| 54 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 55 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 56 |
+
- matplotlib-inline=0.1.6=py310h06a4308_0
|
| 57 |
+
- mistune=2.0.4=py310h06a4308_0
|
| 58 |
+
- nbclassic=1.0.0=py310h06a4308_0
|
| 59 |
+
- nbclient=0.8.0=py310h06a4308_0
|
| 60 |
+
- nbconvert=7.10.0=py310h06a4308_0
|
| 61 |
+
- nbformat=5.9.2=py310h06a4308_0
|
| 62 |
+
- ncurses=6.4=h6a678d5_0
|
| 63 |
+
- nest-asyncio=1.5.6=py310h06a4308_0
|
| 64 |
+
- notebook=6.5.4=py310h06a4308_0
|
| 65 |
+
- notebook-shim=0.2.3=py310h06a4308_0
|
| 66 |
+
- openssl=3.0.12=h7f8727e_0
|
| 67 |
+
- overrides=7.4.0=py310h06a4308_0
|
| 68 |
+
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
| 69 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
| 70 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
| 71 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
| 72 |
+
- platformdirs=3.10.0=py310h06a4308_0
|
| 73 |
+
- prometheus_client=0.14.1=py310h06a4308_0
|
| 74 |
+
- prompt-toolkit=3.0.36=py310h06a4308_0
|
| 75 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
| 76 |
+
- pure_eval=0.2.2=pyhd3eb1b0_0
|
| 77 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 78 |
+
- pyopenssl=23.2.0=py310h06a4308_0
|
| 79 |
+
- pysocks=1.7.1=py310h06a4308_0
|
| 80 |
+
- python=3.10.13=h955ad1f_0
|
| 81 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 82 |
+
- python-fastjsonschema=2.16.2=py310h06a4308_0
|
| 83 |
+
- python-json-logger=2.0.7=py310h06a4308_0
|
| 84 |
+
- pytz=2023.3.post1=py310h06a4308_0
|
| 85 |
+
- pyyaml=6.0.1=py310h5eee18b_0
|
| 86 |
+
- pyzmq=25.1.0=py310h6a678d5_0
|
| 87 |
+
- readline=8.2=h5eee18b_0
|
| 88 |
+
- referencing=0.30.2=py310h06a4308_0
|
| 89 |
+
- requests=2.31.0=py310h06a4308_0
|
| 90 |
+
- rfc3339-validator=0.1.4=py310h06a4308_0
|
| 91 |
+
- rfc3986-validator=0.1.1=py310h06a4308_0
|
| 92 |
+
- rpds-py=0.10.6=py310hb02cf49_0
|
| 93 |
+
- send2trash=1.8.2=py310h06a4308_0
|
| 94 |
+
- setuptools=68.0.0=py310h06a4308_0
|
| 95 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 96 |
+
- soupsieve=2.5=py310h06a4308_0
|
| 97 |
+
- sqlite=3.41.2=h5eee18b_0
|
| 98 |
+
- stack_data=0.2.0=pyhd3eb1b0_0
|
| 99 |
+
- terminado=0.17.1=py310h06a4308_0
|
| 100 |
+
- tinycss2=1.2.1=py310h06a4308_0
|
| 101 |
+
- tk=8.6.12=h1ccaba5_0
|
| 102 |
+
- tomli=2.0.1=py310h06a4308_0
|
| 103 |
+
- tornado=6.3.3=py310h5eee18b_0
|
| 104 |
+
- webencodings=0.5.1=py310h06a4308_1
|
| 105 |
+
- wheel=0.41.2=py310h06a4308_0
|
| 106 |
+
- y-py=0.5.9=py310h52d8a92_0
|
| 107 |
+
- yaml=0.2.5=h7b6447c_0
|
| 108 |
+
- ypy-websocket=0.8.2=py310h06a4308_0
|
| 109 |
+
- zeromq=4.3.4=h2531618_0
|
| 110 |
+
- zlib=1.2.13=h5eee18b_0
|
| 111 |
+
- pip:
|
| 112 |
+
- absl-py==2.0.0
|
| 113 |
+
- aiobotocore==2.8.0
|
| 114 |
+
- aiohttp==3.9.1
|
| 115 |
+
- aioitertools==0.11.0
|
| 116 |
+
- aiosignal==1.3.1
|
| 117 |
+
- altair==5.2.0
|
| 118 |
+
- annotated-types==0.6.0
|
| 119 |
+
- antlr4-python3-runtime==4.9.3
|
| 120 |
+
- anyio==3.7.1
|
| 121 |
+
- arrow==1.3.0
|
| 122 |
+
- async-timeout==4.0.3
|
| 123 |
+
- backoff==2.2.1
|
| 124 |
+
- bitsandbytes==0.41.3
|
| 125 |
+
- blessed==1.20.0
|
| 126 |
+
- boto3==1.33.1
|
| 127 |
+
- botocore==1.33.1
|
| 128 |
+
- cachetools==5.3.2
|
| 129 |
+
- chardet==5.2.0
|
| 130 |
+
- charset-normalizer==3.3.2
|
| 131 |
+
- click==8.1.7
|
| 132 |
+
- colorama==0.4.6
|
| 133 |
+
- contourpy==1.2.0
|
| 134 |
+
- croniter==1.4.1
|
| 135 |
+
- cycler==0.12.1
|
| 136 |
+
- dateutils==0.6.12
|
| 137 |
+
- deepdiff==6.7.1
|
| 138 |
+
- docker==6.1.3
|
| 139 |
+
- docstring-parser==0.15
|
| 140 |
+
- exceptiongroup==1.2.0
|
| 141 |
+
- fastapi==0.104.1
|
| 142 |
+
- ffmpy==0.3.1
|
| 143 |
+
- filelock==3.13.1
|
| 144 |
+
- fonttools==4.46.0
|
| 145 |
+
- frozenlist==1.4.0
|
| 146 |
+
- fsspec==2023.12.1
|
| 147 |
+
- google-auth==2.25.1
|
| 148 |
+
- google-auth-oauthlib==1.1.0
|
| 149 |
+
- gradio==4.12.0
|
| 150 |
+
- gradio-client==0.8.0
|
| 151 |
+
- grpcio==1.59.3
|
| 152 |
+
- h11==0.14.0
|
| 153 |
+
- httpcore==1.0.2
|
| 154 |
+
- httpx==0.26.0
|
| 155 |
+
- huggingface-hub==0.19.4
|
| 156 |
+
- hydra-core==1.3.2
|
| 157 |
+
- idna==3.6
|
| 158 |
+
- importlib-resources==6.1.1
|
| 159 |
+
- inquirer==3.1.4
|
| 160 |
+
- itsdangerous==2.1.2
|
| 161 |
+
- jmespath==1.0.1
|
| 162 |
+
- jsonargparse==4.27.1
|
| 163 |
+
- kiwisolver==1.4.5
|
| 164 |
+
- lightning==2.1.2
|
| 165 |
+
- lightning-api-access==0.0.5
|
| 166 |
+
- lightning-cloud==0.5.52
|
| 167 |
+
- lightning-fabric==2.1.2
|
| 168 |
+
- lightning-utilities==0.10.0
|
| 169 |
+
- line-profiler==4.1.2
|
| 170 |
+
- markdown==3.5.1
|
| 171 |
+
- markdown-it-py==3.0.0
|
| 172 |
+
- markupsafe==2.1.3
|
| 173 |
+
- matplotlib==3.8.2
|
| 174 |
+
- mdurl==0.1.2
|
| 175 |
+
- mpmath==1.3.0
|
| 176 |
+
- multidict==6.0.4
|
| 177 |
+
- networkx==3.2.1
|
| 178 |
+
- numpy==1.26.2
|
| 179 |
+
- nvidia-cublas-cu12==12.1.3.1
|
| 180 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
| 181 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
| 182 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
| 183 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
| 184 |
+
- nvidia-cufft-cu12==11.0.2.54
|
| 185 |
+
- nvidia-curand-cu12==10.3.2.106
|
| 186 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
| 187 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
| 188 |
+
- nvidia-nccl-cu12==2.18.1
|
| 189 |
+
- nvidia-nvjitlink-cu12==12.3.101
|
| 190 |
+
- nvidia-nvtx-cu12==12.1.105
|
| 191 |
+
- oauthlib==3.2.2
|
| 192 |
+
- omegaconf==2.3.0
|
| 193 |
+
- ordered-set==4.1.0
|
| 194 |
+
- orjson==3.9.10
|
| 195 |
+
- packaging==23.2
|
| 196 |
+
- pandas==2.1.3
|
| 197 |
+
- pillow==10.1.0
|
| 198 |
+
- protobuf==4.23.4
|
| 199 |
+
- psutil==5.9.6
|
| 200 |
+
- pyasn1==0.5.1
|
| 201 |
+
- pyasn1-modules==0.3.0
|
| 202 |
+
- pydantic==2.5.2
|
| 203 |
+
- pydantic-core==2.14.5
|
| 204 |
+
- pydub==0.25.1
|
| 205 |
+
- pygments==2.17.2
|
| 206 |
+
- pyjwt==2.8.0
|
| 207 |
+
- pyparsing==3.1.1
|
| 208 |
+
- python-editor==1.0.4
|
| 209 |
+
- python-multipart==0.0.6
|
| 210 |
+
- pytorch-lightning==2.1.2
|
| 211 |
+
- readchar==4.0.5
|
| 212 |
+
- redis==5.0.1
|
| 213 |
+
- regex==2023.10.3
|
| 214 |
+
- requests-oauthlib==1.3.1
|
| 215 |
+
- rich==13.7.0
|
| 216 |
+
- rsa==4.9
|
| 217 |
+
- s3fs==2023.12.1
|
| 218 |
+
- s3transfer==0.8.0
|
| 219 |
+
- safetensors==0.4.1
|
| 220 |
+
- semantic-version==2.10.0
|
| 221 |
+
- shellingham==1.5.4
|
| 222 |
+
- sniffio==1.3.0
|
| 223 |
+
- starlette==0.27.0
|
| 224 |
+
- starsessions==1.3.0
|
| 225 |
+
- sympy==1.12
|
| 226 |
+
- tensorboard==2.15.1
|
| 227 |
+
- tensorboard-data-server==0.7.2
|
| 228 |
+
- tensorboardx==2.6.2.2
|
| 229 |
+
- tokenizers==0.15.0
|
| 230 |
+
- tomlkit==0.12.0
|
| 231 |
+
- toolz==0.12.0
|
| 232 |
+
- torch==2.1.1
|
| 233 |
+
- torchmetrics==1.2.1
|
| 234 |
+
- torchvision==0.16.1
|
| 235 |
+
- tqdm==4.66.1
|
| 236 |
+
- traitlets==5.14.0
|
| 237 |
+
- transformers==4.35.2
|
| 238 |
+
- triton==2.1.0
|
| 239 |
+
- typer==0.9.0
|
| 240 |
+
- types-python-dateutil==2.8.19.14
|
| 241 |
+
- typeshed-client==2.4.0
|
| 242 |
+
- typing-extensions==4.9.0
|
| 243 |
+
- tzdata==2023.3
|
| 244 |
+
- urllib3==2.0.7
|
| 245 |
+
- uvicorn==0.24.0.post1
|
| 246 |
+
- wcwidth==0.2.12
|
| 247 |
+
- websocket-client==1.7.0
|
| 248 |
+
- websockets==11.0.3
|
| 249 |
+
- werkzeug==3.0.1
|
| 250 |
+
- wrapt==1.16.0
|
| 251 |
+
- yarl==1.9.4
|
labelmap.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
DR_LABELMAP = {
|
| 2 |
0: 'No DR',
|
| 3 |
1: 'Mild',
|
|
|
|
| 1 |
+
""" Mapping of class IDs to lables. """
|
| 2 |
+
|
| 3 |
DR_LABELMAP = {
|
| 4 |
0: 'No DR',
|
| 5 |
1: 'Mild',
|
requirements.txt
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 2 |
+
aiofiles==23.2.1
|
| 3 |
+
altair==5.2.0
|
| 4 |
+
annotated-types==0.6.0
|
| 5 |
+
anyio==4.2.0
|
| 6 |
+
attrs==23.2.0
|
| 7 |
+
certifi==2023.11.17
|
| 8 |
+
charset-normalizer==3.3.2
|
| 9 |
+
click==8.1.7
|
| 10 |
+
colorama==0.4.6
|
| 11 |
+
contourpy==1.2.0
|
| 12 |
+
cycler==0.12.1
|
| 13 |
+
exceptiongroup==1.2.0
|
| 14 |
+
fastapi==0.108.0
|
| 15 |
+
ffmpy==0.3.1
|
| 16 |
+
filelock==3.13.1
|
| 17 |
+
fonttools==4.47.0
|
| 18 |
+
fsspec==2023.12.2
|
| 19 |
+
gradio==4.13.0
|
| 20 |
+
gradio_client==0.8.0
|
| 21 |
+
h11==0.14.0
|
| 22 |
+
httpcore==1.0.2
|
| 23 |
+
httpx==0.26.0
|
| 24 |
+
huggingface-hub==0.20.2
|
| 25 |
+
idna==3.6
|
| 26 |
+
importlib-resources==6.1.1
|
| 27 |
+
Jinja2==3.1.2
|
| 28 |
+
jsonschema==4.20.0
|
| 29 |
+
jsonschema-specifications==2023.12.1
|
| 30 |
+
kiwisolver==1.4.5
|
| 31 |
+
markdown-it-py==3.0.0
|
| 32 |
+
MarkupSafe==2.1.3
|
| 33 |
+
matplotlib==3.8.2
|
| 34 |
+
mdurl==0.1.2
|
| 35 |
+
mpmath==1.3.0
|
| 36 |
+
networkx==3.2.1
|
| 37 |
+
numpy==1.26.3
|
| 38 |
+
orjson==3.9.10
|
| 39 |
+
packaging==23.2
|
| 40 |
+
pandas==2.1.4
|
| 41 |
+
pillow==10.2.0
|
| 42 |
+
pydantic==2.5.3
|
| 43 |
+
pydantic_core==2.14.6
|
| 44 |
+
pydub==0.25.1
|
| 45 |
+
Pygments==2.17.2
|
| 46 |
+
pyparsing==3.1.1
|
| 47 |
+
python-dateutil==2.8.2
|
| 48 |
+
python-multipart==0.0.6
|
| 49 |
+
pytz==2023.3.post1
|
| 50 |
+
PyYAML==6.0.1
|
| 51 |
+
referencing==0.32.1
|
| 52 |
+
regex==2023.12.25
|
| 53 |
+
requests==2.31.0
|
| 54 |
+
rich==13.7.0
|
| 55 |
+
rpds-py==0.16.2
|
| 56 |
+
safetensors==0.4.1
|
| 57 |
+
semantic-version==2.10.0
|
| 58 |
+
shellingham==1.5.4
|
| 59 |
+
six==1.16.0
|
| 60 |
+
sniffio==1.3.0
|
| 61 |
+
starlette==0.32.0.post1
|
| 62 |
+
sympy==1.12
|
| 63 |
+
tokenizers==0.15.0
|
| 64 |
+
tomlkit==0.12.0
|
| 65 |
+
toolz==0.12.0
|
| 66 |
+
torch==2.1.2+cpu
|
| 67 |
+
tqdm==4.66.1
|
| 68 |
+
transformers==4.36.2
|
| 69 |
+
typer==0.9.0
|
| 70 |
+
typing_extensions==4.9.0
|
| 71 |
+
tzdata==2023.4
|
| 72 |
+
urllib3==2.1.0
|
| 73 |
+
uvicorn==0.25.0
|
| 74 |
+
websockets==11.0.3
|
train.py
CHANGED
|
@@ -49,7 +49,15 @@ DataRecord = Tuple[Image.Image, int]
|
|
| 49 |
|
| 50 |
|
| 51 |
class RetinopathyDataset(data.Dataset[DataRecord]):
|
|
|
|
|
|
|
| 52 |
def __init__(self, data_path: str) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
super().__init__()
|
| 54 |
|
| 55 |
self.data_path = data_path
|
|
@@ -88,21 +96,25 @@ class RetinopathyDataset(data.Dataset[DataRecord]):
|
|
| 88 |
return img_path
|
| 89 |
|
| 90 |
|
|
|
|
| 91 |
class Purpose(Enum):
|
| 92 |
Train = 0
|
| 93 |
Val = 1
|
| 94 |
|
| 95 |
-
|
| 96 |
FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
|
| 97 |
Callable[..., torch.Tensor]]
|
| 98 |
|
|
|
|
| 99 |
TensorRecord = Tuple[torch.Tensor, torch.Tensor]
|
| 100 |
|
| 101 |
-
def normalize(arr: np.ndarray) -> np.ndarray:
|
| 102 |
-
return arr / np.sum(arr)
|
| 103 |
-
|
| 104 |
|
| 105 |
class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def __init__(self, dataset: RetinopathyDataset,
|
| 107 |
indices: np.ndarray,
|
| 108 |
purpose: Purpose,
|
|
@@ -111,7 +123,24 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
| 111 |
stratify_classes: bool = False,
|
| 112 |
use_log_frequencies: bool = False,
|
| 113 |
):
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
self.dataset = dataset
|
| 116 |
self.indices = indices
|
| 117 |
self.purpose = purpose
|
|
@@ -124,22 +153,26 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
| 124 |
self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
|
| 125 |
self.frequencies: Optional[Dict[int, float]] = None
|
| 126 |
if self.stratify_classes:
|
| 127 |
-
self.
|
| 128 |
if self.use_log_frequencies:
|
| 129 |
-
self.
|
| 130 |
|
| 131 |
-
def
|
| 132 |
assert self.per_class_indices is not None
|
| 133 |
counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
|
| 134 |
counts = np.array(list(counts_dict.values()))
|
| 135 |
-
counts_nrm =
|
| 136 |
temperature = 50.0 # > 1 to even-out frequencies
|
| 137 |
-
freqs =
|
| 138 |
self.frequencies = {k: freq.item() for k, freq
|
| 139 |
in zip(self.per_class_indices.keys(), freqs)}
|
| 140 |
print(self.frequencies)
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
buckets = defaultdict(list)
|
| 144 |
for index in self.indices:
|
| 145 |
label = self.dataset.get_label_at(index)
|
|
@@ -191,6 +224,14 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
| 191 |
seed: int = 54,
|
| 192 |
) -> Tuple['Split', 'Split']:
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
prng = RandomState(seed)
|
| 195 |
|
| 196 |
num_train = int(len(all_data) * train_fraction)
|
|
@@ -204,7 +245,8 @@ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
|
| 204 |
return train_data, val_data
|
| 205 |
|
| 206 |
|
| 207 |
-
def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
|
|
|
|
| 208 |
labels = []
|
| 209 |
for _, label in dataset:
|
| 210 |
if isinstance(label, torch.Tensor):
|
|
@@ -261,7 +303,16 @@ class Metrics:
|
|
| 261 |
return self
|
| 262 |
|
| 263 |
|
| 264 |
-
def worker_init_fn(worker_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
state = np.random.get_state()
|
| 266 |
assert isinstance(state, tuple)
|
| 267 |
assert isinstance(state[1], np.ndarray)
|
|
@@ -274,6 +325,7 @@ def worker_init_fn(worker_id):
|
|
| 274 |
|
| 275 |
|
| 276 |
class ViTLightningModule(L.LightningModule):
|
|
|
|
| 277 |
def __init__(self, debug: bool) -> None:
|
| 278 |
super().__init__()
|
| 279 |
|
|
@@ -443,6 +495,7 @@ class ViTLightningModule(L.LightningModule):
|
|
| 443 |
return loss
|
| 444 |
|
| 445 |
def _dump_train_images(self) -> None:
|
|
|
|
| 446 |
img_batch, label_batch = next(iter(self._train_dataloader))
|
| 447 |
for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
|
| 448 |
img_np = img.cpu().numpy()
|
|
@@ -494,18 +547,19 @@ class ViTLightningModule(L.LightningModule):
|
|
| 494 |
|
| 495 |
|
| 496 |
def main():
|
|
|
|
| 497 |
|
| 498 |
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
|
| 499 |
parser.add_argument('--tag', action='store', type=str,
|
| 500 |
help='Extra suffix to put on the artefact dir name')
|
| 501 |
-
parser.add_argument('--debug', action='store_true'
|
|
|
|
| 502 |
parser.add_argument('--convert-checkpoint', action='store', type=str,
|
| 503 |
help='Convert a checkpoint from training to pickle-independent '
|
| 504 |
'predictor-compatible directory')
|
| 505 |
|
| 506 |
args = parser.parse_args()
|
| 507 |
|
| 508 |
-
|
| 509 |
torch.set_float32_matmul_precision('high') # for V100/A100
|
| 510 |
|
| 511 |
if args.convert_checkpoint is not None:
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
class RetinopathyDataset(data.Dataset[DataRecord]):
|
| 52 |
+
""" A class to access the pre-downloaded Diabetic Retinopathy dataset. """
|
| 53 |
+
|
| 54 |
def __init__(self, data_path: str) -> None:
|
| 55 |
+
""" Constructor.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
data_path (str): path to the dataset, ex: "retinopathy_data"
|
| 59 |
+
containing "trainLabels.csv" and "train/".
|
| 60 |
+
"""
|
| 61 |
super().__init__()
|
| 62 |
|
| 63 |
self.data_path = data_path
|
|
|
|
| 96 |
return img_path
|
| 97 |
|
| 98 |
|
| 99 |
+
""" Purpose of a split: training or validation. """
|
| 100 |
class Purpose(Enum):
|
| 101 |
Train = 0
|
| 102 |
Val = 1
|
| 103 |
|
| 104 |
+
""" Augmentation transformations for an image and a label. """
|
| 105 |
FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
|
| 106 |
Callable[..., torch.Tensor]]
|
| 107 |
|
| 108 |
+
""" Feature (image) and target (label) tensors. """
|
| 109 |
TensorRecord = Tuple[torch.Tensor, torch.Tensor]
|
| 110 |
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
|
| 113 |
+
""" Split is a class that keep a view on a part of a dataset.
|
| 114 |
+
Split is used to hold the imormation about which samples go to training
|
| 115 |
+
and which to validation without a need to put these groups of files into
|
| 116 |
+
separate folders.
|
| 117 |
+
"""
|
| 118 |
def __init__(self, dataset: RetinopathyDataset,
|
| 119 |
indices: np.ndarray,
|
| 120 |
purpose: Purpose,
|
|
|
|
| 123 |
stratify_classes: bool = False,
|
| 124 |
use_log_frequencies: bool = False,
|
| 125 |
):
|
| 126 |
+
""" Constructor.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
dataset (RetinopathyDataset): The dataset on which the Split "views".
|
| 130 |
+
indices (np.ndarray): Externally provided indices of samples that
|
| 131 |
+
are "viewed" on.
|
| 132 |
+
purpose (Purpose): Either train or val, to be able to replicate
|
| 133 |
+
the data for train split for effecient workers utilization.
|
| 134 |
+
transforms (FeatureAndTargetTransforms): Functors of feature and
|
| 135 |
+
target transforms.
|
| 136 |
+
oversample_factor (int, optional): Expand the training dataset by
|
| 137 |
+
replication to avoid dataloader stalls on epoch ends. Defaults to 1.
|
| 138 |
+
stratify_classes (bool, optional): Whether to apply stratified sampling.
|
| 139 |
+
Defaults to False.
|
| 140 |
+
use_log_frequencies (bool, optional): If stratify_classes=True,
|
| 141 |
+
whether to use logarithmic sampling strategy. If False, apply
|
| 142 |
+
regular even sampling. Defaults to False.
|
| 143 |
+
"""
|
| 144 |
self.dataset = dataset
|
| 145 |
self.indices = indices
|
| 146 |
self.purpose = purpose
|
|
|
|
| 153 |
self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
|
| 154 |
self.frequencies: Optional[Dict[int, float]] = None
|
| 155 |
if self.stratify_classes:
|
| 156 |
+
self._bucketize_indices()
|
| 157 |
if self.use_log_frequencies:
|
| 158 |
+
self._calc_frequencies()
|
| 159 |
|
| 160 |
+
def _calc_frequencies(self):
|
| 161 |
assert self.per_class_indices is not None
|
| 162 |
counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
|
| 163 |
counts = np.array(list(counts_dict.values()))
|
| 164 |
+
counts_nrm = self._normalize(counts)
|
| 165 |
temperature = 50.0 # > 1 to even-out frequencies
|
| 166 |
+
freqs = self._normalize(np.log1p(counts_nrm * temperature))
|
| 167 |
self.frequencies = {k: freq.item() for k, freq
|
| 168 |
in zip(self.per_class_indices.keys(), freqs)}
|
| 169 |
print(self.frequencies)
|
| 170 |
|
| 171 |
+
@staticmethod
|
| 172 |
+
def _normalize(arr: np.ndarray) -> np.ndarray:
|
| 173 |
+
return arr / np.sum(arr)
|
| 174 |
+
|
| 175 |
+
def _bucketize_indices(self):
|
| 176 |
buckets = defaultdict(list)
|
| 177 |
for index in self.indices:
|
| 178 |
label = self.dataset.get_label_at(index)
|
|
|
|
| 224 |
seed: int = 54,
|
| 225 |
) -> Tuple['Split', 'Split']:
|
| 226 |
|
| 227 |
+
""" Prepare train and val splits deterministically.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Tuple[Split, Split]:
|
| 231 |
+
- Train split
|
| 232 |
+
- Val split
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
prng = RandomState(seed)
|
| 236 |
|
| 237 |
num_train = int(len(all_data) * train_fraction)
|
|
|
|
| 245 |
return train_data, val_data
|
| 246 |
|
| 247 |
|
| 248 |
+
def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
|
| 249 |
+
split_name: str) -> None:
|
| 250 |
labels = []
|
| 251 |
for _, label in dataset:
|
| 252 |
if isinstance(label, torch.Tensor):
|
|
|
|
| 303 |
return self
|
| 304 |
|
| 305 |
|
| 306 |
+
def worker_init_fn(worker_id: int) -> None:
|
| 307 |
+
""" Initialize workers in a way that they draw different
|
| 308 |
+
random samples and do not repeat identical pseudorandom
|
| 309 |
+
sequences of each other, which may be the case with Fork
|
| 310 |
+
multiprocessing.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
worker_id (int): id of a preprocessing worker process launched
|
| 314 |
+
by one DDP training process.
|
| 315 |
+
"""
|
| 316 |
state = np.random.get_state()
|
| 317 |
assert isinstance(state, tuple)
|
| 318 |
assert isinstance(state[1], np.ndarray)
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
class ViTLightningModule(L.LightningModule):
|
| 328 |
+
""" Lightning Module that implements neural network training hooks. """
|
| 329 |
def __init__(self, debug: bool) -> None:
|
| 330 |
super().__init__()
|
| 331 |
|
|
|
|
| 495 |
return loss
|
| 496 |
|
| 497 |
def _dump_train_images(self) -> None:
|
| 498 |
+
""" Save augmented images to disk for inspection. """
|
| 499 |
img_batch, label_batch = next(iter(self._train_dataloader))
|
| 500 |
for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
|
| 501 |
img_np = img.cpu().numpy()
|
|
|
|
| 547 |
|
| 548 |
|
| 549 |
def main():
|
| 550 |
+
""" Neural network trainer entry point. """
|
| 551 |
|
| 552 |
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
|
| 553 |
parser.add_argument('--tag', action='store', type=str,
|
| 554 |
help='Extra suffix to put on the artefact dir name')
|
| 555 |
+
parser.add_argument('--debug', action='store_true',
|
| 556 |
+
help="Dummy training cycle for testing purposes")
|
| 557 |
parser.add_argument('--convert-checkpoint', action='store', type=str,
|
| 558 |
help='Convert a checkpoint from training to pickle-independent '
|
| 559 |
'predictor-compatible directory')
|
| 560 |
|
| 561 |
args = parser.parse_args()
|
| 562 |
|
|
|
|
| 563 |
torch.set_float32_matmul_precision('high') # for V100/A100
|
| 564 |
|
| 565 |
if args.convert_checkpoint is not None:
|