|
|
"""Tests for data loader module.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
|
|
|
from stroke_deepisles_demo.core.exceptions import DataLoadError |
|
|
from stroke_deepisles_demo.data.loader import ( |
|
|
DatasetInfo, |
|
|
get_dataset_info, |
|
|
load_isles_dataset, |
|
|
) |
|
|
|
|
|
|
|
|
class TestLoadIslesDataset: |
|
|
"""Tests for load_isles_dataset.""" |
|
|
|
|
|
def test_calls_hf_load_dataset(self) -> None: |
|
|
"""Calls datasets.load_dataset with correct arguments.""" |
|
|
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load: |
|
|
mock_load.return_value = MagicMock() |
|
|
|
|
|
load_isles_dataset("test/dataset") |
|
|
|
|
|
mock_load.assert_called_once() |
|
|
call_args = mock_load.call_args |
|
|
assert call_args.args[0] == "test/dataset" |
|
|
|
|
|
def test_returns_dataset_object(self) -> None: |
|
|
"""Returns the loaded Dataset object.""" |
|
|
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load: |
|
|
expected = MagicMock() |
|
|
mock_load.return_value = expected |
|
|
|
|
|
result = load_isles_dataset() |
|
|
|
|
|
assert result is expected |
|
|
|
|
|
def test_handles_load_error(self) -> None: |
|
|
"""Wraps HF errors in DataLoadError.""" |
|
|
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load: |
|
|
mock_load.side_effect = Exception("Network error") |
|
|
|
|
|
with pytest.raises(DataLoadError, match="Network error"): |
|
|
load_isles_dataset() |
|
|
|
|
|
|
|
|
class TestGetDatasetInfo: |
|
|
"""Tests for get_dataset_info.""" |
|
|
|
|
|
def test_returns_datasetinfo(self) -> None: |
|
|
"""Returns DatasetInfo with expected fields.""" |
|
|
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load: |
|
|
mock_ds = MagicMock() |
|
|
mock_ds.__len__ = MagicMock(return_value=149) |
|
|
|
|
|
mock_ds.info.splits.__getitem__.return_value.num_examples = 149 |
|
|
|
|
|
mock_ds.features = {"dwi": None, "adc": None, "mask": None} |
|
|
mock_load.return_value = mock_ds |
|
|
|
|
|
info = get_dataset_info() |
|
|
|
|
|
assert isinstance(info, DatasetInfo) |
|
|
assert info.num_cases == 149 |
|
|
assert "dwi" in info.modalities |
|
|
assert info.has_ground_truth is True |
|
|
|
|
|
|
|
|
@pytest.mark.integration |
|
|
class TestLoadIslesDatasetIntegration: |
|
|
"""Integration tests that hit the real HuggingFace Hub.""" |
|
|
|
|
|
@pytest.mark.slow |
|
|
def test_load_real_dataset(self) -> None: |
|
|
"""Actually loads ISLES24-MR-Lite from HF Hub.""" |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
dataset = load_isles_dataset(streaming=True) |
|
|
assert dataset is not None |
|
|
|
|
|
|
|
|
assert hasattr(dataset, "features") |
|
|
assert len(dataset.features) > 0 |
|
|
except Exception as e: |
|
|
pytest.fail(f"Failed to load real dataset: {e}") |
|
|
|