VibecoderMcSwaggins's picture
feat(data): integrate HuggingFace dataset as primary data source (#11)
363ba14 unverified
raw
history blame
9 kB
"""Provide typed access to ISLES24 cases."""
from __future__ import annotations
import re
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self
from stroke_deepisles_demo.core.exceptions import DataLoadError
from stroke_deepisles_demo.core.logging import get_logger
if TYPE_CHECKING:
from collections.abc import Iterator
from stroke_deepisles_demo.core.types import CaseFiles
logger = get_logger(__name__)
@dataclass
class LocalDataset:
"""File-based dataset for local ISLES24 data.
Can be used as a context manager for consistency with HuggingFaceDataset,
though no cleanup is needed for local files.
Example:
with build_local_dataset(path) as ds:
case = ds.get_case(0)
"""
data_dir: Path
cases: dict[str, CaseFiles] # subject_id -> files
def __len__(self) -> int:
return len(self.cases)
def __iter__(self) -> Iterator[str]:
return iter(self.cases.keys())
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object) -> None:
# No cleanup needed for local files
pass
def list_case_ids(self) -> list[str]:
"""Return sorted list of subject IDs."""
return sorted(self.cases.keys())
def get_case(self, case_id: str | int) -> CaseFiles:
"""Get files for a case by ID or index."""
if isinstance(case_id, int):
case_id = self.list_case_ids()[case_id]
return self.cases[case_id]
def cleanup(self) -> None:
"""No-op for local dataset (files are not temporary)."""
pass
# Subject ID extraction
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
def parse_subject_id(filename: str) -> str | None:
"""Extract subject ID from BIDS filename."""
match = SUBJECT_PATTERN.match(filename)
return f"sub-{match.group(1)}" if match else None
def build_local_dataset(data_dir: Path) -> LocalDataset:
"""
Scan directory and build case mapping.
Matches DWI + ADC + Mask files by subject ID.
Logs warnings for incomplete cases that are skipped.
Raises:
FileNotFoundError: If DWI subdirectory (Images-DWI) is missing
"""
dwi_dir = data_dir / "Images-DWI"
adc_dir = data_dir / "Images-ADC"
mask_dir = data_dir / "Masks"
if not dwi_dir.exists():
raise FileNotFoundError(f"Data directory not found or invalid: {dwi_dir}")
cases: dict[str, CaseFiles] = {}
skipped_no_subject_id = 0
skipped_no_adc: list[str] = []
# Scan DWI files to get subject IDs
for dwi_file in dwi_dir.glob("*.nii.gz"):
subject_id = parse_subject_id(dwi_file.name)
if not subject_id:
skipped_no_subject_id += 1
continue
# Find matching ADC and Mask
adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.")
mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.")
if not adc_file.exists():
skipped_no_adc.append(subject_id)
continue
case_files: CaseFiles = {
"dwi": dwi_file,
"adc": adc_file,
}
if mask_file.exists():
case_files["ground_truth"] = mask_file
cases[subject_id] = case_files
# Log skipped cases for debugging
if skipped_no_subject_id > 0:
logger.warning(
"Skipped %d DWI files: could not parse subject ID from filename",
skipped_no_subject_id,
)
if skipped_no_adc:
logger.warning(
"Skipped %d cases missing ADC file: %s",
len(skipped_no_adc),
", ".join(skipped_no_adc[:5]) + ("..." if len(skipped_no_adc) > 5 else ""),
)
logger.info("Loaded %d cases from %s", len(cases), data_dir)
return LocalDataset(data_dir=data_dir, cases=cases)
# =============================================================================
# HuggingFace Dataset Adapter
# =============================================================================
@dataclass
class HuggingFaceDataset:
"""Dataset adapter for HuggingFace ISLES24 dataset.
Wraps the HuggingFace dataset and provides the same interface as LocalDataset.
When get_case() is called, writes NIfTI bytes to temp files and returns paths.
IMPORTANT: Use as a context manager to ensure temp files are cleaned up:
with load_isles_dataset() as ds:
case = ds.get_case(0)
# ... process case ...
# temp files automatically cleaned up
Or call cleanup() manually when done.
"""
dataset_id: str
_hf_dataset: Any = field(repr=False)
_case_ids: list[str] = field(default_factory=list)
_temp_dir: Path | None = field(default=None, repr=False)
_cached_cases: dict[str, CaseFiles] = field(default_factory=dict, repr=False)
def __len__(self) -> int:
return len(self._hf_dataset)
def __iter__(self) -> Iterator[str]:
return iter(self._case_ids)
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object) -> None:
self.cleanup()
def list_case_ids(self) -> list[str]:
"""Return sorted list of subject IDs."""
return sorted(self._case_ids)
def get_case(self, case_id: str | int) -> CaseFiles:
"""Get files for a case by ID or index.
Writes NIfTI bytes to temp files on first access; returns cached paths
on subsequent calls for the same case.
Raises:
DataError: If HuggingFace data is malformed or missing required fields.
"""
if isinstance(case_id, int):
idx = case_id
subject_id = self._case_ids[idx]
else:
subject_id = case_id
idx = self._case_ids.index(subject_id)
# Return cached case if already materialized
if subject_id in self._cached_cases:
return self._cached_cases[subject_id]
# Create shared temp directory on first use
if self._temp_dir is None:
self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_"))
logger.debug("Created temp directory: %s", self._temp_dir)
# Get the HuggingFace example
example = self._hf_dataset[idx]
# Create case subdirectory
case_dir = self._temp_dir / subject_id
case_dir.mkdir(exist_ok=True)
# Write NIfTI files to temp directory
dwi_path = case_dir / f"{subject_id}_ses-02_dwi.nii.gz"
adc_path = case_dir / f"{subject_id}_ses-02_adc.nii.gz"
mask_path = case_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz"
# Extract bytes with defensive error handling
try:
dwi_bytes = example["dwi"]["bytes"]
adc_bytes = example["adc"]["bytes"]
except (KeyError, TypeError) as e:
raise DataLoadError(
f"Malformed HuggingFace data for {subject_id}: missing 'dwi' or 'adc' bytes. "
f"The dataset schema may have changed. Error: {e}"
) from e
# Write the gzipped NIfTI bytes
dwi_path.write_bytes(dwi_bytes)
adc_path.write_bytes(adc_bytes)
case_files: CaseFiles = {
"dwi": dwi_path,
"adc": adc_path,
}
# Write lesion mask if available
try:
mask_data = example.get("lesion_mask")
if mask_data and mask_data.get("bytes"):
mask_path.write_bytes(mask_data["bytes"])
case_files["ground_truth"] = mask_path
except (KeyError, TypeError):
# Mask is optional, log and continue
logger.debug("No lesion mask available for %s", subject_id)
# Cache for subsequent calls
self._cached_cases[subject_id] = case_files
return case_files
def cleanup(self) -> None:
"""Remove temp directory and clear cache."""
if self._temp_dir and self._temp_dir.exists():
shutil.rmtree(self._temp_dir, ignore_errors=True)
logger.debug("Cleaned up temp directory: %s", self._temp_dir)
self._temp_dir = None
self._cached_cases.clear()
def build_huggingface_dataset(dataset_id: str) -> HuggingFaceDataset:
"""
Load ISLES24 dataset from HuggingFace Hub.
Args:
dataset_id: HuggingFace dataset identifier (e.g., "hugging-science/isles24-stroke")
Returns:
HuggingFaceDataset providing case access
"""
from datasets import load_dataset
logger.info("Loading HuggingFace dataset: %s", dataset_id)
hf_dataset = load_dataset(dataset_id, split="train")
# Extract case IDs
case_ids = [example["subject_id"] for example in hf_dataset]
logger.info("Loaded %d cases from HuggingFace: %s", len(case_ids), dataset_id)
return HuggingFaceDataset(
dataset_id=dataset_id,
_hf_dataset=hf_dataset,
_case_ids=case_ids,
)