"""Load ISLES24 data from local directory or HuggingFace Hub.""" from __future__ import annotations import re import shutil import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Protocol, Self from stroke_deepisles_demo.core.logging import get_logger from stroke_deepisles_demo.core.types import CaseFiles # noqa: TC001 from stroke_deepisles_demo.data.isles24_manifest import ( ISLES24_DATASET_ID, ISLES24_DATASET_REVISION, ISLES24_TRAIN_CASE_IDS, isles24_train_data_file, ) # Security: Regex for valid ISLES24 subject IDs (defense-in-depth) # Expected format: sub-strokeXXXX (e.g., sub-stroke0001) _SAFE_SUBJECT_ID_PATTERN = re.compile(r"^sub-stroke\d{4}$") if TYPE_CHECKING: from datasets import Dataset as HFDataset logger = get_logger(__name__) class Dataset(Protocol): """Protocol for dataset access. All dataset implementations support context manager usage for proper cleanup: with load_isles_dataset() as ds: case = ds.get_case(0) # ... process case ... # cleanup happens automatically """ def __len__(self) -> int: ... def __enter__(self) -> Self: ... def __exit__(self, *args: object) -> None: ... def list_case_ids(self) -> list[str]: ... def get_case(self, case_id: str | int) -> CaseFiles: ... def cleanup(self) -> None: ... @dataclass class DatasetInfo: """Metadata about the dataset.""" source: str # "local" or HF dataset ID num_cases: int modalities: list[str] has_ground_truth: bool @dataclass class HuggingFaceDatasetWrapper: """Wrapper for HuggingFace dataset to match the Dataset protocol. Uses the standard datasets library (with neuroimaging-go-brrrr patched Nifti feature) to load data. Materializes NIfTI images to temporary files on demand. """ dataset: HFDataset dataset_id: str _temp_dir: Path | None = field(default=None, repr=False) _case_id_to_index: dict[str, int] = field(default_factory=dict, repr=False) def __post_init__(self) -> None: """Build index of subject IDs for O(1) lookup.""" try: # Efficiently build index from 'subject_id' column self._case_id_to_index = { sid: idx for idx, sid in enumerate(self.dataset["subject_id"]) } except (KeyError, TypeError, ValueError) as e: logger.warning( "Failed to build index from subject_id column: %s. Fallback to iteration.", e ) for idx, item in enumerate(self.dataset): self._case_id_to_index[item["subject_id"]] = idx def __len__(self) -> int: return len(self.dataset) def __enter__(self) -> Self: return self def __exit__(self, *args: object) -> None: self.cleanup() def list_case_ids(self) -> list[str]: return sorted(self._case_id_to_index.keys()) def get_case(self, case_id: str | int) -> CaseFiles: """Get files for a case by ID or index. Materializes NIfTI objects to temporary files. """ # Resolve case_id to index if isinstance(case_id, int): if case_id < 0 or case_id >= len(self.dataset): raise IndexError(f"Case index {case_id} out of range") idx = case_id else: if case_id not in self._case_id_to_index: raise KeyError(f"Case ID {case_id} not found") idx = self._case_id_to_index[case_id] row = self.dataset[idx] subject_id = row["subject_id"] # Security: Validate subject_id before using in path (defense-in-depth) if not _SAFE_SUBJECT_ID_PATTERN.match(subject_id): raise ValueError( f"Invalid subject_id format: {subject_id!r}. Expected format: sub-strokeXXXX" ) # Prepare temp dir if self._temp_dir is None: self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_")) case_dir = self._temp_dir / subject_id case_dir.mkdir(exist_ok=True) dwi_path = case_dir / f"{subject_id}_dwi.nii.gz" adc_path = case_dir / f"{subject_id}_adc.nii.gz" # Materialize files if they don't exist if not dwi_path.exists(): row["dwi"].to_filename(str(dwi_path)) if not adc_path.exists(): row["adc"].to_filename(str(adc_path)) case_files: CaseFiles = { "dwi": dwi_path, "adc": adc_path, } # Handle lesion mask (mapped to ground_truth) if "lesion_mask" in row and row["lesion_mask"] is not None: mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz" if not mask_path.exists(): row["lesion_mask"].to_filename(str(mask_path)) case_files["ground_truth"] = mask_path return case_files def cleanup(self) -> None: if self._temp_dir and self._temp_dir.exists(): try: shutil.rmtree(self._temp_dir) except OSError as e: logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e) self._temp_dir = None @dataclass class Isles24HuggingFaceDataset: """ISLES24 dataset access optimized for HF Spaces. Key behavior: - `list_case_ids()` returns from a pinned manifest (no dataset download). - `get_case()` loads exactly one Parquet shard via `data_files=...` (no 27GB eager download). This class exists because `datasets.load_dataset(dataset_id, split="train")` can trigger an eager full-dataset download/prepare on cold starts, which is not viable for API endpoints like `/api/cases` on Hugging Face Spaces. """ dataset_id: str = ISLES24_DATASET_ID token: str | None = None revision: str = ISLES24_DATASET_REVISION _temp_dir: Path | None = field(default=None, repr=False) def __len__(self) -> int: return len(ISLES24_TRAIN_CASE_IDS) def __enter__(self) -> Self: return self def __exit__(self, *args: object) -> None: self.cleanup() def list_case_ids(self) -> list[str]: return list(ISLES24_TRAIN_CASE_IDS) def get_case(self, case_id: str | int) -> CaseFiles: """Load files for a single ISLES24 case. Args: case_id: Case identifier (e.g., "sub-stroke0102") or 0-based integer index. """ from datasets import load_dataset if isinstance(case_id, int): if case_id < 0 or case_id >= len(ISLES24_TRAIN_CASE_IDS): raise IndexError(f"Case index {case_id} out of range") resolved_case_id = ISLES24_TRAIN_CASE_IDS[case_id] else: resolved_case_id = case_id # Security: Validate subject_id before using in path (defense-in-depth) if not _SAFE_SUBJECT_ID_PATTERN.match(resolved_case_id): raise ValueError( f"Invalid subject_id format: {resolved_case_id!r}. Expected format: sub-strokeXXXX" ) # Load exactly one shard (1 case per parquet file in this dataset) data_file = isles24_train_data_file(resolved_case_id) ds = load_dataset( self.dataset_id, data_files={"train": data_file}, split="train", token=self.token, revision=self.revision, ) ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"]) if len(ds) != 1: raise RuntimeError(f"Expected 1 row for {resolved_case_id}, got {len(ds)}") row = ds[0] subject_id = row["subject_id"] if subject_id != resolved_case_id: raise RuntimeError( f"Unexpected subject_id {subject_id!r} in {data_file} (expected {resolved_case_id!r})" ) if self._temp_dir is None: self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_wrapper_")) case_dir = self._temp_dir / subject_id case_dir.mkdir(exist_ok=True) dwi_path = case_dir / f"{subject_id}_dwi.nii.gz" adc_path = case_dir / f"{subject_id}_adc.nii.gz" if not dwi_path.exists(): row["dwi"].to_filename(str(dwi_path)) if not adc_path.exists(): row["adc"].to_filename(str(adc_path)) case_files: CaseFiles = { "dwi": dwi_path, "adc": adc_path, } if row.get("lesion_mask") is not None: mask_path = case_dir / f"{subject_id}_lesion-msk.nii.gz" if not mask_path.exists(): row["lesion_mask"].to_filename(str(mask_path)) case_files["ground_truth"] = mask_path return case_files def cleanup(self) -> None: if self._temp_dir and self._temp_dir.exists(): try: shutil.rmtree(self._temp_dir) except OSError as e: logger.warning("Failed to cleanup temp directory %s: %s", self._temp_dir, e) self._temp_dir = None def load_isles_dataset( source: str | Path | None = None, *, local_mode: bool | None = None, token: str | None = None, ) -> Dataset: """ Load ISLES24 dataset from local directory or HuggingFace Hub. Args: source: Local directory path or HuggingFace dataset ID. If None, uses Settings.hf_dataset_id from config. local_mode: If True, treat source as local directory. If None, auto-detect based on source type. token: HuggingFace token for private/gated datasets. If None, uses Settings.hf_token from config. Returns: Dataset-like object providing case access. Use as context manager for automatic cleanup of temp files (important for HuggingFace mode). Examples: # Load from HuggingFace with automatic cleanup (recommended) with load_isles_dataset() as ds: case = ds.get_case(0) # Load from local directory ds = load_isles_dataset("data/isles24", local_mode=True) # Load specific HuggingFace dataset with token ds = load_isles_dataset("org/private-dataset", token="hf_xxx") """ # Auto-detect mode if not specified if local_mode is None: if source is None: local_mode = False # Default to HuggingFace elif isinstance(source, Path): local_mode = True else: # String: check if it's an existing local path # Only select local mode if the path itself exists # (avoids misclassifying HF dataset IDs like "org/dataset") source_path = Path(source) local_mode = source_path.exists() if local_mode: from stroke_deepisles_demo.data.adapter import build_local_dataset if source is None: source = "data/isles24" return build_local_dataset(Path(source)) # HuggingFace mode from datasets import load_dataset from stroke_deepisles_demo.core.config import get_settings settings = get_settings() # Use settings defaults if not specified dataset_id = str(source) if source else settings.hf_dataset_id hf_token = token if token is not None else settings.hf_token if dataset_id == ISLES24_DATASET_ID: return Isles24HuggingFaceDataset(dataset_id=dataset_id, token=hf_token) # Load dataset, selecting only necessary columns to minimize decoding overhead # We rely on neuroimaging-go-brrrr's Nifti feature for lazy loading if configured, # but select_columns ensures we don't touch other modalities. # Token enables access to private/gated datasets ds = load_dataset(dataset_id, split="train", token=hf_token) ds = ds.select_columns(["subject_id", "dwi", "adc", "lesion_mask"]) return HuggingFaceDatasetWrapper(ds, dataset_id)