zyf0717 commited on
Commit
41a3267
·
1 Parent(s): d8a41b8

Migrate repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +12 -0
  3. README.md +193 -3
  4. artery_vein/av_july24.pt +3 -0
  5. artery_vein/av_july24_AVRDB.pt +3 -0
  6. artery_vein/av_july24_IOSTAR.pt +3 -0
  7. artery_vein/av_july24_LEUVEN.pt +3 -0
  8. artery_vein/av_july24_RS.pt +3 -0
  9. config.yaml +16 -0
  10. disc/disc_july24.pt +3 -0
  11. disc/disc_july24_ADAM.pt +3 -0
  12. disc/disc_july24_IDRID.pt +3 -0
  13. disc/disc_july24_ORIGA.pt +3 -0
  14. disc/disc_july24_PAPILA.pt +3 -0
  15. discedge/discedge_july24.pt +3 -0
  16. environment.yml +19 -0
  17. fovea/fovea_july24.pt +3 -0
  18. imgs/CHASEDB1_08L.png +3 -0
  19. imgs/CHASEDB1_08L_rgb.png +3 -0
  20. imgs/CHASEDB1_12R.png +3 -0
  21. imgs/CHASEDB1_12R_rgb.png +3 -0
  22. imgs/DRIVE_22.png +3 -0
  23. imgs/DRIVE_22_rgb.png +3 -0
  24. imgs/DRIVE_40.png +3 -0
  25. imgs/DRIVE_40_rgb.png +3 -0
  26. imgs/HRF_04_g.png +3 -0
  27. imgs/HRF_04_g_rgb.png +3 -0
  28. imgs/HRF_07_dr.png +3 -0
  29. imgs/HRF_07_dr_rgb.png +3 -0
  30. imgs/samples_vascx_hrf.png +3 -0
  31. notebooks/0_preprocess.ipynb +138 -0
  32. notebooks/1_segment_preprocessed.ipynb +217 -0
  33. odfd/odfd_march25.pt +3 -0
  34. quality/quality.pt +3 -0
  35. run.sh +60 -0
  36. samples/fundus/original/CHASEDB1_08L.png +3 -0
  37. samples/fundus/original/CHASEDB1_12R.png +3 -0
  38. samples/fundus/original/DRIVE_22.png +3 -0
  39. samples/fundus/original/DRIVE_40.png +3 -0
  40. samples/fundus/original/HRF_04_g.jpg +3 -0
  41. samples/fundus/original/HRF_07_dr.jpg +3 -0
  42. setup.py +36 -0
  43. vascx_models/__init__.py +0 -0
  44. vascx_models/cli.py +259 -0
  45. vascx_models/config.py +196 -0
  46. vascx_models/disc_rings.py +118 -0
  47. vascx_models/inference.py +292 -0
  48. vascx_models/utils.py +196 -0
  49. vessels/vessels_july24.pt +3 -0
  50. vessels/vessels_july24_DRHAGIS.pt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ __pycache__
3
+ *.egg-info
4
+ *.zip
5
+ .DS_Store
6
+ .cache/
7
+ .mplconfig/
8
+ model_releases/
9
+ output_*/
10
+ output_*.zip
11
+ /samples/fundus/*
12
+ !/samples/fundus/original
README.md CHANGED
@@ -1,3 +1,193 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: agpl-3.0
3
+ pipeline_tag: image-segmentation
4
+ tags:
5
+ - medical
6
+ - biology
7
+ ---
8
+
9
+ # 👁️ VascX Fork
10
+
11
+ This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016). This fork is published as `zyf0717/vascx-fork` on the Hugging Face Hub.
12
+
13
+ The model weights are in [huggingface](https://huggingface.co/zyf0717/vascx-fork).
14
+
15
+ <img src="imgs/samples_vascx_hrf.png">
16
+
17
+ ## 🛠️ Installation
18
+
19
+ To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
20
+
21
+ 1. Create a conda or virtualenv virtual environment, or otherwise ensure a clean environment.
22
+
23
+ 2. Install `torch` and `torchvision` for your platform.
24
+
25
+ 3. Install the pipeline runtime packages:
26
+
27
+ ```bash
28
+ pip install retinalysis-fundusprep retinalysis-inference
29
+ pip install -e .
30
+ ```
31
+
32
+ The `environment.yml` in this repository includes the same runtime dependencies.
33
+
34
+
35
+ ## 🚀 `vascx run` Command
36
+
37
+ The repository name is `vascx-fork`, but the installed CLI entry point remains `vascx` for compatibility.
38
+
39
+ The `run` command provides a comprehensive pipeline for processing fundus images, performing various analyses, and creating visualizations.
40
+
41
+ ### Usage
42
+
43
+ ```bash
44
+ vascx run DATA_PATH OUTPUT_PATH [OPTIONS]
45
+ ```
46
+
47
+ If `config.yaml` exists in the current working directory or at the repository root, `vascx run` loads it automatically. You can also point to a specific file with `--config /path/to/config.yaml`.
48
+
49
+ ### Arguments
50
+
51
+ - `DATA_PATH`: Path to input data. Can be either:
52
+ - A directory containing fundus images
53
+ - A CSV file with a 'path' column containing paths to images
54
+
55
+ - `OUTPUT_PATH`: Directory where processed results will be stored
56
+
57
+ ### Options
58
+
59
+ | Option | Default | Description |
60
+ |--------|---------|-------------|
61
+ | `--preprocess/--no-preprocess` | `--preprocess` | Run preprocessing to standardize images for model input |
62
+ | `--vessels/--no-vessels` | `--vessels` | Run vessel segmentation and artery-vein classification |
63
+ | `--disc/--no-disc` | `--disc` | Run optic disc segmentation |
64
+ | `--quality/--no-quality` | `--quality` | Run image quality assessment |
65
+ | `--fovea/--no-fovea` | `--fovea` | Run fovea detection |
66
+ | `--overlay/--no-overlay` | `config.yaml` or `--overlay` | Create visualization overlays combining all results |
67
+ | `--config PATH` | auto-detect `config.yaml` | Load pipeline configuration from YAML |
68
+ | `--n_jobs` | `4` | Number of preprocessing workers for parallel processing |
69
+
70
+ ### 📁 Output Structure
71
+
72
+ When run with default options, the command creates the following structure in `OUTPUT_PATH`:
73
+
74
+ ```
75
+ OUTPUT_PATH/
76
+ ├── preprocessed_rgb/ # Standardized fundus images
77
+ ├── vessels/ # Vessel segmentation results
78
+ ├── artery_vein/ # Artery-vein classification
79
+ ├── disc/ # Optic disc segmentation
80
+ ├── disc_ring_2r/ # Binary masks for the 2r optic-disc ring
81
+ ├── disc_ring_3r/ # Binary masks for the 3r optic-disc ring
82
+ ├── overlays/ # Visualization images
83
+ ├── bounds.csv # Image boundary information
84
+ ├── disc_geometry.csv # Disc center and radius estimates in pixels
85
+ ├── quality.csv # Image quality scores
86
+ └── fovea.csv # Fovea coordinates
87
+ ```
88
+
89
+ ### 🔄 Processing Stages
90
+
91
+ 1. **Preprocessing**:
92
+ - Standardizes input images for consistent analysis
93
+ - Outputs preprocessed images and boundary information
94
+
95
+ 2. **Quality Assessment**:
96
+ - Evaluates image quality with three quality metrics (q1, q2, q3)
97
+ - Higher scores indicate better image quality
98
+
99
+ 3. **Vessel Segmentation and Artery-Vein Classification**:
100
+ - Identifies blood vessels in the retina
101
+ - Classifies vessels as arteries (1) or veins (2) with intersections (3)
102
+
103
+ 4. **Optic Disc Segmentation**:
104
+ - Identifies the optic disc location and boundaries
105
+ - Estimates disc center and radius from the disc mask
106
+ - Generates 2r and 3r ring masks around the disc
107
+
108
+ 5. **Fovea Detection**:
109
+ - Determines the coordinates of the fovea (center of vision)
110
+
111
+ 6. **Visualization Overlays**:
112
+ - Creates color-coded images showing:
113
+ - Arteries in red
114
+ - Veins in blue
115
+ - Optic disc in white
116
+ - 2r ring in green
117
+ - 3r ring in magenta
118
+ - Fovea marked with yellow X
119
+ - Overlay layers and colors can be controlled from `config.yaml`
120
+
121
+ ### ⚙️ `config.yaml`
122
+
123
+ The repository root now includes a `config.yaml` file for overlay settings. The default file looks like this:
124
+
125
+ ```yaml
126
+ overlay:
127
+ enabled: true
128
+ layers:
129
+ arteries: true
130
+ veins: true
131
+ disc: true
132
+ ring_2r: true
133
+ ring_3r: true
134
+ fovea: true
135
+ colours:
136
+ artery: "#FF0000"
137
+ vein: "#0000FF"
138
+ disc: "#FFFFFF"
139
+ ring_2r: "#00FF00"
140
+ ring_3r: "#FF00FF"
141
+ fovea: "#FFFF00"
142
+ ```
143
+
144
+ Notes:
145
+
146
+ - `overlay.enabled` controls whether overlays are produced when `--overlay/--no-overlay` is not set explicitly.
147
+ - `overlay.layers` lets you choose which predictions are drawn.
148
+ - `overlay.colors` and `overlay.colours` are both accepted.
149
+ - Colors can be written as `#RRGGBB` strings or 3-value RGB arrays such as `[255, 0, 0]`.
150
+
151
+ ### 💻 Examples
152
+
153
+ **Process a directory of images with all analyses:**
154
+ ```bash
155
+ vascx run /path/to/images /path/to/output
156
+ ```
157
+
158
+ **Process specific images listed in a CSV:**
159
+ ```bash
160
+ vascx run /path/to/image_list.csv /path/to/output
161
+ ```
162
+
163
+ **Only run preprocessing and vessel segmentation:**
164
+ ```bash
165
+ vascx run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
166
+ ```
167
+
168
+ **Skip preprocessing on already preprocessed images:**
169
+ ```bash
170
+ vascx run /path/to/preprocessed/images /path/to/output --no-preprocess
171
+ ```
172
+
173
+ **Increase parallel processing workers:**
174
+ ```bash
175
+ vascx run /path/to/images /path/to/output --n_jobs 8
176
+ ```
177
+
178
+ ### 📝 Notes
179
+
180
+ - The CSV input must contain a 'path' column with image file paths
181
+ - If the CSV includes an 'id' column, these IDs will be used instead of filenames
182
+ - When `--no-preprocess` is used, input images must already be in the proper format
183
+ - The overlay visualization requires at least one analysis component to be enabled
184
+
185
+ ## 📓 Notebooks
186
+
187
+ For more advanced usage, we have Jupyter notebooks showing how preprocessing and inference are run.
188
+
189
+ To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
190
+
191
+ 1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
192
+
193
+ 2. Inference. See [this notebook](./notebooks/1_segment_preprocessed.ipynb). All models can be ran in a single GPU with >10GB VRAM.
artery_vein/av_july24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b11d4e26ada8e1f0f279747afa5f4ef348d9d7350c4bf80e7ae6b5ac8d0b95b5
3
+ size 352774102
artery_vein/av_july24_AVRDB.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21fd6a693be9a5ffbf1b56e624612a493470d6e63025ce11f2e3886bd6f18b4c
3
+ size 352791110
artery_vein/av_july24_IOSTAR.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21fd6a693be9a5ffbf1b56e624612a493470d6e63025ce11f2e3886bd6f18b4c
3
+ size 352791110
artery_vein/av_july24_LEUVEN.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:446580e6cda2acec8dc2ab30d9526735fc670f296048055cbb5ebb9ccac28d0b
3
+ size 352830466
artery_vein/av_july24_RS.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3de549b74ebd9c9a4f49c17043b831665cc7a1981773ff8c17db2416b9dfe48
3
+ size 352805874
config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ overlay:
2
+ enabled: true
3
+ layers:
4
+ arteries: true
5
+ veins: true
6
+ disc: true
7
+ ring_2r: true
8
+ ring_3r: true
9
+ fovea: true
10
+ colours:
11
+ artery: "#FF0000"
12
+ vein: "#0000FF"
13
+ disc: "#FFFFFF"
14
+ ring_2r: "#00FF00"
15
+ ring_3r: "#FF00FF"
16
+ fovea: "#FFFF00"
disc/disc_july24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6892b1bdb3bb68b666ea9a7891b0fb2f6fbb5fd4f05038c013c1c69ec6c7910c
3
+ size 352801898
disc/disc_july24_ADAM.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f423e047c88d9f6c03ada2c706bc84265d61ec473eaab28e8ca12a0f1738401
3
+ size 352819138
disc/disc_july24_IDRID.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbf239cfec2ee9b550aa09e3623af01cd79de688cac1c79902d0feb7d24bb3f7
3
+ size 352835178
disc/disc_july24_ORIGA.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18cc9ffac522fc9b91e46fe3aed8d6bb8fbf00e44bfb768b622f5b881713add6
3
+ size 352826358
disc/disc_july24_PAPILA.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:080de64a1c005a921cb09eb04e42ce9c087dca2c1090209f77a3635af9eb1d19
3
+ size 352832490
discedge/discedge_july24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:891d8ef9bbc0676b019a81b1eceb349c1cdf4b5665a834196dd252915af64392
3
+ size 352723146
environment.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vascx-fork
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python=3.11
6
+ - pip
7
+ - numpy=2.*
8
+ - pandas=2.*
9
+ - tqdm=4.*
10
+ - pillow=11.*
11
+ - click=8.*
12
+ - pyyaml=6.*
13
+ - pip:
14
+ - torch==2.11.0
15
+ - torchvision==0.26.0
16
+ - torchaudio==2.11.0
17
+ - retinalysis-fundusprep
18
+ - retinalysis-inference
19
+ - -e .
fovea/fovea_july24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1af042f7e2a398f512be8a8d54cc480300312c3f3692c13bd98be65439a33222
3
+ size 352714676
imgs/CHASEDB1_08L.png ADDED

Git LFS Details

  • SHA256: 4b3537bcda4faa0abd2f187bf508d9dfc3b469f73f3b889d158bd3ef30fa64a9
  • Pointer size: 131 Bytes
  • Size of remote file: 694 kB
imgs/CHASEDB1_08L_rgb.png ADDED

Git LFS Details

  • SHA256: 923cbd785406a4d370b48cc0ffe2525309d35f8ecdf66a7552db6d0e3b0fd758
  • Pointer size: 131 Bytes
  • Size of remote file: 757 kB
imgs/CHASEDB1_12R.png ADDED

Git LFS Details

  • SHA256: d5457e090dc4de46bdc5c7eae45e536d680862e6059eff2c46f7425030672a79
  • Pointer size: 131 Bytes
  • Size of remote file: 804 kB
imgs/CHASEDB1_12R_rgb.png ADDED

Git LFS Details

  • SHA256: d0af405bbd3e8df582bfdd4cd91ca0006aeea307a129221c2d5d46da0ed62234
  • Pointer size: 131 Bytes
  • Size of remote file: 883 kB
imgs/DRIVE_22.png ADDED

Git LFS Details

  • SHA256: cf12b1603f3a50aa125a327aefa07c512ed4b804243e70b3d23b2f4145416d91
  • Pointer size: 131 Bytes
  • Size of remote file: 852 kB
imgs/DRIVE_22_rgb.png ADDED

Git LFS Details

  • SHA256: 87df6604a7348fd328cc5c4e51c028bd996e183fea5f012d9b045de15d8608eb
  • Pointer size: 131 Bytes
  • Size of remote file: 893 kB
imgs/DRIVE_40.png ADDED

Git LFS Details

  • SHA256: 33a24859edb67575ee6fbd2c797dc903cd64df13d314649a2bb9643706895c70
  • Pointer size: 131 Bytes
  • Size of remote file: 834 kB
imgs/DRIVE_40_rgb.png ADDED

Git LFS Details

  • SHA256: b0dcb48533f7b6859a4187eab7ca386e0655be5f7e356ad1d46d02bb3b52caa7
  • Pointer size: 131 Bytes
  • Size of remote file: 874 kB
imgs/HRF_04_g.png ADDED

Git LFS Details

  • SHA256: 64113c3789edace497c717418879e7257a0b20f73af972ac61417ba3c709a50f
  • Pointer size: 131 Bytes
  • Size of remote file: 711 kB
imgs/HRF_04_g_rgb.png ADDED

Git LFS Details

  • SHA256: 4f4f9698e15221b6dd61a3636c5b266b35fc6b221dbbaa1fc25b9e4b410c77b9
  • Pointer size: 131 Bytes
  • Size of remote file: 843 kB
imgs/HRF_07_dr.png ADDED

Git LFS Details

  • SHA256: cfcb0a41b79cd31d3531e0277ec8c44fbc829cc8c733f7fc21c99183453dcd17
  • Pointer size: 131 Bytes
  • Size of remote file: 767 kB
imgs/HRF_07_dr_rgb.png ADDED

Git LFS Details

  • SHA256: 74046f4d4d50dd3673394ba1cf1db33aeabfdd8e733c9fa1f0ad1fc9ef38dd15
  • Pointer size: 131 Bytes
  • Size of remote file: 898 kB
imgs/samples_vascx_hrf.png ADDED

Git LFS Details

  • SHA256: 17499c0fef958fe55ed8bc359d71d803048ef16c106e7cee78e01d95a38de1ec
  • Pointer size: 132 Bytes
  • Size of remote file: 6.08 MB
notebooks/0_preprocess.ipynb ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import pandas as pd\n",
12
+ "\n",
13
+ "from rtnls_fundusprep.preprocessor import parallel_preprocess"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "## Preprocessing\n",
21
+ "\n",
22
+ "This code will preprocess the images and write .png files with the square fundus image and the contrast enhanced version\n",
23
+ "\n",
24
+ "This step is not strictly necessary, but it is useful if you want to run the preprocessing step separately before model inference\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "Create a list of files to be preprocessed:"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "ds_path = Path(\"../samples/fundus\")\n",
41
+ "files = list((ds_path / \"original\").glob(\"*\"))"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "metadata": {},
47
+ "source": [
48
+ "Images with .dcm extension will be read as dicom and the pixel_array will be read as RGB. All other images will be read using PIL's Image.open"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "name": "stderr",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
61
+ "6it [00:00, 154.80it/s]\n"
62
+ ]
63
+ },
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Error with image ../samples/fundus/original/HRF_07_dr.jpg\n",
69
+ "Error with image ../samples/fundus/original/HRF_04_g.jpg\n"
70
+ ]
71
+ },
72
+ {
73
+ "name": "stderr",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 0.9s remaining: 1.8s\n",
77
+ "[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 1.5s remaining: 1.5s\n",
78
+ "[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 1.5s remaining: 0.8s\n",
79
+ "[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 1.6s finished\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "bounds = parallel_preprocess(\n",
85
+ " files, # List of image files\n",
86
+ " rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
87
+ " ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
88
+ " n_jobs=4, # number of preprocessing workers\n",
89
+ ")\n",
90
+ "df_bounds = pd.DataFrame(bounds).set_index(\"id\")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "The preprocessor will produce RGB and contrast-enhanced preprocessed images cropped to a square and return a dataframe with the image bounds that can be used to reconstruct the original image. Output files will be named the same as input images, but with .png extension. Be careful with providing multiple inputs with the same filename without extension as this will result in over-written images. Any exceptions during pre-processing will not stop execution but will print error. Images that failed pre-processing for any reason will be marked with `success=False` in the df_bounds dataframe."
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 4,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "df_bounds.to_csv(ds_path / \"meta.csv\")"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": []
115
+ }
116
+ ],
117
+ "metadata": {
118
+ "kernelspec": {
119
+ "display_name": "retinalysis",
120
+ "language": "python",
121
+ "name": "python3"
122
+ },
123
+ "language_info": {
124
+ "codemirror_mode": {
125
+ "name": "ipython",
126
+ "version": 3
127
+ },
128
+ "file_extension": ".py",
129
+ "mimetype": "text/x-python",
130
+ "name": "python",
131
+ "nbconvert_exporter": "python",
132
+ "pygments_lexer": "ipython3",
133
+ "version": "3.10.13"
134
+ }
135
+ },
136
+ "nbformat": 4,
137
+ "nbformat_minor": 2
138
+ }
notebooks/1_segment_preprocessed.ipynb ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pathlib import Path\n",
10
+ "\n",
11
+ "import torch\n",
12
+ "\n",
13
+ "from rtnls_inference import (\n",
14
+ " HeatmapRegressionEnsemble,\n",
15
+ " SegmentationEnsemble,\n",
16
+ ")"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "markdown",
21
+ "metadata": {},
22
+ "source": [
23
+ "## Segmentation of preprocessed images\n",
24
+ "\n",
25
+ "Here we segment images preprocessed using 0_preprocess.ipynb\n"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": []
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "ds_path = Path(\"../samples/fundus\")\n",
40
+ "\n",
41
+ "# input folders. these are the folders where we stored the preprocessed images\n",
42
+ "rgb_path = ds_path / \"rgb\"\n",
43
+ "ce_path = ds_path / \"ce\"\n",
44
+ "\n",
45
+ "# these are the output folders for:\n",
46
+ "av_path = ds_path / \"av\" # artery-vein segmentations\n",
47
+ "discs_path = ds_path / \"discs\" # optic disc segmentations\n",
48
+ "overlays_path = ds_path / \"overlays\" # optional overlay visualizations\n",
49
+ "\n",
50
+ "device = torch.device(\"cuda:0\") # device to use for inference"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 3,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "rgb_paths = sorted(list(rgb_path.glob(\"*.png\")))\n",
60
+ "ce_paths = sorted(list(ce_path.glob(\"*.png\")))\n",
61
+ "paired_paths = list(zip(rgb_paths, ce_paths))"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "paired_paths[0] # important to make sure that the paths are paired correctly"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "### Artery-vein segmentation\n"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "av_ensemble = SegmentationEnsemble.from_huggingface('zyf0717/vascx-fork:artery_vein/av_july24.pt').to(device)\n",
87
+ "\n",
88
+ "av_ensemble.predict_preprocessed(paired_paths, dest_path=av_path, num_workers=2)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "markdown",
93
+ "metadata": {},
94
+ "source": [
95
+ "### Disc segmentation\n"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "disc_ensemble = SegmentationEnsemble.from_huggingface('zyf0717/vascx-fork:disc/disc_july24.pt').to(device)\n",
105
+ "disc_ensemble.predict_preprocessed(paired_paths, dest_path=discs_path, num_workers=2)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "### Fovea detection\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "fovea_ensemble = HeatmapRegressionEnsemble.from_huggingface('zyf0717/vascx-fork:fovea/fovea_july24.pt').to(device)\n",
122
+ "# note: this model does not use contrast enhanced images\n",
123
+ "df = fovea_ensemble.predict_preprocessed(paired_paths, num_workers=2)\n",
124
+ "df.columns = [\"mean_x\", \"mean_y\"]\n",
125
+ "df.to_csv(ds_path / \"fovea.csv\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "df"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {},
140
+ "source": [
141
+ "### Plotting the retinas (optional)\n",
142
+ "\n",
143
+ "This will only work if you ran all the models and stored the outputs using the same folder/file names as above\n"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "from vascx.fundus.loader import RetinaLoader\n",
153
+ "\n",
154
+ "from rtnls_enface.utils.plotting import plot_gridfns\n",
155
+ "\n",
156
+ "loader = RetinaLoader.from_folder(ds_path)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "plot_gridfns([ret.plot for ret in loader[:6]])"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "metadata": {},
171
+ "source": [
172
+ "### Storing visualizations (optional)\n"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 10,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "if not overlays_path.exists():\n",
182
+ " overlays_path.mkdir()\n",
183
+ "for ret in loader:\n",
184
+ " fig, _ = ret.plot()\n",
185
+ " fig.savefig(overlays_path / f\"{ret.id}.png\", bbox_inches=\"tight\", pad_inches=0)"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": []
194
+ }
195
+ ],
196
+ "metadata": {
197
+ "kernelspec": {
198
+ "display_name": "retinalysis",
199
+ "language": "python",
200
+ "name": "python3"
201
+ },
202
+ "language_info": {
203
+ "codemirror_mode": {
204
+ "name": "ipython",
205
+ "version": 3
206
+ },
207
+ "file_extension": ".py",
208
+ "mimetype": "text/x-python",
209
+ "name": "python",
210
+ "nbconvert_exporter": "python",
211
+ "pygments_lexer": "ipython3",
212
+ "version": "3.10.13"
213
+ }
214
+ },
215
+ "nbformat": 4,
216
+ "nbformat_minor": 2
217
+ }
odfd/odfd_march25.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa2be119eb915bc9da6ba42234f703b5cc270d53b62c1e7d7e1bdff52c1e0edd
3
+ size 855538988
quality/quality.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80034ccb21a57522ba0cb86be0d46d2b659e193b86db909ad6ec2e85e61f87aa
3
+ size 855578258
run.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+ CONDA_ENV="${CONDA_ENV:-vascx-fork}"
6
+ SAMPLE_INPUT_PATH="$REPO_ROOT/samples/fundus/original"
7
+ DEFAULT_INPUT_PATH="$SAMPLE_INPUT_PATH"
8
+ INPUT_PATH="${INPUT_PATH:-$DEFAULT_INPUT_PATH}"
9
+ TIMESTAMP="$(date +"%Y%m%d_%H%M%S")"
10
+ DEFAULT_OUTPUT_PATH="$REPO_ROOT/output_$TIMESTAMP"
11
+ OUTPUT_PATH="${OUTPUT_PATH:-$DEFAULT_OUTPUT_PATH}"
12
+ N_JOBS="${N_JOBS:-1}"
13
+ MODEL_RELEASES_DIR="$REPO_ROOT/model_releases"
14
+
15
+ while [[ $# -gt 0 ]]; do
16
+ case "$1" in
17
+ --sample-run)
18
+ INPUT_PATH="$SAMPLE_INPUT_PATH"
19
+ shift
20
+ ;;
21
+ *)
22
+ echo "Unknown argument: $1" >&2
23
+ echo "Usage: $0 [--sample-run]" >&2
24
+ exit 1
25
+ ;;
26
+ esac
27
+ done
28
+
29
+ if [[ ! -d "$INPUT_PATH" ]]; then
30
+ echo "Input directory does not exist: $INPUT_PATH" >&2
31
+ exit 1
32
+ fi
33
+
34
+ mkdir -p "$REPO_ROOT/.mplconfig" "$REPO_ROOT/.cache" "$MODEL_RELEASES_DIR" "$OUTPUT_PATH"
35
+
36
+ for model_path in "$REPO_ROOT"/*/*.pt; do
37
+ [[ -e "$model_path" ]] || continue
38
+ if [[ "$model_path" == "$MODEL_RELEASES_DIR"/* ]]; then
39
+ continue
40
+ fi
41
+ ln -sf "$model_path" "$MODEL_RELEASES_DIR/$(basename "$model_path")"
42
+ done
43
+
44
+ export MPLCONFIGDIR="$REPO_ROOT/.mplconfig"
45
+ export XDG_CACHE_HOME="$REPO_ROOT/.cache"
46
+ export RTNLS_MODEL_RELEASES="$MODEL_RELEASES_DIR"
47
+
48
+ echo "Running VascX Fork"
49
+ echo " conda env: $CONDA_ENV"
50
+ echo " input path: $INPUT_PATH"
51
+ echo " output path: $OUTPUT_PATH"
52
+ echo " n_jobs: $N_JOBS"
53
+ echo " models dir: $RTNLS_MODEL_RELEASES"
54
+
55
+ CONDA_BASE="$(conda info --base)"
56
+ # shellcheck disable=SC1091
57
+ source "$CONDA_BASE/etc/profile.d/conda.sh"
58
+ conda activate "$CONDA_ENV"
59
+
60
+ exec python -c "from vascx_models.cli import cli; cli()" run "$INPUT_PATH" "$OUTPUT_PATH" --n_jobs "$N_JOBS"
samples/fundus/original/CHASEDB1_08L.png ADDED

Git LFS Details

  • SHA256: 16735352efdb2d951be1f07882e3906ee159c81e12e69c8230a7172d562cfc6b
  • Pointer size: 131 Bytes
  • Size of remote file: 621 kB
samples/fundus/original/CHASEDB1_12R.png ADDED

Git LFS Details

  • SHA256: a541717baf5d83c7295657604d1f529e9ed1cd3a4327aa224dbb83d80d49cc3a
  • Pointer size: 131 Bytes
  • Size of remote file: 776 kB
samples/fundus/original/DRIVE_22.png ADDED

Git LFS Details

  • SHA256: 58a0a44558d23d9cd4ffc60326abf91eed824bbe5718e995cb181595499f595b
  • Pointer size: 131 Bytes
  • Size of remote file: 394 kB
samples/fundus/original/DRIVE_40.png ADDED

Git LFS Details

  • SHA256: 0d8d7685974b7c0eff3583245dbb9e88a1a6a82ed60dbe09112364ba51894438
  • Pointer size: 131 Bytes
  • Size of remote file: 387 kB
samples/fundus/original/HRF_04_g.jpg ADDED

Git LFS Details

  • SHA256: fc9ed13ef42502eeecb3f1754dc0d3b72a454c82884b40dde934e8a516495588
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
samples/fundus/original/HRF_07_dr.jpg ADDED

Git LFS Details

  • SHA256: 203ddec480816b6c9d7ea3c19c1ff0870a5a61b5b6c9a176300402ac47fbc10f
  • Pointer size: 131 Bytes
  • Size of remote file: 921 kB
setup.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ with open("README.md", "r") as fh:
4
+ long_description = fh.read()
5
+
6
+ setup(
7
+ name="vascx_models",
8
+ # using versioneer for versioning using git tags
9
+ # https://github.com/python-versioneer/python-versioneer/blob/master/INSTALL.md
10
+ # version=versioneer.get_version(),
11
+ # cmdclass=versioneer.get_cmdclass(),
12
+ author="Jose Vargas",
13
+ author_email="j.vargasquiros@erasmusmc.nl",
14
+ description="Retinal analysis toolbox for Python",
15
+ long_description=long_description,
16
+ long_description_content_type="text/markdown",
17
+ packages=find_packages(),
18
+ include_package_data=True,
19
+ zip_safe=False,
20
+ entry_points={
21
+ "console_scripts": [
22
+ "vascx = vascx_models.cli:cli",
23
+ ]
24
+ },
25
+ install_requires=[
26
+ "numpy == 2.*",
27
+ "pandas == 2.*",
28
+ "tqdm == 4.*",
29
+ "Pillow == 11.*",
30
+ "click==8.*",
31
+ "PyYAML == 6.*",
32
+ "retinalysis-fundusprep",
33
+ "retinalysis-inference",
34
+ ],
35
+ python_requires=">=3.10, <3.13",
36
+ )
vascx_models/__init__.py ADDED
File without changes
vascx_models/cli.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import pandas as pd
7
+
8
+ from rtnls_fundusprep.cli import _run_preprocessing
9
+
10
+ from .config import load_app_config
11
+ from .disc_rings import generate_disc_rings
12
+ from .inference import (
13
+ preferred_device,
14
+ run_fovea_detection,
15
+ run_quality_estimation,
16
+ run_segmentation_disc,
17
+ run_segmentation_vessels_and_av,
18
+ )
19
+ from .utils import batch_create_overlays
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def configure_logging() -> None:
25
+ logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
26
+ warnings.filterwarnings(
27
+ "ignore",
28
+ message=(
29
+ "Using a non-tuple sequence for multidimensional indexing is deprecated "
30
+ "and will be changed in pytorch 2.9; use x\\[tuple\\(seq\\)\\] instead of x\\[seq\\].*"
31
+ ),
32
+ category=UserWarning,
33
+ module=r"monai\.inferers\.utils",
34
+ )
35
+
36
+
37
+ @click.group(name="vascx")
38
+ def cli():
39
+ configure_logging()
40
+
41
+
42
+ @cli.command()
43
+ @click.argument("data_path", type=click.Path(exists=True))
44
+ @click.argument("output_path", type=click.Path())
45
+ @click.option(
46
+ "--config",
47
+ "config_path",
48
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
49
+ default=None,
50
+ help="Path to a YAML config file. Defaults to ./config.yaml or the repo-root config.yaml when present.",
51
+ )
52
+ @click.option(
53
+ "--preprocess/--no-preprocess",
54
+ default=True,
55
+ help="Run preprocessing or use preprocessed images",
56
+ )
57
+ @click.option(
58
+ "--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
59
+ )
60
+ @click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
61
+ @click.option(
62
+ "--quality/--no-quality", default=True, help="Run image quality estimation"
63
+ )
64
+ @click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
65
+ @click.option(
66
+ "--overlay/--no-overlay",
67
+ default=None,
68
+ help="Create visualization overlays. Defaults to the config value when set.",
69
+ )
70
+ @click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
71
+ def run(
72
+ data_path,
73
+ output_path,
74
+ config_path,
75
+ preprocess,
76
+ vessels,
77
+ disc,
78
+ quality,
79
+ fovea,
80
+ overlay,
81
+ n_jobs,
82
+ ):
83
+ """Run the complete inference pipeline on fundus images.
84
+
85
+ DATA_PATH is either a directory containing images or a CSV file with 'path' column.
86
+ OUTPUT_PATH is the directory where results will be stored.
87
+ """
88
+
89
+ output_path = Path(output_path)
90
+ output_path.mkdir(exist_ok=True, parents=True)
91
+ try:
92
+ app_config = load_app_config(config_path)
93
+ except (FileNotFoundError, ValueError) as exc:
94
+ raise click.ClickException(str(exc)) from exc
95
+ overlay_enabled = app_config.overlay.enabled if overlay is None else overlay
96
+ if app_config.source_path is not None:
97
+ logger.info("Loaded config from %s", app_config.source_path)
98
+
99
+ # Setup output directories
100
+ preprocess_rgb_path = output_path / "preprocessed_rgb"
101
+ vessels_path = output_path / "vessels"
102
+ av_path = output_path / "artery_vein"
103
+ disc_path = output_path / "disc"
104
+ disc_ring_2r_path = output_path / "disc_ring_2r"
105
+ disc_ring_3r_path = output_path / "disc_ring_3r"
106
+ overlay_path = output_path / "overlays"
107
+
108
+ # Create required directories
109
+ if preprocess:
110
+ preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
111
+ if vessels:
112
+ av_path.mkdir(exist_ok=True, parents=True)
113
+ vessels_path.mkdir(exist_ok=True, parents=True)
114
+ if disc:
115
+ disc_path.mkdir(exist_ok=True, parents=True)
116
+ disc_ring_2r_path.mkdir(exist_ok=True, parents=True)
117
+ disc_ring_3r_path.mkdir(exist_ok=True, parents=True)
118
+ if overlay_enabled:
119
+ overlay_path.mkdir(exist_ok=True, parents=True)
120
+
121
+ bounds_path = output_path / "bounds.csv" if preprocess else None
122
+ quality_path = output_path / "quality.csv" if quality else None
123
+ fovea_path = output_path / "fovea.csv" if fovea else None
124
+ disc_geometry_path = output_path / "disc_geometry.csv" if disc else None
125
+
126
+ # Determine if input is a folder or CSV file
127
+ data_path = Path(data_path)
128
+ is_csv = data_path.suffix.lower() == ".csv"
129
+
130
+ # Get files to process
131
+ files = []
132
+ ids = None
133
+
134
+ if is_csv:
135
+ logger.info("Reading file paths from CSV: %s", data_path)
136
+ try:
137
+ df = pd.read_csv(data_path)
138
+ if "path" not in df.columns:
139
+ logger.error("CSV must contain a 'path' column")
140
+ return
141
+
142
+ # Get file paths and convert to Path objects
143
+ files = [Path(p) for p in df["path"]]
144
+
145
+ if "id" in df.columns:
146
+ ids = df["id"].tolist()
147
+ logger.info("Using IDs from CSV 'id' column")
148
+
149
+ except Exception as e:
150
+ logger.exception("Error reading CSV file: %s", e)
151
+ return
152
+ else:
153
+ logger.info("Finding files in directory: %s", data_path)
154
+ files = list(data_path.glob("*"))
155
+ ids = [f.stem for f in files]
156
+
157
+ if not files:
158
+ logger.warning("No files found to process")
159
+ return
160
+
161
+ logger.info("Found %d files to process", len(files))
162
+
163
+ # Step 1: Preprocess images if requested
164
+ if preprocess:
165
+ logger.info("Running preprocessing")
166
+ _run_preprocessing(
167
+ files=files,
168
+ ids=ids,
169
+ rgb_path=preprocess_rgb_path,
170
+ bounds_path=bounds_path,
171
+ n_jobs=n_jobs,
172
+ )
173
+ # Use the preprocessed images for subsequent steps
174
+ preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
175
+ else:
176
+ # Use the input files directly
177
+ preprocessed_files = files
178
+ ids = [f.stem for f in preprocessed_files]
179
+ logger.info("Prepared %d images for inference", len(preprocessed_files))
180
+
181
+ # Prefer hardware acceleration when the active torch build supports it.
182
+ device = preferred_device()
183
+ logger.info("Using device: %s", device)
184
+
185
+ # Step 2: Run quality estimation if requested
186
+ if quality:
187
+ logger.info("Running quality estimation")
188
+ df_quality = run_quality_estimation(
189
+ fpaths=preprocessed_files, ids=ids, device=device
190
+ )
191
+ df_quality.to_csv(quality_path)
192
+ logger.info("Quality results saved to %s", quality_path)
193
+
194
+ # Step 3: Run vessels and AV segmentation if requested
195
+ if vessels:
196
+ logger.info("Running vessels and AV segmentation")
197
+ run_segmentation_vessels_and_av(
198
+ rgb_paths=preprocessed_files,
199
+ ids=ids,
200
+ av_path=av_path,
201
+ vessels_path=vessels_path,
202
+ device=device,
203
+ )
204
+ logger.info("Vessel segmentation saved to %s", vessels_path)
205
+ logger.info("AV segmentation saved to %s", av_path)
206
+
207
+ # Step 4: Run optic disc segmentation if requested
208
+ if disc:
209
+ logger.info("Running optic disc segmentation")
210
+ run_segmentation_disc(
211
+ rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
212
+ )
213
+ logger.info("Disc segmentation saved to %s", disc_path)
214
+ generate_disc_rings(
215
+ disc_dir=disc_path,
216
+ ring_2r_dir=disc_ring_2r_path,
217
+ ring_3r_dir=disc_ring_3r_path,
218
+ measurements_path=disc_geometry_path,
219
+ )
220
+ logger.info("2r disc rings saved to %s", disc_ring_2r_path)
221
+ logger.info("3r disc rings saved to %s", disc_ring_3r_path)
222
+
223
+ # Step 5: Run fovea detection if requested
224
+ df_fovea = None
225
+ if fovea:
226
+ logger.info("Running fovea detection")
227
+ df_fovea = run_fovea_detection(
228
+ rgb_paths=preprocessed_files, ids=ids, device=device
229
+ )
230
+ df_fovea.to_csv(fovea_path)
231
+ logger.info("Fovea detection results saved to %s", fovea_path)
232
+
233
+ # Step 6: Create overlays if requested
234
+ if overlay_enabled:
235
+ logger.info("Creating visualization overlays")
236
+
237
+ # Prepare fovea data if available
238
+ fovea_data = None
239
+ if df_fovea is not None:
240
+ fovea_data = {
241
+ idx: (row["x_fovea"], row["y_fovea"])
242
+ for idx, row in df_fovea.iterrows()
243
+ }
244
+
245
+ # Create visualization overlays
246
+ batch_create_overlays(
247
+ rgb_dir=preprocess_rgb_path if preprocess else data_path,
248
+ output_dir=overlay_path,
249
+ av_dir=av_path,
250
+ disc_dir=disc_path,
251
+ ring_2r_dir=disc_ring_2r_path,
252
+ ring_3r_dir=disc_ring_3r_path,
253
+ fovea_data=fovea_data,
254
+ overlay_config=app_config.overlay,
255
+ )
256
+
257
+ logger.info("Visualization overlays saved to %s", overlay_path)
258
+
259
+ logger.info("All requested processing complete. Results saved to %s", output_path)
vascx_models/config.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Iterable, Mapping
6
+
7
+ import yaml
8
+
9
+
10
+ DEFAULT_CONFIG_NAME = "config.yaml"
11
+
12
+
13
+ def _repo_root() -> Path:
14
+ return Path(__file__).resolve().parent.parent
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class OverlayLayers:
19
+ arteries: bool = True
20
+ veins: bool = True
21
+ disc: bool = True
22
+ ring_2r: bool = True
23
+ ring_3r: bool = True
24
+ fovea: bool = True
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class OverlayColors:
29
+ artery: tuple[int, int, int] = (255, 0, 0)
30
+ vein: tuple[int, int, int] = (0, 0, 255)
31
+ disc: tuple[int, int, int] = (255, 255, 255)
32
+ ring_2r: tuple[int, int, int] = (0, 255, 0)
33
+ ring_3r: tuple[int, int, int] = (255, 0, 255)
34
+ fovea: tuple[int, int, int] = (255, 255, 0)
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class OverlayConfig:
39
+ enabled: bool = True
40
+ layers: OverlayLayers = field(default_factory=OverlayLayers)
41
+ colors: OverlayColors = field(default_factory=OverlayColors)
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class AppConfig:
46
+ overlay: OverlayConfig = field(default_factory=OverlayConfig)
47
+ source_path: Path | None = None
48
+
49
+
50
+ def default_config_candidates() -> list[Path]:
51
+ candidates = [Path.cwd() / DEFAULT_CONFIG_NAME, _repo_root() / DEFAULT_CONFIG_NAME]
52
+ unique_candidates: list[Path] = []
53
+ seen: set[Path] = set()
54
+ for candidate in candidates:
55
+ resolved = candidate.resolve()
56
+ if resolved not in seen:
57
+ unique_candidates.append(candidate)
58
+ seen.add(resolved)
59
+ return unique_candidates
60
+
61
+
62
+ def resolve_config_path(config_path: str | Path | None) -> Path | None:
63
+ if config_path is not None:
64
+ candidate = Path(config_path).expanduser()
65
+ if not candidate.exists():
66
+ raise FileNotFoundError(f"Config file not found: {candidate}")
67
+ return candidate
68
+
69
+ for candidate in default_config_candidates():
70
+ if candidate.exists():
71
+ return candidate
72
+ return None
73
+
74
+
75
+ def load_app_config(config_path: str | Path | None = None) -> AppConfig:
76
+ resolved_path = resolve_config_path(config_path)
77
+ if resolved_path is None:
78
+ return AppConfig()
79
+
80
+ with resolved_path.open("r", encoding="utf-8") as handle:
81
+ raw_config = yaml.safe_load(handle) or {}
82
+
83
+ if not isinstance(raw_config, dict):
84
+ raise ValueError("Config root must be a mapping")
85
+
86
+ overlay_raw = raw_config.get("overlay", {})
87
+ if overlay_raw is None:
88
+ overlay_raw = {}
89
+ if not isinstance(overlay_raw, dict):
90
+ raise ValueError("'overlay' must be a mapping")
91
+
92
+ layer_overrides = overlay_raw.get("layers", {})
93
+ if layer_overrides is None:
94
+ layer_overrides = {}
95
+ if not isinstance(layer_overrides, dict):
96
+ raise ValueError("'overlay.layers' must be a mapping")
97
+
98
+ color_overrides = overlay_raw.get("colors", overlay_raw.get("colours", {}))
99
+ if color_overrides is None:
100
+ color_overrides = {}
101
+ if not isinstance(color_overrides, dict):
102
+ raise ValueError("'overlay.colors' must be a mapping")
103
+
104
+ return AppConfig(
105
+ overlay=OverlayConfig(
106
+ enabled=_coerce_bool(overlay_raw.get("enabled", True), "overlay.enabled"),
107
+ layers=_build_overlay_layers(layer_overrides),
108
+ colors=_build_overlay_colors(color_overrides),
109
+ ),
110
+ source_path=resolved_path,
111
+ )
112
+
113
+
114
+ def _build_overlay_layers(raw_layers: Mapping[str, object]) -> OverlayLayers:
115
+ defaults = OverlayLayers()
116
+ alias_map = {
117
+ "artery": "arteries",
118
+ "arteries": "arteries",
119
+ "vein": "veins",
120
+ "veins": "veins",
121
+ "disc": "disc",
122
+ "ring_2r": "ring_2r",
123
+ "disc_ring_2r": "ring_2r",
124
+ "ring_3r": "ring_3r",
125
+ "disc_ring_3r": "ring_3r",
126
+ "fovea": "fovea",
127
+ }
128
+ values = defaults.__dict__.copy()
129
+ for raw_key, raw_value in raw_layers.items():
130
+ if raw_key not in alias_map:
131
+ raise ValueError(f"Unsupported overlay layer '{raw_key}'")
132
+ normalized_key = alias_map[raw_key]
133
+ values[normalized_key] = _coerce_bool(
134
+ raw_value, f"overlay.layers.{raw_key}"
135
+ )
136
+ return OverlayLayers(**values)
137
+
138
+
139
+ def _build_overlay_colors(raw_colors: Mapping[str, object]) -> OverlayColors:
140
+ defaults = OverlayColors()
141
+ alias_map = {
142
+ "artery": "artery",
143
+ "arteries": "artery",
144
+ "vein": "vein",
145
+ "veins": "vein",
146
+ "disc": "disc",
147
+ "ring_2r": "ring_2r",
148
+ "disc_ring_2r": "ring_2r",
149
+ "ring_3r": "ring_3r",
150
+ "disc_ring_3r": "ring_3r",
151
+ "fovea": "fovea",
152
+ }
153
+ values = defaults.__dict__.copy()
154
+ for raw_key, raw_value in raw_colors.items():
155
+ if raw_key not in alias_map:
156
+ raise ValueError(f"Unsupported overlay color '{raw_key}'")
157
+ normalized_key = alias_map[raw_key]
158
+ values[normalized_key] = _parse_rgb(raw_value, f"overlay.colors.{raw_key}")
159
+ return OverlayColors(**values)
160
+
161
+
162
+ def _coerce_bool(value: object, field_name: str) -> bool:
163
+ if isinstance(value, bool):
164
+ return value
165
+ raise ValueError(f"'{field_name}' must be a boolean")
166
+
167
+
168
+ def _parse_rgb(value: object, field_name: str) -> tuple[int, int, int]:
169
+ if isinstance(value, str):
170
+ return _parse_hex_color(value, field_name)
171
+ if isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)):
172
+ channels = tuple(value)
173
+ if len(channels) != 3:
174
+ raise ValueError(f"'{field_name}' must contain exactly 3 channels")
175
+ return tuple(_coerce_channel(channel, field_name) for channel in channels)
176
+ raise ValueError(
177
+ f"'{field_name}' must be a '#RRGGBB' string or a 3-item RGB sequence"
178
+ )
179
+
180
+
181
+ def _parse_hex_color(value: str, field_name: str) -> tuple[int, int, int]:
182
+ normalized = value.strip()
183
+ if normalized.startswith("#"):
184
+ normalized = normalized[1:]
185
+ if len(normalized) != 6:
186
+ raise ValueError(f"'{field_name}' must be a 6-digit hex color")
187
+ try:
188
+ return tuple(int(normalized[index : index + 2], 16) for index in (0, 2, 4))
189
+ except ValueError as exc:
190
+ raise ValueError(f"'{field_name}' must be a valid hex color") from exc
191
+
192
+
193
+ def _coerce_channel(value: object, field_name: str) -> int:
194
+ if isinstance(value, int) and 0 <= value <= 255:
195
+ return value
196
+ raise ValueError(f"'{field_name}' channels must be integers between 0 and 255")
vascx_models/disc_rings.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image, ImageDraw
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def estimate_disc_geometry(
13
+ disc_mask: np.ndarray,
14
+ ) -> Optional[Tuple[float, float, float]]:
15
+ """Estimate optic disc center and radius from a binary mask."""
16
+ mask = disc_mask > 0
17
+ if not np.any(mask):
18
+ return None
19
+
20
+ ys, xs = np.nonzero(mask)
21
+ center_x = float(xs.mean())
22
+ center_y = float(ys.mean())
23
+
24
+ # Use the equivalent-circle radius so the estimate is stable for irregular masks.
25
+ radius = float(np.sqrt(mask.sum() / np.pi))
26
+ return center_x, center_y, radius
27
+
28
+
29
+ def create_ring_mask(
30
+ image_shape: Tuple[int, int],
31
+ center: Tuple[float, float],
32
+ radius: float,
33
+ thickness: int,
34
+ ) -> np.ndarray:
35
+ """Create a binary ring mask for a circle outline."""
36
+ height, width = image_shape
37
+ ring = Image.new("L", (width, height), 0)
38
+ draw = ImageDraw.Draw(ring)
39
+
40
+ center_x, center_y = center
41
+ bbox = (
42
+ center_x - radius,
43
+ center_y - radius,
44
+ center_x + radius,
45
+ center_y + radius,
46
+ )
47
+ draw.ellipse(bbox, outline=255, width=thickness)
48
+ return np.array(ring, dtype=np.uint8)
49
+
50
+
51
+ def generate_disc_rings(
52
+ disc_dir: Path,
53
+ ring_2r_dir: Path,
54
+ ring_3r_dir: Path,
55
+ measurements_path: Optional[Path] = None,
56
+ ) -> pd.DataFrame:
57
+ """Generate 2r and 3r optic-disc ring masks from saved disc segmentations."""
58
+ ring_2r_dir.mkdir(exist_ok=True, parents=True)
59
+ ring_3r_dir.mkdir(exist_ok=True, parents=True)
60
+
61
+ disc_files = list(disc_dir.glob("*.png"))
62
+ if not disc_files:
63
+ logger.warning("No disc masks found for ring generation in %s", disc_dir)
64
+ columns = [
65
+ "x_disc_center",
66
+ "y_disc_center",
67
+ "disc_radius_px",
68
+ "ring_2r_px",
69
+ "ring_3r_px",
70
+ ]
71
+ return pd.DataFrame(columns=columns)
72
+
73
+ records: Dict[str, Dict[str, float]] = {}
74
+ logger.info("Generating 2r and 3r rings for %d disc masks", len(disc_files))
75
+
76
+ for disc_file in disc_files:
77
+ image_id = disc_file.stem
78
+ disc_mask = np.array(Image.open(disc_file)) > 0
79
+ geometry = estimate_disc_geometry(disc_mask)
80
+
81
+ if geometry is None:
82
+ logger.warning("Disc mask is empty for %s; writing blank ring masks", image_id)
83
+ blank = np.zeros(disc_mask.shape, dtype=np.uint8)
84
+ Image.fromarray(blank).save(ring_2r_dir / f"{image_id}.png")
85
+ Image.fromarray(blank).save(ring_3r_dir / f"{image_id}.png")
86
+ records[image_id] = {
87
+ "x_disc_center": np.nan,
88
+ "y_disc_center": np.nan,
89
+ "disc_radius_px": np.nan,
90
+ "ring_2r_px": np.nan,
91
+ "ring_3r_px": np.nan,
92
+ }
93
+ continue
94
+
95
+ center_x, center_y, disc_radius = geometry
96
+ line_width = max(1, int(round(disc_radius * 0.08)))
97
+ ring_2r = create_ring_mask(
98
+ disc_mask.shape, (center_x, center_y), radius=disc_radius * 2.0, thickness=line_width
99
+ )
100
+ ring_3r = create_ring_mask(
101
+ disc_mask.shape, (center_x, center_y), radius=disc_radius * 3.0, thickness=line_width
102
+ )
103
+
104
+ Image.fromarray(ring_2r).save(ring_2r_dir / f"{image_id}.png")
105
+ Image.fromarray(ring_3r).save(ring_3r_dir / f"{image_id}.png")
106
+ records[image_id] = {
107
+ "x_disc_center": center_x,
108
+ "y_disc_center": center_y,
109
+ "disc_radius_px": disc_radius,
110
+ "ring_2r_px": disc_radius * 2.0,
111
+ "ring_3r_px": disc_radius * 3.0,
112
+ }
113
+
114
+ df_measurements = pd.DataFrame.from_dict(records, orient="index")
115
+ if measurements_path is not None:
116
+ df_measurements.to_csv(measurements_path)
117
+ logger.info("Disc ring measurements saved to %s", measurements_path)
118
+ return df_measurements
vascx_models/inference.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from contextlib import nullcontext
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ from rtnls_inference.ensembles.ensemble_classification import ClassificationEnsemble
14
+ from rtnls_inference.ensembles.ensemble_heatmap_regression import (
15
+ HeatmapRegressionEnsemble,
16
+ )
17
+ from rtnls_inference.ensembles.ensemble_segmentation import SegmentationEnsemble
18
+ from rtnls_inference.utils import decollate_batch, extract_keypoints_from_heatmaps
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def preferred_device() -> torch.device:
24
+ if torch.cuda.is_available():
25
+ return torch.device("cuda:0")
26
+ if torch.backends.mps.is_available():
27
+ return torch.device("mps")
28
+ return torch.device("cpu")
29
+
30
+
31
+ def _inference_num_workers(device: torch.device) -> int:
32
+ # Torch shared-memory workers can fail in restricted CPU environments.
33
+ return 8 if device.type in {"cuda", "mps"} else 0
34
+
35
+
36
+ def _autocast_context(device: torch.device):
37
+ return torch.autocast(device_type=device.type) if device.type == "cuda" else nullcontext()
38
+
39
+
40
+ def run_quality_estimation(fpaths, ids, device: torch.device):
41
+ logger.info("Loading quality model on %s", device)
42
+ ensemble_quality = ClassificationEnsemble.from_release("quality.pt").to(device)
43
+ dataloader = ensemble_quality._make_inference_dataloader(
44
+ fpaths,
45
+ ids=ids,
46
+ num_workers=_inference_num_workers(device),
47
+ preprocess=False,
48
+ batch_size=16,
49
+ )
50
+ logger.info("Quality dataloader ready with %d images", len(fpaths))
51
+
52
+ output_ids, outputs = [], []
53
+ with torch.no_grad():
54
+ for batch in tqdm(dataloader):
55
+ if len(batch) == 0:
56
+ continue
57
+
58
+ im = batch["image"].to(device)
59
+
60
+ # QUALITY
61
+ quality = ensemble_quality.predict_step(im)
62
+ quality = torch.mean(quality, dim=0)
63
+
64
+ items = {"id": batch["id"], "quality": quality}
65
+ items = decollate_batch(items)
66
+
67
+ for item in items:
68
+ output_ids.append(item["id"])
69
+ outputs.append(item["quality"].tolist())
70
+
71
+ return pd.DataFrame(
72
+ outputs,
73
+ index=output_ids,
74
+ columns=["q1", "q2", "q3"],
75
+ )
76
+
77
+
78
+ def run_segmentation_vessels_and_av(
79
+ rgb_paths: List[Path],
80
+ ce_paths: Optional[List[Path]] = None,
81
+ ids: Optional[List[str]] = None,
82
+ av_path: Optional[Path] = None,
83
+ vessels_path: Optional[Path] = None,
84
+ device: torch.device = preferred_device(),
85
+ ) -> None:
86
+ """
87
+ Run AV and vessel segmentation on the provided images.
88
+
89
+ Args:
90
+ rgb_paths: List of paths to RGB fundus images
91
+ ce_paths: Optional list of paths to contrast enhanced images
92
+ ids: Optional list of ids to pass to _make_inference_dataloader
93
+ av_path: Folder where to store output AV segmentations
94
+ vessels_path: Folder where to store output vessel segmentations
95
+ device: Device to run inference on
96
+ """
97
+ # Create output directories if they don't exist
98
+ if av_path is not None:
99
+ av_path.mkdir(exist_ok=True, parents=True)
100
+ if vessels_path is not None:
101
+ vessels_path.mkdir(exist_ok=True, parents=True)
102
+
103
+ # Load models
104
+ logger.info("Loading AV and vessel models on %s", device)
105
+ ensemble_av = SegmentationEnsemble.from_release("av_july24.pt").to(device).eval()
106
+ ensemble_vessels = (
107
+ SegmentationEnsemble.from_release("vessels_july24.pt").to(device).eval()
108
+ )
109
+
110
+ # Prepare input paths
111
+ if ce_paths is None:
112
+ # If CE paths are not provided, use RGB paths for both inputs
113
+ fpaths = rgb_paths
114
+ else:
115
+ # If CE paths are provided, pair them with RGB paths
116
+ if len(rgb_paths) != len(ce_paths):
117
+ raise ValueError("rgb_paths and ce_paths must have the same length")
118
+ fpaths = list(zip(rgb_paths, ce_paths))
119
+
120
+ # Create dataloader
121
+ dataloader = ensemble_av._make_inference_dataloader(
122
+ fpaths,
123
+ ids=ids,
124
+ num_workers=_inference_num_workers(device),
125
+ preprocess=False,
126
+ batch_size=8,
127
+ )
128
+ logger.info("AV and vessel dataloader ready with %d images", len(fpaths))
129
+
130
+ # Run inference
131
+ with torch.no_grad():
132
+ for batch in tqdm(dataloader):
133
+ # AV segmentation
134
+ if av_path is not None:
135
+ with _autocast_context(device):
136
+ proba = ensemble_av.forward(batch["image"].to(device))
137
+ proba = torch.mean(proba, dim=1) # average over models
138
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
139
+ proba = torch.nn.functional.softmax(proba, dim=-1)
140
+
141
+ items = {
142
+ "id": batch["id"],
143
+ "image": proba,
144
+ }
145
+
146
+ items = decollate_batch(items)
147
+ for i, item in enumerate(items):
148
+ fpath = os.path.join(av_path, f"{item['id']}.png")
149
+ mask = np.argmax(item["image"], -1)
150
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
151
+
152
+ # Vessel segmentation
153
+ if vessels_path is not None:
154
+ with _autocast_context(device):
155
+ proba = ensemble_vessels.forward(batch["image"].to(device))
156
+ proba = torch.mean(proba, dim=1) # average over models
157
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
158
+ proba = torch.nn.functional.softmax(proba, dim=-1)
159
+
160
+ items = {
161
+ "id": batch["id"],
162
+ "image": proba,
163
+ }
164
+
165
+ items = decollate_batch(items)
166
+ for i, item in enumerate(items):
167
+ fpath = os.path.join(vessels_path, f"{item['id']}.png")
168
+ mask = np.argmax(item["image"], -1)
169
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
170
+
171
+
172
+ def run_segmentation_disc(
173
+ rgb_paths: List[Path],
174
+ ce_paths: Optional[List[Path]] = None,
175
+ ids: Optional[List[str]] = None,
176
+ output_path: Optional[Path] = None,
177
+ device: torch.device = preferred_device(),
178
+ ) -> None:
179
+ logger.info("Loading disc model on %s", device)
180
+ ensemble_disc = (
181
+ SegmentationEnsemble.from_release("disc_july24.pt").to(device).eval()
182
+ )
183
+
184
+ # Prepare input paths
185
+ if ce_paths is None:
186
+ # If CE paths are not provided, use RGB paths for both inputs
187
+ fpaths = rgb_paths
188
+ else:
189
+ # If CE paths are provided, pair them with RGB paths
190
+ if len(rgb_paths) != len(ce_paths):
191
+ raise ValueError("rgb_paths and ce_paths must have the same length")
192
+ fpaths = list(zip(rgb_paths, ce_paths))
193
+
194
+ dataloader = ensemble_disc._make_inference_dataloader(
195
+ fpaths,
196
+ ids=ids,
197
+ num_workers=_inference_num_workers(device),
198
+ preprocess=False,
199
+ batch_size=8,
200
+ )
201
+ logger.info("Disc dataloader ready with %d images", len(fpaths))
202
+
203
+ with torch.no_grad():
204
+ for batch in tqdm(dataloader):
205
+ # AV
206
+ with _autocast_context(device):
207
+ proba = ensemble_disc.forward(batch["image"].to(device))
208
+ proba = torch.mean(proba, dim=1) # average over models
209
+ proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
210
+ proba = torch.nn.functional.softmax(proba, dim=-1)
211
+
212
+ items = {
213
+ "id": batch["id"],
214
+ "image": proba,
215
+ }
216
+
217
+ items = decollate_batch(items)
218
+ items = [dataloader.dataset.transform.undo_item(item) for item in items]
219
+ for i, item in enumerate(items):
220
+ fpath = os.path.join(output_path, f"{item['id']}.png")
221
+
222
+ mask = np.argmax(item["image"], -1)
223
+ Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
224
+
225
+
226
+ def run_fovea_detection(
227
+ rgb_paths: List[Path],
228
+ ce_paths: Optional[List[Path]] = None,
229
+ ids: Optional[List[str]] = None,
230
+ device: torch.device = preferred_device(),
231
+ ) -> None:
232
+ # def run_fovea_detection(fpaths, ids, device: torch.device):
233
+ logger.info("Loading fovea model on %s", device)
234
+ ensemble_fovea = HeatmapRegressionEnsemble.from_release("fovea_july24.pt").to(
235
+ device
236
+ )
237
+
238
+ # Prepare input paths
239
+ if ce_paths is None:
240
+ # If CE paths are not provided, use RGB paths for both inputs
241
+ fpaths = rgb_paths
242
+ else:
243
+ # If CE paths are provided, pair them with RGB paths
244
+ if len(rgb_paths) != len(ce_paths):
245
+ raise ValueError("rgb_paths and ce_paths must have the same length")
246
+ fpaths = list(zip(rgb_paths, ce_paths))
247
+
248
+ dataloader = ensemble_fovea._make_inference_dataloader(
249
+ fpaths,
250
+ ids=ids,
251
+ num_workers=_inference_num_workers(device),
252
+ preprocess=False,
253
+ batch_size=8,
254
+ )
255
+ logger.info("Fovea dataloader ready with %d images", len(fpaths))
256
+
257
+ output_ids, outputs = [], []
258
+ with torch.no_grad():
259
+ for batch in tqdm(dataloader):
260
+ if len(batch) == 0:
261
+ continue
262
+
263
+ im = batch["image"].to(device)
264
+
265
+ # FOVEA DETECTION
266
+ with _autocast_context(device):
267
+ heatmap = ensemble_fovea.forward(im)
268
+ keypoints = extract_keypoints_from_heatmaps(heatmap)
269
+
270
+ kp_fovea = torch.mean(keypoints, dim=1) # average over models
271
+
272
+ items = {
273
+ "id": batch["id"],
274
+ "keypoints": kp_fovea,
275
+ "metadata": batch["metadata"],
276
+ }
277
+ items = decollate_batch(items)
278
+
279
+ items = [dataloader.dataset.transform.undo_item(item) for item in items]
280
+
281
+ for item in items:
282
+ output_ids.append(item["id"])
283
+ outputs.append(
284
+ [
285
+ *item["keypoints"][0].tolist(),
286
+ ]
287
+ )
288
+ return pd.DataFrame(
289
+ outputs,
290
+ index=output_ids,
291
+ columns=["x_fovea", "y_fovea"],
292
+ )
vascx_models/utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw
7
+
8
+ from .config import OverlayConfig
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def create_fundus_overlay(
14
+ rgb_path: str,
15
+ av_path: Optional[str] = None,
16
+ disc_path: Optional[str] = None,
17
+ ring_2r_path: Optional[str] = None,
18
+ ring_3r_path: Optional[str] = None,
19
+ fovea_location: Optional[Tuple[int, int]] = None,
20
+ output_path: Optional[str] = None,
21
+ overlay_config: Optional[OverlayConfig] = None,
22
+ ) -> np.ndarray:
23
+ """
24
+ Create a visualization of a fundus image with overlaid segmentations and markers.
25
+
26
+ Args:
27
+ rgb_path: Path to the RGB fundus image
28
+ av_path: Optional path to artery-vein segmentation (1=artery, 2=vein, 3=intersection)
29
+ disc_path: Optional path to binary disc segmentation
30
+ ring_2r_path: Optional path to a binary 2r ring mask
31
+ ring_3r_path: Optional path to a binary 3r ring mask
32
+ fovea_location: Optional (x,y) tuple indicating the location of the fovea
33
+ output_path: Optional path to save the visualization image
34
+ overlay_config: Overlay display configuration including enabled layers and colors
35
+
36
+ Returns:
37
+ Numpy array containing the visualization image
38
+ """
39
+ overlay_config = overlay_config or OverlayConfig()
40
+
41
+ # Load RGB image
42
+ rgb_img = np.array(Image.open(rgb_path))
43
+
44
+ # Create output image starting with the RGB image
45
+ output_img = rgb_img.copy()
46
+
47
+ # Load and overlay AV segmentation if provided
48
+ if av_path and (overlay_config.layers.arteries or overlay_config.layers.veins):
49
+ av_mask = np.array(Image.open(av_path))
50
+
51
+ # Create masks for arteries (1), veins (2) and intersections (3)
52
+ artery_mask = av_mask == 1
53
+ vein_mask = av_mask == 2
54
+ intersection_mask = av_mask == 3
55
+
56
+ if overlay_config.layers.arteries:
57
+ artery_combined = np.logical_or(artery_mask, intersection_mask)
58
+ output_img[artery_combined, :] = overlay_config.colors.artery
59
+
60
+ if overlay_config.layers.veins:
61
+ vein_combined = np.logical_or(vein_mask, intersection_mask)
62
+ output_img[vein_combined, :] = overlay_config.colors.vein
63
+
64
+ # Load and overlay optic disc segmentation if provided
65
+ if disc_path and overlay_config.layers.disc:
66
+ disc_mask = np.array(Image.open(disc_path)) > 0
67
+ output_img[disc_mask, :] = overlay_config.colors.disc
68
+
69
+ if ring_2r_path and overlay_config.layers.ring_2r:
70
+ ring_2r_mask = np.array(Image.open(ring_2r_path)) > 0
71
+ output_img[ring_2r_mask, :] = overlay_config.colors.ring_2r
72
+
73
+ if ring_3r_path and overlay_config.layers.ring_3r:
74
+ ring_3r_mask = np.array(Image.open(ring_3r_path)) > 0
75
+ output_img[ring_3r_mask, :] = overlay_config.colors.ring_3r
76
+
77
+ # Convert to PIL image for drawing the fovea marker
78
+ pil_img = Image.fromarray(output_img)
79
+
80
+ # Add fovea marker if provided
81
+ if fovea_location and overlay_config.layers.fovea:
82
+ draw = ImageDraw.Draw(pil_img)
83
+ x, y = fovea_location
84
+ marker_size = (
85
+ min(pil_img.width, pil_img.height) // 50
86
+ ) # Scale marker with image
87
+
88
+ # Draw yellow X at fovea location
89
+ draw.line(
90
+ [(x - marker_size, y - marker_size), (x + marker_size, y + marker_size)],
91
+ fill=overlay_config.colors.fovea,
92
+ width=2,
93
+ )
94
+ draw.line(
95
+ [(x - marker_size, y + marker_size), (x + marker_size, y - marker_size)],
96
+ fill=overlay_config.colors.fovea,
97
+ width=2,
98
+ )
99
+
100
+ # Convert back to numpy array
101
+ output_img = np.array(pil_img)
102
+
103
+ # Save output if path provided
104
+ if output_path:
105
+ Image.fromarray(output_img).save(output_path)
106
+
107
+ return output_img
108
+
109
+
110
+ def batch_create_overlays(
111
+ rgb_dir: Path,
112
+ output_dir: Path,
113
+ av_dir: Optional[Path] = None,
114
+ disc_dir: Optional[Path] = None,
115
+ ring_2r_dir: Optional[Path] = None,
116
+ ring_3r_dir: Optional[Path] = None,
117
+ fovea_data: Optional[Dict[str, Tuple[int, int]]] = None,
118
+ overlay_config: Optional[OverlayConfig] = None,
119
+ ) -> None:
120
+ """
121
+ Create visualization overlays for a batch of images.
122
+
123
+ Args:
124
+ rgb_dir: Directory containing RGB fundus images
125
+ output_dir: Directory to save visualization images
126
+ av_dir: Optional directory containing AV segmentations
127
+ disc_dir: Optional directory containing disc segmentations
128
+ ring_2r_dir: Optional directory containing 2r ring masks
129
+ ring_3r_dir: Optional directory containing 3r ring masks
130
+ fovea_data: Optional dictionary mapping image IDs to fovea coordinates
131
+ overlay_config: Overlay display configuration including enabled layers and colors
132
+
133
+ Returns:
134
+ List of paths to created visualization images
135
+ """
136
+ # Create output directory if it doesn't exist
137
+ output_dir.mkdir(exist_ok=True, parents=True)
138
+ overlay_config = overlay_config or OverlayConfig()
139
+
140
+ # Get all RGB images
141
+ rgb_files = list(rgb_dir.glob("*.png"))
142
+ if not rgb_files:
143
+ logger.warning("No RGB images found for overlays in %s", rgb_dir)
144
+ return []
145
+ logger.info("Creating overlays for %d images", len(rgb_files))
146
+
147
+ # Process each image
148
+ for rgb_file in rgb_files:
149
+ image_id = rgb_file.stem
150
+
151
+ # Check for corresponding AV segmentation
152
+ av_file = None
153
+ if av_dir:
154
+ av_file_path = av_dir / f"{image_id}.png"
155
+ if av_file_path.exists():
156
+ av_file = str(av_file_path)
157
+
158
+ # Check for corresponding disc segmentation
159
+ disc_file = None
160
+ if disc_dir:
161
+ disc_file_path = disc_dir / f"{image_id}.png"
162
+ if disc_file_path.exists():
163
+ disc_file = str(disc_file_path)
164
+
165
+ ring_2r_file = None
166
+ if ring_2r_dir:
167
+ ring_2r_file_path = ring_2r_dir / f"{image_id}.png"
168
+ if ring_2r_file_path.exists():
169
+ ring_2r_file = str(ring_2r_file_path)
170
+
171
+ ring_3r_file = None
172
+ if ring_3r_dir:
173
+ ring_3r_file_path = ring_3r_dir / f"{image_id}.png"
174
+ if ring_3r_file_path.exists():
175
+ ring_3r_file = str(ring_3r_file_path)
176
+
177
+ # Get fovea location if available
178
+ fovea_location = None
179
+ if fovea_data and image_id in fovea_data:
180
+ fovea_location = fovea_data[image_id]
181
+
182
+ # Create output path
183
+ output_file = output_dir / f"{image_id}.png"
184
+
185
+ # Create and save overlay
186
+ create_fundus_overlay(
187
+ rgb_path=str(rgb_file),
188
+ av_path=av_file,
189
+ disc_path=disc_file,
190
+ ring_2r_path=ring_2r_file,
191
+ ring_3r_path=ring_3r_file,
192
+ fovea_location=fovea_location,
193
+ output_path=str(output_file),
194
+ overlay_config=overlay_config,
195
+ )
196
+ logger.info("Finished overlay generation in %s", output_dir)
vessels/vessels_july24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdabae77502648acd1c176bff6d5e3c8295da60f18f2f84fe1bcd6181d2b2ca4
3
+ size 352821632
vessels/vessels_july24_DRHAGIS.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:406cf89ed2c35713296096cdfdbc2d6e67c164e25dbd6542612f31e2bfa0c85e
3
+ size 352848262