|
|
"""Unit tests for HuggingFace dataset wrapper.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any |
|
|
from unittest.mock import MagicMock |
|
|
|
|
|
import pytest |
|
|
|
|
|
from stroke_deepisles_demo.data.loader import HuggingFaceDatasetWrapper |
|
|
|
|
|
|
|
|
class TestHuggingFaceDatasetWrapper: |
|
|
"""Tests for HuggingFaceDatasetWrapper class.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def mock_hf_dataset(self) -> MagicMock: |
|
|
"""Create a mock HuggingFace dataset.""" |
|
|
dataset = MagicMock() |
|
|
|
|
|
|
|
|
dataset.__len__.return_value = 3 |
|
|
|
|
|
|
|
|
|
|
|
dataset.__getitem__.side_effect = lambda key: ( |
|
|
["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"] |
|
|
if key == "subject_id" |
|
|
else MagicMock() |
|
|
) |
|
|
|
|
|
return dataset |
|
|
|
|
|
def test_init_builds_index_correctly(self, mock_hf_dataset: MagicMock) -> None: |
|
|
"""Test that initialization builds the subject ID index.""" |
|
|
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset") |
|
|
|
|
|
assert len(wrapper) == 3 |
|
|
assert wrapper.list_case_ids() == ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"] |
|
|
assert wrapper._case_id_to_index["sub-stroke0001"] == 0 |
|
|
assert wrapper._case_id_to_index["sub-stroke0003"] == 2 |
|
|
|
|
|
def test_get_case_materializes_files(self, mock_hf_dataset: MagicMock) -> None: |
|
|
"""Test that get_case materializes NIfTI objects to files.""" |
|
|
|
|
|
mock_dwi = MagicMock() |
|
|
mock_adc = MagicMock() |
|
|
mock_mask = MagicMock() |
|
|
|
|
|
row_data = { |
|
|
"subject_id": "sub-stroke0001", |
|
|
"dwi": mock_dwi, |
|
|
"adc": mock_adc, |
|
|
"lesion_mask": mock_mask, |
|
|
} |
|
|
|
|
|
|
|
|
mock_hf_dataset.__getitem__.side_effect = ( |
|
|
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"] |
|
|
) |
|
|
|
|
|
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset") |
|
|
|
|
|
with wrapper: |
|
|
case = wrapper.get_case("sub-stroke0001") |
|
|
|
|
|
|
|
|
assert case["dwi"].name == "sub-stroke0001_dwi.nii.gz" |
|
|
assert case["adc"].name == "sub-stroke0001_adc.nii.gz" |
|
|
assert case["ground_truth"].name == "sub-stroke0001_lesion-msk.nii.gz" |
|
|
|
|
|
|
|
|
mock_dwi.to_filename.assert_called_once() |
|
|
mock_adc.to_filename.assert_called_once() |
|
|
mock_mask.to_filename.assert_called_once() |
|
|
|
|
|
|
|
|
assert wrapper._temp_dir is not None |
|
|
assert case["dwi"].parent == wrapper._temp_dir / "sub-stroke0001" |
|
|
|
|
|
def test_get_case_handles_missing_mask(self, mock_hf_dataset: MagicMock) -> None: |
|
|
"""Test that get_case handles cases without lesion mask.""" |
|
|
row_data = { |
|
|
"subject_id": "sub-stroke0002", |
|
|
"dwi": MagicMock(), |
|
|
"adc": MagicMock(), |
|
|
"lesion_mask": None, |
|
|
} |
|
|
|
|
|
mock_hf_dataset.__getitem__.side_effect = ( |
|
|
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0002"] |
|
|
) |
|
|
|
|
|
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset") |
|
|
|
|
|
with wrapper: |
|
|
case = wrapper.get_case("sub-stroke0002") |
|
|
|
|
|
assert "dwi" in case |
|
|
assert "adc" in case |
|
|
assert "ground_truth" not in case |
|
|
|
|
|
def test_cleanup_removes_temp_dir(self, mock_hf_dataset: MagicMock) -> None: |
|
|
"""Test that cleanup removes the temporary directory.""" |
|
|
row_data = { |
|
|
"subject_id": "sub-stroke0001", |
|
|
"dwi": MagicMock(), |
|
|
"adc": MagicMock(), |
|
|
"lesion_mask": None, |
|
|
} |
|
|
mock_hf_dataset.__getitem__.side_effect = ( |
|
|
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"] |
|
|
) |
|
|
|
|
|
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset") |
|
|
|
|
|
|
|
|
wrapper.get_case(0) |
|
|
temp_dir = wrapper._temp_dir |
|
|
|
|
|
assert temp_dir is not None |
|
|
assert temp_dir.exists() |
|
|
|
|
|
|
|
|
wrapper.cleanup() |
|
|
|
|
|
assert not temp_dir.exists() |
|
|
assert wrapper._temp_dir is None |
|
|
|
|
|
def test_fallback_iteration(self) -> None: |
|
|
"""Test fallback to iteration if column access fails.""" |
|
|
dataset = MagicMock() |
|
|
dataset.__len__.return_value = 2 |
|
|
|
|
|
|
|
|
dataset.__iter__.return_value = iter([{"subject_id": "sub-0"}, {"subject_id": "sub-1"}]) |
|
|
|
|
|
|
|
|
def getitem(key: Any) -> Any: |
|
|
if key == "subject_id": |
|
|
raise ValueError("No column access") |
|
|
if isinstance(key, int): |
|
|
return {"subject_id": f"sub-{key}"} |
|
|
return MagicMock() |
|
|
|
|
|
dataset.__getitem__.side_effect = getitem |
|
|
|
|
|
wrapper = HuggingFaceDatasetWrapper(dataset, "test/dataset") |
|
|
|
|
|
assert wrapper._case_id_to_index["sub-0"] == 0 |
|
|
assert wrapper._case_id_to_index["sub-1"] == 1 |
|
|
|