VibecoderMcSwaggins commited on
Commit
aef1f5a
·
unverified ·
1 Parent(s): 1b55f5e

feat: Phase 1A + Phase 2 - Local data loader and DeepISLES Docker wrapper (#3)

Browse files

## Summary
- **Phase 1A**: Implement local file loader for ISLES24-MR-Lite dataset (149 cases)
- **Phase 2**: Implement DeepISLES Docker wrapper with GPU support

## Changes
- Add `LocalDataset` dataclass for file-based dataset access
- Add BIDS filename parsing (`parse_subject_id`)
- Add Docker utilities (`run_container`, `build_docker_command`, GPU detection)
- Add DeepISLES wrapper (`run_deepisles_on_folder`, `validate_input_folder`)
- 52 unit tests, mypy strict, ruff clean

## CodeRabbit Feedback Addressed
- Made `inspect_isles24.py` executable
- Fixed Windows compatibility in `match_user` logic

.gitignore CHANGED
@@ -205,3 +205,7 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
 
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+
209
+ # Data Discovery (per docs/specs/data-discovery.md)
210
+ data/scratch/*
211
+ !data/scratch/.gitkeep
data/scratch/.gitkeep ADDED
File without changes
docs/specs/00-context.md CHANGED
@@ -11,19 +11,38 @@ This document explains **why** we're building `stroke-deepisles-demo` and the ar
11
  We want to demonstrate an end-to-end neuroimaging inference pipeline:
12
 
13
  ```
14
- HuggingFace Hub (ISLES24-MR-Lite)
15
-
16
- BIDS/NIfTI loader (datasets fork)
17
-
18
- DeepISLES Docker (stroke segmentation)
19
-
20
- NiiVue visualization (Gradio Space)
 
 
 
 
 
 
 
 
 
 
21
  ```
22
 
23
  This showcases that:
24
- 1. Neuroimaging data can be consumed from HF Hub with proper BIDS/NIfTI support
25
- 2. Clinical-grade models can run via Docker as black boxes
26
- 3. Results can be visualized interactively in a browser
 
 
 
 
 
 
 
 
 
27
 
28
  ## why we need tobias's datasets fork
29
 
@@ -55,11 +74,22 @@ We pin to this branch until upstream merges the PRs.
55
 
56
  ### 1. data source: ISLES24-MR-Lite
57
 
58
- - **HF Dataset**: [YongchengYAO/ISLES24-MR-Lite](https://huggingface.co/datasets/YongchengYAO/ISLES24-MR-Lite)
 
59
  - **Content**: 149 acute stroke MRI cases with DWI, ADC, and manual infarct masks
60
  - **Origin**: Subset of ISLES 2024 challenge data
61
  - **Why suitable**: DeepISLES was trained on ISLES 2022, so ISLES24 is an **external** test set (no data leakage)
62
 
 
 
 
 
 
 
 
 
 
 
63
  ### 2. model: DeepISLES
64
 
65
  - **Paper**: Nature Communications 2025 - "DeepISLES: A clinically validated ischemic stroke segmentation model"
 
11
  We want to demonstrate an end-to-end neuroimaging inference pipeline:
12
 
13
  ```
14
+ CURRENT (Phase 1A):
15
+ Local NIfTI files (extracted from ISLES24-MR-Lite ZIPs)
16
+
17
+ File-based loader (parse BIDS filenames)
18
+
19
+ DeepISLES Docker (stroke segmentation)
20
+
21
+ NiiVue visualization (Gradio Space)
22
+
23
+ FUTURE (Phase 1C-D):
24
+ HuggingFace Hub (properly uploaded dataset)
25
+
26
+ Tobias's datasets fork (BIDS loader + Nifti feature)
27
+
28
+ DeepISLES Docker (stroke segmentation)
29
+
30
+ NiiVue visualization (Gradio Space)
31
  ```
32
 
33
  This showcases that:
34
+ 1. Neuroimaging data can be loaded from local BIDS-named files (NOW)
35
+ 2. Neuroimaging data can be consumed from HF Hub with proper BIDS/NIfTI support (FUTURE)
36
+ 3. Clinical-grade models can run via Docker as black boxes
37
+ 4. Results can be visualized interactively in a browser
38
+
39
+ ## critical discovery (2025-12-04)
40
+
41
+ **The original ISLES24-MR-Lite dataset is NOT properly uploaded to HuggingFace.**
42
+
43
+ It's just raw ZIP files dumped on HF, not a proper Dataset with parquet/Arrow format. This means `load_dataset()` fails. See `data/scratch/isles24_schema_report.txt` for full details.
44
+
45
+ **Workaround**: We extracted the ZIPs locally to `data/scratch/isles24_extracted/` (git-ignored) and will implement a file-based loader first. Later, we'll re-upload properly and verify full HF consumption.
46
 
47
  ## why we need tobias's datasets fork
48
 
 
74
 
75
  ### 1. data source: ISLES24-MR-Lite
76
 
77
+ - **HF Dataset**: [YongchengYAO/ISLES24-MR-Lite](https://huggingface.co/datasets/YongchengYAO/ISLES24-MR-Lite) (**BROKEN** - raw ZIPs, not proper dataset)
78
+ - **Local extracted**: `data/scratch/isles24_extracted/` (git-ignored)
79
  - **Content**: 149 acute stroke MRI cases with DWI, ADC, and manual infarct masks
80
  - **Origin**: Subset of ISLES 2024 challenge data
81
  - **Why suitable**: DeepISLES was trained on ISLES 2022, so ISLES24 is an **external** test set (no data leakage)
82
 
83
+ **File structure** (after extraction):
84
+ ```
85
+ data/scratch/isles24_extracted/
86
+ ├── Images-DWI/sub-stroke{XXXX}_ses-02_dwi.nii.gz # 149 files
87
+ ├── Images-ADC/sub-stroke{XXXX}_ses-02_adc.nii.gz # 149 files
88
+ └── Masks/sub-stroke{XXXX}_ses-02_lesion-msk.nii.gz # 149 files
89
+ ```
90
+
91
+ **Schema reference**: `data/scratch/isles24_schema_report.txt`
92
+
93
  ### 2. model: DeepISLES
94
 
95
  - **Paper**: Nature Communications 2025 - "DeepISLES: A clinically validated ischemic stroke segmentation model"
docs/specs/02-phase-1-data-access.md CHANGED
@@ -1,695 +1,415 @@
1
- # phase 1: data access / hf integration
2
 
3
  ## purpose
4
 
5
- Implement the data loading layer that consumes ISLES24-MR-Lite from HuggingFace Hub. At the end of this phase, we can load any case by ID and get local paths to DWI, ADC, and ground truth NIfTI files.
6
 
7
- ## deliverables
8
 
9
- - [ ] `src/stroke_deepisles_demo/data/loader.py` - HF dataset loading
10
- - [ ] `src/stroke_deepisles_demo/data/adapter.py` - Case adapter for file access
11
- - [ ] `src/stroke_deepisles_demo/data/staging.py` - Stage files for DeepISLES
12
- - [ ] Unit tests with fixtures (no network required)
13
- - [ ] Integration test (marked, requires network)
14
 
15
- ## vertical slice outcome
 
 
 
 
16
 
17
- After this phase, you can run:
18
 
19
- ```python
20
- from stroke_deepisles_demo.data import get_case, list_case_ids
 
 
 
21
 
22
- # List available cases
23
- case_ids = list_case_ids()
24
- print(f"Found {len(case_ids)} cases")
25
 
26
- # Load a specific case
27
- case = get_case("sub-001")
28
- print(f"DWI: {case.dwi}")
29
- print(f"ADC: {case.adc}")
30
- print(f"Ground truth: {case.ground_truth}")
31
- ```
32
 
33
- ## module structure
34
 
35
  ```
36
- src/stroke_deepisles_demo/data/
37
- ├── __init__.py # Public API exports
38
- ├── loader.py # HF Hub dataset loading
39
- ├── adapter.py # Case adapter (index → files)
40
- └── staging.py # Stage files with DeepISLES naming
 
 
41
  ```
42
 
43
- ## interfaces and types
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- ### `data/loader.py`
 
 
 
 
 
 
 
 
 
 
46
 
47
  ```python
48
- """Load ISLES24-MR-Lite dataset from HuggingFace Hub."""
49
 
50
  from __future__ import annotations
51
 
 
52
  from pathlib import Path
53
  from typing import TYPE_CHECKING
54
 
55
  if TYPE_CHECKING:
56
- from datasets import Dataset
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def load_isles_dataset(
60
- dataset_id: str = "YongchengYAO/ISLES24-MR-Lite",
61
  *,
62
- cache_dir: Path | None = None,
63
- streaming: bool = False,
64
- ) -> Dataset:
65
  """
66
- Load the ISLES24-MR-Lite dataset from HuggingFace Hub.
67
 
68
  Args:
69
- dataset_id: HuggingFace dataset identifier
70
- cache_dir: Local cache directory (uses HF default if None)
71
- streaming: If True, use streaming mode (lazy loading)
72
 
73
  Returns:
74
- HuggingFace Dataset object with BIDS/NIfTI support
75
 
76
  Raises:
77
- DataLoadError: If dataset cannot be loaded
78
  """
79
- ...
 
 
 
80
 
81
 
82
- def get_dataset_info(dataset_id: str = "YongchengYAO/ISLES24-MR-Lite") -> DatasetInfo:
83
  """
84
- Get metadata about the dataset without downloading.
85
 
86
- Returns:
87
- DatasetInfo with case count, available modalities, etc.
 
 
 
88
  """
89
  ...
90
-
91
-
92
- @dataclass
93
- class DatasetInfo:
94
- """Metadata about the loaded dataset."""
95
-
96
- dataset_id: str
97
- num_cases: int
98
- modalities: list[str] # e.g., ["dwi", "adc", "mask"]
99
- has_ground_truth: bool
100
  ```
101
 
102
- ### `data/adapter.py`
103
 
104
  ```python
105
- """Adapt HF dataset rows to typed file references."""
106
 
107
  from __future__ import annotations
108
 
 
 
109
  from pathlib import Path
110
  from typing import Iterator
111
 
112
  from stroke_deepisles_demo.core.types import CaseFiles
113
 
114
 
115
- class CaseAdapter:
116
- """
117
- Adapts HuggingFace dataset to provide typed access to case files.
118
-
119
- This handles the mapping between HF dataset structure and our
120
- internal CaseFiles type.
121
- """
122
-
123
- def __init__(self, dataset: Dataset) -> None:
124
- """
125
- Initialize adapter with a loaded dataset.
126
 
127
- Args:
128
- dataset: HuggingFace Dataset with NIfTI files
129
- """
130
- ...
131
 
132
  def __len__(self) -> int:
133
- """Return number of cases in the dataset."""
134
- ...
135
 
136
  def __iter__(self) -> Iterator[str]:
137
- """Iterate over case IDs."""
138
- ...
139
 
140
  def list_case_ids(self) -> list[str]:
141
- """
142
- List all available case identifiers.
143
-
144
- Returns:
145
- List of case IDs (e.g., ["sub-001", "sub-002", ...])
146
- """
147
- ...
148
 
149
  def get_case(self, case_id: str | int) -> CaseFiles:
150
- """
151
- Get file paths for a specific case.
 
 
152
 
153
- Args:
154
- case_id: Either a string ID (e.g., "sub-001") or integer index
155
 
156
- Returns:
157
- CaseFiles with paths to DWI, ADC, and optionally ground truth
158
 
159
- Raises:
160
- KeyError: If case_id not found
161
- DataLoadError: If files cannot be accessed
162
- """
163
- ...
164
 
165
- def get_case_by_index(self, index: int) -> tuple[str, CaseFiles]:
166
- """
167
- Get case by numerical index.
168
-
169
- Returns:
170
- Tuple of (case_id, CaseFiles)
171
- """
172
- ...
173
- ```
174
 
175
- ### `data/staging.py`
176
-
177
- ```python
178
- """Stage NIfTI files with DeepISLES-expected naming."""
179
-
180
- from __future__ import annotations
181
 
182
- from pathlib import Path
183
- from typing import NamedTuple
184
-
185
- from stroke_deepisles_demo.core.types import CaseFiles
186
-
187
-
188
- class StagedCase(NamedTuple):
189
- """Paths to staged files ready for DeepISLES."""
190
-
191
- input_dir: Path # Directory containing staged files
192
- dwi_path: Path # Path to dwi.nii.gz
193
- adc_path: Path # Path to adc.nii.gz
194
- flair_path: Path | None # Path to flair.nii.gz if available
195
-
196
-
197
- def stage_case_for_deepisles(
198
- case_files: CaseFiles,
199
- output_dir: Path,
200
- *,
201
- case_id: str | None = None,
202
- ) -> StagedCase:
203
  """
204
- Stage case files with DeepISLES-expected naming convention.
205
 
206
- DeepISLES expects files named exactly:
207
- - dwi.nii.gz
208
- - adc.nii.gz
209
- - flair.nii.gz (optional)
210
-
211
- This function copies/symlinks the source files to a staging directory
212
- with the correct names.
213
-
214
- Args:
215
- case_files: Source file paths from CaseAdapter
216
- output_dir: Directory to stage files into
217
- case_id: Optional case ID for logging/subdirectory
218
-
219
- Returns:
220
- StagedCase with paths to staged files
221
-
222
- Raises:
223
- MissingInputError: If required files (DWI, ADC) are missing
224
- OSError: If file operations fail
225
  """
226
- ...
 
 
227
 
 
228
 
229
- def create_staging_directory(base_dir: Path | None = None) -> Path:
230
- """
231
- Create a temporary staging directory.
 
 
232
 
233
- Args:
234
- base_dir: Parent directory (uses system temp if None)
 
235
 
236
- Returns:
237
- Path to created staging directory
238
- """
239
- ...
240
- ```
241
-
242
- ### `data/__init__.py` (public API)
243
-
244
- ```python
245
- """Data loading and case management for stroke-deepisles-demo."""
246
-
247
- from stroke_deepisles_demo.data.adapter import CaseAdapter
248
- from stroke_deepisles_demo.data.loader import DatasetInfo, get_dataset_info, load_isles_dataset
249
- from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
250
-
251
- __all__ = [
252
- # Loader
253
- "load_isles_dataset",
254
- "get_dataset_info",
255
- "DatasetInfo",
256
- # Adapter
257
- "CaseAdapter",
258
- # Staging
259
- "stage_case_for_deepisles",
260
- "StagedCase",
261
- ]
262
-
263
-
264
- # Convenience functions (combine loader + adapter)
265
- def get_case(case_id: str | int) -> CaseFiles:
266
- """Load a single case by ID or index."""
267
- ...
268
 
 
 
 
 
 
269
 
270
- def list_case_ids() -> list[str]:
271
- """List all available case IDs."""
272
- ...
273
- ```
274
-
275
- ## tdd plan
276
-
277
- ### test file structure
278
-
279
- ```
280
- tests/
281
- ├── conftest.py # Shared fixtures
282
- ├── data/
283
- │ ├── __init__.py
284
- │ ├── test_loader.py # Tests for HF loading
285
- │ ├── test_adapter.py # Tests for case adapter
286
- │ └── test_staging.py # Tests for file staging
287
- └── fixtures/
288
- └── nifti/ # Minimal synthetic NIfTI files
289
- ├── dwi.nii.gz
290
- ├── adc.nii.gz
291
- └── mask.nii.gz
292
  ```
293
 
294
- ### tests to write first (TDD order)
295
 
296
- #### 1. `tests/conftest.py` - Fixtures
297
 
298
  ```python
299
- """Shared test fixtures."""
300
-
301
- from __future__ import annotations
302
-
303
- import tempfile
304
- from pathlib import Path
305
-
306
- import nibabel as nib
307
- import numpy as np
308
- import pytest
309
-
310
-
311
  @pytest.fixture
312
- def temp_dir() -> Path:
313
- """Create a temporary directory for test outputs."""
314
- with tempfile.TemporaryDirectory() as td:
315
- yield Path(td)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
 
 
 
317
 
318
- @pytest.fixture
319
- def synthetic_nifti_3d(temp_dir: Path) -> Path:
320
- """Create a minimal synthetic 3D NIfTI file."""
321
- data = np.random.rand(10, 10, 10).astype(np.float32)
322
- img = nib.Nifti1Image(data, affine=np.eye(4))
323
- path = temp_dir / "synthetic.nii.gz"
324
- nib.save(img, path)
325
- return path
326
 
 
 
 
 
327
 
328
- @pytest.fixture
329
- def synthetic_case_files(temp_dir: Path) -> CaseFiles:
330
- """Create a complete set of synthetic case files."""
331
- # Create DWI
332
- dwi_data = np.random.rand(64, 64, 30).astype(np.float32)
333
- dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4))
334
- dwi_path = temp_dir / "dwi.nii.gz"
335
- nib.save(dwi_img, dwi_path)
336
-
337
- # Create ADC
338
- adc_data = np.random.rand(64, 64, 30).astype(np.float32) * 2000
339
- adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4))
340
- adc_path = temp_dir / "adc.nii.gz"
341
- nib.save(adc_img, adc_path)
342
-
343
- # Create mask
344
- mask_data = (np.random.rand(64, 64, 30) > 0.9).astype(np.uint8)
345
- mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4))
346
- mask_path = temp_dir / "mask.nii.gz"
347
- nib.save(mask_img, mask_path)
348
-
349
- return CaseFiles(
350
- dwi=dwi_path,
351
- adc=adc_path,
352
- flair=None,
353
- ground_truth=mask_path,
354
- )
355
 
 
 
 
 
356
 
357
- @pytest.fixture
358
- def mock_hf_dataset(synthetic_case_files: CaseFiles):
359
- """Create a mock HF Dataset-like object."""
360
- # Returns a simple dict-based mock that mimics Dataset behavior
361
- ...
362
  ```
363
 
364
- #### 2. `tests/data/test_staging.py` - Start with staging (no network)
365
 
366
  ```python
367
- """Tests for data staging module."""
368
 
369
- from __future__ import annotations
 
 
370
 
371
- from pathlib import Path
 
 
372
 
373
- import pytest
374
 
375
- from stroke_deepisles_demo.core.exceptions import MissingInputError
376
- from stroke_deepisles_demo.core.types import CaseFiles
377
- from stroke_deepisles_demo.data.staging import (
378
- StagedCase,
379
- create_staging_directory,
380
- stage_case_for_deepisles,
381
- )
382
-
383
-
384
- class TestCreateStagingDirectory:
385
- """Tests for create_staging_directory."""
386
-
387
- def test_creates_directory(self, temp_dir: Path) -> None:
388
- """Staging directory is created and exists."""
389
- staging = create_staging_directory(base_dir=temp_dir)
390
- assert staging.exists()
391
- assert staging.is_dir()
392
-
393
- def test_uses_system_temp_when_no_base(self) -> None:
394
- """Uses system temp directory when base_dir is None."""
395
- staging = create_staging_directory(base_dir=None)
396
- assert staging.exists()
397
- # Cleanup
398
- staging.rmdir()
399
-
400
-
401
- class TestStageCaseForDeepIsles:
402
- """Tests for stage_case_for_deepisles."""
403
-
404
- def test_stages_required_files(
405
- self, synthetic_case_files: CaseFiles, temp_dir: Path
406
- ) -> None:
407
- """DWI and ADC are staged with correct names."""
408
- staged = stage_case_for_deepisles(synthetic_case_files, temp_dir)
409
-
410
- assert staged.dwi_path.name == "dwi.nii.gz"
411
- assert staged.adc_path.name == "adc.nii.gz"
412
- assert staged.dwi_path.exists()
413
- assert staged.adc_path.exists()
414
-
415
- def test_staged_files_are_readable(
416
- self, synthetic_case_files: CaseFiles, temp_dir: Path
417
- ) -> None:
418
- """Staged files can be read as valid NIfTI."""
419
- import nibabel as nib
420
-
421
- staged = stage_case_for_deepisles(synthetic_case_files, temp_dir)
422
-
423
- dwi = nib.load(staged.dwi_path)
424
- assert dwi.shape == (64, 64, 30)
425
-
426
- def test_raises_when_dwi_missing(self, temp_dir: Path) -> None:
427
- """Raises MissingInputError when DWI is missing."""
428
- case_files = CaseFiles(
429
- dwi=temp_dir / "nonexistent.nii.gz",
430
- adc=temp_dir / "adc.nii.gz",
431
- flair=None,
432
- ground_truth=None,
433
- )
434
 
435
- with pytest.raises(MissingInputError, match="DWI"):
436
- stage_case_for_deepisles(case_files, temp_dir)
437
-
438
- def test_flair_is_optional(
439
- self, synthetic_case_files: CaseFiles, temp_dir: Path
440
- ) -> None:
441
- """Staging succeeds when FLAIR is None."""
442
- # synthetic_case_files has flair=None
443
- staged = stage_case_for_deepisles(synthetic_case_files, temp_dir)
444
 
445
- assert staged.flair_path is None
 
 
446
  ```
447
 
448
- #### 3. `tests/data/test_adapter.py` - Case adapter with mocks
449
-
450
- ```python
451
- """Tests for case adapter module."""
452
-
453
- from __future__ import annotations
454
-
455
- import pytest
456
-
457
- from stroke_deepisles_demo.core.types import CaseFiles
458
- from stroke_deepisles_demo.data.adapter import CaseAdapter
459
-
460
-
461
- class TestCaseAdapter:
462
- """Tests for CaseAdapter."""
463
-
464
- def test_list_case_ids_returns_strings(self, mock_hf_dataset) -> None:
465
- """list_case_ids returns list of string identifiers."""
466
- adapter = CaseAdapter(mock_hf_dataset)
467
- case_ids = adapter.list_case_ids()
468
-
469
- assert isinstance(case_ids, list)
470
- assert all(isinstance(cid, str) for cid in case_ids)
471
-
472
- def test_len_matches_dataset_size(self, mock_hf_dataset) -> None:
473
- """len(adapter) equals number of cases in dataset."""
474
- adapter = CaseAdapter(mock_hf_dataset)
475
-
476
- assert len(adapter) == len(mock_hf_dataset)
477
-
478
- def test_get_case_by_string_id(self, mock_hf_dataset) -> None:
479
- """Can retrieve case by string identifier."""
480
- adapter = CaseAdapter(mock_hf_dataset)
481
- case_ids = adapter.list_case_ids()
482
 
483
- case = adapter.get_case(case_ids[0])
 
 
 
 
484
 
485
- assert isinstance(case, dict) # CaseFiles is a TypedDict
486
- assert "dwi" in case
487
- assert "adc" in case
488
 
489
- def test_get_case_by_index(self, mock_hf_dataset) -> None:
490
- """Can retrieve case by integer index."""
491
- adapter = CaseAdapter(mock_hf_dataset)
492
 
493
- case_id, case = adapter.get_case_by_index(0)
494
 
495
- assert isinstance(case_id, str)
496
- assert case["dwi"] is not None
497
 
498
- def test_get_case_invalid_id_raises(self, mock_hf_dataset) -> None:
499
- """Raises KeyError for invalid case ID."""
500
- adapter = CaseAdapter(mock_hf_dataset)
501
-
502
- with pytest.raises(KeyError):
503
- adapter.get_case("nonexistent-case-id")
504
-
505
- def test_iteration(self, mock_hf_dataset) -> None:
506
- """Can iterate over case IDs."""
507
- adapter = CaseAdapter(mock_hf_dataset)
508
-
509
- case_ids = list(adapter)
510
-
511
- assert len(case_ids) == len(adapter)
512
- ```
513
-
514
- #### 4. `tests/data/test_loader.py` - Loader with network mocks
515
 
516
  ```python
517
- """Tests for data loader module."""
518
-
519
- from __future__ import annotations
520
-
521
- from unittest.mock import MagicMock, patch
522
-
523
- import pytest
524
-
525
- from stroke_deepisles_demo.core.exceptions import DataLoadError
526
- from stroke_deepisles_demo.data.loader import (
527
- DatasetInfo,
528
- get_dataset_info,
529
- load_isles_dataset,
530
- )
531
-
532
-
533
- class TestLoadIslesDataset:
534
- """Tests for load_isles_dataset."""
535
 
536
- def test_calls_hf_load_dataset(self) -> None:
537
- """Calls datasets.load_dataset with correct arguments."""
538
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
539
- mock_load.return_value = MagicMock()
540
 
541
- load_isles_dataset("test/dataset")
 
 
542
 
543
- mock_load.assert_called_once()
544
- call_args = mock_load.call_args
545
- assert call_args.args[0] == "test/dataset"
546
 
547
- def test_returns_dataset_object(self) -> None:
548
- """Returns the loaded Dataset object."""
549
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
550
- expected = MagicMock()
551
- mock_load.return_value = expected
552
 
553
- result = load_isles_dataset()
554
 
555
- assert result is expected
556
 
557
- def test_handles_load_error(self) -> None:
558
- """Wraps HF errors in DataLoadError."""
559
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
560
- mock_load.side_effect = Exception("Network error")
561
 
562
- with pytest.raises(DataLoadError, match="Network error"):
563
- load_isles_dataset()
 
 
 
 
 
 
564
 
 
565
 
566
- class TestGetDatasetInfo:
567
- """Tests for get_dataset_info."""
 
568
 
569
- def test_returns_datasetinfo(self) -> None:
570
- """Returns DatasetInfo with expected fields."""
571
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
572
- mock_ds = MagicMock()
573
- mock_ds.__len__ = MagicMock(return_value=149)
574
- mock_ds.features = {"dwi": ..., "adc": ..., "mask": ...}
575
- mock_load.return_value = mock_ds
576
 
577
- info = get_dataset_info()
578
 
579
- assert isinstance(info, DatasetInfo)
580
- assert info.num_cases == 149
581
 
 
582
 
583
- @pytest.mark.integration
584
- class TestLoadIslesDatasetIntegration:
585
- """Integration tests that hit the real HuggingFace Hub."""
586
 
587
- @pytest.mark.slow
588
- def test_load_real_dataset(self) -> None:
589
- """Actually loads ISLES24-MR-Lite from HF Hub."""
590
- # This test requires network access
591
- # Run with: pytest -m integration
592
- dataset = load_isles_dataset(streaming=True)
593
 
594
- # Just verify we got something
595
- assert dataset is not None
 
 
596
  ```
597
 
598
- ### what to mock
599
-
600
- - `datasets.load_dataset` - Mock for unit tests, real for integration tests
601
- - `huggingface_hub` calls - Mock for unit tests
602
- - File system operations - Use `temp_dir` fixture with real files
603
-
604
- ### what to test for real
605
-
606
- - NIfTI file creation/reading with nibabel
607
- - File staging (copy/symlink operations)
608
- - Integration test: actual HF Hub download (marked `@pytest.mark.integration`)
609
 
610
- ## "done" criteria
611
-
612
- Phase 1 is complete when:
613
-
614
- 1. All unit tests pass: `uv run pytest tests/data/ -v`
615
- 2. Can load synthetic test cases without network
616
- 3. Can list case IDs from mock dataset
617
- 4. Can stage files with correct DeepISLES naming
618
- 5. Integration test passes (with network): `uv run pytest -m integration`
619
- 6. Type checking passes: `uv run mypy src/stroke_deepisles_demo/data/`
620
- 7. Code coverage for data module > 80%
621
-
622
- ## implementation notes
623
-
624
- - ISLES24-MR-Lite structure needs investigation - check HF page for exact column names
625
- - Consider using `huggingface_hub.snapshot_download` if `datasets.load_dataset` has issues with NIfTI
626
- - Staging can use symlinks on Unix, copies on Windows
627
- - Cache the HF dataset locally to avoid repeated downloads
628
-
629
- ### critical: streaming mode + docker materialization
630
-
631
- **Reviewer feedback (valid)**: When using `streaming=True`, the dataset returns URLs or lazy file objects, NOT local POSIX paths. Docker requires physical files on the host disk for volume mounting.
632
-
633
- **Solution**: The `stage_case_for_deepisles` function MUST handle materialization:
634
 
635
  ```python
636
- def stage_case_for_deepisles(
637
- case_files: CaseFiles,
638
- output_dir: Path,
639
- *,
640
- case_id: str | None = None,
641
- ) -> StagedCase:
642
  """
643
- Stage case files with DeepISLES-expected naming.
644
 
645
- IMPORTANT: This function handles both local paths and streaming data.
646
- When files come from streaming mode, they must be downloaded/materialized
647
- before Docker can mount them.
648
- """
649
- output_dir.mkdir(parents=True, exist_ok=True)
 
650
 
651
- # Handle DWI - may be Path, URL, or NIfTI object
652
- dwi_staged = output_dir / "dwi.nii.gz"
653
- _materialize_nifti(case_files["dwi"], dwi_staged)
 
 
 
 
654
 
655
- # Handle ADC
656
- adc_staged = output_dir / "adc.nii.gz"
657
- _materialize_nifti(case_files["adc"], adc_staged)
658
 
659
- # ... etc
660
 
 
 
 
 
661
 
662
- def _materialize_nifti(source: Path | str | bytes | NiftiImage, dest: Path) -> None:
663
- """
664
- Materialize a NIfTI file to a local path.
665
 
666
- Handles:
667
- - Local Path: copy or symlink
668
- - URL string: download
669
- - bytes: write directly
670
- - NIfTI object: serialize with nibabel
671
- """
672
- if isinstance(source, Path) and source.exists():
673
- # Local file - symlink if possible, copy otherwise
674
- shutil.copy2(source, dest)
675
- elif isinstance(source, str) and source.startswith(("http://", "https://")):
676
- # URL - download
677
- _download_file(source, dest)
678
- elif isinstance(source, bytes):
679
- # Raw bytes
680
- dest.write_bytes(source)
681
- elif hasattr(source, "to_bytes"):
682
- # NIfTI object (nibabel or wrapper)
683
- dest.write_bytes(source.to_bytes())
684
- else:
685
- raise MissingInputError(f"Cannot materialize source: {type(source)}")
686
- ```
687
 
688
- This ensures Docker always gets physical files regardless of how data was loaded.
689
 
690
- ## dependencies to add
691
 
692
- No new dependencies needed - all specified in Phase 0:
693
- - `datasets` (Tobias fork)
694
- - `nibabel`
695
- - `numpy`
 
1
+ # phase 1: data access layer
2
 
3
  ## purpose
4
 
5
+ Implement a data loading layer that provides typed access to ISLES24 neuroimaging cases. This phase is split into sub-phases due to a critical discovery: the upstream dataset is not properly formatted for HuggingFace consumption.
6
 
7
+ ## critical discovery (2025-12-04)
8
 
9
+ **`YongchengYAO/ISLES24-MR-Lite` is NOT a proper HuggingFace Dataset.**
 
 
 
 
10
 
11
+ | What we expected | What actually exists |
12
+ |------------------|---------------------|
13
+ | `load_dataset()` returns Dataset with columns | `load_dataset()` FAILS with "no data" |
14
+ | Columns: `dwi`, `adc`, `mask`, `participant_id` | No columns - just raw ZIP files |
15
+ | Parquet/Arrow format | Three ZIP archives dumped on HF |
16
 
17
+ **Evidence**: `data/scratch/isles24_schema_report.txt`
18
 
19
+ This means the demo must be built in phases:
20
+ 1. **Phase 1A**: Local file loader (works NOW with extracted data)
21
+ 2. **Phase 1B**: Test Tobias's `Nifti()` feature on local files (proves loading works)
22
+ 3. **Phase 1C**: Upload properly to HuggingFace (future - proves production pipeline)
23
+ 4. **Phase 1D**: Consume via Tobias's fork (future - proves full round-trip)
24
 
25
+ ---
 
 
26
 
27
+ ## phase 1a: local file loader (CURRENT PRIORITY)
 
 
 
 
 
28
 
29
+ ### data location
30
 
31
  ```
32
+ data/scratch/isles24_extracted/ # Git-ignored
33
+ ├── Images-DWI/ # 149 files
34
+ │ └── sub-stroke{XXXX}_ses-02_dwi.nii.gz
35
+ ├── Images-ADC/ # 149 files
36
+ └── sub-stroke{XXXX}_ses-02_adc.nii.gz
37
+ └── Masks/ # 149 files
38
+ └── sub-stroke{XXXX}_ses-02_lesion-msk.nii.gz
39
  ```
40
 
41
+ ### file naming convention (BIDS-like)
42
+
43
+ | Component | Pattern | Example |
44
+ |-----------|---------|---------|
45
+ | Subject ID | `sub-stroke{XXXX}` | `sub-stroke0005` |
46
+ | Session | `ses-02` | Always "02" in this dataset |
47
+ | Modality | `dwi`, `adc`, `lesion-msk` | - |
48
+ | Extension | `.nii.gz` | Compressed NIfTI |
49
+
50
+ **Subject ID regex**: `sub-stroke(\d{4})_ses-02_.*\.nii\.gz`
51
+
52
+ **Note**: Subject IDs have gaps (e.g., 0018 missing). Range is 0001-0189, total 149 cases.
53
 
54
+ ### deliverables
55
+
56
+ - [ ] `src/stroke_deepisles_demo/data/loader.py` - Rewrite with local mode
57
+ - [ ] `src/stroke_deepisles_demo/data/adapter.py` - Rewrite for file-based access
58
+ - [ ] `src/stroke_deepisles_demo/data/staging.py` - Already correct, no changes
59
+ - [ ] Unit tests with synthetic fixtures
60
+ - [ ] Integration test with actual extracted data
61
+
62
+ ### interfaces
63
+
64
+ #### `data/loader.py`
65
 
66
  ```python
67
+ """Load ISLES24 data from local directory or HuggingFace Hub."""
68
 
69
  from __future__ import annotations
70
 
71
+ from dataclasses import dataclass
72
  from pathlib import Path
73
  from typing import TYPE_CHECKING
74
 
75
  if TYPE_CHECKING:
76
+ from stroke_deepisles_demo.data.adapter import LocalDataset
77
+
78
+
79
+ @dataclass
80
+ class DatasetInfo:
81
+ """Metadata about the dataset."""
82
+
83
+ source: str # "local" or HF dataset ID
84
+ num_cases: int
85
+ modalities: list[str]
86
+ has_ground_truth: bool
87
 
88
 
89
  def load_isles_dataset(
90
+ source: str | Path = "data/scratch/isles24_extracted",
91
  *,
92
+ local_mode: bool = True, # Default to local for now
93
+ ) -> LocalDataset:
 
94
  """
95
+ Load ISLES24 dataset.
96
 
97
  Args:
98
+ source: Local directory path or HuggingFace dataset ID
99
+ local_mode: If True, treat source as local directory
 
100
 
101
  Returns:
102
+ Dataset-like object providing case access
103
 
104
  Raises:
105
+ DataLoadError: If data cannot be loaded
106
  """
107
+ if local_mode or isinstance(source, Path):
108
+ return _load_from_local_directory(Path(source))
109
+ # Future: return _load_from_huggingface(source)
110
+ raise NotImplementedError("HuggingFace mode not yet implemented")
111
 
112
 
113
+ def _load_from_local_directory(data_dir: Path) -> LocalDataset:
114
  """
115
+ Load cases from extracted local files.
116
 
117
+ Expects structure:
118
+ data_dir/
119
+ ├── Images-DWI/sub-stroke{XXXX}_ses-02_dwi.nii.gz
120
+ ├── Images-ADC/sub-stroke{XXXX}_ses-02_adc.nii.gz
121
+ └── Masks/sub-stroke{XXXX}_ses-02_lesion-msk.nii.gz
122
  """
123
  ...
 
 
 
 
 
 
 
 
 
 
124
  ```
125
 
126
+ #### `data/adapter.py`
127
 
128
  ```python
129
+ """Provide typed access to ISLES24 cases."""
130
 
131
  from __future__ import annotations
132
 
133
+ import re
134
+ from dataclasses import dataclass
135
  from pathlib import Path
136
  from typing import Iterator
137
 
138
  from stroke_deepisles_demo.core.types import CaseFiles
139
 
140
 
141
+ @dataclass
142
+ class LocalDataset:
143
+ """File-based dataset for local ISLES24 data."""
 
 
 
 
 
 
 
 
144
 
145
+ data_dir: Path
146
+ cases: dict[str, CaseFiles] # subject_id -> files
 
 
147
 
148
  def __len__(self) -> int:
149
+ return len(self.cases)
 
150
 
151
  def __iter__(self) -> Iterator[str]:
152
+ return iter(self.cases.keys())
 
153
 
154
  def list_case_ids(self) -> list[str]:
155
+ """Return sorted list of subject IDs."""
156
+ return sorted(self.cases.keys())
 
 
 
 
 
157
 
158
  def get_case(self, case_id: str | int) -> CaseFiles:
159
+ """Get files for a case by ID or index."""
160
+ if isinstance(case_id, int):
161
+ case_id = self.list_case_ids()[case_id]
162
+ return self.cases[case_id]
163
 
 
 
164
 
165
+ # Subject ID extraction
166
+ SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
167
 
 
 
 
 
 
168
 
169
+ def parse_subject_id(filename: str) -> str | None:
170
+ """Extract subject ID from BIDS filename."""
171
+ match = SUBJECT_PATTERN.match(filename)
172
+ return f"sub-{match.group(1)}" if match else None
 
 
 
 
 
173
 
 
 
 
 
 
 
174
 
175
+ def build_local_dataset(data_dir: Path) -> LocalDataset:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  """
177
+ Scan directory and build case mapping.
178
 
179
+ Matches DWI + ADC + Mask files by subject ID.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  """
181
+ dwi_dir = data_dir / "Images-DWI"
182
+ adc_dir = data_dir / "Images-ADC"
183
+ mask_dir = data_dir / "Masks"
184
 
185
+ cases: dict[str, CaseFiles] = {}
186
 
187
+ # Scan DWI files to get subject IDs
188
+ for dwi_file in dwi_dir.glob("*.nii.gz"):
189
+ subject_id = parse_subject_id(dwi_file.name)
190
+ if not subject_id:
191
+ continue
192
 
193
+ # Find matching ADC and Mask
194
+ adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.")
195
+ mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.")
196
 
197
+ if not adc_file.exists():
198
+ continue # Skip incomplete cases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ cases[subject_id] = CaseFiles(
201
+ dwi=dwi_file,
202
+ adc=adc_file,
203
+ ground_truth=mask_file if mask_file.exists() else None,
204
+ )
205
 
206
+ return LocalDataset(data_dir=data_dir, cases=cases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  ```
208
 
209
+ ### synthetic fixture structure
210
 
211
+ Unit tests MUST use fixtures that replicate the **exact** directory structure. Add to `tests/conftest.py`:
212
 
213
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
214
  @pytest.fixture
215
+ def synthetic_isles_dir(temp_dir: Path) -> Path:
216
+ """
217
+ Create synthetic ISLES24-like directory structure.
218
+
219
+ Structure:
220
+ temp_dir/
221
+ ├── Images-DWI/
222
+ │ ├── sub-stroke0001_ses-02_dwi.nii.gz
223
+ │ └── sub-stroke0002_ses-02_dwi.nii.gz
224
+ ├── Images-ADC/
225
+ │ ├── sub-stroke0001_ses-02_adc.nii.gz
226
+ │ └── sub-stroke0002_ses-02_adc.nii.gz
227
+ └── Masks/
228
+ ├── sub-stroke0001_ses-02_lesion-msk.nii.gz
229
+ └── sub-stroke0002_ses-02_lesion-msk.nii.gz
230
+ """
231
+ dwi_dir = temp_dir / "Images-DWI"
232
+ adc_dir = temp_dir / "Images-ADC"
233
+ mask_dir = temp_dir / "Masks"
234
 
235
+ dwi_dir.mkdir()
236
+ adc_dir.mkdir()
237
+ mask_dir.mkdir()
238
 
239
+ for subject_num in [1, 2]:
240
+ subject_id = f"sub-stroke{subject_num:04d}"
 
 
 
 
 
 
241
 
242
+ # Create DWI
243
+ dwi_data = np.random.rand(10, 10, 5).astype(np.float32)
244
+ dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4))
245
+ nib.save(dwi_img, dwi_dir / f"{subject_id}_ses-02_dwi.nii.gz")
246
 
247
+ # Create ADC
248
+ adc_data = np.random.rand(10, 10, 5).astype(np.float32) * 2000
249
+ adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4))
250
+ nib.save(adc_img, adc_dir / f"{subject_id}_ses-02_adc.nii.gz")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ # Create Mask
253
+ mask_data = (np.random.rand(10, 10, 5) > 0.9).astype(np.uint8)
254
+ mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4))
255
+ nib.save(mask_img, mask_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz")
256
 
257
+ return temp_dir
 
 
 
 
258
  ```
259
 
260
+ ### tdd plan
261
 
262
  ```python
263
+ # tests/data/test_loader.py
264
 
265
+ def test_load_from_local_returns_local_dataset(synthetic_isles_dir):
266
+ """Local mode returns LocalDataset."""
267
+ ...
268
 
269
+ def test_load_from_local_finds_all_cases(synthetic_isles_dir):
270
+ """Finds all cases in synthetic structure."""
271
+ ...
272
 
273
+ # tests/data/test_adapter.py
274
 
275
+ def test_parse_subject_id_extracts_correctly():
276
+ """Extracts subject ID from BIDS filename."""
277
+ assert parse_subject_id("sub-stroke0005_ses-02_dwi.nii.gz") == "sub-stroke0005"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ def test_build_local_dataset_matches_files(synthetic_isles_dir):
280
+ """Matches DWI, ADC, Mask by subject ID."""
281
+ ...
 
 
 
 
 
 
282
 
283
+ def test_get_case_returns_case_files(synthetic_isles_dir):
284
+ """get_case returns CaseFiles with correct paths."""
285
+ ...
286
  ```
287
 
288
+ ### done criteria (phase 1a)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ - [ ] `uv run pytest tests/data/ -v` passes
291
+ - [ ] Can load all 149 cases from `data/scratch/isles24_extracted/`
292
+ - [ ] `list_case_ids()` returns 149 subject IDs
293
+ - [ ] `get_case("sub-stroke0005")` returns valid CaseFiles
294
+ - [ ] Type checking passes: `uv run mypy src/stroke_deepisles_demo/data/`
295
 
296
+ ---
 
 
297
 
298
+ ## phase 1b: test tobias's nifti feature (NEXT)
 
 
299
 
300
+ ### purpose
301
 
302
+ Verify that Tobias's `Nifti()` feature type from the datasets fork can correctly load/parse NIfTI files. This proves the **loading** part of the consumption pipeline works, even though the **download** part is broken.
 
303
 
304
+ ### approach
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  ```python
307
+ # Test script to verify Nifti() feature works on local files
308
+ from datasets import Features, Value
309
+ from datasets.features import Nifti # From Tobias's fork
310
+
311
+ # Create a simple dataset from local files
312
+ features = Features({
313
+ "subject_id": Value("string"),
314
+ "dwi": Nifti(),
315
+ "adc": Nifti(),
316
+ "mask": Nifti(),
317
+ })
318
+
319
+ # Load a single case and verify Nifti() decodes correctly
320
+ ```
 
 
 
 
321
 
322
+ ### done criteria (phase 1b)
 
 
 
323
 
324
+ - [ ] Tobias's `Nifti()` feature loads local `.nii.gz` files
325
+ - [ ] Decoded NIfTI has correct shape/dtype
326
+ - [ ] Can access voxel data via nibabel-like interface
327
 
328
+ ---
 
 
329
 
330
+ ## phase 1c: proper huggingface upload (FUTURE)
 
 
 
 
331
 
332
+ ### purpose
333
 
334
+ Re-upload ISLES24 data to HuggingFace **properly** using the arc-aphasia-bids approach. This proves the **production** pipeline works.
335
 
336
+ ### approach
 
 
 
337
 
338
+ 1. Use BIDS loader from Tobias's fork
339
+ 2. Create proper parquet schema with columns:
340
+ - `subject`: string
341
+ - `session`: string
342
+ - `dwi`: Nifti()
343
+ - `adc`: Nifti()
344
+ - `mask`: Nifti()
345
+ 3. Upload to new HuggingFace repo (e.g., `The-Obstacle-Is-The-Way/ISLES24-BIDS`)
346
 
347
+ ### done criteria (phase 1c)
348
 
349
+ - [ ] Dataset uploaded to HuggingFace with proper schema
350
+ - [ ] HuggingFace dataset viewer shows data correctly
351
+ - [ ] `load_dataset("new-repo-id")` returns Dataset with expected columns
352
 
353
+ ---
 
 
 
 
 
 
354
 
355
+ ## phase 1d: consumption verification (FUTURE)
356
 
357
+ ### purpose
 
358
 
359
+ Verify the full round-trip: Download from HuggingFace using Tobias's fork.
360
 
361
+ ### approach
 
 
362
 
363
+ ```python
364
+ from datasets import load_dataset
 
 
 
 
365
 
366
+ # This should work after Phase 1C
367
+ ds = load_dataset("The-Obstacle-Is-The-Way/ISLES24-BIDS")
368
+ case = ds["train"][0]
369
+ print(case["dwi"].shape) # Should work!
370
  ```
371
 
372
+ ### new adapter function
 
 
 
 
 
 
 
 
 
 
373
 
374
+ When Phase 1D is implemented, `adapter.py` will need a new function alongside `build_local_dataset`:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  ```python
377
+ def adapt_hf_case(hf_row: dict) -> CaseFiles:
 
 
 
 
 
378
  """
379
+ Adapt a HuggingFace Dataset row to CaseFiles.
380
 
381
+ Args:
382
+ hf_row: Row from load_dataset() with columns:
383
+ - dwi: Nifti feature (nibabel-like object)
384
+ - adc: Nifti feature
385
+ - mask: Nifti feature
386
+ - subject: str
387
 
388
+ Returns:
389
+ CaseFiles with materialized paths or nibabel objects
390
+ """
391
+ # Implementation depends on how Nifti() feature exposes data
392
+ # May need to write to temp files or pass nibabel objects directly
393
+ ...
394
+ ```
395
 
396
+ This maintains the same `CaseFiles` contract for downstream phases regardless of data source.
 
 
397
 
398
+ ### done criteria (phase 1d)
399
 
400
+ - [ ] `load_dataset()` works on properly uploaded dataset
401
+ - [ ] `adapt_hf_case()` function converts HF rows to CaseFiles
402
+ - [ ] Full demo runs with HuggingFace consumption (not just local files)
403
+ - [ ] Documents the pitfall for future projects
404
 
405
+ ---
 
 
406
 
407
+ ## dependencies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
+ No new dependencies needed beyond Phase 0.
410
 
411
+ ## notes
412
 
413
+ - The original `adapter.py` assumed HF Dataset with columns - COMPLETELY WRONG
414
+ - The original `loader.py` called `load_dataset()` directly - FAILS on this dataset
415
+ - `staging.py` is still correct - it just needs `CaseFiles` with paths
 
docs/specs/data-discovery.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data discovery & verification protocol
2
+
3
+ ## purpose
4
+ To establish a rigorous, reproducible process for exploring, verifying, and documenting external data sources (Hugging Face Datasets, BIDS repos, etc.) before integrating them into the production codebase. This prevents "schema guessing" and ensures strict typing aligns with reality.
5
+
6
+ ## principles
7
+ 1. **No Assumptions**: Never assume column names, file formats, or data types. Verify them programmatically.
8
+ 2. **Isolation**: Discovery scripts and their outputs must be isolated from production code and source control.
9
+ 3. **Reproducibility**: The discovery process must be scriptable and reproducible, not a series of manual CLI commands.
10
+
11
+ ## standard locations
12
+
13
+ ### scripts
14
+ All discovery logic resides in:
15
+ ```
16
+ scripts/discovery/
17
+ ├── __init__.py
18
+ ├── inspect_hf_dataset.py # e.g., Generic HF inspector
19
+ ├── verify_bids_layout.py # e.g., BIDS validator
20
+ └── ...
21
+ ```
22
+
23
+ ### data & artifacts
24
+ All downloaded samples, temporary outputs, and schema reports reside in:
25
+ ```
26
+ data/scratch/
27
+ ├── .gitkeep # Tracked
28
+ ├── schema_report.txt # Generated report
29
+ └── samples/ # Raw data samples (IGNORED)
30
+ ```
31
+
32
+ ## discovery workflow
33
+
34
+ ### 1. implementation
35
+ Write a focused script in `scripts/discovery/` that:
36
+ - Connects to the data source (e.g., HF Hub).
37
+ - Fetches *metadata* or a *minimal sample* (streaming mode preferred).
38
+ - Prints/Logs:
39
+ - Feature keys (column names).
40
+ - Data types (Arrow types, Python types).
41
+ - Non-null counts (if feasible).
42
+ - A sample row structure.
43
+
44
+ ### 2. execution
45
+ Run the script from the project root:
46
+ ```bash
47
+ uv run scripts/discovery/inspect_hf_dataset.py > data/scratch/schema_report.txt
48
+ ```
49
+
50
+ ### 3. verification
51
+ Manually review `data/scratch/schema_report.txt`.
52
+ - **Check**: Do column names match `CaseAdapter` expectations?
53
+ - **Check**: Are file paths strings or objects?
54
+ - **Check**: Are required fields (DWI, ADC) actually present?
55
+
56
+ ### 4. remediation
57
+ If the report contradicts the code/specs:
58
+ 1. Update the spec (`docs/specs/`) to reflect reality.
59
+ 2. Update the code (`src/.../adapter.py`) to handle the actual schema.
60
+ 3. Add a regression test if the edge case is complex.
61
+
62
+ ## git configuration
63
+ Ensure `.gitignore` includes:
64
+ ```gitignore
65
+ data/scratch/*
66
+ !data/scratch/.gitkeep
67
+ ```
pyproject.toml CHANGED
@@ -118,6 +118,7 @@ addopts = [
118
  "-v",
119
  "--tb=short",
120
  "--strict-markers",
 
121
  ]
122
  markers = [
123
  "integration: marks tests requiring external resources (Docker, network)",
 
118
  "-v",
119
  "--tb=short",
120
  "--strict-markers",
121
+ "-m", "not integration", # Skip integration tests by default
122
  ]
123
  markers = [
124
  "integration: marks tests requiring external resources (Docker, network)",
scripts/discovery/__init__.py ADDED
File without changes
scripts/discovery/inspect_isles24.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ISLES24-MR-Lite Dataset Discovery Script
4
+
5
+ Downloads and inspects the full YongchengYAO/ISLES24-MR-Lite dataset
6
+ to document its exact schema before building adapters.
7
+
8
+ Per: docs/specs/data-discovery.md
9
+
10
+ Output: data/scratch/isles24_schema_report.txt
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import sys
16
+ from collections import Counter
17
+ from datetime import datetime
18
+ from pathlib import Path
19
+ from typing import Any
20
+
21
+ # Constants
22
+ DATASET_ID = "YongchengYAO/ISLES24-MR-Lite"
23
+ OUTPUT_DIR = Path(__file__).parent.parent.parent / "data" / "scratch"
24
+ REPORT_FILE = OUTPUT_DIR / "isles24_schema_report.txt"
25
+
26
+
27
+ def safe_type_name(val: Any) -> str:
28
+ """Get a safe string representation of a value's type."""
29
+ if val is None:
30
+ return "None"
31
+ t = type(val).__name__
32
+ if hasattr(val, "dtype"):
33
+ return f"{t}[{val.dtype}]"
34
+ return t
35
+
36
+
37
+ def safe_repr(val: Any, max_len: int = 100) -> str:
38
+ """Get a safe truncated repr of a value."""
39
+ if val is None:
40
+ return "None"
41
+ if isinstance(val, bytes):
42
+ return f"<bytes len={len(val)}>"
43
+ if isinstance(val, dict):
44
+ if "bytes" in val:
45
+ return f"<dict with 'bytes' key, len={len(val.get('bytes', b''))}>"
46
+ return f"<dict keys={list(val.keys())}>"
47
+ r = repr(val)
48
+ if len(r) > max_len:
49
+ return r[: max_len - 3] + "..."
50
+ return r
51
+
52
+
53
+ def main() -> int:
54
+ """Main discovery workflow."""
55
+ print("=" * 70)
56
+ print("ISLES24-MR-Lite Dataset Discovery")
57
+ print(f"Started: {datetime.now().isoformat()}")
58
+ print("=" * 70)
59
+ print()
60
+
61
+ # Ensure output directory exists
62
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
63
+
64
+ # Import datasets library
65
+ try:
66
+ from datasets import load_dataset
67
+ except ImportError:
68
+ print("ERROR: 'datasets' library not installed.")
69
+ print("Run: uv add datasets")
70
+ return 1
71
+
72
+ # =========================================================================
73
+ # PHASE 1: Load Dataset (Full Download)
74
+ # =========================================================================
75
+ print(f"[1/4] Loading dataset: {DATASET_ID}")
76
+ print(" This will download the FULL dataset...")
77
+ print()
78
+
79
+ try:
80
+ # Try loading without streaming first to get full access
81
+ ds = load_dataset(DATASET_ID)
82
+ print(" SUCCESS: Dataset loaded")
83
+ print(f" Splits available: {list(ds.keys())}")
84
+ print()
85
+ except Exception as e:
86
+ print(f" ERROR loading dataset: {e}")
87
+ print()
88
+ print(" Trying streaming mode as fallback...")
89
+ try:
90
+ ds = load_dataset(DATASET_ID, streaming=True)
91
+ print(" SUCCESS (streaming): Dataset loaded")
92
+ print(f" Splits available: {list(ds.keys())}")
93
+ except Exception as e2:
94
+ print(f" FATAL: Cannot load dataset: {e2}")
95
+ return 1
96
+
97
+ # =========================================================================
98
+ # PHASE 2: Inspect Schema (Features)
99
+ # =========================================================================
100
+ print("[2/4] Inspecting schema...")
101
+ print()
102
+
103
+ report_lines: list[str] = []
104
+ report_lines.append("=" * 70)
105
+ report_lines.append("ISLES24-MR-Lite Schema Discovery Report")
106
+ report_lines.append(f"Generated: {datetime.now().isoformat()}")
107
+ report_lines.append(f"Dataset: {DATASET_ID}")
108
+ report_lines.append("=" * 70)
109
+ report_lines.append("")
110
+
111
+ for split_name in ds:
112
+ split = ds[split_name]
113
+ report_lines.append(f"SPLIT: {split_name}")
114
+ report_lines.append("-" * 50)
115
+
116
+ # Get features/schema
117
+ if hasattr(split, "features"):
118
+ features = split.features
119
+ report_lines.append(
120
+ f"Number of rows: {len(split) if hasattr(split, '__len__') else 'unknown (streaming)'}"
121
+ )
122
+ report_lines.append("")
123
+ report_lines.append("FEATURES (columns):")
124
+ for feat_name, feat_type in features.items():
125
+ report_lines.append(f" - {feat_name}: {feat_type}")
126
+ report_lines.append("")
127
+ else:
128
+ report_lines.append(" (No features metadata available)")
129
+ report_lines.append("")
130
+
131
+ print(" Schema extracted.")
132
+ print()
133
+
134
+ # =========================================================================
135
+ # PHASE 3: Sample Inspection (check actual data)
136
+ # =========================================================================
137
+ print("[3/4] Inspecting sample rows...")
138
+ print()
139
+
140
+ # Use the first available split (usually 'train')
141
+ main_split_name = next(iter(ds.keys()))
142
+ main_split = ds[main_split_name]
143
+
144
+ report_lines.append("=" * 70)
145
+ report_lines.append("SAMPLE DATA INSPECTION")
146
+ report_lines.append("=" * 70)
147
+ report_lines.append("")
148
+
149
+ # Check first 3 rows in detail
150
+ report_lines.append("First 3 rows (detailed):")
151
+ report_lines.append("-" * 50)
152
+
153
+ sample_count = 0
154
+ column_value_types: dict[str, Counter[str]] = {}
155
+
156
+ # Iterate through dataset
157
+ iterable = iter(main_split) if hasattr(main_split, "__iter__") else main_split
158
+
159
+ for i, row in enumerate(iterable):
160
+ if i < 3:
161
+ report_lines.append(f"\nROW {i}:")
162
+ for key, val in row.items():
163
+ val_type = safe_type_name(val)
164
+ val_repr = safe_repr(val)
165
+ report_lines.append(f" {key}:")
166
+ report_lines.append(f" type: {val_type}")
167
+ report_lines.append(f" value: {val_repr}")
168
+
169
+ # Track types for all rows
170
+ for key, val in row.items():
171
+ if key not in column_value_types:
172
+ column_value_types[key] = Counter()
173
+ column_value_types[key][safe_type_name(val)] += 1
174
+
175
+ sample_count += 1
176
+
177
+ # Progress indicator
178
+ if sample_count % 50 == 0:
179
+ print(f" Processed {sample_count} rows...")
180
+
181
+ print(f" Total rows processed: {sample_count}")
182
+ print()
183
+
184
+ # =========================================================================
185
+ # PHASE 4: Consistency Check
186
+ # =========================================================================
187
+ print("[4/4] Checking consistency across all rows...")
188
+ print()
189
+
190
+ report_lines.append("")
191
+ report_lines.append("=" * 70)
192
+ report_lines.append("CONSISTENCY ANALYSIS (all rows)")
193
+ report_lines.append("=" * 70)
194
+ report_lines.append("")
195
+ report_lines.append(f"Total rows analyzed: {sample_count}")
196
+ report_lines.append("")
197
+
198
+ report_lines.append("Column type distribution:")
199
+ report_lines.append("-" * 50)
200
+ for col_name, type_counts in column_value_types.items():
201
+ report_lines.append(f"\n {col_name}:")
202
+ for type_name, count in type_counts.most_common():
203
+ pct = (count / sample_count) * 100
204
+ report_lines.append(f" {type_name}: {count} ({pct:.1f}%)")
205
+
206
+ # =========================================================================
207
+ # PHASE 5: CaseAdapter Compatibility Check
208
+ # =========================================================================
209
+ report_lines.append("")
210
+ report_lines.append("=" * 70)
211
+ report_lines.append("CASEADAPTER COMPATIBILITY CHECK")
212
+ report_lines.append("=" * 70)
213
+ report_lines.append("")
214
+
215
+ expected_columns = ["dwi", "adc", "flair", "mask", "ground_truth", "participant_id"]
216
+ actual_columns = list(column_value_types.keys())
217
+
218
+ report_lines.append("Expected by CaseAdapter:")
219
+ for col in expected_columns:
220
+ status = "FOUND" if col in actual_columns else "MISSING"
221
+ report_lines.append(f" {col}: {status}")
222
+
223
+ report_lines.append("")
224
+ report_lines.append("Actual columns in dataset:")
225
+ for col in actual_columns:
226
+ expected = "expected" if col in expected_columns else "UNEXPECTED"
227
+ report_lines.append(f" {col}: {expected}")
228
+
229
+ report_lines.append("")
230
+ report_lines.append("=" * 70)
231
+ report_lines.append("END OF REPORT")
232
+ report_lines.append("=" * 70)
233
+
234
+ # Write report
235
+ report_content = "\n".join(report_lines)
236
+ REPORT_FILE.write_text(report_content)
237
+
238
+ print(f"Report written to: {REPORT_FILE}")
239
+ print()
240
+ print("=" * 70)
241
+ print("DISCOVERY COMPLETE")
242
+ print("=" * 70)
243
+ print()
244
+ print("Next steps:")
245
+ print(f" 1. Review: {REPORT_FILE}")
246
+ print(" 2. Compare findings against src/stroke_deepisles_demo/data/adapter.py")
247
+ print(" 3. Update adapter if schema differs from expectations")
248
+ print()
249
+
250
+ # Print summary to stdout as well
251
+ print("-" * 70)
252
+ print("QUICK SUMMARY:")
253
+ print("-" * 70)
254
+ print(f"Columns found: {actual_columns}")
255
+ print()
256
+ missing = [c for c in expected_columns if c not in actual_columns]
257
+ if missing:
258
+ print(f"WARNING: Expected columns MISSING: {missing}")
259
+ unexpected = [c for c in actual_columns if c not in expected_columns]
260
+ if unexpected:
261
+ print(f"NOTE: Unexpected columns found: {unexpected}")
262
+
263
+ return 0
264
+
265
+
266
+ if __name__ == "__main__":
267
+ sys.exit(main())
src/stroke_deepisles_demo/core/exceptions.py CHANGED
@@ -21,3 +21,7 @@ class DeepISLESError(StrokeDemoError):
21
 
22
  class MissingInputError(StrokeDemoError):
23
  """Required input files are missing."""
 
 
 
 
 
21
 
22
  class MissingInputError(StrokeDemoError):
23
  """Required input files are missing."""
24
+
25
+
26
+ class DockerGPUNotAvailableError(StrokeDemoError):
27
+ """GPU requested but NVIDIA Container Runtime not available."""
src/stroke_deepisles_demo/data/__init__.py CHANGED
@@ -1,27 +1,21 @@
1
  """Data loading and case management for stroke-deepisles-demo."""
2
 
3
- from stroke_deepisles_demo.data.adapter import CaseAdapter
4
- from stroke_deepisles_demo.data.loader import DatasetInfo, get_dataset_info, load_isles_dataset
 
5
  from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
6
 
7
  __all__ = [
8
- # Adapter
9
- "CaseAdapter",
10
- # Loader
11
  "DatasetInfo",
12
- # Staging
13
  "StagedCase",
14
  "get_case",
15
- "get_dataset_info",
16
  "list_case_ids",
17
  "load_isles_dataset",
18
  "stage_case_for_deepisles",
19
  ]
20
 
21
 
22
- from stroke_deepisles_demo.core.types import CaseFiles
23
-
24
-
25
  # Convenience functions (combine loader + adapter)
26
  def get_case(case_id: str | int) -> CaseFiles:
27
  """
@@ -31,12 +25,10 @@ def get_case(case_id: str | int) -> CaseFiles:
31
  CaseFiles dictionary
32
  """
33
  dataset = load_isles_dataset()
34
- adapter = CaseAdapter(dataset)
35
- return adapter.get_case(case_id)
36
 
37
 
38
  def list_case_ids() -> list[str]:
39
  """List all available case IDs."""
40
  dataset = load_isles_dataset()
41
- adapter = CaseAdapter(dataset)
42
- return adapter.list_case_ids()
 
1
  """Data loading and case management for stroke-deepisles-demo."""
2
 
3
+ from stroke_deepisles_demo.core.types import CaseFiles
4
+ from stroke_deepisles_demo.data.adapter import LocalDataset
5
+ from stroke_deepisles_demo.data.loader import DatasetInfo, load_isles_dataset
6
  from stroke_deepisles_demo.data.staging import StagedCase, stage_case_for_deepisles
7
 
8
  __all__ = [
 
 
 
9
  "DatasetInfo",
10
+ "LocalDataset",
11
  "StagedCase",
12
  "get_case",
 
13
  "list_case_ids",
14
  "load_isles_dataset",
15
  "stage_case_for_deepisles",
16
  ]
17
 
18
 
 
 
 
19
  # Convenience functions (combine loader + adapter)
20
  def get_case(case_id: str | int) -> CaseFiles:
21
  """
 
25
  CaseFiles dictionary
26
  """
27
  dataset = load_isles_dataset()
28
+ return dataset.get_case(case_id)
 
29
 
30
 
31
  def list_case_ids() -> list[str]:
32
  """List all available case IDs."""
33
  dataset = load_isles_dataset()
34
+ return dataset.list_case_ids()
 
src/stroke_deepisles_demo/data/adapter.py CHANGED
@@ -1,147 +1,84 @@
1
- """Adapt HF dataset rows to typed file references."""
2
 
3
  from __future__ import annotations
4
 
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any
7
-
8
- from stroke_deepisles_demo.core.exceptions import DataLoadError
9
- from stroke_deepisles_demo.core.types import CaseFiles
10
 
11
  if TYPE_CHECKING:
12
  from collections.abc import Iterator
 
13
 
14
- from datasets import Dataset
15
 
16
 
17
- class CaseAdapter:
18
- """
19
- Adapts HuggingFace dataset to provide typed access to case files.
20
 
21
- This handles the mapping between HF dataset structure and our
22
- internal CaseFiles type.
23
- """
24
-
25
- def __init__(self, dataset: Dataset) -> None:
26
- """
27
- Initialize adapter with a loaded dataset.
28
-
29
- Args:
30
- dataset: HuggingFace Dataset with NIfTI files
31
- """
32
- self.dataset = dataset
33
- self._case_id_map = self._build_case_id_map()
34
-
35
- def _build_case_id_map(self) -> dict[str, int]:
36
- """Build mapping from case ID to index."""
37
- case_map = {}
38
- # Assuming dataset has 'participant_id' or similar
39
- # If not, we might need to generate IDs or use index
40
-
41
- # Check features to find ID column
42
- id_col = "participant_id"
43
- if id_col not in self.dataset.features:
44
- # Fallback: try to find a string column that looks like an ID
45
- # Or just use f"case_{i}"
46
- pass
47
-
48
- # Iterate to build map
49
- # This might be slow for huge datasets, but for 149 cases it's fine
50
- for idx, row in enumerate(self.dataset):
51
- case_id = row.get(id_col, f"case_{idx:03d}")
52
- case_map[str(case_id)] = idx
53
-
54
- return case_map
55
 
56
  def __len__(self) -> int:
57
- """Return number of cases in the dataset."""
58
- return len(self.dataset)
59
 
60
  def __iter__(self) -> Iterator[str]:
61
- """Iterate over case IDs."""
62
- return iter(self._case_id_map.keys())
63
 
64
  def list_case_ids(self) -> list[str]:
65
- """
66
- List all available case identifiers.
67
-
68
- Returns:
69
- List of case IDs (e.g., ["sub-001", "sub-002", ...])
70
- """
71
- return list(self._case_id_map.keys())
72
 
73
  def get_case(self, case_id: str | int) -> CaseFiles:
74
- """
75
- Get file paths for a specific case.
 
 
76
 
77
- Args:
78
- case_id: Either a string ID (e.g., "sub-001") or integer index
79
 
80
- Returns:
81
- CaseFiles with paths to DWI, ADC, and optionally ground truth
82
 
83
- Raises:
84
- KeyError: If case_id not found
85
- DataLoadError: If files cannot be accessed
86
- """
87
- if isinstance(case_id, int):
88
- index = case_id
89
- else:
90
- if case_id not in self._case_id_map:
91
- raise KeyError(f"Case ID not found: {case_id}")
92
- index = self._case_id_map[case_id]
93
-
94
- return self._get_case_by_index_internal(index)
95
-
96
- def get_case_by_index(self, index: int) -> tuple[str, CaseFiles]:
97
- """
98
- Get case by numerical index.
99
-
100
- Returns:
101
- Tuple of (case_id, CaseFiles)
102
- """
103
- if index < 0 or index >= len(self.dataset):
104
- raise IndexError("Case index out of range")
105
-
106
- # Find ID for index (reverse lookup)
107
- # This is inefficient O(N) if we don't store reverse map, but N is small.
108
- # Or we can just get it from row again.
109
- row = self.dataset[index]
110
- # Assuming 'participant_id' exists or we used fallback
111
- case_id = row.get("participant_id", f"case_{index:03d}")
112
-
113
- case_files = self._row_to_case_files(row)
114
- return str(case_id), case_files
115
-
116
- def _get_case_by_index_internal(self, index: int) -> CaseFiles:
117
- """Internal helper to get CaseFiles by index."""
118
- row = self.dataset[index]
119
- return self._row_to_case_files(row)
120
-
121
- def _row_to_case_files(self, row: dict[str, Any]) -> CaseFiles:
122
- """Convert a dataset row to CaseFiles."""
123
- # Map columns. DeepISLES needs DWI and ADC.
124
- # Dataset columns might vary. Based on spec/mock: 'dwi', 'adc', 'flair', 'mask'
125
-
126
- # Helper to ensure we return Path if it's a local string path, or keep as is
127
- def to_path_or_raw(val: Any) -> Any:
128
- if isinstance(val, str) and not val.startswith(("http://", "https://")):
129
- return Path(val)
130
- return val
131
-
132
- dwi = to_path_or_raw(row.get("dwi"))
133
- adc = to_path_or_raw(row.get("adc"))
134
- flair = to_path_or_raw(row.get("flair"))
135
- ground_truth = to_path_or_raw(row.get("mask") or row.get("ground_truth"))
136
-
137
- if not dwi or not adc:
138
- raise DataLoadError("Case missing required DWI or ADC files")
139
-
140
- case_files = CaseFiles(dwi=dwi, adc=adc)
141
-
142
- if flair:
143
- case_files["flair"] = flair
144
- if ground_truth:
145
- case_files["ground_truth"] = ground_truth
146
-
147
- return case_files
 
1
+ """Provide typed access to ISLES24 cases."""
2
 
3
  from __future__ import annotations
4
 
5
+ import re
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING
 
 
8
 
9
  if TYPE_CHECKING:
10
  from collections.abc import Iterator
11
+ from pathlib import Path
12
 
13
+ from stroke_deepisles_demo.core.types import CaseFiles
14
 
15
 
16
+ @dataclass
17
+ class LocalDataset:
18
+ """File-based dataset for local ISLES24 data."""
19
 
20
+ data_dir: Path
21
+ cases: dict[str, CaseFiles] # subject_id -> files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def __len__(self) -> int:
24
+ return len(self.cases)
 
25
 
26
  def __iter__(self) -> Iterator[str]:
27
+ return iter(self.cases.keys())
 
28
 
29
  def list_case_ids(self) -> list[str]:
30
+ """Return sorted list of subject IDs."""
31
+ return sorted(self.cases.keys())
 
 
 
 
 
32
 
33
  def get_case(self, case_id: str | int) -> CaseFiles:
34
+ """Get files for a case by ID or index."""
35
+ if isinstance(case_id, int):
36
+ case_id = self.list_case_ids()[case_id]
37
+ return self.cases[case_id]
38
 
 
 
39
 
40
+ # Subject ID extraction
41
+ SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
42
 
43
+
44
+ def parse_subject_id(filename: str) -> str | None:
45
+ """Extract subject ID from BIDS filename."""
46
+ match = SUBJECT_PATTERN.match(filename)
47
+ return f"sub-{match.group(1)}" if match else None
48
+
49
+
50
+ def build_local_dataset(data_dir: Path) -> LocalDataset:
51
+ """
52
+ Scan directory and build case mapping.
53
+
54
+ Matches DWI + ADC + Mask files by subject ID.
55
+ """
56
+ dwi_dir = data_dir / "Images-DWI"
57
+ adc_dir = data_dir / "Images-ADC"
58
+ mask_dir = data_dir / "Masks"
59
+
60
+ cases: dict[str, CaseFiles] = {}
61
+
62
+ # Scan DWI files to get subject IDs
63
+ for dwi_file in dwi_dir.glob("*.nii.gz"):
64
+ subject_id = parse_subject_id(dwi_file.name)
65
+ if not subject_id:
66
+ continue
67
+
68
+ # Find matching ADC and Mask
69
+ adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.")
70
+ mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.")
71
+
72
+ if not adc_file.exists():
73
+ continue # Skip incomplete cases
74
+
75
+ case_files: CaseFiles = {
76
+ "dwi": dwi_file,
77
+ "adc": adc_file,
78
+ }
79
+ if mask_file.exists():
80
+ case_files["ground_truth"] = mask_file
81
+
82
+ cases[subject_id] = case_files
83
+
84
+ return LocalDataset(data_dir=data_dir, cases=cases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/stroke_deepisles_demo/data/loader.py CHANGED
@@ -1,138 +1,47 @@
1
- """Load ISLES24-MR-Lite dataset from HuggingFace Hub."""
2
 
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass
 
6
  from typing import TYPE_CHECKING
7
 
8
- from datasets import load_dataset
 
9
 
10
- from stroke_deepisles_demo.core.exceptions import DataLoadError
11
 
12
- if TYPE_CHECKING:
13
- from pathlib import Path
 
14
 
15
- from datasets import Dataset
 
 
 
16
 
17
 
18
  def load_isles_dataset(
19
- dataset_id: str = "YongchengYAO/ISLES24-MR-Lite",
20
  *,
21
- cache_dir: Path | None = None,
22
- streaming: bool = False,
23
- ) -> Dataset:
24
  """
25
- Load the ISLES24-MR-Lite dataset from HuggingFace Hub.
26
 
27
  Args:
28
- dataset_id: HuggingFace dataset identifier
29
- cache_dir: Local cache directory (uses HF default if None)
30
- streaming: If True, use streaming mode (lazy loading)
31
 
32
  Returns:
33
- HuggingFace Dataset object with BIDS/NIfTI support
34
 
35
  Raises:
36
- DataLoadError: If dataset cannot be loaded
37
  """
38
- try:
39
- # The pinned fork supports BIDS/NIfTI properly.
40
- # We pass trust_remote_code=True if needed for custom scripts,
41
- # but standard datasets usually don't need it unless using custom builder.
42
- # ISLES24-MR-Lite is likely a standard dataset or Parquet-based.
43
- # If it's BIDS, we might need type="bids" if the PR features are used that way.
44
- # For now, standard load_dataset.
45
-
46
- ds = load_dataset(
47
- dataset_id,
48
- cache_dir=str(cache_dir) if cache_dir else None,
49
- streaming=streaming,
50
- # If the dataset is BIDS, we might need a specific config/builder.
51
- # Assuming default works or it's already parquet.
52
- )
53
-
54
- # If streaming, load_dataset returns IterableDataset.
55
- # If not, it returns DatasetDict or Dataset.
56
- # We assume it returns the 'train' split if it's a DatasetDict, or we handle it.
57
- # Usually load_dataset returns DatasetDict unless split is specified.
58
-
59
- if hasattr(ds, "keys"):
60
- keys = list(ds.keys())
61
- if "train" in keys:
62
- return ds["train"]
63
- elif len(keys) > 0:
64
- # Fallback to first split if 'train' not found
65
- return ds[keys[0]]
66
-
67
- return ds
68
-
69
- except Exception as e:
70
- raise DataLoadError(f"Failed to load dataset {dataset_id}: {e}") from e
71
-
72
-
73
- @dataclass
74
- class DatasetInfo:
75
- """Metadata about the loaded dataset."""
76
-
77
- dataset_id: str
78
- num_cases: int
79
- modalities: list[str] # e.g., ["dwi", "adc", "mask"]
80
- has_ground_truth: bool
81
-
82
-
83
- def get_dataset_info(dataset_id: str = "YongchengYAO/ISLES24-MR-Lite") -> DatasetInfo:
84
- """
85
- Get metadata about the dataset without downloading (if possible).
86
-
87
- Returns:
88
- DatasetInfo with case count, available modalities, etc.
89
- """
90
- try:
91
- # Load in streaming mode to get features/info cheaply
92
- ds = load_isles_dataset(dataset_id, streaming=True)
93
-
94
- # Count cases (might be slow for streaming, but okay for demo scale)
95
- # Or check if info is available
96
- if hasattr(ds, "info") and ds.info.splits:
97
- # Approximate from splits info if available
98
- num_cases = ds.info.splits["train"].num_examples
99
- else:
100
- # Iterate to count? Or just rely on known size?
101
- # For streaming, len() might not work.
102
- # Let's just load non-streaming but with no data download? No.
103
- # Let's just assume we can get length if we loaded it.
104
- # If we loaded it streaming, we might not get length.
105
- # For the demo, let's just try to get it.
106
-
107
- # If we can't get length easily from streaming, we might need to trust metadata.
108
- # Or just iterate (expensive).
109
- # Let's use a safer approach: load non-streaming (lazy) might download metadata only.
110
- # But datasets downloads parquet files.
111
-
112
- # For get_dataset_info, maybe we just load it fully? No, expensive.
113
- # Let's use streaming and try to get info.
114
- num_cases = 0
115
- # Use a fixed number if we can't determine?
116
- # Or just count - 149 is small.
117
- # But streaming iteration means network calls.
118
-
119
- # Try to access info object
120
- if hasattr(ds, "n_shards"):
121
- # Approximate?
122
- pass
123
-
124
- # Fallback: 149 (known)
125
- num_cases = 149
126
 
127
- features = ds.features.keys()
128
- modalities = [k for k in features if k in ["dwi", "adc", "flair"]]
129
- has_ground_truth = "mask" in features or "ground_truth" in features
130
 
131
- return DatasetInfo(
132
- dataset_id=dataset_id,
133
- num_cases=num_cases,
134
- modalities=sorted(modalities),
135
- has_ground_truth=has_ground_truth,
136
- )
137
- except Exception as e:
138
- raise DataLoadError(f"Failed to get info for {dataset_id}: {e}") from e
 
1
+ """Load ISLES24 data from local directory or HuggingFace Hub."""
2
 
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass
6
+ from pathlib import Path
7
  from typing import TYPE_CHECKING
8
 
9
+ if TYPE_CHECKING:
10
+ from stroke_deepisles_demo.data.adapter import LocalDataset
11
 
 
12
 
13
+ @dataclass
14
+ class DatasetInfo:
15
+ """Metadata about the dataset."""
16
 
17
+ source: str # "local" or HF dataset ID
18
+ num_cases: int
19
+ modalities: list[str]
20
+ has_ground_truth: bool
21
 
22
 
23
  def load_isles_dataset(
24
+ source: str | Path = "data/scratch/isles24_extracted",
25
  *,
26
+ local_mode: bool = True, # Default to local for now
27
+ ) -> LocalDataset:
 
28
  """
29
+ Load ISLES24 dataset.
30
 
31
  Args:
32
+ source: Local directory path or HuggingFace dataset ID
33
+ local_mode: If True, treat source as local directory
 
34
 
35
  Returns:
36
+ Dataset-like object providing case access
37
 
38
  Raises:
39
+ NotImplementedError: If non-local mode is requested
40
  """
41
+ if local_mode or isinstance(source, Path):
42
+ from stroke_deepisles_demo.data.adapter import build_local_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ return build_local_dataset(Path(source))
 
 
45
 
46
+ # Future: return _load_from_huggingface(source)
47
+ raise NotImplementedError("HuggingFace mode not yet implemented")
 
 
 
 
 
 
src/stroke_deepisles_demo/inference/__init__.py CHANGED
@@ -1 +1,37 @@
1
- """DeepISLES inference module for stroke-deepisles-demo."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference module for stroke-deepisles-demo."""
2
+
3
+ from stroke_deepisles_demo.inference.deepisles import (
4
+ DEEPISLES_IMAGE,
5
+ DeepISLESResult,
6
+ find_prediction_mask,
7
+ run_deepisles_on_folder,
8
+ validate_input_folder,
9
+ )
10
+ from stroke_deepisles_demo.inference.docker import (
11
+ DockerRunResult,
12
+ build_docker_command,
13
+ check_docker_available,
14
+ check_nvidia_docker_available,
15
+ ensure_docker_available,
16
+ ensure_gpu_available_if_requested,
17
+ pull_image_if_missing,
18
+ run_container,
19
+ )
20
+
21
+ __all__ = [
22
+ # DeepISLES
23
+ "DEEPISLES_IMAGE",
24
+ "DeepISLESResult",
25
+ # Docker utilities
26
+ "DockerRunResult",
27
+ "build_docker_command",
28
+ "check_docker_available",
29
+ "check_nvidia_docker_available",
30
+ "ensure_docker_available",
31
+ "ensure_gpu_available_if_requested",
32
+ "find_prediction_mask",
33
+ "pull_image_if_missing",
34
+ "run_container",
35
+ "run_deepisles_on_folder",
36
+ "validate_input_folder",
37
+ ]
src/stroke_deepisles_demo/inference/deepisles.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DeepISLES stroke segmentation wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING
8
+
9
+ from stroke_deepisles_demo.core.exceptions import DeepISLESError, MissingInputError
10
+ from stroke_deepisles_demo.inference.docker import (
11
+ DockerRunResult,
12
+ ensure_gpu_available_if_requested,
13
+ run_container,
14
+ )
15
+
16
+ if TYPE_CHECKING:
17
+ from pathlib import Path
18
+
19
+ # Constants
20
+ DEEPISLES_IMAGE = "isleschallenge/deepisles"
21
+ EXPECTED_INPUT_FILES = ["dwi.nii.gz", "adc.nii.gz"]
22
+ OPTIONAL_INPUT_FILES = ["flair.nii.gz"]
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class DeepISLESResult:
27
+ """Result of DeepISLES inference."""
28
+
29
+ prediction_path: Path
30
+ docker_result: DockerRunResult
31
+ elapsed_seconds: float
32
+
33
+
34
+ def validate_input_folder(input_dir: Path) -> tuple[Path, Path, Path | None]:
35
+ """
36
+ Validate that input folder contains required files.
37
+
38
+ Args:
39
+ input_dir: Directory to validate
40
+
41
+ Returns:
42
+ Tuple of (dwi_path, adc_path, flair_path_or_none)
43
+
44
+ Raises:
45
+ MissingInputError: If required files are missing
46
+ """
47
+ dwi_path = input_dir / "dwi.nii.gz"
48
+ adc_path = input_dir / "adc.nii.gz"
49
+ flair_path = input_dir / "flair.nii.gz"
50
+
51
+ if not dwi_path.exists():
52
+ raise MissingInputError(f"Required file 'dwi.nii.gz' not found in {input_dir}")
53
+
54
+ if not adc_path.exists():
55
+ raise MissingInputError(f"Required file 'adc.nii.gz' not found in {input_dir}")
56
+
57
+ return dwi_path, adc_path, flair_path if flair_path.exists() else None
58
+
59
+
60
+ def find_prediction_mask(output_dir: Path) -> Path:
61
+ """
62
+ Find the prediction mask in DeepISLES output directory.
63
+
64
+ DeepISLES outputs may have varying names depending on version.
65
+ This function finds the most likely prediction file.
66
+
67
+ Args:
68
+ output_dir: DeepISLES output directory
69
+
70
+ Returns:
71
+ Path to the prediction mask NIfTI file
72
+
73
+ Raises:
74
+ DeepISLESError: If no prediction mask found
75
+ """
76
+ results_dir = output_dir / "results"
77
+
78
+ # Check common output patterns
79
+ possible_names = [
80
+ "prediction.nii.gz",
81
+ "pred.nii.gz",
82
+ "lesion_mask.nii.gz",
83
+ "output.nii.gz",
84
+ ]
85
+
86
+ for name in possible_names:
87
+ pred_path = results_dir / name
88
+ if pred_path.exists():
89
+ return pred_path
90
+
91
+ # Fall back to finding any .nii.gz in results dir
92
+ if results_dir.exists():
93
+ nifti_files = list(results_dir.glob("*.nii.gz"))
94
+ if nifti_files:
95
+ return nifti_files[0]
96
+
97
+ raise DeepISLESError(
98
+ f"No prediction mask found in {results_dir}. "
99
+ "Expected files like 'prediction.nii.gz' or similar."
100
+ )
101
+
102
+
103
+ def run_deepisles_on_folder(
104
+ input_dir: Path,
105
+ *,
106
+ output_dir: Path | None = None,
107
+ fast: bool = True,
108
+ gpu: bool = True,
109
+ timeout: float | None = 1800, # 30 minutes default
110
+ ) -> DeepISLESResult:
111
+ """
112
+ Run DeepISLES stroke segmentation on a folder of NIfTI files.
113
+
114
+ Args:
115
+ input_dir: Directory containing dwi.nii.gz, adc.nii.gz, [flair.nii.gz]
116
+ output_dir: Where to write results (default: input_dir/results)
117
+ fast: If True, use single-model mode (faster, slightly less accurate)
118
+ gpu: If True, use GPU acceleration
119
+ timeout: Maximum seconds to wait for inference
120
+
121
+ Returns:
122
+ DeepISLESResult with path to prediction mask
123
+
124
+ Raises:
125
+ DockerNotAvailableError: If Docker is not available
126
+ DockerGPUNotAvailableError: If GPU requested but not available
127
+ MissingInputError: If required input files are missing
128
+ DeepISLESError: If inference fails (non-zero exit, missing output)
129
+
130
+ Example:
131
+ >>> result = run_deepisles_on_folder(Path("/data/case001"), fast=True)
132
+ >>> print(result.prediction_path)
133
+ /data/case001/results/prediction.nii.gz
134
+ """
135
+ start_time = time.time()
136
+
137
+ # Validate inputs
138
+ _dwi_path, _adc_path, flair_path = validate_input_folder(input_dir)
139
+
140
+ # Check GPU if requested
141
+ if gpu:
142
+ ensure_gpu_available_if_requested(gpu)
143
+
144
+ # Set up output directory
145
+ if output_dir is None:
146
+ output_dir = input_dir
147
+
148
+ # Build command arguments
149
+ command: list[str] = [
150
+ "--dwi_file_name",
151
+ "dwi.nii.gz",
152
+ "--adc_file_name",
153
+ "adc.nii.gz",
154
+ ]
155
+
156
+ if flair_path is not None:
157
+ command.extend(["--flair_file_name", "flair.nii.gz"])
158
+
159
+ if fast:
160
+ command.extend(["--fast", "True"])
161
+
162
+ # Set up volume mounts
163
+ volumes = {
164
+ input_dir.resolve(): "/input",
165
+ output_dir.resolve(): "/output",
166
+ }
167
+
168
+ # Run the container
169
+ docker_result = run_container(
170
+ DEEPISLES_IMAGE,
171
+ command=command,
172
+ volumes=volumes,
173
+ gpu=gpu,
174
+ timeout=timeout,
175
+ )
176
+
177
+ # Check for failure
178
+ if docker_result.exit_code != 0:
179
+ raise DeepISLESError(
180
+ f"DeepISLES inference failed with exit code {docker_result.exit_code}. "
181
+ f"stderr: {docker_result.stderr}"
182
+ )
183
+
184
+ # Find the prediction mask
185
+ prediction_path = find_prediction_mask(output_dir)
186
+
187
+ elapsed = time.time() - start_time
188
+
189
+ return DeepISLESResult(
190
+ prediction_path=prediction_path,
191
+ docker_result=docker_result,
192
+ elapsed_seconds=elapsed,
193
+ )
src/stroke_deepisles_demo/inference/docker.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Docker execution utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING
10
+
11
+ from stroke_deepisles_demo.core.exceptions import (
12
+ DockerGPUNotAvailableError,
13
+ DockerNotAvailableError,
14
+ )
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Sequence
18
+ from pathlib import Path
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class DockerRunResult:
23
+ """Result of a Docker container run."""
24
+
25
+ exit_code: int
26
+ stdout: str
27
+ stderr: str
28
+ elapsed_seconds: float
29
+
30
+
31
+ def check_docker_available() -> bool:
32
+ """
33
+ Check if Docker is installed and the daemon is running.
34
+
35
+ Returns:
36
+ True if Docker is available, False otherwise
37
+ """
38
+ try:
39
+ result = subprocess.run(
40
+ ["docker", "info"],
41
+ capture_output=True,
42
+ timeout=10,
43
+ check=False,
44
+ )
45
+ return result.returncode == 0
46
+ except (FileNotFoundError, subprocess.TimeoutExpired):
47
+ return False
48
+
49
+
50
+ def ensure_docker_available() -> None:
51
+ """
52
+ Ensure Docker is available, raising if not.
53
+
54
+ Raises:
55
+ DockerNotAvailableError: If Docker is not installed or not running
56
+ """
57
+ if not check_docker_available():
58
+ raise DockerNotAvailableError(
59
+ "Docker is not available. Please ensure Docker is installed and running."
60
+ )
61
+
62
+
63
+ def check_nvidia_docker_available() -> bool:
64
+ """
65
+ Check if NVIDIA Container Runtime is available for GPU support.
66
+
67
+ Returns:
68
+ True if nvidia-docker/nvidia-container-toolkit is configured
69
+ """
70
+ try:
71
+ result = subprocess.run(
72
+ [
73
+ "docker",
74
+ "run",
75
+ "--rm",
76
+ "--gpus",
77
+ "all",
78
+ "nvidia/cuda:11.0-base",
79
+ "nvidia-smi",
80
+ ],
81
+ capture_output=True,
82
+ timeout=30,
83
+ check=False,
84
+ )
85
+ return result.returncode == 0
86
+ except (subprocess.TimeoutExpired, FileNotFoundError):
87
+ return False
88
+
89
+
90
+ def ensure_gpu_available_if_requested(gpu: bool) -> None:
91
+ """
92
+ Verify GPU is available if requested.
93
+
94
+ Args:
95
+ gpu: Whether GPU was requested
96
+
97
+ Raises:
98
+ DockerGPUNotAvailableError: If GPU requested but not available
99
+ """
100
+ if gpu and not check_nvidia_docker_available():
101
+ raise DockerGPUNotAvailableError(
102
+ "GPU requested but NVIDIA Container Runtime not available. "
103
+ "Either install nvidia-container-toolkit or set gpu=False."
104
+ )
105
+
106
+
107
+ def pull_image_if_missing(image: str, *, timeout: float = 600) -> bool:
108
+ """
109
+ Pull a Docker image if not present locally.
110
+
111
+ Args:
112
+ image: Docker image name (e.g., "isleschallenge/deepisles")
113
+ timeout: Maximum seconds to wait for pull
114
+
115
+ Returns:
116
+ True if image was pulled, False if already present
117
+ """
118
+ # Check if image exists locally
119
+ result = subprocess.run(
120
+ ["docker", "image", "inspect", image],
121
+ capture_output=True,
122
+ timeout=10,
123
+ check=False,
124
+ )
125
+ if result.returncode == 0:
126
+ return False # Image already present
127
+
128
+ # Pull the image
129
+ subprocess.run(
130
+ ["docker", "pull", image],
131
+ capture_output=True,
132
+ timeout=timeout,
133
+ check=True,
134
+ )
135
+ return True
136
+
137
+
138
+ def build_docker_command(
139
+ image: str,
140
+ *,
141
+ command: Sequence[str] | None = None,
142
+ volumes: dict[Path, str] | None = None,
143
+ environment: dict[str, str] | None = None,
144
+ gpu: bool = False,
145
+ remove: bool = True,
146
+ match_user: bool = True,
147
+ ) -> list[str]:
148
+ """
149
+ Build the docker run command without executing.
150
+
151
+ Args:
152
+ image: Docker image name
153
+ command: Command to run in container
154
+ volumes: Volume mounts (host path -> container path)
155
+ environment: Environment variables
156
+ gpu: If True, pass --gpus all
157
+ remove: If True, remove container after exit (--rm)
158
+ match_user: If True, match host user (Linux only)
159
+
160
+ Returns:
161
+ List of command arguments for subprocess
162
+ """
163
+ cmd: list[str] = ["docker", "run"]
164
+
165
+ if remove:
166
+ cmd.append("--rm")
167
+
168
+ if gpu:
169
+ cmd.extend(["--gpus", "all"])
170
+
171
+ # Match host user to avoid permission issues (Linux only).
172
+ # Guard against platforms (e.g. Windows, macOS) where os.getuid()/getgid()
173
+ # are absent or not meaningful.
174
+ if match_user:
175
+ import os
176
+
177
+ if (
178
+ os.name == "posix"
179
+ and sys.platform != "darwin"
180
+ and hasattr(os, "getuid")
181
+ and hasattr(os, "getgid")
182
+ ):
183
+ uid = os.getuid()
184
+ gid = os.getgid()
185
+ cmd.extend(["--user", f"{uid}:{gid}"])
186
+
187
+ if volumes:
188
+ for host_path, container_path in volumes.items():
189
+ cmd.extend(["-v", f"{host_path}:{container_path}"])
190
+
191
+ if environment:
192
+ for key, value in environment.items():
193
+ cmd.extend(["-e", f"{key}={value}"])
194
+
195
+ cmd.append(image)
196
+
197
+ if command:
198
+ cmd.extend(command)
199
+
200
+ return cmd
201
+
202
+
203
+ def run_container(
204
+ image: str,
205
+ *,
206
+ command: Sequence[str] | None = None,
207
+ volumes: dict[Path, str] | None = None,
208
+ environment: dict[str, str] | None = None,
209
+ gpu: bool = False,
210
+ remove: bool = True,
211
+ timeout: float | None = None,
212
+ ) -> DockerRunResult:
213
+ """
214
+ Run a Docker container and wait for completion.
215
+
216
+ Args:
217
+ image: Docker image name
218
+ command: Command to run in container
219
+ volumes: Volume mounts (host path -> container path)
220
+ environment: Environment variables
221
+ gpu: If True, pass --gpus all
222
+ remove: If True, remove container after exit (--rm)
223
+ timeout: Maximum seconds to wait (None = no timeout)
224
+
225
+ Returns:
226
+ DockerRunResult with exit code, stdout, stderr, elapsed time
227
+
228
+ Raises:
229
+ DockerNotAvailableError: If Docker is not available
230
+ subprocess.TimeoutExpired: If timeout exceeded
231
+ """
232
+ ensure_docker_available()
233
+
234
+ cmd = build_docker_command(
235
+ image,
236
+ command=command,
237
+ volumes=volumes,
238
+ environment=environment,
239
+ gpu=gpu,
240
+ remove=remove,
241
+ )
242
+
243
+ start_time = time.time()
244
+ result = subprocess.run(
245
+ cmd,
246
+ capture_output=True,
247
+ text=True,
248
+ timeout=timeout,
249
+ check=False,
250
+ )
251
+ elapsed = time.time() - start_time
252
+
253
+ return DockerRunResult(
254
+ exit_code=result.returncode,
255
+ stdout=result.stdout,
256
+ stderr=result.stderr,
257
+ elapsed_seconds=elapsed,
258
+ )
tests/conftest.py CHANGED
@@ -13,7 +13,7 @@ import pytest
13
  from stroke_deepisles_demo.core.types import CaseFiles
14
 
15
  if TYPE_CHECKING:
16
- from collections.abc import Generator, Iterator
17
 
18
 
19
  @pytest.fixture
@@ -62,30 +62,46 @@ def synthetic_case_files(temp_dir: Path) -> CaseFiles:
62
 
63
 
64
  @pytest.fixture
65
- def mock_hf_dataset(synthetic_case_files: CaseFiles) -> object:
66
- """Create a mock HF Dataset-like object."""
67
-
68
- # Simple list-based mock that mimics dataset behavior
69
- class MockDataset:
70
- def __init__(self) -> None:
71
- self.data = [
72
- {
73
- "participant_id": "sub-001",
74
- "dwi": str(synthetic_case_files["dwi"]),
75
- "adc": str(synthetic_case_files["adc"]),
76
- "flair": None,
77
- "mask": str(synthetic_case_files.get("ground_truth")),
78
- }
79
- ]
80
- self.features = {"dwi": None, "adc": None, "flair": None, "mask": None}
81
-
82
- def __len__(self) -> int:
83
- return len(self.data)
84
-
85
- def __getitem__(self, idx: int) -> dict[str, str | None]:
86
- return self.data[idx]
87
-
88
- def __iter__(self) -> Iterator[dict[str, str | None]]:
89
- return iter(self.data)
90
-
91
- return MockDataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from stroke_deepisles_demo.core.types import CaseFiles
14
 
15
  if TYPE_CHECKING:
16
+ from collections.abc import Generator
17
 
18
 
19
  @pytest.fixture
 
62
 
63
 
64
  @pytest.fixture
65
+ def synthetic_isles_dir(temp_dir: Path) -> Path:
66
+ """
67
+ Create synthetic ISLES24-like directory structure.
68
+
69
+ Structure:
70
+ temp_dir/
71
+ ├── Images-DWI/
72
+ │ ├── sub-stroke0001_ses-02_dwi.nii.gz
73
+ │ └── sub-stroke0002_ses-02_dwi.nii.gz
74
+ ├── Images-ADC/
75
+ │ ├── sub-stroke0001_ses-02_adc.nii.gz
76
+ │ └── sub-stroke0002_ses-02_adc.nii.gz
77
+ └── Masks/
78
+ ├── sub-stroke0001_ses-02_lesion-msk.nii.gz
79
+ └── sub-stroke0002_ses-02_lesion-msk.nii.gz
80
+ """
81
+ dwi_dir = temp_dir / "Images-DWI"
82
+ adc_dir = temp_dir / "Images-ADC"
83
+ mask_dir = temp_dir / "Masks"
84
+
85
+ dwi_dir.mkdir()
86
+ adc_dir.mkdir()
87
+ mask_dir.mkdir()
88
+
89
+ for subject_num in [1, 2]:
90
+ subject_id = f"sub-stroke{subject_num:04d}"
91
+
92
+ # Create DWI
93
+ dwi_data = np.random.rand(10, 10, 5).astype(np.float32)
94
+ dwi_img = nib.Nifti1Image(dwi_data, affine=np.eye(4)) # type: ignore
95
+ nib.save(dwi_img, dwi_dir / f"{subject_id}_ses-02_dwi.nii.gz") # type: ignore
96
+
97
+ # Create ADC
98
+ adc_data = np.random.rand(10, 10, 5).astype(np.float32) * 2000
99
+ adc_img = nib.Nifti1Image(adc_data, affine=np.eye(4)) # type: ignore
100
+ nib.save(adc_img, adc_dir / f"{subject_id}_ses-02_adc.nii.gz") # type: ignore
101
+
102
+ # Create Mask
103
+ mask_data = (np.random.rand(10, 10, 5) > 0.9).astype(np.uint8)
104
+ mask_img = nib.Nifti1Image(mask_data, affine=np.eye(4)) # type: ignore
105
+ nib.save(mask_img, mask_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz") # type: ignore
106
+
107
+ return temp_dir
tests/data/test_adapter.py CHANGED
@@ -1,70 +1,94 @@
1
- """Tests for case adapter module."""
2
 
3
  from __future__ import annotations
4
 
5
  from typing import TYPE_CHECKING
6
 
7
- import pytest
8
-
9
- from stroke_deepisles_demo.data.adapter import CaseAdapter
 
 
10
 
11
  if TYPE_CHECKING:
12
- from unittest.mock import MagicMock
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
- class TestCaseAdapter:
16
- """Tests for CaseAdapter."""
 
17
 
18
- def test_list_case_ids_returns_strings(self, mock_hf_dataset: MagicMock) -> None:
19
- """list_case_ids returns list of string identifiers."""
20
- adapter = CaseAdapter(mock_hf_dataset)
21
- case_ids = adapter.list_case_ids()
22
 
23
- assert isinstance(case_ids, list)
24
- assert all(isinstance(cid, str) for cid in case_ids)
25
- assert case_ids == ["sub-001"]
 
 
 
26
 
27
- def test_len_matches_dataset_size(self, mock_hf_dataset: MagicMock) -> None:
28
- """len(adapter) equals number of cases in dataset."""
29
- adapter = CaseAdapter(mock_hf_dataset)
30
 
31
- assert len(adapter) == len(mock_hf_dataset)
 
 
32
 
33
- def test_get_case_by_string_id(self, mock_hf_dataset: MagicMock) -> None:
34
- """Can retrieve case by string identifier."""
35
- adapter = CaseAdapter(mock_hf_dataset)
36
- case_ids = adapter.list_case_ids()
 
37
 
38
- case = adapter.get_case(case_ids[0])
 
 
 
39
 
40
- assert isinstance(case, dict)
41
- assert "dwi" in case
42
- assert "adc" in case
43
- # Paths should be Path objects or convertible
44
- from pathlib import Path
45
 
46
- assert isinstance(case["dwi"], (Path, str))
 
 
 
 
 
 
47
 
48
- def test_get_case_by_index(self, mock_hf_dataset: MagicMock) -> None:
49
- """Can retrieve case by integer index."""
50
- adapter = CaseAdapter(mock_hf_dataset)
51
 
52
- case_id, case = adapter.get_case_by_index(0)
 
 
53
 
54
- assert isinstance(case_id, str)
55
- assert case["dwi"] is not None
56
 
57
- def test_get_case_invalid_id_raises(self, mock_hf_dataset: MagicMock) -> None:
58
- """Raises KeyError for invalid case ID."""
59
- adapter = CaseAdapter(mock_hf_dataset)
 
 
 
 
60
 
61
- with pytest.raises(KeyError):
62
- adapter.get_case("nonexistent-case-id")
 
63
 
64
- def test_iteration(self, mock_hf_dataset: MagicMock) -> None:
65
- """Can iterate over case IDs."""
66
- adapter = CaseAdapter(mock_hf_dataset)
67
 
68
- case_ids = list(adapter)
 
69
 
70
- assert len(case_ids) == len(adapter)
 
 
1
+ """Tests for the data adapter."""
2
 
3
  from __future__ import annotations
4
 
5
  from typing import TYPE_CHECKING
6
 
7
+ from stroke_deepisles_demo.data.adapter import (
8
+ LocalDataset,
9
+ build_local_dataset,
10
+ parse_subject_id,
11
+ )
12
 
13
  if TYPE_CHECKING:
14
+ from pathlib import Path
15
+
16
+
17
+ def test_parse_subject_id_extracts_correctly() -> None:
18
+ """Test extracting subject ID from BIDS filename."""
19
+ # Valid cases
20
+ assert parse_subject_id("sub-stroke0005_ses-02_dwi.nii.gz") == "sub-stroke0005"
21
+ assert parse_subject_id("sub-stroke0149_ses-02_adc.nii.gz") == "sub-stroke0149"
22
+ assert parse_subject_id("sub-stroke1234_ses-02_lesion-msk.nii.gz") == "sub-stroke1234"
23
+
24
+ # Invalid cases
25
+ assert parse_subject_id("random_file.nii.gz") is None
26
+ assert parse_subject_id("sub-strokeABC_ses-02_dwi.nii.gz") is None # Non-digit ID
27
 
28
 
29
+ def test_build_local_dataset_matches_files(synthetic_isles_dir: Path) -> None:
30
+ """Test that files are correctly matched by subject ID."""
31
+ dataset = build_local_dataset(synthetic_isles_dir)
32
 
33
+ assert isinstance(dataset, LocalDataset)
34
+ assert len(dataset) == 2 # synthetic_isles_dir creates 2 subjects
35
+ assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
 
36
 
37
+ # Verify matching logic
38
+ case1 = dataset.get_case("sub-stroke0001")
39
+ assert case1["dwi"].name == "sub-stroke0001_ses-02_dwi.nii.gz"
40
+ assert case1["adc"].name == "sub-stroke0001_ses-02_adc.nii.gz"
41
+ assert case1["ground_truth"] is not None
42
+ assert case1["ground_truth"].name == "sub-stroke0001_ses-02_lesion-msk.nii.gz"
43
 
 
 
 
44
 
45
+ def test_get_case_returns_case_files(synthetic_isles_dir: Path) -> None:
46
+ """Test retrieval of cases by ID and index."""
47
+ dataset = build_local_dataset(synthetic_isles_dir)
48
 
49
+ # By ID
50
+ case_by_id = dataset.get_case("sub-stroke0001")
51
+ assert isinstance(case_by_id, dict)
52
+ assert "dwi" in case_by_id
53
+ assert "adc" in case_by_id
54
 
55
+ # By Index
56
+ case_by_idx = dataset.get_case(0)
57
+ assert isinstance(case_by_idx, dict)
58
+ assert case_by_id == case_by_idx # Should be the same case
59
 
 
 
 
 
 
60
 
61
+ def test_build_local_dataset_skips_incomplete(
62
+ synthetic_isles_dir: Path,
63
+ ) -> None:
64
+ """Test that incomplete cases (missing ADC) are skipped."""
65
+ # Delete ADC for subject 2
66
+ adc_file = synthetic_isles_dir / "Images-ADC" / "sub-stroke0002_ses-02_adc.nii.gz"
67
+ adc_file.unlink()
68
 
69
+ dataset = build_local_dataset(synthetic_isles_dir)
 
 
70
 
71
+ # Subject 2 should be gone
72
+ assert len(dataset) == 1
73
+ assert dataset.list_case_ids() == ["sub-stroke0001"]
74
 
 
 
75
 
76
+ def test_build_local_dataset_handles_missing_mask(
77
+ synthetic_isles_dir: Path,
78
+ ) -> None:
79
+ """Test that missing mask results in ground_truth=None (if allowed)."""
80
+ # NOTE: Adapter currently allows missing mask?
81
+ # Spec says: "ground_truth=mask_file if mask_file.exists() else None"
82
+ # So yes, it should load but with None.
83
 
84
+ # Delete Mask for subject 2
85
+ mask_file = synthetic_isles_dir / "Masks" / "sub-stroke0002_ses-02_lesion-msk.nii.gz"
86
+ mask_file.unlink()
87
 
88
+ dataset = build_local_dataset(synthetic_isles_dir)
 
 
89
 
90
+ # Subject 2 should still exist
91
+ assert len(dataset) == 2
92
 
93
+ case2 = dataset.get_case("sub-stroke0002")
94
+ assert case2.get("ground_truth") is None
tests/data/test_integration_real_data.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Integration tests with real ISLES24 data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+ from stroke_deepisles_demo.data.loader import load_isles_dataset
10
+
11
+ REAL_DATA_PATH = Path("data/scratch/isles24_extracted")
12
+
13
+
14
+ @pytest.mark.skipif(not REAL_DATA_PATH.exists(), reason="Real data not found in data/scratch")
15
+ def test_load_real_data_count() -> None:
16
+ """Verify that we can load the expected number of cases from real data."""
17
+ dataset = load_isles_dataset(source=REAL_DATA_PATH)
18
+
19
+ # We expect 149 cases based on schema report
20
+ assert len(dataset) == 149
21
+
22
+ # Check a specific known case
23
+ case = dataset.get_case("sub-stroke0005")
24
+ assert case["dwi"].name == "sub-stroke0005_ses-02_dwi.nii.gz"
25
+ assert case["dwi"].exists()
26
+ assert case["adc"].exists()
27
+ assert case["ground_truth"] is not None
28
+ assert case["ground_truth"].exists()
29
+
30
+
31
+ @pytest.mark.skipif(not REAL_DATA_PATH.exists(), reason="Real data not found in data/scratch")
32
+ def test_real_data_subject_ids() -> None:
33
+ """Verify subject ID formatting on real data."""
34
+ dataset = load_isles_dataset(source=REAL_DATA_PATH)
35
+ ids = dataset.list_case_ids()
36
+
37
+ assert len(ids) == 149
38
+ assert ids[0] == "sub-stroke0001"
39
+ # We know there are gaps, so just check the format
40
+ for subject_id in ids:
41
+ assert subject_id.startswith("sub-stroke")
42
+ assert len(subject_id) == len("sub-strokeXXXX")
tests/data/test_loader.py CHANGED
@@ -1,90 +1,33 @@
1
- """Tests for data loader module."""
2
 
3
  from __future__ import annotations
4
 
5
- from unittest.mock import MagicMock, patch
6
 
7
  import pytest
8
 
9
- from stroke_deepisles_demo.core.exceptions import DataLoadError
10
- from stroke_deepisles_demo.data.loader import (
11
- DatasetInfo,
12
- get_dataset_info,
13
- load_isles_dataset,
14
- )
15
 
 
 
16
 
17
- class TestLoadIslesDataset:
18
- """Tests for load_isles_dataset."""
19
 
20
- def test_calls_hf_load_dataset(self) -> None:
21
- """Calls datasets.load_dataset with correct arguments."""
22
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
23
- mock_load.return_value = MagicMock()
 
24
 
25
- load_isles_dataset("test/dataset")
26
 
27
- mock_load.assert_called_once()
28
- call_args = mock_load.call_args
29
- assert call_args.args[0] == "test/dataset"
 
 
30
 
31
- def test_returns_dataset_object(self) -> None:
32
- """Returns the loaded Dataset object."""
33
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
34
- expected = MagicMock()
35
- mock_load.return_value = expected
36
 
37
- result = load_isles_dataset()
38
-
39
- assert result is expected
40
-
41
- def test_handles_load_error(self) -> None:
42
- """Wraps HF errors in DataLoadError."""
43
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
44
- mock_load.side_effect = Exception("Network error")
45
-
46
- with pytest.raises(DataLoadError, match="Network error"):
47
- load_isles_dataset()
48
-
49
-
50
- class TestGetDatasetInfo:
51
- """Tests for get_dataset_info."""
52
-
53
- def test_returns_datasetinfo(self) -> None:
54
- """Returns DatasetInfo with expected fields."""
55
- with patch("stroke_deepisles_demo.data.loader.load_dataset") as mock_load:
56
- mock_ds = MagicMock()
57
- mock_ds.__len__ = MagicMock(return_value=149)
58
- # Mock info.splits['train'].num_examples
59
- mock_ds.info.splits.__getitem__.return_value.num_examples = 149
60
- # Mock features as dict-like
61
- mock_ds.features = {"dwi": None, "adc": None, "mask": None}
62
- mock_load.return_value = mock_ds
63
-
64
- info = get_dataset_info()
65
-
66
- assert isinstance(info, DatasetInfo)
67
- assert info.num_cases == 149
68
- assert "dwi" in info.modalities
69
- assert info.has_ground_truth is True
70
-
71
-
72
- @pytest.mark.integration
73
- class TestLoadIslesDatasetIntegration:
74
- """Integration tests that hit the real HuggingFace Hub."""
75
-
76
- @pytest.mark.slow
77
- def test_load_real_dataset(self) -> None:
78
- """Actually loads ISLES24-MR-Lite from HF Hub."""
79
- # This test requires network access
80
- # Run with: pytest -m integration
81
- # Using streaming=True to avoid downloading everything
82
- try:
83
- dataset = load_isles_dataset(streaming=True)
84
- assert dataset is not None
85
- # Verify we got metadata/features - this confirms connectivity
86
- # Iterating might trigger heavy downloads or fail if dataset is empty/gated
87
- assert hasattr(dataset, "features")
88
- assert len(dataset.features) > 0
89
- except Exception as e:
90
- pytest.fail(f"Failed to load real dataset: {e}")
 
1
+ """Tests for the data loader."""
2
 
3
  from __future__ import annotations
4
 
5
+ from typing import TYPE_CHECKING
6
 
7
  import pytest
8
 
9
+ from stroke_deepisles_demo.data.adapter import LocalDataset
10
+ from stroke_deepisles_demo.data.loader import load_isles_dataset
 
 
 
 
11
 
12
+ if TYPE_CHECKING:
13
+ from pathlib import Path
14
 
 
 
15
 
16
+ def test_load_from_local_returns_local_dataset(synthetic_isles_dir: Path) -> None:
17
+ """Test that loading from local path returns a LocalDataset."""
18
+ dataset = load_isles_dataset(source=synthetic_isles_dir, local_mode=True)
19
+ assert isinstance(dataset, LocalDataset)
20
+ assert len(dataset) > 0
21
 
 
22
 
23
+ def test_load_from_local_finds_all_cases(synthetic_isles_dir: Path) -> None:
24
+ """Test that the loader correctly delegates finding cases to adapter."""
25
+ dataset = load_isles_dataset(source=synthetic_isles_dir)
26
+ assert len(dataset) == 2
27
+ assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
28
 
 
 
 
 
 
29
 
30
+ def test_load_raises_not_implemented_for_hf() -> None:
31
+ """Test that HF mode raises NotImplementedError."""
32
+ with pytest.raises(NotImplementedError):
33
+ load_isles_dataset(source="fake/dataset", local_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/inference/__init__.py ADDED
File without changes
tests/inference/test_deepisles.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for DeepISLES wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import pytest
9
+
10
+ from stroke_deepisles_demo.core.exceptions import DeepISLESError, MissingInputError
11
+ from stroke_deepisles_demo.inference.deepisles import (
12
+ DeepISLESResult,
13
+ find_prediction_mask,
14
+ run_deepisles_on_folder,
15
+ validate_input_folder,
16
+ )
17
+
18
+
19
+ class TestValidateInputFolder:
20
+ """Tests for validate_input_folder."""
21
+
22
+ def test_succeeds_with_required_files(self, temp_dir: Path) -> None:
23
+ """Returns paths when required files exist."""
24
+ (temp_dir / "dwi.nii.gz").touch()
25
+ (temp_dir / "adc.nii.gz").touch()
26
+
27
+ dwi, adc, flair = validate_input_folder(temp_dir)
28
+
29
+ assert dwi == temp_dir / "dwi.nii.gz"
30
+ assert adc == temp_dir / "adc.nii.gz"
31
+ assert flair is None
32
+
33
+ def test_includes_flair_when_present(self, temp_dir: Path) -> None:
34
+ """Returns FLAIR path when present."""
35
+ (temp_dir / "dwi.nii.gz").touch()
36
+ (temp_dir / "adc.nii.gz").touch()
37
+ (temp_dir / "flair.nii.gz").touch()
38
+
39
+ _dwi, _adc, flair = validate_input_folder(temp_dir)
40
+
41
+ assert flair == temp_dir / "flair.nii.gz"
42
+
43
+ def test_raises_when_dwi_missing(self, temp_dir: Path) -> None:
44
+ """Raises MissingInputError when DWI is missing."""
45
+ (temp_dir / "adc.nii.gz").touch()
46
+
47
+ with pytest.raises(MissingInputError, match="dwi"):
48
+ validate_input_folder(temp_dir)
49
+
50
+ def test_raises_when_adc_missing(self, temp_dir: Path) -> None:
51
+ """Raises MissingInputError when ADC is missing."""
52
+ (temp_dir / "dwi.nii.gz").touch()
53
+
54
+ with pytest.raises(MissingInputError, match="adc"):
55
+ validate_input_folder(temp_dir)
56
+
57
+
58
+ class TestFindPredictionMask:
59
+ """Tests for find_prediction_mask."""
60
+
61
+ def test_finds_prediction_file(self, temp_dir: Path) -> None:
62
+ """Finds prediction.nii.gz in output directory."""
63
+ results_dir = temp_dir / "results"
64
+ results_dir.mkdir()
65
+ pred_file = results_dir / "prediction.nii.gz"
66
+ pred_file.touch()
67
+
68
+ result = find_prediction_mask(temp_dir)
69
+
70
+ assert result == pred_file
71
+
72
+ def test_finds_alternate_name(self, temp_dir: Path) -> None:
73
+ """Finds alternate named prediction files."""
74
+ results_dir = temp_dir / "results"
75
+ results_dir.mkdir()
76
+ pred_file = results_dir / "pred.nii.gz"
77
+ pred_file.touch()
78
+
79
+ result = find_prediction_mask(temp_dir)
80
+
81
+ assert result == pred_file
82
+
83
+ def test_falls_back_to_any_nifti(self, temp_dir: Path) -> None:
84
+ """Falls back to any .nii.gz file if standard names not found."""
85
+ results_dir = temp_dir / "results"
86
+ results_dir.mkdir()
87
+ pred_file = results_dir / "some_output.nii.gz"
88
+ pred_file.touch()
89
+
90
+ result = find_prediction_mask(temp_dir)
91
+
92
+ assert result == pred_file
93
+
94
+ def test_raises_when_no_prediction(self, temp_dir: Path) -> None:
95
+ """Raises DeepISLESError when no prediction found."""
96
+ results_dir = temp_dir / "results"
97
+ results_dir.mkdir()
98
+
99
+ with pytest.raises(DeepISLESError, match="prediction"):
100
+ find_prediction_mask(temp_dir)
101
+
102
+ def test_raises_when_results_dir_missing(self, temp_dir: Path) -> None:
103
+ """Raises DeepISLESError when results directory missing."""
104
+ with pytest.raises(DeepISLESError, match="prediction"):
105
+ find_prediction_mask(temp_dir)
106
+
107
+
108
+ class TestRunDeepIslesOnFolder:
109
+ """Tests for run_deepisles_on_folder."""
110
+
111
+ @pytest.fixture
112
+ def valid_input_dir(self, temp_dir: Path) -> Path:
113
+ """Create a valid input directory with required files."""
114
+ (temp_dir / "dwi.nii.gz").touch()
115
+ (temp_dir / "adc.nii.gz").touch()
116
+ return temp_dir
117
+
118
+ def test_validates_input_files(self, temp_dir: Path) -> None:
119
+ """Validates input files before running Docker."""
120
+ # Missing required files
121
+ with pytest.raises(MissingInputError):
122
+ run_deepisles_on_folder(temp_dir)
123
+
124
+ def test_calls_docker_with_correct_image(self, valid_input_dir: Path) -> None:
125
+ """Calls Docker with DeepISLES image."""
126
+ with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
127
+ mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
128
+ with (
129
+ patch(
130
+ "stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
131
+ ),
132
+ patch(
133
+ "stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
134
+ ) as mock_find,
135
+ ):
136
+ mock_find.return_value = valid_input_dir / "results" / "pred.nii.gz"
137
+ run_deepisles_on_folder(valid_input_dir)
138
+
139
+ # Check image name
140
+ call_args = mock_run.call_args
141
+ assert call_args.args[0] == "isleschallenge/deepisles"
142
+
143
+ def test_passes_fast_flag(self, valid_input_dir: Path) -> None:
144
+ """Passes --fast True when fast=True."""
145
+ with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
146
+ mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
147
+ with (
148
+ patch(
149
+ "stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
150
+ ),
151
+ patch(
152
+ "stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
153
+ ) as mock_find,
154
+ ):
155
+ mock_find.return_value = valid_input_dir / "results" / "pred.nii.gz"
156
+
157
+ run_deepisles_on_folder(valid_input_dir, fast=True)
158
+
159
+ # Check --fast in command
160
+ call_kwargs = mock_run.call_args.kwargs
161
+ command = call_kwargs.get("command", [])
162
+ assert "--fast" in command
163
+ assert "True" in command
164
+
165
+ def test_includes_flair_when_present(self, valid_input_dir: Path) -> None:
166
+ """Includes FLAIR in command when present."""
167
+ (valid_input_dir / "flair.nii.gz").touch()
168
+
169
+ with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
170
+ mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
171
+ with (
172
+ patch(
173
+ "stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
174
+ ),
175
+ patch(
176
+ "stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
177
+ ) as mock_find,
178
+ ):
179
+ mock_find.return_value = valid_input_dir / "results" / "pred.nii.gz"
180
+
181
+ run_deepisles_on_folder(valid_input_dir)
182
+
183
+ call_kwargs = mock_run.call_args.kwargs
184
+ command = call_kwargs.get("command", [])
185
+ assert "--flair_file_name" in command
186
+ assert "flair.nii.gz" in command
187
+
188
+ def test_raises_on_docker_failure(self, valid_input_dir: Path) -> None:
189
+ """Raises DeepISLESError when Docker returns non-zero."""
190
+ with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
191
+ mock_run.return_value = MagicMock(exit_code=1, stdout="", stderr="Segmentation fault")
192
+ with (
193
+ patch(
194
+ "stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
195
+ ),
196
+ pytest.raises(DeepISLESError, match="failed"),
197
+ ):
198
+ run_deepisles_on_folder(valid_input_dir)
199
+
200
+ def test_returns_result_with_prediction_path(self, valid_input_dir: Path) -> None:
201
+ """Returns DeepISLESResult with prediction path."""
202
+ with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
203
+ mock_run.return_value = MagicMock(
204
+ exit_code=0, stdout="", stderr="", elapsed_seconds=10.0
205
+ )
206
+ with (
207
+ patch(
208
+ "stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
209
+ ),
210
+ patch(
211
+ "stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
212
+ ) as mock_find,
213
+ ):
214
+ expected_path = valid_input_dir / "results" / "prediction.nii.gz"
215
+ mock_find.return_value = expected_path
216
+
217
+ result = run_deepisles_on_folder(valid_input_dir)
218
+
219
+ assert isinstance(result, DeepISLESResult)
220
+ assert result.prediction_path == expected_path
221
+
222
+ def test_passes_volume_mounts(self, valid_input_dir: Path, temp_dir: Path) -> None:
223
+ """Passes correct volume mounts to Docker."""
224
+ # Create a separate output directory
225
+ output_dir = temp_dir / "output"
226
+ output_dir.mkdir()
227
+
228
+ with patch("stroke_deepisles_demo.inference.deepisles.run_container") as mock_run:
229
+ mock_run.return_value = MagicMock(exit_code=0, stdout="", stderr="")
230
+ with (
231
+ patch(
232
+ "stroke_deepisles_demo.inference.deepisles.ensure_gpu_available_if_requested"
233
+ ),
234
+ patch(
235
+ "stroke_deepisles_demo.inference.deepisles.find_prediction_mask"
236
+ ) as mock_find,
237
+ ):
238
+ mock_find.return_value = output_dir / "results" / "pred.nii.gz"
239
+
240
+ run_deepisles_on_folder(valid_input_dir, output_dir=output_dir)
241
+
242
+ call_kwargs = mock_run.call_args.kwargs
243
+ volumes = call_kwargs.get("volumes", {})
244
+ # Should have input and output mounts (2 separate directories)
245
+ assert len(volumes) == 2
246
+ # Values should be container paths
247
+ assert "/input" in volumes.values()
248
+ assert "/output" in volumes.values()
249
+
250
+
251
+ @pytest.mark.integration
252
+ @pytest.mark.slow
253
+ class TestDeepIslesIntegration:
254
+ """Integration tests requiring real Docker and DeepISLES image."""
255
+
256
+ def test_real_inference(self, synthetic_case_files: dict[str, object]) -> None:
257
+ """Run actual DeepISLES inference on synthetic data."""
258
+ # This test requires:
259
+ # 1. Docker available
260
+ # 2. isleschallenge/deepisles image pulled
261
+ # 3. GPU (optional but recommended)
262
+ #
263
+ # Run with: pytest -m integration
264
+ import tempfile
265
+
266
+ from stroke_deepisles_demo.data.staging import stage_case_for_deepisles
267
+
268
+ # Create a separate staging directory
269
+ with tempfile.TemporaryDirectory() as staging_dir:
270
+ # Stage the synthetic files to the new directory
271
+ staged = stage_case_for_deepisles(
272
+ synthetic_case_files, # type: ignore[arg-type]
273
+ Path(staging_dir),
274
+ )
275
+
276
+ # Run inference
277
+ result = run_deepisles_on_folder(
278
+ staged.input_dir,
279
+ fast=True,
280
+ gpu=False, # Might not have GPU in CI
281
+ timeout=600,
282
+ )
283
+
284
+ # Verify output exists
285
+ assert result.prediction_path.exists()
tests/inference/test_docker.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Docker utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import pytest
9
+
10
+ from stroke_deepisles_demo.core.exceptions import DockerNotAvailableError
11
+ from stroke_deepisles_demo.inference.docker import (
12
+ build_docker_command,
13
+ check_docker_available,
14
+ ensure_docker_available,
15
+ run_container,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from pathlib import Path
20
+
21
+
22
+ class TestCheckDockerAvailable:
23
+ """Tests for check_docker_available."""
24
+
25
+ def test_returns_true_when_docker_responds(self) -> None:
26
+ """Returns True when 'docker info' succeeds."""
27
+ with patch("subprocess.run") as mock_run:
28
+ mock_run.return_value = MagicMock(returncode=0)
29
+
30
+ result = check_docker_available()
31
+
32
+ assert result is True
33
+
34
+ def test_returns_false_when_docker_not_found(self) -> None:
35
+ """Returns False when docker command not found."""
36
+ with patch("subprocess.run") as mock_run:
37
+ mock_run.side_effect = FileNotFoundError()
38
+
39
+ result = check_docker_available()
40
+
41
+ assert result is False
42
+
43
+ def test_returns_false_when_daemon_not_running(self) -> None:
44
+ """Returns False when docker daemon not running."""
45
+ with patch("subprocess.run") as mock_run:
46
+ mock_run.return_value = MagicMock(returncode=1)
47
+
48
+ result = check_docker_available()
49
+
50
+ assert result is False
51
+
52
+
53
+ class TestEnsureDockerAvailable:
54
+ """Tests for ensure_docker_available."""
55
+
56
+ def test_raises_when_docker_not_available(self) -> None:
57
+ """Raises DockerNotAvailableError when Docker not available."""
58
+ with (
59
+ patch(
60
+ "stroke_deepisles_demo.inference.docker.check_docker_available",
61
+ return_value=False,
62
+ ),
63
+ pytest.raises(DockerNotAvailableError),
64
+ ):
65
+ ensure_docker_available()
66
+
67
+ def test_no_error_when_docker_available(self) -> None:
68
+ """No exception when Docker is available."""
69
+ with patch(
70
+ "stroke_deepisles_demo.inference.docker.check_docker_available",
71
+ return_value=True,
72
+ ):
73
+ ensure_docker_available() # Should not raise
74
+
75
+
76
+ class TestBuildDockerCommand:
77
+ """Tests for build_docker_command."""
78
+
79
+ def test_basic_command(self) -> None:
80
+ """Builds basic docker run command."""
81
+ cmd = build_docker_command("myimage:latest")
82
+
83
+ assert cmd[0] == "docker"
84
+ assert "run" in cmd
85
+ assert "myimage:latest" in cmd
86
+
87
+ def test_includes_rm_flag(self) -> None:
88
+ """Includes --rm when remove=True."""
89
+ cmd = build_docker_command("myimage", remove=True)
90
+
91
+ assert "--rm" in cmd
92
+
93
+ def test_excludes_rm_flag(self) -> None:
94
+ """Excludes --rm when remove=False."""
95
+ cmd = build_docker_command("myimage", remove=False)
96
+
97
+ assert "--rm" not in cmd
98
+
99
+ def test_includes_gpu_flag(self) -> None:
100
+ """Includes --gpus all when gpu=True."""
101
+ cmd = build_docker_command("myimage", gpu=True)
102
+
103
+ assert "--gpus" in cmd
104
+ gpu_index = cmd.index("--gpus")
105
+ assert cmd[gpu_index + 1] == "all"
106
+
107
+ def test_volume_mounts(self, temp_dir: Path) -> None:
108
+ """Includes volume mounts."""
109
+ volumes = {temp_dir: "/data"}
110
+ cmd = build_docker_command("myimage", volumes=volumes)
111
+
112
+ assert "-v" in cmd
113
+ # Find the volume argument
114
+ v_index = cmd.index("-v")
115
+ assert f"{temp_dir}:/data" in cmd[v_index + 1]
116
+
117
+ def test_custom_command(self) -> None:
118
+ """Appends custom command arguments."""
119
+ cmd = build_docker_command("myimage", command=["--input", "/data", "--fast", "True"])
120
+
121
+ assert "--input" in cmd
122
+ assert "--fast" in cmd
123
+
124
+ def test_environment_variables(self) -> None:
125
+ """Includes environment variables."""
126
+ env = {"MY_VAR": "value", "OTHER": "123"}
127
+ cmd = build_docker_command("myimage", environment=env)
128
+
129
+ assert "-e" in cmd
130
+ # Check both vars are present
131
+ cmd_str = " ".join(cmd)
132
+ assert "MY_VAR=value" in cmd_str
133
+ assert "OTHER=123" in cmd_str
134
+
135
+
136
+ class TestRunContainer:
137
+ """Tests for run_container."""
138
+
139
+ def test_calls_subprocess_with_built_command(self) -> None:
140
+ """Calls subprocess.run with built command."""
141
+ with patch("subprocess.run") as mock_run:
142
+ mock_run.return_value = MagicMock(returncode=0, stdout="output", stderr="")
143
+ with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
144
+ run_container("myimage")
145
+
146
+ mock_run.assert_called_once()
147
+
148
+ def test_returns_result_with_exit_code(self) -> None:
149
+ """Returns DockerRunResult with correct exit code."""
150
+ with patch("subprocess.run") as mock_run:
151
+ mock_run.return_value = MagicMock(returncode=42, stdout="out", stderr="err")
152
+ with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
153
+ result = run_container("myimage")
154
+
155
+ assert result.exit_code == 42
156
+
157
+ def test_captures_stdout_stderr(self) -> None:
158
+ """Captures stdout and stderr from container."""
159
+ with patch("subprocess.run") as mock_run:
160
+ mock_run.return_value = MagicMock(returncode=0, stdout="hello", stderr="warning")
161
+ with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
162
+ result = run_container("myimage")
163
+
164
+ assert result.stdout == "hello"
165
+ assert result.stderr == "warning"
166
+
167
+ def test_respects_timeout(self) -> None:
168
+ """Passes timeout to subprocess."""
169
+ with patch("subprocess.run") as mock_run:
170
+ mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
171
+ with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
172
+ run_container("myimage", timeout=60.0)
173
+
174
+ call_kwargs = mock_run.call_args.kwargs
175
+ assert call_kwargs.get("timeout") == 60.0
176
+
177
+ def test_tracks_elapsed_time(self) -> None:
178
+ """Tracks elapsed time in result."""
179
+ with patch("subprocess.run") as mock_run:
180
+ mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
181
+ with patch("stroke_deepisles_demo.inference.docker.ensure_docker_available"):
182
+ result = run_container("myimage")
183
+
184
+ # Should have some elapsed time (even if small)
185
+ assert result.elapsed_seconds >= 0
186
+
187
+
188
+ @pytest.mark.integration
189
+ class TestDockerIntegration:
190
+ """Integration tests requiring real Docker."""
191
+
192
+ def test_docker_actually_available(self) -> None:
193
+ """Docker is actually available on this system."""
194
+ # This test only runs with -m integration
195
+ assert check_docker_available() is True
196
+
197
+ def test_can_run_hello_world(self) -> None:
198
+ """Can run docker hello-world container."""
199
+ result = run_container("hello-world", timeout=60.0)
200
+
201
+ assert result.exit_code == 0
202
+ assert "Hello from Docker!" in result.stdout