zyf0717 commited on
Commit
1386847
·
1 Parent(s): b655085

Enhance device selection and logging for inference; add end-to-end tests

Browse files
README.md CHANGED
@@ -33,6 +33,7 @@ This is not the canonical upstream repository. The upstream project remains `Eye
33
  - The legacy `setup.py` and installed `vascx` console script were removed
34
  - Supported entrypoints are `./run.sh` and `python -m vascx_models`
35
  - Overlay generation can now be configured from the root `config.yaml`
 
36
  - Local helper scripts and docs were updated to point at this fork instead of the upstream Hub repo
37
  - Generated outputs, caches, and other non-repository artifacts are excluded from version control
38
 
@@ -71,6 +72,7 @@ Run the full pipeline:
71
 
72
  ```bash
73
  INPUT_PATH=/path/to/images OUTPUT_PATH=/path/to/output N_JOBS=4 ./run.sh
 
74
  ./run.sh --sample-run
75
  ```
76
 
@@ -94,10 +96,25 @@ Typical examples:
94
  python -m vascx_models run /path/to/images /path/to/output
95
  python -m vascx_models run /path/to/image_list.csv /path/to/output
96
  python -m vascx_models run /path/to/preprocessed/images /path/to/output --no-preprocess
 
 
97
  python -m vascx_models run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
98
  python -m vascx_models run /path/to/images /path/to/output --no-vessels
99
  ```
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  ## Configuration
102
 
103
  This fork adds a root-level `config.yaml` for overlay behavior, disc-circle generation, and vessel-width sampling.
@@ -197,10 +214,30 @@ Current measurement behavior is intentionally conservative:
197
  - `vascx_models/`: package source and CLI
198
  - `artery_vein/`, `disc/`, `fovea/`, `vessels/`, `quality/`, `odfd/`, `discedge/`: model artifacts
199
  - `config.yaml`: fork-specific overlay configuration
 
200
  - `run.sh`: primary local runner
201
  - `tests/`: pytest suite
202
  - `notebooks/`: preprocessing and inference examples
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  ## Upstream Reference
205
 
206
  Original upstream project:
 
33
  - The legacy `setup.py` and installed `vascx` console script were removed
34
  - Supported entrypoints are `./run.sh` and `python -m vascx_models`
35
  - Overlay generation can now be configured from the root `config.yaml`
36
+ - Inference device selection is automatic by default and can be overridden explicitly
37
  - Local helper scripts and docs were updated to point at this fork instead of the upstream Hub repo
38
  - Generated outputs, caches, and other non-repository artifacts are excluded from version control
39
 
 
72
 
73
  ```bash
74
  INPUT_PATH=/path/to/images OUTPUT_PATH=/path/to/output N_JOBS=4 ./run.sh
75
+ DEVICE=cpu INPUT_PATH=/path/to/images OUTPUT_PATH=/path/to/output ./run.sh
76
  ./run.sh --sample-run
77
  ```
78
 
 
96
  python -m vascx_models run /path/to/images /path/to/output
97
  python -m vascx_models run /path/to/image_list.csv /path/to/output
98
  python -m vascx_models run /path/to/preprocessed/images /path/to/output --no-preprocess
99
+ python -m vascx_models run /path/to/images /path/to/output --device auto
100
+ python -m vascx_models run /path/to/images /path/to/output --device cpu
101
  python -m vascx_models run /path/to/images /path/to/output --no-disc --no-quality --no-fovea --no-overlay
102
  python -m vascx_models run /path/to/images /path/to/output --no-vessels
103
  ```
104
 
105
+ ## Device Selection
106
+
107
+ Inference device selection is automatic by default.
108
+
109
+ - `--device auto` is the default for `python -m vascx_models run`
110
+ - `DEVICE=auto` is the default for `./run.sh`
111
+ - Auto-selection priority is `cuda` first, then Apple Metal `mps`, then `cpu`
112
+ - The CLI logs detected availability as `cuda=...`, `mps=...`, `cpu=True`
113
+ - The CLI also logs the selected device for each run
114
+ - You can force a backend with `--device cuda`, `--device mps`, or `--device cpu`
115
+ - `./run.sh` forwards the `DEVICE` environment variable to the Python CLI
116
+ - If you request `cuda` or `mps` explicitly and that backend is unavailable, the run exits with a clear error instead of silently falling back
117
+
118
  ## Configuration
119
 
120
  This fork adds a root-level `config.yaml` for overlay behavior, disc-circle generation, and vessel-width sampling.
 
214
  - `vascx_models/`: package source and CLI
215
  - `artery_vein/`, `disc/`, `fovea/`, `vessels/`, `quality/`, `odfd/`, `discedge/`: model artifacts
216
  - `config.yaml`: fork-specific overlay configuration
217
+ - `pytest.ini`: pytest marker definitions for slow and end-to-end tests
218
  - `run.sh`: primary local runner
219
  - `tests/`: pytest suite
220
  - `notebooks/`: preprocessing and inference examples
221
 
222
+ ## Testing
223
+
224
+ The test suite includes unit tests, CLI tests, and an opt-in real-model single-image end-to-end smoke test in `tests/test_e2e.py`.
225
+
226
+ Useful commands:
227
+
228
+ ```bash
229
+ conda run -n vascx-fork pytest
230
+ KMP_DUPLICATE_LIB_OK=TRUE conda run -n vascx-fork pytest tests/test_e2e.py -q
231
+ KMP_DUPLICATE_LIB_OK=TRUE VASCX_RUN_E2E=1 conda run -n vascx-fork pytest tests/test_e2e.py -q -k cpu
232
+ ```
233
+
234
+ Explicitly tested in this fork as of April 21, 2026:
235
+
236
+ - README and CLI/config behavior updates are covered by the regular pytest suite
237
+ - device resolution priority and explicit unavailable-device failures are covered by unit tests
238
+ - the real single-image end-to-end pipeline was run successfully on CPU with preprocessing enabled
239
+ - the end-to-end test is parameterized for `cpu`, `cuda`, and `mps`, but actual `cuda` and `mps` execution were not exercised in this workspace because those backends were unavailable
240
+
241
  ## Upstream Reference
242
 
243
  Original upstream project:
pytest.ini ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ e2e: real-model end-to-end integration tests
4
+ slow: long-running tests
5
+ filterwarnings =
6
+ ignore:`torch\.jit\.interface` is deprecated\. Please use `torch\.compile` instead\.:DeprecationWarning
run.sh CHANGED
@@ -10,6 +10,7 @@ TIMESTAMP="$(date +"%Y%m%d_%H%M%S")"
10
  DEFAULT_OUTPUT_PATH="$REPO_ROOT/output_$TIMESTAMP"
11
  OUTPUT_PATH="${OUTPUT_PATH:-$DEFAULT_OUTPUT_PATH}"
12
  N_JOBS="${N_JOBS:-1}"
 
13
 
14
  while [[ $# -gt 0 ]]; do
15
  case "$1" in
@@ -37,6 +38,7 @@ echo " conda env: $CONDA_ENV"
37
  echo " input path: $INPUT_PATH"
38
  echo " output path: $OUTPUT_PATH"
39
  echo " n_jobs: $N_JOBS"
 
40
 
41
  CONDA_BASE="$(conda info --base)"
42
  # shellcheck disable=SC1091
@@ -44,4 +46,4 @@ source "$CONDA_BASE/etc/profile.d/conda.sh"
44
  conda activate "$CONDA_ENV"
45
 
46
  cd "$REPO_ROOT"
47
- exec python -m vascx_models run "$INPUT_PATH" "$OUTPUT_PATH" --n_jobs "$N_JOBS"
 
10
  DEFAULT_OUTPUT_PATH="$REPO_ROOT/output_$TIMESTAMP"
11
  OUTPUT_PATH="${OUTPUT_PATH:-$DEFAULT_OUTPUT_PATH}"
12
  N_JOBS="${N_JOBS:-1}"
13
+ DEVICE="${DEVICE:-auto}"
14
 
15
  while [[ $# -gt 0 ]]; do
16
  case "$1" in
 
38
  echo " input path: $INPUT_PATH"
39
  echo " output path: $OUTPUT_PATH"
40
  echo " n_jobs: $N_JOBS"
41
+ echo " device: $DEVICE"
42
 
43
  CONDA_BASE="$(conda info --base)"
44
  # shellcheck disable=SC1091
 
46
  conda activate "$CONDA_ENV"
47
 
48
  cd "$REPO_ROOT"
49
+ exec python -m vascx_models run "$INPUT_PATH" "$OUTPUT_PATH" --n_jobs "$N_JOBS" --device "$DEVICE"
tests/test_cli.py CHANGED
@@ -1,8 +1,9 @@
1
  from pathlib import Path
2
 
 
3
  import pandas as pd
4
  from click.testing import CliRunner
5
- import logging
6
  from PIL import Image
7
 
8
  from vascx_models.cli import cli
@@ -20,7 +21,15 @@ def test_cli_run_passes_measurement_config_and_data_to_overlays(
20
 
21
  calls: dict[str, object] = {}
22
 
23
- monkeypatch.setattr("vascx_models.cli.preferred_device", lambda: "cpu")
 
 
 
 
 
 
 
 
24
 
25
  def fake_run_segmentation_vessels_and_av(**kwargs):
26
  calls["run_segmentation_vessels_and_av"] = kwargs
@@ -128,9 +137,11 @@ def test_cli_run_passes_measurement_config_and_data_to_overlays(
128
  )
129
 
130
  assert result.exit_code == 0, result.output
 
131
  assert calls["run_segmentation_vessels_and_av"]["artery_color"] == (170, 0, 0)
132
  assert calls["run_segmentation_vessels_and_av"]["vein_color"] == (0, 0, 187)
133
  assert calls["run_segmentation_vessels_and_av"]["vessel_color"] == (0, 204, 0)
 
134
  assert calls["run_segmentation_disc"]["disc_color"] == (221, 221, 221)
135
  assert calls["measure_vessel_widths"]["inner_circle"].name == "2r"
136
  assert calls["measure_vessel_widths"]["outer_circle"].name == "3r"
@@ -167,3 +178,50 @@ def test_cli_run_reports_missing_path_column_in_csv(tmp_path: Path, caplog) -> N
167
 
168
  assert result.exit_code == 0
169
  assert "CSV must contain a 'path' column" in caplog.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
 
3
+ import logging
4
  import pandas as pd
5
  from click.testing import CliRunner
6
+ import torch
7
  from PIL import Image
8
 
9
  from vascx_models.cli import cli
 
21
 
22
  calls: dict[str, object] = {}
23
 
24
+ def fake_available_device_types():
25
+ return {"cuda": False, "mps": False, "cpu": True}
26
+
27
+ def fake_resolve_device(device_name):
28
+ calls["device_name"] = device_name
29
+ return torch.device("cpu")
30
+
31
+ monkeypatch.setattr("vascx_models.cli.available_device_types", fake_available_device_types)
32
+ monkeypatch.setattr("vascx_models.cli.resolve_device", fake_resolve_device)
33
 
34
  def fake_run_segmentation_vessels_and_av(**kwargs):
35
  calls["run_segmentation_vessels_and_av"] = kwargs
 
137
  )
138
 
139
  assert result.exit_code == 0, result.output
140
+ assert calls["device_name"] == "auto"
141
  assert calls["run_segmentation_vessels_and_av"]["artery_color"] == (170, 0, 0)
142
  assert calls["run_segmentation_vessels_and_av"]["vein_color"] == (0, 0, 187)
143
  assert calls["run_segmentation_vessels_and_av"]["vessel_color"] == (0, 204, 0)
144
+ assert calls["run_segmentation_vessels_and_av"]["device"] == torch.device("cpu")
145
  assert calls["run_segmentation_disc"]["disc_color"] == (221, 221, 221)
146
  assert calls["measure_vessel_widths"]["inner_circle"].name == "2r"
147
  assert calls["measure_vessel_widths"]["outer_circle"].name == "3r"
 
178
 
179
  assert result.exit_code == 0
180
  assert "CSV must contain a 'path' column" in caplog.text
181
+
182
+
183
+ def test_cli_run_accepts_explicit_device_and_logs_selection(
184
+ tmp_path: Path, monkeypatch, caplog
185
+ ) -> None:
186
+ input_dir = tmp_path / "input"
187
+ output_dir = tmp_path / "output"
188
+ input_dir.mkdir()
189
+ Image.new("RGB", (32, 32), color=(0, 0, 0)).save(input_dir / "sample.png")
190
+
191
+ calls: dict[str, object] = {}
192
+
193
+ monkeypatch.setattr(
194
+ "vascx_models.cli.available_device_types",
195
+ lambda: {"cuda": False, "mps": False, "cpu": True},
196
+ )
197
+
198
+ def fake_resolve_device(device_name):
199
+ calls["device_name"] = device_name
200
+ return torch.device("cpu")
201
+
202
+ monkeypatch.setattr("vascx_models.cli.resolve_device", fake_resolve_device)
203
+ monkeypatch.setattr("vascx_models.cli.run_quality_estimation", lambda **kwargs: pd.DataFrame())
204
+ monkeypatch.setattr("vascx_models.cli.run_fovea_detection", lambda **kwargs: pd.DataFrame())
205
+
206
+ with caplog.at_level(logging.INFO):
207
+ result = CliRunner().invoke(
208
+ cli,
209
+ [
210
+ "run",
211
+ str(input_dir),
212
+ str(output_dir),
213
+ "--no-preprocess",
214
+ "--no-vessels",
215
+ "--no-disc",
216
+ "--no-quality",
217
+ "--no-fovea",
218
+ "--no-overlay",
219
+ "--device",
220
+ "cpu",
221
+ ],
222
+ )
223
+
224
+ assert result.exit_code == 0, result.output
225
+ assert calls["device_name"] == "cpu"
226
+ assert "Device availability: cuda=False, mps=False, cpu=True" in caplog.text
227
+ assert "Using requested device 'cpu': cpu" in caplog.text
tests/test_e2e.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pytest
10
+ import torch
11
+ from PIL import Image
12
+ from rtnls_fundusprep.cli import _run_preprocessing
13
+
14
+ from vascx_models.config import AppConfig
15
+ from vascx_models.disc_circles import generate_disc_circles
16
+ from vascx_models.inference import (
17
+ run_fovea_detection,
18
+ run_quality_estimation,
19
+ run_segmentation_disc,
20
+ run_segmentation_vessels_and_av,
21
+ )
22
+ from vascx_models.runtime import configure_runtime_environment
23
+ from vascx_models.utils import batch_create_overlays
24
+ from vascx_models.vessel_widths import (
25
+ measure_vessel_widths_between_disc_circle_pair,
26
+ resolve_vessel_width_circle_pair,
27
+ )
28
+
29
+ pytestmark = [pytest.mark.e2e, pytest.mark.slow]
30
+
31
+ REPO_ROOT = Path(__file__).resolve().parents[1]
32
+ SAMPLE_IMAGE = REPO_ROOT / "samples" / "fundus" / "original" / "DRIVE_22.png"
33
+ EXPECTED_VESSEL_WIDTH_COLUMNS = [
34
+ "image_id",
35
+ "inner_circle",
36
+ "outer_circle",
37
+ "inner_circle_radius_px",
38
+ "outer_circle_radius_px",
39
+ "connection_index",
40
+ "sample_index",
41
+ "x",
42
+ "y",
43
+ "width_px",
44
+ "x_start",
45
+ "y_start",
46
+ "x_end",
47
+ "y_end",
48
+ "vessel_type",
49
+ ]
50
+
51
+
52
+ def _require_e2e_opt_in() -> None:
53
+ if os.environ.get("VASCX_RUN_E2E") != "1":
54
+ pytest.skip("Set VASCX_RUN_E2E=1 to run real-model end-to-end tests")
55
+
56
+
57
+ def _device_or_skip(device_name: str) -> torch.device:
58
+ if device_name == "cpu":
59
+ return torch.device("cpu")
60
+ if device_name == "cuda":
61
+ if not torch.cuda.is_available():
62
+ pytest.skip("CUDA is not available in this environment")
63
+ return torch.device("cuda:0")
64
+ if device_name == "mps":
65
+ if not torch.backends.mps.is_available():
66
+ pytest.skip("MPS is not available in this environment")
67
+ return torch.device("mps")
68
+ raise AssertionError(f"Unsupported device name: {device_name}")
69
+
70
+
71
+ def _prepare_single_image_input(tmp_path: Path) -> tuple[str, Path, Path]:
72
+ input_dir = tmp_path / "input"
73
+ input_dir.mkdir()
74
+
75
+ image_path = input_dir / SAMPLE_IMAGE.name
76
+ shutil.copy2(SAMPLE_IMAGE, image_path)
77
+ return SAMPLE_IMAGE.stem, image_path, input_dir
78
+
79
+
80
+ def _assert_nonempty_mask(path: Path) -> None:
81
+ assert path.exists()
82
+ assert np.any(np.array(Image.open(path)) > 0)
83
+
84
+
85
+ @pytest.mark.parametrize("device_name", ["cpu", "cuda", "mps"])
86
+ def test_single_image_pipeline_smoke(tmp_path: Path, device_name: str) -> None:
87
+ _require_e2e_opt_in()
88
+ configure_runtime_environment()
89
+ device = _device_or_skip(device_name)
90
+ app_config = AppConfig()
91
+
92
+ image_id, image_path, input_dir = _prepare_single_image_input(tmp_path)
93
+
94
+ output_dir = tmp_path / "output"
95
+ output_dir.mkdir()
96
+ preprocessed_rgb_dir = output_dir / "preprocessed_rgb"
97
+ av_dir = output_dir / "artery_vein"
98
+ vessels_dir = output_dir / "vessels"
99
+ disc_dir = output_dir / "disc"
100
+ disc_circles_dir = output_dir / "disc_circles"
101
+ overlay_dir = output_dir / "overlays"
102
+ preprocessed_rgb_dir.mkdir()
103
+ av_dir.mkdir()
104
+ vessels_dir.mkdir()
105
+ disc_dir.mkdir()
106
+ overlay_dir.mkdir()
107
+
108
+ bounds_path = output_dir / "bounds.csv"
109
+ quality_path = output_dir / "quality.csv"
110
+ fovea_path = output_dir / "fovea.csv"
111
+ disc_geometry_path = output_dir / "disc_geometry.csv"
112
+ vessel_widths_path = output_dir / "vessel_widths.csv"
113
+
114
+ _run_preprocessing(
115
+ files=[image_path],
116
+ ids=[image_id],
117
+ rgb_path=preprocessed_rgb_dir,
118
+ bounds_path=bounds_path,
119
+ n_jobs=1,
120
+ )
121
+ preprocessed_image_path = preprocessed_rgb_dir / f"{image_id}.png"
122
+
123
+ df_quality = run_quality_estimation([preprocessed_image_path], ids=[image_id], device=device)
124
+ df_quality.to_csv(quality_path)
125
+
126
+ run_segmentation_vessels_and_av(
127
+ rgb_paths=[preprocessed_image_path],
128
+ ids=[image_id],
129
+ av_path=av_dir,
130
+ vessels_path=vessels_dir,
131
+ artery_color=app_config.overlay.colors.artery,
132
+ vein_color=app_config.overlay.colors.vein,
133
+ vessel_color=app_config.overlay.colors.vessel,
134
+ device=device,
135
+ )
136
+ run_segmentation_disc(
137
+ rgb_paths=[preprocessed_image_path],
138
+ ids=[image_id],
139
+ output_path=disc_dir,
140
+ disc_color=app_config.overlay.colors.disc,
141
+ device=device,
142
+ )
143
+
144
+ df_disc_geometry = generate_disc_circles(
145
+ disc_dir=disc_dir,
146
+ circle_output_dir=disc_circles_dir,
147
+ circles=app_config.overlay.circles,
148
+ measurements_path=disc_geometry_path,
149
+ )
150
+ inner_circle, outer_circle = resolve_vessel_width_circle_pair(
151
+ app_config.overlay.circles,
152
+ inner_circle_name=app_config.vessel_widths.inner_circle,
153
+ outer_circle_name=app_config.vessel_widths.outer_circle,
154
+ )
155
+ df_vessel_widths = measure_vessel_widths_between_disc_circle_pair(
156
+ vessels_dir=vessels_dir,
157
+ av_dir=av_dir,
158
+ disc_geometry_path=disc_geometry_path,
159
+ inner_circle=inner_circle,
160
+ outer_circle=outer_circle,
161
+ output_path=vessel_widths_path,
162
+ samples_per_connection=app_config.vessel_widths.samples_per_connection,
163
+ )
164
+
165
+ df_fovea = run_fovea_detection([preprocessed_image_path], ids=[image_id], device=device)
166
+ df_fovea.to_csv(fovea_path)
167
+
168
+ batch_create_overlays(
169
+ rgb_dir=preprocessed_rgb_dir,
170
+ output_dir=overlay_dir,
171
+ av_dir=av_dir,
172
+ disc_dir=disc_dir,
173
+ vessels_dir=vessels_dir,
174
+ circle_dirs={
175
+ circle.name: disc_circles_dir / circle.name
176
+ for circle in app_config.overlay.circles
177
+ },
178
+ vessel_width_data=df_vessel_widths,
179
+ fovea_data={
180
+ index: (row["x_fovea"], row["y_fovea"])
181
+ for index, row in df_fovea.iterrows()
182
+ },
183
+ overlay_config=app_config.overlay,
184
+ )
185
+
186
+ assert df_quality.index.tolist() == [image_id]
187
+ assert df_quality.columns.tolist() == ["q1", "q2", "q3"]
188
+ assert np.isfinite(df_quality.to_numpy()).all()
189
+ assert quality_path.exists()
190
+ assert bounds_path.exists()
191
+ assert preprocessed_image_path.exists()
192
+
193
+ _assert_nonempty_mask(av_dir / f"{image_id}.png")
194
+ _assert_nonempty_mask(vessels_dir / f"{image_id}.png")
195
+ _assert_nonempty_mask(disc_dir / f"{image_id}.png")
196
+
197
+ assert df_disc_geometry.index.tolist() == [image_id]
198
+ assert float(df_disc_geometry.loc[image_id, "disc_radius_px"]) > 0.0
199
+ assert disc_geometry_path.exists()
200
+ for circle in app_config.overlay.circles:
201
+ _assert_nonempty_mask(disc_circles_dir / circle.name / f"{image_id}.png")
202
+
203
+ assert vessel_widths_path.exists()
204
+ df_vessel_widths_disk = pd.read_csv(vessel_widths_path)
205
+ assert df_vessel_widths_disk.columns.tolist() == EXPECTED_VESSEL_WIDTH_COLUMNS
206
+ assert df_vessel_widths.columns.tolist() == EXPECTED_VESSEL_WIDTH_COLUMNS
207
+ if not df_vessel_widths.empty:
208
+ assert df_vessel_widths["image_id"].eq(image_id).all()
209
+ assert df_vessel_widths["vessel_type"].isin(["artery", "vein"]).all()
210
+ assert (df_vessel_widths["width_px"] > 0).all()
211
+
212
+ assert df_fovea.index.tolist() == [image_id]
213
+ assert df_fovea.columns.tolist() == ["x_fovea", "y_fovea"]
214
+ assert np.isfinite(df_fovea.to_numpy()).all()
215
+ assert fovea_path.exists()
216
+
217
+ overlay_path = overlay_dir / f"{image_id}.png"
218
+ assert overlay_path.exists()
219
+ assert Image.open(overlay_path).size == Image.open(preprocessed_image_path).size
tests/test_inference.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+
4
+ from vascx_models import inference
5
+
6
+
7
+ def test_resolve_device_auto_prefers_cuda_then_mps_then_cpu(monkeypatch) -> None:
8
+ monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: True)
9
+ monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: True)
10
+ assert inference.resolve_device("auto") == torch.device("cuda:0")
11
+
12
+ monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False)
13
+ monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: True)
14
+ assert inference.resolve_device("auto") == torch.device("mps")
15
+
16
+ monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False)
17
+ monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: False)
18
+ assert inference.resolve_device("auto") == torch.device("cpu")
19
+
20
+
21
+ def test_resolve_device_rejects_unavailable_requested_accelerator(monkeypatch) -> None:
22
+ monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False)
23
+ monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: False)
24
+
25
+ with pytest.raises(RuntimeError, match="Requested device 'cuda' is not available"):
26
+ inference.resolve_device("cuda")
27
+
28
+ with pytest.raises(RuntimeError, match="Requested device 'mps' is not available"):
29
+ inference.resolve_device("mps")
vascx_models/cli.py CHANGED
@@ -14,7 +14,8 @@ from .runtime import configure_runtime_environment
14
  configure_runtime_environment()
15
 
16
  from .inference import (
17
- preferred_device,
 
18
  run_fovea_detection,
19
  run_quality_estimation,
20
  run_segmentation_disc,
@@ -75,6 +76,14 @@ def cli():
75
  default=None,
76
  help="Create visualization overlays. Defaults to the config value when set.",
77
  )
 
 
 
 
 
 
 
 
78
  @click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
79
  def run(
80
  data_path,
@@ -86,6 +95,7 @@ def run(
86
  quality,
87
  fovea,
88
  overlay,
 
89
  n_jobs,
90
  ):
91
  """Run the complete inference pipeline on fundus images.
@@ -185,9 +195,21 @@ def run(
185
  ids = [f.stem for f in preprocessed_files]
186
  logger.info("Prepared %d images for inference", len(preprocessed_files))
187
 
188
- # Prefer hardware acceleration when the active torch build supports it.
189
- device = preferred_device()
190
- logger.info("Using device: %s", device)
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  # Step 2: Run quality estimation if requested
193
  if quality:
 
14
  configure_runtime_environment()
15
 
16
  from .inference import (
17
+ available_device_types,
18
+ resolve_device,
19
  run_fovea_detection,
20
  run_quality_estimation,
21
  run_segmentation_disc,
 
76
  default=None,
77
  help="Create visualization overlays. Defaults to the config value when set.",
78
  )
79
+ @click.option(
80
+ "--device",
81
+ "device_name",
82
+ type=click.Choice(["auto", "cuda", "mps", "cpu"], case_sensitive=False),
83
+ default="auto",
84
+ show_default=True,
85
+ help="Inference device. 'auto' prefers CUDA first, then Apple Metal (MPS), then CPU.",
86
+ )
87
  @click.option("--n_jobs", type=int, default=4, help="Number of preprocessing workers")
88
  def run(
89
  data_path,
 
95
  quality,
96
  fovea,
97
  overlay,
98
+ device_name,
99
  n_jobs,
100
  ):
101
  """Run the complete inference pipeline on fundus images.
 
195
  ids = [f.stem for f in preprocessed_files]
196
  logger.info("Prepared %d images for inference", len(preprocessed_files))
197
 
198
+ available_devices = available_device_types()
199
+ logger.info(
200
+ "Device availability: cuda=%s, mps=%s, cpu=%s",
201
+ available_devices["cuda"],
202
+ available_devices["mps"],
203
+ available_devices["cpu"],
204
+ )
205
+ try:
206
+ device = resolve_device(device_name)
207
+ except (RuntimeError, ValueError) as exc:
208
+ raise click.ClickException(str(exc)) from exc
209
+ if device_name == "auto":
210
+ logger.info("Auto-selected device: %s", device)
211
+ else:
212
+ logger.info("Using requested device '%s': %s", device_name, device)
213
 
214
  # Step 2: Run quality estimation if requested
215
  if quality:
vascx_models/disc_circles.py CHANGED
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
12
 
13
 
14
  def _save_visual_circle_mask(mask: np.ndarray, path: Path, color: tuple[int, int, int]) -> None:
15
- image = Image.fromarray(mask.astype(np.uint8), mode="P")
16
  palette = [0] * (256 * 3)
17
  palette[255 * 3 : 255 * 3 + 3] = list(color)
18
  image.putpalette(palette)
 
12
 
13
 
14
  def _save_visual_circle_mask(mask: np.ndarray, path: Path, color: tuple[int, int, int]) -> None:
15
+ image = Image.fromarray(mask.astype(np.uint8))
16
  palette = [0] * (256 * 3)
17
  palette[255 * 3 : 255 * 3 + 3] = list(color)
18
  image.putpalette(palette)
vascx_models/inference.py CHANGED
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
23
  def _save_visual_mask(mask: np.ndarray, path: str, color_by_value: dict[int, tuple[int, int, int]]) -> None:
24
  """Save a label mask with a palette while preserving label values."""
25
  mask_uint8 = mask.squeeze().astype(np.uint8)
26
- image = Image.fromarray(mask_uint8, mode="P")
27
  palette = [0] * (256 * 3)
28
  for value, color in color_by_value.items():
29
  start = int(value) * 3
@@ -32,12 +32,37 @@ def _save_visual_mask(mask: np.ndarray, path: str, color_by_value: dict[int, tup
32
  image.save(path)
33
 
34
 
 
 
 
 
 
 
 
 
35
  def preferred_device() -> torch.device:
36
- if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return torch.device("cuda:0")
38
- if torch.backends.mps.is_available():
 
 
39
  return torch.device("mps")
40
- return torch.device("cpu")
 
 
41
 
42
 
43
  def _inference_num_workers(device: torch.device) -> int:
 
23
  def _save_visual_mask(mask: np.ndarray, path: str, color_by_value: dict[int, tuple[int, int, int]]) -> None:
24
  """Save a label mask with a palette while preserving label values."""
25
  mask_uint8 = mask.squeeze().astype(np.uint8)
26
+ image = Image.fromarray(mask_uint8)
27
  palette = [0] * (256 * 3)
28
  for value, color in color_by_value.items():
29
  start = int(value) * 3
 
32
  image.save(path)
33
 
34
 
35
+ def available_device_types() -> dict[str, bool]:
36
+ return {
37
+ "cuda": torch.cuda.is_available(),
38
+ "mps": torch.backends.mps.is_available(),
39
+ "cpu": True,
40
+ }
41
+
42
+
43
  def preferred_device() -> torch.device:
44
+ return resolve_device("auto")
45
+
46
+
47
+ def resolve_device(device_name: str = "auto") -> torch.device:
48
+ available = available_device_types()
49
+ if device_name == "auto":
50
+ if available["cuda"]:
51
+ return torch.device("cuda:0")
52
+ if available["mps"]:
53
+ return torch.device("mps")
54
+ return torch.device("cpu")
55
+ if device_name == "cuda":
56
+ if not available["cuda"]:
57
+ raise RuntimeError("Requested device 'cuda' is not available")
58
  return torch.device("cuda:0")
59
+ if device_name == "mps":
60
+ if not available["mps"]:
61
+ raise RuntimeError("Requested device 'mps' is not available")
62
  return torch.device("mps")
63
+ if device_name == "cpu":
64
+ return torch.device("cpu")
65
+ raise ValueError(f"Unsupported device '{device_name}'")
66
 
67
 
68
  def _inference_num_workers(device: torch.device) -> int: