Migrate repo
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- .gitignore +12 -0
- README.md +193 -3
- artery_vein/av_july24.pt +3 -0
- artery_vein/av_july24_AVRDB.pt +3 -0
- artery_vein/av_july24_IOSTAR.pt +3 -0
- artery_vein/av_july24_LEUVEN.pt +3 -0
- artery_vein/av_july24_RS.pt +3 -0
- config.yaml +16 -0
- disc/disc_july24.pt +3 -0
- disc/disc_july24_ADAM.pt +3 -0
- disc/disc_july24_IDRID.pt +3 -0
- disc/disc_july24_ORIGA.pt +3 -0
- disc/disc_july24_PAPILA.pt +3 -0
- discedge/discedge_july24.pt +3 -0
- environment.yml +19 -0
- fovea/fovea_july24.pt +3 -0
- imgs/CHASEDB1_08L.png +3 -0
- imgs/CHASEDB1_08L_rgb.png +3 -0
- imgs/CHASEDB1_12R.png +3 -0
- imgs/CHASEDB1_12R_rgb.png +3 -0
- imgs/DRIVE_22.png +3 -0
- imgs/DRIVE_22_rgb.png +3 -0
- imgs/DRIVE_40.png +3 -0
- imgs/DRIVE_40_rgb.png +3 -0
- imgs/HRF_04_g.png +3 -0
- imgs/HRF_04_g_rgb.png +3 -0
- imgs/HRF_07_dr.png +3 -0
- imgs/HRF_07_dr_rgb.png +3 -0
- imgs/samples_vascx_hrf.png +3 -0
- notebooks/0_preprocess.ipynb +138 -0
- notebooks/1_segment_preprocessed.ipynb +217 -0
- odfd/odfd_march25.pt +3 -0
- quality/quality.pt +3 -0
- run.sh +60 -0
- samples/fundus/original/CHASEDB1_08L.png +3 -0
- samples/fundus/original/CHASEDB1_12R.png +3 -0
- samples/fundus/original/DRIVE_22.png +3 -0
- samples/fundus/original/DRIVE_40.png +3 -0
- samples/fundus/original/HRF_04_g.jpg +3 -0
- samples/fundus/original/HRF_07_dr.jpg +3 -0
- setup.py +36 -0
- vascx_models/__init__.py +0 -0
- vascx_models/cli.py +259 -0
- vascx_models/config.py +196 -0
- vascx_models/disc_rings.py +118 -0
- vascx_models/inference.py +292 -0
- vascx_models/utils.py +196 -0
- vessels/vessels_july24.pt +3 -0
- vessels/vessels_july24_DRHAGIS.pt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
| 2 |
+
__pycache__
|
| 3 |
+
*.egg-info
|
| 4 |
+
*.zip
|
| 5 |
+
.DS_Store
|
| 6 |
+
.cache/
|
| 7 |
+
.mplconfig/
|
| 8 |
+
model_releases/
|
| 9 |
+
output_*/
|
| 10 |
+
output_*.zip
|
| 11 |
+
/samples/fundus/*
|
| 12 |
+
!/samples/fundus/original
|
README.md
CHANGED
|
@@ -1,3 +1,193 @@
|
|
| 1 |
-
---
|
| 2 |
-
license:
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: agpl-3.0
|
| 3 |
+
pipeline_tag: image-segmentation
|
| 4 |
+
tags:
|
| 5 |
+
- medical
|
| 6 |
+
- biology
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# 👁️ VascX Fork
|
| 10 |
+
|
| 11 |
+
This repository contains the instructions for using the VascX models from the paper [VascX Models: Model Ensembles for Retinal Vascular Analysis from Color Fundus Images](https://arxiv.org/abs/2409.16016). This fork is published as `zyf0717/vascx-fork` on the Hugging Face Hub.
|
| 12 |
+
|
| 13 |
+
The model weights are in [huggingface](https://huggingface.co/zyf0717/vascx-fork).
|
| 14 |
+
|
| 15 |
+
<img src="imgs/samples_vascx_hrf.png">
|
| 16 |
+
|
| 17 |
+
## 🛠️ Installation
|
| 18 |
+
|
| 19 |
+
To install the entire fundus analysis pipeline including fundus preprocessing, model inference code and vascular biomarker extraction:
|
| 20 |
+
|
| 21 |
+
1. Create a conda or virtualenv virtual environment, or otherwise ensure a clean environment.
|
| 22 |
+
|
| 23 |
+
2. Install `torch` and `torchvision` for your platform.
|
| 24 |
+
|
| 25 |
+
3. Install the pipeline runtime packages:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
pip install retinalysis-fundusprep retinalysis-inference
|
| 29 |
+
pip install -e .
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
The `environment.yml` in this repository includes the same runtime dependencies.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## 🚀 `vascx run` Command
|
| 36 |
+
|
| 37 |
+
The repository name is `vascx-fork`, but the installed CLI entry point remains `vascx` for compatibility.
|
| 38 |
+
|
| 39 |
+
The `run` command provides a comprehensive pipeline for processing fundus images, performing various analyses, and creating visualizations.
|
| 40 |
+
|
| 41 |
+
### Usage
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
vascx run DATA_PATH OUTPUT_PATH [OPTIONS]
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
If `config.yaml` exists in the current working directory or at the repository root, `vascx run` loads it automatically. You can also point to a specific file with `--config /path/to/config.yaml`.
|
| 48 |
+
|
| 49 |
+
### Arguments
|
| 50 |
+
|
| 51 |
+
- `DATA_PATH`: Path to input data. Can be either:
|
| 52 |
+
- A directory containing fundus images
|
| 53 |
+
- A CSV file with a 'path' column containing paths to images
|
| 54 |
+
|
| 55 |
+
- `OUTPUT_PATH`: Directory where processed results will be stored
|
| 56 |
+
|
| 57 |
+
### Options
|
| 58 |
+
|
| 59 |
+
| Option | Default | Description |
|
| 60 |
+
|--------|---------|-------------|
|
| 61 |
+
| `--preprocess/--no-preprocess` | `--preprocess` | Run preprocessing to standardize images for model input |
|
| 62 |
+
| `--vessels/--no-vessels` | `--vessels` | Run vessel segmentation and artery-vein classification |
|
| 63 |
+
| `--disc/--no-disc` | `--disc` | Run optic disc segmentation |
|
| 64 |
+
| `--quality/--no-quality` | `--quality` | Run image quality assessment |
|
| 65 |
+
| `--fovea/--no-fovea` | `--fovea` | Run fovea detection |
|
| 66 |
+
| `--overlay/--no-overlay` | `config.yaml` or `--overlay` | Create visualization overlays combining all results |
|
| 67 |
+
| `--config PATH` | auto-detect `config.yaml` | Load pipeline configuration from YAML |
|
| 68 |
+
| `--n_jobs` | `4` | Number of preprocessing workers for parallel processing |
|
| 69 |
+
|
| 70 |
+
### 📁 Output Structure
|
| 71 |
+
|
| 72 |
+
When run with default options, the command creates the following structure in `OUTPUT_PATH`:
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
OUTPUT_PATH/
|
| 76 |
+
├── preprocessed_rgb/ # Standardized fundus images
|
| 77 |
+
├── vessels/ # Vessel segmentation results
|
| 78 |
+
├── artery_vein/ # Artery-vein classification
|
| 79 |
+
├── disc/ # Optic disc segmentation
|
| 80 |
+
├── disc_ring_2r/ # Binary masks for the 2r optic-disc ring
|
| 81 |
+
├── disc_ring_3r/ # Binary masks for the 3r optic-disc ring
|
| 82 |
+
├── overlays/ # Visualization images
|
| 83 |
+
├── bounds.csv # Image boundary information
|
| 84 |
+
├── disc_geometry.csv # Disc center and radius estimates in pixels
|
| 85 |
+
├── quality.csv # Image quality scores
|
| 86 |
+
└── fovea.csv # Fovea coordinates
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### 🔄 Processing Stages
|
| 90 |
+
|
| 91 |
+
1. **Preprocessing**:
|
| 92 |
+
- Standardizes input images for consistent analysis
|
| 93 |
+
- Outputs preprocessed images and boundary information
|
| 94 |
+
|
| 95 |
+
2. **Quality Assessment**:
|
| 96 |
+
- Evaluates image quality with three quality metrics (q1, q2, q3)
|
| 97 |
+
- Higher scores indicate better image quality
|
| 98 |
+
|
| 99 |
+
3. **Vessel Segmentation and Artery-Vein Classification**:
|
| 100 |
+
- Identifies blood vessels in the retina
|
| 101 |
+
- Classifies vessels as arteries (1) or veins (2) with intersections (3)
|
| 102 |
+
|
| 103 |
+
4. **Optic Disc Segmentation**:
|
| 104 |
+
- Identifies the optic disc location and boundaries
|
| 105 |
+
- Estimates disc center and radius from the disc mask
|
| 106 |
+
- Generates 2r and 3r ring masks around the disc
|
| 107 |
+
|
| 108 |
+
5. **Fovea Detection**:
|
| 109 |
+
- Determines the coordinates of the fovea (center of vision)
|
| 110 |
+
|
| 111 |
+
6. **Visualization Overlays**:
|
| 112 |
+
- Creates color-coded images showing:
|
| 113 |
+
- Arteries in red
|
| 114 |
+
- Veins in blue
|
| 115 |
+
- Optic disc in white
|
| 116 |
+
- 2r ring in green
|
| 117 |
+
- 3r ring in magenta
|
| 118 |
+
- Fovea marked with yellow X
|
| 119 |
+
- Overlay layers and colors can be controlled from `config.yaml`
|
| 120 |
+
|
| 121 |
+
### ⚙️ `config.yaml`
|
| 122 |
+
|
| 123 |
+
The repository root now includes a `config.yaml` file for overlay settings. The default file looks like this:
|
| 124 |
+
|
| 125 |
+
```yaml
|
| 126 |
+
overlay:
|
| 127 |
+
enabled: true
|
| 128 |
+
layers:
|
| 129 |
+
arteries: true
|
| 130 |
+
veins: true
|
| 131 |
+
disc: true
|
| 132 |
+
ring_2r: true
|
| 133 |
+
ring_3r: true
|
| 134 |
+
fovea: true
|
| 135 |
+
colours:
|
| 136 |
+
artery: "#FF0000"
|
| 137 |
+
vein: "#0000FF"
|
| 138 |
+
disc: "#FFFFFF"
|
| 139 |
+
ring_2r: "#00FF00"
|
| 140 |
+
ring_3r: "#FF00FF"
|
| 141 |
+
fovea: "#FFFF00"
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Notes:
|
| 145 |
+
|
| 146 |
+
- `overlay.enabled` controls whether overlays are produced when `--overlay/--no-overlay` is not set explicitly.
|
| 147 |
+
- `overlay.layers` lets you choose which predictions are drawn.
|
| 148 |
+
- `overlay.colors` and `overlay.colours` are both accepted.
|
| 149 |
+
- Colors can be written as `#RRGGBB` strings or 3-value RGB arrays such as `[255, 0, 0]`.
|
| 150 |
+
|
| 151 |
+
### 💻 Examples
|
| 152 |
+
|
| 153 |
+
**Process a directory of images with all analyses:**
|
| 154 |
+
```bash
|
| 155 |
+
vascx run /path/to/images /path/to/output
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
**Process specific images listed in a CSV:**
|
| 159 |
+
```bash
|
| 160 |
+
vascx run /path/to/image_list.csv /path/to/output
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
**Only run preprocessing and vessel segmentation:**
|
| 164 |
+
```bash
|
| 165 |
+
vascx run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
**Skip preprocessing on already preprocessed images:**
|
| 169 |
+
```bash
|
| 170 |
+
vascx run /path/to/preprocessed/images /path/to/output --no-preprocess
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
**Increase parallel processing workers:**
|
| 174 |
+
```bash
|
| 175 |
+
vascx run /path/to/images /path/to/output --n_jobs 8
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
### 📝 Notes
|
| 179 |
+
|
| 180 |
+
- The CSV input must contain a 'path' column with image file paths
|
| 181 |
+
- If the CSV includes an 'id' column, these IDs will be used instead of filenames
|
| 182 |
+
- When `--no-preprocess` is used, input images must already be in the proper format
|
| 183 |
+
- The overlay visualization requires at least one analysis component to be enabled
|
| 184 |
+
|
| 185 |
+
## 📓 Notebooks
|
| 186 |
+
|
| 187 |
+
For more advanced usage, we have Jupyter notebooks showing how preprocessing and inference are run.
|
| 188 |
+
|
| 189 |
+
To speed up re-execution of vascx we recommend to run the preprocessing and segmentation steps separately:
|
| 190 |
+
|
| 191 |
+
1. Preprocessing. See [this notebook](./notebooks/0_preprocess.ipynb). This step is CPU-heavy and benefits from parallelization (see notebook).
|
| 192 |
+
|
| 193 |
+
2. Inference. See [this notebook](./notebooks/1_segment_preprocessed.ipynb). All models can be ran in a single GPU with >10GB VRAM.
|
artery_vein/av_july24.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b11d4e26ada8e1f0f279747afa5f4ef348d9d7350c4bf80e7ae6b5ac8d0b95b5
|
| 3 |
+
size 352774102
|
artery_vein/av_july24_AVRDB.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21fd6a693be9a5ffbf1b56e624612a493470d6e63025ce11f2e3886bd6f18b4c
|
| 3 |
+
size 352791110
|
artery_vein/av_july24_IOSTAR.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21fd6a693be9a5ffbf1b56e624612a493470d6e63025ce11f2e3886bd6f18b4c
|
| 3 |
+
size 352791110
|
artery_vein/av_july24_LEUVEN.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:446580e6cda2acec8dc2ab30d9526735fc670f296048055cbb5ebb9ccac28d0b
|
| 3 |
+
size 352830466
|
artery_vein/av_july24_RS.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3de549b74ebd9c9a4f49c17043b831665cc7a1981773ff8c17db2416b9dfe48
|
| 3 |
+
size 352805874
|
config.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
overlay:
|
| 2 |
+
enabled: true
|
| 3 |
+
layers:
|
| 4 |
+
arteries: true
|
| 5 |
+
veins: true
|
| 6 |
+
disc: true
|
| 7 |
+
ring_2r: true
|
| 8 |
+
ring_3r: true
|
| 9 |
+
fovea: true
|
| 10 |
+
colours:
|
| 11 |
+
artery: "#FF0000"
|
| 12 |
+
vein: "#0000FF"
|
| 13 |
+
disc: "#FFFFFF"
|
| 14 |
+
ring_2r: "#00FF00"
|
| 15 |
+
ring_3r: "#FF00FF"
|
| 16 |
+
fovea: "#FFFF00"
|
disc/disc_july24.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6892b1bdb3bb68b666ea9a7891b0fb2f6fbb5fd4f05038c013c1c69ec6c7910c
|
| 3 |
+
size 352801898
|
disc/disc_july24_ADAM.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9f423e047c88d9f6c03ada2c706bc84265d61ec473eaab28e8ca12a0f1738401
|
| 3 |
+
size 352819138
|
disc/disc_july24_IDRID.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbf239cfec2ee9b550aa09e3623af01cd79de688cac1c79902d0feb7d24bb3f7
|
| 3 |
+
size 352835178
|
disc/disc_july24_ORIGA.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18cc9ffac522fc9b91e46fe3aed8d6bb8fbf00e44bfb768b622f5b881713add6
|
| 3 |
+
size 352826358
|
disc/disc_july24_PAPILA.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:080de64a1c005a921cb09eb04e42ce9c087dca2c1090209f77a3635af9eb1d19
|
| 3 |
+
size 352832490
|
discedge/discedge_july24.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:891d8ef9bbc0676b019a81b1eceb349c1cdf4b5665a834196dd252915af64392
|
| 3 |
+
size 352723146
|
environment.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: vascx-fork
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.11
|
| 6 |
+
- pip
|
| 7 |
+
- numpy=2.*
|
| 8 |
+
- pandas=2.*
|
| 9 |
+
- tqdm=4.*
|
| 10 |
+
- pillow=11.*
|
| 11 |
+
- click=8.*
|
| 12 |
+
- pyyaml=6.*
|
| 13 |
+
- pip:
|
| 14 |
+
- torch==2.11.0
|
| 15 |
+
- torchvision==0.26.0
|
| 16 |
+
- torchaudio==2.11.0
|
| 17 |
+
- retinalysis-fundusprep
|
| 18 |
+
- retinalysis-inference
|
| 19 |
+
- -e .
|
fovea/fovea_july24.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1af042f7e2a398f512be8a8d54cc480300312c3f3692c13bd98be65439a33222
|
| 3 |
+
size 352714676
|
imgs/CHASEDB1_08L.png
ADDED
|
Git LFS Details
|
imgs/CHASEDB1_08L_rgb.png
ADDED
|
Git LFS Details
|
imgs/CHASEDB1_12R.png
ADDED
|
Git LFS Details
|
imgs/CHASEDB1_12R_rgb.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_22.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_22_rgb.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_40.png
ADDED
|
Git LFS Details
|
imgs/DRIVE_40_rgb.png
ADDED
|
Git LFS Details
|
imgs/HRF_04_g.png
ADDED
|
Git LFS Details
|
imgs/HRF_04_g_rgb.png
ADDED
|
Git LFS Details
|
imgs/HRF_07_dr.png
ADDED
|
Git LFS Details
|
imgs/HRF_07_dr_rgb.png
ADDED
|
Git LFS Details
|
imgs/samples_vascx_hrf.png
ADDED
|
Git LFS Details
|
notebooks/0_preprocess.ipynb
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from pathlib import Path\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from rtnls_fundusprep.preprocessor import parallel_preprocess"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"## Preprocessing\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"This code will preprocess the images and write .png files with the square fundus image and the contrast enhanced version\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"This step is not strictly necessary, but it is useful if you want to run the preprocessing step separately before model inference\n"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"Create a list of files to be preprocessed:"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": 2,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"ds_path = Path(\"../samples/fundus\")\n",
|
| 41 |
+
"files = list((ds_path / \"original\").glob(\"*\"))"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"source": [
|
| 48 |
+
"Images with .dcm extension will be read as dicom and the pixel_array will be read as RGB. All other images will be read using PIL's Image.open"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": 3,
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"outputs": [
|
| 56 |
+
{
|
| 57 |
+
"name": "stderr",
|
| 58 |
+
"output_type": "stream",
|
| 59 |
+
"text": [
|
| 60 |
+
"0it [00:00, ?it/s][Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.\n",
|
| 61 |
+
"6it [00:00, 154.80it/s]\n"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "stdout",
|
| 66 |
+
"output_type": "stream",
|
| 67 |
+
"text": [
|
| 68 |
+
"Error with image ../samples/fundus/original/HRF_07_dr.jpg\n",
|
| 69 |
+
"Error with image ../samples/fundus/original/HRF_04_g.jpg\n"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"name": "stderr",
|
| 74 |
+
"output_type": "stream",
|
| 75 |
+
"text": [
|
| 76 |
+
"[Parallel(n_jobs=4)]: Done 2 out of 6 | elapsed: 0.9s remaining: 1.8s\n",
|
| 77 |
+
"[Parallel(n_jobs=4)]: Done 3 out of 6 | elapsed: 1.5s remaining: 1.5s\n",
|
| 78 |
+
"[Parallel(n_jobs=4)]: Done 4 out of 6 | elapsed: 1.5s remaining: 0.8s\n",
|
| 79 |
+
"[Parallel(n_jobs=4)]: Done 6 out of 6 | elapsed: 1.6s finished\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"bounds = parallel_preprocess(\n",
|
| 85 |
+
" files, # List of image files\n",
|
| 86 |
+
" rgb_path=ds_path / \"rgb\", # Output path for RGB images\n",
|
| 87 |
+
" ce_path=ds_path / \"ce\", # Output path for Contrast Enhanced images\n",
|
| 88 |
+
" n_jobs=4, # number of preprocessing workers\n",
|
| 89 |
+
")\n",
|
| 90 |
+
"df_bounds = pd.DataFrame(bounds).set_index(\"id\")"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "markdown",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"source": [
|
| 97 |
+
"The preprocessor will produce RGB and contrast-enhanced preprocessed images cropped to a square and return a dataframe with the image bounds that can be used to reconstruct the original image. Output files will be named the same as input images, but with .png extension. Be careful with providing multiple inputs with the same filename without extension as this will result in over-written images. Any exceptions during pre-processing will not stop execution but will print error. Images that failed pre-processing for any reason will be marked with `success=False` in the df_bounds dataframe."
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": 4,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"df_bounds.to_csv(ds_path / \"meta.csv\")"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": []
|
| 115 |
+
}
|
| 116 |
+
],
|
| 117 |
+
"metadata": {
|
| 118 |
+
"kernelspec": {
|
| 119 |
+
"display_name": "retinalysis",
|
| 120 |
+
"language": "python",
|
| 121 |
+
"name": "python3"
|
| 122 |
+
},
|
| 123 |
+
"language_info": {
|
| 124 |
+
"codemirror_mode": {
|
| 125 |
+
"name": "ipython",
|
| 126 |
+
"version": 3
|
| 127 |
+
},
|
| 128 |
+
"file_extension": ".py",
|
| 129 |
+
"mimetype": "text/x-python",
|
| 130 |
+
"name": "python",
|
| 131 |
+
"nbconvert_exporter": "python",
|
| 132 |
+
"pygments_lexer": "ipython3",
|
| 133 |
+
"version": "3.10.13"
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
"nbformat": 4,
|
| 137 |
+
"nbformat_minor": 2
|
| 138 |
+
}
|
notebooks/1_segment_preprocessed.ipynb
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"from pathlib import Path\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"import torch\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from rtnls_inference import (\n",
|
| 14 |
+
" HeatmapRegressionEnsemble,\n",
|
| 15 |
+
" SegmentationEnsemble,\n",
|
| 16 |
+
")"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"source": [
|
| 23 |
+
"## Segmentation of preprocessed images\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"Here we segment images preprocessed using 0_preprocess.ipynb\n"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": []
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": 2,
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"ds_path = Path(\"../samples/fundus\")\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"# input folders. these are the folders where we stored the preprocessed images\n",
|
| 42 |
+
"rgb_path = ds_path / \"rgb\"\n",
|
| 43 |
+
"ce_path = ds_path / \"ce\"\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# these are the output folders for:\n",
|
| 46 |
+
"av_path = ds_path / \"av\" # artery-vein segmentations\n",
|
| 47 |
+
"discs_path = ds_path / \"discs\" # optic disc segmentations\n",
|
| 48 |
+
"overlays_path = ds_path / \"overlays\" # optional overlay visualizations\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"device = torch.device(\"cuda:0\") # device to use for inference"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": 3,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"rgb_paths = sorted(list(rgb_path.glob(\"*.png\")))\n",
|
| 60 |
+
"ce_paths = sorted(list(ce_path.glob(\"*.png\")))\n",
|
| 61 |
+
"paired_paths = list(zip(rgb_paths, ce_paths))"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": null,
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"paired_paths[0] # important to make sure that the paths are paired correctly"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"source": [
|
| 77 |
+
"### Artery-vein segmentation\n"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"av_ensemble = SegmentationEnsemble.from_huggingface('zyf0717/vascx-fork:artery_vein/av_july24.pt').to(device)\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"av_ensemble.predict_preprocessed(paired_paths, dest_path=av_path, num_workers=2)"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "markdown",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"source": [
|
| 95 |
+
"### Disc segmentation\n"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "code",
|
| 100 |
+
"execution_count": null,
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"disc_ensemble = SegmentationEnsemble.from_huggingface('zyf0717/vascx-fork:disc/disc_july24.pt').to(device)\n",
|
| 105 |
+
"disc_ensemble.predict_preprocessed(paired_paths, dest_path=discs_path, num_workers=2)"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "markdown",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"source": [
|
| 112 |
+
"### Fovea detection\n"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [],
|
| 120 |
+
"source": [
|
| 121 |
+
"fovea_ensemble = HeatmapRegressionEnsemble.from_huggingface('zyf0717/vascx-fork:fovea/fovea_july24.pt').to(device)\n",
|
| 122 |
+
"# note: this model does not use contrast enhanced images\n",
|
| 123 |
+
"df = fovea_ensemble.predict_preprocessed(paired_paths, num_workers=2)\n",
|
| 124 |
+
"df.columns = [\"mean_x\", \"mean_y\"]\n",
|
| 125 |
+
"df.to_csv(ds_path / \"fovea.csv\")"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"df"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "markdown",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"source": [
|
| 141 |
+
"### Plotting the retinas (optional)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"This will only work if you ran all the models and stored the outputs using the same folder/file names as above\n"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"source": [
|
| 152 |
+
"from vascx.fundus.loader import RetinaLoader\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"from rtnls_enface.utils.plotting import plot_gridfns\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"loader = RetinaLoader.from_folder(ds_path)"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "code",
|
| 161 |
+
"execution_count": null,
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"outputs": [],
|
| 164 |
+
"source": [
|
| 165 |
+
"plot_gridfns([ret.plot for ret in loader[:6]])"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "markdown",
|
| 170 |
+
"metadata": {},
|
| 171 |
+
"source": [
|
| 172 |
+
"### Storing visualizations (optional)\n"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": 10,
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [],
|
| 180 |
+
"source": [
|
| 181 |
+
"if not overlays_path.exists():\n",
|
| 182 |
+
" overlays_path.mkdir()\n",
|
| 183 |
+
"for ret in loader:\n",
|
| 184 |
+
" fig, _ = ret.plot()\n",
|
| 185 |
+
" fig.savefig(overlays_path / f\"{ret.id}.png\", bbox_inches=\"tight\", pad_inches=0)"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": []
|
| 194 |
+
}
|
| 195 |
+
],
|
| 196 |
+
"metadata": {
|
| 197 |
+
"kernelspec": {
|
| 198 |
+
"display_name": "retinalysis",
|
| 199 |
+
"language": "python",
|
| 200 |
+
"name": "python3"
|
| 201 |
+
},
|
| 202 |
+
"language_info": {
|
| 203 |
+
"codemirror_mode": {
|
| 204 |
+
"name": "ipython",
|
| 205 |
+
"version": 3
|
| 206 |
+
},
|
| 207 |
+
"file_extension": ".py",
|
| 208 |
+
"mimetype": "text/x-python",
|
| 209 |
+
"name": "python",
|
| 210 |
+
"nbconvert_exporter": "python",
|
| 211 |
+
"pygments_lexer": "ipython3",
|
| 212 |
+
"version": "3.10.13"
|
| 213 |
+
}
|
| 214 |
+
},
|
| 215 |
+
"nbformat": 4,
|
| 216 |
+
"nbformat_minor": 2
|
| 217 |
+
}
|
odfd/odfd_march25.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa2be119eb915bc9da6ba42234f703b5cc270d53b62c1e7d7e1bdff52c1e0edd
|
| 3 |
+
size 855538988
|
quality/quality.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:80034ccb21a57522ba0cb86be0d46d2b659e193b86db909ad6ec2e85e61f87aa
|
| 3 |
+
size 855578258
|
run.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
CONDA_ENV="${CONDA_ENV:-vascx-fork}"
|
| 6 |
+
SAMPLE_INPUT_PATH="$REPO_ROOT/samples/fundus/original"
|
| 7 |
+
DEFAULT_INPUT_PATH="$SAMPLE_INPUT_PATH"
|
| 8 |
+
INPUT_PATH="${INPUT_PATH:-$DEFAULT_INPUT_PATH}"
|
| 9 |
+
TIMESTAMP="$(date +"%Y%m%d_%H%M%S")"
|
| 10 |
+
DEFAULT_OUTPUT_PATH="$REPO_ROOT/output_$TIMESTAMP"
|
| 11 |
+
OUTPUT_PATH="${OUTPUT_PATH:-$DEFAULT_OUTPUT_PATH}"
|
| 12 |
+
N_JOBS="${N_JOBS:-1}"
|
| 13 |
+
MODEL_RELEASES_DIR="$REPO_ROOT/model_releases"
|
| 14 |
+
|
| 15 |
+
while [[ $# -gt 0 ]]; do
|
| 16 |
+
case "$1" in
|
| 17 |
+
--sample-run)
|
| 18 |
+
INPUT_PATH="$SAMPLE_INPUT_PATH"
|
| 19 |
+
shift
|
| 20 |
+
;;
|
| 21 |
+
*)
|
| 22 |
+
echo "Unknown argument: $1" >&2
|
| 23 |
+
echo "Usage: $0 [--sample-run]" >&2
|
| 24 |
+
exit 1
|
| 25 |
+
;;
|
| 26 |
+
esac
|
| 27 |
+
done
|
| 28 |
+
|
| 29 |
+
if [[ ! -d "$INPUT_PATH" ]]; then
|
| 30 |
+
echo "Input directory does not exist: $INPUT_PATH" >&2
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
mkdir -p "$REPO_ROOT/.mplconfig" "$REPO_ROOT/.cache" "$MODEL_RELEASES_DIR" "$OUTPUT_PATH"
|
| 35 |
+
|
| 36 |
+
for model_path in "$REPO_ROOT"/*/*.pt; do
|
| 37 |
+
[[ -e "$model_path" ]] || continue
|
| 38 |
+
if [[ "$model_path" == "$MODEL_RELEASES_DIR"/* ]]; then
|
| 39 |
+
continue
|
| 40 |
+
fi
|
| 41 |
+
ln -sf "$model_path" "$MODEL_RELEASES_DIR/$(basename "$model_path")"
|
| 42 |
+
done
|
| 43 |
+
|
| 44 |
+
export MPLCONFIGDIR="$REPO_ROOT/.mplconfig"
|
| 45 |
+
export XDG_CACHE_HOME="$REPO_ROOT/.cache"
|
| 46 |
+
export RTNLS_MODEL_RELEASES="$MODEL_RELEASES_DIR"
|
| 47 |
+
|
| 48 |
+
echo "Running VascX Fork"
|
| 49 |
+
echo " conda env: $CONDA_ENV"
|
| 50 |
+
echo " input path: $INPUT_PATH"
|
| 51 |
+
echo " output path: $OUTPUT_PATH"
|
| 52 |
+
echo " n_jobs: $N_JOBS"
|
| 53 |
+
echo " models dir: $RTNLS_MODEL_RELEASES"
|
| 54 |
+
|
| 55 |
+
CONDA_BASE="$(conda info --base)"
|
| 56 |
+
# shellcheck disable=SC1091
|
| 57 |
+
source "$CONDA_BASE/etc/profile.d/conda.sh"
|
| 58 |
+
conda activate "$CONDA_ENV"
|
| 59 |
+
|
| 60 |
+
exec python -c "from vascx_models.cli import cli; cli()" run "$INPUT_PATH" "$OUTPUT_PATH" --n_jobs "$N_JOBS"
|
samples/fundus/original/CHASEDB1_08L.png
ADDED
|
Git LFS Details
|
samples/fundus/original/CHASEDB1_12R.png
ADDED
|
Git LFS Details
|
samples/fundus/original/DRIVE_22.png
ADDED
|
Git LFS Details
|
samples/fundus/original/DRIVE_40.png
ADDED
|
Git LFS Details
|
samples/fundus/original/HRF_04_g.jpg
ADDED
|
Git LFS Details
|
samples/fundus/original/HRF_07_dr.jpg
ADDED
|
Git LFS Details
|
setup.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import find_packages, setup
|
| 2 |
+
|
| 3 |
+
with open("README.md", "r") as fh:
|
| 4 |
+
long_description = fh.read()
|
| 5 |
+
|
| 6 |
+
setup(
|
| 7 |
+
name="vascx_models",
|
| 8 |
+
# using versioneer for versioning using git tags
|
| 9 |
+
# https://github.com/python-versioneer/python-versioneer/blob/master/INSTALL.md
|
| 10 |
+
# version=versioneer.get_version(),
|
| 11 |
+
# cmdclass=versioneer.get_cmdclass(),
|
| 12 |
+
author="Jose Vargas",
|
| 13 |
+
author_email="j.vargasquiros@erasmusmc.nl",
|
| 14 |
+
description="Retinal analysis toolbox for Python",
|
| 15 |
+
long_description=long_description,
|
| 16 |
+
long_description_content_type="text/markdown",
|
| 17 |
+
packages=find_packages(),
|
| 18 |
+
include_package_data=True,
|
| 19 |
+
zip_safe=False,
|
| 20 |
+
entry_points={
|
| 21 |
+
"console_scripts": [
|
| 22 |
+
"vascx = vascx_models.cli:cli",
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
install_requires=[
|
| 26 |
+
"numpy == 2.*",
|
| 27 |
+
"pandas == 2.*",
|
| 28 |
+
"tqdm == 4.*",
|
| 29 |
+
"Pillow == 11.*",
|
| 30 |
+
"click==8.*",
|
| 31 |
+
"PyYAML == 6.*",
|
| 32 |
+
"retinalysis-fundusprep",
|
| 33 |
+
"retinalysis-inference",
|
| 34 |
+
],
|
| 35 |
+
python_requires=">=3.10, <3.13",
|
| 36 |
+
)
|
vascx_models/__init__.py
ADDED
|
File without changes
|
vascx_models/cli.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import warnings
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import click
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from rtnls_fundusprep.cli import _run_preprocessing
|
| 9 |
+
|
| 10 |
+
from .config import load_app_config
|
| 11 |
+
from .disc_rings import generate_disc_rings
|
| 12 |
+
from .inference import (
|
| 13 |
+
preferred_device,
|
| 14 |
+
run_fovea_detection,
|
| 15 |
+
run_quality_estimation,
|
| 16 |
+
run_segmentation_disc,
|
| 17 |
+
run_segmentation_vessels_and_av,
|
| 18 |
+
)
|
| 19 |
+
from .utils import batch_create_overlays
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def configure_logging() -> None:
|
| 25 |
+
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
| 26 |
+
warnings.filterwarnings(
|
| 27 |
+
"ignore",
|
| 28 |
+
message=(
|
| 29 |
+
"Using a non-tuple sequence for multidimensional indexing is deprecated "
|
| 30 |
+
"and will be changed in pytorch 2.9; use x\\[tuple\\(seq\\)\\] instead of x\\[seq\\].*"
|
| 31 |
+
),
|
| 32 |
+
category=UserWarning,
|
| 33 |
+
module=r"monai\.inferers\.utils",
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@click.group(name="vascx")
|
| 38 |
+
def cli():
|
| 39 |
+
configure_logging()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@cli.command()
|
| 43 |
+
@click.argument("data_path", type=click.Path(exists=True))
|
| 44 |
+
@click.argument("output_path", type=click.Path())
|
| 45 |
+
@click.option(
|
| 46 |
+
"--config",
|
| 47 |
+
"config_path",
|
| 48 |
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
| 49 |
+
default=None,
|
| 50 |
+
help="Path to a YAML config file. Defaults to ./config.yaml or the repo-root config.yaml when present.",
|
| 51 |
+
)
|
| 52 |
+
@click.option(
|
| 53 |
+
"--preprocess/--no-preprocess",
|
| 54 |
+
default=True,
|
| 55 |
+
help="Run preprocessing or use preprocessed images",
|
| 56 |
+
)
|
| 57 |
+
@click.option(
|
| 58 |
+
"--vessels/--no-vessels", default=True, help="Run vessels and AV segmentation"
|
| 59 |
+
)
|
| 60 |
+
@click.option("--disc/--no-disc", default=True, help="Run optic disc segmentation")
|
| 61 |
+
@click.option(
|
| 62 |
+
"--quality/--no-quality", default=True, help="Run image quality estimation"
|
| 63 |
+
)
|
| 64 |
+
@click.option("--fovea/--no-fovea", default=True, help="Run fovea detection")
|
| 65 |
+
@click.option(
|
| 66 |
+
"--overlay/--no-overlay",
|
| 67 |
+
default=None,
|
| 68 |
+
help="Create visualization overlays. Defaults to the config value when set.",
|
| 69 |
+
)
|
| 70 |
+
@click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
|
| 71 |
+
def run(
|
| 72 |
+
data_path,
|
| 73 |
+
output_path,
|
| 74 |
+
config_path,
|
| 75 |
+
preprocess,
|
| 76 |
+
vessels,
|
| 77 |
+
disc,
|
| 78 |
+
quality,
|
| 79 |
+
fovea,
|
| 80 |
+
overlay,
|
| 81 |
+
n_jobs,
|
| 82 |
+
):
|
| 83 |
+
"""Run the complete inference pipeline on fundus images.
|
| 84 |
+
|
| 85 |
+
DATA_PATH is either a directory containing images or a CSV file with 'path' column.
|
| 86 |
+
OUTPUT_PATH is the directory where results will be stored.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
output_path = Path(output_path)
|
| 90 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
| 91 |
+
try:
|
| 92 |
+
app_config = load_app_config(config_path)
|
| 93 |
+
except (FileNotFoundError, ValueError) as exc:
|
| 94 |
+
raise click.ClickException(str(exc)) from exc
|
| 95 |
+
overlay_enabled = app_config.overlay.enabled if overlay is None else overlay
|
| 96 |
+
if app_config.source_path is not None:
|
| 97 |
+
logger.info("Loaded config from %s", app_config.source_path)
|
| 98 |
+
|
| 99 |
+
# Setup output directories
|
| 100 |
+
preprocess_rgb_path = output_path / "preprocessed_rgb"
|
| 101 |
+
vessels_path = output_path / "vessels"
|
| 102 |
+
av_path = output_path / "artery_vein"
|
| 103 |
+
disc_path = output_path / "disc"
|
| 104 |
+
disc_ring_2r_path = output_path / "disc_ring_2r"
|
| 105 |
+
disc_ring_3r_path = output_path / "disc_ring_3r"
|
| 106 |
+
overlay_path = output_path / "overlays"
|
| 107 |
+
|
| 108 |
+
# Create required directories
|
| 109 |
+
if preprocess:
|
| 110 |
+
preprocess_rgb_path.mkdir(exist_ok=True, parents=True)
|
| 111 |
+
if vessels:
|
| 112 |
+
av_path.mkdir(exist_ok=True, parents=True)
|
| 113 |
+
vessels_path.mkdir(exist_ok=True, parents=True)
|
| 114 |
+
if disc:
|
| 115 |
+
disc_path.mkdir(exist_ok=True, parents=True)
|
| 116 |
+
disc_ring_2r_path.mkdir(exist_ok=True, parents=True)
|
| 117 |
+
disc_ring_3r_path.mkdir(exist_ok=True, parents=True)
|
| 118 |
+
if overlay_enabled:
|
| 119 |
+
overlay_path.mkdir(exist_ok=True, parents=True)
|
| 120 |
+
|
| 121 |
+
bounds_path = output_path / "bounds.csv" if preprocess else None
|
| 122 |
+
quality_path = output_path / "quality.csv" if quality else None
|
| 123 |
+
fovea_path = output_path / "fovea.csv" if fovea else None
|
| 124 |
+
disc_geometry_path = output_path / "disc_geometry.csv" if disc else None
|
| 125 |
+
|
| 126 |
+
# Determine if input is a folder or CSV file
|
| 127 |
+
data_path = Path(data_path)
|
| 128 |
+
is_csv = data_path.suffix.lower() == ".csv"
|
| 129 |
+
|
| 130 |
+
# Get files to process
|
| 131 |
+
files = []
|
| 132 |
+
ids = None
|
| 133 |
+
|
| 134 |
+
if is_csv:
|
| 135 |
+
logger.info("Reading file paths from CSV: %s", data_path)
|
| 136 |
+
try:
|
| 137 |
+
df = pd.read_csv(data_path)
|
| 138 |
+
if "path" not in df.columns:
|
| 139 |
+
logger.error("CSV must contain a 'path' column")
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
# Get file paths and convert to Path objects
|
| 143 |
+
files = [Path(p) for p in df["path"]]
|
| 144 |
+
|
| 145 |
+
if "id" in df.columns:
|
| 146 |
+
ids = df["id"].tolist()
|
| 147 |
+
logger.info("Using IDs from CSV 'id' column")
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.exception("Error reading CSV file: %s", e)
|
| 151 |
+
return
|
| 152 |
+
else:
|
| 153 |
+
logger.info("Finding files in directory: %s", data_path)
|
| 154 |
+
files = list(data_path.glob("*"))
|
| 155 |
+
ids = [f.stem for f in files]
|
| 156 |
+
|
| 157 |
+
if not files:
|
| 158 |
+
logger.warning("No files found to process")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
logger.info("Found %d files to process", len(files))
|
| 162 |
+
|
| 163 |
+
# Step 1: Preprocess images if requested
|
| 164 |
+
if preprocess:
|
| 165 |
+
logger.info("Running preprocessing")
|
| 166 |
+
_run_preprocessing(
|
| 167 |
+
files=files,
|
| 168 |
+
ids=ids,
|
| 169 |
+
rgb_path=preprocess_rgb_path,
|
| 170 |
+
bounds_path=bounds_path,
|
| 171 |
+
n_jobs=n_jobs,
|
| 172 |
+
)
|
| 173 |
+
# Use the preprocessed images for subsequent steps
|
| 174 |
+
preprocessed_files = list(preprocess_rgb_path.glob("*.png"))
|
| 175 |
+
else:
|
| 176 |
+
# Use the input files directly
|
| 177 |
+
preprocessed_files = files
|
| 178 |
+
ids = [f.stem for f in preprocessed_files]
|
| 179 |
+
logger.info("Prepared %d images for inference", len(preprocessed_files))
|
| 180 |
+
|
| 181 |
+
# Prefer hardware acceleration when the active torch build supports it.
|
| 182 |
+
device = preferred_device()
|
| 183 |
+
logger.info("Using device: %s", device)
|
| 184 |
+
|
| 185 |
+
# Step 2: Run quality estimation if requested
|
| 186 |
+
if quality:
|
| 187 |
+
logger.info("Running quality estimation")
|
| 188 |
+
df_quality = run_quality_estimation(
|
| 189 |
+
fpaths=preprocessed_files, ids=ids, device=device
|
| 190 |
+
)
|
| 191 |
+
df_quality.to_csv(quality_path)
|
| 192 |
+
logger.info("Quality results saved to %s", quality_path)
|
| 193 |
+
|
| 194 |
+
# Step 3: Run vessels and AV segmentation if requested
|
| 195 |
+
if vessels:
|
| 196 |
+
logger.info("Running vessels and AV segmentation")
|
| 197 |
+
run_segmentation_vessels_and_av(
|
| 198 |
+
rgb_paths=preprocessed_files,
|
| 199 |
+
ids=ids,
|
| 200 |
+
av_path=av_path,
|
| 201 |
+
vessels_path=vessels_path,
|
| 202 |
+
device=device,
|
| 203 |
+
)
|
| 204 |
+
logger.info("Vessel segmentation saved to %s", vessels_path)
|
| 205 |
+
logger.info("AV segmentation saved to %s", av_path)
|
| 206 |
+
|
| 207 |
+
# Step 4: Run optic disc segmentation if requested
|
| 208 |
+
if disc:
|
| 209 |
+
logger.info("Running optic disc segmentation")
|
| 210 |
+
run_segmentation_disc(
|
| 211 |
+
rgb_paths=preprocessed_files, ids=ids, output_path=disc_path, device=device
|
| 212 |
+
)
|
| 213 |
+
logger.info("Disc segmentation saved to %s", disc_path)
|
| 214 |
+
generate_disc_rings(
|
| 215 |
+
disc_dir=disc_path,
|
| 216 |
+
ring_2r_dir=disc_ring_2r_path,
|
| 217 |
+
ring_3r_dir=disc_ring_3r_path,
|
| 218 |
+
measurements_path=disc_geometry_path,
|
| 219 |
+
)
|
| 220 |
+
logger.info("2r disc rings saved to %s", disc_ring_2r_path)
|
| 221 |
+
logger.info("3r disc rings saved to %s", disc_ring_3r_path)
|
| 222 |
+
|
| 223 |
+
# Step 5: Run fovea detection if requested
|
| 224 |
+
df_fovea = None
|
| 225 |
+
if fovea:
|
| 226 |
+
logger.info("Running fovea detection")
|
| 227 |
+
df_fovea = run_fovea_detection(
|
| 228 |
+
rgb_paths=preprocessed_files, ids=ids, device=device
|
| 229 |
+
)
|
| 230 |
+
df_fovea.to_csv(fovea_path)
|
| 231 |
+
logger.info("Fovea detection results saved to %s", fovea_path)
|
| 232 |
+
|
| 233 |
+
# Step 6: Create overlays if requested
|
| 234 |
+
if overlay_enabled:
|
| 235 |
+
logger.info("Creating visualization overlays")
|
| 236 |
+
|
| 237 |
+
# Prepare fovea data if available
|
| 238 |
+
fovea_data = None
|
| 239 |
+
if df_fovea is not None:
|
| 240 |
+
fovea_data = {
|
| 241 |
+
idx: (row["x_fovea"], row["y_fovea"])
|
| 242 |
+
for idx, row in df_fovea.iterrows()
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
# Create visualization overlays
|
| 246 |
+
batch_create_overlays(
|
| 247 |
+
rgb_dir=preprocess_rgb_path if preprocess else data_path,
|
| 248 |
+
output_dir=overlay_path,
|
| 249 |
+
av_dir=av_path,
|
| 250 |
+
disc_dir=disc_path,
|
| 251 |
+
ring_2r_dir=disc_ring_2r_path,
|
| 252 |
+
ring_3r_dir=disc_ring_3r_path,
|
| 253 |
+
fovea_data=fovea_data,
|
| 254 |
+
overlay_config=app_config.overlay,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
logger.info("Visualization overlays saved to %s", overlay_path)
|
| 258 |
+
|
| 259 |
+
logger.info("All requested processing complete. Results saved to %s", output_path)
|
vascx_models/config.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Iterable, Mapping
|
| 6 |
+
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DEFAULT_CONFIG_NAME = "config.yaml"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _repo_root() -> Path:
|
| 14 |
+
return Path(__file__).resolve().parent.parent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class OverlayLayers:
|
| 19 |
+
arteries: bool = True
|
| 20 |
+
veins: bool = True
|
| 21 |
+
disc: bool = True
|
| 22 |
+
ring_2r: bool = True
|
| 23 |
+
ring_3r: bool = True
|
| 24 |
+
fovea: bool = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class OverlayColors:
|
| 29 |
+
artery: tuple[int, int, int] = (255, 0, 0)
|
| 30 |
+
vein: tuple[int, int, int] = (0, 0, 255)
|
| 31 |
+
disc: tuple[int, int, int] = (255, 255, 255)
|
| 32 |
+
ring_2r: tuple[int, int, int] = (0, 255, 0)
|
| 33 |
+
ring_3r: tuple[int, int, int] = (255, 0, 255)
|
| 34 |
+
fovea: tuple[int, int, int] = (255, 255, 0)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass(frozen=True)
|
| 38 |
+
class OverlayConfig:
|
| 39 |
+
enabled: bool = True
|
| 40 |
+
layers: OverlayLayers = field(default_factory=OverlayLayers)
|
| 41 |
+
colors: OverlayColors = field(default_factory=OverlayColors)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class AppConfig:
|
| 46 |
+
overlay: OverlayConfig = field(default_factory=OverlayConfig)
|
| 47 |
+
source_path: Path | None = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def default_config_candidates() -> list[Path]:
|
| 51 |
+
candidates = [Path.cwd() / DEFAULT_CONFIG_NAME, _repo_root() / DEFAULT_CONFIG_NAME]
|
| 52 |
+
unique_candidates: list[Path] = []
|
| 53 |
+
seen: set[Path] = set()
|
| 54 |
+
for candidate in candidates:
|
| 55 |
+
resolved = candidate.resolve()
|
| 56 |
+
if resolved not in seen:
|
| 57 |
+
unique_candidates.append(candidate)
|
| 58 |
+
seen.add(resolved)
|
| 59 |
+
return unique_candidates
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def resolve_config_path(config_path: str | Path | None) -> Path | None:
|
| 63 |
+
if config_path is not None:
|
| 64 |
+
candidate = Path(config_path).expanduser()
|
| 65 |
+
if not candidate.exists():
|
| 66 |
+
raise FileNotFoundError(f"Config file not found: {candidate}")
|
| 67 |
+
return candidate
|
| 68 |
+
|
| 69 |
+
for candidate in default_config_candidates():
|
| 70 |
+
if candidate.exists():
|
| 71 |
+
return candidate
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_app_config(config_path: str | Path | None = None) -> AppConfig:
|
| 76 |
+
resolved_path = resolve_config_path(config_path)
|
| 77 |
+
if resolved_path is None:
|
| 78 |
+
return AppConfig()
|
| 79 |
+
|
| 80 |
+
with resolved_path.open("r", encoding="utf-8") as handle:
|
| 81 |
+
raw_config = yaml.safe_load(handle) or {}
|
| 82 |
+
|
| 83 |
+
if not isinstance(raw_config, dict):
|
| 84 |
+
raise ValueError("Config root must be a mapping")
|
| 85 |
+
|
| 86 |
+
overlay_raw = raw_config.get("overlay", {})
|
| 87 |
+
if overlay_raw is None:
|
| 88 |
+
overlay_raw = {}
|
| 89 |
+
if not isinstance(overlay_raw, dict):
|
| 90 |
+
raise ValueError("'overlay' must be a mapping")
|
| 91 |
+
|
| 92 |
+
layer_overrides = overlay_raw.get("layers", {})
|
| 93 |
+
if layer_overrides is None:
|
| 94 |
+
layer_overrides = {}
|
| 95 |
+
if not isinstance(layer_overrides, dict):
|
| 96 |
+
raise ValueError("'overlay.layers' must be a mapping")
|
| 97 |
+
|
| 98 |
+
color_overrides = overlay_raw.get("colors", overlay_raw.get("colours", {}))
|
| 99 |
+
if color_overrides is None:
|
| 100 |
+
color_overrides = {}
|
| 101 |
+
if not isinstance(color_overrides, dict):
|
| 102 |
+
raise ValueError("'overlay.colors' must be a mapping")
|
| 103 |
+
|
| 104 |
+
return AppConfig(
|
| 105 |
+
overlay=OverlayConfig(
|
| 106 |
+
enabled=_coerce_bool(overlay_raw.get("enabled", True), "overlay.enabled"),
|
| 107 |
+
layers=_build_overlay_layers(layer_overrides),
|
| 108 |
+
colors=_build_overlay_colors(color_overrides),
|
| 109 |
+
),
|
| 110 |
+
source_path=resolved_path,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _build_overlay_layers(raw_layers: Mapping[str, object]) -> OverlayLayers:
|
| 115 |
+
defaults = OverlayLayers()
|
| 116 |
+
alias_map = {
|
| 117 |
+
"artery": "arteries",
|
| 118 |
+
"arteries": "arteries",
|
| 119 |
+
"vein": "veins",
|
| 120 |
+
"veins": "veins",
|
| 121 |
+
"disc": "disc",
|
| 122 |
+
"ring_2r": "ring_2r",
|
| 123 |
+
"disc_ring_2r": "ring_2r",
|
| 124 |
+
"ring_3r": "ring_3r",
|
| 125 |
+
"disc_ring_3r": "ring_3r",
|
| 126 |
+
"fovea": "fovea",
|
| 127 |
+
}
|
| 128 |
+
values = defaults.__dict__.copy()
|
| 129 |
+
for raw_key, raw_value in raw_layers.items():
|
| 130 |
+
if raw_key not in alias_map:
|
| 131 |
+
raise ValueError(f"Unsupported overlay layer '{raw_key}'")
|
| 132 |
+
normalized_key = alias_map[raw_key]
|
| 133 |
+
values[normalized_key] = _coerce_bool(
|
| 134 |
+
raw_value, f"overlay.layers.{raw_key}"
|
| 135 |
+
)
|
| 136 |
+
return OverlayLayers(**values)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _build_overlay_colors(raw_colors: Mapping[str, object]) -> OverlayColors:
|
| 140 |
+
defaults = OverlayColors()
|
| 141 |
+
alias_map = {
|
| 142 |
+
"artery": "artery",
|
| 143 |
+
"arteries": "artery",
|
| 144 |
+
"vein": "vein",
|
| 145 |
+
"veins": "vein",
|
| 146 |
+
"disc": "disc",
|
| 147 |
+
"ring_2r": "ring_2r",
|
| 148 |
+
"disc_ring_2r": "ring_2r",
|
| 149 |
+
"ring_3r": "ring_3r",
|
| 150 |
+
"disc_ring_3r": "ring_3r",
|
| 151 |
+
"fovea": "fovea",
|
| 152 |
+
}
|
| 153 |
+
values = defaults.__dict__.copy()
|
| 154 |
+
for raw_key, raw_value in raw_colors.items():
|
| 155 |
+
if raw_key not in alias_map:
|
| 156 |
+
raise ValueError(f"Unsupported overlay color '{raw_key}'")
|
| 157 |
+
normalized_key = alias_map[raw_key]
|
| 158 |
+
values[normalized_key] = _parse_rgb(raw_value, f"overlay.colors.{raw_key}")
|
| 159 |
+
return OverlayColors(**values)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _coerce_bool(value: object, field_name: str) -> bool:
|
| 163 |
+
if isinstance(value, bool):
|
| 164 |
+
return value
|
| 165 |
+
raise ValueError(f"'{field_name}' must be a boolean")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _parse_rgb(value: object, field_name: str) -> tuple[int, int, int]:
|
| 169 |
+
if isinstance(value, str):
|
| 170 |
+
return _parse_hex_color(value, field_name)
|
| 171 |
+
if isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)):
|
| 172 |
+
channels = tuple(value)
|
| 173 |
+
if len(channels) != 3:
|
| 174 |
+
raise ValueError(f"'{field_name}' must contain exactly 3 channels")
|
| 175 |
+
return tuple(_coerce_channel(channel, field_name) for channel in channels)
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"'{field_name}' must be a '#RRGGBB' string or a 3-item RGB sequence"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _parse_hex_color(value: str, field_name: str) -> tuple[int, int, int]:
|
| 182 |
+
normalized = value.strip()
|
| 183 |
+
if normalized.startswith("#"):
|
| 184 |
+
normalized = normalized[1:]
|
| 185 |
+
if len(normalized) != 6:
|
| 186 |
+
raise ValueError(f"'{field_name}' must be a 6-digit hex color")
|
| 187 |
+
try:
|
| 188 |
+
return tuple(int(normalized[index : index + 2], 16) for index in (0, 2, 4))
|
| 189 |
+
except ValueError as exc:
|
| 190 |
+
raise ValueError(f"'{field_name}' must be a valid hex color") from exc
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _coerce_channel(value: object, field_name: str) -> int:
|
| 194 |
+
if isinstance(value, int) and 0 <= value <= 255:
|
| 195 |
+
return value
|
| 196 |
+
raise ValueError(f"'{field_name}' channels must be integers between 0 and 255")
|
vascx_models/disc_rings.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def estimate_disc_geometry(
|
| 13 |
+
disc_mask: np.ndarray,
|
| 14 |
+
) -> Optional[Tuple[float, float, float]]:
|
| 15 |
+
"""Estimate optic disc center and radius from a binary mask."""
|
| 16 |
+
mask = disc_mask > 0
|
| 17 |
+
if not np.any(mask):
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
ys, xs = np.nonzero(mask)
|
| 21 |
+
center_x = float(xs.mean())
|
| 22 |
+
center_y = float(ys.mean())
|
| 23 |
+
|
| 24 |
+
# Use the equivalent-circle radius so the estimate is stable for irregular masks.
|
| 25 |
+
radius = float(np.sqrt(mask.sum() / np.pi))
|
| 26 |
+
return center_x, center_y, radius
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_ring_mask(
|
| 30 |
+
image_shape: Tuple[int, int],
|
| 31 |
+
center: Tuple[float, float],
|
| 32 |
+
radius: float,
|
| 33 |
+
thickness: int,
|
| 34 |
+
) -> np.ndarray:
|
| 35 |
+
"""Create a binary ring mask for a circle outline."""
|
| 36 |
+
height, width = image_shape
|
| 37 |
+
ring = Image.new("L", (width, height), 0)
|
| 38 |
+
draw = ImageDraw.Draw(ring)
|
| 39 |
+
|
| 40 |
+
center_x, center_y = center
|
| 41 |
+
bbox = (
|
| 42 |
+
center_x - radius,
|
| 43 |
+
center_y - radius,
|
| 44 |
+
center_x + radius,
|
| 45 |
+
center_y + radius,
|
| 46 |
+
)
|
| 47 |
+
draw.ellipse(bbox, outline=255, width=thickness)
|
| 48 |
+
return np.array(ring, dtype=np.uint8)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def generate_disc_rings(
|
| 52 |
+
disc_dir: Path,
|
| 53 |
+
ring_2r_dir: Path,
|
| 54 |
+
ring_3r_dir: Path,
|
| 55 |
+
measurements_path: Optional[Path] = None,
|
| 56 |
+
) -> pd.DataFrame:
|
| 57 |
+
"""Generate 2r and 3r optic-disc ring masks from saved disc segmentations."""
|
| 58 |
+
ring_2r_dir.mkdir(exist_ok=True, parents=True)
|
| 59 |
+
ring_3r_dir.mkdir(exist_ok=True, parents=True)
|
| 60 |
+
|
| 61 |
+
disc_files = list(disc_dir.glob("*.png"))
|
| 62 |
+
if not disc_files:
|
| 63 |
+
logger.warning("No disc masks found for ring generation in %s", disc_dir)
|
| 64 |
+
columns = [
|
| 65 |
+
"x_disc_center",
|
| 66 |
+
"y_disc_center",
|
| 67 |
+
"disc_radius_px",
|
| 68 |
+
"ring_2r_px",
|
| 69 |
+
"ring_3r_px",
|
| 70 |
+
]
|
| 71 |
+
return pd.DataFrame(columns=columns)
|
| 72 |
+
|
| 73 |
+
records: Dict[str, Dict[str, float]] = {}
|
| 74 |
+
logger.info("Generating 2r and 3r rings for %d disc masks", len(disc_files))
|
| 75 |
+
|
| 76 |
+
for disc_file in disc_files:
|
| 77 |
+
image_id = disc_file.stem
|
| 78 |
+
disc_mask = np.array(Image.open(disc_file)) > 0
|
| 79 |
+
geometry = estimate_disc_geometry(disc_mask)
|
| 80 |
+
|
| 81 |
+
if geometry is None:
|
| 82 |
+
logger.warning("Disc mask is empty for %s; writing blank ring masks", image_id)
|
| 83 |
+
blank = np.zeros(disc_mask.shape, dtype=np.uint8)
|
| 84 |
+
Image.fromarray(blank).save(ring_2r_dir / f"{image_id}.png")
|
| 85 |
+
Image.fromarray(blank).save(ring_3r_dir / f"{image_id}.png")
|
| 86 |
+
records[image_id] = {
|
| 87 |
+
"x_disc_center": np.nan,
|
| 88 |
+
"y_disc_center": np.nan,
|
| 89 |
+
"disc_radius_px": np.nan,
|
| 90 |
+
"ring_2r_px": np.nan,
|
| 91 |
+
"ring_3r_px": np.nan,
|
| 92 |
+
}
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
center_x, center_y, disc_radius = geometry
|
| 96 |
+
line_width = max(1, int(round(disc_radius * 0.08)))
|
| 97 |
+
ring_2r = create_ring_mask(
|
| 98 |
+
disc_mask.shape, (center_x, center_y), radius=disc_radius * 2.0, thickness=line_width
|
| 99 |
+
)
|
| 100 |
+
ring_3r = create_ring_mask(
|
| 101 |
+
disc_mask.shape, (center_x, center_y), radius=disc_radius * 3.0, thickness=line_width
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
Image.fromarray(ring_2r).save(ring_2r_dir / f"{image_id}.png")
|
| 105 |
+
Image.fromarray(ring_3r).save(ring_3r_dir / f"{image_id}.png")
|
| 106 |
+
records[image_id] = {
|
| 107 |
+
"x_disc_center": center_x,
|
| 108 |
+
"y_disc_center": center_y,
|
| 109 |
+
"disc_radius_px": disc_radius,
|
| 110 |
+
"ring_2r_px": disc_radius * 2.0,
|
| 111 |
+
"ring_3r_px": disc_radius * 3.0,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
df_measurements = pd.DataFrame.from_dict(records, orient="index")
|
| 115 |
+
if measurements_path is not None:
|
| 116 |
+
df_measurements.to_csv(measurements_path)
|
| 117 |
+
logger.info("Disc ring measurements saved to %s", measurements_path)
|
| 118 |
+
return df_measurements
|
vascx_models/inference.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from rtnls_inference.ensembles.ensemble_classification import ClassificationEnsemble
|
| 14 |
+
from rtnls_inference.ensembles.ensemble_heatmap_regression import (
|
| 15 |
+
HeatmapRegressionEnsemble,
|
| 16 |
+
)
|
| 17 |
+
from rtnls_inference.ensembles.ensemble_segmentation import SegmentationEnsemble
|
| 18 |
+
from rtnls_inference.utils import decollate_batch, extract_keypoints_from_heatmaps
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def preferred_device() -> torch.device:
|
| 24 |
+
if torch.cuda.is_available():
|
| 25 |
+
return torch.device("cuda:0")
|
| 26 |
+
if torch.backends.mps.is_available():
|
| 27 |
+
return torch.device("mps")
|
| 28 |
+
return torch.device("cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _inference_num_workers(device: torch.device) -> int:
|
| 32 |
+
# Torch shared-memory workers can fail in restricted CPU environments.
|
| 33 |
+
return 8 if device.type in {"cuda", "mps"} else 0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _autocast_context(device: torch.device):
|
| 37 |
+
return torch.autocast(device_type=device.type) if device.type == "cuda" else nullcontext()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def run_quality_estimation(fpaths, ids, device: torch.device):
|
| 41 |
+
logger.info("Loading quality model on %s", device)
|
| 42 |
+
ensemble_quality = ClassificationEnsemble.from_release("quality.pt").to(device)
|
| 43 |
+
dataloader = ensemble_quality._make_inference_dataloader(
|
| 44 |
+
fpaths,
|
| 45 |
+
ids=ids,
|
| 46 |
+
num_workers=_inference_num_workers(device),
|
| 47 |
+
preprocess=False,
|
| 48 |
+
batch_size=16,
|
| 49 |
+
)
|
| 50 |
+
logger.info("Quality dataloader ready with %d images", len(fpaths))
|
| 51 |
+
|
| 52 |
+
output_ids, outputs = [], []
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for batch in tqdm(dataloader):
|
| 55 |
+
if len(batch) == 0:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
im = batch["image"].to(device)
|
| 59 |
+
|
| 60 |
+
# QUALITY
|
| 61 |
+
quality = ensemble_quality.predict_step(im)
|
| 62 |
+
quality = torch.mean(quality, dim=0)
|
| 63 |
+
|
| 64 |
+
items = {"id": batch["id"], "quality": quality}
|
| 65 |
+
items = decollate_batch(items)
|
| 66 |
+
|
| 67 |
+
for item in items:
|
| 68 |
+
output_ids.append(item["id"])
|
| 69 |
+
outputs.append(item["quality"].tolist())
|
| 70 |
+
|
| 71 |
+
return pd.DataFrame(
|
| 72 |
+
outputs,
|
| 73 |
+
index=output_ids,
|
| 74 |
+
columns=["q1", "q2", "q3"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def run_segmentation_vessels_and_av(
|
| 79 |
+
rgb_paths: List[Path],
|
| 80 |
+
ce_paths: Optional[List[Path]] = None,
|
| 81 |
+
ids: Optional[List[str]] = None,
|
| 82 |
+
av_path: Optional[Path] = None,
|
| 83 |
+
vessels_path: Optional[Path] = None,
|
| 84 |
+
device: torch.device = preferred_device(),
|
| 85 |
+
) -> None:
|
| 86 |
+
"""
|
| 87 |
+
Run AV and vessel segmentation on the provided images.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
rgb_paths: List of paths to RGB fundus images
|
| 91 |
+
ce_paths: Optional list of paths to contrast enhanced images
|
| 92 |
+
ids: Optional list of ids to pass to _make_inference_dataloader
|
| 93 |
+
av_path: Folder where to store output AV segmentations
|
| 94 |
+
vessels_path: Folder where to store output vessel segmentations
|
| 95 |
+
device: Device to run inference on
|
| 96 |
+
"""
|
| 97 |
+
# Create output directories if they don't exist
|
| 98 |
+
if av_path is not None:
|
| 99 |
+
av_path.mkdir(exist_ok=True, parents=True)
|
| 100 |
+
if vessels_path is not None:
|
| 101 |
+
vessels_path.mkdir(exist_ok=True, parents=True)
|
| 102 |
+
|
| 103 |
+
# Load models
|
| 104 |
+
logger.info("Loading AV and vessel models on %s", device)
|
| 105 |
+
ensemble_av = SegmentationEnsemble.from_release("av_july24.pt").to(device).eval()
|
| 106 |
+
ensemble_vessels = (
|
| 107 |
+
SegmentationEnsemble.from_release("vessels_july24.pt").to(device).eval()
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Prepare input paths
|
| 111 |
+
if ce_paths is None:
|
| 112 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
| 113 |
+
fpaths = rgb_paths
|
| 114 |
+
else:
|
| 115 |
+
# If CE paths are provided, pair them with RGB paths
|
| 116 |
+
if len(rgb_paths) != len(ce_paths):
|
| 117 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
| 118 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
| 119 |
+
|
| 120 |
+
# Create dataloader
|
| 121 |
+
dataloader = ensemble_av._make_inference_dataloader(
|
| 122 |
+
fpaths,
|
| 123 |
+
ids=ids,
|
| 124 |
+
num_workers=_inference_num_workers(device),
|
| 125 |
+
preprocess=False,
|
| 126 |
+
batch_size=8,
|
| 127 |
+
)
|
| 128 |
+
logger.info("AV and vessel dataloader ready with %d images", len(fpaths))
|
| 129 |
+
|
| 130 |
+
# Run inference
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
for batch in tqdm(dataloader):
|
| 133 |
+
# AV segmentation
|
| 134 |
+
if av_path is not None:
|
| 135 |
+
with _autocast_context(device):
|
| 136 |
+
proba = ensemble_av.forward(batch["image"].to(device))
|
| 137 |
+
proba = torch.mean(proba, dim=1) # average over models
|
| 138 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 139 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
| 140 |
+
|
| 141 |
+
items = {
|
| 142 |
+
"id": batch["id"],
|
| 143 |
+
"image": proba,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
items = decollate_batch(items)
|
| 147 |
+
for i, item in enumerate(items):
|
| 148 |
+
fpath = os.path.join(av_path, f"{item['id']}.png")
|
| 149 |
+
mask = np.argmax(item["image"], -1)
|
| 150 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
| 151 |
+
|
| 152 |
+
# Vessel segmentation
|
| 153 |
+
if vessels_path is not None:
|
| 154 |
+
with _autocast_context(device):
|
| 155 |
+
proba = ensemble_vessels.forward(batch["image"].to(device))
|
| 156 |
+
proba = torch.mean(proba, dim=1) # average over models
|
| 157 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 158 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
| 159 |
+
|
| 160 |
+
items = {
|
| 161 |
+
"id": batch["id"],
|
| 162 |
+
"image": proba,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
items = decollate_batch(items)
|
| 166 |
+
for i, item in enumerate(items):
|
| 167 |
+
fpath = os.path.join(vessels_path, f"{item['id']}.png")
|
| 168 |
+
mask = np.argmax(item["image"], -1)
|
| 169 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def run_segmentation_disc(
|
| 173 |
+
rgb_paths: List[Path],
|
| 174 |
+
ce_paths: Optional[List[Path]] = None,
|
| 175 |
+
ids: Optional[List[str]] = None,
|
| 176 |
+
output_path: Optional[Path] = None,
|
| 177 |
+
device: torch.device = preferred_device(),
|
| 178 |
+
) -> None:
|
| 179 |
+
logger.info("Loading disc model on %s", device)
|
| 180 |
+
ensemble_disc = (
|
| 181 |
+
SegmentationEnsemble.from_release("disc_july24.pt").to(device).eval()
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Prepare input paths
|
| 185 |
+
if ce_paths is None:
|
| 186 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
| 187 |
+
fpaths = rgb_paths
|
| 188 |
+
else:
|
| 189 |
+
# If CE paths are provided, pair them with RGB paths
|
| 190 |
+
if len(rgb_paths) != len(ce_paths):
|
| 191 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
| 192 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
| 193 |
+
|
| 194 |
+
dataloader = ensemble_disc._make_inference_dataloader(
|
| 195 |
+
fpaths,
|
| 196 |
+
ids=ids,
|
| 197 |
+
num_workers=_inference_num_workers(device),
|
| 198 |
+
preprocess=False,
|
| 199 |
+
batch_size=8,
|
| 200 |
+
)
|
| 201 |
+
logger.info("Disc dataloader ready with %d images", len(fpaths))
|
| 202 |
+
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
for batch in tqdm(dataloader):
|
| 205 |
+
# AV
|
| 206 |
+
with _autocast_context(device):
|
| 207 |
+
proba = ensemble_disc.forward(batch["image"].to(device))
|
| 208 |
+
proba = torch.mean(proba, dim=1) # average over models
|
| 209 |
+
proba = torch.permute(proba, (0, 2, 3, 1)) # NCHW -> NHWC
|
| 210 |
+
proba = torch.nn.functional.softmax(proba, dim=-1)
|
| 211 |
+
|
| 212 |
+
items = {
|
| 213 |
+
"id": batch["id"],
|
| 214 |
+
"image": proba,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
items = decollate_batch(items)
|
| 218 |
+
items = [dataloader.dataset.transform.undo_item(item) for item in items]
|
| 219 |
+
for i, item in enumerate(items):
|
| 220 |
+
fpath = os.path.join(output_path, f"{item['id']}.png")
|
| 221 |
+
|
| 222 |
+
mask = np.argmax(item["image"], -1)
|
| 223 |
+
Image.fromarray(mask.squeeze().astype(np.uint8)).save(fpath)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def run_fovea_detection(
|
| 227 |
+
rgb_paths: List[Path],
|
| 228 |
+
ce_paths: Optional[List[Path]] = None,
|
| 229 |
+
ids: Optional[List[str]] = None,
|
| 230 |
+
device: torch.device = preferred_device(),
|
| 231 |
+
) -> None:
|
| 232 |
+
# def run_fovea_detection(fpaths, ids, device: torch.device):
|
| 233 |
+
logger.info("Loading fovea model on %s", device)
|
| 234 |
+
ensemble_fovea = HeatmapRegressionEnsemble.from_release("fovea_july24.pt").to(
|
| 235 |
+
device
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# Prepare input paths
|
| 239 |
+
if ce_paths is None:
|
| 240 |
+
# If CE paths are not provided, use RGB paths for both inputs
|
| 241 |
+
fpaths = rgb_paths
|
| 242 |
+
else:
|
| 243 |
+
# If CE paths are provided, pair them with RGB paths
|
| 244 |
+
if len(rgb_paths) != len(ce_paths):
|
| 245 |
+
raise ValueError("rgb_paths and ce_paths must have the same length")
|
| 246 |
+
fpaths = list(zip(rgb_paths, ce_paths))
|
| 247 |
+
|
| 248 |
+
dataloader = ensemble_fovea._make_inference_dataloader(
|
| 249 |
+
fpaths,
|
| 250 |
+
ids=ids,
|
| 251 |
+
num_workers=_inference_num_workers(device),
|
| 252 |
+
preprocess=False,
|
| 253 |
+
batch_size=8,
|
| 254 |
+
)
|
| 255 |
+
logger.info("Fovea dataloader ready with %d images", len(fpaths))
|
| 256 |
+
|
| 257 |
+
output_ids, outputs = [], []
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
for batch in tqdm(dataloader):
|
| 260 |
+
if len(batch) == 0:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
im = batch["image"].to(device)
|
| 264 |
+
|
| 265 |
+
# FOVEA DETECTION
|
| 266 |
+
with _autocast_context(device):
|
| 267 |
+
heatmap = ensemble_fovea.forward(im)
|
| 268 |
+
keypoints = extract_keypoints_from_heatmaps(heatmap)
|
| 269 |
+
|
| 270 |
+
kp_fovea = torch.mean(keypoints, dim=1) # average over models
|
| 271 |
+
|
| 272 |
+
items = {
|
| 273 |
+
"id": batch["id"],
|
| 274 |
+
"keypoints": kp_fovea,
|
| 275 |
+
"metadata": batch["metadata"],
|
| 276 |
+
}
|
| 277 |
+
items = decollate_batch(items)
|
| 278 |
+
|
| 279 |
+
items = [dataloader.dataset.transform.undo_item(item) for item in items]
|
| 280 |
+
|
| 281 |
+
for item in items:
|
| 282 |
+
output_ids.append(item["id"])
|
| 283 |
+
outputs.append(
|
| 284 |
+
[
|
| 285 |
+
*item["keypoints"][0].tolist(),
|
| 286 |
+
]
|
| 287 |
+
)
|
| 288 |
+
return pd.DataFrame(
|
| 289 |
+
outputs,
|
| 290 |
+
index=output_ids,
|
| 291 |
+
columns=["x_fovea", "y_fovea"],
|
| 292 |
+
)
|
vascx_models/utils.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image, ImageDraw
|
| 7 |
+
|
| 8 |
+
from .config import OverlayConfig
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_fundus_overlay(
|
| 14 |
+
rgb_path: str,
|
| 15 |
+
av_path: Optional[str] = None,
|
| 16 |
+
disc_path: Optional[str] = None,
|
| 17 |
+
ring_2r_path: Optional[str] = None,
|
| 18 |
+
ring_3r_path: Optional[str] = None,
|
| 19 |
+
fovea_location: Optional[Tuple[int, int]] = None,
|
| 20 |
+
output_path: Optional[str] = None,
|
| 21 |
+
overlay_config: Optional[OverlayConfig] = None,
|
| 22 |
+
) -> np.ndarray:
|
| 23 |
+
"""
|
| 24 |
+
Create a visualization of a fundus image with overlaid segmentations and markers.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
rgb_path: Path to the RGB fundus image
|
| 28 |
+
av_path: Optional path to artery-vein segmentation (1=artery, 2=vein, 3=intersection)
|
| 29 |
+
disc_path: Optional path to binary disc segmentation
|
| 30 |
+
ring_2r_path: Optional path to a binary 2r ring mask
|
| 31 |
+
ring_3r_path: Optional path to a binary 3r ring mask
|
| 32 |
+
fovea_location: Optional (x,y) tuple indicating the location of the fovea
|
| 33 |
+
output_path: Optional path to save the visualization image
|
| 34 |
+
overlay_config: Overlay display configuration including enabled layers and colors
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Numpy array containing the visualization image
|
| 38 |
+
"""
|
| 39 |
+
overlay_config = overlay_config or OverlayConfig()
|
| 40 |
+
|
| 41 |
+
# Load RGB image
|
| 42 |
+
rgb_img = np.array(Image.open(rgb_path))
|
| 43 |
+
|
| 44 |
+
# Create output image starting with the RGB image
|
| 45 |
+
output_img = rgb_img.copy()
|
| 46 |
+
|
| 47 |
+
# Load and overlay AV segmentation if provided
|
| 48 |
+
if av_path and (overlay_config.layers.arteries or overlay_config.layers.veins):
|
| 49 |
+
av_mask = np.array(Image.open(av_path))
|
| 50 |
+
|
| 51 |
+
# Create masks for arteries (1), veins (2) and intersections (3)
|
| 52 |
+
artery_mask = av_mask == 1
|
| 53 |
+
vein_mask = av_mask == 2
|
| 54 |
+
intersection_mask = av_mask == 3
|
| 55 |
+
|
| 56 |
+
if overlay_config.layers.arteries:
|
| 57 |
+
artery_combined = np.logical_or(artery_mask, intersection_mask)
|
| 58 |
+
output_img[artery_combined, :] = overlay_config.colors.artery
|
| 59 |
+
|
| 60 |
+
if overlay_config.layers.veins:
|
| 61 |
+
vein_combined = np.logical_or(vein_mask, intersection_mask)
|
| 62 |
+
output_img[vein_combined, :] = overlay_config.colors.vein
|
| 63 |
+
|
| 64 |
+
# Load and overlay optic disc segmentation if provided
|
| 65 |
+
if disc_path and overlay_config.layers.disc:
|
| 66 |
+
disc_mask = np.array(Image.open(disc_path)) > 0
|
| 67 |
+
output_img[disc_mask, :] = overlay_config.colors.disc
|
| 68 |
+
|
| 69 |
+
if ring_2r_path and overlay_config.layers.ring_2r:
|
| 70 |
+
ring_2r_mask = np.array(Image.open(ring_2r_path)) > 0
|
| 71 |
+
output_img[ring_2r_mask, :] = overlay_config.colors.ring_2r
|
| 72 |
+
|
| 73 |
+
if ring_3r_path and overlay_config.layers.ring_3r:
|
| 74 |
+
ring_3r_mask = np.array(Image.open(ring_3r_path)) > 0
|
| 75 |
+
output_img[ring_3r_mask, :] = overlay_config.colors.ring_3r
|
| 76 |
+
|
| 77 |
+
# Convert to PIL image for drawing the fovea marker
|
| 78 |
+
pil_img = Image.fromarray(output_img)
|
| 79 |
+
|
| 80 |
+
# Add fovea marker if provided
|
| 81 |
+
if fovea_location and overlay_config.layers.fovea:
|
| 82 |
+
draw = ImageDraw.Draw(pil_img)
|
| 83 |
+
x, y = fovea_location
|
| 84 |
+
marker_size = (
|
| 85 |
+
min(pil_img.width, pil_img.height) // 50
|
| 86 |
+
) # Scale marker with image
|
| 87 |
+
|
| 88 |
+
# Draw yellow X at fovea location
|
| 89 |
+
draw.line(
|
| 90 |
+
[(x - marker_size, y - marker_size), (x + marker_size, y + marker_size)],
|
| 91 |
+
fill=overlay_config.colors.fovea,
|
| 92 |
+
width=2,
|
| 93 |
+
)
|
| 94 |
+
draw.line(
|
| 95 |
+
[(x - marker_size, y + marker_size), (x + marker_size, y - marker_size)],
|
| 96 |
+
fill=overlay_config.colors.fovea,
|
| 97 |
+
width=2,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Convert back to numpy array
|
| 101 |
+
output_img = np.array(pil_img)
|
| 102 |
+
|
| 103 |
+
# Save output if path provided
|
| 104 |
+
if output_path:
|
| 105 |
+
Image.fromarray(output_img).save(output_path)
|
| 106 |
+
|
| 107 |
+
return output_img
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def batch_create_overlays(
|
| 111 |
+
rgb_dir: Path,
|
| 112 |
+
output_dir: Path,
|
| 113 |
+
av_dir: Optional[Path] = None,
|
| 114 |
+
disc_dir: Optional[Path] = None,
|
| 115 |
+
ring_2r_dir: Optional[Path] = None,
|
| 116 |
+
ring_3r_dir: Optional[Path] = None,
|
| 117 |
+
fovea_data: Optional[Dict[str, Tuple[int, int]]] = None,
|
| 118 |
+
overlay_config: Optional[OverlayConfig] = None,
|
| 119 |
+
) -> None:
|
| 120 |
+
"""
|
| 121 |
+
Create visualization overlays for a batch of images.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
rgb_dir: Directory containing RGB fundus images
|
| 125 |
+
output_dir: Directory to save visualization images
|
| 126 |
+
av_dir: Optional directory containing AV segmentations
|
| 127 |
+
disc_dir: Optional directory containing disc segmentations
|
| 128 |
+
ring_2r_dir: Optional directory containing 2r ring masks
|
| 129 |
+
ring_3r_dir: Optional directory containing 3r ring masks
|
| 130 |
+
fovea_data: Optional dictionary mapping image IDs to fovea coordinates
|
| 131 |
+
overlay_config: Overlay display configuration including enabled layers and colors
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
List of paths to created visualization images
|
| 135 |
+
"""
|
| 136 |
+
# Create output directory if it doesn't exist
|
| 137 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 138 |
+
overlay_config = overlay_config or OverlayConfig()
|
| 139 |
+
|
| 140 |
+
# Get all RGB images
|
| 141 |
+
rgb_files = list(rgb_dir.glob("*.png"))
|
| 142 |
+
if not rgb_files:
|
| 143 |
+
logger.warning("No RGB images found for overlays in %s", rgb_dir)
|
| 144 |
+
return []
|
| 145 |
+
logger.info("Creating overlays for %d images", len(rgb_files))
|
| 146 |
+
|
| 147 |
+
# Process each image
|
| 148 |
+
for rgb_file in rgb_files:
|
| 149 |
+
image_id = rgb_file.stem
|
| 150 |
+
|
| 151 |
+
# Check for corresponding AV segmentation
|
| 152 |
+
av_file = None
|
| 153 |
+
if av_dir:
|
| 154 |
+
av_file_path = av_dir / f"{image_id}.png"
|
| 155 |
+
if av_file_path.exists():
|
| 156 |
+
av_file = str(av_file_path)
|
| 157 |
+
|
| 158 |
+
# Check for corresponding disc segmentation
|
| 159 |
+
disc_file = None
|
| 160 |
+
if disc_dir:
|
| 161 |
+
disc_file_path = disc_dir / f"{image_id}.png"
|
| 162 |
+
if disc_file_path.exists():
|
| 163 |
+
disc_file = str(disc_file_path)
|
| 164 |
+
|
| 165 |
+
ring_2r_file = None
|
| 166 |
+
if ring_2r_dir:
|
| 167 |
+
ring_2r_file_path = ring_2r_dir / f"{image_id}.png"
|
| 168 |
+
if ring_2r_file_path.exists():
|
| 169 |
+
ring_2r_file = str(ring_2r_file_path)
|
| 170 |
+
|
| 171 |
+
ring_3r_file = None
|
| 172 |
+
if ring_3r_dir:
|
| 173 |
+
ring_3r_file_path = ring_3r_dir / f"{image_id}.png"
|
| 174 |
+
if ring_3r_file_path.exists():
|
| 175 |
+
ring_3r_file = str(ring_3r_file_path)
|
| 176 |
+
|
| 177 |
+
# Get fovea location if available
|
| 178 |
+
fovea_location = None
|
| 179 |
+
if fovea_data and image_id in fovea_data:
|
| 180 |
+
fovea_location = fovea_data[image_id]
|
| 181 |
+
|
| 182 |
+
# Create output path
|
| 183 |
+
output_file = output_dir / f"{image_id}.png"
|
| 184 |
+
|
| 185 |
+
# Create and save overlay
|
| 186 |
+
create_fundus_overlay(
|
| 187 |
+
rgb_path=str(rgb_file),
|
| 188 |
+
av_path=av_file,
|
| 189 |
+
disc_path=disc_file,
|
| 190 |
+
ring_2r_path=ring_2r_file,
|
| 191 |
+
ring_3r_path=ring_3r_file,
|
| 192 |
+
fovea_location=fovea_location,
|
| 193 |
+
output_path=str(output_file),
|
| 194 |
+
overlay_config=overlay_config,
|
| 195 |
+
)
|
| 196 |
+
logger.info("Finished overlay generation in %s", output_dir)
|
vessels/vessels_july24.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bdabae77502648acd1c176bff6d5e3c8295da60f18f2f84fe1bcd6181d2b2ca4
|
| 3 |
+
size 352821632
|
vessels/vessels_july24_DRHAGIS.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:406cf89ed2c35713296096cdfdbc2d6e67c164e25dbd6542612f31e2bfa0c85e
|
| 3 |
+
size 352848262
|