|
|
"""Shared test fixtures.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import nibabel as nib |
|
|
import numpy as np |
|
|
import pytest |
|
|
|
|
|
from stroke_deepisles_demo.core.types import CaseFiles |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Generator, Iterator |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def temp_dir() -> Generator[Path, None, None]: |
|
|
"""Create a temporary directory for test outputs.""" |
|
|
with tempfile.TemporaryDirectory() as td: |
|
|
yield Path(td) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def synthetic_nifti_3d(temp_dir: Path) -> Path: |
|
|
"""Create a minimal synthetic 3D NIfTI file.""" |
|
|
data = np.random.rand(10, 10, 10).astype(np.float32) |
|
|
img = nib.Nifti1Image(data, affine=np.eye(4)) |
|
|
path = temp_dir / "synthetic.nii.gz" |
|
|
nib.save(img, path) |
|
|
return path |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def synthetic_case_files(temp_dir: Path) -> CaseFiles: |
|
|
"""Create a complete set of synthetic case files.""" |
|
|
|
|
|
dwi_data = np.random.rand(64, 64, 30).astype(np.float32) |
|
|
dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4)) |
|
|
dwi_path = temp_dir / "dwi.nii.gz" |
|
|
nib.save(dwi_img, dwi_path) |
|
|
|
|
|
|
|
|
adc_data = np.random.rand(64, 64, 30).astype(np.float32) * 2000 |
|
|
adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4)) |
|
|
adc_path = temp_dir / "adc.nii.gz" |
|
|
nib.save(adc_img, adc_path) |
|
|
|
|
|
|
|
|
mask_data = (np.random.rand(64, 64, 30) > 0.9).astype(np.uint8) |
|
|
mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4)) |
|
|
mask_path = temp_dir / "mask.nii.gz" |
|
|
nib.save(mask_img, mask_path) |
|
|
|
|
|
return CaseFiles( |
|
|
dwi=dwi_path, |
|
|
adc=adc_path, |
|
|
ground_truth=mask_path, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_hf_dataset(synthetic_case_files: CaseFiles) -> object: |
|
|
"""Create a mock HF Dataset-like object.""" |
|
|
|
|
|
|
|
|
class MockDataset: |
|
|
def __init__(self) -> None: |
|
|
self.data = [ |
|
|
{ |
|
|
"participant_id": "sub-001", |
|
|
"dwi": str(synthetic_case_files["dwi"]), |
|
|
"adc": str(synthetic_case_files["adc"]), |
|
|
"flair": None, |
|
|
"mask": str(synthetic_case_files.get("ground_truth")), |
|
|
} |
|
|
] |
|
|
self.features = {"dwi": None, "adc": None, "flair": None, "mask": None} |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> dict[str, str | None]: |
|
|
return self.data[idx] |
|
|
|
|
|
def __iter__(self) -> Iterator[dict[str, str | None]]: |
|
|
return iter(self.data) |
|
|
|
|
|
return MockDataset() |
|
|
|