File size: 11,910 Bytes
aef1f5a
3c4c67b
 
 
785d976
262b3cb
 
 
aef1f5a
363ba14
3c4c67b
262b3cb
 
7e5ddec
 
 
 
 
 
262b3cb
785d976
 
 
 
aef1f5a
262b3cb
 
 
363ba14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c4c67b
 
aef1f5a
 
 
3c4c67b
aef1f5a
 
 
 
3c4c67b
 
262b3cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785d976
 
 
 
 
 
262b3cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e5ddec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c4c67b
363ba14
3c4c67b
363ba14
fa1717e
363ba14
3c4c67b
363ba14
3c4c67b
 
363ba14
fa1717e
363ba14
 
fa1717e
 
3c4c67b
 
363ba14
 
 
 
 
 
 
3c4c67b
363ba14
 
 
fa1717e
 
3c4c67b
363ba14
 
 
 
 
 
 
 
 
 
 
 
 
 
aef1f5a
3c4c67b
363ba14
 
aef1f5a
3c4c67b
363ba14
262b3cb
 
fa1717e
 
 
 
 
 
 
262b3cb
7e5ddec
 
 
262b3cb
 
 
fa1717e
 
262b3cb
363ba14
262b3cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""Load ISLES24 data from local directory or HuggingFace Hub."""

from __future__ import annotations

import re
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Protocol, Self

from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.core.types import CaseFiles  # noqa: TC001
from stroke_deepisles_demo.data.isles24_manifest import (
    ISLES24_DATASET_ID,
    ISLES24_DATASET_REVISION,
    ISLES24_TRAIN_CASE_IDS,
    isles24_train_data_file,
)

# Security: Regex for valid ISLES24 subject IDs (defense-in-depth)
# Expected format: sub-strokeXXXX (e.g., sub-stroke0001)
_SAFE_SUBJECT_ID_PATTERN = re.compile(r"^sub-stroke\d{4}$")

if TYPE_CHECKING:
    from datasets import Dataset as HFDataset

logger = get_logger(__name__)


class Dataset(Protocol):
    """Protocol for dataset access.

    All dataset implementations support context manager usage for proper cleanup:

        with load_isles_dataset() as ds:
            case = ds.get_case(0)
            # ... process case ...
        # cleanup happens automatically
    """

    def __len__(self) -> int: ...
    def __enter__(self) -> Self: ...
    def __exit__(self, *args: object) -> None: ...
    def list_case_ids(self) -> list[str]: ...
    def get_case(self, case_id: str | int) -> CaseFiles: ...
    def cleanup(self) -> None: ...


@dataclass
class DatasetInfo:
    """Metadata about the dataset."""

    source: str  # "local" or HF dataset ID
    num_cases: int
    modalities: list[str]
    has_ground_truth: bool


@dataclass
class HuggingFaceDatasetWrapper:
    """Wrapper for HuggingFace dataset to match the Dataset protocol.

    Uses the standard datasets library (with neuroimaging-go-brrrr patched Nifti feature)
    to load data. Materializes NIfTI images to temporary files on demand.
    """

    dataset: HFDataset
    dataset_id: str
    _temp_dir: Path | None = field(default=None, repr=False)
    _case_id_to_index: dict[str, int] = field(default_factory=dict, repr=False)

    def __post_init__(self) -> None:
        """Build index of subject IDs for O(1) lookup."""
        try:
            # Efficiently build index from 'subject_id' column
            self._case_id_to_index = {
                sid: idx for idx, sid in enumerate(self.dataset["subject_id"])
            }
        except (KeyError, TypeError, ValueError) as e:
            logger.warning(
                "Failed to build index from subject_id column: %s. Fallback to iteration.", e
            )
            for idx, item in enumerate(self.dataset):
                self._case_id_to_index[item["subject_id"]] = idx

    def __len__(self) -> int:
        return len(self.dataset)

    def __enter__(self) -> Self:
        return self

    def __exit__(self, *args: object) -> None:
        self.cleanup()

    def list_case_ids(self) -> list[str]:
        return sorted(self._case_id_to_index.keys())

    def get_case(self, case_id: str | int) -> CaseFiles:
        """Get files for a case by ID or index.

        Materializes NIfTI objects to temporary files.
        """
        # Resolve case_id to index
        if isinstance(case_id, int):
            if case_id < 0 or case_id >= len(self.dataset):
                raise IndexError(f"Case index {case_id} out of range")
            idx = case_id
        else:
            if case_id not in self._case_id_to_index:
                raise KeyError(f"Case ID {case_id} not found")
            idx = self._case_id_to_index[case_id]

        row = self.dataset[idx]
        subject_id = row["subject_id"]

        # Security: Validate subject_id before using in path (defense-in-depth)
        if not _SAFE_SUBJECT_ID_PATTERN.match(subject_id):
            raise ValueError(
                f"Invalid subject_id format: {subject_id!r}. Expected format: sub-strokeXXXX"
            )

        # Prepare temp dir
        if self._temp_dir is None:
            self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_"))

        case_dir = self._temp_dir / subject_id
        case_dir.mkdir(exist_ok=True)

        dwi_path = case_dir / f"{subject_id}_dwi.nii.gz"
        adc_path = case_dir / f"{subject_id}_adc.nii.gz"

        # Materialize files if they don't exist
        if not dwi_path.exists():
            row["dwi"].to_filename(str(dwi_path))

        if not adc_path.exists():
            row["adc"].to_filename(str(adc_path))

        case_files: CaseFiles = {
            "dwi": dwi_path,
            "adc": adc_path,
        }

        # Handle lesion mask (mapped to ground_truth)
        if "lesion_mask" in row and row["lesion_mask"] is not None:
            mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz"
            if not mask_path.exists():
                row["lesion_mask"].to_filename(str(mask_path))
            case_files["ground_truth"] = mask_path

        return case_files

    def cleanup(self) -> None:
        if self._temp_dir and self._temp_dir.exists():
            try:
                shutil.rmtree(self._temp_dir)
            except OSError as e:
                logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e)
        self._temp_dir = None


@dataclass
class Isles24HuggingFaceDataset:
    """ISLES24 dataset access optimized for HF Spaces.

    Key behavior:
    - `list_case_ids()` returns from a pinned manifest (no dataset download).
    - `get_case()` loads exactly one Parquet shard via `data_files=...` (no 27GB eager download).

    This class exists because `datasets.load_dataset(dataset_id, split="train")` can
    trigger an eager full-dataset download/prepare on cold starts, which is not viable
    for API endpoints like `/api/cases` on Hugging Face Spaces.
    """

    dataset_id: str = ISLES24_DATASET_ID
    token: str | None = None
    revision: str = ISLES24_DATASET_REVISION
    _temp_dir: Path | None = field(default=None, repr=False)

    def __len__(self) -> int:
        return len(ISLES24_TRAIN_CASE_IDS)

    def __enter__(self) -> Self:
        return self

    def __exit__(self, *args: object) -> None:
        self.cleanup()

    def list_case_ids(self) -> list[str]:
        return list(ISLES24_TRAIN_CASE_IDS)

    def get_case(self, case_id: str | int) -> CaseFiles:
        """Load files for a single ISLES24 case.

        Args:
            case_id: Case identifier (e.g., "sub-stroke0102") or 0-based integer index.
        """
        from datasets import load_dataset

        if isinstance(case_id, int):
            if case_id < 0 or case_id >= len(ISLES24_TRAIN_CASE_IDS):
                raise IndexError(f"Case index {case_id} out of range")
            resolved_case_id = ISLES24_TRAIN_CASE_IDS[case_id]
        else:
            resolved_case_id = case_id

        # Security: Validate subject_id before using in path (defense-in-depth)
        if not _SAFE_SUBJECT_ID_PATTERN.match(resolved_case_id):
            raise ValueError(
                f"Invalid subject_id format: {resolved_case_id!r}. Expected format: sub-strokeXXXX"
            )

        # Load exactly one shard (1 case per parquet file in this dataset)
        data_file = isles24_train_data_file(resolved_case_id)
        ds = load_dataset(
            self.dataset_id,
            data_files={"train": data_file},
            split="train",
            token=self.token,
            revision=self.revision,
        )
        ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"])
        if len(ds) != 1:
            raise RuntimeError(f"Expected 1 row for {resolved_case_id}, got {len(ds)}")

        row = ds[0]
        subject_id = row["subject_id"]
        if subject_id != resolved_case_id:
            raise RuntimeError(
                f"Unexpected subject_id {subject_id!r} in {data_file} (expected {resolved_case_id!r})"
            )

        if self._temp_dir is None:
            self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_"))

        case_dir = self._temp_dir / subject_id
        case_dir.mkdir(exist_ok=True)

        dwi_path = case_dir / f"{subject_id}_dwi.nii.gz"
        adc_path = case_dir / f"{subject_id}_adc.nii.gz"

        if not dwi_path.exists():
            row["dwi"].to_filename(str(dwi_path))
        if not adc_path.exists():
            row["adc"].to_filename(str(adc_path))

        case_files: CaseFiles = {
            "dwi": dwi_path,
            "adc": adc_path,
        }

        if row.get("lesion_mask") is not None:
            mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz"
            if not mask_path.exists():
                row["lesion_mask"].to_filename(str(mask_path))
            case_files["ground_truth"] = mask_path

        return case_files

    def cleanup(self) -> None:
        if self._temp_dir and self._temp_dir.exists():
            try:
                shutil.rmtree(self._temp_dir)
            except OSError as e:
                logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e)
        self._temp_dir = None


def load_isles_dataset(
    source: str | Path | None = None,
    *,
    local_mode: bool | None = None,
    token: str | None = None,
) -> Dataset:
    """
    Load ISLES24 dataset from local directory or HuggingFace Hub.

    Args:
        source: Local directory path or HuggingFace dataset ID.
                If None, uses Settings.hf_dataset_id from config.
        local_mode: If True, treat source as local directory.
                    If None, auto-detect based on source type.
        token: HuggingFace token for private/gated datasets.
               If None, uses Settings.hf_token from config.

    Returns:
        Dataset-like object providing case access. Use as context manager
        for automatic cleanup of temp files (important for HuggingFace mode).

    Examples:
        # Load from HuggingFace with automatic cleanup (recommended)
        with load_isles_dataset() as ds:
            case = ds.get_case(0)

        # Load from local directory
        ds = load_isles_dataset("data/isles24", local_mode=True)

        # Load specific HuggingFace dataset with token
        ds = load_isles_dataset("org/private-dataset", token="hf_xxx")
    """
    # Auto-detect mode if not specified
    if local_mode is None:
        if source is None:
            local_mode = False  # Default to HuggingFace
        elif isinstance(source, Path):
            local_mode = True
        else:
            # String: check if it's an existing local path
            # Only select local mode if the path itself exists
            # (avoids misclassifying HF dataset IDs like "org/dataset")
            source_path = Path(source)
            local_mode = source_path.exists()

    if local_mode:
        from stroke_deepisles_demo.data.adapter import build_local_dataset

        if source is None:
            source = "data/isles24"
        return build_local_dataset(Path(source))

    # HuggingFace mode
    from datasets import load_dataset

    from stroke_deepisles_demo.core.config import get_settings

    settings = get_settings()

    # Use settings defaults if not specified
    dataset_id = str(source) if source else settings.hf_dataset_id
    hf_token = token if token is not None else settings.hf_token

    if dataset_id == ISLES24_DATASET_ID:
        return Isles24HuggingFaceDataset(dataset_id=dataset_id, token=hf_token)

    # Load dataset, selecting only necessary columns to minimize decoding overhead
    # We rely on neuroimaging-go-brrrr's Nifti feature for lazy loading if configured,
    # but select_columns ensures we don't touch other modalities.
    # Token enables access to private/gated datasets
    ds = load_dataset(dataset_id, split="train", token=hf_token)
    ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"])

    return HuggingFaceDatasetWrapper(ds, dataset_id)