File size: 6,398 Bytes
363ba14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""Unit tests for HuggingFace dataset adapter with mocked HF dataset."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from stroke_deepisles_demo.core.exceptions import DataLoadError
from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, build_huggingface_dataset
def create_mock_hf_example(subject_id: str, include_mask: bool = True) -> dict[str, Any]:
"""Create a mock HuggingFace dataset example."""
example: dict[str, Any] = {
"subject_id": subject_id,
"dwi": {"bytes": b"fake_dwi_nifti_data", "path": f"{subject_id}_dwi.nii.gz"},
"adc": {"bytes": b"fake_adc_nifti_data", "path": f"{subject_id}_adc.nii.gz"},
}
if include_mask:
example["lesion_mask"] = {
"bytes": b"fake_mask_nifti_data",
"path": f"{subject_id}_lesion-msk.nii.gz",
}
else:
example["lesion_mask"] = None
return example
@pytest.fixture
def mock_hf_dataset() -> MagicMock:
"""Create a mock HuggingFace dataset with 3 subjects."""
examples = [
create_mock_hf_example("sub-stroke0001"),
create_mock_hf_example("sub-stroke0002"),
create_mock_hf_example("sub-stroke0003", include_mask=False),
]
mock_ds = MagicMock()
mock_ds.__len__ = MagicMock(return_value=len(examples))
mock_ds.__iter__ = MagicMock(return_value=iter(examples))
mock_ds.__getitem__ = MagicMock(side_effect=lambda i: examples[i])
return mock_ds
class TestHuggingFaceDataset:
"""Tests for HuggingFaceDataset class."""
def test_get_case_writes_files_to_temp_dir(self, mock_hf_dataset: MagicMock) -> None:
"""Test that get_case writes NIfTI bytes to temp files."""
case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
ds = HuggingFaceDataset(
dataset_id="test/dataset",
_hf_dataset=mock_hf_dataset,
_case_ids=case_ids,
)
try:
case = ds.get_case(0)
assert "dwi" in case
assert "adc" in case
assert case["dwi"].exists()
assert case["adc"].exists()
assert case["dwi"].read_bytes() == b"fake_dwi_nifti_data"
assert case["adc"].read_bytes() == b"fake_adc_nifti_data"
finally:
ds.cleanup()
def test_get_case_includes_ground_truth_when_available(
self, mock_hf_dataset: MagicMock
) -> None:
"""Test that ground truth is included when lesion_mask is present."""
case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
ds = HuggingFaceDataset(
dataset_id="test/dataset",
_hf_dataset=mock_hf_dataset,
_case_ids=case_ids,
)
try:
case = ds.get_case(0) # Has mask
assert "ground_truth" in case
assert case["ground_truth"].read_bytes() == b"fake_mask_nifti_data"
case_no_mask = ds.get_case(2) # No mask
assert "ground_truth" not in case_no_mask
finally:
ds.cleanup()
def test_get_case_caches_results(self, mock_hf_dataset: MagicMock) -> None:
"""Test that get_case returns cached paths on subsequent calls."""
case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
ds = HuggingFaceDataset(
dataset_id="test/dataset",
_hf_dataset=mock_hf_dataset,
_case_ids=case_ids,
)
try:
case1 = ds.get_case(0)
case2 = ds.get_case(0)
# Same object returned (cached)
assert case1 is case2
# Dataset was only accessed once
assert mock_hf_dataset.__getitem__.call_count == 1
finally:
ds.cleanup()
def test_context_manager_cleans_up_temp_files(self, mock_hf_dataset: MagicMock) -> None:
"""Test that using context manager cleans up temp files."""
case_ids = ["sub-stroke0001"]
ds = HuggingFaceDataset(
dataset_id="test/dataset",
_hf_dataset=mock_hf_dataset,
_case_ids=case_ids,
)
with ds:
case = ds.get_case(0)
temp_dir = case["dwi"].parent.parent
assert temp_dir.exists()
# After context exit, temp dir should be gone
assert not temp_dir.exists()
def test_cleanup_clears_cache(self, mock_hf_dataset: MagicMock) -> None:
"""Test that cleanup clears the case cache."""
case_ids = ["sub-stroke0001"]
ds = HuggingFaceDataset(
dataset_id="test/dataset",
_hf_dataset=mock_hf_dataset,
_case_ids=case_ids,
)
ds.get_case(0)
assert len(ds._cached_cases) == 1
ds.cleanup()
assert len(ds._cached_cases) == 0
def test_get_case_raises_data_load_error_on_malformed_data(self) -> None:
"""Test that get_case raises DataLoadError for malformed HF data."""
# Create mock with missing 'bytes' key
malformed_example = {"subject_id": "sub-stroke0001", "dwi": {}, "adc": {}}
mock_ds = MagicMock()
mock_ds.__len__ = MagicMock(return_value=1)
mock_ds.__getitem__ = MagicMock(return_value=malformed_example)
ds = HuggingFaceDataset(
dataset_id="test/dataset",
_hf_dataset=mock_ds,
_case_ids=["sub-stroke0001"],
)
try:
with pytest.raises(DataLoadError, match="Malformed HuggingFace data"):
ds.get_case(0)
finally:
ds.cleanup()
class TestBuildHuggingFaceDataset:
"""Tests for build_huggingface_dataset function."""
@patch("datasets.load_dataset")
def test_loads_dataset_from_hub(self, mock_load_dataset: MagicMock) -> None:
"""Test that build_huggingface_dataset calls load_dataset correctly."""
mock_ds = MagicMock()
mock_ds.__iter__ = MagicMock(return_value=iter([{"subject_id": "sub-stroke0001"}]))
mock_load_dataset.return_value = mock_ds
result = build_huggingface_dataset("test/my-dataset")
mock_load_dataset.assert_called_once_with("test/my-dataset", split="train")
assert isinstance(result, HuggingFaceDataset)
assert result.dataset_id == "test/my-dataset"
assert result._case_ids == ["sub-stroke0001"]
|