stroke-deepisles-demo / tests /data /test_hf_adapter.py
VibecoderMcSwaggins's picture
refactor(data): use standard datasets.load_dataset() with neuroimaging-go-brrrr
262b3cb unverified
"""Unit tests for HuggingFace dataset wrapper."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from stroke_deepisles_demo.data.loader import HuggingFaceDatasetWrapper
class TestHuggingFaceDatasetWrapper:
"""Tests for HuggingFaceDatasetWrapper class."""
@pytest.fixture
def mock_hf_dataset(self) -> MagicMock:
"""Create a mock HuggingFace dataset."""
dataset = MagicMock()
# Mock dataset length
dataset.__len__.return_value = 3
# Mock column access for fast index building
# This simulates dataset["subject_id"]
dataset.__getitem__.side_effect = lambda key: (
["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
if key == "subject_id"
else MagicMock()
)
return dataset
def test_init_builds_index_correctly(self, mock_hf_dataset: MagicMock) -> None:
"""Test that initialization builds the subject ID index."""
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
assert len(wrapper) == 3
assert wrapper.list_case_ids() == ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
assert wrapper._case_id_to_index["sub-stroke0001"] == 0
assert wrapper._case_id_to_index["sub-stroke0003"] == 2
def test_get_case_materializes_files(self, mock_hf_dataset: MagicMock) -> None:
"""Test that get_case materializes NIfTI objects to files."""
# Setup row return for get_case
mock_dwi = MagicMock()
mock_adc = MagicMock()
mock_mask = MagicMock()
row_data = {
"subject_id": "sub-stroke0001",
"dwi": mock_dwi,
"adc": mock_adc,
"lesion_mask": mock_mask,
}
# Reset side_effect to return row for integer index
mock_hf_dataset.__getitem__.side_effect = (
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"]
)
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
with wrapper:
case = wrapper.get_case("sub-stroke0001")
# Verify file paths
assert case["dwi"].name == "sub-stroke0001_dwi.nii.gz"
assert case["adc"].name == "sub-stroke0001_adc.nii.gz"
assert case["ground_truth"].name == "sub-stroke0001_lesion-msk.nii.gz"
# Verify to_filename called
mock_dwi.to_filename.assert_called_once()
mock_adc.to_filename.assert_called_once()
mock_mask.to_filename.assert_called_once()
# Verify temporary directory usage
assert wrapper._temp_dir is not None
assert case["dwi"].parent == wrapper._temp_dir / "sub-stroke0001"
def test_get_case_handles_missing_mask(self, mock_hf_dataset: MagicMock) -> None:
"""Test that get_case handles cases without lesion mask."""
row_data = {
"subject_id": "sub-stroke0002",
"dwi": MagicMock(),
"adc": MagicMock(),
"lesion_mask": None,
}
mock_hf_dataset.__getitem__.side_effect = (
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0002"]
)
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
with wrapper:
case = wrapper.get_case("sub-stroke0002")
assert "dwi" in case
assert "adc" in case
assert "ground_truth" not in case
def test_cleanup_removes_temp_dir(self, mock_hf_dataset: MagicMock) -> None:
"""Test that cleanup removes the temporary directory."""
row_data = {
"subject_id": "sub-stroke0001",
"dwi": MagicMock(),
"adc": MagicMock(),
"lesion_mask": None,
}
mock_hf_dataset.__getitem__.side_effect = (
lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"]
)
wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
# Create temp dir by accessing a case
wrapper.get_case(0)
temp_dir = wrapper._temp_dir
assert temp_dir is not None
assert temp_dir.exists()
# cleanup
wrapper.cleanup()
assert not temp_dir.exists()
assert wrapper._temp_dir is None
def test_fallback_iteration(self) -> None:
"""Test fallback to iteration if column access fails."""
dataset = MagicMock()
dataset.__len__.return_value = 2
# Configure iteration for fallback
dataset.__iter__.return_value = iter([{"subject_id": "sub-0"}, {"subject_id": "sub-1"}])
# Fail column access
def getitem(key: Any) -> Any:
if key == "subject_id":
raise ValueError("No column access")
if isinstance(key, int):
return {"subject_id": f"sub-{key}"}
return MagicMock()
dataset.__getitem__.side_effect = getitem
wrapper = HuggingFaceDatasetWrapper(dataset, "test/dataset")
assert wrapper._case_id_to_index["sub-0"] == 0
assert wrapper._case_id_to_index["sub-1"] == 1