VibecoderMcSwaggins commited on
Commit
262b3cb
·
unverified ·
1 Parent(s): ba32591

refactor(data): use standard datasets.load_dataset() with neuroimaging-go-brrrr

Browse files

Replaces hand-rolled HuggingFace adapter with standard datasets library using neuroimaging-go-brrrr for NIfTI support. Includes CI fixes and CodeRabbit feedback.

.github/workflows/ci.yml CHANGED
@@ -93,6 +93,17 @@ jobs:
93
  steps:
94
  - uses: actions/checkout@v4
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  - name: Install uv
97
  uses: astral-sh/setup-uv@v4
98
 
 
93
  steps:
94
  - uses: actions/checkout@v4
95
 
96
+ - name: Free disk space
97
+ uses: jlumbroso/free-disk-space@main
98
+ with:
99
+ tool-cache: false
100
+ android: true
101
+ dotnet: true
102
+ haskell: true
103
+ large-packages: false # Keep false to avoid long cleanup time
104
+ docker-images: false
105
+ swap-storage: false
106
+
107
  - name: Install uv
108
  uses: astral-sh/setup-uv@v4
109
 
src/stroke_deepisles_demo/data/adapter.py CHANGED
@@ -3,13 +3,10 @@
3
  from __future__ import annotations
4
 
5
  import re
6
- import shutil
7
- import tempfile
8
- from dataclasses import dataclass, field
9
- from pathlib import Path
10
  from typing import TYPE_CHECKING, Self
11
 
12
- from stroke_deepisles_demo.core.exceptions import DataLoadError
13
  from stroke_deepisles_demo.core.logging import get_logger
14
 
15
  if TYPE_CHECKING:
@@ -24,7 +21,7 @@ logger = get_logger(__name__)
24
  class LocalDataset:
25
  """File-based dataset for local ISLES24 data.
26
 
27
- Can be used as a context manager for consistency with HuggingFaceDataset,
28
  though no cleanup is needed for local files.
29
 
30
  Example:
@@ -133,246 +130,3 @@ def build_local_dataset(data_dir: Path) -> LocalDataset:
133
 
134
  logger.info("Loaded %d cases from %s", len(cases), data_dir)
135
  return LocalDataset(data_dir=data_dir, cases=cases)
136
-
137
-
138
- # =============================================================================
139
- # HuggingFace Dataset Adapter
140
- # =============================================================================
141
-
142
-
143
- @dataclass
144
- class HuggingFaceDataset:
145
- """Dataset adapter for HuggingFace ISLES24 dataset.
146
-
147
- Wraps the HuggingFace dataset and provides the same interface as LocalDataset.
148
- When get_case() is called, downloads NIfTI bytes from individual parquet files
149
- and writes them to temp files.
150
-
151
- This implementation bypasses `load_dataset()` entirely to avoid:
152
- 1. PyArrow streaming bug (apache/arrow#45214) that hangs on parquet iteration
153
- 2. Memory issues from downloading the full 99GB dataset
154
-
155
- IMPORTANT: Use as a context manager to ensure temp files are cleaned up:
156
-
157
- with build_huggingface_dataset(dataset_id) as ds:
158
- case = ds.get_case(0)
159
- # ... process case ...
160
- # temp files automatically cleaned up
161
-
162
- Or call cleanup() manually when done.
163
- """
164
-
165
- dataset_id: str
166
- _case_ids: list[str] = field(default_factory=list)
167
- _case_index: dict[str, int] = field(default_factory=dict)
168
- _temp_dir: Path | None = field(default=None, repr=False)
169
- _cached_cases: dict[str, CaseFiles] = field(default_factory=dict, repr=False)
170
-
171
- def __len__(self) -> int:
172
- return len(self._case_ids)
173
-
174
- def __iter__(self) -> Iterator[str]:
175
- return iter(self._case_ids)
176
-
177
- def __enter__(self) -> Self:
178
- return self
179
-
180
- def __exit__(self, *args: object) -> None:
181
- self.cleanup()
182
-
183
- def list_case_ids(self) -> list[str]:
184
- """Return sorted list of subject IDs."""
185
- return sorted(self._case_ids)
186
-
187
- def get_case(self, case_id: str | int) -> CaseFiles:
188
- """Get files for a case by ID or index.
189
-
190
- Downloads NIfTI bytes from the individual parquet file for this case
191
- and writes to temp files. Returns cached paths on subsequent calls.
192
-
193
- This uses HfFileSystem + pyarrow to download only the single case (~50MB)
194
- instead of the full dataset (99GB), completing in ~2 seconds.
195
-
196
- Raises:
197
- DataLoadError: If HuggingFace data is malformed or missing required fields.
198
- KeyError: If case_id is not found in the dataset.
199
- """
200
- # Resolve case_id to subject_id and file index
201
- if isinstance(case_id, int):
202
- if case_id < 0 or case_id >= len(self._case_ids):
203
- raise IndexError(f"Case index {case_id} out of range [0, {len(self._case_ids)})")
204
- subject_id = self._case_ids[case_id]
205
- file_idx = case_id
206
- else:
207
- subject_id = case_id
208
- if subject_id not in self._case_index:
209
- raise KeyError(f"Case ID '{subject_id}' not found in dataset")
210
- file_idx = self._case_index[subject_id]
211
-
212
- # Return cached case if already materialized
213
- if subject_id in self._cached_cases:
214
- return self._cached_cases[subject_id]
215
-
216
- # Create shared temp directory on first use
217
- if self._temp_dir is None:
218
- self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_"))
219
- logger.debug("Created temp directory: %s", self._temp_dir)
220
-
221
- # Download case data from individual parquet file
222
- logger.info("Downloading case %s from HuggingFace...", subject_id)
223
- case_data = self._download_case_from_parquet(file_idx, subject_id)
224
-
225
- # Create case subdirectory
226
- case_dir = self._temp_dir / subject_id
227
- case_dir.mkdir(exist_ok=True)
228
-
229
- # Write NIfTI files to temp directory
230
- dwi_path = case_dir / f"{subject_id}_ses-02_dwi.nii.gz"
231
- adc_path = case_dir / f"{subject_id}_ses-02_adc.nii.gz"
232
- mask_path = case_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz"
233
-
234
- # Write the gzipped NIfTI bytes
235
- dwi_path.write_bytes(case_data["dwi_bytes"])
236
- adc_path.write_bytes(case_data["adc_bytes"])
237
-
238
- case_files: CaseFiles = {
239
- "dwi": dwi_path,
240
- "adc": adc_path,
241
- }
242
-
243
- # Write lesion mask if available
244
- if case_data.get("mask_bytes"):
245
- mask_path.write_bytes(case_data["mask_bytes"])
246
- case_files["ground_truth"] = mask_path
247
-
248
- # Cache for subsequent calls
249
- self._cached_cases[subject_id] = case_files
250
- logger.info(
251
- "Case %s ready: DWI=%.1fMB, ADC=%.1fMB",
252
- subject_id,
253
- len(case_data["dwi_bytes"]) / 1024 / 1024,
254
- len(case_data["adc_bytes"]) / 1024 / 1024,
255
- )
256
-
257
- return case_files
258
-
259
- def _download_case_from_parquet(self, file_idx: int, subject_id: str) -> dict[str, bytes]:
260
- """Download case data directly from individual parquet file.
261
-
262
- Uses HfFileSystem + pyarrow to read only the columns we need from
263
- a single parquet file, avoiding the need to download the full dataset.
264
-
265
- Args:
266
- file_idx: Index of the parquet file (0-148)
267
- subject_id: Expected subject ID (for validation)
268
-
269
- Returns:
270
- Dict with dwi_bytes, adc_bytes, and optionally mask_bytes
271
- """
272
- import pyarrow.parquet as pq
273
- from huggingface_hub import HfFileSystem
274
-
275
- from stroke_deepisles_demo.data.constants import ISLES24_NUM_FILES
276
-
277
- # Construct path to the specific parquet file
278
- fpath = f"datasets/{self.dataset_id}/data/train-{file_idx:05d}-of-{ISLES24_NUM_FILES:05d}.parquet"
279
-
280
- try:
281
- fs = HfFileSystem()
282
- with fs.open(fpath, "rb") as f:
283
- pf = pq.ParquetFile(f)
284
- # Read only the columns we need
285
- table = pf.read(columns=["subject_id", "dwi", "adc", "lesion_mask"])
286
- df = table.to_pandas()
287
-
288
- if len(df) != 1:
289
- raise DataLoadError(f"Expected 1 row in parquet file, got {len(df)}: {fpath}")
290
-
291
- row = df.iloc[0]
292
-
293
- # Validate subject_id matches
294
- actual_subject_id = row["subject_id"]
295
- if actual_subject_id != subject_id:
296
- raise DataLoadError(
297
- f"Subject ID mismatch: expected {subject_id}, got {actual_subject_id} in {fpath}"
298
- )
299
-
300
- # Extract bytes with defensive error handling
301
- try:
302
- dwi_bytes = row["dwi"]["bytes"]
303
- adc_bytes = row["adc"]["bytes"]
304
- except (KeyError, TypeError) as e:
305
- raise DataLoadError(
306
- f"Malformed HuggingFace data for {subject_id}: missing 'dwi' or 'adc' bytes. "
307
- f"The dataset schema may have changed. Error: {e}"
308
- ) from e
309
-
310
- result: dict[str, bytes] = {
311
- "dwi_bytes": dwi_bytes,
312
- "adc_bytes": adc_bytes,
313
- }
314
-
315
- # Extract mask if available
316
- mask_data = row.get("lesion_mask")
317
- if mask_data is not None and isinstance(mask_data, dict) and mask_data.get("bytes"):
318
- result["mask_bytes"] = mask_data["bytes"]
319
-
320
- return result
321
-
322
- except Exception as e:
323
- if isinstance(e, DataLoadError):
324
- raise
325
- raise DataLoadError(f"Failed to download case {subject_id} from {fpath}: {e}") from e
326
-
327
- def cleanup(self) -> None:
328
- """Remove temp directory and clear cache."""
329
- if self._temp_dir is not None and self._temp_dir.exists():
330
- try:
331
- shutil.rmtree(self._temp_dir)
332
- logger.debug("Cleaned up temp directory: %s", self._temp_dir)
333
- except OSError as e:
334
- logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e)
335
- self._temp_dir = None
336
- self._cached_cases.clear()
337
-
338
-
339
- def build_huggingface_dataset(dataset_id: str) -> HuggingFaceDataset:
340
- """
341
- Build ISLES24 dataset adapter for HuggingFace Hub.
342
-
343
- Uses pre-computed case IDs to avoid streaming enumeration (which hangs
344
- due to PyArrow bug apache/arrow#45214). Actual data is downloaded lazily
345
- from individual parquet files when get_case() is called.
346
-
347
- Args:
348
- dataset_id: HuggingFace dataset identifier (e.g., "hugging-science/isles24-stroke")
349
-
350
- Returns:
351
- HuggingFaceDataset providing case access
352
- """
353
- from stroke_deepisles_demo.data.constants import (
354
- ISLES24_CASE_IDS,
355
- ISLES24_CASE_INDEX,
356
- ISLES24_DATASET_ID,
357
- )
358
-
359
- # Validate dataset_id matches our pre-computed constants
360
- if dataset_id != ISLES24_DATASET_ID:
361
- logger.warning(
362
- "Dataset ID '%s' does not match pre-computed constants for '%s'. "
363
- "Case IDs may be incorrect.",
364
- dataset_id,
365
- ISLES24_DATASET_ID,
366
- )
367
-
368
- logger.info(
369
- "Building HuggingFace dataset adapter: %s (%d cases, pre-computed)",
370
- dataset_id,
371
- len(ISLES24_CASE_IDS),
372
- )
373
-
374
- return HuggingFaceDataset(
375
- dataset_id=dataset_id,
376
- _case_ids=list(ISLES24_CASE_IDS),
377
- _case_index=dict(ISLES24_CASE_INDEX),
378
- )
 
3
  from __future__ import annotations
4
 
5
  import re
6
+ from dataclasses import dataclass
7
+ from pathlib import Path # noqa: TC003
 
 
8
  from typing import TYPE_CHECKING, Self
9
 
 
10
  from stroke_deepisles_demo.core.logging import get_logger
11
 
12
  if TYPE_CHECKING:
 
21
  class LocalDataset:
22
  """File-based dataset for local ISLES24 data.
23
 
24
+ Can be used as a context manager for consistency with HuggingFaceDatasetWrapper,
25
  though no cleanup is needed for local files.
26
 
27
  Example:
 
130
 
131
  logger.info("Loaded %d cases from %s", len(cases), data_dir)
132
  return LocalDataset(data_dir=data_dir, cases=cases)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/stroke_deepisles_demo/data/constants.py DELETED
@@ -1,181 +0,0 @@
1
- """Pre-computed constants for ISLES24 dataset.
2
-
3
- The ISLES24 challenge dataset is static (case IDs will never change).
4
- Pre-computing these values avoids:
5
- 1. PyArrow streaming bug (apache/arrow#45214) that hangs on parquet iteration
6
- 2. Memory issues from downloading the full 99GB dataset
7
-
8
- See docs/specs/08-bug-hf-spaces-dataset-loop.md for full investigation.
9
- """
10
-
11
- # Pre-computed case IDs for ISLES24 dataset
12
- # Extracted via HfFileSystem enumeration on 2025-12-08
13
- # Order matches parquet file indices (train-00000-of-00149.parquet = index 0)
14
- ISLES24_CASE_IDS: tuple[str, ...] = (
15
- "sub-stroke0001",
16
- "sub-stroke0002",
17
- "sub-stroke0003",
18
- "sub-stroke0004",
19
- "sub-stroke0005",
20
- "sub-stroke0006",
21
- "sub-stroke0007",
22
- "sub-stroke0008",
23
- "sub-stroke0009",
24
- "sub-stroke0010",
25
- "sub-stroke0011",
26
- "sub-stroke0012",
27
- "sub-stroke0013",
28
- "sub-stroke0014",
29
- "sub-stroke0015",
30
- "sub-stroke0016",
31
- "sub-stroke0017",
32
- "sub-stroke0019",
33
- "sub-stroke0020",
34
- "sub-stroke0021",
35
- "sub-stroke0022",
36
- "sub-stroke0025",
37
- "sub-stroke0026",
38
- "sub-stroke0027",
39
- "sub-stroke0028",
40
- "sub-stroke0030",
41
- "sub-stroke0033",
42
- "sub-stroke0036",
43
- "sub-stroke0037",
44
- "sub-stroke0038",
45
- "sub-stroke0040",
46
- "sub-stroke0043",
47
- "sub-stroke0045",
48
- "sub-stroke0047",
49
- "sub-stroke0048",
50
- "sub-stroke0049",
51
- "sub-stroke0052",
52
- "sub-stroke0053",
53
- "sub-stroke0054",
54
- "sub-stroke0055",
55
- "sub-stroke0057",
56
- "sub-stroke0062",
57
- "sub-stroke0066",
58
- "sub-stroke0068",
59
- "sub-stroke0070",
60
- "sub-stroke0071",
61
- "sub-stroke0073",
62
- "sub-stroke0074",
63
- "sub-stroke0075",
64
- "sub-stroke0076",
65
- "sub-stroke0077",
66
- "sub-stroke0078",
67
- "sub-stroke0079",
68
- "sub-stroke0080",
69
- "sub-stroke0081",
70
- "sub-stroke0082",
71
- "sub-stroke0083",
72
- "sub-stroke0084",
73
- "sub-stroke0085",
74
- "sub-stroke0086",
75
- "sub-stroke0087",
76
- "sub-stroke0088",
77
- "sub-stroke0089",
78
- "sub-stroke0090",
79
- "sub-stroke0091",
80
- "sub-stroke0092",
81
- "sub-stroke0093",
82
- "sub-stroke0094",
83
- "sub-stroke0095",
84
- "sub-stroke0096",
85
- "sub-stroke0097",
86
- "sub-stroke0098",
87
- "sub-stroke0099",
88
- "sub-stroke0100",
89
- "sub-stroke0101",
90
- "sub-stroke0102",
91
- "sub-stroke0103",
92
- "sub-stroke0104",
93
- "sub-stroke0105",
94
- "sub-stroke0106",
95
- "sub-stroke0107",
96
- "sub-stroke0108",
97
- "sub-stroke0109",
98
- "sub-stroke0110",
99
- "sub-stroke0111",
100
- "sub-stroke0112",
101
- "sub-stroke0113",
102
- "sub-stroke0114",
103
- "sub-stroke0115",
104
- "sub-stroke0116",
105
- "sub-stroke0117",
106
- "sub-stroke0118",
107
- "sub-stroke0119",
108
- "sub-stroke0133",
109
- "sub-stroke0134",
110
- "sub-stroke0135",
111
- "sub-stroke0136",
112
- "sub-stroke0137",
113
- "sub-stroke0138",
114
- "sub-stroke0139",
115
- "sub-stroke0140",
116
- "sub-stroke0141",
117
- "sub-stroke0142",
118
- "sub-stroke0143",
119
- "sub-stroke0144",
120
- "sub-stroke0145",
121
- "sub-stroke0146",
122
- "sub-stroke0147",
123
- "sub-stroke0148",
124
- "sub-stroke0149",
125
- "sub-stroke0150",
126
- "sub-stroke0151",
127
- "sub-stroke0152",
128
- "sub-stroke0153",
129
- "sub-stroke0154",
130
- "sub-stroke0155",
131
- "sub-stroke0156",
132
- "sub-stroke0157",
133
- "sub-stroke0158",
134
- "sub-stroke0159",
135
- "sub-stroke0161",
136
- "sub-stroke0162",
137
- "sub-stroke0163",
138
- "sub-stroke0164",
139
- "sub-stroke0165",
140
- "sub-stroke0166",
141
- "sub-stroke0167",
142
- "sub-stroke0168",
143
- "sub-stroke0169",
144
- "sub-stroke0170",
145
- "sub-stroke0171",
146
- "sub-stroke0172",
147
- "sub-stroke0173",
148
- "sub-stroke0174",
149
- "sub-stroke0175",
150
- "sub-stroke0176",
151
- "sub-stroke0177",
152
- "sub-stroke0178",
153
- "sub-stroke0179",
154
- "sub-stroke0180",
155
- "sub-stroke0181",
156
- "sub-stroke0182",
157
- "sub-stroke0183",
158
- "sub-stroke0184",
159
- "sub-stroke0185",
160
- "sub-stroke0186",
161
- "sub-stroke0187",
162
- "sub-stroke0188",
163
- "sub-stroke0189",
164
- )
165
-
166
- # Mapping from case ID to parquet file index (0-indexed)
167
- # train-00000-of-00149.parquet contains sub-stroke0001
168
- # train-00001-of-00149.parquet contains sub-stroke0002
169
- # etc.
170
- ISLES24_CASE_INDEX: dict[str, int] = {case_id: idx for idx, case_id in enumerate(ISLES24_CASE_IDS)}
171
-
172
- # Total number of parquet files in the dataset
173
- ISLES24_NUM_FILES: int = 149
174
-
175
- # Sanity check: ensure constants are consistent
176
- assert len(ISLES24_CASE_IDS) == ISLES24_NUM_FILES, (
177
- f"ISLES24_CASE_IDS has {len(ISLES24_CASE_IDS)} entries but ISLES24_NUM_FILES is {ISLES24_NUM_FILES}"
178
- )
179
-
180
- # Dataset identifier on HuggingFace Hub
181
- ISLES24_DATASET_ID: str = "hugging-science/isles24-stroke"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/stroke_deepisles_demo/data/loader.py CHANGED
@@ -2,12 +2,19 @@
2
 
3
  from __future__ import annotations
4
 
5
- from dataclasses import dataclass
 
 
6
  from pathlib import Path
7
  from typing import TYPE_CHECKING, Protocol, Self
8
 
 
 
 
9
  if TYPE_CHECKING:
10
- from stroke_deepisles_demo.core.types import CaseFiles
 
 
11
 
12
 
13
  class Dataset(Protocol):
@@ -39,6 +46,103 @@ class DatasetInfo:
39
  has_ground_truth: bool
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Default HuggingFace dataset ID
43
  DEFAULT_HF_DATASET = "hugging-science/isles24-stroke"
44
 
@@ -93,7 +197,14 @@ def load_isles_dataset(
93
  return build_local_dataset(Path(source))
94
 
95
  # HuggingFace mode
96
- from stroke_deepisles_demo.data.adapter import build_huggingface_dataset
 
 
 
 
 
 
 
 
97
 
98
- dataset_id = source if source else DEFAULT_HF_DATASET
99
- return build_huggingface_dataset(str(dataset_id))
 
2
 
3
  from __future__ import annotations
4
 
5
+ import shutil
6
+ import tempfile
7
+ from dataclasses import dataclass, field
8
  from pathlib import Path
9
  from typing import TYPE_CHECKING, Protocol, Self
10
 
11
+ from stroke_deepisles_demo.core.logging import get_logger
12
+ from stroke_deepisles_demo.core.types import CaseFiles # noqa: TC001
13
+
14
  if TYPE_CHECKING:
15
+ from datasets import Dataset as HFDataset
16
+
17
+ logger = get_logger(__name__)
18
 
19
 
20
  class Dataset(Protocol):
 
46
  has_ground_truth: bool
47
 
48
 
49
+ @dataclass
50
+ class HuggingFaceDatasetWrapper:
51
+ """Wrapper for HuggingFace dataset to match the Dataset protocol.
52
+
53
+ Uses the standard datasets library (with neuroimaging-go-brrrr patched Nifti feature)
54
+ to load data. Materializes NIfTI images to temporary files on demand.
55
+ """
56
+
57
+ dataset: HFDataset
58
+ dataset_id: str
59
+ _temp_dir: Path | None = field(default=None, repr=False)
60
+ _case_id_to_index: dict[str, int] = field(default_factory=dict, repr=False)
61
+
62
+ def __post_init__(self) -> None:
63
+ """Build index of subject IDs for O(1) lookup."""
64
+ try:
65
+ # Efficiently build index from 'subject_id' column
66
+ self._case_id_to_index = {
67
+ sid: idx for idx, sid in enumerate(self.dataset["subject_id"])
68
+ }
69
+ except (KeyError, TypeError, ValueError) as e:
70
+ logger.warning(
71
+ "Failed to build index from subject_id column: %s. Fallback to iteration.", e
72
+ )
73
+ for idx, item in enumerate(self.dataset):
74
+ self._case_id_to_index[item["subject_id"]] = idx
75
+
76
+ def __len__(self) -> int:
77
+ return len(self.dataset)
78
+
79
+ def __enter__(self) -> Self:
80
+ return self
81
+
82
+ def __exit__(self, *args: object) -> None:
83
+ self.cleanup()
84
+
85
+ def list_case_ids(self) -> list[str]:
86
+ return sorted(self._case_id_to_index.keys())
87
+
88
+ def get_case(self, case_id: str | int) -> CaseFiles:
89
+ """Get files for a case by ID or index.
90
+
91
+ Materializes NIfTI objects to temporary files.
92
+ """
93
+ # Resolve case_id to index
94
+ if isinstance(case_id, int):
95
+ if case_id < 0 or case_id >= len(self.dataset):
96
+ raise IndexError(f"Case index {case_id} out of range")
97
+ idx = case_id
98
+ else:
99
+ if case_id not in self._case_id_to_index:
100
+ raise KeyError(f"Case ID {case_id} not found")
101
+ idx = self._case_id_to_index[case_id]
102
+
103
+ row = self.dataset[idx]
104
+ subject_id = row["subject_id"]
105
+
106
+ # Prepare temp dir
107
+ if self._temp_dir is None:
108
+ self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_"))
109
+
110
+ case_dir = self._temp_dir / subject_id
111
+ case_dir.mkdir(exist_ok=True)
112
+
113
+ dwi_path = case_dir / f"{subject_id}_dwi.nii.gz"
114
+ adc_path = case_dir / f"{subject_id}_adc.nii.gz"
115
+
116
+ # Materialize files if they don't exist
117
+ if not dwi_path.exists():
118
+ row["dwi"].to_filename(str(dwi_path))
119
+
120
+ if not adc_path.exists():
121
+ row["adc"].to_filename(str(adc_path))
122
+
123
+ case_files: CaseFiles = {
124
+ "dwi": dwi_path,
125
+ "adc": adc_path,
126
+ }
127
+
128
+ # Handle lesion mask (mapped to ground_truth)
129
+ if "lesion_mask" in row and row["lesion_mask"] is not None:
130
+ mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz"
131
+ if not mask_path.exists():
132
+ row["lesion_mask"].to_filename(str(mask_path))
133
+ case_files["ground_truth"] = mask_path
134
+
135
+ return case_files
136
+
137
+ def cleanup(self) -> None:
138
+ if self._temp_dir and self._temp_dir.exists():
139
+ try:
140
+ shutil.rmtree(self._temp_dir)
141
+ except OSError as e:
142
+ logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e)
143
+ self._temp_dir = None
144
+
145
+
146
  # Default HuggingFace dataset ID
147
  DEFAULT_HF_DATASET = "hugging-science/isles24-stroke"
148
 
 
197
  return build_local_dataset(Path(source))
198
 
199
  # HuggingFace mode
200
+ from datasets import load_dataset
201
+
202
+ dataset_id = str(source) if source else DEFAULT_HF_DATASET
203
+
204
+ # Load dataset, selecting only necessary columns to minimize decoding overhead
205
+ # We rely on neuroimaging-go-brrrr's Nifti feature for lazy loading if configured,
206
+ # but select_columns ensures we don't touch other modalities.
207
+ ds = load_dataset(dataset_id, split="train")
208
+ ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"])
209
 
210
+ return HuggingFaceDatasetWrapper(ds, dataset_id)
 
tests/api/test_endpoints.py CHANGED
@@ -84,31 +84,37 @@ class TestPostSegment:
84
 
85
  def test_creates_job_and_returns_202(self, client: TestClient) -> None:
86
  """POST /api/segment creates a job and returns 202 Accepted."""
87
- response = client.post(
88
- "/api/segment",
89
- json={"case_id": "sub-stroke0001", "fast_mode": True},
90
- )
91
 
92
- assert response.status_code == 202
93
- data = response.json()
94
- assert "jobId" in data
95
- assert data["status"] == "pending"
96
- assert "message" in data
 
 
 
 
 
97
 
98
  def test_returns_job_id_for_polling(self, client: TestClient) -> None:
99
  """POST /api/segment returns a job ID that can be used for polling."""
100
- response = client.post(
101
- "/api/segment",
102
- json={"case_id": "sub-stroke0001", "fast_mode": True},
103
- )
 
 
 
104
 
105
- job_id = response.json()["jobId"]
106
- assert job_id is not None
107
- assert len(job_id) > 0
108
 
109
- # Job should be retrievable via GET /api/jobs/{id}
110
- status_response = client.get(f"/api/jobs/{job_id}")
111
- assert status_response.status_code == 200
112
 
113
  def test_returns_422_on_missing_case_id(self, client: TestClient) -> None:
114
  """POST /api/segment returns 422 when case_id is missing."""
 
84
 
85
  def test_creates_job_and_returns_202(self, client: TestClient) -> None:
86
  """POST /api/segment creates a job and returns 202 Accepted."""
87
+ with patch("stroke_deepisles_demo.api.routes.list_case_ids") as mock_list:
88
+ mock_list.return_value = ["sub-stroke0001", "sub-stroke0002"]
 
 
89
 
90
+ response = client.post(
91
+ "/api/segment",
92
+ json={"case_id": "sub-stroke0001", "fast_mode": True},
93
+ )
94
+
95
+ assert response.status_code == 202
96
+ data = response.json()
97
+ assert "jobId" in data
98
+ assert data["status"] == "pending"
99
+ assert "message" in data
100
 
101
  def test_returns_job_id_for_polling(self, client: TestClient) -> None:
102
  """POST /api/segment returns a job ID that can be used for polling."""
103
+ with patch("stroke_deepisles_demo.api.routes.list_case_ids") as mock_list:
104
+ mock_list.return_value = ["sub-stroke0001", "sub-stroke0002"]
105
+
106
+ response = client.post(
107
+ "/api/segment",
108
+ json={"case_id": "sub-stroke0001", "fast_mode": True},
109
+ )
110
 
111
+ job_id = response.json()["jobId"]
112
+ assert job_id is not None
113
+ assert len(job_id) > 0
114
 
115
+ # Job should be retrievable via GET /api/jobs/{id}
116
+ status_response = client.get(f"/api/jobs/{job_id}")
117
+ assert status_response.status_code == 200
118
 
119
  def test_returns_422_on_missing_case_id(self, client: TestClient) -> None:
120
  """POST /api/segment returns 422 when case_id is missing."""
tests/core/test_config.py CHANGED
@@ -25,7 +25,8 @@ class TestSettings:
25
  assert settings.log_level == "INFO"
26
  assert settings.hf_dataset_id == "hugging-science/isles24-stroke"
27
  assert settings.deepisles_timeout_seconds == 1800
28
- assert settings.results_dir == Path("./results")
 
29
 
30
  def test_env_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
31
  """Environment variables override defaults."""
 
25
  assert settings.log_level == "INFO"
26
  assert settings.hf_dataset_id == "hugging-science/isles24-stroke"
27
  assert settings.deepisles_timeout_seconds == 1800
28
+ # Default is /tmp/stroke-results for HF Spaces compatibility (only /tmp is writable)
29
+ assert settings.results_dir == Path("/tmp/stroke-results")
30
 
31
  def test_env_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
32
  """Environment variables override defaults."""
tests/data/test_hf_adapter.py CHANGED
@@ -1,295 +1,151 @@
1
- """Unit tests for HuggingFace dataset adapter with mocked HF data access."""
2
 
3
  from __future__ import annotations
4
 
5
- from unittest.mock import MagicMock, patch
 
6
 
7
  import pytest
8
 
9
- from stroke_deepisles_demo.core.exceptions import DataLoadError
10
- from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, build_huggingface_dataset
11
 
12
 
13
- class TestHuggingFaceDataset:
14
- """Tests for HuggingFaceDataset class."""
15
 
16
- def test_get_case_writes_files_to_temp_dir(self) -> None:
17
- """Test that get_case writes NIfTI bytes to temp files."""
18
- case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
19
- case_index = {cid: idx for idx, cid in enumerate(case_ids)}
20
 
21
- ds = HuggingFaceDataset(
22
- dataset_id="test/dataset",
23
- _case_ids=case_ids,
24
- _case_index=case_index,
25
- )
26
-
27
- # Mock the download method
28
- mock_data = {
29
- "dwi_bytes": b"fake_dwi_nifti_data",
30
- "adc_bytes": b"fake_adc_nifti_data",
31
- "mask_bytes": b"fake_mask_nifti_data",
32
- }
33
 
34
- try:
35
- with patch.object(ds, "_download_case_from_parquet", return_value=mock_data):
36
- case = ds.get_case(0)
37
-
38
- assert "dwi" in case
39
- assert "adc" in case
40
- assert case["dwi"].exists()
41
- assert case["adc"].exists()
42
- assert case["dwi"].read_bytes() == b"fake_dwi_nifti_data"
43
- assert case["adc"].read_bytes() == b"fake_adc_nifti_data"
44
- finally:
45
- ds.cleanup()
46
-
47
- def test_get_case_includes_ground_truth_when_available(self) -> None:
48
- """Test that ground truth is included when lesion_mask is present."""
49
- case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
50
- case_index = {cid: idx for idx, cid in enumerate(case_ids)}
51
-
52
- ds = HuggingFaceDataset(
53
- dataset_id="test/dataset",
54
- _case_ids=case_ids,
55
- _case_index=case_index,
56
  )
57
 
58
- try:
59
- # Case with mask
60
- mock_data_with_mask = {
61
- "dwi_bytes": b"fake_dwi_nifti_data",
62
- "adc_bytes": b"fake_adc_nifti_data",
63
- "mask_bytes": b"fake_mask_nifti_data",
64
- }
65
- with patch.object(ds, "_download_case_from_parquet", return_value=mock_data_with_mask):
66
- case = ds.get_case(0)
67
- assert "ground_truth" in case
68
- assert case["ground_truth"].read_bytes() == b"fake_mask_nifti_data"
69
-
70
- # Case without mask
71
- mock_data_no_mask = {
72
- "dwi_bytes": b"fake_dwi_nifti_data",
73
- "adc_bytes": b"fake_adc_nifti_data",
74
- }
75
- with patch.object(ds, "_download_case_from_parquet", return_value=mock_data_no_mask):
76
- case_no_mask = ds.get_case(2)
77
- assert "ground_truth" not in case_no_mask
78
- finally:
79
- ds.cleanup()
80
-
81
- def test_get_case_caches_results(self) -> None:
82
- """Test that get_case returns cached paths on subsequent calls."""
83
- case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
84
- case_index = {cid: idx for idx, cid in enumerate(case_ids)}
85
-
86
- ds = HuggingFaceDataset(
87
- dataset_id="test/dataset",
88
- _case_ids=case_ids,
89
- _case_index=case_index,
90
- )
91
-
92
- mock_data = {
93
- "dwi_bytes": b"fake_dwi_nifti_data",
94
- "adc_bytes": b"fake_adc_nifti_data",
95
  }
96
 
97
- try:
98
- with patch.object(
99
- ds, "_download_case_from_parquet", return_value=mock_data
100
- ) as mock_download:
101
- case1 = ds.get_case(0)
102
- case2 = ds.get_case(0)
103
-
104
- # Same object returned (cached)
105
- assert case1 is case2
106
-
107
- # Download was only called once
108
- assert mock_download.call_count == 1
109
- finally:
110
- ds.cleanup()
111
-
112
- def test_context_manager_cleans_up_temp_files(self) -> None:
113
- """Test that using context manager cleans up temp files."""
114
- case_ids = ["sub-stroke0001"]
115
- case_index = {"sub-stroke0001": 0}
116
-
117
- ds = HuggingFaceDataset(
118
- dataset_id="test/dataset",
119
- _case_ids=case_ids,
120
- _case_index=case_index,
121
  )
122
 
123
- mock_data = {
124
- "dwi_bytes": b"fake_dwi_nifti_data",
125
- "adc_bytes": b"fake_adc_nifti_data",
126
- }
127
 
128
- with patch.object(ds, "_download_case_from_parquet", return_value=mock_data), ds:
129
- case = ds.get_case(0)
130
- temp_dir = case["dwi"].parent.parent
131
- assert temp_dir.exists()
132
 
133
- # After context exit, temp dir should be gone
134
- assert not temp_dir.exists()
 
 
135
 
136
- def test_cleanup_clears_cache(self) -> None:
137
- """Test that cleanup clears the case cache."""
138
- case_ids = ["sub-stroke0001"]
139
- case_index = {"sub-stroke0001": 0}
140
 
141
- ds = HuggingFaceDataset(
142
- dataset_id="test/dataset",
143
- _case_ids=case_ids,
144
- _case_index=case_index,
145
- )
146
 
147
- mock_data = {
148
- "dwi_bytes": b"fake_dwi_nifti_data",
149
- "adc_bytes": b"fake_adc_nifti_data",
 
 
 
 
150
  }
151
 
152
- with patch.object(ds, "_download_case_from_parquet", return_value=mock_data):
153
- ds.get_case(0)
154
- assert len(ds._cached_cases) == 1
155
-
156
- ds.cleanup()
157
- assert len(ds._cached_cases) == 0
158
-
159
- def test_get_case_by_string_id(self) -> None:
160
- """Test that get_case works with string case IDs."""
161
- case_ids = ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
162
- case_index = {cid: idx for idx, cid in enumerate(case_ids)}
163
-
164
- ds = HuggingFaceDataset(
165
- dataset_id="test/dataset",
166
- _case_ids=case_ids,
167
- _case_index=case_index,
168
  )
169
 
170
- mock_data = {
171
- "dwi_bytes": b"fake_dwi_nifti_data",
172
- "adc_bytes": b"fake_adc_nifti_data",
173
- }
174
-
175
- try:
176
- with patch.object(
177
- ds, "_download_case_from_parquet", return_value=mock_data
178
- ) as mock_download:
179
- case = ds.get_case("sub-stroke0002")
180
- assert case["dwi"].exists()
181
- # Should have been called with index 1 (second case)
182
- mock_download.assert_called_once_with(1, "sub-stroke0002")
183
- finally:
184
- ds.cleanup()
185
-
186
- def test_get_case_raises_key_error_for_invalid_id(self) -> None:
187
- """Test that get_case raises KeyError for invalid case ID."""
188
- case_ids = ["sub-stroke0001"]
189
- case_index = {"sub-stroke0001": 0}
190
-
191
- ds = HuggingFaceDataset(
192
- dataset_id="test/dataset",
193
- _case_ids=case_ids,
194
- _case_index=case_index,
195
- )
196
 
197
- with pytest.raises(KeyError, match="not found in dataset"):
198
- ds.get_case("sub-stroke9999")
199
 
200
- def test_get_case_raises_index_error_for_out_of_range(self) -> None:
201
- """Test that get_case raises IndexError for out of range index."""
202
- case_ids = ["sub-stroke0001"]
203
- case_index = {"sub-stroke0001": 0}
204
 
205
- ds = HuggingFaceDataset(
206
- dataset_id="test/dataset",
207
- _case_ids=case_ids,
208
- _case_index=case_index,
 
 
 
 
 
 
209
  )
210
 
211
- with pytest.raises(IndexError, match="out of range"):
212
- ds.get_case(99)
213
-
214
-
215
- class TestBuildHuggingFaceDataset:
216
- """Tests for build_huggingface_dataset function."""
217
-
218
- def test_uses_precomputed_case_ids(self) -> None:
219
- """Test that build_huggingface_dataset uses pre-computed case IDs."""
220
- result = build_huggingface_dataset("hugging-science/isles24-stroke")
221
-
222
- assert isinstance(result, HuggingFaceDataset)
223
- assert result.dataset_id == "hugging-science/isles24-stroke"
224
- # Should have 149 cases from pre-computed list
225
- assert len(result._case_ids) == 149
226
- assert "sub-stroke0001" in result._case_ids
227
- assert "sub-stroke0189" in result._case_ids
228
-
229
- def test_case_index_mapping_is_correct(self) -> None:
230
- """Test that case index mapping matches case IDs order."""
231
- result = build_huggingface_dataset("hugging-science/isles24-stroke")
232
-
233
- # First case should map to index 0
234
- assert result._case_index["sub-stroke0001"] == 0
235
- # Last case should map to index 148
236
- assert result._case_index["sub-stroke0189"] == 148
237
 
238
- def test_warns_for_different_dataset_id(self) -> None:
239
- """Test that a warning is logged for non-standard dataset IDs."""
240
- from stroke_deepisles_demo.data.adapter import logger
241
 
242
- with patch.object(logger, "warning") as mock_warning:
243
- build_huggingface_dataset("some-other/dataset")
244
- mock_warning.assert_called_once()
245
- assert "does not match pre-computed constants" in mock_warning.call_args[0][0]
246
 
 
 
247
 
248
- class TestDownloadCaseFromParquet:
249
- """Tests for _download_case_from_parquet method."""
250
-
251
- def test_raises_data_load_error_on_malformed_data(self) -> None:
252
- """Test that _download_case_from_parquet raises DataLoadError for malformed data."""
253
- import pandas as pd # type: ignore[import-untyped]
254
-
255
- case_ids = ["sub-stroke0001"]
256
- case_index = {"sub-stroke0001": 0}
257
-
258
- ds = HuggingFaceDataset(
259
- dataset_id="test/dataset",
260
- _case_ids=case_ids,
261
- _case_index=case_index,
262
- )
263
 
264
- # Create mock with missing 'bytes' key
265
- mock_df = pd.DataFrame(
266
- [
267
- {
268
- "subject_id": "sub-stroke0001",
269
- "dwi": {}, # Missing 'bytes'
270
- "adc": {},
271
- "lesion_mask": None,
272
- }
273
- ]
274
- )
275
 
276
- mock_table = MagicMock()
277
- mock_table.to_pandas.return_value = mock_df
278
 
279
- mock_pf = MagicMock()
280
- mock_pf.read.return_value = mock_table
 
 
 
 
 
281
 
282
- mock_file = MagicMock()
283
- mock_file.__enter__ = MagicMock(return_value=mock_file)
284
- mock_file.__exit__ = MagicMock(return_value=False)
285
 
286
- mock_fs = MagicMock()
287
- mock_fs.open.return_value = mock_file
288
 
289
- # Patch at the source module where they're imported, not where they're used
290
- with (
291
- patch("huggingface_hub.HfFileSystem", return_value=mock_fs),
292
- patch("pyarrow.parquet.ParquetFile", return_value=mock_pf),
293
- pytest.raises(DataLoadError, match="Malformed HuggingFace data"),
294
- ):
295
- ds._download_case_from_parquet(0, "sub-stroke0001")
 
1
+ """Unit tests for HuggingFace dataset wrapper."""
2
 
3
  from __future__ import annotations
4
 
5
+ from typing import Any
6
+ from unittest.mock import MagicMock
7
 
8
  import pytest
9
 
10
+ from stroke_deepisles_demo.data.loader import HuggingFaceDatasetWrapper
 
11
 
12
 
13
+ class TestHuggingFaceDatasetWrapper:
14
+ """Tests for HuggingFaceDatasetWrapper class."""
15
 
16
+ @pytest.fixture
17
+ def mock_hf_dataset(self) -> MagicMock:
18
+ """Create a mock HuggingFace dataset."""
19
+ dataset = MagicMock()
20
 
21
+ # Mock dataset length
22
+ dataset.__len__.return_value = 3
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Mock column access for fast index building
25
+ # This simulates dataset["subject_id"]
26
+ dataset.__getitem__.side_effect = lambda key: (
27
+ ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
28
+ if key == "subject_id"
29
+ else MagicMock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
+ return dataset
33
+
34
+ def test_init_builds_index_correctly(self, mock_hf_dataset: MagicMock) -> None:
35
+ """Test that initialization builds the subject ID index."""
36
+ wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
37
+
38
+ assert len(wrapper) == 3
39
+ assert wrapper.list_case_ids() == ["sub-stroke0001", "sub-stroke0002", "sub-stroke0003"]
40
+ assert wrapper._case_id_to_index["sub-stroke0001"] == 0
41
+ assert wrapper._case_id_to_index["sub-stroke0003"] == 2
42
+
43
+ def test_get_case_materializes_files(self, mock_hf_dataset: MagicMock) -> None:
44
+ """Test that get_case materializes NIfTI objects to files."""
45
+ # Setup row return for get_case
46
+ mock_dwi = MagicMock()
47
+ mock_adc = MagicMock()
48
+ mock_mask = MagicMock()
49
+
50
+ row_data = {
51
+ "subject_id": "sub-stroke0001",
52
+ "dwi": mock_dwi,
53
+ "adc": mock_adc,
54
+ "lesion_mask": mock_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  }
56
 
57
+ # Reset side_effect to return row for integer index
58
+ mock_hf_dataset.__getitem__.side_effect = (
59
+ lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
+ wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
 
 
 
63
 
64
+ with wrapper:
65
+ case = wrapper.get_case("sub-stroke0001")
 
 
66
 
67
+ # Verify file paths
68
+ assert case["dwi"].name == "sub-stroke0001_dwi.nii.gz"
69
+ assert case["adc"].name == "sub-stroke0001_adc.nii.gz"
70
+ assert case["ground_truth"].name == "sub-stroke0001_lesion-msk.nii.gz"
71
 
72
+ # Verify to_filename called
73
+ mock_dwi.to_filename.assert_called_once()
74
+ mock_adc.to_filename.assert_called_once()
75
+ mock_mask.to_filename.assert_called_once()
76
 
77
+ # Verify temporary directory usage
78
+ assert wrapper._temp_dir is not None
79
+ assert case["dwi"].parent == wrapper._temp_dir / "sub-stroke0001"
 
 
80
 
81
+ def test_get_case_handles_missing_mask(self, mock_hf_dataset: MagicMock) -> None:
82
+ """Test that get_case handles cases without lesion mask."""
83
+ row_data = {
84
+ "subject_id": "sub-stroke0002",
85
+ "dwi": MagicMock(),
86
+ "adc": MagicMock(),
87
+ "lesion_mask": None,
88
  }
89
 
90
+ mock_hf_dataset.__getitem__.side_effect = (
91
+ lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0002"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
94
+ wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ with wrapper:
97
+ case = wrapper.get_case("sub-stroke0002")
98
 
99
+ assert "dwi" in case
100
+ assert "adc" in case
101
+ assert "ground_truth" not in case
 
102
 
103
+ def test_cleanup_removes_temp_dir(self, mock_hf_dataset: MagicMock) -> None:
104
+ """Test that cleanup removes the temporary directory."""
105
+ row_data = {
106
+ "subject_id": "sub-stroke0001",
107
+ "dwi": MagicMock(),
108
+ "adc": MagicMock(),
109
+ "lesion_mask": None,
110
+ }
111
+ mock_hf_dataset.__getitem__.side_effect = (
112
+ lambda idx: row_data if isinstance(idx, int) else ["sub-stroke0001"]
113
  )
114
 
115
+ wrapper = HuggingFaceDatasetWrapper(mock_hf_dataset, "test/dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ # Create temp dir by accessing a case
118
+ wrapper.get_case(0)
119
+ temp_dir = wrapper._temp_dir
120
 
121
+ assert temp_dir is not None
122
+ assert temp_dir.exists()
 
 
123
 
124
+ # cleanup
125
+ wrapper.cleanup()
126
 
127
+ assert not temp_dir.exists()
128
+ assert wrapper._temp_dir is None
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ def test_fallback_iteration(self) -> None:
131
+ """Test fallback to iteration if column access fails."""
132
+ dataset = MagicMock()
133
+ dataset.__len__.return_value = 2
 
 
 
 
 
 
 
134
 
135
+ # Configure iteration for fallback
136
+ dataset.__iter__.return_value = iter([{"subject_id": "sub-0"}, {"subject_id": "sub-1"}])
137
 
138
+ # Fail column access
139
+ def getitem(key: Any) -> Any:
140
+ if key == "subject_id":
141
+ raise ValueError("No column access")
142
+ if isinstance(key, int):
143
+ return {"subject_id": f"sub-{key}"}
144
+ return MagicMock()
145
 
146
+ dataset.__getitem__.side_effect = getitem
 
 
147
 
148
+ wrapper = HuggingFaceDatasetWrapper(dataset, "test/dataset")
 
149
 
150
+ assert wrapper._case_id_to_index["sub-0"] == 0
151
+ assert wrapper._case_id_to_index["sub-1"] == 1
 
 
 
 
 
tests/data/test_loader.py CHANGED
@@ -4,12 +4,12 @@ from __future__ import annotations
4
 
5
  import os
6
  from typing import TYPE_CHECKING
7
- from unittest.mock import patch
8
 
9
  import pytest
10
 
11
- from stroke_deepisles_demo.data.adapter import HuggingFaceDataset, LocalDataset, logger
12
- from stroke_deepisles_demo.data.loader import load_isles_dataset
13
 
14
  if TYPE_CHECKING:
15
  from pathlib import Path
@@ -35,31 +35,31 @@ def test_load_from_local_finds_all_cases(synthetic_isles_dir: Path) -> None:
35
  assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
36
 
37
 
38
- def test_load_hf_warns_on_non_standard_dataset() -> None:
39
- """Test that loading a non-standard HF dataset logs a warning.
 
 
 
 
 
 
40
 
41
- Note: With pre-computed case IDs, the dataset ID mismatch is only detected
42
- at build time (warning logged), not at get_case() time. The actual 404 error
43
- would only occur when trying to download a case that doesn't exist.
44
- """
45
- with patch.object(logger, "warning") as mock_warning:
46
- ds = load_isles_dataset(source="fake/nonexistent-dataset", local_mode=False)
47
- mock_warning.assert_called_once()
48
- assert "does not match pre-computed constants" in mock_warning.call_args[0][0]
49
- # Dataset is still created with pre-computed case IDs
50
- assert isinstance(ds, HuggingFaceDataset)
51
- assert len(ds) == 149 # Uses pre-computed list
52
 
53
 
54
  @pytest.mark.integration
55
  @SKIP_IN_CI
56
  def test_load_from_huggingface_returns_hf_dataset() -> None:
57
- """Test that loading from HuggingFace returns a HuggingFaceDataset.
58
 
59
  Note: Skipped in CI due to large download size (~GB) and limited disk space.
60
  Run locally with: pytest -m integration tests/data/test_loader.py
61
  """
62
  with load_isles_dataset() as dataset: # Default is HuggingFace mode
63
- assert isinstance(dataset, HuggingFaceDataset)
64
- assert len(dataset) == 149
65
- assert dataset.list_case_ids()[0] == "sub-stroke0001"
 
4
 
5
  import os
6
  from typing import TYPE_CHECKING
7
+ from unittest.mock import MagicMock, patch
8
 
9
  import pytest
10
 
11
+ from stroke_deepisles_demo.data.adapter import LocalDataset
12
+ from stroke_deepisles_demo.data.loader import HuggingFaceDatasetWrapper, load_isles_dataset
13
 
14
  if TYPE_CHECKING:
15
  from pathlib import Path
 
35
  assert dataset.list_case_ids() == ["sub-stroke0001", "sub-stroke0002"]
36
 
37
 
38
+ def test_load_hf_calls_load_dataset() -> None:
39
+ """Test that loading from HF calls datasets.load_dataset."""
40
+ with patch("datasets.load_dataset") as mock_load:
41
+ mock_ds = MagicMock()
42
+ mock_ds.__len__.return_value = 0
43
+ # Mock column access for index building
44
+ mock_ds.__getitem__.side_effect = lambda key: [] if key == "subject_id" else MagicMock()
45
+ mock_load.return_value = mock_ds
46
 
47
+ ds = load_isles_dataset(source="my/dataset", local_mode=False)
48
+
49
+ assert isinstance(ds, HuggingFaceDatasetWrapper)
50
+ mock_load.assert_called_once()
51
+ assert mock_load.call_args[0][0] == "my/dataset"
 
 
 
 
 
 
52
 
53
 
54
  @pytest.mark.integration
55
  @SKIP_IN_CI
56
  def test_load_from_huggingface_returns_hf_dataset() -> None:
57
+ """Test that loading from HuggingFace returns a HuggingFaceDatasetWrapper.
58
 
59
  Note: Skipped in CI due to large download size (~GB) and limited disk space.
60
  Run locally with: pytest -m integration tests/data/test_loader.py
61
  """
62
  with load_isles_dataset() as dataset: # Default is HuggingFace mode
63
+ assert isinstance(dataset, HuggingFaceDatasetWrapper)
64
+ # We can't guarantee length if we don't mock, but we can check type
65
+ # Real test might fail if network issue or auth issue