feat(phase-1): implement data access layer with TDD (#2)
Browse files- .pre-commit-config.yaml +6 -6
- Makefile +18 -0
- pyproject.toml +2 -0
- src/stroke_deepisles_demo/data/__init__.py +42 -1
- src/stroke_deepisles_demo/data/adapter.py +147 -0
- src/stroke_deepisles_demo/data/loader.py +138 -0
- src/stroke_deepisles_demo/data/staging.py +150 -0
- tests/conftest.py +88 -2
- tests/data/test_adapter.py +70 -0
- tests/data/test_loader.py +90 -0
- tests/data/test_staging.py +77 -0
.pre-commit-config.yaml
CHANGED
|
@@ -6,14 +6,14 @@ repos:
|
|
| 6 |
args: [--fix]
|
| 7 |
- id: ruff-format
|
| 8 |
|
| 9 |
-
- repo:
|
| 10 |
-
rev: v1.19.0
|
| 11 |
hooks:
|
| 12 |
- id: mypy
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 19 |
rev: v6.0.0
|
|
|
|
| 6 |
args: [--fix]
|
| 7 |
- id: ruff-format
|
| 8 |
|
| 9 |
+
- repo: local
|
|
|
|
| 10 |
hooks:
|
| 11 |
- id: mypy
|
| 12 |
+
name: mypy
|
| 13 |
+
entry: uv run mypy
|
| 14 |
+
language: system
|
| 15 |
+
types: [python]
|
| 16 |
+
require_serial: true
|
| 17 |
|
| 18 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 19 |
rev: v6.0.0
|
Makefile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: install test lint format check all
|
| 2 |
+
|
| 3 |
+
install:
|
| 4 |
+
uv sync
|
| 5 |
+
|
| 6 |
+
test:
|
| 7 |
+
uv run pytest
|
| 8 |
+
|
| 9 |
+
lint:
|
| 10 |
+
uv run ruff check .
|
| 11 |
+
|
| 12 |
+
format:
|
| 13 |
+
uv run ruff format .
|
| 14 |
+
|
| 15 |
+
check:
|
| 16 |
+
uv run mypy src/ tests/
|
| 17 |
+
|
| 18 |
+
all: lint check test
|
pyproject.toml
CHANGED
|
@@ -102,6 +102,8 @@ module = [
|
|
| 102 |
"gradio.*",
|
| 103 |
"datasets.*",
|
| 104 |
"niivue.*",
|
|
|
|
|
|
|
| 105 |
]
|
| 106 |
ignore_missing_imports = true
|
| 107 |
|
|
|
|
| 102 |
"gradio.*",
|
| 103 |
"datasets.*",
|
| 104 |
"niivue.*",
|
| 105 |
+
"numpy.*",
|
| 106 |
+
"pytest.*",
|
| 107 |
]
|
| 108 |
ignore_missing_imports = true
|
| 109 |
|
src/stroke_deepisles_demo/data/__init__.py
CHANGED
|
@@ -1 +1,42 @@
|
|
| 1 |
-
"""Data loading
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data loading and case management for stroke-deepisles-demo."""
|
| 2 |
+
|
| 3 |
+
from stroke_deepisles_demo.data.adapter import CaseAdapter
|
| 4 |
+
from stroke_deepisles_demo.data.loader import DatasetInfo, get_dataset_info, load_isles_dataset
|
| 5 |
+
from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
# Adapter
|
| 9 |
+
"CaseAdapter",
|
| 10 |
+
# Loader
|
| 11 |
+
"DatasetInfo",
|
| 12 |
+
# Staging
|
| 13 |
+
"StagedCase",
|
| 14 |
+
"get_case",
|
| 15 |
+
"get_dataset_info",
|
| 16 |
+
"list_case_ids",
|
| 17 |
+
"load_isles_dataset",
|
| 18 |
+
"stage_case_for_deepisles",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Convenience functions (combine loader + adapter)
|
| 26 |
+
def get_case(case_id: str | int) -> CaseFiles:
|
| 27 |
+
"""
|
| 28 |
+
Load a single case by ID or index.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
CaseFiles dictionary
|
| 32 |
+
"""
|
| 33 |
+
dataset = load_isles_dataset()
|
| 34 |
+
adapter = CaseAdapter(dataset)
|
| 35 |
+
return adapter.get_case(case_id)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def list_case_ids() -> list[str]:
|
| 39 |
+
"""List all available case IDs."""
|
| 40 |
+
dataset = load_isles_dataset()
|
| 41 |
+
adapter = CaseAdapter(dataset)
|
| 42 |
+
return adapter.list_case_ids()
|
src/stroke_deepisles_demo/data/adapter.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapt HF dataset rows to typed file references."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import TYPE_CHECKING, Any
|
| 7 |
+
|
| 8 |
+
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 9 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from collections.abc import Iterator
|
| 13 |
+
|
| 14 |
+
from datasets import Dataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CaseAdapter:
|
| 18 |
+
"""
|
| 19 |
+
Adapts HuggingFace dataset to provide typed access to case files.
|
| 20 |
+
|
| 21 |
+
This handles the mapping between HF dataset structure and our
|
| 22 |
+
internal CaseFiles type.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, dataset: Dataset) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Initialize adapter with a loaded dataset.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dataset: HuggingFace Dataset with NIfTI files
|
| 31 |
+
"""
|
| 32 |
+
self.dataset = dataset
|
| 33 |
+
self._case_id_map = self._build_case_id_map()
|
| 34 |
+
|
| 35 |
+
def _build_case_id_map(self) -> dict[str, int]:
|
| 36 |
+
"""Build mapping from case ID to index."""
|
| 37 |
+
case_map = {}
|
| 38 |
+
# Assuming dataset has 'participant_id' or similar
|
| 39 |
+
# If not, we might need to generate IDs or use index
|
| 40 |
+
|
| 41 |
+
# Check features to find ID column
|
| 42 |
+
id_col = "participant_id"
|
| 43 |
+
if id_col not in self.dataset.features:
|
| 44 |
+
# Fallback: try to find a string column that looks like an ID
|
| 45 |
+
# Or just use f"case_{i}"
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
# Iterate to build map
|
| 49 |
+
# This might be slow for huge datasets, but for 149 cases it's fine
|
| 50 |
+
for idx, row in enumerate(self.dataset):
|
| 51 |
+
case_id = row.get(id_col, f"case_{idx:03d}")
|
| 52 |
+
case_map[str(case_id)] = idx
|
| 53 |
+
|
| 54 |
+
return case_map
|
| 55 |
+
|
| 56 |
+
def __len__(self) -> int:
|
| 57 |
+
"""Return number of cases in the dataset."""
|
| 58 |
+
return len(self.dataset)
|
| 59 |
+
|
| 60 |
+
def __iter__(self) -> Iterator[str]:
|
| 61 |
+
"""Iterate over case IDs."""
|
| 62 |
+
return iter(self._case_id_map.keys())
|
| 63 |
+
|
| 64 |
+
def list_case_ids(self) -> list[str]:
|
| 65 |
+
"""
|
| 66 |
+
List all available case identifiers.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
List of case IDs (e.g., ["sub-001", "sub-002", ...])
|
| 70 |
+
"""
|
| 71 |
+
return list(self._case_id_map.keys())
|
| 72 |
+
|
| 73 |
+
def get_case(self, case_id: str | int) -> CaseFiles:
|
| 74 |
+
"""
|
| 75 |
+
Get file paths for a specific case.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
case_id: Either a string ID (e.g., "sub-001") or integer index
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
CaseFiles with paths to DWI, ADC, and optionally ground truth
|
| 82 |
+
|
| 83 |
+
Raises:
|
| 84 |
+
KeyError: If case_id not found
|
| 85 |
+
DataLoadError: If files cannot be accessed
|
| 86 |
+
"""
|
| 87 |
+
if isinstance(case_id, int):
|
| 88 |
+
index = case_id
|
| 89 |
+
else:
|
| 90 |
+
if case_id not in self._case_id_map:
|
| 91 |
+
raise KeyError(f"Case ID not found: {case_id}")
|
| 92 |
+
index = self._case_id_map[case_id]
|
| 93 |
+
|
| 94 |
+
return self._get_case_by_index_internal(index)
|
| 95 |
+
|
| 96 |
+
def get_case_by_index(self, index: int) -> tuple[str, CaseFiles]:
|
| 97 |
+
"""
|
| 98 |
+
Get case by numerical index.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tuple of (case_id, CaseFiles)
|
| 102 |
+
"""
|
| 103 |
+
if index < 0 or index >= len(self.dataset):
|
| 104 |
+
raise IndexError("Case index out of range")
|
| 105 |
+
|
| 106 |
+
# Find ID for index (reverse lookup)
|
| 107 |
+
# This is inefficient O(N) if we don't store reverse map, but N is small.
|
| 108 |
+
# Or we can just get it from row again.
|
| 109 |
+
row = self.dataset[index]
|
| 110 |
+
# Assuming 'participant_id' exists or we used fallback
|
| 111 |
+
case_id = row.get("participant_id", f"case_{index:03d}")
|
| 112 |
+
|
| 113 |
+
case_files = self._row_to_case_files(row)
|
| 114 |
+
return str(case_id), case_files
|
| 115 |
+
|
| 116 |
+
def _get_case_by_index_internal(self, index: int) -> CaseFiles:
|
| 117 |
+
"""Internal helper to get CaseFiles by index."""
|
| 118 |
+
row = self.dataset[index]
|
| 119 |
+
return self._row_to_case_files(row)
|
| 120 |
+
|
| 121 |
+
def _row_to_case_files(self, row: dict[str, Any]) -> CaseFiles:
|
| 122 |
+
"""Convert a dataset row to CaseFiles."""
|
| 123 |
+
# Map columns. DeepISLES needs DWI and ADC.
|
| 124 |
+
# Dataset columns might vary. Based on spec/mock: 'dwi', 'adc', 'flair', 'mask'
|
| 125 |
+
|
| 126 |
+
# Helper to ensure we return Path if it's a local string path, or keep as is
|
| 127 |
+
def to_path_or_raw(val: Any) -> Any:
|
| 128 |
+
if isinstance(val, str) and not val.startswith(("http://", "https://")):
|
| 129 |
+
return Path(val)
|
| 130 |
+
return val
|
| 131 |
+
|
| 132 |
+
dwi = to_path_or_raw(row.get("dwi"))
|
| 133 |
+
adc = to_path_or_raw(row.get("adc"))
|
| 134 |
+
flair = to_path_or_raw(row.get("flair"))
|
| 135 |
+
ground_truth = to_path_or_raw(row.get("mask") or row.get("ground_truth"))
|
| 136 |
+
|
| 137 |
+
if not dwi or not adc:
|
| 138 |
+
raise DataLoadError("Case missing required DWI or ADC files")
|
| 139 |
+
|
| 140 |
+
case_files = CaseFiles(dwi=dwi, adc=adc)
|
| 141 |
+
|
| 142 |
+
if flair:
|
| 143 |
+
case_files["flair"] = flair
|
| 144 |
+
if ground_truth:
|
| 145 |
+
case_files["ground_truth"] = ground_truth
|
| 146 |
+
|
| 147 |
+
return case_files
|
src/stroke_deepisles_demo/data/loader.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load ISLES24-MR-Lite dataset from HuggingFace Hub."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from datasets import Dataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_isles_dataset(
|
| 19 |
+
dataset_id: str = "YongchengYAO/ISLES24-MR-Lite",
|
| 20 |
+
*,
|
| 21 |
+
cache_dir: Path | None = None,
|
| 22 |
+
streaming: bool = False,
|
| 23 |
+
) -> Dataset:
|
| 24 |
+
"""
|
| 25 |
+
Load the ISLES24-MR-Lite dataset from HuggingFace Hub.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
dataset_id: HuggingFace dataset identifier
|
| 29 |
+
cache_dir: Local cache directory (uses HF default if None)
|
| 30 |
+
streaming: If True, use streaming mode (lazy loading)
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
HuggingFace Dataset object with BIDS/NIfTI support
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
DataLoadError: If dataset cannot be loaded
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
# The pinned fork supports BIDS/NIfTI properly.
|
| 40 |
+
# We pass trust_remote_code=True if needed for custom scripts,
|
| 41 |
+
# but standard datasets usually don't need it unless using custom builder.
|
| 42 |
+
# ISLES24-MR-Lite is likely a standard dataset or Parquet-based.
|
| 43 |
+
# If it's BIDS, we might need type="bids" if the PR features are used that way.
|
| 44 |
+
# For now, standard load_dataset.
|
| 45 |
+
|
| 46 |
+
ds = load_dataset(
|
| 47 |
+
dataset_id,
|
| 48 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
| 49 |
+
streaming=streaming,
|
| 50 |
+
# If the dataset is BIDS, we might need a specific config/builder.
|
| 51 |
+
# Assuming default works or it's already parquet.
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# If streaming, load_dataset returns IterableDataset.
|
| 55 |
+
# If not, it returns DatasetDict or Dataset.
|
| 56 |
+
# We assume it returns the 'train' split if it's a DatasetDict, or we handle it.
|
| 57 |
+
# Usually load_dataset returns DatasetDict unless split is specified.
|
| 58 |
+
|
| 59 |
+
if hasattr(ds, "keys"):
|
| 60 |
+
keys = list(ds.keys())
|
| 61 |
+
if "train" in keys:
|
| 62 |
+
return ds["train"]
|
| 63 |
+
elif len(keys) > 0:
|
| 64 |
+
# Fallback to first split if 'train' not found
|
| 65 |
+
return ds[keys[0]]
|
| 66 |
+
|
| 67 |
+
return ds
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
raise DataLoadError(f"Failed to load dataset {dataset_id}: {e}") from e
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass
|
| 74 |
+
class DatasetInfo:
|
| 75 |
+
"""Metadata about the loaded dataset."""
|
| 76 |
+
|
| 77 |
+
dataset_id: str
|
| 78 |
+
num_cases: int
|
| 79 |
+
modalities: list[str] # e.g., ["dwi", "adc", "mask"]
|
| 80 |
+
has_ground_truth: bool
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_dataset_info(dataset_id: str = "YongchengYAO/ISLES24-MR-Lite") -> DatasetInfo:
|
| 84 |
+
"""
|
| 85 |
+
Get metadata about the dataset without downloading (if possible).
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
DatasetInfo with case count, available modalities, etc.
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
# Load in streaming mode to get features/info cheaply
|
| 92 |
+
ds = load_isles_dataset(dataset_id, streaming=True)
|
| 93 |
+
|
| 94 |
+
# Count cases (might be slow for streaming, but okay for demo scale)
|
| 95 |
+
# Or check if info is available
|
| 96 |
+
if hasattr(ds, "info") and ds.info.splits:
|
| 97 |
+
# Approximate from splits info if available
|
| 98 |
+
num_cases = ds.info.splits["train"].num_examples
|
| 99 |
+
else:
|
| 100 |
+
# Iterate to count? Or just rely on known size?
|
| 101 |
+
# For streaming, len() might not work.
|
| 102 |
+
# Let's just load non-streaming but with no data download? No.
|
| 103 |
+
# Let's just assume we can get length if we loaded it.
|
| 104 |
+
# If we loaded it streaming, we might not get length.
|
| 105 |
+
# For the demo, let's just try to get it.
|
| 106 |
+
|
| 107 |
+
# If we can't get length easily from streaming, we might need to trust metadata.
|
| 108 |
+
# Or just iterate (expensive).
|
| 109 |
+
# Let's use a safer approach: load non-streaming (lazy) might download metadata only.
|
| 110 |
+
# But datasets downloads parquet files.
|
| 111 |
+
|
| 112 |
+
# For get_dataset_info, maybe we just load it fully? No, expensive.
|
| 113 |
+
# Let's use streaming and try to get info.
|
| 114 |
+
num_cases = 0
|
| 115 |
+
# Use a fixed number if we can't determine?
|
| 116 |
+
# Or just count - 149 is small.
|
| 117 |
+
# But streaming iteration means network calls.
|
| 118 |
+
|
| 119 |
+
# Try to access info object
|
| 120 |
+
if hasattr(ds, "n_shards"):
|
| 121 |
+
# Approximate?
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
# Fallback: 149 (known)
|
| 125 |
+
num_cases = 149
|
| 126 |
+
|
| 127 |
+
features = ds.features.keys()
|
| 128 |
+
modalities = [k for k in features if k in ["dwi", "adc", "flair"]]
|
| 129 |
+
has_ground_truth = "mask" in features or "ground_truth" in features
|
| 130 |
+
|
| 131 |
+
return DatasetInfo(
|
| 132 |
+
dataset_id=dataset_id,
|
| 133 |
+
num_cases=num_cases,
|
| 134 |
+
modalities=sorted(modalities),
|
| 135 |
+
has_ground_truth=has_ground_truth,
|
| 136 |
+
)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
raise DataLoadError(f"Failed to get info for {dataset_id}: {e}") from e
|
src/stroke_deepisles_demo/data/staging.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stage NIfTI files with DeepISLES-expected naming."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import TYPE_CHECKING, Any, NamedTuple
|
| 9 |
+
|
| 10 |
+
from stroke_deepisles_demo.core.exceptions import MissingInputError
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class StagedCase(NamedTuple):
|
| 17 |
+
"""Paths to staged files ready for DeepISLES."""
|
| 18 |
+
|
| 19 |
+
input_dir: Path # Directory containing staged files
|
| 20 |
+
dwi_path: Path # Path to dwi.nii.gz
|
| 21 |
+
adc_path: Path # Path to adc.nii.gz
|
| 22 |
+
flair_path: Path | None # Path to flair.nii.gz if available
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def stage_case_for_deepisles(
|
| 26 |
+
case_files: CaseFiles,
|
| 27 |
+
output_dir: Path,
|
| 28 |
+
*,
|
| 29 |
+
case_id: str | None = None,
|
| 30 |
+
) -> StagedCase:
|
| 31 |
+
"""
|
| 32 |
+
Stage case files with DeepISLES-expected naming convention.
|
| 33 |
+
|
| 34 |
+
DeepISLES expects files named exactly:
|
| 35 |
+
- dwi.nii.gz
|
| 36 |
+
- adc.nii.gz
|
| 37 |
+
- flair.nii.gz (optional)
|
| 38 |
+
|
| 39 |
+
This function copies/symlinks the source files to a staging directory
|
| 40 |
+
with the correct names.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
case_files: Source file paths from CaseAdapter
|
| 44 |
+
output_dir: Directory to stage files into
|
| 45 |
+
case_id: Optional case ID for logging/subdirectory
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
StagedCase with paths to staged files
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
MissingInputError: If required files (DWI, ADC) are missing
|
| 52 |
+
OSError: If file operations fail
|
| 53 |
+
"""
|
| 54 |
+
# Create specific subdirectory if case_id provided, else use output_dir directly
|
| 55 |
+
# The spec says "output_dir: Directory to stage files into".
|
| 56 |
+
# If we append case_id, we might nest deeper than expected if output_dir is already specific.
|
| 57 |
+
# Let's use output_dir as the container.
|
| 58 |
+
|
| 59 |
+
stage_dir = output_dir
|
| 60 |
+
if case_id:
|
| 61 |
+
stage_dir = output_dir / case_id
|
| 62 |
+
|
| 63 |
+
stage_dir.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
# DWI (Required)
|
| 66 |
+
if "dwi" not in case_files or not case_files["dwi"]:
|
| 67 |
+
raise MissingInputError("DWI file is required but missing from case files.")
|
| 68 |
+
|
| 69 |
+
dwi_dest = stage_dir / "dwi.nii.gz"
|
| 70 |
+
_materialize_nifti(case_files["dwi"], dwi_dest)
|
| 71 |
+
|
| 72 |
+
# ADC (Required)
|
| 73 |
+
if "adc" not in case_files or not case_files["adc"]:
|
| 74 |
+
raise MissingInputError("ADC file is required but missing from case files.")
|
| 75 |
+
|
| 76 |
+
adc_dest = stage_dir / "adc.nii.gz"
|
| 77 |
+
_materialize_nifti(case_files["adc"], adc_dest)
|
| 78 |
+
|
| 79 |
+
# FLAIR (Optional)
|
| 80 |
+
flair_dest: Path | None = None
|
| 81 |
+
if "flair" in case_files and case_files["flair"] is not None:
|
| 82 |
+
flair_dest = stage_dir / "flair.nii.gz"
|
| 83 |
+
_materialize_nifti(case_files["flair"], flair_dest)
|
| 84 |
+
|
| 85 |
+
return StagedCase(
|
| 86 |
+
input_dir=stage_dir,
|
| 87 |
+
dwi_path=dwi_dest,
|
| 88 |
+
adc_path=adc_dest,
|
| 89 |
+
flair_path=flair_dest,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def create_staging_directory(base_dir: Path | None = None) -> Path:
|
| 94 |
+
"""
|
| 95 |
+
Create a temporary staging directory.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
base_dir: Parent directory (uses system temp if None)
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Path to created staging directory
|
| 102 |
+
"""
|
| 103 |
+
if base_dir:
|
| 104 |
+
base_dir.mkdir(parents=True, exist_ok=True)
|
| 105 |
+
return Path(tempfile.mkdtemp(dir=base_dir))
|
| 106 |
+
return Path(tempfile.mkdtemp())
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _materialize_nifti(source: Path | str | bytes | Any, dest: Path) -> None:
|
| 110 |
+
"""
|
| 111 |
+
Materialize a NIfTI file to a local path.
|
| 112 |
+
|
| 113 |
+
Handles:
|
| 114 |
+
- Local Path: copy
|
| 115 |
+
- URL string: download (not implemented yet, placeholder)
|
| 116 |
+
- bytes: write directly
|
| 117 |
+
- NIfTI object: serialize with nibabel
|
| 118 |
+
"""
|
| 119 |
+
if isinstance(source, Path):
|
| 120 |
+
if not source.exists():
|
| 121 |
+
raise MissingInputError(f"Source file does not exist: {source}")
|
| 122 |
+
# Use copy2 to preserve metadata
|
| 123 |
+
shutil.copy2(source, dest)
|
| 124 |
+
elif isinstance(source, str):
|
| 125 |
+
if source.startswith(("http://", "https://")):
|
| 126 |
+
# TODO: Implement download logic or use requests
|
| 127 |
+
# For now, we assume we don't hit this in offline tests
|
| 128 |
+
raise NotImplementedError("URL download not yet implemented")
|
| 129 |
+
else:
|
| 130 |
+
# Assume local path string
|
| 131 |
+
src_path = Path(source)
|
| 132 |
+
if not src_path.exists():
|
| 133 |
+
raise MissingInputError(f"Source file does not exist: {source}")
|
| 134 |
+
shutil.copy2(src_path, dest)
|
| 135 |
+
elif isinstance(source, bytes):
|
| 136 |
+
dest.write_bytes(source)
|
| 137 |
+
elif hasattr(source, "to_bytes"):
|
| 138 |
+
# NIfTI object (nibabel image)
|
| 139 |
+
# nibabel images don't strictly have to_bytes(), they have to_filename()
|
| 140 |
+
# But datasets might wrap them.
|
| 141 |
+
# If it's a nibabel image:
|
| 142 |
+
if hasattr(source, "to_filename"):
|
| 143 |
+
source.to_filename(dest)
|
| 144 |
+
else:
|
| 145 |
+
# Fallback for bytes-like
|
| 146 |
+
dest.write_bytes(source.to_bytes())
|
| 147 |
+
else:
|
| 148 |
+
# If it's a lazy NIfTI object from datasets, it might be tricky.
|
| 149 |
+
# Assuming mostly Path for now based on current tests.
|
| 150 |
+
raise MissingInputError(f"Cannot materialize source of type: {type(source)}")
|
tests/conftest.py
CHANGED
|
@@ -1,5 +1,91 @@
|
|
| 1 |
-
"""Shared
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared test fixtures."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import tempfile
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import nibabel as nib
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pytest
|
| 12 |
+
|
| 13 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from collections.abc import Generator, Iterator
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.fixture
|
| 20 |
+
def temp_dir() -> Generator[Path, None, None]:
|
| 21 |
+
"""Create a temporary directory for test outputs."""
|
| 22 |
+
with tempfile.TemporaryDirectory() as td:
|
| 23 |
+
yield Path(td)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@pytest.fixture
|
| 27 |
+
def synthetic_nifti_3d(temp_dir: Path) -> Path:
|
| 28 |
+
"""Create a minimal synthetic 3D NIfTI file."""
|
| 29 |
+
data = np.random.rand(10, 10, 10).astype(np.float32)
|
| 30 |
+
img = nib.Nifti1Image(data, affine=np.eye(4)) # type: ignore
|
| 31 |
+
path = temp_dir / "synthetic.nii.gz"
|
| 32 |
+
nib.save(img, path) # type: ignore
|
| 33 |
+
return path
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.fixture
|
| 37 |
+
def synthetic_case_files(temp_dir: Path) -> CaseFiles:
|
| 38 |
+
"""Create a complete set of synthetic case files."""
|
| 39 |
+
# Create DWI
|
| 40 |
+
dwi_data = np.random.rand(64, 64, 30).astype(np.float32)
|
| 41 |
+
dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4)) # type: ignore
|
| 42 |
+
dwi_path = temp_dir / "dwi.nii.gz"
|
| 43 |
+
nib.save(dwi_img, dwi_path) # type: ignore
|
| 44 |
+
|
| 45 |
+
# Create ADC
|
| 46 |
+
adc_data = np.random.rand(64, 64, 30).astype(np.float32) * 2000
|
| 47 |
+
adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4)) # type: ignore
|
| 48 |
+
adc_path = temp_dir / "adc.nii.gz"
|
| 49 |
+
nib.save(adc_img, adc_path) # type: ignore
|
| 50 |
+
|
| 51 |
+
# Create mask
|
| 52 |
+
mask_data = (np.random.rand(64, 64, 30) > 0.9).astype(np.uint8)
|
| 53 |
+
mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4)) # type: ignore
|
| 54 |
+
mask_path = temp_dir / "mask.nii.gz"
|
| 55 |
+
nib.save(mask_img, mask_path) # type: ignore
|
| 56 |
+
|
| 57 |
+
return CaseFiles(
|
| 58 |
+
dwi=dwi_path,
|
| 59 |
+
adc=adc_path,
|
| 60 |
+
ground_truth=mask_path,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@pytest.fixture
|
| 65 |
+
def mock_hf_dataset(synthetic_case_files: CaseFiles) -> object:
|
| 66 |
+
"""Create a mock HF Dataset-like object."""
|
| 67 |
+
|
| 68 |
+
# Simple list-based mock that mimics dataset behavior
|
| 69 |
+
class MockDataset:
|
| 70 |
+
def __init__(self) -> None:
|
| 71 |
+
self.data = [
|
| 72 |
+
{
|
| 73 |
+
"participant_id": "sub-001",
|
| 74 |
+
"dwi": str(synthetic_case_files["dwi"]),
|
| 75 |
+
"adc": str(synthetic_case_files["adc"]),
|
| 76 |
+
"flair": None,
|
| 77 |
+
"mask": str(synthetic_case_files.get("ground_truth")),
|
| 78 |
+
}
|
| 79 |
+
]
|
| 80 |
+
self.features = {"dwi": None, "adc": None, "flair": None, "mask": None}
|
| 81 |
+
|
| 82 |
+
def __len__(self) -> int:
|
| 83 |
+
return len(self.data)
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, idx: int) -> dict[str, str | None]:
|
| 86 |
+
return self.data[idx]
|
| 87 |
+
|
| 88 |
+
def __iter__(self) -> Iterator[dict[str, str | None]]:
|
| 89 |
+
return iter(self.data)
|
| 90 |
+
|
| 91 |
+
return MockDataset()
|
tests/data/test_adapter.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for case adapter module."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stroke_deepisles_demo.data.adapter import CaseAdapter
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from unittest.mock import MagicMock
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestCaseAdapter:
|
| 16 |
+
"""Tests for CaseAdapter."""
|
| 17 |
+
|
| 18 |
+
def test_list_case_ids_returns_strings(self, mock_hf_dataset: MagicMock) -> None:
|
| 19 |
+
"""list_case_ids returns list of string identifiers."""
|
| 20 |
+
adapter = CaseAdapter(mock_hf_dataset)
|
| 21 |
+
case_ids = adapter.list_case_ids()
|
| 22 |
+
|
| 23 |
+
assert isinstance(case_ids, list)
|
| 24 |
+
assert all(isinstance(cid, str) for cid in case_ids)
|
| 25 |
+
assert case_ids == ["sub-001"]
|
| 26 |
+
|
| 27 |
+
def test_len_matches_dataset_size(self, mock_hf_dataset: MagicMock) -> None:
|
| 28 |
+
"""len(adapter) equals number of cases in dataset."""
|
| 29 |
+
adapter = CaseAdapter(mock_hf_dataset)
|
| 30 |
+
|
| 31 |
+
assert len(adapter) == len(mock_hf_dataset)
|
| 32 |
+
|
| 33 |
+
def test_get_case_by_string_id(self, mock_hf_dataset: MagicMock) -> None:
|
| 34 |
+
"""Can retrieve case by string identifier."""
|
| 35 |
+
adapter = CaseAdapter(mock_hf_dataset)
|
| 36 |
+
case_ids = adapter.list_case_ids()
|
| 37 |
+
|
| 38 |
+
case = adapter.get_case(case_ids[0])
|
| 39 |
+
|
| 40 |
+
assert isinstance(case, dict)
|
| 41 |
+
assert "dwi" in case
|
| 42 |
+
assert "adc" in case
|
| 43 |
+
# Paths should be Path objects or convertible
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
|
| 46 |
+
assert isinstance(case["dwi"], (Path, str))
|
| 47 |
+
|
| 48 |
+
def test_get_case_by_index(self, mock_hf_dataset: MagicMock) -> None:
|
| 49 |
+
"""Can retrieve case by integer index."""
|
| 50 |
+
adapter = CaseAdapter(mock_hf_dataset)
|
| 51 |
+
|
| 52 |
+
case_id, case = adapter.get_case_by_index(0)
|
| 53 |
+
|
| 54 |
+
assert isinstance(case_id, str)
|
| 55 |
+
assert case["dwi"] is not None
|
| 56 |
+
|
| 57 |
+
def test_get_case_invalid_id_raises(self, mock_hf_dataset: MagicMock) -> None:
|
| 58 |
+
"""Raises KeyError for invalid case ID."""
|
| 59 |
+
adapter = CaseAdapter(mock_hf_dataset)
|
| 60 |
+
|
| 61 |
+
with pytest.raises(KeyError):
|
| 62 |
+
adapter.get_case("nonexistent-case-id")
|
| 63 |
+
|
| 64 |
+
def test_iteration(self, mock_hf_dataset: MagicMock) -> None:
|
| 65 |
+
"""Can iterate over case IDs."""
|
| 66 |
+
adapter = CaseAdapter(mock_hf_dataset)
|
| 67 |
+
|
| 68 |
+
case_ids = list(adapter)
|
| 69 |
+
|
| 70 |
+
assert len(case_ids) == len(adapter)
|
tests/data/test_loader.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for data loader module."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from unittest.mock import MagicMock, patch
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 10 |
+
from stroke_deepisles_demo.data.loader import (
|
| 11 |
+
DatasetInfo,
|
| 12 |
+
get_dataset_info,
|
| 13 |
+
load_isles_dataset,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestLoadIslesDataset:
|
| 18 |
+
"""Tests for load_isles_dataset."""
|
| 19 |
+
|
| 20 |
+
def test_calls_hf_load_dataset(self) -> None:
|
| 21 |
+
"""Calls datasets.load_dataset with correct arguments."""
|
| 22 |
+
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 23 |
+
mock_load.return_value = MagicMock()
|
| 24 |
+
|
| 25 |
+
load_isles_dataset("test/dataset")
|
| 26 |
+
|
| 27 |
+
mock_load.assert_called_once()
|
| 28 |
+
call_args = mock_load.call_args
|
| 29 |
+
assert call_args.args[0] == "test/dataset"
|
| 30 |
+
|
| 31 |
+
def test_returns_dataset_object(self) -> None:
|
| 32 |
+
"""Returns the loaded Dataset object."""
|
| 33 |
+
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 34 |
+
expected = MagicMock()
|
| 35 |
+
mock_load.return_value = expected
|
| 36 |
+
|
| 37 |
+
result = load_isles_dataset()
|
| 38 |
+
|
| 39 |
+
assert result is expected
|
| 40 |
+
|
| 41 |
+
def test_handles_load_error(self) -> None:
|
| 42 |
+
"""Wraps HF errors in DataLoadError."""
|
| 43 |
+
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 44 |
+
mock_load.side_effect = Exception("Network error")
|
| 45 |
+
|
| 46 |
+
with pytest.raises(DataLoadError, match="Network error"):
|
| 47 |
+
load_isles_dataset()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TestGetDatasetInfo:
|
| 51 |
+
"""Tests for get_dataset_info."""
|
| 52 |
+
|
| 53 |
+
def test_returns_datasetinfo(self) -> None:
|
| 54 |
+
"""Returns DatasetInfo with expected fields."""
|
| 55 |
+
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 56 |
+
mock_ds = MagicMock()
|
| 57 |
+
mock_ds.__len__ = MagicMock(return_value=149)
|
| 58 |
+
# Mock info.splits['train'].num_examples
|
| 59 |
+
mock_ds.info.splits.__getitem__.return_value.num_examples = 149
|
| 60 |
+
# Mock features as dict-like
|
| 61 |
+
mock_ds.features = {"dwi": None, "adc": None, "mask": None}
|
| 62 |
+
mock_load.return_value = mock_ds
|
| 63 |
+
|
| 64 |
+
info = get_dataset_info()
|
| 65 |
+
|
| 66 |
+
assert isinstance(info, DatasetInfo)
|
| 67 |
+
assert info.num_cases == 149
|
| 68 |
+
assert "dwi" in info.modalities
|
| 69 |
+
assert info.has_ground_truth is True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@pytest.mark.integration
|
| 73 |
+
class TestLoadIslesDatasetIntegration:
|
| 74 |
+
"""Integration tests that hit the real HuggingFace Hub."""
|
| 75 |
+
|
| 76 |
+
@pytest.mark.slow
|
| 77 |
+
def test_load_real_dataset(self) -> None:
|
| 78 |
+
"""Actually loads ISLES24-MR-Lite from HF Hub."""
|
| 79 |
+
# This test requires network access
|
| 80 |
+
# Run with: pytest -m integration
|
| 81 |
+
# Using streaming=True to avoid downloading everything
|
| 82 |
+
try:
|
| 83 |
+
dataset = load_isles_dataset(streaming=True)
|
| 84 |
+
assert dataset is not None
|
| 85 |
+
# Verify we got metadata/features - this confirms connectivity
|
| 86 |
+
# Iterating might trigger heavy downloads or fail if dataset is empty/gated
|
| 87 |
+
assert hasattr(dataset, "features")
|
| 88 |
+
assert len(dataset.features) > 0
|
| 89 |
+
except Exception as e:
|
| 90 |
+
pytest.fail(f"Failed to load real dataset: {e}")
|
tests/data/test_staging.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for data staging module."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stroke_deepisles_demo.core.exceptions import MissingInputError
|
| 10 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 11 |
+
from stroke_deepisles_demo.data.staging import (
|
| 12 |
+
create_staging_directory,
|
| 13 |
+
stage_case_for_deepisles,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestCreateStagingDirectory:
|
| 21 |
+
"""Tests for create_staging_directory."""
|
| 22 |
+
|
| 23 |
+
def test_creates_directory(self, temp_dir: Path) -> None:
|
| 24 |
+
"""Staging directory is created and exists."""
|
| 25 |
+
staging = create_staging_directory(base_dir=temp_dir)
|
| 26 |
+
assert staging.exists()
|
| 27 |
+
assert staging.is_dir()
|
| 28 |
+
|
| 29 |
+
def test_uses_system_temp_when_no_base(self) -> None:
|
| 30 |
+
"""Uses system temp directory when base_dir is None."""
|
| 31 |
+
staging = create_staging_directory(base_dir=None)
|
| 32 |
+
assert staging.exists()
|
| 33 |
+
# Cleanup
|
| 34 |
+
staging.rmdir()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TestStageCaseForDeepIsles:
|
| 38 |
+
"""Tests for stage_case_for_deepisles."""
|
| 39 |
+
|
| 40 |
+
def test_stages_required_files(self, synthetic_case_files: CaseFiles, temp_dir: Path) -> None:
|
| 41 |
+
"""DWI and ADC are staged with correct names."""
|
| 42 |
+
output_dir = temp_dir / "staged"
|
| 43 |
+
staged = stage_case_for_deepisles(synthetic_case_files, output_dir)
|
| 44 |
+
|
| 45 |
+
assert staged.dwi_path.name == "dwi.nii.gz"
|
| 46 |
+
assert staged.adc_path.name == "adc.nii.gz"
|
| 47 |
+
assert staged.dwi_path.exists()
|
| 48 |
+
assert staged.adc_path.exists()
|
| 49 |
+
|
| 50 |
+
def test_staged_files_are_readable(
|
| 51 |
+
self, synthetic_case_files: CaseFiles, temp_dir: Path
|
| 52 |
+
) -> None:
|
| 53 |
+
"""Staged files can be read as valid NIfTI."""
|
| 54 |
+
import nibabel as nib
|
| 55 |
+
|
| 56 |
+
output_dir = temp_dir / "staged"
|
| 57 |
+
staged = stage_case_for_deepisles(synthetic_case_files, output_dir)
|
| 58 |
+
|
| 59 |
+
dwi = nib.load(staged.dwi_path) # type: ignore
|
| 60 |
+
assert dwi.shape == (64, 64, 30) # type: ignore
|
| 61 |
+
|
| 62 |
+
def test_raises_when_dwi_missing(self, temp_dir: Path) -> None:
|
| 63 |
+
"""Raises MissingInputError when DWI is missing."""
|
| 64 |
+
case_files = CaseFiles(
|
| 65 |
+
dwi=temp_dir / "nonexistent.nii.gz",
|
| 66 |
+
adc=temp_dir / "adc.nii.gz",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
with pytest.raises(MissingInputError, match="Source file does not exist"):
|
| 70 |
+
stage_case_for_deepisles(case_files, temp_dir)
|
| 71 |
+
|
| 72 |
+
def test_flair_is_optional(self, synthetic_case_files: CaseFiles, temp_dir: Path) -> None:
|
| 73 |
+
"""Staging succeeds when FLAIR is None."""
|
| 74 |
+
output_dir = temp_dir / "staged"
|
| 75 |
+
staged = stage_case_for_deepisles(synthetic_case_files, output_dir)
|
| 76 |
+
|
| 77 |
+
assert staged.flair_path is None
|