|
|
"""Provide typed access to ISLES24 cases.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import re |
|
|
import shutil |
|
|
import tempfile |
|
|
from dataclasses import dataclass, field |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING, Any, Self |
|
|
|
|
|
from stroke_deepisles_demo.core.exceptions import DataLoadError |
|
|
from stroke_deepisles_demo.core.logging import get_logger |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Iterator |
|
|
|
|
|
from stroke_deepisles_demo.core.types import CaseFiles |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LocalDataset: |
|
|
"""File-based dataset for local ISLES24 data. |
|
|
|
|
|
Can be used as a context manager for consistency with HuggingFaceDataset, |
|
|
though no cleanup is needed for local files. |
|
|
|
|
|
Example: |
|
|
with build_local_dataset(path) as ds: |
|
|
case = ds.get_case(0) |
|
|
""" |
|
|
|
|
|
data_dir: Path |
|
|
cases: dict[str, CaseFiles] |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.cases) |
|
|
|
|
|
def __iter__(self) -> Iterator[str]: |
|
|
return iter(self.cases.keys()) |
|
|
|
|
|
def __enter__(self) -> Self: |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args: object) -> None: |
|
|
|
|
|
pass |
|
|
|
|
|
def list_case_ids(self) -> list[str]: |
|
|
"""Return sorted list of subject IDs.""" |
|
|
return sorted(self.cases.keys()) |
|
|
|
|
|
def get_case(self, case_id: str | int) -> CaseFiles: |
|
|
"""Get files for a case by ID or index.""" |
|
|
if isinstance(case_id, int): |
|
|
case_id = self.list_case_ids()[case_id] |
|
|
return self.cases[case_id] |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""No-op for local dataset (files are not temporary).""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz") |
|
|
|
|
|
|
|
|
def parse_subject_id(filename: str) -> str | None: |
|
|
"""Extract subject ID from BIDS filename.""" |
|
|
match = SUBJECT_PATTERN.match(filename) |
|
|
return f"sub-{match.group(1)}" if match else None |
|
|
|
|
|
|
|
|
def build_local_dataset(data_dir: Path) -> LocalDataset: |
|
|
""" |
|
|
Scan directory and build case mapping. |
|
|
|
|
|
Matches DWI + ADC + Mask files by subject ID. |
|
|
Logs warnings for incomplete cases that are skipped. |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If DWI subdirectory (Images-DWI) is missing |
|
|
""" |
|
|
dwi_dir = data_dir / "Images-DWI" |
|
|
adc_dir = data_dir / "Images-ADC" |
|
|
mask_dir = data_dir / "Masks" |
|
|
|
|
|
if not dwi_dir.exists(): |
|
|
raise FileNotFoundError(f"Data directory not found or invalid: {dwi_dir}") |
|
|
|
|
|
cases: dict[str, CaseFiles] = {} |
|
|
skipped_no_subject_id = 0 |
|
|
skipped_no_adc: list[str] = [] |
|
|
|
|
|
|
|
|
for dwi_file in dwi_dir.glob("*.nii.gz"): |
|
|
subject_id = parse_subject_id(dwi_file.name) |
|
|
if not subject_id: |
|
|
skipped_no_subject_id += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.") |
|
|
mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.") |
|
|
|
|
|
if not adc_file.exists(): |
|
|
skipped_no_adc.append(subject_id) |
|
|
continue |
|
|
|
|
|
case_files: CaseFiles = { |
|
|
"dwi": dwi_file, |
|
|
"adc": adc_file, |
|
|
} |
|
|
if mask_file.exists(): |
|
|
case_files["ground_truth"] = mask_file |
|
|
|
|
|
cases[subject_id] = case_files |
|
|
|
|
|
|
|
|
if skipped_no_subject_id > 0: |
|
|
logger.warning( |
|
|
"Skipped %d DWI files: could not parse subject ID from filename", |
|
|
skipped_no_subject_id, |
|
|
) |
|
|
if skipped_no_adc: |
|
|
logger.warning( |
|
|
"Skipped %d cases missing ADC file: %s", |
|
|
len(skipped_no_adc), |
|
|
", ".join(skipped_no_adc[:5]) + ("..." if len(skipped_no_adc) > 5 else ""), |
|
|
) |
|
|
|
|
|
logger.info("Loaded %d cases from %s", len(cases), data_dir) |
|
|
return LocalDataset(data_dir=data_dir, cases=cases) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class HuggingFaceDataset: |
|
|
"""Dataset adapter for HuggingFace ISLES24 dataset. |
|
|
|
|
|
Wraps the HuggingFace dataset and provides the same interface as LocalDataset. |
|
|
When get_case() is called, writes NIfTI bytes to temp files and returns paths. |
|
|
|
|
|
IMPORTANT: Use as a context manager to ensure temp files are cleaned up: |
|
|
|
|
|
with load_isles_dataset() as ds: |
|
|
case = ds.get_case(0) |
|
|
# ... process case ... |
|
|
# temp files automatically cleaned up |
|
|
|
|
|
Or call cleanup() manually when done. |
|
|
""" |
|
|
|
|
|
dataset_id: str |
|
|
_hf_dataset: Any = field(repr=False) |
|
|
_case_ids: list[str] = field(default_factory=list) |
|
|
_temp_dir: Path | None = field(default=None, repr=False) |
|
|
_cached_cases: dict[str, CaseFiles] = field(default_factory=dict, repr=False) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self._hf_dataset) |
|
|
|
|
|
def __iter__(self) -> Iterator[str]: |
|
|
return iter(self._case_ids) |
|
|
|
|
|
def __enter__(self) -> Self: |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args: object) -> None: |
|
|
self.cleanup() |
|
|
|
|
|
def list_case_ids(self) -> list[str]: |
|
|
"""Return sorted list of subject IDs.""" |
|
|
return sorted(self._case_ids) |
|
|
|
|
|
def get_case(self, case_id: str | int) -> CaseFiles: |
|
|
"""Get files for a case by ID or index. |
|
|
|
|
|
Writes NIfTI bytes to temp files on first access; returns cached paths |
|
|
on subsequent calls for the same case. |
|
|
|
|
|
Raises: |
|
|
DataError: If HuggingFace data is malformed or missing required fields. |
|
|
""" |
|
|
if isinstance(case_id, int): |
|
|
idx = case_id |
|
|
subject_id = self._case_ids[idx] |
|
|
else: |
|
|
subject_id = case_id |
|
|
idx = self._case_ids.index(subject_id) |
|
|
|
|
|
|
|
|
if subject_id in self._cached_cases: |
|
|
return self._cached_cases[subject_id] |
|
|
|
|
|
|
|
|
if self._temp_dir is None: |
|
|
self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_")) |
|
|
logger.debug("Created temp directory: %s", self._temp_dir) |
|
|
|
|
|
|
|
|
example = self._hf_dataset[idx] |
|
|
|
|
|
|
|
|
case_dir = self._temp_dir / subject_id |
|
|
case_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
dwi_path = case_dir / f"{subject_id}_ses-02_dwi.nii.gz" |
|
|
adc_path = case_dir / f"{subject_id}_ses-02_adc.nii.gz" |
|
|
mask_path = case_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz" |
|
|
|
|
|
|
|
|
try: |
|
|
dwi_bytes = example["dwi"]["bytes"] |
|
|
adc_bytes = example["adc"]["bytes"] |
|
|
except (KeyError, TypeError) as e: |
|
|
raise DataLoadError( |
|
|
f"Malformed HuggingFace data for {subject_id}: missing 'dwi' or 'adc' bytes. " |
|
|
f"The dataset schema may have changed. Error: {e}" |
|
|
) from e |
|
|
|
|
|
|
|
|
dwi_path.write_bytes(dwi_bytes) |
|
|
adc_path.write_bytes(adc_bytes) |
|
|
|
|
|
case_files: CaseFiles = { |
|
|
"dwi": dwi_path, |
|
|
"adc": adc_path, |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
mask_data = example.get("lesion_mask") |
|
|
if mask_data and mask_data.get("bytes"): |
|
|
mask_path.write_bytes(mask_data["bytes"]) |
|
|
case_files["ground_truth"] = mask_path |
|
|
except (KeyError, TypeError): |
|
|
|
|
|
logger.debug("No lesion mask available for %s", subject_id) |
|
|
|
|
|
|
|
|
self._cached_cases[subject_id] = case_files |
|
|
|
|
|
return case_files |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""Remove temp directory and clear cache.""" |
|
|
if self._temp_dir and self._temp_dir.exists(): |
|
|
shutil.rmtree(self._temp_dir, ignore_errors=True) |
|
|
logger.debug("Cleaned up temp directory: %s", self._temp_dir) |
|
|
self._temp_dir = None |
|
|
self._cached_cases.clear() |
|
|
|
|
|
|
|
|
def build_huggingface_dataset(dataset_id: str) -> HuggingFaceDataset: |
|
|
""" |
|
|
Load ISLES24 dataset from HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
dataset_id: HuggingFace dataset identifier (e.g., "hugging-science/isles24-stroke") |
|
|
|
|
|
Returns: |
|
|
HuggingFaceDataset providing case access |
|
|
""" |
|
|
from datasets import load_dataset |
|
|
|
|
|
logger.info("Loading HuggingFace dataset: %s", dataset_id) |
|
|
hf_dataset = load_dataset(dataset_id, split="train") |
|
|
|
|
|
|
|
|
case_ids = [example["subject_id"] for example in hf_dataset] |
|
|
|
|
|
logger.info("Loaded %d cases from HuggingFace: %s", len(case_ids), dataset_id) |
|
|
|
|
|
return HuggingFaceDataset( |
|
|
dataset_id=dataset_id, |
|
|
_hf_dataset=hf_dataset, |
|
|
_case_ids=case_ids, |
|
|
) |
|
|
|