feat: Phase 1A + Phase 2 - Local data loader and DeepISLES Docker wrapper (#3)
Browse files## Summary
- **Phase 1A**: Implement local file loader for ISLES24-MR-Lite dataset (149 cases)
- **Phase 2**: Implement DeepISLES Docker wrapper with GPU support
## Changes
- Add `LocalDataset` dataclass for file-based dataset access
- Add BIDS filename parsing (`parse_subject_id`)
- Add Docker utilities (`run_container`, `build_docker_command`, GPU detection)
- Add DeepISLES wrapper (`run_deepisles_on_folder`, `validate_input_folder`)
- 52 unit tests, mypy strict, ruff clean
## CodeRabbit Feedback Addressed
- Made `inspect_isles24.py` executable
- Fixed Windows compatibility in `match_user` logic
- .gitignore +4 -0
- data/scratch/.gitkeep +0 -0
- docs/specs/00-context.md +41 -11
- docs/specs/02-phase-1-data-access.md +270 -550
- docs/specs/data-discovery.md +67 -0
- pyproject.toml +1 -0
- scripts/discovery/__init__.py +0 -0
- scripts/discovery/inspect_isles24.py +267 -0
- src/stroke_deepisles_demo/core/exceptions.py +4 -0
- src/stroke_deepisles_demo/data/__init__.py +6 -14
- src/stroke_deepisles_demo/data/adapter.py +63 -126
- src/stroke_deepisles_demo/data/loader.py +24 -115
- src/stroke_deepisles_demo/inference/__init__.py +37 -1
- src/stroke_deepisles_demo/inference/deepisles.py +193 -0
- src/stroke_deepisles_demo/inference/docker.py +258 -0
- tests/conftest.py +44 -28
- tests/data/test_adapter.py +69 -45
- tests/data/test_integration_real_data.py +42 -0
- tests/data/test_loader.py +20 -77
- tests/inference/__init__.py +0 -0
- tests/inference/test_deepisles.py +285 -0
- tests/inference/test_docker.py +202 -0
.gitignore
CHANGED
|
@@ -205,3 +205,7 @@ cython_debug/
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
marimo/_static/
|
| 206 |
marimo/_lsp/
|
| 207 |
__marimo__/
|
| 208 |
+
|
| 209 |
+
# Data Discovery (per docs/specs/data-discovery.md)
|
| 210 |
+
data/scratch/*
|
| 211 |
+
!data/scratch/.gitkeep
|
data/scratch/.gitkeep
ADDED
|
File without changes
|
docs/specs/00-context.md
CHANGED
|
@@ -11,19 +11,38 @@ This document explains **why** we're building `stroke-deepisles-demo` and the ar
|
|
| 11 |
We want to demonstrate an end-to-end neuroimaging inference pipeline:
|
| 12 |
|
| 13 |
```
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
```
|
| 22 |
|
| 23 |
This showcases that:
|
| 24 |
-
1. Neuroimaging data can be
|
| 25 |
-
2.
|
| 26 |
-
3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
## why we need tobias's datasets fork
|
| 29 |
|
|
@@ -55,11 +74,22 @@ We pin to this branch until upstream merges the PRs.
|
|
| 55 |
|
| 56 |
### 1. data source: ISLES24-MR-Lite
|
| 57 |
|
| 58 |
-
- **HF Dataset**: [YongchengYAO/ISLES24-MR-Lite](https://huggingface.co/datasets/YongchengYAO/ISLES24-MR-Lite)
|
|
|
|
| 59 |
- **Content**: 149 acute stroke MRI cases with DWI, ADC, and manual infarct masks
|
| 60 |
- **Origin**: Subset of ISLES 2024 challenge data
|
| 61 |
- **Why suitable**: DeepISLES was trained on ISLES 2022, so ISLES24 is an **external** test set (no data leakage)
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
### 2. model: DeepISLES
|
| 64 |
|
| 65 |
- **Paper**: Nature Communications 2025 - "DeepISLES: A clinically validated ischemic stroke segmentation model"
|
|
|
|
| 11 |
We want to demonstrate an end-to-end neuroimaging inference pipeline:
|
| 12 |
|
| 13 |
```
|
| 14 |
+
CURRENT (Phase 1A):
|
| 15 |
+
Local NIfTI files (extracted from ISLES24-MR-Lite ZIPs)
|
| 16 |
+
↓
|
| 17 |
+
File-based loader (parse BIDS filenames)
|
| 18 |
+
↓
|
| 19 |
+
DeepISLES Docker (stroke segmentation)
|
| 20 |
+
↓
|
| 21 |
+
NiiVue visualization (Gradio Space)
|
| 22 |
+
|
| 23 |
+
FUTURE (Phase 1C-D):
|
| 24 |
+
HuggingFace Hub (properly uploaded dataset)
|
| 25 |
+
↓
|
| 26 |
+
Tobias's datasets fork (BIDS loader + Nifti feature)
|
| 27 |
+
↓
|
| 28 |
+
DeepISLES Docker (stroke segmentation)
|
| 29 |
+
↓
|
| 30 |
+
NiiVue visualization (Gradio Space)
|
| 31 |
```
|
| 32 |
|
| 33 |
This showcases that:
|
| 34 |
+
1. Neuroimaging data can be loaded from local BIDS-named files (NOW)
|
| 35 |
+
2. Neuroimaging data can be consumed from HF Hub with proper BIDS/NIfTI support (FUTURE)
|
| 36 |
+
3. Clinical-grade models can run via Docker as black boxes
|
| 37 |
+
4. Results can be visualized interactively in a browser
|
| 38 |
+
|
| 39 |
+
## critical discovery (2025-12-04)
|
| 40 |
+
|
| 41 |
+
**The original ISLES24-MR-Lite dataset is NOT properly uploaded to HuggingFace.**
|
| 42 |
+
|
| 43 |
+
It's just raw ZIP files dumped on HF, not a proper Dataset with parquet/Arrow format. This means `load_dataset()` fails. See `data/scratch/isles24_schema_report.txt` for full details.
|
| 44 |
+
|
| 45 |
+
**Workaround**: We extracted the ZIPs locally to `data/scratch/isles24_extracted/` (git-ignored) and will implement a file-based loader first. Later, we'll re-upload properly and verify full HF consumption.
|
| 46 |
|
| 47 |
## why we need tobias's datasets fork
|
| 48 |
|
|
|
|
| 74 |
|
| 75 |
### 1. data source: ISLES24-MR-Lite
|
| 76 |
|
| 77 |
+
- **HF Dataset**: [YongchengYAO/ISLES24-MR-Lite](https://huggingface.co/datasets/YongchengYAO/ISLES24-MR-Lite) (**BROKEN** - raw ZIPs, not proper dataset)
|
| 78 |
+
- **Local extracted**: `data/scratch/isles24_extracted/` (git-ignored)
|
| 79 |
- **Content**: 149 acute stroke MRI cases with DWI, ADC, and manual infarct masks
|
| 80 |
- **Origin**: Subset of ISLES 2024 challenge data
|
| 81 |
- **Why suitable**: DeepISLES was trained on ISLES 2022, so ISLES24 is an **external** test set (no data leakage)
|
| 82 |
|
| 83 |
+
**File structure** (after extraction):
|
| 84 |
+
```
|
| 85 |
+
data/scratch/isles24_extracted/
|
| 86 |
+
├── Images-DWI/sub-stroke{XXXX}_ses-02_dwi.nii.gz # 149 files
|
| 87 |
+
├── Images-ADC/sub-stroke{XXXX}_ses-02_adc.nii.gz # 149 files
|
| 88 |
+
└── Masks/sub-stroke{XXXX}_ses-02_lesion-msk.nii.gz # 149 files
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
**Schema reference**: `data/scratch/isles24_schema_report.txt`
|
| 92 |
+
|
| 93 |
### 2. model: DeepISLES
|
| 94 |
|
| 95 |
- **Paper**: Nature Communications 2025 - "DeepISLES: A clinically validated ischemic stroke segmentation model"
|
docs/specs/02-phase-1-data-access.md
CHANGED
|
@@ -1,695 +1,415 @@
|
|
| 1 |
-
# phase 1: data access
|
| 2 |
|
| 3 |
## purpose
|
| 4 |
|
| 5 |
-
Implement
|
| 6 |
|
| 7 |
-
##
|
| 8 |
|
| 9 |
-
-
|
| 10 |
-
- [ ] `src/stroke_deepisles_demo/data/adapter.py` - Case adapter for file access
|
| 11 |
-
- [ ] `src/stroke_deepisles_demo/data/staging.py` - Stage files for DeepISLES
|
| 12 |
-
- [ ] Unit tests with fixtures (no network required)
|
| 13 |
-
- [ ] Integration test (marked, requires network)
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
case_ids = list_case_ids()
|
| 24 |
-
print(f"Found {len(case_ids)} cases")
|
| 25 |
|
| 26 |
-
|
| 27 |
-
case = get_case("sub-001")
|
| 28 |
-
print(f"DWI: {case.dwi}")
|
| 29 |
-
print(f"ADC: {case.adc}")
|
| 30 |
-
print(f"Ground truth: {case.ground_truth}")
|
| 31 |
-
```
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
```
|
| 36 |
-
|
| 37 |
-
├──
|
| 38 |
-
|
| 39 |
-
├──
|
| 40 |
-
└──
|
|
|
|
|
|
|
| 41 |
```
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
```python
|
| 48 |
-
"""Load ISLES24
|
| 49 |
|
| 50 |
from __future__ import annotations
|
| 51 |
|
|
|
|
| 52 |
from pathlib import Path
|
| 53 |
from typing import TYPE_CHECKING
|
| 54 |
|
| 55 |
if TYPE_CHECKING:
|
| 56 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def load_isles_dataset(
|
| 60 |
-
|
| 61 |
*,
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
) -> Dataset:
|
| 65 |
"""
|
| 66 |
-
Load
|
| 67 |
|
| 68 |
Args:
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
streaming: If True, use streaming mode (lazy loading)
|
| 72 |
|
| 73 |
Returns:
|
| 74 |
-
|
| 75 |
|
| 76 |
Raises:
|
| 77 |
-
DataLoadError: If
|
| 78 |
"""
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
-
def
|
| 83 |
"""
|
| 84 |
-
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
...
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
@dataclass
|
| 93 |
-
class DatasetInfo:
|
| 94 |
-
"""Metadata about the loaded dataset."""
|
| 95 |
-
|
| 96 |
-
dataset_id: str
|
| 97 |
-
num_cases: int
|
| 98 |
-
modalities: list[str] # e.g., ["dwi", "adc", "mask"]
|
| 99 |
-
has_ground_truth: bool
|
| 100 |
```
|
| 101 |
|
| 102 |
-
|
| 103 |
|
| 104 |
```python
|
| 105 |
-
"""
|
| 106 |
|
| 107 |
from __future__ import annotations
|
| 108 |
|
|
|
|
|
|
|
| 109 |
from pathlib import Path
|
| 110 |
from typing import Iterator
|
| 111 |
|
| 112 |
from stroke_deepisles_demo.core.types import CaseFiles
|
| 113 |
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
This handles the mapping between HF dataset structure and our
|
| 120 |
-
internal CaseFiles type.
|
| 121 |
-
"""
|
| 122 |
-
|
| 123 |
-
def __init__(self, dataset: Dataset) -> None:
|
| 124 |
-
"""
|
| 125 |
-
Initialize adapter with a loaded dataset.
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
"""
|
| 130 |
-
...
|
| 131 |
|
| 132 |
def __len__(self) -> int:
|
| 133 |
-
|
| 134 |
-
...
|
| 135 |
|
| 136 |
def __iter__(self) -> Iterator[str]:
|
| 137 |
-
|
| 138 |
-
...
|
| 139 |
|
| 140 |
def list_case_ids(self) -> list[str]:
|
| 141 |
-
"""
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
Returns:
|
| 145 |
-
List of case IDs (e.g., ["sub-001", "sub-002", ...])
|
| 146 |
-
"""
|
| 147 |
-
...
|
| 148 |
|
| 149 |
def get_case(self, case_id: str | int) -> CaseFiles:
|
| 150 |
-
"""
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
Args:
|
| 154 |
-
case_id: Either a string ID (e.g., "sub-001") or integer index
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
-
Raises:
|
| 160 |
-
KeyError: If case_id not found
|
| 161 |
-
DataLoadError: If files cannot be accessed
|
| 162 |
-
"""
|
| 163 |
-
...
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
Returns:
|
| 170 |
-
Tuple of (case_id, CaseFiles)
|
| 171 |
-
"""
|
| 172 |
-
...
|
| 173 |
-
```
|
| 174 |
|
| 175 |
-
### `data/staging.py`
|
| 176 |
-
|
| 177 |
-
```python
|
| 178 |
-
"""Stage NIfTI files with DeepISLES-expected naming."""
|
| 179 |
-
|
| 180 |
-
from __future__ import annotations
|
| 181 |
|
| 182 |
-
|
| 183 |
-
from typing import NamedTuple
|
| 184 |
-
|
| 185 |
-
from stroke_deepisles_demo.core.types import CaseFiles
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
class StagedCase(NamedTuple):
|
| 189 |
-
"""Paths to staged files ready for DeepISLES."""
|
| 190 |
-
|
| 191 |
-
input_dir: Path # Directory containing staged files
|
| 192 |
-
dwi_path: Path # Path to dwi.nii.gz
|
| 193 |
-
adc_path: Path # Path to adc.nii.gz
|
| 194 |
-
flair_path: Path | None # Path to flair.nii.gz if available
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def stage_case_for_deepisles(
|
| 198 |
-
case_files: CaseFiles,
|
| 199 |
-
output_dir: Path,
|
| 200 |
-
*,
|
| 201 |
-
case_id: str | None = None,
|
| 202 |
-
) -> StagedCase:
|
| 203 |
"""
|
| 204 |
-
|
| 205 |
|
| 206 |
-
|
| 207 |
-
- dwi.nii.gz
|
| 208 |
-
- adc.nii.gz
|
| 209 |
-
- flair.nii.gz (optional)
|
| 210 |
-
|
| 211 |
-
This function copies/symlinks the source files to a staging directory
|
| 212 |
-
with the correct names.
|
| 213 |
-
|
| 214 |
-
Args:
|
| 215 |
-
case_files: Source file paths from CaseAdapter
|
| 216 |
-
output_dir: Directory to stage files into
|
| 217 |
-
case_id: Optional case ID for logging/subdirectory
|
| 218 |
-
|
| 219 |
-
Returns:
|
| 220 |
-
StagedCase with paths to staged files
|
| 221 |
-
|
| 222 |
-
Raises:
|
| 223 |
-
MissingInputError: If required files (DWI, ADC) are missing
|
| 224 |
-
OSError: If file operations fail
|
| 225 |
"""
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
|
|
|
| 228 |
|
| 229 |
-
|
| 230 |
-
""
|
| 231 |
-
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
"""
|
| 239 |
-
...
|
| 240 |
-
```
|
| 241 |
-
|
| 242 |
-
### `data/__init__.py` (public API)
|
| 243 |
-
|
| 244 |
-
```python
|
| 245 |
-
"""Data loading and case management for stroke-deepisles-demo."""
|
| 246 |
-
|
| 247 |
-
from stroke_deepisles_demo.data.adapter import CaseAdapter
|
| 248 |
-
from stroke_deepisles_demo.data.loader import DatasetInfo, get_dataset_info, load_isles_dataset
|
| 249 |
-
from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
|
| 250 |
-
|
| 251 |
-
__all__ = [
|
| 252 |
-
# Loader
|
| 253 |
-
"load_isles_dataset",
|
| 254 |
-
"get_dataset_info",
|
| 255 |
-
"DatasetInfo",
|
| 256 |
-
# Adapter
|
| 257 |
-
"CaseAdapter",
|
| 258 |
-
# Staging
|
| 259 |
-
"stage_case_for_deepisles",
|
| 260 |
-
"StagedCase",
|
| 261 |
-
]
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
# Convenience functions (combine loader + adapter)
|
| 265 |
-
def get_case(case_id: str | int) -> CaseFiles:
|
| 266 |
-
"""Load a single case by ID or index."""
|
| 267 |
-
...
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
-
|
| 271 |
-
"""List all available case IDs."""
|
| 272 |
-
...
|
| 273 |
-
```
|
| 274 |
-
|
| 275 |
-
## tdd plan
|
| 276 |
-
|
| 277 |
-
### test file structure
|
| 278 |
-
|
| 279 |
-
```
|
| 280 |
-
tests/
|
| 281 |
-
├── conftest.py # Shared fixtures
|
| 282 |
-
├── data/
|
| 283 |
-
│ ├── __init__.py
|
| 284 |
-
│ ├── test_loader.py # Tests for HF loading
|
| 285 |
-
│ ├── test_adapter.py # Tests for case adapter
|
| 286 |
-
│ └── test_staging.py # Tests for file staging
|
| 287 |
-
└── fixtures/
|
| 288 |
-
└── nifti/ # Minimal synthetic NIfTI files
|
| 289 |
-
├── dwi.nii.gz
|
| 290 |
-
├── adc.nii.gz
|
| 291 |
-
└── mask.nii.gz
|
| 292 |
```
|
| 293 |
|
| 294 |
-
###
|
| 295 |
|
| 296 |
-
|
| 297 |
|
| 298 |
```python
|
| 299 |
-
"""Shared test fixtures."""
|
| 300 |
-
|
| 301 |
-
from __future__ import annotations
|
| 302 |
-
|
| 303 |
-
import tempfile
|
| 304 |
-
from pathlib import Path
|
| 305 |
-
|
| 306 |
-
import nibabel as nib
|
| 307 |
-
import numpy as np
|
| 308 |
-
import pytest
|
| 309 |
-
|
| 310 |
-
|
| 311 |
@pytest.fixture
|
| 312 |
-
def temp_dir
|
| 313 |
-
"""
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
"""Create a minimal synthetic 3D NIfTI file."""
|
| 321 |
-
data = np.random.rand(10, 10, 10).astype(np.float32)
|
| 322 |
-
img = nib.Nifti1Image(data, affine=np.eye(4))
|
| 323 |
-
path = temp_dir / "synthetic.nii.gz"
|
| 324 |
-
nib.save(img, path)
|
| 325 |
-
return path
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
dwi_data = np.random.rand(64, 64, 30).astype(np.float32)
|
| 333 |
-
dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4))
|
| 334 |
-
dwi_path = temp_dir / "dwi.nii.gz"
|
| 335 |
-
nib.save(dwi_img, dwi_path)
|
| 336 |
-
|
| 337 |
-
# Create ADC
|
| 338 |
-
adc_data = np.random.rand(64, 64, 30).astype(np.float32) * 2000
|
| 339 |
-
adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4))
|
| 340 |
-
adc_path = temp_dir / "adc.nii.gz"
|
| 341 |
-
nib.save(adc_img, adc_path)
|
| 342 |
-
|
| 343 |
-
# Create mask
|
| 344 |
-
mask_data = (np.random.rand(64, 64, 30) > 0.9).astype(np.uint8)
|
| 345 |
-
mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4))
|
| 346 |
-
mask_path = temp_dir / "mask.nii.gz"
|
| 347 |
-
nib.save(mask_img, mask_path)
|
| 348 |
-
|
| 349 |
-
return CaseFiles(
|
| 350 |
-
dwi=dwi_path,
|
| 351 |
-
adc=adc_path,
|
| 352 |
-
flair=None,
|
| 353 |
-
ground_truth=mask_path,
|
| 354 |
-
)
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
|
| 358 |
-
def mock_hf_dataset(synthetic_case_files: CaseFiles):
|
| 359 |
-
"""Create a mock HF Dataset-like object."""
|
| 360 |
-
# Returns a simple dict-based mock that mimics Dataset behavior
|
| 361 |
-
...
|
| 362 |
```
|
| 363 |
|
| 364 |
-
|
| 365 |
|
| 366 |
```python
|
| 367 |
-
|
| 368 |
|
| 369 |
-
|
|
|
|
|
|
|
| 370 |
|
| 371 |
-
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
|
| 374 |
|
| 375 |
-
|
| 376 |
-
from
|
| 377 |
-
|
| 378 |
-
StagedCase,
|
| 379 |
-
create_staging_directory,
|
| 380 |
-
stage_case_for_deepisles,
|
| 381 |
-
)
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
class TestCreateStagingDirectory:
|
| 385 |
-
"""Tests for create_staging_directory."""
|
| 386 |
-
|
| 387 |
-
def test_creates_directory(self, temp_dir: Path) -> None:
|
| 388 |
-
"""Staging directory is created and exists."""
|
| 389 |
-
staging = create_staging_directory(base_dir=temp_dir)
|
| 390 |
-
assert staging.exists()
|
| 391 |
-
assert staging.is_dir()
|
| 392 |
-
|
| 393 |
-
def test_uses_system_temp_when_no_base(self) -> None:
|
| 394 |
-
"""Uses system temp directory when base_dir is None."""
|
| 395 |
-
staging = create_staging_directory(base_dir=None)
|
| 396 |
-
assert staging.exists()
|
| 397 |
-
# Cleanup
|
| 398 |
-
staging.rmdir()
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
class TestStageCaseForDeepIsles:
|
| 402 |
-
"""Tests for stage_case_for_deepisles."""
|
| 403 |
-
|
| 404 |
-
def test_stages_required_files(
|
| 405 |
-
self, synthetic_case_files: CaseFiles, temp_dir: Path
|
| 406 |
-
) -> None:
|
| 407 |
-
"""DWI and ADC are staged with correct names."""
|
| 408 |
-
staged = stage_case_for_deepisles(synthetic_case_files, temp_dir)
|
| 409 |
-
|
| 410 |
-
assert staged.dwi_path.name == "dwi.nii.gz"
|
| 411 |
-
assert staged.adc_path.name == "adc.nii.gz"
|
| 412 |
-
assert staged.dwi_path.exists()
|
| 413 |
-
assert staged.adc_path.exists()
|
| 414 |
-
|
| 415 |
-
def test_staged_files_are_readable(
|
| 416 |
-
self, synthetic_case_files: CaseFiles, temp_dir: Path
|
| 417 |
-
) -> None:
|
| 418 |
-
"""Staged files can be read as valid NIfTI."""
|
| 419 |
-
import nibabel as nib
|
| 420 |
-
|
| 421 |
-
staged = stage_case_for_deepisles(synthetic_case_files, temp_dir)
|
| 422 |
-
|
| 423 |
-
dwi = nib.load(staged.dwi_path)
|
| 424 |
-
assert dwi.shape == (64, 64, 30)
|
| 425 |
-
|
| 426 |
-
def test_raises_when_dwi_missing(self, temp_dir: Path) -> None:
|
| 427 |
-
"""Raises MissingInputError when DWI is missing."""
|
| 428 |
-
case_files = CaseFiles(
|
| 429 |
-
dwi=temp_dir / "nonexistent.nii.gz",
|
| 430 |
-
adc=temp_dir / "adc.nii.gz",
|
| 431 |
-
flair=None,
|
| 432 |
-
ground_truth=None,
|
| 433 |
-
)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
def test_flair_is_optional(
|
| 439 |
-
self, synthetic_case_files: CaseFiles, temp_dir: Path
|
| 440 |
-
) -> None:
|
| 441 |
-
"""Staging succeeds when FLAIR is None."""
|
| 442 |
-
# synthetic_case_files has flair=None
|
| 443 |
-
staged = stage_case_for_deepisles(synthetic_case_files, temp_dir)
|
| 444 |
|
| 445 |
-
|
|
|
|
|
|
|
| 446 |
```
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
```python
|
| 451 |
-
"""Tests for case adapter module."""
|
| 452 |
-
|
| 453 |
-
from __future__ import annotations
|
| 454 |
-
|
| 455 |
-
import pytest
|
| 456 |
-
|
| 457 |
-
from stroke_deepisles_demo.core.types import CaseFiles
|
| 458 |
-
from stroke_deepisles_demo.data.adapter import CaseAdapter
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
class TestCaseAdapter:
|
| 462 |
-
"""Tests for CaseAdapter."""
|
| 463 |
-
|
| 464 |
-
def test_list_case_ids_returns_strings(self, mock_hf_dataset) -> None:
|
| 465 |
-
"""list_case_ids returns list of string identifiers."""
|
| 466 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 467 |
-
case_ids = adapter.list_case_ids()
|
| 468 |
-
|
| 469 |
-
assert isinstance(case_ids, list)
|
| 470 |
-
assert all(isinstance(cid, str) for cid in case_ids)
|
| 471 |
-
|
| 472 |
-
def test_len_matches_dataset_size(self, mock_hf_dataset) -> None:
|
| 473 |
-
"""len(adapter) equals number of cases in dataset."""
|
| 474 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 475 |
-
|
| 476 |
-
assert len(adapter) == len(mock_hf_dataset)
|
| 477 |
-
|
| 478 |
-
def test_get_case_by_string_id(self, mock_hf_dataset) -> None:
|
| 479 |
-
"""Can retrieve case by string identifier."""
|
| 480 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 481 |
-
case_ids = adapter.list_case_ids()
|
| 482 |
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
-
|
| 486 |
-
assert "dwi" in case
|
| 487 |
-
assert "adc" in case
|
| 488 |
|
| 489 |
-
|
| 490 |
-
"""Can retrieve case by integer index."""
|
| 491 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 492 |
|
| 493 |
-
|
| 494 |
|
| 495 |
-
|
| 496 |
-
assert case["dwi"] is not None
|
| 497 |
|
| 498 |
-
|
| 499 |
-
"""Raises KeyError for invalid case ID."""
|
| 500 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 501 |
-
|
| 502 |
-
with pytest.raises(KeyError):
|
| 503 |
-
adapter.get_case("nonexistent-case-id")
|
| 504 |
-
|
| 505 |
-
def test_iteration(self, mock_hf_dataset) -> None:
|
| 506 |
-
"""Can iterate over case IDs."""
|
| 507 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 508 |
-
|
| 509 |
-
case_ids = list(adapter)
|
| 510 |
-
|
| 511 |
-
assert len(case_ids) == len(adapter)
|
| 512 |
-
```
|
| 513 |
-
|
| 514 |
-
#### 4. `tests/data/test_loader.py` - Loader with network mocks
|
| 515 |
|
| 516 |
```python
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
from
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
class TestLoadIslesDataset:
|
| 534 |
-
"""Tests for load_isles_dataset."""
|
| 535 |
|
| 536 |
-
|
| 537 |
-
"""Calls datasets.load_dataset with correct arguments."""
|
| 538 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 539 |
-
mock_load.return_value = MagicMock()
|
| 540 |
|
| 541 |
-
|
|
|
|
|
|
|
| 542 |
|
| 543 |
-
|
| 544 |
-
call_args = mock_load.call_args
|
| 545 |
-
assert call_args.args[0] == "test/dataset"
|
| 546 |
|
| 547 |
-
|
| 548 |
-
"""Returns the loaded Dataset object."""
|
| 549 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 550 |
-
expected = MagicMock()
|
| 551 |
-
mock_load.return_value = expected
|
| 552 |
|
| 553 |
-
|
| 554 |
|
| 555 |
-
|
| 556 |
|
| 557 |
-
|
| 558 |
-
"""Wraps HF errors in DataLoadError."""
|
| 559 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 560 |
-
mock_load.side_effect = Exception("Network error")
|
| 561 |
|
| 562 |
-
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
|
|
|
| 565 |
|
| 566 |
-
|
| 567 |
-
|
|
|
|
| 568 |
|
| 569 |
-
|
| 570 |
-
"""Returns DatasetInfo with expected fields."""
|
| 571 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 572 |
-
mock_ds = MagicMock()
|
| 573 |
-
mock_ds.__len__ = MagicMock(return_value=149)
|
| 574 |
-
mock_ds.features = {"dwi": ..., "adc": ..., "mask": ...}
|
| 575 |
-
mock_load.return_value = mock_ds
|
| 576 |
|
| 577 |
-
|
| 578 |
|
| 579 |
-
|
| 580 |
-
assert info.num_cases == 149
|
| 581 |
|
|
|
|
| 582 |
|
| 583 |
-
|
| 584 |
-
class TestLoadIslesDatasetIntegration:
|
| 585 |
-
"""Integration tests that hit the real HuggingFace Hub."""
|
| 586 |
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
"""Actually loads ISLES24-MR-Lite from HF Hub."""
|
| 590 |
-
# This test requires network access
|
| 591 |
-
# Run with: pytest -m integration
|
| 592 |
-
dataset = load_isles_dataset(streaming=True)
|
| 593 |
|
| 594 |
-
|
| 595 |
-
|
|
|
|
|
|
|
| 596 |
```
|
| 597 |
|
| 598 |
-
###
|
| 599 |
-
|
| 600 |
-
- `datasets.load_dataset` - Mock for unit tests, real for integration tests
|
| 601 |
-
- `huggingface_hub` calls - Mock for unit tests
|
| 602 |
-
- File system operations - Use `temp_dir` fixture with real files
|
| 603 |
-
|
| 604 |
-
### what to test for real
|
| 605 |
-
|
| 606 |
-
- NIfTI file creation/reading with nibabel
|
| 607 |
-
- File staging (copy/symlink operations)
|
| 608 |
-
- Integration test: actual HF Hub download (marked `@pytest.mark.integration`)
|
| 609 |
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
Phase 1 is complete when:
|
| 613 |
-
|
| 614 |
-
1. All unit tests pass: `uv run pytest tests/data/ -v`
|
| 615 |
-
2. Can load synthetic test cases without network
|
| 616 |
-
3. Can list case IDs from mock dataset
|
| 617 |
-
4. Can stage files with correct DeepISLES naming
|
| 618 |
-
5. Integration test passes (with network): `uv run pytest -m integration`
|
| 619 |
-
6. Type checking passes: `uv run mypy src/stroke_deepisles_demo/data/`
|
| 620 |
-
7. Code coverage for data module > 80%
|
| 621 |
-
|
| 622 |
-
## implementation notes
|
| 623 |
-
|
| 624 |
-
- ISLES24-MR-Lite structure needs investigation - check HF page for exact column names
|
| 625 |
-
- Consider using `huggingface_hub.snapshot_download` if `datasets.load_dataset` has issues with NIfTI
|
| 626 |
-
- Staging can use symlinks on Unix, copies on Windows
|
| 627 |
-
- Cache the HF dataset locally to avoid repeated downloads
|
| 628 |
-
|
| 629 |
-
### critical: streaming mode + docker materialization
|
| 630 |
-
|
| 631 |
-
**Reviewer feedback (valid)**: When using `streaming=True`, the dataset returns URLs or lazy file objects, NOT local POSIX paths. Docker requires physical files on the host disk for volume mounting.
|
| 632 |
-
|
| 633 |
-
**Solution**: The `stage_case_for_deepisles` function MUST handle materialization:
|
| 634 |
|
| 635 |
```python
|
| 636 |
-
def
|
| 637 |
-
case_files: CaseFiles,
|
| 638 |
-
output_dir: Path,
|
| 639 |
-
*,
|
| 640 |
-
case_id: str | None = None,
|
| 641 |
-
) -> StagedCase:
|
| 642 |
"""
|
| 643 |
-
|
| 644 |
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
|
|
|
| 650 |
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
|
| 655 |
-
|
| 656 |
-
adc_staged = output_dir / "adc.nii.gz"
|
| 657 |
-
_materialize_nifti(case_files["adc"], adc_staged)
|
| 658 |
|
| 659 |
-
|
| 660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
|
| 662 |
-
|
| 663 |
-
"""
|
| 664 |
-
Materialize a NIfTI file to a local path.
|
| 665 |
|
| 666 |
-
|
| 667 |
-
- Local Path: copy or symlink
|
| 668 |
-
- URL string: download
|
| 669 |
-
- bytes: write directly
|
| 670 |
-
- NIfTI object: serialize with nibabel
|
| 671 |
-
"""
|
| 672 |
-
if isinstance(source, Path) and source.exists():
|
| 673 |
-
# Local file - symlink if possible, copy otherwise
|
| 674 |
-
shutil.copy2(source, dest)
|
| 675 |
-
elif isinstance(source, str) and source.startswith(("http://", "https://")):
|
| 676 |
-
# URL - download
|
| 677 |
-
_download_file(source, dest)
|
| 678 |
-
elif isinstance(source, bytes):
|
| 679 |
-
# Raw bytes
|
| 680 |
-
dest.write_bytes(source)
|
| 681 |
-
elif hasattr(source, "to_bytes"):
|
| 682 |
-
# NIfTI object (nibabel or wrapper)
|
| 683 |
-
dest.write_bytes(source.to_bytes())
|
| 684 |
-
else:
|
| 685 |
-
raise MissingInputError(f"Cannot materialize source: {type(source)}")
|
| 686 |
-
```
|
| 687 |
|
| 688 |
-
|
| 689 |
|
| 690 |
-
##
|
| 691 |
|
| 692 |
-
|
| 693 |
-
- `
|
| 694 |
-
- `
|
| 695 |
-
- `numpy`
|
|
|
|
| 1 |
+
# phase 1: data access layer
|
| 2 |
|
| 3 |
## purpose
|
| 4 |
|
| 5 |
+
Implement a data loading layer that provides typed access to ISLES24 neuroimaging cases. This phase is split into sub-phases due to a critical discovery: the upstream dataset is not properly formatted for HuggingFace consumption.
|
| 6 |
|
| 7 |
+
## critical discovery (2025-12-04)
|
| 8 |
|
| 9 |
+
**`YongchengYAO/ISLES24-MR-Lite` is NOT a proper HuggingFace Dataset.**
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
| What we expected | What actually exists |
|
| 12 |
+
|------------------|---------------------|
|
| 13 |
+
| `load_dataset()` returns Dataset with columns | `load_dataset()` FAILS with "no data" |
|
| 14 |
+
| Columns: `dwi`, `adc`, `mask`, `participant_id` | No columns - just raw ZIP files |
|
| 15 |
+
| Parquet/Arrow format | Three ZIP archives dumped on HF |
|
| 16 |
|
| 17 |
+
**Evidence**: `data/scratch/isles24_schema_report.txt`
|
| 18 |
|
| 19 |
+
This means the demo must be built in phases:
|
| 20 |
+
1. **Phase 1A**: Local file loader (works NOW with extracted data)
|
| 21 |
+
2. **Phase 1B**: Test Tobias's `Nifti()` feature on local files (proves loading works)
|
| 22 |
+
3. **Phase 1C**: Upload properly to HuggingFace (future - proves production pipeline)
|
| 23 |
+
4. **Phase 1D**: Consume via Tobias's fork (future - proves full round-trip)
|
| 24 |
|
| 25 |
+
---
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
## phase 1a: local file loader (CURRENT PRIORITY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
### data location
|
| 30 |
|
| 31 |
```
|
| 32 |
+
data/scratch/isles24_extracted/ # Git-ignored
|
| 33 |
+
├── Images-DWI/ # 149 files
|
| 34 |
+
│ └── sub-stroke{XXXX}_ses-02_dwi.nii.gz
|
| 35 |
+
├── Images-ADC/ # 149 files
|
| 36 |
+
│ └── sub-stroke{XXXX}_ses-02_adc.nii.gz
|
| 37 |
+
└── Masks/ # 149 files
|
| 38 |
+
└── sub-stroke{XXXX}_ses-02_lesion-msk.nii.gz
|
| 39 |
```
|
| 40 |
|
| 41 |
+
### file naming convention (BIDS-like)
|
| 42 |
+
|
| 43 |
+
| Component | Pattern | Example |
|
| 44 |
+
|-----------|---------|---------|
|
| 45 |
+
| Subject ID | `sub-stroke{XXXX}` | `sub-stroke0005` |
|
| 46 |
+
| Session | `ses-02` | Always "02" in this dataset |
|
| 47 |
+
| Modality | `dwi`, `adc`, `lesion-msk` | - |
|
| 48 |
+
| Extension | `.nii.gz` | Compressed NIfTI |
|
| 49 |
+
|
| 50 |
+
**Subject ID regex**: `sub-stroke(\d{4})_ses-02_.*\.nii\.gz`
|
| 51 |
+
|
| 52 |
+
**Note**: Subject IDs have gaps (e.g., 0018 missing). Range is 0001-0189, total 149 cases.
|
| 53 |
|
| 54 |
+
### deliverables
|
| 55 |
+
|
| 56 |
+
- [ ] `src/stroke_deepisles_demo/data/loader.py` - Rewrite with local mode
|
| 57 |
+
- [ ] `src/stroke_deepisles_demo/data/adapter.py` - Rewrite for file-based access
|
| 58 |
+
- [ ] `src/stroke_deepisles_demo/data/staging.py` - Already correct, no changes
|
| 59 |
+
- [ ] Unit tests with synthetic fixtures
|
| 60 |
+
- [ ] Integration test with actual extracted data
|
| 61 |
+
|
| 62 |
+
### interfaces
|
| 63 |
+
|
| 64 |
+
#### `data/loader.py`
|
| 65 |
|
| 66 |
```python
|
| 67 |
+
"""Load ISLES24 data from local directory or HuggingFace Hub."""
|
| 68 |
|
| 69 |
from __future__ import annotations
|
| 70 |
|
| 71 |
+
from dataclasses import dataclass
|
| 72 |
from pathlib import Path
|
| 73 |
from typing import TYPE_CHECKING
|
| 74 |
|
| 75 |
if TYPE_CHECKING:
|
| 76 |
+
from stroke_deepisles_demo.data.adapter import LocalDataset
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class DatasetInfo:
|
| 81 |
+
"""Metadata about the dataset."""
|
| 82 |
+
|
| 83 |
+
source: str # "local" or HF dataset ID
|
| 84 |
+
num_cases: int
|
| 85 |
+
modalities: list[str]
|
| 86 |
+
has_ground_truth: bool
|
| 87 |
|
| 88 |
|
| 89 |
def load_isles_dataset(
|
| 90 |
+
source: str | Path = "data/scratch/isles24_extracted",
|
| 91 |
*,
|
| 92 |
+
local_mode: bool = True, # Default to local for now
|
| 93 |
+
) -> LocalDataset:
|
|
|
|
| 94 |
"""
|
| 95 |
+
Load ISLES24 dataset.
|
| 96 |
|
| 97 |
Args:
|
| 98 |
+
source: Local directory path or HuggingFace dataset ID
|
| 99 |
+
local_mode: If True, treat source as local directory
|
|
|
|
| 100 |
|
| 101 |
Returns:
|
| 102 |
+
Dataset-like object providing case access
|
| 103 |
|
| 104 |
Raises:
|
| 105 |
+
DataLoadError: If data cannot be loaded
|
| 106 |
"""
|
| 107 |
+
if local_mode or isinstance(source, Path):
|
| 108 |
+
return _load_from_local_directory(Path(source))
|
| 109 |
+
# Future: return _load_from_huggingface(source)
|
| 110 |
+
raise NotImplementedError("HuggingFace mode not yet implemented")
|
| 111 |
|
| 112 |
|
| 113 |
+
def _load_from_local_directory(data_dir: Path) -> LocalDataset:
|
| 114 |
"""
|
| 115 |
+
Load cases from extracted local files.
|
| 116 |
|
| 117 |
+
Expects structure:
|
| 118 |
+
data_dir/
|
| 119 |
+
├── Images-DWI/sub-stroke{XXXX}_ses-02_dwi.nii.gz
|
| 120 |
+
├── Images-ADC/sub-stroke{XXXX}_ses-02_adc.nii.gz
|
| 121 |
+
└── Masks/sub-stroke{XXXX}_ses-02_lesion-msk.nii.gz
|
| 122 |
"""
|
| 123 |
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
```
|
| 125 |
|
| 126 |
+
#### `data/adapter.py`
|
| 127 |
|
| 128 |
```python
|
| 129 |
+
"""Provide typed access to ISLES24 cases."""
|
| 130 |
|
| 131 |
from __future__ import annotations
|
| 132 |
|
| 133 |
+
import re
|
| 134 |
+
from dataclasses import dataclass
|
| 135 |
from pathlib import Path
|
| 136 |
from typing import Iterator
|
| 137 |
|
| 138 |
from stroke_deepisles_demo.core.types import CaseFiles
|
| 139 |
|
| 140 |
|
| 141 |
+
@dataclass
|
| 142 |
+
class LocalDataset:
|
| 143 |
+
"""File-based dataset for local ISLES24 data."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
data_dir: Path
|
| 146 |
+
cases: dict[str, CaseFiles] # subject_id -> files
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def __len__(self) -> int:
|
| 149 |
+
return len(self.cases)
|
|
|
|
| 150 |
|
| 151 |
def __iter__(self) -> Iterator[str]:
|
| 152 |
+
return iter(self.cases.keys())
|
|
|
|
| 153 |
|
| 154 |
def list_case_ids(self) -> list[str]:
|
| 155 |
+
"""Return sorted list of subject IDs."""
|
| 156 |
+
return sorted(self.cases.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
def get_case(self, case_id: str | int) -> CaseFiles:
|
| 159 |
+
"""Get files for a case by ID or index."""
|
| 160 |
+
if isinstance(case_id, int):
|
| 161 |
+
case_id = self.list_case_ids()[case_id]
|
| 162 |
+
return self.cases[case_id]
|
| 163 |
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
# Subject ID extraction
|
| 166 |
+
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
def parse_subject_id(filename: str) -> str | None:
|
| 170 |
+
"""Extract subject ID from BIDS filename."""
|
| 171 |
+
match = SUBJECT_PATTERN.match(filename)
|
| 172 |
+
return f"sub-{match.group(1)}" if match else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
def build_local_dataset(data_dir: Path) -> LocalDataset:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
"""
|
| 177 |
+
Scan directory and build case mapping.
|
| 178 |
|
| 179 |
+
Matches DWI + ADC + Mask files by subject ID.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
"""
|
| 181 |
+
dwi_dir = data_dir / "Images-DWI"
|
| 182 |
+
adc_dir = data_dir / "Images-ADC"
|
| 183 |
+
mask_dir = data_dir / "Masks"
|
| 184 |
|
| 185 |
+
cases: dict[str, CaseFiles] = {}
|
| 186 |
|
| 187 |
+
# Scan DWI files to get subject IDs
|
| 188 |
+
for dwi_file in dwi_dir.glob("*.nii.gz"):
|
| 189 |
+
subject_id = parse_subject_id(dwi_file.name)
|
| 190 |
+
if not subject_id:
|
| 191 |
+
continue
|
| 192 |
|
| 193 |
+
# Find matching ADC and Mask
|
| 194 |
+
adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.")
|
| 195 |
+
mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.")
|
| 196 |
|
| 197 |
+
if not adc_file.exists():
|
| 198 |
+
continue # Skip incomplete cases
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
cases[subject_id] = CaseFiles(
|
| 201 |
+
dwi=dwi_file,
|
| 202 |
+
adc=adc_file,
|
| 203 |
+
ground_truth=mask_file if mask_file.exists() else None,
|
| 204 |
+
)
|
| 205 |
|
| 206 |
+
return LocalDataset(data_dir=data_dir, cases=cases)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
```
|
| 208 |
|
| 209 |
+
### synthetic fixture structure
|
| 210 |
|
| 211 |
+
Unit tests MUST use fixtures that replicate the **exact** directory structure. Add to `tests/conftest.py`:
|
| 212 |
|
| 213 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
@pytest.fixture
|
| 215 |
+
def synthetic_isles_dir(temp_dir: Path) -> Path:
|
| 216 |
+
"""
|
| 217 |
+
Create synthetic ISLES24-like directory structure.
|
| 218 |
+
|
| 219 |
+
Structure:
|
| 220 |
+
temp_dir/
|
| 221 |
+
├── Images-DWI/
|
| 222 |
+
│ ├── sub-stroke0001_ses-02_dwi.nii.gz
|
| 223 |
+
│ └── sub-stroke0002_ses-02_dwi.nii.gz
|
| 224 |
+
├── Images-ADC/
|
| 225 |
+
│ ├── sub-stroke0001_ses-02_adc.nii.gz
|
| 226 |
+
│ └── sub-stroke0002_ses-02_adc.nii.gz
|
| 227 |
+
└── Masks/
|
| 228 |
+
├── sub-stroke0001_ses-02_lesion-msk.nii.gz
|
| 229 |
+
└── sub-stroke0002_ses-02_lesion-msk.nii.gz
|
| 230 |
+
"""
|
| 231 |
+
dwi_dir = temp_dir / "Images-DWI"
|
| 232 |
+
adc_dir = temp_dir / "Images-ADC"
|
| 233 |
+
mask_dir = temp_dir / "Masks"
|
| 234 |
|
| 235 |
+
dwi_dir.mkdir()
|
| 236 |
+
adc_dir.mkdir()
|
| 237 |
+
mask_dir.mkdir()
|
| 238 |
|
| 239 |
+
for subject_num in [1, 2]:
|
| 240 |
+
subject_id = f"sub-stroke{subject_num:04d}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
+
# Create DWI
|
| 243 |
+
dwi_data = np.random.rand(10, 10, 5).astype(np.float32)
|
| 244 |
+
dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4))
|
| 245 |
+
nib.save(dwi_img, dwi_dir / f"{subject_id}_ses-02_dwi.nii.gz")
|
| 246 |
|
| 247 |
+
# Create ADC
|
| 248 |
+
adc_data = np.random.rand(10, 10, 5).astype(np.float32) * 2000
|
| 249 |
+
adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4))
|
| 250 |
+
nib.save(adc_img, adc_dir / f"{subject_id}_ses-02_adc.nii.gz")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
# Create Mask
|
| 253 |
+
mask_data = (np.random.rand(10, 10, 5) > 0.9).astype(np.uint8)
|
| 254 |
+
mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4))
|
| 255 |
+
nib.save(mask_img, mask_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz")
|
| 256 |
|
| 257 |
+
return temp_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
```
|
| 259 |
|
| 260 |
+
### tdd plan
|
| 261 |
|
| 262 |
```python
|
| 263 |
+
# tests/data/test_loader.py
|
| 264 |
|
| 265 |
+
def test_load_from_local_returns_local_dataset(synthetic_isles_dir):
|
| 266 |
+
"""Local mode returns LocalDataset."""
|
| 267 |
+
...
|
| 268 |
|
| 269 |
+
def test_load_from_local_finds_all_cases(synthetic_isles_dir):
|
| 270 |
+
"""Finds all cases in synthetic structure."""
|
| 271 |
+
...
|
| 272 |
|
| 273 |
+
# tests/data/test_adapter.py
|
| 274 |
|
| 275 |
+
def test_parse_subject_id_extracts_correctly():
|
| 276 |
+
"""Extracts subject ID from BIDS filename."""
|
| 277 |
+
assert parse_subject_id("sub-stroke0005_ses-02_dwi.nii.gz") == "sub-stroke0005"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
+
def test_build_local_dataset_matches_files(synthetic_isles_dir):
|
| 280 |
+
"""Matches DWI, ADC, Mask by subject ID."""
|
| 281 |
+
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
+
def test_get_case_returns_case_files(synthetic_isles_dir):
|
| 284 |
+
"""get_case returns CaseFiles with correct paths."""
|
| 285 |
+
...
|
| 286 |
```
|
| 287 |
|
| 288 |
+
### done criteria (phase 1a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
+
- [ ] `uv run pytest tests/data/ -v` passes
|
| 291 |
+
- [ ] Can load all 149 cases from `data/scratch/isles24_extracted/`
|
| 292 |
+
- [ ] `list_case_ids()` returns 149 subject IDs
|
| 293 |
+
- [ ] `get_case("sub-stroke0005")` returns valid CaseFiles
|
| 294 |
+
- [ ] Type checking passes: `uv run mypy src/stroke_deepisles_demo/data/`
|
| 295 |
|
| 296 |
+
---
|
|
|
|
|
|
|
| 297 |
|
| 298 |
+
## phase 1b: test tobias's nifti feature (NEXT)
|
|
|
|
|
|
|
| 299 |
|
| 300 |
+
### purpose
|
| 301 |
|
| 302 |
+
Verify that Tobias's `Nifti()` feature type from the datasets fork can correctly load/parse NIfTI files. This proves the **loading** part of the consumption pipeline works, even though the **download** part is broken.
|
|
|
|
| 303 |
|
| 304 |
+
### approach
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
```python
|
| 307 |
+
# Test script to verify Nifti() feature works on local files
|
| 308 |
+
from datasets import Features, Value
|
| 309 |
+
from datasets.features import Nifti # From Tobias's fork
|
| 310 |
+
|
| 311 |
+
# Create a simple dataset from local files
|
| 312 |
+
features = Features({
|
| 313 |
+
"subject_id": Value("string"),
|
| 314 |
+
"dwi": Nifti(),
|
| 315 |
+
"adc": Nifti(),
|
| 316 |
+
"mask": Nifti(),
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
# Load a single case and verify Nifti() decodes correctly
|
| 320 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
### done criteria (phase 1b)
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
+
- [ ] Tobias's `Nifti()` feature loads local `.nii.gz` files
|
| 325 |
+
- [ ] Decoded NIfTI has correct shape/dtype
|
| 326 |
+
- [ ] Can access voxel data via nibabel-like interface
|
| 327 |
|
| 328 |
+
---
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
## phase 1c: proper huggingface upload (FUTURE)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
+
### purpose
|
| 333 |
|
| 334 |
+
Re-upload ISLES24 data to HuggingFace **properly** using the arc-aphasia-bids approach. This proves the **production** pipeline works.
|
| 335 |
|
| 336 |
+
### approach
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
+
1. Use BIDS loader from Tobias's fork
|
| 339 |
+
2. Create proper parquet schema with columns:
|
| 340 |
+
- `subject`: string
|
| 341 |
+
- `session`: string
|
| 342 |
+
- `dwi`: Nifti()
|
| 343 |
+
- `adc`: Nifti()
|
| 344 |
+
- `mask`: Nifti()
|
| 345 |
+
3. Upload to new HuggingFace repo (e.g., `The-Obstacle-Is-The-Way/ISLES24-BIDS`)
|
| 346 |
|
| 347 |
+
### done criteria (phase 1c)
|
| 348 |
|
| 349 |
+
- [ ] Dataset uploaded to HuggingFace with proper schema
|
| 350 |
+
- [ ] HuggingFace dataset viewer shows data correctly
|
| 351 |
+
- [ ] `load_dataset("new-repo-id")` returns Dataset with expected columns
|
| 352 |
|
| 353 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
+
## phase 1d: consumption verification (FUTURE)
|
| 356 |
|
| 357 |
+
### purpose
|
|
|
|
| 358 |
|
| 359 |
+
Verify the full round-trip: Download from HuggingFace using Tobias's fork.
|
| 360 |
|
| 361 |
+
### approach
|
|
|
|
|
|
|
| 362 |
|
| 363 |
+
```python
|
| 364 |
+
from datasets import load_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
# This should work after Phase 1C
|
| 367 |
+
ds = load_dataset("The-Obstacle-Is-The-Way/ISLES24-BIDS")
|
| 368 |
+
case = ds["train"][0]
|
| 369 |
+
print(case["dwi"].shape) # Should work!
|
| 370 |
```
|
| 371 |
|
| 372 |
+
### new adapter function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
+
When Phase 1D is implemented, `adapter.py` will need a new function alongside `build_local_dataset`:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
```python
|
| 377 |
+
def adapt_hf_case(hf_row: dict) -> CaseFiles:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
"""
|
| 379 |
+
Adapt a HuggingFace Dataset row to CaseFiles.
|
| 380 |
|
| 381 |
+
Args:
|
| 382 |
+
hf_row: Row from load_dataset() with columns:
|
| 383 |
+
- dwi: Nifti feature (nibabel-like object)
|
| 384 |
+
- adc: Nifti feature
|
| 385 |
+
- mask: Nifti feature
|
| 386 |
+
- subject: str
|
| 387 |
|
| 388 |
+
Returns:
|
| 389 |
+
CaseFiles with materialized paths or nibabel objects
|
| 390 |
+
"""
|
| 391 |
+
# Implementation depends on how Nifti() feature exposes data
|
| 392 |
+
# May need to write to temp files or pass nibabel objects directly
|
| 393 |
+
...
|
| 394 |
+
```
|
| 395 |
|
| 396 |
+
This maintains the same `CaseFiles` contract for downstream phases regardless of data source.
|
|
|
|
|
|
|
| 397 |
|
| 398 |
+
### done criteria (phase 1d)
|
| 399 |
|
| 400 |
+
- [ ] `load_dataset()` works on properly uploaded dataset
|
| 401 |
+
- [ ] `adapt_hf_case()` function converts HF rows to CaseFiles
|
| 402 |
+
- [ ] Full demo runs with HuggingFace consumption (not just local files)
|
| 403 |
+
- [ ] Documents the pitfall for future projects
|
| 404 |
|
| 405 |
+
---
|
|
|
|
|
|
|
| 406 |
|
| 407 |
+
## dependencies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
+
No new dependencies needed beyond Phase 0.
|
| 410 |
|
| 411 |
+
## notes
|
| 412 |
|
| 413 |
+
- The original `adapter.py` assumed HF Dataset with columns - COMPLETELY WRONG
|
| 414 |
+
- The original `loader.py` called `load_dataset()` directly - FAILS on this dataset
|
| 415 |
+
- `staging.py` is still correct - it just needs `CaseFiles` with paths
|
|
|
docs/specs/data-discovery.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data discovery & verification protocol
|
| 2 |
+
|
| 3 |
+
## purpose
|
| 4 |
+
To establish a rigorous, reproducible process for exploring, verifying, and documenting external data sources (Hugging Face Datasets, BIDS repos, etc.) before integrating them into the production codebase. This prevents "schema guessing" and ensures strict typing aligns with reality.
|
| 5 |
+
|
| 6 |
+
## principles
|
| 7 |
+
1. **No Assumptions**: Never assume column names, file formats, or data types. Verify them programmatically.
|
| 8 |
+
2. **Isolation**: Discovery scripts and their outputs must be isolated from production code and source control.
|
| 9 |
+
3. **Reproducibility**: The discovery process must be scriptable and reproducible, not a series of manual CLI commands.
|
| 10 |
+
|
| 11 |
+
## standard locations
|
| 12 |
+
|
| 13 |
+
### scripts
|
| 14 |
+
All discovery logic resides in:
|
| 15 |
+
```
|
| 16 |
+
scripts/discovery/
|
| 17 |
+
├── __init__.py
|
| 18 |
+
├── inspect_hf_dataset.py # e.g., Generic HF inspector
|
| 19 |
+
├── verify_bids_layout.py # e.g., BIDS validator
|
| 20 |
+
└── ...
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### data & artifacts
|
| 24 |
+
All downloaded samples, temporary outputs, and schema reports reside in:
|
| 25 |
+
```
|
| 26 |
+
data/scratch/
|
| 27 |
+
├── .gitkeep # Tracked
|
| 28 |
+
├── schema_report.txt # Generated report
|
| 29 |
+
└── samples/ # Raw data samples (IGNORED)
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## discovery workflow
|
| 33 |
+
|
| 34 |
+
### 1. implementation
|
| 35 |
+
Write a focused script in `scripts/discovery/` that:
|
| 36 |
+
- Connects to the data source (e.g., HF Hub).
|
| 37 |
+
- Fetches *metadata* or a *minimal sample* (streaming mode preferred).
|
| 38 |
+
- Prints/Logs:
|
| 39 |
+
- Feature keys (column names).
|
| 40 |
+
- Data types (Arrow types, Python types).
|
| 41 |
+
- Non-null counts (if feasible).
|
| 42 |
+
- A sample row structure.
|
| 43 |
+
|
| 44 |
+
### 2. execution
|
| 45 |
+
Run the script from the project root:
|
| 46 |
+
```bash
|
| 47 |
+
uv run scripts/discovery/inspect_hf_dataset.py > data/scratch/schema_report.txt
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### 3. verification
|
| 51 |
+
Manually review `data/scratch/schema_report.txt`.
|
| 52 |
+
- **Check**: Do column names match `CaseAdapter` expectations?
|
| 53 |
+
- **Check**: Are file paths strings or objects?
|
| 54 |
+
- **Check**: Are required fields (DWI, ADC) actually present?
|
| 55 |
+
|
| 56 |
+
### 4. remediation
|
| 57 |
+
If the report contradicts the code/specs:
|
| 58 |
+
1. Update the spec (`docs/specs/`) to reflect reality.
|
| 59 |
+
2. Update the code (`src/.../adapter.py`) to handle the actual schema.
|
| 60 |
+
3. Add a regression test if the edge case is complex.
|
| 61 |
+
|
| 62 |
+
## git configuration
|
| 63 |
+
Ensure `.gitignore` includes:
|
| 64 |
+
```gitignore
|
| 65 |
+
data/scratch/*
|
| 66 |
+
!data/scratch/.gitkeep
|
| 67 |
+
```
|
pyproject.toml
CHANGED
|
@@ -118,6 +118,7 @@ addopts = [
|
|
| 118 |
"-v",
|
| 119 |
"--tb=short",
|
| 120 |
"--strict-markers",
|
|
|
|
| 121 |
]
|
| 122 |
markers = [
|
| 123 |
"integration: marks tests requiring external resources (Docker, network)",
|
|
|
|
| 118 |
"-v",
|
| 119 |
"--tb=short",
|
| 120 |
"--strict-markers",
|
| 121 |
+
"-m", "not integration", # Skip integration tests by default
|
| 122 |
]
|
| 123 |
markers = [
|
| 124 |
"integration: marks tests requiring external resources (Docker, network)",
|
scripts/discovery/__init__.py
ADDED
|
File without changes
|
scripts/discovery/inspect_isles24.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ISLES24-MR-Lite Dataset Discovery Script
|
| 4 |
+
|
| 5 |
+
Downloads and inspects the full YongchengYAO/ISLES24-MR-Lite dataset
|
| 6 |
+
to document its exact schema before building adapters.
|
| 7 |
+
|
| 8 |
+
Per: docs/specs/data-discovery.md
|
| 9 |
+
|
| 10 |
+
Output: data/scratch/isles24_schema_report.txt
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from collections import Counter
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
# Constants
|
| 22 |
+
DATASET_ID = "YongchengYAO/ISLES24-MR-Lite"
|
| 23 |
+
OUTPUT_DIR = Path(__file__).parent.parent.parent / "data" / "scratch"
|
| 24 |
+
REPORT_FILE = OUTPUT_DIR / "isles24_schema_report.txt"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def safe_type_name(val: Any) -> str:
|
| 28 |
+
"""Get a safe string representation of a value's type."""
|
| 29 |
+
if val is None:
|
| 30 |
+
return "None"
|
| 31 |
+
t = type(val).__name__
|
| 32 |
+
if hasattr(val, "dtype"):
|
| 33 |
+
return f"{t}[{val.dtype}]"
|
| 34 |
+
return t
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def safe_repr(val: Any, max_len: int = 100) -> str:
|
| 38 |
+
"""Get a safe truncated repr of a value."""
|
| 39 |
+
if val is None:
|
| 40 |
+
return "None"
|
| 41 |
+
if isinstance(val, bytes):
|
| 42 |
+
return f"<bytes len={len(val)}>"
|
| 43 |
+
if isinstance(val, dict):
|
| 44 |
+
if "bytes" in val:
|
| 45 |
+
return f"<dict with 'bytes' key, len={len(val.get('bytes', b''))}>"
|
| 46 |
+
return f"<dict keys={list(val.keys())}>"
|
| 47 |
+
r = repr(val)
|
| 48 |
+
if len(r) > max_len:
|
| 49 |
+
return r[: max_len - 3] + "..."
|
| 50 |
+
return r
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main() -> int:
|
| 54 |
+
"""Main discovery workflow."""
|
| 55 |
+
print("=" * 70)
|
| 56 |
+
print("ISLES24-MR-Lite Dataset Discovery")
|
| 57 |
+
print(f"Started: {datetime.now().isoformat()}")
|
| 58 |
+
print("=" * 70)
|
| 59 |
+
print()
|
| 60 |
+
|
| 61 |
+
# Ensure output directory exists
|
| 62 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
# Import datasets library
|
| 65 |
+
try:
|
| 66 |
+
from datasets import load_dataset
|
| 67 |
+
except ImportError:
|
| 68 |
+
print("ERROR: 'datasets' library not installed.")
|
| 69 |
+
print("Run: uv add datasets")
|
| 70 |
+
return 1
|
| 71 |
+
|
| 72 |
+
# =========================================================================
|
| 73 |
+
# PHASE 1: Load Dataset (Full Download)
|
| 74 |
+
# =========================================================================
|
| 75 |
+
print(f"[1/4] Loading dataset: {DATASET_ID}")
|
| 76 |
+
print(" This will download the FULL dataset...")
|
| 77 |
+
print()
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
# Try loading without streaming first to get full access
|
| 81 |
+
ds = load_dataset(DATASET_ID)
|
| 82 |
+
print(" SUCCESS: Dataset loaded")
|
| 83 |
+
print(f" Splits available: {list(ds.keys())}")
|
| 84 |
+
print()
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f" ERROR loading dataset: {e}")
|
| 87 |
+
print()
|
| 88 |
+
print(" Trying streaming mode as fallback...")
|
| 89 |
+
try:
|
| 90 |
+
ds = load_dataset(DATASET_ID, streaming=True)
|
| 91 |
+
print(" SUCCESS (streaming): Dataset loaded")
|
| 92 |
+
print(f" Splits available: {list(ds.keys())}")
|
| 93 |
+
except Exception as e2:
|
| 94 |
+
print(f" FATAL: Cannot load dataset: {e2}")
|
| 95 |
+
return 1
|
| 96 |
+
|
| 97 |
+
# =========================================================================
|
| 98 |
+
# PHASE 2: Inspect Schema (Features)
|
| 99 |
+
# =========================================================================
|
| 100 |
+
print("[2/4] Inspecting schema...")
|
| 101 |
+
print()
|
| 102 |
+
|
| 103 |
+
report_lines: list[str] = []
|
| 104 |
+
report_lines.append("=" * 70)
|
| 105 |
+
report_lines.append("ISLES24-MR-Lite Schema Discovery Report")
|
| 106 |
+
report_lines.append(f"Generated: {datetime.now().isoformat()}")
|
| 107 |
+
report_lines.append(f"Dataset: {DATASET_ID}")
|
| 108 |
+
report_lines.append("=" * 70)
|
| 109 |
+
report_lines.append("")
|
| 110 |
+
|
| 111 |
+
for split_name in ds:
|
| 112 |
+
split = ds[split_name]
|
| 113 |
+
report_lines.append(f"SPLIT: {split_name}")
|
| 114 |
+
report_lines.append("-" * 50)
|
| 115 |
+
|
| 116 |
+
# Get features/schema
|
| 117 |
+
if hasattr(split, "features"):
|
| 118 |
+
features = split.features
|
| 119 |
+
report_lines.append(
|
| 120 |
+
f"Number of rows: {len(split) if hasattr(split, '__len__') else 'unknown (streaming)'}"
|
| 121 |
+
)
|
| 122 |
+
report_lines.append("")
|
| 123 |
+
report_lines.append("FEATURES (columns):")
|
| 124 |
+
for feat_name, feat_type in features.items():
|
| 125 |
+
report_lines.append(f" - {feat_name}: {feat_type}")
|
| 126 |
+
report_lines.append("")
|
| 127 |
+
else:
|
| 128 |
+
report_lines.append(" (No features metadata available)")
|
| 129 |
+
report_lines.append("")
|
| 130 |
+
|
| 131 |
+
print(" Schema extracted.")
|
| 132 |
+
print()
|
| 133 |
+
|
| 134 |
+
# =========================================================================
|
| 135 |
+
# PHASE 3: Sample Inspection (check actual data)
|
| 136 |
+
# =========================================================================
|
| 137 |
+
print("[3/4] Inspecting sample rows...")
|
| 138 |
+
print()
|
| 139 |
+
|
| 140 |
+
# Use the first available split (usually 'train')
|
| 141 |
+
main_split_name = next(iter(ds.keys()))
|
| 142 |
+
main_split = ds[main_split_name]
|
| 143 |
+
|
| 144 |
+
report_lines.append("=" * 70)
|
| 145 |
+
report_lines.append("SAMPLE DATA INSPECTION")
|
| 146 |
+
report_lines.append("=" * 70)
|
| 147 |
+
report_lines.append("")
|
| 148 |
+
|
| 149 |
+
# Check first 3 rows in detail
|
| 150 |
+
report_lines.append("First 3 rows (detailed):")
|
| 151 |
+
report_lines.append("-" * 50)
|
| 152 |
+
|
| 153 |
+
sample_count = 0
|
| 154 |
+
column_value_types: dict[str, Counter[str]] = {}
|
| 155 |
+
|
| 156 |
+
# Iterate through dataset
|
| 157 |
+
iterable = iter(main_split) if hasattr(main_split, "__iter__") else main_split
|
| 158 |
+
|
| 159 |
+
for i, row in enumerate(iterable):
|
| 160 |
+
if i < 3:
|
| 161 |
+
report_lines.append(f"\nROW {i}:")
|
| 162 |
+
for key, val in row.items():
|
| 163 |
+
val_type = safe_type_name(val)
|
| 164 |
+
val_repr = safe_repr(val)
|
| 165 |
+
report_lines.append(f" {key}:")
|
| 166 |
+
report_lines.append(f" type: {val_type}")
|
| 167 |
+
report_lines.append(f" value: {val_repr}")
|
| 168 |
+
|
| 169 |
+
# Track types for all rows
|
| 170 |
+
for key, val in row.items():
|
| 171 |
+
if key not in column_value_types:
|
| 172 |
+
column_value_types[key] = Counter()
|
| 173 |
+
column_value_types[key][safe_type_name(val)] += 1
|
| 174 |
+
|
| 175 |
+
sample_count += 1
|
| 176 |
+
|
| 177 |
+
# Progress indicator
|
| 178 |
+
if sample_count % 50 == 0:
|
| 179 |
+
print(f" Processed {sample_count} rows...")
|
| 180 |
+
|
| 181 |
+
print(f" Total rows processed: {sample_count}")
|
| 182 |
+
print()
|
| 183 |
+
|
| 184 |
+
# =========================================================================
|
| 185 |
+
# PHASE 4: Consistency Check
|
| 186 |
+
# =========================================================================
|
| 187 |
+
print("[4/4] Checking consistency across all rows...")
|
| 188 |
+
print()
|
| 189 |
+
|
| 190 |
+
report_lines.append("")
|
| 191 |
+
report_lines.append("=" * 70)
|
| 192 |
+
report_lines.append("CONSISTENCY ANALYSIS (all rows)")
|
| 193 |
+
report_lines.append("=" * 70)
|
| 194 |
+
report_lines.append("")
|
| 195 |
+
report_lines.append(f"Total rows analyzed: {sample_count}")
|
| 196 |
+
report_lines.append("")
|
| 197 |
+
|
| 198 |
+
report_lines.append("Column type distribution:")
|
| 199 |
+
report_lines.append("-" * 50)
|
| 200 |
+
for col_name, type_counts in column_value_types.items():
|
| 201 |
+
report_lines.append(f"\n {col_name}:")
|
| 202 |
+
for type_name, count in type_counts.most_common():
|
| 203 |
+
pct = (count / sample_count) * 100
|
| 204 |
+
report_lines.append(f" {type_name}: {count} ({pct:.1f}%)")
|
| 205 |
+
|
| 206 |
+
# =========================================================================
|
| 207 |
+
# PHASE 5: CaseAdapter Compatibility Check
|
| 208 |
+
# =========================================================================
|
| 209 |
+
report_lines.append("")
|
| 210 |
+
report_lines.append("=" * 70)
|
| 211 |
+
report_lines.append("CASEADAPTER COMPATIBILITY CHECK")
|
| 212 |
+
report_lines.append("=" * 70)
|
| 213 |
+
report_lines.append("")
|
| 214 |
+
|
| 215 |
+
expected_columns = ["dwi", "adc", "flair", "mask", "ground_truth", "participant_id"]
|
| 216 |
+
actual_columns = list(column_value_types.keys())
|
| 217 |
+
|
| 218 |
+
report_lines.append("Expected by CaseAdapter:")
|
| 219 |
+
for col in expected_columns:
|
| 220 |
+
status = "FOUND" if col in actual_columns else "MISSING"
|
| 221 |
+
report_lines.append(f" {col}: {status}")
|
| 222 |
+
|
| 223 |
+
report_lines.append("")
|
| 224 |
+
report_lines.append("Actual columns in dataset:")
|
| 225 |
+
for col in actual_columns:
|
| 226 |
+
expected = "expected" if col in expected_columns else "UNEXPECTED"
|
| 227 |
+
report_lines.append(f" {col}: {expected}")
|
| 228 |
+
|
| 229 |
+
report_lines.append("")
|
| 230 |
+
report_lines.append("=" * 70)
|
| 231 |
+
report_lines.append("END OF REPORT")
|
| 232 |
+
report_lines.append("=" * 70)
|
| 233 |
+
|
| 234 |
+
# Write report
|
| 235 |
+
report_content = "\n".join(report_lines)
|
| 236 |
+
REPORT_FILE.write_text(report_content)
|
| 237 |
+
|
| 238 |
+
print(f"Report written to: {REPORT_FILE}")
|
| 239 |
+
print()
|
| 240 |
+
print("=" * 70)
|
| 241 |
+
print("DISCOVERY COMPLETE")
|
| 242 |
+
print("=" * 70)
|
| 243 |
+
print()
|
| 244 |
+
print("Next steps:")
|
| 245 |
+
print(f" 1. Review: {REPORT_FILE}")
|
| 246 |
+
print(" 2. Compare findings against src/stroke_deepisles_demo/data/adapter.py")
|
| 247 |
+
print(" 3. Update adapter if schema differs from expectations")
|
| 248 |
+
print()
|
| 249 |
+
|
| 250 |
+
# Print summary to stdout as well
|
| 251 |
+
print("-" * 70)
|
| 252 |
+
print("QUICK SUMMARY:")
|
| 253 |
+
print("-" * 70)
|
| 254 |
+
print(f"Columns found: {actual_columns}")
|
| 255 |
+
print()
|
| 256 |
+
missing = [c for c in expected_columns if c not in actual_columns]
|
| 257 |
+
if missing:
|
| 258 |
+
print(f"WARNING: Expected columns MISSING: {missing}")
|
| 259 |
+
unexpected = [c for c in actual_columns if c not in expected_columns]
|
| 260 |
+
if unexpected:
|
| 261 |
+
print(f"NOTE: Unexpected columns found: {unexpected}")
|
| 262 |
+
|
| 263 |
+
return 0
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
sys.exit(main())
|
src/stroke_deepisles_demo/core/exceptions.py
CHANGED
|
@@ -21,3 +21,7 @@ class DeepISLESError(StrokeDemoError):
|
|
| 21 |
|
| 22 |
class MissingInputError(StrokeDemoError):
|
| 23 |
"""Required input files are missing."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
class MissingInputError(StrokeDemoError):
|
| 23 |
"""Required input files are missing."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DockerGPUNotAvailableError(StrokeDemoError):
|
| 27 |
+
"""GPU requested but NVIDIA Container Runtime not available."""
|
src/stroke_deepisles_demo/data/__init__.py
CHANGED
|
@@ -1,27 +1,21 @@
|
|
| 1 |
"""Data loading and case management for stroke-deepisles-demo."""
|
| 2 |
|
| 3 |
-
from stroke_deepisles_demo.
|
| 4 |
-
from stroke_deepisles_demo.data.
|
|
|
|
| 5 |
from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
-
# Adapter
|
| 9 |
-
"CaseAdapter",
|
| 10 |
-
# Loader
|
| 11 |
"DatasetInfo",
|
| 12 |
-
|
| 13 |
"StagedCase",
|
| 14 |
"get_case",
|
| 15 |
-
"get_dataset_info",
|
| 16 |
"list_case_ids",
|
| 17 |
"load_isles_dataset",
|
| 18 |
"stage_case_for_deepisles",
|
| 19 |
]
|
| 20 |
|
| 21 |
|
| 22 |
-
from stroke_deepisles_demo.core.types import CaseFiles
|
| 23 |
-
|
| 24 |
-
|
| 25 |
# Convenience functions (combine loader + adapter)
|
| 26 |
def get_case(case_id: str | int) -> CaseFiles:
|
| 27 |
"""
|
|
@@ -31,12 +25,10 @@ def get_case(case_id: str | int) -> CaseFiles:
|
|
| 31 |
CaseFiles dictionary
|
| 32 |
"""
|
| 33 |
dataset = load_isles_dataset()
|
| 34 |
-
|
| 35 |
-
return adapter.get_case(case_id)
|
| 36 |
|
| 37 |
|
| 38 |
def list_case_ids() -> list[str]:
|
| 39 |
"""List all available case IDs."""
|
| 40 |
dataset = load_isles_dataset()
|
| 41 |
-
|
| 42 |
-
return adapter.list_case_ids()
|
|
|
|
| 1 |
"""Data loading and case management for stroke-deepisles-demo."""
|
| 2 |
|
| 3 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 4 |
+
from stroke_deepisles_demo.data.adapter import LocalDataset
|
| 5 |
+
from stroke_deepisles_demo.data.loader import DatasetInfo, load_isles_dataset
|
| 6 |
from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
|
| 7 |
|
| 8 |
__all__ = [
|
|
|
|
|
|
|
|
|
|
| 9 |
"DatasetInfo",
|
| 10 |
+
"LocalDataset",
|
| 11 |
"StagedCase",
|
| 12 |
"get_case",
|
|
|
|
| 13 |
"list_case_ids",
|
| 14 |
"load_isles_dataset",
|
| 15 |
"stage_case_for_deepisles",
|
| 16 |
]
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
| 19 |
# Convenience functions (combine loader + adapter)
|
| 20 |
def get_case(case_id: str | int) -> CaseFiles:
|
| 21 |
"""
|
|
|
|
| 25 |
CaseFiles dictionary
|
| 26 |
"""
|
| 27 |
dataset = load_isles_dataset()
|
| 28 |
+
return dataset.get_case(case_id)
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def list_case_ids() -> list[str]:
|
| 32 |
"""List all available case IDs."""
|
| 33 |
dataset = load_isles_dataset()
|
| 34 |
+
return dataset.list_case_ids()
|
|
|
src/stroke_deepisles_demo/data/adapter.py
CHANGED
|
@@ -1,147 +1,84 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 9 |
-
from stroke_deepisles_demo.core.types import CaseFiles
|
| 10 |
|
| 11 |
if TYPE_CHECKING:
|
| 12 |
from collections.abc import Iterator
|
|
|
|
| 13 |
|
| 14 |
-
from
|
| 15 |
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, dataset: Dataset) -> None:
|
| 26 |
-
"""
|
| 27 |
-
Initialize adapter with a loaded dataset.
|
| 28 |
-
|
| 29 |
-
Args:
|
| 30 |
-
dataset: HuggingFace Dataset with NIfTI files
|
| 31 |
-
"""
|
| 32 |
-
self.dataset = dataset
|
| 33 |
-
self._case_id_map = self._build_case_id_map()
|
| 34 |
-
|
| 35 |
-
def _build_case_id_map(self) -> dict[str, int]:
|
| 36 |
-
"""Build mapping from case ID to index."""
|
| 37 |
-
case_map = {}
|
| 38 |
-
# Assuming dataset has 'participant_id' or similar
|
| 39 |
-
# If not, we might need to generate IDs or use index
|
| 40 |
-
|
| 41 |
-
# Check features to find ID column
|
| 42 |
-
id_col = "participant_id"
|
| 43 |
-
if id_col not in self.dataset.features:
|
| 44 |
-
# Fallback: try to find a string column that looks like an ID
|
| 45 |
-
# Or just use f"case_{i}"
|
| 46 |
-
pass
|
| 47 |
-
|
| 48 |
-
# Iterate to build map
|
| 49 |
-
# This might be slow for huge datasets, but for 149 cases it's fine
|
| 50 |
-
for idx, row in enumerate(self.dataset):
|
| 51 |
-
case_id = row.get(id_col, f"case_{idx:03d}")
|
| 52 |
-
case_map[str(case_id)] = idx
|
| 53 |
-
|
| 54 |
-
return case_map
|
| 55 |
|
| 56 |
def __len__(self) -> int:
|
| 57 |
-
|
| 58 |
-
return len(self.dataset)
|
| 59 |
|
| 60 |
def __iter__(self) -> Iterator[str]:
|
| 61 |
-
|
| 62 |
-
return iter(self._case_id_map.keys())
|
| 63 |
|
| 64 |
def list_case_ids(self) -> list[str]:
|
| 65 |
-
"""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
Returns:
|
| 69 |
-
List of case IDs (e.g., ["sub-001", "sub-002", ...])
|
| 70 |
-
"""
|
| 71 |
-
return list(self._case_id_map.keys())
|
| 72 |
|
| 73 |
def get_case(self, case_id: str | int) -> CaseFiles:
|
| 74 |
-
"""
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
Args:
|
| 78 |
-
case_id: Either a string ID (e.g., "sub-001") or integer index
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
# Helper to ensure we return Path if it's a local string path, or keep as is
|
| 127 |
-
def to_path_or_raw(val: Any) -> Any:
|
| 128 |
-
if isinstance(val, str) and not val.startswith(("http://", "https://")):
|
| 129 |
-
return Path(val)
|
| 130 |
-
return val
|
| 131 |
-
|
| 132 |
-
dwi = to_path_or_raw(row.get("dwi"))
|
| 133 |
-
adc = to_path_or_raw(row.get("adc"))
|
| 134 |
-
flair = to_path_or_raw(row.get("flair"))
|
| 135 |
-
ground_truth = to_path_or_raw(row.get("mask") or row.get("ground_truth"))
|
| 136 |
-
|
| 137 |
-
if not dwi or not adc:
|
| 138 |
-
raise DataLoadError("Case missing required DWI or ADC files")
|
| 139 |
-
|
| 140 |
-
case_files = CaseFiles(dwi=dwi, adc=adc)
|
| 141 |
-
|
| 142 |
-
if flair:
|
| 143 |
-
case_files["flair"] = flair
|
| 144 |
-
if ground_truth:
|
| 145 |
-
case_files["ground_truth"] = ground_truth
|
| 146 |
-
|
| 147 |
-
return case_files
|
|
|
|
| 1 |
+
"""Provide typed access to ISLES24 cases."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import re
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import TYPE_CHECKING
|
|
|
|
|
|
|
| 8 |
|
| 9 |
if TYPE_CHECKING:
|
| 10 |
from collections.abc import Iterator
|
| 11 |
+
from pathlib import Path
|
| 12 |
|
| 13 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 14 |
|
| 15 |
|
| 16 |
+
@dataclass
|
| 17 |
+
class LocalDataset:
|
| 18 |
+
"""File-based dataset for local ISLES24 data."""
|
| 19 |
|
| 20 |
+
data_dir: Path
|
| 21 |
+
cases: dict[str, CaseFiles] # subject_id -> files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def __len__(self) -> int:
|
| 24 |
+
return len(self.cases)
|
|
|
|
| 25 |
|
| 26 |
def __iter__(self) -> Iterator[str]:
|
| 27 |
+
return iter(self.cases.keys())
|
|
|
|
| 28 |
|
| 29 |
def list_case_ids(self) -> list[str]:
|
| 30 |
+
"""Return sorted list of subject IDs."""
|
| 31 |
+
return sorted(self.cases.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def get_case(self, case_id: str | int) -> CaseFiles:
|
| 34 |
+
"""Get files for a case by ID or index."""
|
| 35 |
+
if isinstance(case_id, int):
|
| 36 |
+
case_id = self.list_case_ids()[case_id]
|
| 37 |
+
return self.cases[case_id]
|
| 38 |
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
# Subject ID extraction
|
| 41 |
+
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
|
| 42 |
|
| 43 |
+
|
| 44 |
+
def parse_subject_id(filename: str) -> str | None:
|
| 45 |
+
"""Extract subject ID from BIDS filename."""
|
| 46 |
+
match = SUBJECT_PATTERN.match(filename)
|
| 47 |
+
return f"sub-{match.group(1)}" if match else None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_local_dataset(data_dir: Path) -> LocalDataset:
|
| 51 |
+
"""
|
| 52 |
+
Scan directory and build case mapping.
|
| 53 |
+
|
| 54 |
+
Matches DWI + ADC + Mask files by subject ID.
|
| 55 |
+
"""
|
| 56 |
+
dwi_dir = data_dir / "Images-DWI"
|
| 57 |
+
adc_dir = data_dir / "Images-ADC"
|
| 58 |
+
mask_dir = data_dir / "Masks"
|
| 59 |
+
|
| 60 |
+
cases: dict[str, CaseFiles] = {}
|
| 61 |
+
|
| 62 |
+
# Scan DWI files to get subject IDs
|
| 63 |
+
for dwi_file in dwi_dir.glob("*.nii.gz"):
|
| 64 |
+
subject_id = parse_subject_id(dwi_file.name)
|
| 65 |
+
if not subject_id:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
# Find matching ADC and Mask
|
| 69 |
+
adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.")
|
| 70 |
+
mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.")
|
| 71 |
+
|
| 72 |
+
if not adc_file.exists():
|
| 73 |
+
continue # Skip incomplete cases
|
| 74 |
+
|
| 75 |
+
case_files: CaseFiles = {
|
| 76 |
+
"dwi": dwi_file,
|
| 77 |
+
"adc": adc_file,
|
| 78 |
+
}
|
| 79 |
+
if mask_file.exists():
|
| 80 |
+
case_files["ground_truth"] = mask_file
|
| 81 |
+
|
| 82 |
+
cases[subject_id] = case_files
|
| 83 |
+
|
| 84 |
+
return LocalDataset(data_dir=data_dir, cases=cases)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/stroke_deepisles_demo/data/loader.py
CHANGED
|
@@ -1,138 +1,47 @@
|
|
| 1 |
-
"""Load ISLES24
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
|
|
|
| 6 |
from typing import TYPE_CHECKING
|
| 7 |
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
-
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def load_isles_dataset(
|
| 19 |
-
|
| 20 |
*,
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
) -> Dataset:
|
| 24 |
"""
|
| 25 |
-
Load
|
| 26 |
|
| 27 |
Args:
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
streaming: If True, use streaming mode (lazy loading)
|
| 31 |
|
| 32 |
Returns:
|
| 33 |
-
|
| 34 |
|
| 35 |
Raises:
|
| 36 |
-
|
| 37 |
"""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# We pass trust_remote_code=True if needed for custom scripts,
|
| 41 |
-
# but standard datasets usually don't need it unless using custom builder.
|
| 42 |
-
# ISLES24-MR-Lite is likely a standard dataset or Parquet-based.
|
| 43 |
-
# If it's BIDS, we might need type="bids" if the PR features are used that way.
|
| 44 |
-
# For now, standard load_dataset.
|
| 45 |
-
|
| 46 |
-
ds = load_dataset(
|
| 47 |
-
dataset_id,
|
| 48 |
-
cache_dir=str(cache_dir) if cache_dir else None,
|
| 49 |
-
streaming=streaming,
|
| 50 |
-
# If the dataset is BIDS, we might need a specific config/builder.
|
| 51 |
-
# Assuming default works or it's already parquet.
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
# If streaming, load_dataset returns IterableDataset.
|
| 55 |
-
# If not, it returns DatasetDict or Dataset.
|
| 56 |
-
# We assume it returns the 'train' split if it's a DatasetDict, or we handle it.
|
| 57 |
-
# Usually load_dataset returns DatasetDict unless split is specified.
|
| 58 |
-
|
| 59 |
-
if hasattr(ds, "keys"):
|
| 60 |
-
keys = list(ds.keys())
|
| 61 |
-
if "train" in keys:
|
| 62 |
-
return ds["train"]
|
| 63 |
-
elif len(keys) > 0:
|
| 64 |
-
# Fallback to first split if 'train' not found
|
| 65 |
-
return ds[keys[0]]
|
| 66 |
-
|
| 67 |
-
return ds
|
| 68 |
-
|
| 69 |
-
except Exception as e:
|
| 70 |
-
raise DataLoadError(f"Failed to load dataset {dataset_id}: {e}") from e
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
@dataclass
|
| 74 |
-
class DatasetInfo:
|
| 75 |
-
"""Metadata about the loaded dataset."""
|
| 76 |
-
|
| 77 |
-
dataset_id: str
|
| 78 |
-
num_cases: int
|
| 79 |
-
modalities: list[str] # e.g., ["dwi", "adc", "mask"]
|
| 80 |
-
has_ground_truth: bool
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def get_dataset_info(dataset_id: str = "YongchengYAO/ISLES24-MR-Lite") -> DatasetInfo:
|
| 84 |
-
"""
|
| 85 |
-
Get metadata about the dataset without downloading (if possible).
|
| 86 |
-
|
| 87 |
-
Returns:
|
| 88 |
-
DatasetInfo with case count, available modalities, etc.
|
| 89 |
-
"""
|
| 90 |
-
try:
|
| 91 |
-
# Load in streaming mode to get features/info cheaply
|
| 92 |
-
ds = load_isles_dataset(dataset_id, streaming=True)
|
| 93 |
-
|
| 94 |
-
# Count cases (might be slow for streaming, but okay for demo scale)
|
| 95 |
-
# Or check if info is available
|
| 96 |
-
if hasattr(ds, "info") and ds.info.splits:
|
| 97 |
-
# Approximate from splits info if available
|
| 98 |
-
num_cases = ds.info.splits["train"].num_examples
|
| 99 |
-
else:
|
| 100 |
-
# Iterate to count? Or just rely on known size?
|
| 101 |
-
# For streaming, len() might not work.
|
| 102 |
-
# Let's just load non-streaming but with no data download? No.
|
| 103 |
-
# Let's just assume we can get length if we loaded it.
|
| 104 |
-
# If we loaded it streaming, we might not get length.
|
| 105 |
-
# For the demo, let's just try to get it.
|
| 106 |
-
|
| 107 |
-
# If we can't get length easily from streaming, we might need to trust metadata.
|
| 108 |
-
# Or just iterate (expensive).
|
| 109 |
-
# Let's use a safer approach: load non-streaming (lazy) might download metadata only.
|
| 110 |
-
# But datasets downloads parquet files.
|
| 111 |
-
|
| 112 |
-
# For get_dataset_info, maybe we just load it fully? No, expensive.
|
| 113 |
-
# Let's use streaming and try to get info.
|
| 114 |
-
num_cases = 0
|
| 115 |
-
# Use a fixed number if we can't determine?
|
| 116 |
-
# Or just count - 149 is small.
|
| 117 |
-
# But streaming iteration means network calls.
|
| 118 |
-
|
| 119 |
-
# Try to access info object
|
| 120 |
-
if hasattr(ds, "n_shards"):
|
| 121 |
-
# Approximate?
|
| 122 |
-
pass
|
| 123 |
-
|
| 124 |
-
# Fallback: 149 (known)
|
| 125 |
-
num_cases = 149
|
| 126 |
|
| 127 |
-
|
| 128 |
-
modalities = [k for k in features if k in ["dwi", "adc", "flair"]]
|
| 129 |
-
has_ground_truth = "mask" in features or "ground_truth" in features
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
num_cases=num_cases,
|
| 134 |
-
modalities=sorted(modalities),
|
| 135 |
-
has_ground_truth=has_ground_truth,
|
| 136 |
-
)
|
| 137 |
-
except Exception as e:
|
| 138 |
-
raise DataLoadError(f"Failed to get info for {dataset_id}: {e}") from e
|
|
|
|
| 1 |
+
"""Load ISLES24 data from local directory or HuggingFace Hub."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
from typing import TYPE_CHECKING
|
| 8 |
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from stroke_deepisles_demo.data.adapter import LocalDataset
|
| 11 |
|
|
|
|
| 12 |
|
| 13 |
+
@dataclass
|
| 14 |
+
class DatasetInfo:
|
| 15 |
+
"""Metadata about the dataset."""
|
| 16 |
|
| 17 |
+
source: str # "local" or HF dataset ID
|
| 18 |
+
num_cases: int
|
| 19 |
+
modalities: list[str]
|
| 20 |
+
has_ground_truth: bool
|
| 21 |
|
| 22 |
|
| 23 |
def load_isles_dataset(
|
| 24 |
+
source: str | Path = "data/scratch/isles24_extracted",
|
| 25 |
*,
|
| 26 |
+
local_mode: bool = True, # Default to local for now
|
| 27 |
+
) -> LocalDataset:
|
|
|
|
| 28 |
"""
|
| 29 |
+
Load ISLES24 dataset.
|
| 30 |
|
| 31 |
Args:
|
| 32 |
+
source: Local directory path or HuggingFace dataset ID
|
| 33 |
+
local_mode: If True, treat source as local directory
|
|
|
|
| 34 |
|
| 35 |
Returns:
|
| 36 |
+
Dataset-like object providing case access
|
| 37 |
|
| 38 |
Raises:
|
| 39 |
+
NotImplementedError: If non-local mode is requested
|
| 40 |
"""
|
| 41 |
+
if local_mode or isinstance(source, Path):
|
| 42 |
+
from stroke_deepisles_demo.data.adapter import build_local_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
return build_local_dataset(Path(source))
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
# Future: return _load_from_huggingface(source)
|
| 47 |
+
raise NotImplementedError("HuggingFace mode not yet implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/stroke_deepisles_demo/inference/__init__.py
CHANGED
|
@@ -1 +1,37 @@
|
|
| 1 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference module for stroke-deepisles-demo."""
|
| 2 |
+
|
| 3 |
+
from stroke_deepisles_demo.inference.deepisles import (
|
| 4 |
+
DEEPISLES_IMAGE,
|
| 5 |
+
DeepISLESResult,
|
| 6 |
+
find_prediction_mask,
|
| 7 |
+
run_deepisles_on_folder,
|
| 8 |
+
validate_input_folder,
|
| 9 |
+
)
|
| 10 |
+
from stroke_deepisles_demo.inference.docker import (
|
| 11 |
+
DockerRunResult,
|
| 12 |
+
build_docker_command,
|
| 13 |
+
check_docker_available,
|
| 14 |
+
check_nvidia_docker_available,
|
| 15 |
+
ensure_docker_available,
|
| 16 |
+
ensure_gpu_available_if_requested,
|
| 17 |
+
pull_image_if_missing,
|
| 18 |
+
run_container,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
# DeepISLES
|
| 23 |
+
"DEEPISLES_IMAGE",
|
| 24 |
+
"DeepISLESResult",
|
| 25 |
+
# Docker utilities
|
| 26 |
+
"DockerRunResult",
|
| 27 |
+
"build_docker_command",
|
| 28 |
+
"check_docker_available",
|
| 29 |
+
"check_nvidia_docker_available",
|
| 30 |
+
"ensure_docker_available",
|
| 31 |
+
"ensure_gpu_available_if_requested",
|
| 32 |
+
"find_prediction_mask",
|
| 33 |
+
"pull_image_if_missing",
|
| 34 |
+
"run_container",
|
| 35 |
+
"run_deepisles_on_folder",
|
| 36 |
+
"validate_input_folder",
|
| 37 |
+
]
|
src/stroke_deepisles_demo/inference/deepisles.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DeepISLES stroke segmentation wrapper."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
from stroke_deepisles_demo.core.exceptions import DeepISLESError, MissingInputError
|
| 10 |
+
from stroke_deepisles_demo.inference.docker import (
|
| 11 |
+
DockerRunResult,
|
| 12 |
+
ensure_gpu_available_if_requested,
|
| 13 |
+
run_container,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
# Constants
|
| 20 |
+
DEEPISLES_IMAGE = "isleschallenge/deepisles"
|
| 21 |
+
EXPECTED_INPUT_FILES = ["dwi.nii.gz", "adc.nii.gz"]
|
| 22 |
+
OPTIONAL_INPUT_FILES = ["flair.nii.gz"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class DeepISLESResult:
|
| 27 |
+
"""Result of DeepISLES inference."""
|
| 28 |
+
|
| 29 |
+
prediction_path: Path
|
| 30 |
+
docker_result: DockerRunResult
|
| 31 |
+
elapsed_seconds: float
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def validate_input_folder(input_dir: Path) -> tuple[Path, Path, Path | None]:
|
| 35 |
+
"""
|
| 36 |
+
Validate that input folder contains required files.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
input_dir: Directory to validate
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tuple of (dwi_path, adc_path, flair_path_or_none)
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
MissingInputError: If required files are missing
|
| 46 |
+
"""
|
| 47 |
+
dwi_path = input_dir / "dwi.nii.gz"
|
| 48 |
+
adc_path = input_dir / "adc.nii.gz"
|
| 49 |
+
flair_path = input_dir / "flair.nii.gz"
|
| 50 |
+
|
| 51 |
+
if not dwi_path.exists():
|
| 52 |
+
raise MissingInputError(f"Required file 'dwi.nii.gz' not found in {input_dir}")
|
| 53 |
+
|
| 54 |
+
if not adc_path.exists():
|
| 55 |
+
raise MissingInputError(f"Required file 'adc.nii.gz' not found in {input_dir}")
|
| 56 |
+
|
| 57 |
+
return dwi_path, adc_path, flair_path if flair_path.exists() else None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def find_prediction_mask(output_dir: Path) -> Path:
|
| 61 |
+
"""
|
| 62 |
+
Find the prediction mask in DeepISLES output directory.
|
| 63 |
+
|
| 64 |
+
DeepISLES outputs may have varying names depending on version.
|
| 65 |
+
This function finds the most likely prediction file.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
output_dir: DeepISLES output directory
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Path to the prediction mask NIfTI file
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
DeepISLESError: If no prediction mask found
|
| 75 |
+
"""
|
| 76 |
+
results_dir = output_dir / "results"
|
| 77 |
+
|
| 78 |
+
# Check common output patterns
|
| 79 |
+
possible_names = [
|
| 80 |
+
"prediction.nii.gz",
|
| 81 |
+
"pred.nii.gz",
|
| 82 |
+
"lesion_mask.nii.gz",
|
| 83 |
+
"output.nii.gz",
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
for name in possible_names:
|
| 87 |
+
pred_path = results_dir / name
|
| 88 |
+
if pred_path.exists():
|
| 89 |
+
return pred_path
|
| 90 |
+
|
| 91 |
+
# Fall back to finding any .nii.gz in results dir
|
| 92 |
+
if results_dir.exists():
|
| 93 |
+
nifti_files = list(results_dir.glob("*.nii.gz"))
|
| 94 |
+
if nifti_files:
|
| 95 |
+
return nifti_files[0]
|
| 96 |
+
|
| 97 |
+
raise DeepISLESError(
|
| 98 |
+
f"No prediction mask found in {results_dir}. "
|
| 99 |
+
"Expected files like 'prediction.nii.gz' or similar."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def run_deepisles_on_folder(
|
| 104 |
+
input_dir: Path,
|
| 105 |
+
*,
|
| 106 |
+
output_dir: Path | None = None,
|
| 107 |
+
fast: bool = True,
|
| 108 |
+
gpu: bool = True,
|
| 109 |
+
timeout: float | None = 1800, # 30 minutes default
|
| 110 |
+
) -> DeepISLESResult:
|
| 111 |
+
"""
|
| 112 |
+
Run DeepISLES stroke segmentation on a folder of NIfTI files.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
input_dir: Directory containing dwi.nii.gz, adc.nii.gz, [flair.nii.gz]
|
| 116 |
+
output_dir: Where to write results (default: input_dir/results)
|
| 117 |
+
fast: If True, use single-model mode (faster, slightly less accurate)
|
| 118 |
+
gpu: If True, use GPU acceleration
|
| 119 |
+
timeout: Maximum seconds to wait for inference
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
DeepISLESResult with path to prediction mask
|
| 123 |
+
|
| 124 |
+
Raises:
|
| 125 |
+
DockerNotAvailableError: If Docker is not available
|
| 126 |
+
DockerGPUNotAvailableError: If GPU requested but not available
|
| 127 |
+
MissingInputError: If required input files are missing
|
| 128 |
+
DeepISLESError: If inference fails (non-zero exit, missing output)
|
| 129 |
+
|
| 130 |
+
Example:
|
| 131 |
+
>>> result = run_deepisles_on_folder(Path("/data/case001"), fast=True)
|
| 132 |
+
>>> print(result.prediction_path)
|
| 133 |
+
/data/case001/results/prediction.nii.gz
|
| 134 |
+
"""
|
| 135 |
+
start_time = time.time()
|
| 136 |
+
|
| 137 |
+
# Validate inputs
|
| 138 |
+
_dwi_path, _adc_path, flair_path = validate_input_folder(input_dir)
|
| 139 |
+
|
| 140 |
+
# Check GPU if requested
|
| 141 |
+
if gpu:
|
| 142 |
+
ensure_gpu_available_if_requested(gpu)
|
| 143 |
+
|
| 144 |
+
# Set up output directory
|
| 145 |
+
if output_dir is None:
|
| 146 |
+
output_dir = input_dir
|
| 147 |
+
|
| 148 |
+
# Build command arguments
|
| 149 |
+
command: list[str] = [
|
| 150 |
+
"--dwi_file_name",
|
| 151 |
+
"dwi.nii.gz",
|
| 152 |
+
"--adc_file_name",
|
| 153 |
+
"adc.nii.gz",
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
if flair_path is not None:
|
| 157 |
+
command.extend(["--flair_file_name", "flair.nii.gz"])
|
| 158 |
+
|
| 159 |
+
if fast:
|
| 160 |
+
command.extend(["--fast", "True"])
|
| 161 |
+
|
| 162 |
+
# Set up volume mounts
|
| 163 |
+
volumes = {
|
| 164 |
+
input_dir.resolve(): "/input",
|
| 165 |
+
output_dir.resolve(): "/output",
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# Run the container
|
| 169 |
+
docker_result = run_container(
|
| 170 |
+
DEEPISLES_IMAGE,
|
| 171 |
+
command=command,
|
| 172 |
+
volumes=volumes,
|
| 173 |
+
gpu=gpu,
|
| 174 |
+
timeout=timeout,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Check for failure
|
| 178 |
+
if docker_result.exit_code != 0:
|
| 179 |
+
raise DeepISLESError(
|
| 180 |
+
f"DeepISLES inference failed with exit code {docker_result.exit_code}. "
|
| 181 |
+
f"stderr: {docker_result.stderr}"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Find the prediction mask
|
| 185 |
+
prediction_path = find_prediction_mask(output_dir)
|
| 186 |
+
|
| 187 |
+
elapsed = time.time() - start_time
|
| 188 |
+
|
| 189 |
+
return DeepISLESResult(
|
| 190 |
+
prediction_path=prediction_path,
|
| 191 |
+
docker_result=docker_result,
|
| 192 |
+
elapsed_seconds=elapsed,
|
| 193 |
+
)
|
src/stroke_deepisles_demo/inference/docker.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Docker execution utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import TYPE_CHECKING
|
| 10 |
+
|
| 11 |
+
from stroke_deepisles_demo.core.exceptions import (
|
| 12 |
+
DockerGPUNotAvailableError,
|
| 13 |
+
DockerNotAvailableError,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from collections.abc import Sequence
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class DockerRunResult:
|
| 23 |
+
"""Result of a Docker container run."""
|
| 24 |
+
|
| 25 |
+
exit_code: int
|
| 26 |
+
stdout: str
|
| 27 |
+
stderr: str
|
| 28 |
+
elapsed_seconds: float
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_docker_available() -> bool:
|
| 32 |
+
"""
|
| 33 |
+
Check if Docker is installed and the daemon is running.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
True if Docker is available, False otherwise
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
result = subprocess.run(
|
| 40 |
+
["docker", "info"],
|
| 41 |
+
capture_output=True,
|
| 42 |
+
timeout=10,
|
| 43 |
+
check=False,
|
| 44 |
+
)
|
| 45 |
+
return result.returncode == 0
|
| 46 |
+
except (FileNotFoundError, subprocess.TimeoutExpired):
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def ensure_docker_available() -> None:
|
| 51 |
+
"""
|
| 52 |
+
Ensure Docker is available, raising if not.
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
DockerNotAvailableError: If Docker is not installed or not running
|
| 56 |
+
"""
|
| 57 |
+
if not check_docker_available():
|
| 58 |
+
raise DockerNotAvailableError(
|
| 59 |
+
"Docker is not available. Please ensure Docker is installed and running."
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def check_nvidia_docker_available() -> bool:
|
| 64 |
+
"""
|
| 65 |
+
Check if NVIDIA Container Runtime is available for GPU support.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
True if nvidia-docker/nvidia-container-toolkit is configured
|
| 69 |
+
"""
|
| 70 |
+
try:
|
| 71 |
+
result = subprocess.run(
|
| 72 |
+
[
|
| 73 |
+
"docker",
|
| 74 |
+
"run",
|
| 75 |
+
"--rm",
|
| 76 |
+
"--gpus",
|
| 77 |
+
"all",
|
| 78 |
+
"nvidia/cuda:11.0-base",
|
| 79 |
+
"nvidia-smi",
|
| 80 |
+
],
|
| 81 |
+
capture_output=True,
|
| 82 |
+
timeout=30,
|
| 83 |
+
check=False,
|
| 84 |
+
)
|
| 85 |
+
return result.returncode == 0
|
| 86 |
+
except (subprocess.TimeoutExpired, FileNotFoundError):
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def ensure_gpu_available_if_requested(gpu: bool) -> None:
|
| 91 |
+
"""
|
| 92 |
+
Verify GPU is available if requested.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
gpu: Whether GPU was requested
|
| 96 |
+
|
| 97 |
+
Raises:
|
| 98 |
+
DockerGPUNotAvailableError: If GPU requested but not available
|
| 99 |
+
"""
|
| 100 |
+
if gpu and not check_nvidia_docker_available():
|
| 101 |
+
raise DockerGPUNotAvailableError(
|
| 102 |
+
"GPU requested but NVIDIA Container Runtime not available. "
|
| 103 |
+
"Either install nvidia-container-toolkit or set gpu=False."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def pull_image_if_missing(image: str, *, timeout: float = 600) -> bool:
|
| 108 |
+
"""
|
| 109 |
+
Pull a Docker image if not present locally.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
image: Docker image name (e.g., "isleschallenge/deepisles")
|
| 113 |
+
timeout: Maximum seconds to wait for pull
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
True if image was pulled, False if already present
|
| 117 |
+
"""
|
| 118 |
+
# Check if image exists locally
|
| 119 |
+
result = subprocess.run(
|
| 120 |
+
["docker", "image", "inspect", image],
|
| 121 |
+
capture_output=True,
|
| 122 |
+
timeout=10,
|
| 123 |
+
check=False,
|
| 124 |
+
)
|
| 125 |
+
if result.returncode == 0:
|
| 126 |
+
return False # Image already present
|
| 127 |
+
|
| 128 |
+
# Pull the image
|
| 129 |
+
subprocess.run(
|
| 130 |
+
["docker", "pull", image],
|
| 131 |
+
capture_output=True,
|
| 132 |
+
timeout=timeout,
|
| 133 |
+
check=True,
|
| 134 |
+
)
|
| 135 |
+
return True
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def build_docker_command(
|
| 139 |
+
image: str,
|
| 140 |
+
*,
|
| 141 |
+
command: Sequence[str] | None = None,
|
| 142 |
+
volumes: dict[Path, str] | None = None,
|
| 143 |
+
environment: dict[str, str] | None = None,
|
| 144 |
+
gpu: bool = False,
|
| 145 |
+
remove: bool = True,
|
| 146 |
+
match_user: bool = True,
|
| 147 |
+
) -> list[str]:
|
| 148 |
+
"""
|
| 149 |
+
Build the docker run command without executing.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
image: Docker image name
|
| 153 |
+
command: Command to run in container
|
| 154 |
+
volumes: Volume mounts (host path -> container path)
|
| 155 |
+
environment: Environment variables
|
| 156 |
+
gpu: If True, pass --gpus all
|
| 157 |
+
remove: If True, remove container after exit (--rm)
|
| 158 |
+
match_user: If True, match host user (Linux only)
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
List of command arguments for subprocess
|
| 162 |
+
"""
|
| 163 |
+
cmd: list[str] = ["docker", "run"]
|
| 164 |
+
|
| 165 |
+
if remove:
|
| 166 |
+
cmd.append("--rm")
|
| 167 |
+
|
| 168 |
+
if gpu:
|
| 169 |
+
cmd.extend(["--gpus", "all"])
|
| 170 |
+
|
| 171 |
+
# Match host user to avoid permission issues (Linux only).
|
| 172 |
+
# Guard against platforms (e.g. Windows, macOS) where os.getuid()/getgid()
|
| 173 |
+
# are absent or not meaningful.
|
| 174 |
+
if match_user:
|
| 175 |
+
import os
|
| 176 |
+
|
| 177 |
+
if (
|
| 178 |
+
os.name == "posix"
|
| 179 |
+
and sys.platform != "darwin"
|
| 180 |
+
and hasattr(os, "getuid")
|
| 181 |
+
and hasattr(os, "getgid")
|
| 182 |
+
):
|
| 183 |
+
uid = os.getuid()
|
| 184 |
+
gid = os.getgid()
|
| 185 |
+
cmd.extend(["--user", f"{uid}:{gid}"])
|
| 186 |
+
|
| 187 |
+
if volumes:
|
| 188 |
+
for host_path, container_path in volumes.items():
|
| 189 |
+
cmd.extend(["-v", f"{host_path}:{container_path}"])
|
| 190 |
+
|
| 191 |
+
if environment:
|
| 192 |
+
for key, value in environment.items():
|
| 193 |
+
cmd.extend(["-e", f"{key}={value}"])
|
| 194 |
+
|
| 195 |
+
cmd.append(image)
|
| 196 |
+
|
| 197 |
+
if command:
|
| 198 |
+
cmd.extend(command)
|
| 199 |
+
|
| 200 |
+
return cmd
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def run_container(
|
| 204 |
+
image: str,
|
| 205 |
+
*,
|
| 206 |
+
command: Sequence[str] | None = None,
|
| 207 |
+
volumes: dict[Path, str] | None = None,
|
| 208 |
+
environment: dict[str, str] | None = None,
|
| 209 |
+
gpu: bool = False,
|
| 210 |
+
remove: bool = True,
|
| 211 |
+
timeout: float | None = None,
|
| 212 |
+
) -> DockerRunResult:
|
| 213 |
+
"""
|
| 214 |
+
Run a Docker container and wait for completion.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
image: Docker image name
|
| 218 |
+
command: Command to run in container
|
| 219 |
+
volumes: Volume mounts (host path -> container path)
|
| 220 |
+
environment: Environment variables
|
| 221 |
+
gpu: If True, pass --gpus all
|
| 222 |
+
remove: If True, remove container after exit (--rm)
|
| 223 |
+
timeout: Maximum seconds to wait (None = no timeout)
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
DockerRunResult with exit code, stdout, stderr, elapsed time
|
| 227 |
+
|
| 228 |
+
Raises:
|
| 229 |
+
DockerNotAvailableError: If Docker is not available
|
| 230 |
+
subprocess.TimeoutExpired: If timeout exceeded
|
| 231 |
+
"""
|
| 232 |
+
ensure_docker_available()
|
| 233 |
+
|
| 234 |
+
cmd = build_docker_command(
|
| 235 |
+
image,
|
| 236 |
+
command=command,
|
| 237 |
+
volumes=volumes,
|
| 238 |
+
environment=environment,
|
| 239 |
+
gpu=gpu,
|
| 240 |
+
remove=remove,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
start_time = time.time()
|
| 244 |
+
result = subprocess.run(
|
| 245 |
+
cmd,
|
| 246 |
+
capture_output=True,
|
| 247 |
+
text=True,
|
| 248 |
+
timeout=timeout,
|
| 249 |
+
check=False,
|
| 250 |
+
)
|
| 251 |
+
elapsed = time.time() - start_time
|
| 252 |
+
|
| 253 |
+
return DockerRunResult(
|
| 254 |
+
exit_code=result.returncode,
|
| 255 |
+
stdout=result.stdout,
|
| 256 |
+
stderr=result.stderr,
|
| 257 |
+
elapsed_seconds=elapsed,
|
| 258 |
+
)
|
tests/conftest.py
CHANGED
|
@@ -13,7 +13,7 @@ import pytest
|
|
| 13 |
from stroke_deepisles_demo.core.types import CaseFiles
|
| 14 |
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
-
from collections.abc import Generator
|
| 17 |
|
| 18 |
|
| 19 |
@pytest.fixture
|
|
@@ -62,30 +62,46 @@ def synthetic_case_files(temp_dir: Path) -> CaseFiles:
|
|
| 62 |
|
| 63 |
|
| 64 |
@pytest.fixture
|
| 65 |
-
def
|
| 66 |
-
"""
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from stroke_deepisles_demo.core.types import CaseFiles
|
| 14 |
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
+
from collections.abc import Generator
|
| 17 |
|
| 18 |
|
| 19 |
@pytest.fixture
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
@pytest.fixture
|
| 65 |
+
def synthetic_isles_dir(temp_dir: Path) -> Path:
|
| 66 |
+
"""
|
| 67 |
+
Create synthetic ISLES24-like directory structure.
|
| 68 |
+
|
| 69 |
+
Structure:
|
| 70 |
+
temp_dir/
|
| 71 |
+
├── Images-DWI/
|
| 72 |
+
│ ├── sub-stroke0001_ses-02_dwi.nii.gz
|
| 73 |
+
│ └── sub-stroke0002_ses-02_dwi.nii.gz
|
| 74 |
+
├── Images-ADC/
|
| 75 |
+
│ ├── sub-stroke0001_ses-02_adc.nii.gz
|
| 76 |
+
│ └── sub-stroke0002_ses-02_adc.nii.gz
|
| 77 |
+
└── Masks/
|
| 78 |
+
├── sub-stroke0001_ses-02_lesion-msk.nii.gz
|
| 79 |
+
└── sub-stroke0002_ses-02_lesion-msk.nii.gz
|
| 80 |
+
"""
|
| 81 |
+
dwi_dir = temp_dir / "Images-DWI"
|
| 82 |
+
adc_dir = temp_dir / "Images-ADC"
|
| 83 |
+
mask_dir = temp_dir / "Masks"
|
| 84 |
+
|
| 85 |
+
dwi_dir.mkdir()
|
| 86 |
+
adc_dir.mkdir()
|
| 87 |
+
mask_dir.mkdir()
|
| 88 |
+
|
| 89 |
+
for subject_num in [1, 2]:
|
| 90 |
+
subject_id = f"sub-stroke{subject_num:04d}"
|
| 91 |
+
|
| 92 |
+
# Create DWI
|
| 93 |
+
dwi_data = np.random.rand(10, 10, 5).astype(np.float32)
|
| 94 |
+
dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4)) # type: ignore
|
| 95 |
+
nib.save(dwi_img, dwi_dir / f"{subject_id}_ses-02_dwi.nii.gz") # type: ignore
|
| 96 |
+
|
| 97 |
+
# Create ADC
|
| 98 |
+
adc_data = np.random.rand(10, 10, 5).astype(np.float32) * 2000
|
| 99 |
+
adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4)) # type: ignore
|
| 100 |
+
nib.save(adc_img, adc_dir / f"{subject_id}_ses-02_adc.nii.gz") # type: ignore
|
| 101 |
+
|
| 102 |
+
# Create Mask
|
| 103 |
+
mask_data = (np.random.rand(10, 10, 5) > 0.9).astype(np.uint8)
|
| 104 |
+
mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4)) # type: ignore
|
| 105 |
+
nib.save(mask_img, mask_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz") # type: ignore
|
| 106 |
+
|
| 107 |
+
return temp_dir
|
tests/data/test_adapter.py
CHANGED
|
@@ -1,70 +1,94 @@
|
|
| 1 |
-
"""Tests for
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
if TYPE_CHECKING:
|
| 12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
-
|
| 16 |
-
"""
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
case_ids = adapter.list_case_ids()
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
def test_len_matches_dataset_size(self, mock_hf_dataset: MagicMock) -> None:
|
| 28 |
-
"""len(adapter) equals number of cases in dataset."""
|
| 29 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
assert isinstance(case, dict)
|
| 41 |
-
assert "dwi" in case
|
| 42 |
-
assert "adc" in case
|
| 43 |
-
# Paths should be Path objects or convertible
|
| 44 |
-
from pathlib import Path
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
"""Can retrieve case by integer index."""
|
| 50 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
assert isinstance(case_id, str)
|
| 55 |
-
assert case["dwi"] is not None
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
"""Can iterate over case IDs."""
|
| 66 |
-
adapter = CaseAdapter(mock_hf_dataset)
|
| 67 |
|
| 68 |
-
|
|
|
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
| 1 |
+
"""Tests for the data adapter."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
+
from stroke_deepisles_demo.data.adapter import (
|
| 8 |
+
LocalDataset,
|
| 9 |
+
build_local_dataset,
|
| 10 |
+
parse_subject_id,
|
| 11 |
+
)
|
| 12 |
|
| 13 |
if TYPE_CHECKING:
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_parse_subject_id_extracts_correctly() -> None:
|
| 18 |
+
"""Test extracting subject ID from BIDS filename."""
|
| 19 |
+
# Valid cases
|
| 20 |
+
assert parse_subject_id("sub-stroke0005_ses-02_dwi.nii.gz") == "sub-stroke0005"
|
| 21 |
+
assert parse_subject_id("sub-stroke0149_ses-02_adc.nii.gz") == "sub-stroke0149"
|
| 22 |
+
assert parse_subject_id("sub-stroke1234_ses-02_lesion-msk.nii.gz") == "sub-stroke1234"
|
| 23 |
+
|
| 24 |
+
# Invalid cases
|
| 25 |
+
assert parse_subject_id("random_file.nii.gz") is None
|
| 26 |
+
assert parse_subject_id("sub-strokeABC_ses-02_dwi.nii.gz") is None # Non-digit ID
|
| 27 |
|
| 28 |
|
| 29 |
+
def test_build_local_dataset_matches_files(synthetic_isles_dir: Path) -> None:
|
| 30 |
+
"""Test that files are correctly matched by subject ID."""
|
| 31 |
+
dataset = build_local_dataset(synthetic_isles_dir)
|
| 32 |
|
| 33 |
+
assert isinstance(dataset, LocalDataset)
|
| 34 |
+
assert len(dataset) == 2 # synthetic_isles_dir creates 2 subjects
|
| 35 |
+
assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
|
|
|
|
| 36 |
|
| 37 |
+
# Verify matching logic
|
| 38 |
+
case1 = dataset.get_case("sub-stroke0001")
|
| 39 |
+
assert case1["dwi"].name == "sub-stroke0001_ses-02_dwi.nii.gz"
|
| 40 |
+
assert case1["adc"].name == "sub-stroke0001_ses-02_adc.nii.gz"
|
| 41 |
+
assert case1["ground_truth"] is not None
|
| 42 |
+
assert case1["ground_truth"].name == "sub-stroke0001_ses-02_lesion-msk.nii.gz"
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
def test_get_case_returns_case_files(synthetic_isles_dir: Path) -> None:
|
| 46 |
+
"""Test retrieval of cases by ID and index."""
|
| 47 |
+
dataset = build_local_dataset(synthetic_isles_dir)
|
| 48 |
|
| 49 |
+
# By ID
|
| 50 |
+
case_by_id = dataset.get_case("sub-stroke0001")
|
| 51 |
+
assert isinstance(case_by_id, dict)
|
| 52 |
+
assert "dwi" in case_by_id
|
| 53 |
+
assert "adc" in case_by_id
|
| 54 |
|
| 55 |
+
# By Index
|
| 56 |
+
case_by_idx = dataset.get_case(0)
|
| 57 |
+
assert isinstance(case_by_idx, dict)
|
| 58 |
+
assert case_by_id == case_by_idx # Should be the same case
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
def test_build_local_dataset_skips_incomplete(
|
| 62 |
+
synthetic_isles_dir: Path,
|
| 63 |
+
) -> None:
|
| 64 |
+
"""Test that incomplete cases (missing ADC) are skipped."""
|
| 65 |
+
# Delete ADC for subject 2
|
| 66 |
+
adc_file = synthetic_isles_dir / "Images-ADC" / "sub-stroke0002_ses-02_adc.nii.gz"
|
| 67 |
+
adc_file.unlink()
|
| 68 |
|
| 69 |
+
dataset = build_local_dataset(synthetic_isles_dir)
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# Subject 2 should be gone
|
| 72 |
+
assert len(dataset) == 1
|
| 73 |
+
assert dataset.list_case_ids() == ["sub-stroke0001"]
|
| 74 |
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
def test_build_local_dataset_handles_missing_mask(
|
| 77 |
+
synthetic_isles_dir: Path,
|
| 78 |
+
) -> None:
|
| 79 |
+
"""Test that missing mask results in ground_truth=None (if allowed)."""
|
| 80 |
+
# NOTE: Adapter currently allows missing mask?
|
| 81 |
+
# Spec says: "ground_truth=mask_file if mask_file.exists() else None"
|
| 82 |
+
# So yes, it should load but with None.
|
| 83 |
|
| 84 |
+
# Delete Mask for subject 2
|
| 85 |
+
mask_file = synthetic_isles_dir / "Masks" / "sub-stroke0002_ses-02_lesion-msk.nii.gz"
|
| 86 |
+
mask_file.unlink()
|
| 87 |
|
| 88 |
+
dataset = build_local_dataset(synthetic_isles_dir)
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
# Subject 2 should still exist
|
| 91 |
+
assert len(dataset) == 2
|
| 92 |
|
| 93 |
+
case2 = dataset.get_case("sub-stroke0002")
|
| 94 |
+
assert case2.get("ground_truth") is None
|
tests/data/test_integration_real_data.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration tests with real ISLES24 data."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from stroke_deepisles_demo.data.loader import load_isles_dataset
|
| 10 |
+
|
| 11 |
+
REAL_DATA_PATH = Path("data/scratch/isles24_extracted")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.mark.skipif(not REAL_DATA_PATH.exists(), reason="Real data not found in data/scratch")
|
| 15 |
+
def test_load_real_data_count() -> None:
|
| 16 |
+
"""Verify that we can load the expected number of cases from real data."""
|
| 17 |
+
dataset = load_isles_dataset(source=REAL_DATA_PATH)
|
| 18 |
+
|
| 19 |
+
# We expect 149 cases based on schema report
|
| 20 |
+
assert len(dataset) == 149
|
| 21 |
+
|
| 22 |
+
# Check a specific known case
|
| 23 |
+
case = dataset.get_case("sub-stroke0005")
|
| 24 |
+
assert case["dwi"].name == "sub-stroke0005_ses-02_dwi.nii.gz"
|
| 25 |
+
assert case["dwi"].exists()
|
| 26 |
+
assert case["adc"].exists()
|
| 27 |
+
assert case["ground_truth"] is not None
|
| 28 |
+
assert case["ground_truth"].exists()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.skipif(not REAL_DATA_PATH.exists(), reason="Real data not found in data/scratch")
|
| 32 |
+
def test_real_data_subject_ids() -> None:
|
| 33 |
+
"""Verify subject ID formatting on real data."""
|
| 34 |
+
dataset = load_isles_dataset(source=REAL_DATA_PATH)
|
| 35 |
+
ids = dataset.list_case_ids()
|
| 36 |
+
|
| 37 |
+
assert len(ids) == 149
|
| 38 |
+
assert ids[0] == "sub-stroke0001"
|
| 39 |
+
# We know there are gaps, so just check the format
|
| 40 |
+
for subject_id in ids:
|
| 41 |
+
assert subject_id.startswith("sub-stroke")
|
| 42 |
+
assert len(subject_id) == len("sub-strokeXXXX")
|
tests/data/test_loader.py
CHANGED
|
@@ -1,90 +1,33 @@
|
|
| 1 |
-
"""Tests for data loader
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
from
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
|
| 9 |
-
from stroke_deepisles_demo.
|
| 10 |
-
from stroke_deepisles_demo.data.loader import
|
| 11 |
-
DatasetInfo,
|
| 12 |
-
get_dataset_info,
|
| 13 |
-
load_isles_dataset,
|
| 14 |
-
)
|
| 15 |
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
class TestLoadIslesDataset:
|
| 18 |
-
"""Tests for load_isles_dataset."""
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
load_isles_dataset("test/dataset")
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
def test_returns_dataset_object(self) -> None:
|
| 32 |
-
"""Returns the loaded Dataset object."""
|
| 33 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 34 |
-
expected = MagicMock()
|
| 35 |
-
mock_load.return_value = expected
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def test_handles_load_error(self) -> None:
|
| 42 |
-
"""Wraps HF errors in DataLoadError."""
|
| 43 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 44 |
-
mock_load.side_effect = Exception("Network error")
|
| 45 |
-
|
| 46 |
-
with pytest.raises(DataLoadError, match="Network error"):
|
| 47 |
-
load_isles_dataset()
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class TestGetDatasetInfo:
|
| 51 |
-
"""Tests for get_dataset_info."""
|
| 52 |
-
|
| 53 |
-
def test_returns_datasetinfo(self) -> None:
|
| 54 |
-
"""Returns DatasetInfo with expected fields."""
|
| 55 |
-
with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
|
| 56 |
-
mock_ds = MagicMock()
|
| 57 |
-
mock_ds.__len__ = MagicMock(return_value=149)
|
| 58 |
-
# Mock info.splits['train'].num_examples
|
| 59 |
-
mock_ds.info.splits.__getitem__.return_value.num_examples = 149
|
| 60 |
-
# Mock features as dict-like
|
| 61 |
-
mock_ds.features = {"dwi": None, "adc": None, "mask": None}
|
| 62 |
-
mock_load.return_value = mock_ds
|
| 63 |
-
|
| 64 |
-
info = get_dataset_info()
|
| 65 |
-
|
| 66 |
-
assert isinstance(info, DatasetInfo)
|
| 67 |
-
assert info.num_cases == 149
|
| 68 |
-
assert "dwi" in info.modalities
|
| 69 |
-
assert info.has_ground_truth is True
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
@pytest.mark.integration
|
| 73 |
-
class TestLoadIslesDatasetIntegration:
|
| 74 |
-
"""Integration tests that hit the real HuggingFace Hub."""
|
| 75 |
-
|
| 76 |
-
@pytest.mark.slow
|
| 77 |
-
def test_load_real_dataset(self) -> None:
|
| 78 |
-
"""Actually loads ISLES24-MR-Lite from HF Hub."""
|
| 79 |
-
# This test requires network access
|
| 80 |
-
# Run with: pytest -m integration
|
| 81 |
-
# Using streaming=True to avoid downloading everything
|
| 82 |
-
try:
|
| 83 |
-
dataset = load_isles_dataset(streaming=True)
|
| 84 |
-
assert dataset is not None
|
| 85 |
-
# Verify we got metadata/features - this confirms connectivity
|
| 86 |
-
# Iterating might trigger heavy downloads or fail if dataset is empty/gated
|
| 87 |
-
assert hasattr(dataset, "features")
|
| 88 |
-
assert len(dataset.features) > 0
|
| 89 |
-
except Exception as e:
|
| 90 |
-
pytest.fail(f"Failed to load real dataset: {e}")
|
|
|
|
| 1 |
+
"""Tests for the data loader."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
|
| 9 |
+
from stroke_deepisles_demo.data.adapter import LocalDataset
|
| 10 |
+
from stroke_deepisles_demo.data.loader import load_isles_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from pathlib import Path
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
def test_load_from_local_returns_local_dataset(synthetic_isles_dir: Path) -> None:
|
| 17 |
+
"""Test that loading from local path returns a LocalDataset."""
|
| 18 |
+
dataset = load_isles_dataset(source=synthetic_isles_dir, local_mode=True)
|
| 19 |
+
assert isinstance(dataset, LocalDataset)
|
| 20 |
+
assert len(dataset) > 0
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
+
def test_load_from_local_finds_all_cases(synthetic_isles_dir: Path) -> None:
|
| 24 |
+
"""Test that the loader correctly delegates finding cases to adapter."""
|
| 25 |
+
dataset = load_isles_dataset(source=synthetic_isles_dir)
|
| 26 |
+
assert len(dataset) == 2
|
| 27 |
+
assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
def test_load_raises_not_implemented_for_hf() -> None:
|
| 31 |
+
"""Test that HF mode raises NotImplementedError."""
|
| 32 |
+
with pytest.raises(NotImplementedError):
|
| 33 |
+
load_isles_dataset(source="fake/dataset", local_mode=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/inference/__init__.py
ADDED
|
File without changes
|
tests/inference/test_deepisles.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for DeepISLES wrapper."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from unittest.mock import MagicMock, patch
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stroke_deepisles_demo.core.exceptions import DeepISLESError, MissingInputError
|
| 11 |
+
from stroke_deepisles_demo.inference.deepisles import (
|
| 12 |
+
DeepISLESResult,
|
| 13 |
+
find_prediction_mask,
|
| 14 |
+
run_deepisles_on_folder,
|
| 15 |
+
validate_input_folder,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestValidateInputFolder:
|
| 20 |
+
"""Tests for validate_input_folder."""
|
| 21 |
+
|
| 22 |
+
def test_succeeds_with_required_files(self, temp_dir: Path) -> None:
|
| 23 |
+
"""Returns paths when required files exist."""
|
| 24 |
+
(temp_dir / "dwi.nii.gz").touch()
|
| 25 |
+
(temp_dir / "adc.nii.gz").touch()
|
| 26 |
+
|
| 27 |
+
dwi, adc, flair = validate_input_folder(temp_dir)
|
| 28 |
+
|
| 29 |
+
assert dwi == temp_dir / "dwi.nii.gz"
|
| 30 |
+
assert adc == temp_dir / "adc.nii.gz"
|
| 31 |
+
assert flair is None
|
| 32 |
+
|
| 33 |
+
def test_includes_flair_when_present(self, temp_dir: Path) -> None:
|
| 34 |
+
"""Returns FLAIR path when present."""
|
| 35 |
+
(temp_dir / "dwi.nii.gz").touch()
|
| 36 |
+
(temp_dir / "adc.nii.gz").touch()
|
| 37 |
+
(temp_dir / "flair.nii.gz").touch()
|
| 38 |
+
|
| 39 |
+
_dwi, _adc, flair = validate_input_folder(temp_dir)
|
| 40 |
+
|
| 41 |
+
assert flair == temp_dir / "flair.nii.gz"
|
| 42 |
+
|
| 43 |
+
def test_raises_when_dwi_missing(self, temp_dir: Path) -> None:
|
| 44 |
+
"""Raises MissingInputError when DWI is missing."""
|
| 45 |
+
(temp_dir / "adc.nii.gz").touch()
|
| 46 |
+
|
| 47 |
+
with pytest.raises(MissingInputError, match="dwi"):
|
| 48 |
+
validate_input_folder(temp_dir)
|
| 49 |
+
|
| 50 |
+
def test_raises_when_adc_missing(self, temp_dir: Path) -> None:
|
| 51 |
+
"""Raises MissingInputError when ADC is missing."""
|
| 52 |
+
(temp_dir / "dwi.nii.gz").touch()
|
| 53 |
+
|
| 54 |
+
with pytest.raises(MissingInputError, match="adc"):
|
| 55 |
+
validate_input_folder(temp_dir)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TestFindPredictionMask:
|
| 59 |
+
"""Tests for find_prediction_mask."""
|
| 60 |
+
|
| 61 |
+
def test_finds_prediction_file(self, temp_dir: Path) -> None:
|
| 62 |
+
"""Finds prediction.nii.gz in output directory."""
|
| 63 |
+
results_dir = temp_dir / "results"
|
| 64 |
+
results_dir.mkdir()
|
| 65 |
+
pred_file = results_dir / "prediction.nii.gz"
|
| 66 |
+
pred_file.touch()
|
| 67 |
+
|
| 68 |
+
result = find_prediction_mask(temp_dir)
|
| 69 |
+
|
| 70 |
+
assert result == pred_file
|
| 71 |
+
|
| 72 |
+
def test_finds_alternate_name(self, temp_dir: Path) -> None:
|
| 73 |
+
"""Finds alternate named prediction files."""
|
| 74 |
+
results_dir = temp_dir / "results"
|
| 75 |
+
results_dir.mkdir()
|
| 76 |
+
pred_file = results_dir / "pred.nii.gz"
|
| 77 |
+
pred_file.touch()
|
| 78 |
+
|
| 79 |
+
result = find_prediction_mask(temp_dir)
|
| 80 |
+
|
| 81 |
+
assert result == pred_file
|
| 82 |
+
|
| 83 |
+
def test_falls_back_to_any_nifti(self, temp_dir: Path) -> None:
|
| 84 |
+
"""Falls back to any .nii.gz file if standard names not found."""
|
| 85 |
+
results_dir = temp_dir / "results"
|
| 86 |
+
results_dir.mkdir()
|
| 87 |
+
pred_file = results_dir / "some_output.nii.gz"
|
| 88 |
+
pred_file.touch()
|
| 89 |
+
|
| 90 |
+
result = find_prediction_mask(temp_dir)
|
| 91 |
+
|
| 92 |
+
assert result == pred_file
|
| 93 |
+
|
| 94 |
+
def test_raises_when_no_prediction(self, temp_dir: Path) -> None:
|
| 95 |
+
"""Raises DeepISLESError when no prediction found."""
|
| 96 |
+
results_dir = temp_dir / "results"
|
| 97 |
+
results_dir.mkdir()
|
| 98 |
+
|
| 99 |
+
with pytest.raises(DeepISLESError, match="prediction"):
|
| 100 |
+
find_prediction_mask(temp_dir)
|
| 101 |
+
|
| 102 |
+
def test_raises_when_results_dir_missing(self, temp_dir: Path) -> None:
|
| 103 |
+
"""Raises DeepISLESError when results directory missing."""
|
| 104 |
+
with pytest.raises(DeepISLESError, match="prediction"):
|
| 105 |
+
find_prediction_mask(temp_dir)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TestRunDeepIslesOnFolder:
|
| 109 |
+
"""Tests for run_deepisles_on_folder."""
|
| 110 |
+
|
| 111 |
+
@pytest.fixture
|
| 112 |
+
def valid_input_dir(self, temp_dir: Path) -> Path:
|
| 113 |
+
"""Create a valid input directory with required files."""
|
| 114 |
+
(temp_dir / "dwi.nii.gz").touch()
|
| 115 |
+
(temp_dir / "adc.nii.gz").touch()
|
| 116 |
+
return temp_dir
|
| 117 |
+
|
| 118 |
+
def test_validates_input_files(self, temp_dir: Path) -> None:
|
| 119 |
+
"""Validates input files before running Docker."""
|
| 120 |
+
# Missing required files
|
| 121 |
+
with pytest.raises(MissingInputError):
|
| 122 |
+
run_deepisles_on_folder(temp_dir)
|
| 123 |
+
|
| 124 |
+
def test_calls_docker_with_correct_image(self, valid_input_dir: Path) -> None:
|
| 125 |
+
"""Calls Docker with DeepISLES image."""
|
| 126 |
+
with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
|
| 127 |
+
mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
|
| 128 |
+
with (
|
| 129 |
+
patch(
|
| 130 |
+
"stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
|
| 131 |
+
),
|
| 132 |
+
patch(
|
| 133 |
+
"stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
|
| 134 |
+
) as mock_find,
|
| 135 |
+
):
|
| 136 |
+
mock_find.return_value = valid_input_dir / "results" / "pred.nii.gz"
|
| 137 |
+
run_deepisles_on_folder(valid_input_dir)
|
| 138 |
+
|
| 139 |
+
# Check image name
|
| 140 |
+
call_args = mock_run.call_args
|
| 141 |
+
assert call_args.args[0] == "isleschallenge/deepisles"
|
| 142 |
+
|
| 143 |
+
def test_passes_fast_flag(self, valid_input_dir: Path) -> None:
|
| 144 |
+
"""Passes --fast True when fast=True."""
|
| 145 |
+
with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
|
| 146 |
+
mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
|
| 147 |
+
with (
|
| 148 |
+
patch(
|
| 149 |
+
"stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
|
| 150 |
+
),
|
| 151 |
+
patch(
|
| 152 |
+
"stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
|
| 153 |
+
) as mock_find,
|
| 154 |
+
):
|
| 155 |
+
mock_find.return_value = valid_input_dir / "results" / "pred.nii.gz"
|
| 156 |
+
|
| 157 |
+
run_deepisles_on_folder(valid_input_dir, fast=True)
|
| 158 |
+
|
| 159 |
+
# Check --fast in command
|
| 160 |
+
call_kwargs = mock_run.call_args.kwargs
|
| 161 |
+
command = call_kwargs.get("command", [])
|
| 162 |
+
assert "--fast" in command
|
| 163 |
+
assert "True" in command
|
| 164 |
+
|
| 165 |
+
def test_includes_flair_when_present(self, valid_input_dir: Path) -> None:
|
| 166 |
+
"""Includes FLAIR in command when present."""
|
| 167 |
+
(valid_input_dir / "flair.nii.gz").touch()
|
| 168 |
+
|
| 169 |
+
with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
|
| 170 |
+
mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
|
| 171 |
+
with (
|
| 172 |
+
patch(
|
| 173 |
+
"stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
|
| 174 |
+
),
|
| 175 |
+
patch(
|
| 176 |
+
"stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
|
| 177 |
+
) as mock_find,
|
| 178 |
+
):
|
| 179 |
+
mock_find.return_value = valid_input_dir / "results" / "pred.nii.gz"
|
| 180 |
+
|
| 181 |
+
run_deepisles_on_folder(valid_input_dir)
|
| 182 |
+
|
| 183 |
+
call_kwargs = mock_run.call_args.kwargs
|
| 184 |
+
command = call_kwargs.get("command", [])
|
| 185 |
+
assert "--flair_file_name" in command
|
| 186 |
+
assert "flair.nii.gz" in command
|
| 187 |
+
|
| 188 |
+
def test_raises_on_docker_failure(self, valid_input_dir: Path) -> None:
|
| 189 |
+
"""Raises DeepISLESError when Docker returns non-zero."""
|
| 190 |
+
with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
|
| 191 |
+
mock_run.return_value = MagicMock(exit_code=1, stdout="", stderr="Segmentation fault")
|
| 192 |
+
with (
|
| 193 |
+
patch(
|
| 194 |
+
"stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
|
| 195 |
+
),
|
| 196 |
+
pytest.raises(DeepISLESError, match="failed"),
|
| 197 |
+
):
|
| 198 |
+
run_deepisles_on_folder(valid_input_dir)
|
| 199 |
+
|
| 200 |
+
def test_returns_result_with_prediction_path(self, valid_input_dir: Path) -> None:
|
| 201 |
+
"""Returns DeepISLESResult with prediction path."""
|
| 202 |
+
with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
|
| 203 |
+
mock_run.return_value = MagicMock(
|
| 204 |
+
exit_code=0, stdout="", stderr="", elapsed_seconds=10.0
|
| 205 |
+
)
|
| 206 |
+
with (
|
| 207 |
+
patch(
|
| 208 |
+
"stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
|
| 209 |
+
),
|
| 210 |
+
patch(
|
| 211 |
+
"stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
|
| 212 |
+
) as mock_find,
|
| 213 |
+
):
|
| 214 |
+
expected_path = valid_input_dir / "results" / "prediction.nii.gz"
|
| 215 |
+
mock_find.return_value = expected_path
|
| 216 |
+
|
| 217 |
+
result = run_deepisles_on_folder(valid_input_dir)
|
| 218 |
+
|
| 219 |
+
assert isinstance(result, DeepISLESResult)
|
| 220 |
+
assert result.prediction_path == expected_path
|
| 221 |
+
|
| 222 |
+
def test_passes_volume_mounts(self, valid_input_dir: Path, temp_dir: Path) -> None:
|
| 223 |
+
"""Passes correct volume mounts to Docker."""
|
| 224 |
+
# Create a separate output directory
|
| 225 |
+
output_dir = temp_dir / "output"
|
| 226 |
+
output_dir.mkdir()
|
| 227 |
+
|
| 228 |
+
with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
|
| 229 |
+
mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
|
| 230 |
+
with (
|
| 231 |
+
patch(
|
| 232 |
+
"stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
|
| 233 |
+
),
|
| 234 |
+
patch(
|
| 235 |
+
"stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
|
| 236 |
+
) as mock_find,
|
| 237 |
+
):
|
| 238 |
+
mock_find.return_value = output_dir / "results" / "pred.nii.gz"
|
| 239 |
+
|
| 240 |
+
run_deepisles_on_folder(valid_input_dir, output_dir=output_dir)
|
| 241 |
+
|
| 242 |
+
call_kwargs = mock_run.call_args.kwargs
|
| 243 |
+
volumes = call_kwargs.get("volumes", {})
|
| 244 |
+
# Should have input and output mounts (2 separate directories)
|
| 245 |
+
assert len(volumes) == 2
|
| 246 |
+
# Values should be container paths
|
| 247 |
+
assert "/input" in volumes.values()
|
| 248 |
+
assert "/output" in volumes.values()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@pytest.mark.integration
|
| 252 |
+
@pytest.mark.slow
|
| 253 |
+
class TestDeepIslesIntegration:
|
| 254 |
+
"""Integration tests requiring real Docker and DeepISLES image."""
|
| 255 |
+
|
| 256 |
+
def test_real_inference(self, synthetic_case_files: dict[str, object]) -> None:
|
| 257 |
+
"""Run actual DeepISLES inference on synthetic data."""
|
| 258 |
+
# This test requires:
|
| 259 |
+
# 1. Docker available
|
| 260 |
+
# 2. isleschallenge/deepisles image pulled
|
| 261 |
+
# 3. GPU (optional but recommended)
|
| 262 |
+
#
|
| 263 |
+
# Run with: pytest -m integration
|
| 264 |
+
import tempfile
|
| 265 |
+
|
| 266 |
+
from stroke_deepisles_demo.data.staging import stage_case_for_deepisles
|
| 267 |
+
|
| 268 |
+
# Create a separate staging directory
|
| 269 |
+
with tempfile.TemporaryDirectory() as staging_dir:
|
| 270 |
+
# Stage the synthetic files to the new directory
|
| 271 |
+
staged = stage_case_for_deepisles(
|
| 272 |
+
synthetic_case_files, # type: ignore[arg-type]
|
| 273 |
+
Path(staging_dir),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Run inference
|
| 277 |
+
result = run_deepisles_on_folder(
|
| 278 |
+
staged.input_dir,
|
| 279 |
+
fast=True,
|
| 280 |
+
gpu=False, # Might not have GPU in CI
|
| 281 |
+
timeout=600,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Verify output exists
|
| 285 |
+
assert result.prediction_path.exists()
|
tests/inference/test_docker.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Docker utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
from unittest.mock import MagicMock, patch
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stroke_deepisles_demo.core.exceptions import DockerNotAvailableError
|
| 11 |
+
from stroke_deepisles_demo.inference.docker import (
|
| 12 |
+
build_docker_command,
|
| 13 |
+
check_docker_available,
|
| 14 |
+
ensure_docker_available,
|
| 15 |
+
run_container,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TestCheckDockerAvailable:
|
| 23 |
+
"""Tests for check_docker_available."""
|
| 24 |
+
|
| 25 |
+
def test_returns_true_when_docker_responds(self) -> None:
|
| 26 |
+
"""Returns True when 'docker info' succeeds."""
|
| 27 |
+
with patch("subprocess.run") as mock_run:
|
| 28 |
+
mock_run.return_value = MagicMock(returncode=0)
|
| 29 |
+
|
| 30 |
+
result = check_docker_available()
|
| 31 |
+
|
| 32 |
+
assert result is True
|
| 33 |
+
|
| 34 |
+
def test_returns_false_when_docker_not_found(self) -> None:
|
| 35 |
+
"""Returns False when docker command not found."""
|
| 36 |
+
with patch("subprocess.run") as mock_run:
|
| 37 |
+
mock_run.side_effect = FileNotFoundError()
|
| 38 |
+
|
| 39 |
+
result = check_docker_available()
|
| 40 |
+
|
| 41 |
+
assert result is False
|
| 42 |
+
|
| 43 |
+
def test_returns_false_when_daemon_not_running(self) -> None:
|
| 44 |
+
"""Returns False when docker daemon not running."""
|
| 45 |
+
with patch("subprocess.run") as mock_run:
|
| 46 |
+
mock_run.return_value = MagicMock(returncode=1)
|
| 47 |
+
|
| 48 |
+
result = check_docker_available()
|
| 49 |
+
|
| 50 |
+
assert result is False
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class TestEnsureDockerAvailable:
|
| 54 |
+
"""Tests for ensure_docker_available."""
|
| 55 |
+
|
| 56 |
+
def test_raises_when_docker_not_available(self) -> None:
|
| 57 |
+
"""Raises DockerNotAvailableError when Docker not available."""
|
| 58 |
+
with (
|
| 59 |
+
patch(
|
| 60 |
+
"stroke_deepisles_demo.inference.docker.check_docker_available",
|
| 61 |
+
return_value=False,
|
| 62 |
+
),
|
| 63 |
+
pytest.raises(DockerNotAvailableError),
|
| 64 |
+
):
|
| 65 |
+
ensure_docker_available()
|
| 66 |
+
|
| 67 |
+
def test_no_error_when_docker_available(self) -> None:
|
| 68 |
+
"""No exception when Docker is available."""
|
| 69 |
+
with patch(
|
| 70 |
+
"stroke_deepisles_demo.inference.docker.check_docker_available",
|
| 71 |
+
return_value=True,
|
| 72 |
+
):
|
| 73 |
+
ensure_docker_available() # Should not raise
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TestBuildDockerCommand:
|
| 77 |
+
"""Tests for build_docker_command."""
|
| 78 |
+
|
| 79 |
+
def test_basic_command(self) -> None:
|
| 80 |
+
"""Builds basic docker run command."""
|
| 81 |
+
cmd = build_docker_command("myimage:latest")
|
| 82 |
+
|
| 83 |
+
assert cmd[0] == "docker"
|
| 84 |
+
assert "run" in cmd
|
| 85 |
+
assert "myimage:latest" in cmd
|
| 86 |
+
|
| 87 |
+
def test_includes_rm_flag(self) -> None:
|
| 88 |
+
"""Includes --rm when remove=True."""
|
| 89 |
+
cmd = build_docker_command("myimage", remove=True)
|
| 90 |
+
|
| 91 |
+
assert "--rm" in cmd
|
| 92 |
+
|
| 93 |
+
def test_excludes_rm_flag(self) -> None:
|
| 94 |
+
"""Excludes --rm when remove=False."""
|
| 95 |
+
cmd = build_docker_command("myimage", remove=False)
|
| 96 |
+
|
| 97 |
+
assert "--rm" not in cmd
|
| 98 |
+
|
| 99 |
+
def test_includes_gpu_flag(self) -> None:
|
| 100 |
+
"""Includes --gpus all when gpu=True."""
|
| 101 |
+
cmd = build_docker_command("myimage", gpu=True)
|
| 102 |
+
|
| 103 |
+
assert "--gpus" in cmd
|
| 104 |
+
gpu_index = cmd.index("--gpus")
|
| 105 |
+
assert cmd[gpu_index + 1] == "all"
|
| 106 |
+
|
| 107 |
+
def test_volume_mounts(self, temp_dir: Path) -> None:
|
| 108 |
+
"""Includes volume mounts."""
|
| 109 |
+
volumes = {temp_dir: "/data"}
|
| 110 |
+
cmd = build_docker_command("myimage", volumes=volumes)
|
| 111 |
+
|
| 112 |
+
assert "-v" in cmd
|
| 113 |
+
# Find the volume argument
|
| 114 |
+
v_index = cmd.index("-v")
|
| 115 |
+
assert f"{temp_dir}:/data" in cmd[v_index + 1]
|
| 116 |
+
|
| 117 |
+
def test_custom_command(self) -> None:
|
| 118 |
+
"""Appends custom command arguments."""
|
| 119 |
+
cmd = build_docker_command("myimage", command=["--input", "/data", "--fast", "True"])
|
| 120 |
+
|
| 121 |
+
assert "--input" in cmd
|
| 122 |
+
assert "--fast" in cmd
|
| 123 |
+
|
| 124 |
+
def test_environment_variables(self) -> None:
|
| 125 |
+
"""Includes environment variables."""
|
| 126 |
+
env = {"MY_VAR": "value", "OTHER": "123"}
|
| 127 |
+
cmd = build_docker_command("myimage", environment=env)
|
| 128 |
+
|
| 129 |
+
assert "-e" in cmd
|
| 130 |
+
# Check both vars are present
|
| 131 |
+
cmd_str = " ".join(cmd)
|
| 132 |
+
assert "MY_VAR=value" in cmd_str
|
| 133 |
+
assert "OTHER=123" in cmd_str
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TestRunContainer:
|
| 137 |
+
"""Tests for run_container."""
|
| 138 |
+
|
| 139 |
+
def test_calls_subprocess_with_built_command(self) -> None:
|
| 140 |
+
"""Calls subprocess.run with built command."""
|
| 141 |
+
with patch("subprocess.run") as mock_run:
|
| 142 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="output", stderr="")
|
| 143 |
+
with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
|
| 144 |
+
run_container("myimage")
|
| 145 |
+
|
| 146 |
+
mock_run.assert_called_once()
|
| 147 |
+
|
| 148 |
+
def test_returns_result_with_exit_code(self) -> None:
|
| 149 |
+
"""Returns DockerRunResult with correct exit code."""
|
| 150 |
+
with patch("subprocess.run") as mock_run:
|
| 151 |
+
mock_run.return_value = MagicMock(returncode=42, stdout="out", stderr="err")
|
| 152 |
+
with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
|
| 153 |
+
result = run_container("myimage")
|
| 154 |
+
|
| 155 |
+
assert result.exit_code == 42
|
| 156 |
+
|
| 157 |
+
def test_captures_stdout_stderr(self) -> None:
|
| 158 |
+
"""Captures stdout and stderr from container."""
|
| 159 |
+
with patch("subprocess.run") as mock_run:
|
| 160 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="hello", stderr="warning")
|
| 161 |
+
with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
|
| 162 |
+
result = run_container("myimage")
|
| 163 |
+
|
| 164 |
+
assert result.stdout == "hello"
|
| 165 |
+
assert result.stderr == "warning"
|
| 166 |
+
|
| 167 |
+
def test_respects_timeout(self) -> None:
|
| 168 |
+
"""Passes timeout to subprocess."""
|
| 169 |
+
with patch("subprocess.run") as mock_run:
|
| 170 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
| 171 |
+
with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
|
| 172 |
+
run_container("myimage", timeout=60.0)
|
| 173 |
+
|
| 174 |
+
call_kwargs = mock_run.call_args.kwargs
|
| 175 |
+
assert call_kwargs.get("timeout") == 60.0
|
| 176 |
+
|
| 177 |
+
def test_tracks_elapsed_time(self) -> None:
|
| 178 |
+
"""Tracks elapsed time in result."""
|
| 179 |
+
with patch("subprocess.run") as mock_run:
|
| 180 |
+
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
| 181 |
+
with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
|
| 182 |
+
result = run_container("myimage")
|
| 183 |
+
|
| 184 |
+
# Should have some elapsed time (even if small)
|
| 185 |
+
assert result.elapsed_seconds >= 0
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@pytest.mark.integration
|
| 189 |
+
class TestDockerIntegration:
|
| 190 |
+
"""Integration tests requiring real Docker."""
|
| 191 |
+
|
| 192 |
+
def test_docker_actually_available(self) -> None:
|
| 193 |
+
"""Docker is actually available on this system."""
|
| 194 |
+
# This test only runs with -m integration
|
| 195 |
+
assert check_docker_available() is True
|
| 196 |
+
|
| 197 |
+
def test_can_run_hello_world(self) -> None:
|
| 198 |
+
"""Can run docker hello-world container."""
|
| 199 |
+
result = run_container("hello-world", timeout=60.0)
|
| 200 |
+
|
| 201 |
+
assert result.exit_code == 0
|
| 202 |
+
assert "Hello from Docker!" in result.stdout
|