File size: 8,997 Bytes
aef1f5a 3c4c67b aef1f5a 363ba14 3c4c67b 363ba14 a544a50 3c4c67b aef1f5a 3c4c67b a544a50 3c4c67b aef1f5a 363ba14 3c4c67b aef1f5a 3c4c67b aef1f5a 3c4c67b aef1f5a 3c4c67b 363ba14 3c4c67b aef1f5a 3c4c67b aef1f5a 3c4c67b 363ba14 3c4c67b aef1f5a 3c4c67b aef1f5a a544a50 26f14be aef1f5a 26f14be aef1f5a a544a50 aef1f5a a544a50 aef1f5a a544a50 aef1f5a a544a50 aef1f5a 363ba14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
"""Provide typed access to ISLES24 cases."""
from __future__ import annotations
import re
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self
from stroke_deepisles_demo.core.exceptions import DataLoadError
from stroke_deepisles_demo.core.logging import get_logger
if TYPE_CHECKING:
from collections.abc import Iterator
from stroke_deepisles_demo.core.types import CaseFiles
logger = get_logger(__name__)
@dataclass
class LocalDataset:
"""File-based dataset for local ISLES24 data.
Can be used as a context manager for consistency with HuggingFaceDataset,
though no cleanup is needed for local files.
Example:
with build_local_dataset(path) as ds:
case = ds.get_case(0)
"""
data_dir: Path
cases: dict[str, CaseFiles] # subject_id -> files
def __len__(self) -> int:
return len(self.cases)
def __iter__(self) -> Iterator[str]:
return iter(self.cases.keys())
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object) -> None:
# No cleanup needed for local files
pass
def list_case_ids(self) -> list[str]:
"""Return sorted list of subject IDs."""
return sorted(self.cases.keys())
def get_case(self, case_id: str | int) -> CaseFiles:
"""Get files for a case by ID or index."""
if isinstance(case_id, int):
case_id = self.list_case_ids()[case_id]
return self.cases[case_id]
def cleanup(self) -> None:
"""No-op for local dataset (files are not temporary)."""
pass
# Subject ID extraction
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz")
def parse_subject_id(filename: str) -> str | None:
"""Extract subject ID from BIDS filename."""
match = SUBJECT_PATTERN.match(filename)
return f"sub-{match.group(1)}" if match else None
def build_local_dataset(data_dir: Path) -> LocalDataset:
"""
Scan directory and build case mapping.
Matches DWI + ADC + Mask files by subject ID.
Logs warnings for incomplete cases that are skipped.
Raises:
FileNotFoundError: If DWI subdirectory (Images-DWI) is missing
"""
dwi_dir = data_dir / "Images-DWI"
adc_dir = data_dir / "Images-ADC"
mask_dir = data_dir / "Masks"
if not dwi_dir.exists():
raise FileNotFoundError(f"Data directory not found or invalid: {dwi_dir}")
cases: dict[str, CaseFiles] = {}
skipped_no_subject_id = 0
skipped_no_adc: list[str] = []
# Scan DWI files to get subject IDs
for dwi_file in dwi_dir.glob("*.nii.gz"):
subject_id = parse_subject_id(dwi_file.name)
if not subject_id:
skipped_no_subject_id += 1
continue
# Find matching ADC and Mask
adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.")
mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.")
if not adc_file.exists():
skipped_no_adc.append(subject_id)
continue
case_files: CaseFiles = {
"dwi": dwi_file,
"adc": adc_file,
}
if mask_file.exists():
case_files["ground_truth"] = mask_file
cases[subject_id] = case_files
# Log skipped cases for debugging
if skipped_no_subject_id > 0:
logger.warning(
"Skipped %d DWI files: could not parse subject ID from filename",
skipped_no_subject_id,
)
if skipped_no_adc:
logger.warning(
"Skipped %d cases missing ADC file: %s",
len(skipped_no_adc),
", ".join(skipped_no_adc[:5]) + ("..." if len(skipped_no_adc) > 5 else ""),
)
logger.info("Loaded %d cases from %s", len(cases), data_dir)
return LocalDataset(data_dir=data_dir, cases=cases)
# =============================================================================
# HuggingFace Dataset Adapter
# =============================================================================
@dataclass
class HuggingFaceDataset:
"""Dataset adapter for HuggingFace ISLES24 dataset.
Wraps the HuggingFace dataset and provides the same interface as LocalDataset.
When get_case() is called, writes NIfTI bytes to temp files and returns paths.
IMPORTANT: Use as a context manager to ensure temp files are cleaned up:
with load_isles_dataset() as ds:
case = ds.get_case(0)
# ... process case ...
# temp files automatically cleaned up
Or call cleanup() manually when done.
"""
dataset_id: str
_hf_dataset: Any = field(repr=False)
_case_ids: list[str] = field(default_factory=list)
_temp_dir: Path | None = field(default=None, repr=False)
_cached_cases: dict[str, CaseFiles] = field(default_factory=dict, repr=False)
def __len__(self) -> int:
return len(self._hf_dataset)
def __iter__(self) -> Iterator[str]:
return iter(self._case_ids)
def __enter__(self) -> Self:
return self
def __exit__(self, *args: object) -> None:
self.cleanup()
def list_case_ids(self) -> list[str]:
"""Return sorted list of subject IDs."""
return sorted(self._case_ids)
def get_case(self, case_id: str | int) -> CaseFiles:
"""Get files for a case by ID or index.
Writes NIfTI bytes to temp files on first access; returns cached paths
on subsequent calls for the same case.
Raises:
DataError: If HuggingFace data is malformed or missing required fields.
"""
if isinstance(case_id, int):
idx = case_id
subject_id = self._case_ids[idx]
else:
subject_id = case_id
idx = self._case_ids.index(subject_id)
# Return cached case if already materialized
if subject_id in self._cached_cases:
return self._cached_cases[subject_id]
# Create shared temp directory on first use
if self._temp_dir is None:
self._temp_dir = Path(tempfile.mkdtemp(prefix="isles24_hf_"))
logger.debug("Created temp directory: %s", self._temp_dir)
# Get the HuggingFace example
example = self._hf_dataset[idx]
# Create case subdirectory
case_dir = self._temp_dir / subject_id
case_dir.mkdir(exist_ok=True)
# Write NIfTI files to temp directory
dwi_path = case_dir / f"{subject_id}_ses-02_dwi.nii.gz"
adc_path = case_dir / f"{subject_id}_ses-02_adc.nii.gz"
mask_path = case_dir / f"{subject_id}_ses-02_lesion-msk.nii.gz"
# Extract bytes with defensive error handling
try:
dwi_bytes = example["dwi"]["bytes"]
adc_bytes = example["adc"]["bytes"]
except (KeyError, TypeError) as e:
raise DataLoadError(
f"Malformed HuggingFace data for {subject_id}: missing 'dwi' or 'adc' bytes. "
f"The dataset schema may have changed. Error: {e}"
) from e
# Write the gzipped NIfTI bytes
dwi_path.write_bytes(dwi_bytes)
adc_path.write_bytes(adc_bytes)
case_files: CaseFiles = {
"dwi": dwi_path,
"adc": adc_path,
}
# Write lesion mask if available
try:
mask_data = example.get("lesion_mask")
if mask_data and mask_data.get("bytes"):
mask_path.write_bytes(mask_data["bytes"])
case_files["ground_truth"] = mask_path
except (KeyError, TypeError):
# Mask is optional, log and continue
logger.debug("No lesion mask available for %s", subject_id)
# Cache for subsequent calls
self._cached_cases[subject_id] = case_files
return case_files
def cleanup(self) -> None:
"""Remove temp directory and clear cache."""
if self._temp_dir and self._temp_dir.exists():
shutil.rmtree(self._temp_dir, ignore_errors=True)
logger.debug("Cleaned up temp directory: %s", self._temp_dir)
self._temp_dir = None
self._cached_cases.clear()
def build_huggingface_dataset(dataset_id: str) -> HuggingFaceDataset:
"""
Load ISLES24 dataset from HuggingFace Hub.
Args:
dataset_id: HuggingFace dataset identifier (e.g., "hugging-science/isles24-stroke")
Returns:
HuggingFaceDataset providing case access
"""
from datasets import load_dataset
logger.info("Loading HuggingFace dataset: %s", dataset_id)
hf_dataset = load_dataset(dataset_id, split="train")
# Extract case IDs
case_ids = [example["subject_id"] for example in hf_dataset]
logger.info("Loaded %d cases from HuggingFace: %s", len(case_ids), dataset_id)
return HuggingFaceDataset(
dataset_id=dataset_id,
_hf_dataset=hf_dataset,
_case_ids=case_ids,
)
|