File size: 5,221 Bytes
262b3cb 363ba14 262b3cb 363ba14 262b3cb 363ba14 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 363ba14 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 363ba14 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 363ba14 262b3cb 363ba14 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb 80cbb1a 262b3cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""Unit tests for HuggingFace dataset wrapper."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from stroke_deepisles_demo.data.loader import HuggingFaceDatasetWrapper
class TestHuggingFaceDatasetWrapper:
"""Tests for HuggingFaceDatasetWrapper class."""
@pytest.fixture
def mock_hf_dataset(self) -> MagicMock:
"""Create a mock HuggingFace dataset."""
dataset = MagicMock()
# Mock dataset length
dataset.__len__.return_value = 3
# Mock column access for fast index building
# This simulates dataset["subject_id"]
dataset.__getitem__.side_effect = lambda key: (
["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
if key == "subject_id"
else MagicMock()
)
return dataset
def test_init_builds_index_correctly(self, mock_hf_dataset: MagicMock) -> None:
"""Test that initialization builds the subject ID index."""
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
assert len(wrapper) == 3
assert wrapper.list_case_ids() == ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
assert wrapper._case_id_to_index["sub-stroke0001"] == 0
assert wrapper._case_id_to_index["sub-stroke0003"] == 2
def test_get_case_materializes_files(self, mock_hf_dataset: MagicMock) -> None:
"""Test that get_case materializes NIfTI objects to files."""
# Setup row return for get_case
mock_dwi = MagicMock()
mock_adc = MagicMock()
mock_mask = MagicMock()
row_data = {
"subject_id": "sub-stroke0001",
"dwi": mock_dwi,
"adc": mock_adc,
"lesion_mask": mock_mask,
}
# Reset side_effect to return row for integer index
mock_hf_dataset.__getitem__.side_effect = (
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"]
)
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
with wrapper:
case = wrapper.get_case("sub-stroke0001")
# Verify file paths
assert case["dwi"].name == "sub-stroke0001_dwi.nii.gz"
assert case["adc"].name == "sub-stroke0001_adc.nii.gz"
assert case["ground_truth"].name == "sub-stroke0001_lesion-msk.nii.gz"
# Verify to_filename called
mock_dwi.to_filename.assert_called_once()
mock_adc.to_filename.assert_called_once()
mock_mask.to_filename.assert_called_once()
# Verify temporary directory usage
assert wrapper._temp_dir is not None
assert case["dwi"].parent == wrapper._temp_dir / "sub-stroke0001"
def test_get_case_handles_missing_mask(self, mock_hf_dataset: MagicMock) -> None:
"""Test that get_case handles cases without lesion mask."""
row_data = {
"subject_id": "sub-stroke0002",
"dwi": MagicMock(),
"adc": MagicMock(),
"lesion_mask": None,
}
mock_hf_dataset.__getitem__.side_effect = (
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0002"]
)
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
with wrapper:
case = wrapper.get_case("sub-stroke0002")
assert "dwi" in case
assert "adc" in case
assert "ground_truth" not in case
def test_cleanup_removes_temp_dir(self, mock_hf_dataset: MagicMock) -> None:
"""Test that cleanup removes the temporary directory."""
row_data = {
"subject_id": "sub-stroke0001",
"dwi": MagicMock(),
"adc": MagicMock(),
"lesion_mask": None,
}
mock_hf_dataset.__getitem__.side_effect = (
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"]
)
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
# Create temp dir by accessing a case
wrapper.get_case(0)
temp_dir = wrapper._temp_dir
assert temp_dir is not None
assert temp_dir.exists()
# cleanup
wrapper.cleanup()
assert not temp_dir.exists()
assert wrapper._temp_dir is None
def test_fallback_iteration(self) -> None:
"""Test fallback to iteration if column access fails."""
dataset = MagicMock()
dataset.__len__.return_value = 2
# Configure iteration for fallback
dataset.__iter__.return_value = iter([{"subject_id": "sub-0"}, {"subject_id": "sub-1"}])
# Fail column access
def getitem(key: Any) -> Any:
if key == "subject_id":
raise ValueError("No column access")
if isinstance(key, int):
return {"subject_id": f"sub-{key}"}
return MagicMock()
dataset.__getitem__.side_effect = getitem
wrapper = HuggingFaceDatasetWrapper(dataset, "test/dataset")
assert wrapper._case_id_to_index["sub-0"] == 0
assert wrapper._case_id_to_index["sub-1"] == 1
|