VibecoderMcSwaggins's picture
feat(phase-1): implement data access layer with TDD (#2)
3c4c67b unverified
raw
history blame
2.81 kB
"""Shared test fixtures."""
from __future__ import annotations
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
import nibabel as nib
import numpy as np
import pytest
from stroke_deepisles_demo.core.types import CaseFiles
if TYPE_CHECKING:
from collections.abc import Generator, Iterator
@pytest.fixture
def temp_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for test outputs."""
with tempfile.TemporaryDirectory() as td:
yield Path(td)
@pytest.fixture
def synthetic_nifti_3d(temp_dir: Path) -> Path:
"""Create a minimal synthetic 3D NIfTI file."""
data = np.random.rand(10, 10, 10).astype(np.float32)
img = nib.Nifti1Image(data, affine=np.eye(4)) # type: ignore
path = temp_dir / "synthetic.nii.gz"
nib.save(img, path) # type: ignore
return path
@pytest.fixture
def synthetic_case_files(temp_dir: Path) -> CaseFiles:
"""Create a complete set of synthetic case files."""
# Create DWI
dwi_data = np.random.rand(64, 64, 30).astype(np.float32)
dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4)) # type: ignore
dwi_path = temp_dir / "dwi.nii.gz"
nib.save(dwi_img, dwi_path) # type: ignore
# Create ADC
adc_data = np.random.rand(64, 64, 30).astype(np.float32) * 2000
adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4)) # type: ignore
adc_path = temp_dir / "adc.nii.gz"
nib.save(adc_img, adc_path) # type: ignore
# Create mask
mask_data = (np.random.rand(64, 64, 30) > 0.9).astype(np.uint8)
mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4)) # type: ignore
mask_path = temp_dir / "mask.nii.gz"
nib.save(mask_img, mask_path) # type: ignore
return CaseFiles(
dwi=dwi_path,
adc=adc_path,
ground_truth=mask_path,
)
@pytest.fixture
def mock_hf_dataset(synthetic_case_files: CaseFiles) -> object:
"""Create a mock HF Dataset-like object."""
# Simple list-based mock that mimics dataset behavior
class MockDataset:
def __init__(self) -> None:
self.data = [
{
"participant_id": "sub-001",
"dwi": str(synthetic_case_files["dwi"]),
"adc": str(synthetic_case_files["adc"]),
"flair": None,
"mask": str(synthetic_case_files.get("ground_truth")),
}
]
self.features = {"dwi": None, "adc": None, "flair": None, "mask": None}
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> dict[str, str | None]:
return self.data[idx]
def __iter__(self) -> Iterator[dict[str, str | None]]:
return iter(self.data)
return MockDataset()