File size: 11,910 Bytes
aef1f5a 3c4c67b 785d976 262b3cb aef1f5a 363ba14 3c4c67b 262b3cb 7e5ddec 262b3cb 785d976 aef1f5a 262b3cb 363ba14 3c4c67b aef1f5a 3c4c67b aef1f5a 3c4c67b 262b3cb 785d976 262b3cb 7e5ddec 3c4c67b 363ba14 3c4c67b 363ba14 fa1717e 363ba14 3c4c67b 363ba14 3c4c67b 363ba14 fa1717e 363ba14 fa1717e 3c4c67b 363ba14 3c4c67b 363ba14 fa1717e 3c4c67b 363ba14 aef1f5a 3c4c67b 363ba14 aef1f5a 3c4c67b 363ba14 262b3cb fa1717e 262b3cb 7e5ddec 262b3cb fa1717e 262b3cb 363ba14 262b3cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
"""Load ISLES24 data from local directory or HuggingFace Hub."""
from __future__ import annotations
import re
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Protocol, Self
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.core.types import CaseFiles # noqa: TC001
from stroke_deepisles_demo.data.isles24_manifest import (
ISLES24_DATASET_ID,
ISLES24_DATASET_REVISION,
ISLES24_TRAIN_CASE_IDS,
isles24_train_data_file,
)
# Security: Regex for valid ISLES24 subject IDs (defense-in-depth)
# Expected format: sub-strokeXXXX (e.g., sub-stroke0001)
_SAFE_SUBJECT_ID_PATTERN = re.compile(r"^sub-stroke\d{4}$")
if TYPE_CHECKING:
from datasets import Dataset as HFDataset
logger = get_logger(__name__)
class Dataset(Protocol):
"""Protocol for dataset access.
All dataset implementations support context manager usage for proper cleanup:
with load_isles_dataset() as ds:
case = ds.get_case(0)
# ... process case ...
# cleanup happens automatically
"""
def __len__(self) -> int: ...
def __enter__(self) -> Self: ...
def __exit__(self, *args: object) -> None: ...
def list_case_ids(self) -> list[str]: ...
def get_case(self, case_id: str | int) -> CaseFiles: ...
def cleanup(self) -> None: ...
@dataclass
class DatasetInfo:
"""Metadata about the dataset."""
source: str # "local" or HF dataset ID
num_cases: int
modalities: list[str]
has_ground_truth: bool
@dataclass
class HuggingFaceDatasetWrapper:
"""Wrapper for HuggingFace dataset to match the Dataset protocol.
Uses the standard datasets library (with neuroimaging-go-brrrr patched Nifti feature)
to load data. Materializes NIfTI images to temporary files on demand.
"""
dataset: HFDataset
dataset_id: str
_temp_dir: Path | None = field(default=None, repr=False)
_case_id_to_index: dict[str, int] = field(default_factory=dict, repr=False)
def __post_init__(self) -> None:
"""Build index of subject IDs for O(1) lookup."""
try:
# Efficiently build index from 'subject_id' column
self._case_id_to_index = {
sid: idx for idx, sid in enumerate(self.dataset["subject_id"])
}
except (KeyError, TypeError, ValueError) as e:
logger.warning(
"Failed to build index from subject_id column: %s. Fallback to iteration.", e
)
for idx, item in enumerate(self.dataset):
self._case_id_to_index[item["subject_id"]] = idx
def __len__(self) -> int:
return len(self.dataset)
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object) -> None:
self.cleanup()
def list_case_ids(self) -> list[str]:
return sorted(self._case_id_to_index.keys())
def get_case(self, case_id: str | int) -> CaseFiles:
"""Get files for a case by ID or index.
Materializes NIfTI objects to temporary files.
"""
# Resolve case_id to index
if isinstance(case_id, int):
if case_id < 0 or case_id >= len(self.dataset):
raise IndexError(f"Case index {case_id} out of range")
idx = case_id
else:
if case_id not in self._case_id_to_index:
raise KeyError(f"Case ID {case_id} not found")
idx = self._case_id_to_index[case_id]
row = self.dataset[idx]
subject_id = row["subject_id"]
# Security: Validate subject_id before using in path (defense-in-depth)
if not _SAFE_SUBJECT_ID_PATTERN.match(subject_id):
raise ValueError(
f"Invalid subject_id format: {subject_id!r}. Expected format: sub-strokeXXXX"
)
# Prepare temp dir
if self._temp_dir is None:
self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_"))
case_dir = self._temp_dir / subject_id
case_dir.mkdir(exist_ok=True)
dwi_path = case_dir / f"{subject_id}_dwi.nii.gz"
adc_path = case_dir / f"{subject_id}_adc.nii.gz"
# Materialize files if they don't exist
if not dwi_path.exists():
row["dwi"].to_filename(str(dwi_path))
if not adc_path.exists():
row["adc"].to_filename(str(adc_path))
case_files: CaseFiles = {
"dwi": dwi_path,
"adc": adc_path,
}
# Handle lesion mask (mapped to ground_truth)
if "lesion_mask" in row and row["lesion_mask"] is not None:
mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz"
if not mask_path.exists():
row["lesion_mask"].to_filename(str(mask_path))
case_files["ground_truth"] = mask_path
return case_files
def cleanup(self) -> None:
if self._temp_dir and self._temp_dir.exists():
try:
shutil.rmtree(self._temp_dir)
except OSError as e:
logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e)
self._temp_dir = None
@dataclass
class Isles24HuggingFaceDataset:
"""ISLES24 dataset access optimized for HF Spaces.
Key behavior:
- `list_case_ids()` returns from a pinned manifest (no dataset download).
- `get_case()` loads exactly one Parquet shard via `data_files=...` (no 27GB eager download).
This class exists because `datasets.load_dataset(dataset_id, split="train")` can
trigger an eager full-dataset download/prepare on cold starts, which is not viable
for API endpoints like `/api/cases` on Hugging Face Spaces.
"""
dataset_id: str = ISLES24_DATASET_ID
token: str | None = None
revision: str = ISLES24_DATASET_REVISION
_temp_dir: Path | None = field(default=None, repr=False)
def __len__(self) -> int:
return len(ISLES24_TRAIN_CASE_IDS)
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object) -> None:
self.cleanup()
def list_case_ids(self) -> list[str]:
return list(ISLES24_TRAIN_CASE_IDS)
def get_case(self, case_id: str | int) -> CaseFiles:
"""Load files for a single ISLES24 case.
Args:
case_id: Case identifier (e.g., "sub-stroke0102") or 0-based integer index.
"""
from datasets import load_dataset
if isinstance(case_id, int):
if case_id < 0 or case_id >= len(ISLES24_TRAIN_CASE_IDS):
raise IndexError(f"Case index {case_id} out of range")
resolved_case_id = ISLES24_TRAIN_CASE_IDS[case_id]
else:
resolved_case_id = case_id
# Security: Validate subject_id before using in path (defense-in-depth)
if not _SAFE_SUBJECT_ID_PATTERN.match(resolved_case_id):
raise ValueError(
f"Invalid subject_id format: {resolved_case_id!r}. Expected format: sub-strokeXXXX"
)
# Load exactly one shard (1 case per parquet file in this dataset)
data_file = isles24_train_data_file(resolved_case_id)
ds = load_dataset(
self.dataset_id,
data_files={"train": data_file},
split="train",
token=self.token,
revision=self.revision,
)
ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"])
if len(ds) != 1:
raise RuntimeError(f"Expected 1 row for {resolved_case_id}, got {len(ds)}")
row = ds[0]
subject_id = row["subject_id"]
if subject_id != resolved_case_id:
raise RuntimeError(
f"Unexpected subject_id {subject_id!r} in {data_file} (expected {resolved_case_id!r})"
)
if self._temp_dir is None:
self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_"))
case_dir = self._temp_dir / subject_id
case_dir.mkdir(exist_ok=True)
dwi_path = case_dir / f"{subject_id}_dwi.nii.gz"
adc_path = case_dir / f"{subject_id}_adc.nii.gz"
if not dwi_path.exists():
row["dwi"].to_filename(str(dwi_path))
if not adc_path.exists():
row["adc"].to_filename(str(adc_path))
case_files: CaseFiles = {
"dwi": dwi_path,
"adc": adc_path,
}
if row.get("lesion_mask") is not None:
mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz"
if not mask_path.exists():
row["lesion_mask"].to_filename(str(mask_path))
case_files["ground_truth"] = mask_path
return case_files
def cleanup(self) -> None:
if self._temp_dir and self._temp_dir.exists():
try:
shutil.rmtree(self._temp_dir)
except OSError as e:
logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e)
self._temp_dir = None
def load_isles_dataset(
source: str | Path | None = None,
*,
local_mode: bool | None = None,
token: str | None = None,
) -> Dataset:
"""
Load ISLES24 dataset from local directory or HuggingFace Hub.
Args:
source: Local directory path or HuggingFace dataset ID.
If None, uses Settings.hf_dataset_id from config.
local_mode: If True, treat source as local directory.
If None, auto-detect based on source type.
token: HuggingFace token for private/gated datasets.
If None, uses Settings.hf_token from config.
Returns:
Dataset-like object providing case access. Use as context manager
for automatic cleanup of temp files (important for HuggingFace mode).
Examples:
# Load from HuggingFace with automatic cleanup (recommended)
with load_isles_dataset() as ds:
case = ds.get_case(0)
# Load from local directory
ds = load_isles_dataset("data/isles24", local_mode=True)
# Load specific HuggingFace dataset with token
ds = load_isles_dataset("org/private-dataset", token="hf_xxx")
"""
# Auto-detect mode if not specified
if local_mode is None:
if source is None:
local_mode = False # Default to HuggingFace
elif isinstance(source, Path):
local_mode = True
else:
# String: check if it's an existing local path
# Only select local mode if the path itself exists
# (avoids misclassifying HF dataset IDs like "org/dataset")
source_path = Path(source)
local_mode = source_path.exists()
if local_mode:
from stroke_deepisles_demo.data.adapter import build_local_dataset
if source is None:
source = "data/isles24"
return build_local_dataset(Path(source))
# HuggingFace mode
from datasets import load_dataset
from stroke_deepisles_demo.core.config import get_settings
settings = get_settings()
# Use settings defaults if not specified
dataset_id = str(source) if source else settings.hf_dataset_id
hf_token = token if token is not None else settings.hf_token
if dataset_id == ISLES24_DATASET_ID:
return Isles24HuggingFaceDataset(dataset_id=dataset_id, token=hf_token)
# Load dataset, selecting only necessary columns to minimize decoding overhead
# We rely on neuroimaging-go-brrrr's Nifti feature for lazy loading if configured,
# but select_columns ensures we don't touch other modalities.
# Token enables access to private/gated datasets
ds = load_dataset(dataset_id, split="train", token=hf_token)
ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"])
return HuggingFaceDatasetWrapper(ds, dataset_id)
|