File size: 6,398 Bytes
363ba14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""Unit tests for HuggingFace dataset adapter with mocked HF dataset."""

from __future__ import annotations

from typing import Any
from unittest.mock import MagicMock, patch

import pytest

from stroke_deepisles_demo.core.exceptions import DataLoadError
from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, build_huggingface_dataset


def create_mock_hf_example(subject_id: str, include_mask: bool = True) -> dict[str, Any]:
    """Create a mock HuggingFace dataset example."""
    example: dict[str, Any] = {
        "subject_id": subject_id,
        "dwi": {"bytes": b"fake_dwi_nifti_data", "path": f"{subject_id}_dwi.nii.gz"},
        "adc": {"bytes": b"fake_adc_nifti_data", "path": f"{subject_id}_adc.nii.gz"},
    }
    if include_mask:
        example["lesion_mask"] = {
            "bytes": b"fake_mask_nifti_data",
            "path": f"{subject_id}_lesion-msk.nii.gz",
        }
    else:
        example["lesion_mask"] = None
    return example


@pytest.fixture
def mock_hf_dataset() -> MagicMock:
    """Create a mock HuggingFace dataset with 3 subjects."""
    examples = [
        create_mock_hf_example("sub-stroke0001"),
        create_mock_hf_example("sub-stroke0002"),
        create_mock_hf_example("sub-stroke0003", include_mask=False),
    ]

    mock_ds = MagicMock()
    mock_ds.__len__ = MagicMock(return_value=len(examples))
    mock_ds.__iter__ = MagicMock(return_value=iter(examples))
    mock_ds.__getitem__ = MagicMock(side_effect=lambda i: examples[i])

    return mock_ds


class TestHuggingFaceDataset:
    """Tests for HuggingFaceDataset class."""

    def test_get_case_writes_files_to_temp_dir(self, mock_hf_dataset: MagicMock) -> None:
        """Test that get_case writes NIfTI bytes to temp files."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _hf_dataset=mock_hf_dataset,
            _case_ids=case_ids,
        )

        try:
            case = ds.get_case(0)

            assert "dwi" in case
            assert "adc" in case
            assert case["dwi"].exists()
            assert case["adc"].exists()
            assert case["dwi"].read_bytes() == b"fake_dwi_nifti_data"
            assert case["adc"].read_bytes() == b"fake_adc_nifti_data"
        finally:
            ds.cleanup()

    def test_get_case_includes_ground_truth_when_available(
        self, mock_hf_dataset: MagicMock
    ) -> None:
        """Test that ground truth is included when lesion_mask is present."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _hf_dataset=mock_hf_dataset,
            _case_ids=case_ids,
        )

        try:
            case = ds.get_case(0)  # Has mask
            assert "ground_truth" in case
            assert case["ground_truth"].read_bytes() == b"fake_mask_nifti_data"

            case_no_mask = ds.get_case(2)  # No mask
            assert "ground_truth" not in case_no_mask
        finally:
            ds.cleanup()

    def test_get_case_caches_results(self, mock_hf_dataset: MagicMock) -> None:
        """Test that get_case returns cached paths on subsequent calls."""
        case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _hf_dataset=mock_hf_dataset,
            _case_ids=case_ids,
        )

        try:
            case1 = ds.get_case(0)
            case2 = ds.get_case(0)

            # Same object returned (cached)
            assert case1 is case2

            # Dataset was only accessed once
            assert mock_hf_dataset.__getitem__.call_count == 1
        finally:
            ds.cleanup()

    def test_context_manager_cleans_up_temp_files(self, mock_hf_dataset: MagicMock) -> None:
        """Test that using context manager cleans up temp files."""
        case_ids = ["sub-stroke0001"]
        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _hf_dataset=mock_hf_dataset,
            _case_ids=case_ids,
        )

        with ds:
            case = ds.get_case(0)
            temp_dir = case["dwi"].parent.parent
            assert temp_dir.exists()

        # After context exit, temp dir should be gone
        assert not temp_dir.exists()

    def test_cleanup_clears_cache(self, mock_hf_dataset: MagicMock) -> None:
        """Test that cleanup clears the case cache."""
        case_ids = ["sub-stroke0001"]
        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _hf_dataset=mock_hf_dataset,
            _case_ids=case_ids,
        )

        ds.get_case(0)
        assert len(ds._cached_cases) == 1

        ds.cleanup()
        assert len(ds._cached_cases) == 0

    def test_get_case_raises_data_load_error_on_malformed_data(self) -> None:
        """Test that get_case raises DataLoadError for malformed HF data."""
        # Create mock with missing 'bytes' key
        malformed_example = {"subject_id": "sub-stroke0001", "dwi": {}, "adc": {}}
        mock_ds = MagicMock()
        mock_ds.__len__ = MagicMock(return_value=1)
        mock_ds.__getitem__ = MagicMock(return_value=malformed_example)

        ds = HuggingFaceDataset(
            dataset_id="test/dataset",
            _hf_dataset=mock_ds,
            _case_ids=["sub-stroke0001"],
        )

        try:
            with pytest.raises(DataLoadError, match="Malformed HuggingFace data"):
                ds.get_case(0)
        finally:
            ds.cleanup()


class TestBuildHuggingFaceDataset:
    """Tests for build_huggingface_dataset function."""

    @patch("datasets.load_dataset")
    def test_loads_dataset_from_hub(self, mock_load_dataset: MagicMock) -> None:
        """Test that build_huggingface_dataset calls load_dataset correctly."""
        mock_ds = MagicMock()
        mock_ds.__iter__ = MagicMock(return_value=iter([{"subject_id": "sub-stroke0001"}]))
        mock_load_dataset.return_value = mock_ds

        result = build_huggingface_dataset("test/my-dataset")

        mock_load_dataset.assert_called_once_with("test/my-dataset", split="train")
        assert isinstance(result, HuggingFaceDataset)
        assert result.dataset_id == "test/my-dataset"
        assert result._case_ids == ["sub-stroke0001"]