feat(data): integrate HuggingFace dataset as primary data source (#11)
Browse files## Summary
Integrates HuggingFace dataset `hugging-science/isles24-stroke` as the primary data source.
### Key Changes:
- Added `HuggingFaceDataset` adapter with temp-file caching and cleanup
- Updated `load_isles_dataset()` to auto-detect local vs HF mode
- Added comprehensive mocked unit tests for HF adapter
- Extended `Dataset` protocol with context manager support
### CodeRabbit Findings Addressed:
1. ✅ Sort `list_case_ids()` return value in HuggingFaceDataset
2. ✅ Simplified auto-detection heuristic (removed parent.exists() check)
3. ✅ Use context manager in integration test
4. ❌ Rejected: patch target change (lazy import makes current approach correct)
### Test Results:
- 125 tests pass
- ruff clean
- mypy clean
- .gitignore +3 -0
- README.md +2 -2
- data/README.md +38 -25
- docs/dataset-card/isles24-stroke.md +179 -0
- src/stroke_deepisles_demo/data/adapter.py +177 -4
- src/stroke_deepisles_demo/data/loader.py +66 -14
- tests/data/test_hf_adapter.py +182 -0
- tests/data/test_loader.py +15 -5
.gitignore
CHANGED
|
@@ -212,3 +212,6 @@ data/isles24/
|
|
| 212 |
# Discovery artifacts (schema reports, samples)
|
| 213 |
data/discovery/
|
| 214 |
data/scratch/
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
# Discovery artifacts (schema reports, samples)
|
| 213 |
data/discovery/
|
| 214 |
data/scratch/
|
| 215 |
+
|
| 216 |
+
# macOS
|
| 217 |
+
.DS_Store
|
README.md
CHANGED
|
@@ -12,7 +12,7 @@ short_description: Ischemic stroke lesion segmentation using DeepISLES
|
|
| 12 |
models:
|
| 13 |
- isleschallenge/deepisles
|
| 14 |
datasets:
|
| 15 |
-
-
|
| 16 |
tags:
|
| 17 |
- medical-imaging
|
| 18 |
- stroke
|
|
@@ -29,7 +29,7 @@ tags:
|
|
| 29 |
[](https://github.com/astral-sh/ruff)
|
| 30 |
[](http://mypy-lang.org/)
|
| 31 |
|
| 32 |
-
A demonstration pipeline and UI for ischemic stroke lesion segmentation using **DeepISLES** and **
|
| 33 |
|
| 34 |
This project provides a complete end-to-end workflow:
|
| 35 |
1. **Data Loading**: Lazy-loading of NIfTI neuroimaging data from HuggingFace.
|
|
|
|
| 12 |
models:
|
| 13 |
- isleschallenge/deepisles
|
| 14 |
datasets:
|
| 15 |
+
- hugging-science/isles24-stroke
|
| 16 |
tags:
|
| 17 |
- medical-imaging
|
| 18 |
- stroke
|
|
|
|
| 29 |
[](https://github.com/astral-sh/ruff)
|
| 30 |
[](http://mypy-lang.org/)
|
| 31 |
|
| 32 |
+
A demonstration pipeline and UI for ischemic stroke lesion segmentation using **DeepISLES** and **ISLES'24** data.
|
| 33 |
|
| 34 |
This project provides a complete end-to-end workflow:
|
| 35 |
1. **Data Loading**: Lazy-loading of NIfTI neuroimaging data from HuggingFace.
|
data/README.md
CHANGED
|
@@ -1,39 +1,52 @@
|
|
| 1 |
# Data Directory
|
| 2 |
|
| 3 |
-
This folder
|
| 4 |
|
| 5 |
-
##
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
```
|
| 17 |
|
| 18 |
-
##
|
| 19 |
|
| 20 |
-
|
| 21 |
-
2. Extract the ZIP files into `data/isles24/`:
|
| 22 |
-
- `Images-DWI.zip` → `data/isles24/Images-DWI/`
|
| 23 |
-
- `Images-ADC.zip` → `data/isles24/Images-ADC/`
|
| 24 |
-
- `Masks.zip` → `data/isles24/Masks/`
|
| 25 |
|
| 26 |
-
##
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
```
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## Notes
|
| 36 |
|
| 37 |
-
- All data files are gitignored
|
| 38 |
-
-
|
| 39 |
-
- See
|
|
|
|
| 1 |
# Data Directory
|
| 2 |
|
| 3 |
+
This folder is for local development data only. The primary data source is HuggingFace.
|
| 4 |
|
| 5 |
+
## Data Source
|
| 6 |
|
| 7 |
+
**Primary**: [hugging-science/isles24-stroke](https://huggingface.co/datasets/hugging-science/isles24-stroke)
|
| 8 |
+
|
| 9 |
+
The dataset is automatically downloaded and cached by HuggingFace when you run:
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
from stroke_deepisles_demo.data import load_isles_dataset
|
| 13 |
+
|
| 14 |
+
# Loads from HuggingFace (default)
|
| 15 |
+
dataset = load_isles_dataset()
|
| 16 |
+
|
| 17 |
+
# Access cases
|
| 18 |
+
case = dataset.get_case(0) # or dataset.get_case("sub-stroke0001")
|
| 19 |
```
|
| 20 |
|
| 21 |
+
## HuggingFace Cache Location
|
| 22 |
|
| 23 |
+
Data is cached at: `~/.cache/huggingface/datasets/hugging-science___isles24-stroke/`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
## Dataset Contents
|
| 26 |
|
| 27 |
+
149 acute ischemic stroke cases with:
|
| 28 |
+
- **Imaging**: DWI, ADC, CT, CTA, perfusion maps (tmax, mtt, cbf, cbv)
|
| 29 |
+
- **Masks**: lesion_mask, lvo_mask, cow_segmentation
|
| 30 |
+
- **Clinical**: age, sex, nihss_admission, mrs_admission, mrs_3month
|
| 31 |
+
|
| 32 |
+
## Local Development (Optional)
|
| 33 |
+
|
| 34 |
+
For offline development, you can still use a local directory:
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
dataset = load_isles_dataset("path/to/local/data", local_mode=True)
|
| 38 |
```
|
| 39 |
|
| 40 |
+
Expected structure for local mode:
|
| 41 |
+
```text
|
| 42 |
+
data/
|
| 43 |
+
├── Images-DWI/ # DWI volumes
|
| 44 |
+
├── Images-ADC/ # ADC maps
|
| 45 |
+
└── Masks/ # Ground truth lesion masks
|
| 46 |
+
```
|
| 47 |
|
| 48 |
## Notes
|
| 49 |
|
| 50 |
+
- All data files are gitignored
|
| 51 |
+
- On HuggingFace Spaces, data loads automatically from the HF cache
|
| 52 |
+
- See dataset card for citation requirements
|
docs/dataset-card/isles24-stroke.md
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-sa-4.0
|
| 3 |
+
task_categories:
|
| 4 |
+
- image-segmentation
|
| 5 |
+
tags:
|
| 6 |
+
- medical
|
| 7 |
+
- neuroimaging
|
| 8 |
+
- stroke
|
| 9 |
+
- CT
|
| 10 |
+
- MRI
|
| 11 |
+
- perfusion
|
| 12 |
+
- ISLES
|
| 13 |
+
- BIDS
|
| 14 |
+
size_categories:
|
| 15 |
+
- n<1K
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# ISLES'24 Stroke Training Dataset
|
| 19 |
+
|
| 20 |
+
Multi-center longitudinal multimodal acute ischemic stroke training dataset from the ISLES'24 Challenge.
|
| 21 |
+
|
| 22 |
+
## Dataset Description
|
| 23 |
+
|
| 24 |
+
- **Source:** [Zenodo Record 17652035](https://zenodo.org/records/17652035) (v7, November 2025)
|
| 25 |
+
- **Challenge:** [ISLES 2024](https://isles-24.grand-challenge.org/)
|
| 26 |
+
- **Paper:** [Riedel et al., arXiv:2408.11142](https://arxiv.org/abs/2408.11142)
|
| 27 |
+
- **License:** CC BY-NC-SA 4.0
|
| 28 |
+
- **Size:** 99 GB (compressed)
|
| 29 |
+
|
| 30 |
+
## Overview
|
| 31 |
+
|
| 32 |
+
149 acute ischemic stroke training cases with:
|
| 33 |
+
- **Admission imaging (ses-01):** Non-contrast CT, CT angiography, 4D CT perfusion
|
| 34 |
+
- **Follow-up imaging (ses-02):** Post-treatment MRI (DWI, ADC)
|
| 35 |
+
- **Clinical data:** Demographics, patient history, admission NIHSS, 3-month mRS outcomes
|
| 36 |
+
- **Annotations:** Infarct masks, large vessel occlusion masks, Circle of Willis anatomy
|
| 37 |
+
|
| 38 |
+
> **Note:** The ISLES'24 paper describes a training set of 150 cases; the Zenodo v7 training archive contains 149 publicly released subjects.
|
| 39 |
+
|
| 40 |
+
## Dataset Structure
|
| 41 |
+
|
| 42 |
+
### Imaging Modalities
|
| 43 |
+
|
| 44 |
+
| Session | Modality | Description |
|
| 45 |
+
|---------|----------|-------------|
|
| 46 |
+
| ses-01 (Acute) | `ncct` | Non-contrast CT |
|
| 47 |
+
| ses-01 (Acute) | `cta` | CT Angiography |
|
| 48 |
+
| ses-01 (Acute) | `ctp` | 4D CT Perfusion time series |
|
| 49 |
+
| ses-01 (Acute) | `tmax` | Time-to-maximum perfusion map |
|
| 50 |
+
| ses-01 (Acute) | `mtt` | Mean transit time map |
|
| 51 |
+
| ses-01 (Acute) | `cbf` | Cerebral blood flow map |
|
| 52 |
+
| ses-01 (Acute) | `cbv` | Cerebral blood volume map |
|
| 53 |
+
| ses-02 (Follow-up) | `dwi` | Diffusion-weighted MRI |
|
| 54 |
+
| ses-02 (Follow-up) | `adc` | Apparent diffusion coefficient |
|
| 55 |
+
|
| 56 |
+
### Derivative Masks
|
| 57 |
+
|
| 58 |
+
| Mask | Description |
|
| 59 |
+
|------|-------------|
|
| 60 |
+
| `lesion_mask` | Binary infarct segmentation (from follow-up MRI) |
|
| 61 |
+
| `lvo_mask` | Large vessel occlusion mask (from CTA) |
|
| 62 |
+
| `cow_mask` | Circle of Willis anatomy (multi-label, auto-generated from CTA) |
|
| 63 |
+
|
| 64 |
+
### Clinical Variables
|
| 65 |
+
|
| 66 |
+
Clinical variables are extracted from per-subject XLSX files in the `phenotype/` directory:
|
| 67 |
+
|
| 68 |
+
| Variable | Source File | Description |
|
| 69 |
+
|----------|-------------|-------------|
|
| 70 |
+
| `age` | demographic_baseline.xlsx | Patient age at admission |
|
| 71 |
+
| `sex` | demographic_baseline.xlsx | Patient sex (M/F) |
|
| 72 |
+
| `nihss_admission` | demographic_baseline.xlsx | NIH Stroke Scale score at admission |
|
| 73 |
+
| `mrs_admission` | demographic_baseline.xlsx | Modified Rankin Scale at admission |
|
| 74 |
+
| `mrs_3month` | outcome.xlsx | Modified Rankin Scale at 3 months (primary outcome) |
|
| 75 |
+
|
| 76 |
+
## Usage
|
| 77 |
+
|
| 78 |
+
```python
|
| 79 |
+
from datasets import load_dataset
|
| 80 |
+
|
| 81 |
+
ds = load_dataset("hugging-science/isles24-stroke", split="train")
|
| 82 |
+
|
| 83 |
+
# Access a subject
|
| 84 |
+
example = ds[0]
|
| 85 |
+
print(example["subject_id"]) # "sub-stroke0001"
|
| 86 |
+
print(example["ncct"]) # Non-contrast CT array
|
| 87 |
+
print(example["dwi"]) # Diffusion-weighted MRI
|
| 88 |
+
print(example["lesion_mask"]) # Ground truth segmentation
|
| 89 |
+
print(example["nihss_admission"]) # NIH Stroke Scale at admission
|
| 90 |
+
print(example["mrs_3month"]) # Modified Rankin Scale at 3 months
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## Data Organization
|
| 94 |
+
|
| 95 |
+
The source data follows BIDS structure. This tree shows the actual Zenodo v7 layout:
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
train/
|
| 99 |
+
├── clinical_data-description.xlsx
|
| 100 |
+
├── raw_data/
|
| 101 |
+
│ └── sub-stroke0001/
|
| 102 |
+
│ └── ses-01/
|
| 103 |
+
│ ├── sub-stroke0001_ses-01_ncct.nii.gz
|
| 104 |
+
│ ├── sub-stroke0001_ses-01_cta.nii.gz
|
| 105 |
+
│ ├── sub-stroke0001_ses-01_ctp.nii.gz
|
| 106 |
+
│ └── perfusion-maps/
|
| 107 |
+
│ ├── sub-stroke0001_ses-01_tmax.nii.gz
|
| 108 |
+
│ ├── sub-stroke0001_ses-01_mtt.nii.gz
|
| 109 |
+
│ ├── sub-stroke0001_ses-01_cbf.nii.gz
|
| 110 |
+
│ └── sub-stroke0001_ses-01_cbv.nii.gz
|
| 111 |
+
├── derivatives/
|
| 112 |
+
│ └── sub-stroke0001/
|
| 113 |
+
│ ├── ses-01/
|
| 114 |
+
│ │ ├── perfusion-maps/
|
| 115 |
+
│ │ │ ├── sub-stroke0001_ses-01_space-ncct_tmax.nii.gz
|
| 116 |
+
│ │ │ ├── sub-stroke0001_ses-01_space-ncct_mtt.nii.gz
|
| 117 |
+
│ │ │ ├── sub-stroke0001_ses-01_space-ncct_cbf.nii.gz
|
| 118 |
+
│ │ │ └── sub-stroke0001_ses-01_space-ncct_cbv.nii.gz
|
| 119 |
+
│ │ ├── sub-stroke0001_ses-01_space-ncct_cta.nii.gz
|
| 120 |
+
│ │ ├── sub-stroke0001_ses-01_space-ncct_ctp.nii.gz
|
| 121 |
+
│ │ ├── sub-stroke0001_ses-01_space-ncct_lvo-msk.nii.gz
|
| 122 |
+
│ │ └── sub-stroke0001_ses-01_space-ncct_cow-msk.nii.gz
|
| 123 |
+
│ └── ses-02/
|
| 124 |
+
│ ├── sub-stroke0001_ses-02_space-ncct_dwi.nii.gz
|
| 125 |
+
│ ├── sub-stroke0001_ses-02_space-ncct_adc.nii.gz
|
| 126 |
+
│ └── sub-stroke0001_ses-02_space-ncct_lesion-msk.nii.gz
|
| 127 |
+
└── phenotype/
|
| 128 |
+
└── sub-stroke0001/
|
| 129 |
+
├── ses-01/
|
| 130 |
+
└── ses-02/
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
## Citation
|
| 134 |
+
|
| 135 |
+
When using this dataset, please cite:
|
| 136 |
+
|
| 137 |
+
```bibtex
|
| 138 |
+
@article{riedel2024isles,
|
| 139 |
+
title={ISLES'24 -- A Real-World Longitudinal Multimodal Stroke Dataset},
|
| 140 |
+
author={Riedel, Evamaria Olga and de la Rosa, Ezequiel and Baran, The Anh and
|
| 141 |
+
Hernandez Petzsche, Moritz and Baazaoui, Hakim and Yang, Kaiyuan and
|
| 142 |
+
Musio, Fabio Antonio and Huang, Houjing and Robben, David and
|
| 143 |
+
Seia, Joaquin Oscar and Wiest, Roland and Reyes, Mauricio and
|
| 144 |
+
Su, Ruisheng and Zimmer, Claus and Boeckh-Behrens, Tobias and
|
| 145 |
+
Berndt, Maria and Menze, Bjoern and Rueckert, Daniel and
|
| 146 |
+
Wiestler, Benedikt and Wegener, Susanne and Kirschke, Jan Stefan},
|
| 147 |
+
journal={arXiv preprint arXiv:2408.11142},
|
| 148 |
+
year={2024}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
@article{delarosa2024isles,
|
| 152 |
+
title={ISLES'24: Final Infarct Prediction with Multimodal Imaging and Clinical Data. Where Do We Stand?},
|
| 153 |
+
author={de la Rosa, Ezequiel and Su, Ruisheng and Reyes, Mauricio and
|
| 154 |
+
Wiest, Roland and Riedel, Evamaria Olga and Kofler, Florian and
|
| 155 |
+
others and Menze, Bjoern},
|
| 156 |
+
journal={arXiv preprint arXiv:2408.10966},
|
| 157 |
+
year={2024}
|
| 158 |
+
}
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
If using Circle of Willis masks, also cite:
|
| 162 |
+
|
| 163 |
+
```bibtex
|
| 164 |
+
@article{yang2023benchmarking,
|
| 165 |
+
title={Benchmarking the CoW with the TopCoW Challenge: Topology-Aware Anatomical
|
| 166 |
+
Segmentation of the Circle of Willis for CTA and MRA},
|
| 167 |
+
author={Yang, Kaiyuan and Musio, Fabio and Ma, Yue and Juchler, Norman and
|
| 168 |
+
Paetzold, Johannes C and Al-Maskari, Rami and others and Menze, Bjoern},
|
| 169 |
+
journal={arXiv preprint arXiv:2312.17670},
|
| 170 |
+
year={2023}
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## Related Resources
|
| 175 |
+
|
| 176 |
+
- [ISLES 2024 Challenge](https://isles-24.grand-challenge.org/)
|
| 177 |
+
- [Zenodo Dataset (DOI: 10.5281/zenodo.17652035)](https://doi.org/10.5281/zenodo.17652035)
|
| 178 |
+
- [Dataset Paper (arXiv:2408.11142)](https://arxiv.org/abs/2408.11142)
|
| 179 |
+
- [Challenge Paper (arXiv:2408.10966)](https://arxiv.org/abs/2408.10966)
|
src/stroke_deepisles_demo/data/adapter.py
CHANGED
|
@@ -3,14 +3,17 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import re
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
|
|
|
| 9 |
from stroke_deepisles_demo.core.logging import get_logger
|
| 10 |
|
| 11 |
if TYPE_CHECKING:
|
| 12 |
from collections.abc import Iterator
|
| 13 |
-
from pathlib import Path
|
| 14 |
|
| 15 |
from stroke_deepisles_demo.core.types import CaseFiles
|
| 16 |
|
|
@@ -19,7 +22,15 @@ logger = get_logger(__name__)
|
|
| 19 |
|
| 20 |
@dataclass
|
| 21 |
class LocalDataset:
|
| 22 |
-
"""File-based dataset for local ISLES24 data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
data_dir: Path
|
| 25 |
cases: dict[str, CaseFiles] # subject_id -> files
|
|
@@ -30,6 +41,13 @@ class LocalDataset:
|
|
| 30 |
def __iter__(self) -> Iterator[str]:
|
| 31 |
return iter(self.cases.keys())
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def list_case_ids(self) -> list[str]:
|
| 34 |
"""Return sorted list of subject IDs."""
|
| 35 |
return sorted(self.cases.keys())
|
|
@@ -40,6 +58,10 @@ class LocalDataset:
|
|
| 40 |
case_id = self.list_case_ids()[case_id]
|
| 41 |
return self.cases[case_id]
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
# Subject ID extraction
|
| 45 |
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
|
|
@@ -111,3 +133,154 @@ def build_local_dataset(data_dir: Path) -> LocalDataset:
|
|
| 111 |
|
| 112 |
logger.info("Loaded %d cases from %s", len(cases), data_dir)
|
| 113 |
return LocalDataset(data_dir=data_dir, cases=cases)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import re
|
| 6 |
+
import shutil
|
| 7 |
+
import tempfile
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import TYPE_CHECKING, Any, Self
|
| 11 |
|
| 12 |
+
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 13 |
from stroke_deepisles_demo.core.logging import get_logger
|
| 14 |
|
| 15 |
if TYPE_CHECKING:
|
| 16 |
from collections.abc import Iterator
|
|
|
|
| 17 |
|
| 18 |
from stroke_deepisles_demo.core.types import CaseFiles
|
| 19 |
|
|
|
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class LocalDataset:
|
| 25 |
+
"""File-based dataset for local ISLES24 data.
|
| 26 |
+
|
| 27 |
+
Can be used as a context manager for consistency with HuggingFaceDataset,
|
| 28 |
+
though no cleanup is needed for local files.
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
with build_local_dataset(path) as ds:
|
| 32 |
+
case = ds.get_case(0)
|
| 33 |
+
"""
|
| 34 |
|
| 35 |
data_dir: Path
|
| 36 |
cases: dict[str, CaseFiles] # subject_id -> files
|
|
|
|
| 41 |
def __iter__(self) -> Iterator[str]:
|
| 42 |
return iter(self.cases.keys())
|
| 43 |
|
| 44 |
+
def __enter__(self) -> Self:
|
| 45 |
+
return self
|
| 46 |
+
|
| 47 |
+
def __exit__(self, *args: object) -> None:
|
| 48 |
+
# No cleanup needed for local files
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
def list_case_ids(self) -> list[str]:
|
| 52 |
"""Return sorted list of subject IDs."""
|
| 53 |
return sorted(self.cases.keys())
|
|
|
|
| 58 |
case_id = self.list_case_ids()[case_id]
|
| 59 |
return self.cases[case_id]
|
| 60 |
|
| 61 |
+
def cleanup(self) -> None:
|
| 62 |
+
"""No-op for local dataset (files are not temporary)."""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
|
| 66 |
# Subject ID extraction
|
| 67 |
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
|
|
|
|
| 133 |
|
| 134 |
logger.info("Loaded %d cases from %s", len(cases), data_dir)
|
| 135 |
return LocalDataset(data_dir=data_dir, cases=cases)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# =============================================================================
|
| 139 |
+
# HuggingFace Dataset Adapter
|
| 140 |
+
# =============================================================================
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass
|
| 144 |
+
class HuggingFaceDataset:
|
| 145 |
+
"""Dataset adapter for HuggingFace ISLES24 dataset.
|
| 146 |
+
|
| 147 |
+
Wraps the HuggingFace dataset and provides the same interface as LocalDataset.
|
| 148 |
+
When get_case() is called, writes NIfTI bytes to temp files and returns paths.
|
| 149 |
+
|
| 150 |
+
IMPORTANT: Use as a context manager to ensure temp files are cleaned up:
|
| 151 |
+
|
| 152 |
+
with load_isles_dataset() as ds:
|
| 153 |
+
case = ds.get_case(0)
|
| 154 |
+
# ... process case ...
|
| 155 |
+
# temp files automatically cleaned up
|
| 156 |
+
|
| 157 |
+
Or call cleanup() manually when done.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
dataset_id: str
|
| 161 |
+
_hf_dataset: Any = field(repr=False)
|
| 162 |
+
_case_ids: list[str] = field(default_factory=list)
|
| 163 |
+
_temp_dir: Path | None = field(default=None, repr=False)
|
| 164 |
+
_cached_cases: dict[str, CaseFiles] = field(default_factory=dict, repr=False)
|
| 165 |
+
|
| 166 |
+
def __len__(self) -> int:
|
| 167 |
+
return len(self._hf_dataset)
|
| 168 |
+
|
| 169 |
+
def __iter__(self) -> Iterator[str]:
|
| 170 |
+
return iter(self._case_ids)
|
| 171 |
+
|
| 172 |
+
def __enter__(self) -> Self:
|
| 173 |
+
return self
|
| 174 |
+
|
| 175 |
+
def __exit__(self, *args: object) -> None:
|
| 176 |
+
self.cleanup()
|
| 177 |
+
|
| 178 |
+
def list_case_ids(self) -> list[str]:
|
| 179 |
+
"""Return sorted list of subject IDs."""
|
| 180 |
+
return sorted(self._case_ids)
|
| 181 |
+
|
| 182 |
+
def get_case(self, case_id: str | int) -> CaseFiles:
|
| 183 |
+
"""Get files for a case by ID or index.
|
| 184 |
+
|
| 185 |
+
Writes NIfTI bytes to temp files on first access; returns cached paths
|
| 186 |
+
on subsequent calls for the same case.
|
| 187 |
+
|
| 188 |
+
Raises:
|
| 189 |
+
DataError: If HuggingFace data is malformed or missing required fields.
|
| 190 |
+
"""
|
| 191 |
+
if isinstance(case_id, int):
|
| 192 |
+
idx = case_id
|
| 193 |
+
subject_id = self._case_ids[idx]
|
| 194 |
+
else:
|
| 195 |
+
subject_id = case_id
|
| 196 |
+
idx = self._case_ids.index(subject_id)
|
| 197 |
+
|
| 198 |
+
# Return cached case if already materialized
|
| 199 |
+
if subject_id in self._cached_cases:
|
| 200 |
+
return self._cached_cases[subject_id]
|
| 201 |
+
|
| 202 |
+
# Create shared temp directory on first use
|
| 203 |
+
if self._temp_dir is None:
|
| 204 |
+
self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_"))
|
| 205 |
+
logger.debug("Created temp directory: %s", self._temp_dir)
|
| 206 |
+
|
| 207 |
+
# Get the HuggingFace example
|
| 208 |
+
example = self._hf_dataset[idx]
|
| 209 |
+
|
| 210 |
+
# Create case subdirectory
|
| 211 |
+
case_dir = self._temp_dir / subject_id
|
| 212 |
+
case_dir.mkdir(exist_ok=True)
|
| 213 |
+
|
| 214 |
+
# Write NIfTI files to temp directory
|
| 215 |
+
dwi_path = case_dir / f"{subject_id}_ses-02_dwi.nii.gz"
|
| 216 |
+
adc_path = case_dir / f"{subject_id}_ses-02_adc.nii.gz"
|
| 217 |
+
mask_path = case_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz"
|
| 218 |
+
|
| 219 |
+
# Extract bytes with defensive error handling
|
| 220 |
+
try:
|
| 221 |
+
dwi_bytes = example["dwi"]["bytes"]
|
| 222 |
+
adc_bytes = example["adc"]["bytes"]
|
| 223 |
+
except (KeyError, TypeError) as e:
|
| 224 |
+
raise DataLoadError(
|
| 225 |
+
f"Malformed HuggingFace data for {subject_id}: missing 'dwi' or 'adc' bytes. "
|
| 226 |
+
f"The dataset schema may have changed. Error: {e}"
|
| 227 |
+
) from e
|
| 228 |
+
|
| 229 |
+
# Write the gzipped NIfTI bytes
|
| 230 |
+
dwi_path.write_bytes(dwi_bytes)
|
| 231 |
+
adc_path.write_bytes(adc_bytes)
|
| 232 |
+
|
| 233 |
+
case_files: CaseFiles = {
|
| 234 |
+
"dwi": dwi_path,
|
| 235 |
+
"adc": adc_path,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
# Write lesion mask if available
|
| 239 |
+
try:
|
| 240 |
+
mask_data = example.get("lesion_mask")
|
| 241 |
+
if mask_data and mask_data.get("bytes"):
|
| 242 |
+
mask_path.write_bytes(mask_data["bytes"])
|
| 243 |
+
case_files["ground_truth"] = mask_path
|
| 244 |
+
except (KeyError, TypeError):
|
| 245 |
+
# Mask is optional, log and continue
|
| 246 |
+
logger.debug("No lesion mask available for %s", subject_id)
|
| 247 |
+
|
| 248 |
+
# Cache for subsequent calls
|
| 249 |
+
self._cached_cases[subject_id] = case_files
|
| 250 |
+
|
| 251 |
+
return case_files
|
| 252 |
+
|
| 253 |
+
def cleanup(self) -> None:
|
| 254 |
+
"""Remove temp directory and clear cache."""
|
| 255 |
+
if self._temp_dir and self._temp_dir.exists():
|
| 256 |
+
shutil.rmtree(self._temp_dir, ignore_errors=True)
|
| 257 |
+
logger.debug("Cleaned up temp directory: %s", self._temp_dir)
|
| 258 |
+
self._temp_dir = None
|
| 259 |
+
self._cached_cases.clear()
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def build_huggingface_dataset(dataset_id: str) -> HuggingFaceDataset:
|
| 263 |
+
"""
|
| 264 |
+
Load ISLES24 dataset from HuggingFace Hub.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
dataset_id: HuggingFace dataset identifier (e.g., "hugging-science/isles24-stroke")
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
HuggingFaceDataset providing case access
|
| 271 |
+
"""
|
| 272 |
+
from datasets import load_dataset
|
| 273 |
+
|
| 274 |
+
logger.info("Loading HuggingFace dataset: %s", dataset_id)
|
| 275 |
+
hf_dataset = load_dataset(dataset_id, split="train")
|
| 276 |
+
|
| 277 |
+
# Extract case IDs
|
| 278 |
+
case_ids = [example["subject_id"] for example in hf_dataset]
|
| 279 |
+
|
| 280 |
+
logger.info("Loaded %d cases from HuggingFace: %s", len(case_ids), dataset_id)
|
| 281 |
+
|
| 282 |
+
return HuggingFaceDataset(
|
| 283 |
+
dataset_id=dataset_id,
|
| 284 |
+
_hf_dataset=hf_dataset,
|
| 285 |
+
_case_ids=case_ids,
|
| 286 |
+
)
|
src/stroke_deepisles_demo/data/loader.py
CHANGED
|
@@ -4,10 +4,29 @@ from __future__ import annotations
|
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import TYPE_CHECKING
|
| 8 |
|
| 9 |
if TYPE_CHECKING:
|
| 10 |
-
from stroke_deepisles_demo.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
|
@@ -20,28 +39,61 @@ class DatasetInfo:
|
|
| 20 |
has_ground_truth: bool
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def load_isles_dataset(
|
| 24 |
-
source: str | Path =
|
| 25 |
*,
|
| 26 |
-
local_mode: bool =
|
| 27 |
-
) ->
|
| 28 |
"""
|
| 29 |
-
Load ISLES24 dataset.
|
| 30 |
|
| 31 |
Args:
|
| 32 |
-
source: Local directory path or HuggingFace dataset ID
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
Returns:
|
| 36 |
-
Dataset-like object providing case access
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
"""
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
from stroke_deepisles_demo.data.adapter import build_local_dataset
|
| 43 |
|
|
|
|
|
|
|
| 44 |
return build_local_dataset(Path(source))
|
| 45 |
|
| 46 |
-
#
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Protocol, Self
|
| 8 |
|
| 9 |
if TYPE_CHECKING:
|
| 10 |
+
from stroke_deepisles_demo.core.types import CaseFiles
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Dataset(Protocol):
|
| 14 |
+
"""Protocol for dataset access.
|
| 15 |
+
|
| 16 |
+
All dataset implementations support context manager usage for proper cleanup:
|
| 17 |
+
|
| 18 |
+
with load_isles_dataset() as ds:
|
| 19 |
+
case = ds.get_case(0)
|
| 20 |
+
# ... process case ...
|
| 21 |
+
# cleanup happens automatically
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __len__(self) -> int: ...
|
| 25 |
+
def __enter__(self) -> Self: ...
|
| 26 |
+
def __exit__(self, *args: object) -> None: ...
|
| 27 |
+
def list_case_ids(self) -> list[str]: ...
|
| 28 |
+
def get_case(self, case_id: str | int) -> CaseFiles: ...
|
| 29 |
+
def cleanup(self) -> None: ...
|
| 30 |
|
| 31 |
|
| 32 |
@dataclass
|
|
|
|
| 39 |
has_ground_truth: bool
|
| 40 |
|
| 41 |
|
| 42 |
+
# Default HuggingFace dataset ID
|
| 43 |
+
DEFAULT_HF_DATASET = "hugging-science/isles24-stroke"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
def load_isles_dataset(
|
| 47 |
+
source: str | Path | None = None,
|
| 48 |
*,
|
| 49 |
+
local_mode: bool | None = None,
|
| 50 |
+
) -> Dataset:
|
| 51 |
"""
|
| 52 |
+
Load ISLES24 dataset from local directory or HuggingFace Hub.
|
| 53 |
|
| 54 |
Args:
|
| 55 |
+
source: Local directory path or HuggingFace dataset ID.
|
| 56 |
+
If None, uses HuggingFace dataset by default.
|
| 57 |
+
local_mode: If True, treat source as local directory.
|
| 58 |
+
If None, auto-detect based on source type.
|
| 59 |
|
| 60 |
Returns:
|
| 61 |
+
Dataset-like object providing case access. Use as context manager
|
| 62 |
+
for automatic cleanup of temp files (important for HuggingFace mode).
|
| 63 |
+
|
| 64 |
+
Examples:
|
| 65 |
+
# Load from HuggingFace with automatic cleanup (recommended)
|
| 66 |
+
with load_isles_dataset() as ds:
|
| 67 |
+
case = ds.get_case(0)
|
| 68 |
|
| 69 |
+
# Load from local directory
|
| 70 |
+
ds = load_isles_dataset("data/isles24", local_mode=True)
|
| 71 |
+
|
| 72 |
+
# Load specific HuggingFace dataset
|
| 73 |
+
ds = load_isles_dataset("hugging-science/isles24-stroke")
|
| 74 |
"""
|
| 75 |
+
# Auto-detect mode if not specified
|
| 76 |
+
if local_mode is None:
|
| 77 |
+
if source is None:
|
| 78 |
+
local_mode = False # Default to HuggingFace
|
| 79 |
+
elif isinstance(source, Path):
|
| 80 |
+
local_mode = True
|
| 81 |
+
else:
|
| 82 |
+
# String: check if it's an existing local path
|
| 83 |
+
# Only select local mode if the path itself exists
|
| 84 |
+
# (avoids misclassifying HF dataset IDs like "org/dataset")
|
| 85 |
+
source_path = Path(source)
|
| 86 |
+
local_mode = source_path.exists()
|
| 87 |
+
|
| 88 |
+
if local_mode:
|
| 89 |
from stroke_deepisles_demo.data.adapter import build_local_dataset
|
| 90 |
|
| 91 |
+
if source is None:
|
| 92 |
+
source = "data/isles24"
|
| 93 |
return build_local_dataset(Path(source))
|
| 94 |
|
| 95 |
+
# HuggingFace mode
|
| 96 |
+
from stroke_deepisles_demo.data.adapter import build_huggingface_dataset
|
| 97 |
+
|
| 98 |
+
dataset_id = source if source else DEFAULT_HF_DATASET
|
| 99 |
+
return build_huggingface_dataset(str(dataset_id))
|
tests/data/test_hf_adapter.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for HuggingFace dataset adapter with mocked HF dataset."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
from unittest.mock import MagicMock, patch
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from stroke_deepisles_demo.core.exceptions import DataLoadError
|
| 11 |
+
from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, build_huggingface_dataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_mock_hf_example(subject_id: str, include_mask: bool = True) -> dict[str, Any]:
|
| 15 |
+
"""Create a mock HuggingFace dataset example."""
|
| 16 |
+
example: dict[str, Any] = {
|
| 17 |
+
"subject_id": subject_id,
|
| 18 |
+
"dwi": {"bytes": b"fake_dwi_nifti_data", "path": f"{subject_id}_dwi.nii.gz"},
|
| 19 |
+
"adc": {"bytes": b"fake_adc_nifti_data", "path": f"{subject_id}_adc.nii.gz"},
|
| 20 |
+
}
|
| 21 |
+
if include_mask:
|
| 22 |
+
example["lesion_mask"] = {
|
| 23 |
+
"bytes": b"fake_mask_nifti_data",
|
| 24 |
+
"path": f"{subject_id}_lesion-msk.nii.gz",
|
| 25 |
+
}
|
| 26 |
+
else:
|
| 27 |
+
example["lesion_mask"] = None
|
| 28 |
+
return example
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.fixture
|
| 32 |
+
def mock_hf_dataset() -> MagicMock:
|
| 33 |
+
"""Create a mock HuggingFace dataset with 3 subjects."""
|
| 34 |
+
examples = [
|
| 35 |
+
create_mock_hf_example("sub-stroke0001"),
|
| 36 |
+
create_mock_hf_example("sub-stroke0002"),
|
| 37 |
+
create_mock_hf_example("sub-stroke0003", include_mask=False),
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
mock_ds = MagicMock()
|
| 41 |
+
mock_ds.__len__ = MagicMock(return_value=len(examples))
|
| 42 |
+
mock_ds.__iter__ = MagicMock(return_value=iter(examples))
|
| 43 |
+
mock_ds.__getitem__ = MagicMock(side_effect=lambda i: examples[i])
|
| 44 |
+
|
| 45 |
+
return mock_ds
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TestHuggingFaceDataset:
|
| 49 |
+
"""Tests for HuggingFaceDataset class."""
|
| 50 |
+
|
| 51 |
+
def test_get_case_writes_files_to_temp_dir(self, mock_hf_dataset: MagicMock) -> None:
|
| 52 |
+
"""Test that get_case writes NIfTI bytes to temp files."""
|
| 53 |
+
case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
|
| 54 |
+
ds = HuggingFaceDataset(
|
| 55 |
+
dataset_id="test/dataset",
|
| 56 |
+
_hf_dataset=mock_hf_dataset,
|
| 57 |
+
_case_ids=case_ids,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
case = ds.get_case(0)
|
| 62 |
+
|
| 63 |
+
assert "dwi" in case
|
| 64 |
+
assert "adc" in case
|
| 65 |
+
assert case["dwi"].exists()
|
| 66 |
+
assert case["adc"].exists()
|
| 67 |
+
assert case["dwi"].read_bytes() == b"fake_dwi_nifti_data"
|
| 68 |
+
assert case["adc"].read_bytes() == b"fake_adc_nifti_data"
|
| 69 |
+
finally:
|
| 70 |
+
ds.cleanup()
|
| 71 |
+
|
| 72 |
+
def test_get_case_includes_ground_truth_when_available(
|
| 73 |
+
self, mock_hf_dataset: MagicMock
|
| 74 |
+
) -> None:
|
| 75 |
+
"""Test that ground truth is included when lesion_mask is present."""
|
| 76 |
+
case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
|
| 77 |
+
ds = HuggingFaceDataset(
|
| 78 |
+
dataset_id="test/dataset",
|
| 79 |
+
_hf_dataset=mock_hf_dataset,
|
| 80 |
+
_case_ids=case_ids,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
case = ds.get_case(0) # Has mask
|
| 85 |
+
assert "ground_truth" in case
|
| 86 |
+
assert case["ground_truth"].read_bytes() == b"fake_mask_nifti_data"
|
| 87 |
+
|
| 88 |
+
case_no_mask = ds.get_case(2) # No mask
|
| 89 |
+
assert "ground_truth" not in case_no_mask
|
| 90 |
+
finally:
|
| 91 |
+
ds.cleanup()
|
| 92 |
+
|
| 93 |
+
def test_get_case_caches_results(self, mock_hf_dataset: MagicMock) -> None:
|
| 94 |
+
"""Test that get_case returns cached paths on subsequent calls."""
|
| 95 |
+
case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
|
| 96 |
+
ds = HuggingFaceDataset(
|
| 97 |
+
dataset_id="test/dataset",
|
| 98 |
+
_hf_dataset=mock_hf_dataset,
|
| 99 |
+
_case_ids=case_ids,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
case1 = ds.get_case(0)
|
| 104 |
+
case2 = ds.get_case(0)
|
| 105 |
+
|
| 106 |
+
# Same object returned (cached)
|
| 107 |
+
assert case1 is case2
|
| 108 |
+
|
| 109 |
+
# Dataset was only accessed once
|
| 110 |
+
assert mock_hf_dataset.__getitem__.call_count == 1
|
| 111 |
+
finally:
|
| 112 |
+
ds.cleanup()
|
| 113 |
+
|
| 114 |
+
def test_context_manager_cleans_up_temp_files(self, mock_hf_dataset: MagicMock) -> None:
|
| 115 |
+
"""Test that using context manager cleans up temp files."""
|
| 116 |
+
case_ids = ["sub-stroke0001"]
|
| 117 |
+
ds = HuggingFaceDataset(
|
| 118 |
+
dataset_id="test/dataset",
|
| 119 |
+
_hf_dataset=mock_hf_dataset,
|
| 120 |
+
_case_ids=case_ids,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
with ds:
|
| 124 |
+
case = ds.get_case(0)
|
| 125 |
+
temp_dir = case["dwi"].parent.parent
|
| 126 |
+
assert temp_dir.exists()
|
| 127 |
+
|
| 128 |
+
# After context exit, temp dir should be gone
|
| 129 |
+
assert not temp_dir.exists()
|
| 130 |
+
|
| 131 |
+
def test_cleanup_clears_cache(self, mock_hf_dataset: MagicMock) -> None:
|
| 132 |
+
"""Test that cleanup clears the case cache."""
|
| 133 |
+
case_ids = ["sub-stroke0001"]
|
| 134 |
+
ds = HuggingFaceDataset(
|
| 135 |
+
dataset_id="test/dataset",
|
| 136 |
+
_hf_dataset=mock_hf_dataset,
|
| 137 |
+
_case_ids=case_ids,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
ds.get_case(0)
|
| 141 |
+
assert len(ds._cached_cases) == 1
|
| 142 |
+
|
| 143 |
+
ds.cleanup()
|
| 144 |
+
assert len(ds._cached_cases) == 0
|
| 145 |
+
|
| 146 |
+
def test_get_case_raises_data_load_error_on_malformed_data(self) -> None:
|
| 147 |
+
"""Test that get_case raises DataLoadError for malformed HF data."""
|
| 148 |
+
# Create mock with missing 'bytes' key
|
| 149 |
+
malformed_example = {"subject_id": "sub-stroke0001", "dwi": {}, "adc": {}}
|
| 150 |
+
mock_ds = MagicMock()
|
| 151 |
+
mock_ds.__len__ = MagicMock(return_value=1)
|
| 152 |
+
mock_ds.__getitem__ = MagicMock(return_value=malformed_example)
|
| 153 |
+
|
| 154 |
+
ds = HuggingFaceDataset(
|
| 155 |
+
dataset_id="test/dataset",
|
| 156 |
+
_hf_dataset=mock_ds,
|
| 157 |
+
_case_ids=["sub-stroke0001"],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
with pytest.raises(DataLoadError, match="Malformed HuggingFace data"):
|
| 162 |
+
ds.get_case(0)
|
| 163 |
+
finally:
|
| 164 |
+
ds.cleanup()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class TestBuildHuggingFaceDataset:
|
| 168 |
+
"""Tests for build_huggingface_dataset function."""
|
| 169 |
+
|
| 170 |
+
@patch("datasets.load_dataset")
|
| 171 |
+
def test_loads_dataset_from_hub(self, mock_load_dataset: MagicMock) -> None:
|
| 172 |
+
"""Test that build_huggingface_dataset calls load_dataset correctly."""
|
| 173 |
+
mock_ds = MagicMock()
|
| 174 |
+
mock_ds.__iter__ = MagicMock(return_value=iter([{"subject_id": "sub-stroke0001"}]))
|
| 175 |
+
mock_load_dataset.return_value = mock_ds
|
| 176 |
+
|
| 177 |
+
result = build_huggingface_dataset("test/my-dataset")
|
| 178 |
+
|
| 179 |
+
mock_load_dataset.assert_called_once_with("test/my-dataset", split="train")
|
| 180 |
+
assert isinstance(result, HuggingFaceDataset)
|
| 181 |
+
assert result.dataset_id == "test/my-dataset"
|
| 182 |
+
assert result._case_ids == ["sub-stroke0001"]
|
tests/data/test_loader.py
CHANGED
|
@@ -5,8 +5,9 @@ from __future__ import annotations
|
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import pytest
|
|
|
|
| 8 |
|
| 9 |
-
from stroke_deepisles_demo.data.adapter import LocalDataset
|
| 10 |
from stroke_deepisles_demo.data.loader import load_isles_dataset
|
| 11 |
|
| 12 |
if TYPE_CHECKING:
|
|
@@ -27,7 +28,16 @@ def test_load_from_local_finds_all_cases(synthetic_isles_dir: Path) -> None:
|
|
| 27 |
assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
|
| 28 |
|
| 29 |
|
| 30 |
-
def
|
| 31 |
-
"""Test that HF
|
| 32 |
-
with pytest.raises(
|
| 33 |
-
load_isles_dataset(source="fake/dataset", local_mode=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from typing import TYPE_CHECKING
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
+
from datasets.exceptions import DatasetNotFoundError
|
| 9 |
|
| 10 |
+
from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, LocalDataset
|
| 11 |
from stroke_deepisles_demo.data.loader import load_isles_dataset
|
| 12 |
|
| 13 |
if TYPE_CHECKING:
|
|
|
|
| 28 |
assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
|
| 29 |
|
| 30 |
|
| 31 |
+
def test_load_hf_raises_on_invalid_dataset() -> None:
|
| 32 |
+
"""Test that loading a non-existent HF dataset raises DatasetNotFoundError."""
|
| 33 |
+
with pytest.raises(DatasetNotFoundError):
|
| 34 |
+
load_isles_dataset(source="fake/nonexistent-dataset", local_mode=False)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@pytest.mark.integration
|
| 38 |
+
def test_load_from_huggingface_returns_hf_dataset() -> None:
|
| 39 |
+
"""Test that loading from HuggingFace returns a HuggingFaceDataset."""
|
| 40 |
+
with load_isles_dataset() as dataset: # Default is HuggingFace mode
|
| 41 |
+
assert isinstance(dataset, HuggingFaceDataset)
|
| 42 |
+
assert len(dataset) == 149
|
| 43 |
+
assert dataset.list_case_ids()[0] == "sub-stroke0001"
|