vascx-fork / tests /test_cli.py
zyf0717's picture
Enhance device selection and logging for inference; add end-to-end tests
1386847
from pathlib import Path
import logging
import pandas as pd
from click.testing import CliRunner
import torch
from PIL import Image
from vascx_models.cli import cli
def test_cli_run_passes_measurement_config_and_data_to_overlays(
tmp_path: Path, monkeypatch
) -> None:
input_dir = tmp_path / "input"
output_dir = tmp_path / "output"
input_dir.mkdir()
output_dir.mkdir()
Image.new("RGB", (32, 32), color=(0, 0, 0)).save(input_dir / "sample.png")
calls: dict[str, object] = {}
def fake_available_device_types():
return {"cuda": False, "mps": False, "cpu": True}
def fake_resolve_device(device_name):
calls["device_name"] = device_name
return torch.device("cpu")
monkeypatch.setattr("vascx_models.cli.available_device_types", fake_available_device_types)
monkeypatch.setattr("vascx_models.cli.resolve_device", fake_resolve_device)
def fake_run_segmentation_vessels_and_av(**kwargs):
calls["run_segmentation_vessels_and_av"] = kwargs
def fake_run_segmentation_disc(**kwargs):
calls["run_segmentation_disc"] = kwargs
monkeypatch.setattr(
"vascx_models.cli.run_segmentation_vessels_and_av",
fake_run_segmentation_vessels_and_av,
)
monkeypatch.setattr("vascx_models.cli.run_segmentation_disc", fake_run_segmentation_disc)
monkeypatch.setattr("vascx_models.cli.run_quality_estimation", lambda **kwargs: pd.DataFrame())
monkeypatch.setattr("vascx_models.cli.run_fovea_detection", lambda **kwargs: pd.DataFrame())
def fake_generate_disc_circles(**kwargs):
calls["generate_disc_circles"] = kwargs
pd.DataFrame(
{
"x_disc_center": [16.0],
"y_disc_center": [16.0],
"disc_radius_px": [5.0],
},
index=["sample"],
).to_csv(kwargs["measurements_path"])
def fake_measure_vessel_widths_between_disc_circle_pair(**kwargs):
calls["measure_vessel_widths"] = kwargs
df = pd.DataFrame(
[
{
"image_id": "sample",
"inner_circle": "2r",
"outer_circle": "3r",
"inner_circle_radius_px": 10.0,
"outer_circle_radius_px": 15.0,
"connection_index": 1,
"sample_index": 1,
"x": 16.0,
"y": 12.0,
"width_px": 7.0,
"x_start": 13.0,
"y_start": 12.0,
"x_end": 19.0,
"y_end": 12.0,
"vessel_type": "artery",
}
]
)
df.to_csv(kwargs["output_path"], index=False)
return df
def fake_batch_create_overlays(**kwargs):
calls["batch_create_overlays"] = kwargs
monkeypatch.setattr("vascx_models.cli.generate_disc_circles", fake_generate_disc_circles)
monkeypatch.setattr(
"vascx_models.cli.measure_vessel_widths_between_disc_circle_pair",
fake_measure_vessel_widths_between_disc_circle_pair,
)
monkeypatch.setattr("vascx_models.cli.batch_create_overlays", fake_batch_create_overlays)
config_path = tmp_path / "config.yaml"
config_path.write_text(
"\n".join(
[
"overlay:",
" colours:",
" artery: '#AA0000'",
" vein: '#0000BB'",
" vessel: '#00CC00'",
" disc: '#DDDDDD'",
" circles:",
" - name: 2r",
" diameter: 2.0",
" color: '#00FF00'",
" - name: 3r",
" diameter: 3.0",
" color: '#00FF00'",
"vessel_widths:",
" inner_circle: 2r",
" outer_circle: 3r",
" samples_per_connection: 4",
]
),
encoding="utf-8",
)
runner = CliRunner()
result = runner.invoke(
cli,
[
"run",
str(input_dir),
str(output_dir),
"--config",
str(config_path),
"--no-preprocess",
"--overlay",
"--vessels",
"--disc",
"--no-quality",
"--no-fovea",
],
)
assert result.exit_code == 0, result.output
assert calls["device_name"] == "auto"
assert calls["run_segmentation_vessels_and_av"]["artery_color"] == (170, 0, 0)
assert calls["run_segmentation_vessels_and_av"]["vein_color"] == (0, 0, 187)
assert calls["run_segmentation_vessels_and_av"]["vessel_color"] == (0, 204, 0)
assert calls["run_segmentation_vessels_and_av"]["device"] == torch.device("cpu")
assert calls["run_segmentation_disc"]["disc_color"] == (221, 221, 221)
assert calls["measure_vessel_widths"]["inner_circle"].name == "2r"
assert calls["measure_vessel_widths"]["outer_circle"].name == "3r"
assert calls["measure_vessel_widths"]["samples_per_connection"] == 4
assert calls["generate_disc_circles"]["circles"][0].name == "2r"
assert calls["generate_disc_circles"]["circles"][0].color == (0, 255, 0)
assert calls["batch_create_overlays"]["vessels_dir"] == output_dir / "vessels"
measurement_data = calls["batch_create_overlays"]["vessel_width_data"]
assert isinstance(measurement_data, pd.DataFrame)
assert measurement_data.iloc[0]["width_px"] == 7.0
def test_cli_run_reports_missing_path_column_in_csv(tmp_path: Path, caplog) -> None:
csv_path = tmp_path / "images.csv"
output_dir = tmp_path / "output"
csv_path.write_text("id\nsample\n", encoding="utf-8")
runner = CliRunner()
with caplog.at_level(logging.ERROR):
result = runner.invoke(
cli,
[
"run",
str(csv_path),
str(output_dir),
"--no-preprocess",
"--no-vessels",
"--no-disc",
"--no-quality",
"--no-fovea",
"--no-overlay",
],
)
assert result.exit_code == 0
assert "CSV must contain a 'path' column" in caplog.text
def test_cli_run_accepts_explicit_device_and_logs_selection(
tmp_path: Path, monkeypatch, caplog
) -> None:
input_dir = tmp_path / "input"
output_dir = tmp_path / "output"
input_dir.mkdir()
Image.new("RGB", (32, 32), color=(0, 0, 0)).save(input_dir / "sample.png")
calls: dict[str, object] = {}
monkeypatch.setattr(
"vascx_models.cli.available_device_types",
lambda: {"cuda": False, "mps": False, "cpu": True},
)
def fake_resolve_device(device_name):
calls["device_name"] = device_name
return torch.device("cpu")
monkeypatch.setattr("vascx_models.cli.resolve_device", fake_resolve_device)
monkeypatch.setattr("vascx_models.cli.run_quality_estimation", lambda **kwargs: pd.DataFrame())
monkeypatch.setattr("vascx_models.cli.run_fovea_detection", lambda **kwargs: pd.DataFrame())
with caplog.at_level(logging.INFO):
result = CliRunner().invoke(
cli,
[
"run",
str(input_dir),
str(output_dir),
"--no-preprocess",
"--no-vessels",
"--no-disc",
"--no-quality",
"--no-fovea",
"--no-overlay",
"--device",
"cpu",
],
)
assert result.exit_code == 0, result.output
assert calls["device_name"] == "cpu"
assert "Device availability: cuda=False, mps=False, cpu=True" in caplog.text
assert "Using requested device 'cpu': cpu" in caplog.text