vascx-fork / tests /test_inference.py
zyf0717's picture
Enhance device selection and logging for inference; add end-to-end tests
1386847
import pytest
import torch
from vascx_models import inference
def test_resolve_device_auto_prefers_cuda_then_mps_then_cpu(monkeypatch) -> None:
monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: True)
monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: True)
assert inference.resolve_device("auto") == torch.device("cuda:0")
monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: True)
assert inference.resolve_device("auto") == torch.device("mps")
monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: False)
assert inference.resolve_device("auto") == torch.device("cpu")
def test_resolve_device_rejects_unavailable_requested_accelerator(monkeypatch) -> None:
monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False)
monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: False)
with pytest.raises(RuntimeError, match="Requested device 'cuda' is not available"):
inference.resolve_device("cuda")
with pytest.raises(RuntimeError, match="Requested device 'mps' is not available"):
inference.resolve_device("mps")