|
|
"""Provide typed access to ISLES24 cases.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from stroke_deepisles_demo.core.logging import get_logger |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Iterator |
|
|
from pathlib import Path |
|
|
|
|
|
from stroke_deepisles_demo.core.types import CaseFiles |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LocalDataset: |
|
|
"""File-based dataset for local ISLES24 data.""" |
|
|
|
|
|
data_dir: Path |
|
|
cases: dict[str, CaseFiles] |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.cases) |
|
|
|
|
|
def __iter__(self) -> Iterator[str]: |
|
|
return iter(self.cases.keys()) |
|
|
|
|
|
def list_case_ids(self) -> list[str]: |
|
|
"""Return sorted list of subject IDs.""" |
|
|
return sorted(self.cases.keys()) |
|
|
|
|
|
def get_case(self, case_id: str | int) -> CaseFiles: |
|
|
"""Get files for a case by ID or index.""" |
|
|
if isinstance(case_id, int): |
|
|
case_id = self.list_case_ids()[case_id] |
|
|
return self.cases[case_id] |
|
|
|
|
|
|
|
|
|
|
|
SUBJECT_PATTERN = re.compile(r"sub-(stroke\d{4})_ses-\d+_.*\.nii\.gz") |
|
|
|
|
|
|
|
|
def parse_subject_id(filename: str) -> str | None: |
|
|
"""Extract subject ID from BIDS filename.""" |
|
|
match = SUBJECT_PATTERN.match(filename) |
|
|
return f"sub-{match.group(1)}" if match else None |
|
|
|
|
|
|
|
|
def build_local_dataset(data_dir: Path) -> LocalDataset: |
|
|
""" |
|
|
Scan directory and build case mapping. |
|
|
|
|
|
Matches DWI + ADC + Mask files by subject ID. |
|
|
Logs warnings for incomplete cases that are skipped. |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If DWI subdirectory (Images-DWI) is missing |
|
|
""" |
|
|
dwi_dir = data_dir / "Images-DWI" |
|
|
adc_dir = data_dir / "Images-ADC" |
|
|
mask_dir = data_dir / "Masks" |
|
|
|
|
|
if not dwi_dir.exists(): |
|
|
raise FileNotFoundError(f"Data directory not found or invalid: {dwi_dir}") |
|
|
|
|
|
cases: dict[str, CaseFiles] = {} |
|
|
skipped_no_subject_id = 0 |
|
|
skipped_no_adc: list[str] = [] |
|
|
|
|
|
|
|
|
for dwi_file in dwi_dir.glob("*.nii.gz"): |
|
|
subject_id = parse_subject_id(dwi_file.name) |
|
|
if not subject_id: |
|
|
skipped_no_subject_id += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
adc_file = adc_dir / dwi_file.name.replace("_dwi.", "_adc.") |
|
|
mask_file = mask_dir / dwi_file.name.replace("_dwi.", "_lesion-msk.") |
|
|
|
|
|
if not adc_file.exists(): |
|
|
skipped_no_adc.append(subject_id) |
|
|
continue |
|
|
|
|
|
case_files: CaseFiles = { |
|
|
"dwi": dwi_file, |
|
|
"adc": adc_file, |
|
|
} |
|
|
if mask_file.exists(): |
|
|
case_files["ground_truth"] = mask_file |
|
|
|
|
|
cases[subject_id] = case_files |
|
|
|
|
|
|
|
|
if skipped_no_subject_id > 0: |
|
|
logger.warning( |
|
|
"Skipped %d DWI files: could not parse subject ID from filename", |
|
|
skipped_no_subject_id, |
|
|
) |
|
|
if skipped_no_adc: |
|
|
logger.warning( |
|
|
"Skipped %d cases missing ADC file: %s", |
|
|
len(skipped_no_adc), |
|
|
", ".join(skipped_no_adc[:5]) + ("..." if len(skipped_no_adc) > 5 else ""), |
|
|
) |
|
|
|
|
|
logger.info("Loaded %d cases from %s", len(cases), data_dir) |
|
|
return LocalDataset(data_dir=data_dir, cases=cases) |
|
|
|