File size: 3,234 Bytes
3c4c67b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Tests for data loader module."""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from stroke_deepisles_demo.core.exceptions import DataLoadError
from stroke_deepisles_demo.data.loader import (
    DatasetInfo,
    get_dataset_info,
    load_isles_dataset,
)


class TestLoadIslesDataset:
    """Tests for load_isles_dataset."""

    def test_calls_hf_load_dataset(self) -> None:
        """Calls datasets.load_dataset with correct arguments."""
        with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
            mock_load.return_value = MagicMock()

            load_isles_dataset("test/dataset")

            mock_load.assert_called_once()
            call_args = mock_load.call_args
            assert call_args.args[0] == "test/dataset"

    def test_returns_dataset_object(self) -> None:
        """Returns the loaded Dataset object."""
        with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
            expected = MagicMock()
            mock_load.return_value = expected

            result = load_isles_dataset()

            assert result is expected

    def test_handles_load_error(self) -> None:
        """Wraps HF errors in DataLoadError."""
        with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
            mock_load.side_effect = Exception("Network error")

            with pytest.raises(DataLoadError, match="Network error"):
                load_isles_dataset()


class TestGetDatasetInfo:
    """Tests for get_dataset_info."""

    def test_returns_datasetinfo(self) -> None:
        """Returns DatasetInfo with expected fields."""
        with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
            mock_ds = MagicMock()
            mock_ds.__len__ = MagicMock(return_value=149)
            # Mock info.splits['train'].num_examples
            mock_ds.info.splits.__getitem__.return_value.num_examples = 149
            # Mock features as dict-like
            mock_ds.features = {"dwi": None, "adc": None, "mask": None}
            mock_load.return_value = mock_ds

            info = get_dataset_info()

            assert isinstance(info, DatasetInfo)
            assert info.num_cases == 149
            assert "dwi" in info.modalities
            assert info.has_ground_truth is True


@pytest.mark.integration
class TestLoadIslesDatasetIntegration:
    """Integration tests that hit the real HuggingFace Hub."""

    @pytest.mark.slow
    def test_load_real_dataset(self) -> None:
        """Actually loads ISLES24-MR-Lite from HF Hub."""
        # This test requires network access
        # Run with: pytest -m integration
        # Using streaming=True to avoid downloading everything
        try:
            dataset = load_isles_dataset(streaming=True)
            assert dataset is not None
            # Verify we got metadata/features - this confirms connectivity
            # Iterating might trigger heavy downloads or fail if dataset is empty/gated
            assert hasattr(dataset, "features")
            assert len(dataset.features) > 0
        except Exception as e:
            pytest.fail(f"Failed to load real dataset: {e}")