| import pytest |
| import torch |
|
|
| from vascx_models import inference |
|
|
|
|
| def test_resolve_device_auto_prefers_cuda_then_mps_then_cpu(monkeypatch) -> None: |
| monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: True) |
| monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: True) |
| assert inference.resolve_device("auto") == torch.device("cuda:0") |
|
|
| monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False) |
| monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: True) |
| assert inference.resolve_device("auto") == torch.device("mps") |
|
|
| monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False) |
| monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: False) |
| assert inference.resolve_device("auto") == torch.device("cpu") |
|
|
|
|
| def test_resolve_device_rejects_unavailable_requested_accelerator(monkeypatch) -> None: |
| monkeypatch.setattr(inference.torch.cuda, "is_available", lambda: False) |
| monkeypatch.setattr(inference.torch.backends.mps, "is_available", lambda: False) |
|
|
| with pytest.raises(RuntimeError, match="Requested device 'cuda' is not available"): |
| inference.resolve_device("cuda") |
|
|
| with pytest.raises(RuntimeError, match="Requested device 'mps' is not available"): |
| inference.resolve_device("mps") |
|
|