File size: 10,536 Bytes
80cbb1a
363ba14
 
 
 
 
 
 
 
 
 
 
 
 
 
80cbb1a
363ba14
 
80cbb1a
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
 
 
 
 
363ba14
80cbb1a
 
 
 
 
 
 
 
 
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
363ba14
 
 
80cbb1a
363ba14
 
 
80cbb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
 
 
 
363ba14
80cbb1a
 
 
 
 
363ba14
80cbb1a
 
363ba14
80cbb1a
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
 
 
 
 
363ba14
 
 
 
 
 
 
80cbb1a
363ba14
 
80cbb1a
 
363ba14
 
 
80cbb1a
363ba14
 
80cbb1a
 
 
 
 
 
 
 
363ba14
 
 
 
80cbb1a
 
 
 
363ba14
 
 
80cbb1a
 
363ba14
 
80cbb1a
 
 
 
 
363ba14
80cbb1a
 
 
 
 
 
 
363ba14
 
 
80cbb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363ba14
 
 
 
80cbb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e244238
363ba14
80cbb1a
 
 
 
 
 
 
 
 
 
 
363ba14
80cbb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""Unit tests for HuggingFace dataset adapter with mocked HF data access."""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest

from stroke_deepisles_demo.core.exceptions import DataLoadError
from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, build_huggingface_dataset


class TestHuggingFaceDataset:
    """Tests for HuggingFaceDataset class."""

    def test_get_case_writes_files_to_temp_dir(self) -> None:
        """Test that get_case writes NIfTI bytes to temp files."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        case_index = {cid: idx for idx, cid in enumerate(case_ids)}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        # Mock the download method
        mock_data = {
            "dwi_bytes": b"fake_dwi_nifti_data",
            "adc_bytes": b"fake_adc_nifti_data",
            "mask_bytes": b"fake_mask_nifti_data",
        }

        try:
            with patch.object(ds, "_download_case_from_parquet", return_value=mock_data):
                case = ds.get_case(0)

                assert "dwi" in case
                assert "adc" in case
                assert case["dwi"].exists()
                assert case["adc"].exists()
                assert case["dwi"].read_bytes() == b"fake_dwi_nifti_data"
                assert case["adc"].read_bytes() == b"fake_adc_nifti_data"
        finally:
            ds.cleanup()

    def test_get_case_includes_ground_truth_when_available(self) -> None:
        """Test that ground truth is included when lesion_mask is present."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        case_index = {cid: idx for idx, cid in enumerate(case_ids)}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        try:
            # Case with mask
            mock_data_with_mask = {
                "dwi_bytes": b"fake_dwi_nifti_data",
                "adc_bytes": b"fake_adc_nifti_data",
                "mask_bytes": b"fake_mask_nifti_data",
            }
            with patch.object(ds, "_download_case_from_parquet", return_value=mock_data_with_mask):
                case = ds.get_case(0)
                assert "ground_truth" in case
                assert case["ground_truth"].read_bytes() == b"fake_mask_nifti_data"

            # Case without mask
            mock_data_no_mask = {
                "dwi_bytes": b"fake_dwi_nifti_data",
                "adc_bytes": b"fake_adc_nifti_data",
            }
            with patch.object(ds, "_download_case_from_parquet", return_value=mock_data_no_mask):
                case_no_mask = ds.get_case(2)
                assert "ground_truth" not in case_no_mask
        finally:
            ds.cleanup()

    def test_get_case_caches_results(self) -> None:
        """Test that get_case returns cached paths on subsequent calls."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        case_index = {cid: idx for idx, cid in enumerate(case_ids)}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        mock_data = {
            "dwi_bytes": b"fake_dwi_nifti_data",
            "adc_bytes": b"fake_adc_nifti_data",
        }

        try:
            with patch.object(
                ds, "_download_case_from_parquet", return_value=mock_data
            ) as mock_download:
                case1 = ds.get_case(0)
                case2 = ds.get_case(0)

                # Same object returned (cached)
                assert case1 is case2

                # Download was only called once
                assert mock_download.call_count == 1
        finally:
            ds.cleanup()

    def test_context_manager_cleans_up_temp_files(self) -> None:
        """Test that using context manager cleans up temp files."""
        case_ids = ["sub-stroke0001"]
        case_index = {"sub-stroke0001": 0}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        mock_data = {
            "dwi_bytes": b"fake_dwi_nifti_data",
            "adc_bytes": b"fake_adc_nifti_data",
        }

        with patch.object(ds, "_download_case_from_parquet", return_value=mock_data), ds:
            case = ds.get_case(0)
            temp_dir = case["dwi"].parent.parent
            assert temp_dir.exists()

        # After context exit, temp dir should be gone
        assert not temp_dir.exists()

    def test_cleanup_clears_cache(self) -> None:
        """Test that cleanup clears the case cache."""
        case_ids = ["sub-stroke0001"]
        case_index = {"sub-stroke0001": 0}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        mock_data = {
            "dwi_bytes": b"fake_dwi_nifti_data",
            "adc_bytes": b"fake_adc_nifti_data",
        }

        with patch.object(ds, "_download_case_from_parquet", return_value=mock_data):
            ds.get_case(0)
            assert len(ds._cached_cases) == 1

        ds.cleanup()
        assert len(ds._cached_cases) == 0

    def test_get_case_by_string_id(self) -> None:
        """Test that get_case works with string case IDs."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        case_index = {cid: idx for idx, cid in enumerate(case_ids)}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        mock_data = {
            "dwi_bytes": b"fake_dwi_nifti_data",
            "adc_bytes": b"fake_adc_nifti_data",
        }

        try:
            with patch.object(
                ds, "_download_case_from_parquet", return_value=mock_data
            ) as mock_download:
                case = ds.get_case("sub-stroke0002")
                assert case["dwi"].exists()
                # Should have been called with index 1 (second case)
                mock_download.assert_called_once_with(1, "sub-stroke0002")
        finally:
            ds.cleanup()

    def test_get_case_raises_key_error_for_invalid_id(self) -> None:
        """Test that get_case raises KeyError for invalid case ID."""
        case_ids = ["sub-stroke0001"]
        case_index = {"sub-stroke0001": 0}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        with pytest.raises(KeyError, match="not found in dataset"):
            ds.get_case("sub-stroke9999")

    def test_get_case_raises_index_error_for_out_of_range(self) -> None:
        """Test that get_case raises IndexError for out of range index."""
        case_ids = ["sub-stroke0001"]
        case_index = {"sub-stroke0001": 0}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        with pytest.raises(IndexError, match="out of range"):
            ds.get_case(99)


class TestBuildHuggingFaceDataset:
    """Tests for build_huggingface_dataset function."""

    def test_uses_precomputed_case_ids(self) -> None:
        """Test that build_huggingface_dataset uses pre-computed case IDs."""
        result = build_huggingface_dataset("hugging-science/isles24-stroke")

        assert isinstance(result, HuggingFaceDataset)
        assert result.dataset_id == "hugging-science/isles24-stroke"
        # Should have 149 cases from pre-computed list
        assert len(result._case_ids) == 149
        assert "sub-stroke0001" in result._case_ids
        assert "sub-stroke0189" in result._case_ids

    def test_case_index_mapping_is_correct(self) -> None:
        """Test that case index mapping matches case IDs order."""
        result = build_huggingface_dataset("hugging-science/isles24-stroke")

        # First case should map to index 0
        assert result._case_index["sub-stroke0001"] == 0
        # Last case should map to index 148
        assert result._case_index["sub-stroke0189"] == 148

    def test_warns_for_different_dataset_id(self) -> None:
        """Test that a warning is logged for non-standard dataset IDs."""
        from stroke_deepisles_demo.data.adapter import logger

        with patch.object(logger, "warning") as mock_warning:
            build_huggingface_dataset("some-other/dataset")
            mock_warning.assert_called_once()
            assert "does not match pre-computed constants" in mock_warning.call_args[0][0]


class TestDownloadCaseFromParquet:
    """Tests for _download_case_from_parquet method."""

    def test_raises_data_load_error_on_malformed_data(self) -> None:
        """Test that _download_case_from_parquet raises DataLoadError for malformed data."""
        import pandas as pd  # type: ignore[import-untyped]

        case_ids = ["sub-stroke0001"]
        case_index = {"sub-stroke0001": 0}

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _case_ids=case_ids,
            _case_index=case_index,
        )

        # Create mock with missing 'bytes' key
        mock_df = pd.DataFrame(
            [
                {
                    "subject_id": "sub-stroke0001",
                    "dwi": {},  # Missing 'bytes'
                    "adc": {},
                    "lesion_mask": None,
                }
            ]
        )

        mock_table = MagicMock()
        mock_table.to_pandas.return_value = mock_df

        mock_pf = MagicMock()
        mock_pf.read.return_value = mock_table

        mock_file = MagicMock()
        mock_file.__enter__ = MagicMock(return_value=mock_file)
        mock_file.__exit__ = MagicMock(return_value=False)

        mock_fs = MagicMock()
        mock_fs.open.return_value = mock_file

        # Patch at the source module where they're imported, not where they're used
        with (
            patch("huggingface_hub.HfFileSystem", return_value=mock_fs),
            patch("pyarrow.parquet.ParquetFile", return_value=mock_pf),
            pytest.raises(DataLoadError, match="Malformed HuggingFace data"),
        ):
            ds._download_case_from_parquet(0, "sub-stroke0001")