|
|
"""End-to-end pipeline orchestration.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import shutil |
|
|
import statistics |
|
|
import tempfile |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from stroke_deepisles_demo import metrics |
|
|
from stroke_deepisles_demo.core.logging import get_logger |
|
|
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles |
|
|
from stroke_deepisles_demo.inference import run_deepisles_on_folder |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Sequence |
|
|
|
|
|
from stroke_deepisles_demo.core.types import CaseFiles |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class PipelineResult: |
|
|
"""Complete result of running the pipeline on a case. |
|
|
|
|
|
All file paths in this result point to valid, accessible files in results_dir. |
|
|
Callers are responsible for cleaning up results_dir when done (if desired). |
|
|
""" |
|
|
|
|
|
case_id: str |
|
|
input_files: CaseFiles |
|
|
results_dir: Path |
|
|
prediction_mask: Path |
|
|
ground_truth: Path | None |
|
|
dice_score: float | None |
|
|
elapsed_seconds: float |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class PipelineSummary: |
|
|
"""Summary statistics from multiple pipeline runs.""" |
|
|
|
|
|
num_cases: int |
|
|
num_successful: int |
|
|
num_failed: int |
|
|
mean_dice: float | None |
|
|
std_dice: float | None |
|
|
min_dice: float | None |
|
|
max_dice: float | None |
|
|
mean_elapsed_seconds: float |
|
|
|
|
|
|
|
|
def run_pipeline_on_case( |
|
|
case_id: str | int, |
|
|
*, |
|
|
dataset_id: str | None = None, |
|
|
output_dir: Path | None = None, |
|
|
fast: bool | None = None, |
|
|
gpu: bool | None = None, |
|
|
timeout: float | None = None, |
|
|
compute_dice: bool = True, |
|
|
cleanup_staging: bool = True, |
|
|
) -> PipelineResult: |
|
|
""" |
|
|
Run the complete segmentation pipeline on a single case. |
|
|
|
|
|
Args: |
|
|
case_id: Case identifier (string) or index (int) |
|
|
dataset_id: HF dataset ID (default from Settings.hf_dataset_id) |
|
|
output_dir: Directory for results (default: temp dir) |
|
|
fast: Use SEALS-only mode (default from Settings.deepisles_fast_mode) |
|
|
gpu: Use GPU acceleration (default from Settings.deepisles_use_gpu) |
|
|
timeout: Maximum inference time in seconds (default from Settings.deepisles_timeout_seconds) |
|
|
compute_dice: Compute Dice score if ground truth available |
|
|
cleanup_staging: Remove staging directory after inference |
|
|
|
|
|
Returns: |
|
|
PipelineResult with all paths and optional metrics |
|
|
""" |
|
|
from stroke_deepisles_demo.core.config import get_settings |
|
|
|
|
|
settings = get_settings() |
|
|
|
|
|
|
|
|
if fast is None: |
|
|
fast = settings.deepisles_fast_mode |
|
|
if gpu is None: |
|
|
gpu = settings.deepisles_use_gpu |
|
|
if timeout is None: |
|
|
timeout = settings.deepisles_timeout_seconds |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with load_isles_dataset(dataset_id) as dataset: |
|
|
|
|
|
if isinstance(case_id, int): |
|
|
all_ids = dataset.list_case_ids() |
|
|
if case_id < 0 or case_id >= len(all_ids): |
|
|
raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})") |
|
|
resolved_case_id = all_ids[case_id] |
|
|
else: |
|
|
resolved_case_id = case_id |
|
|
|
|
|
|
|
|
if output_dir: |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
staging_root = output_dir / "staging" / resolved_case_id |
|
|
results_dir = output_dir / resolved_case_id |
|
|
else: |
|
|
base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_")) |
|
|
staging_root = base_temp / "staging" |
|
|
results_dir = base_temp / "results" |
|
|
|
|
|
|
|
|
case_files = dataset.get_case(resolved_case_id) |
|
|
|
|
|
|
|
|
staged = stage_case_for_deepisles(case_files, staging_root) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz" |
|
|
shutil.copy2(case_files["dwi"], dwi_dest) |
|
|
|
|
|
|
|
|
adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz" |
|
|
shutil.copy2(case_files["adc"], adc_dest) |
|
|
|
|
|
|
|
|
ground_truth: Path | None = None |
|
|
original_ground_truth = case_files.get("ground_truth") |
|
|
if original_ground_truth and original_ground_truth.exists(): |
|
|
ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz" |
|
|
shutil.copy2(original_ground_truth, ground_truth) |
|
|
|
|
|
|
|
|
preserved_input_files: CaseFiles = { |
|
|
"dwi": dwi_dest, |
|
|
"adc": adc_dest, |
|
|
} |
|
|
if ground_truth: |
|
|
preserved_input_files["ground_truth"] = ground_truth |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_result = run_deepisles_on_folder( |
|
|
staged.input_dir, |
|
|
output_dir=results_dir, |
|
|
fast=fast, |
|
|
gpu=gpu, |
|
|
timeout=timeout, |
|
|
) |
|
|
|
|
|
|
|
|
dice_score: float | None = None |
|
|
if compute_dice and ground_truth and ground_truth.exists(): |
|
|
try: |
|
|
dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth) |
|
|
except Exception: |
|
|
logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True) |
|
|
|
|
|
|
|
|
if cleanup_staging: |
|
|
try: |
|
|
shutil.rmtree(staging_root) |
|
|
except OSError as e: |
|
|
logger.warning("Failed to cleanup staging directory %s: %s", staging_root, e) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
return PipelineResult( |
|
|
case_id=resolved_case_id, |
|
|
input_files=preserved_input_files, |
|
|
results_dir=results_dir, |
|
|
prediction_mask=inference_result.prediction_path, |
|
|
ground_truth=ground_truth, |
|
|
dice_score=dice_score, |
|
|
elapsed_seconds=elapsed, |
|
|
) |
|
|
|
|
|
|
|
|
def run_pipeline_on_batch( |
|
|
case_ids: Sequence[str | int], |
|
|
*, |
|
|
max_workers: int = 1, |
|
|
**kwargs: object, |
|
|
) -> list[PipelineResult]: |
|
|
""" |
|
|
Run pipeline on multiple cases. |
|
|
|
|
|
Note: Parallel execution requires multiple GPUs or sequential mode. |
|
|
Currently only sequential execution is implemented (max_workers is ignored). |
|
|
|
|
|
Args: |
|
|
case_ids: List of case identifiers or indices |
|
|
max_workers: Number of parallel workers (default 1 for sequential). |
|
|
Currently ignored - reserved for future parallel support. |
|
|
**kwargs: Passed to run_pipeline_on_case |
|
|
|
|
|
Returns: |
|
|
List of PipelineResult, one per case |
|
|
""" |
|
|
|
|
|
|
|
|
_ = max_workers |
|
|
|
|
|
results: list[PipelineResult] = [] |
|
|
for case_id in case_ids: |
|
|
result = run_pipeline_on_case(case_id, **kwargs) |
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary: |
|
|
""" |
|
|
Compute summary statistics from multiple pipeline results. |
|
|
|
|
|
Returns: |
|
|
Summary with mean Dice, success rate, etc. |
|
|
""" |
|
|
|
|
|
dice_scores = [r.dice_score for r in results if r.dice_score is not None] |
|
|
elapsed_times = [r.elapsed_seconds for r in results] |
|
|
|
|
|
num_cases = len(results) |
|
|
|
|
|
num_successful = num_cases |
|
|
num_failed = 0 |
|
|
|
|
|
if dice_scores: |
|
|
mean_dice = statistics.mean(dice_scores) |
|
|
std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0 |
|
|
min_dice = min(dice_scores) |
|
|
max_dice = max(dice_scores) |
|
|
else: |
|
|
mean_dice = None |
|
|
std_dice = None |
|
|
min_dice = None |
|
|
max_dice = None |
|
|
|
|
|
mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0 |
|
|
|
|
|
return PipelineSummary( |
|
|
num_cases=num_cases, |
|
|
num_successful=num_successful, |
|
|
num_failed=num_failed, |
|
|
mean_dice=mean_dice, |
|
|
std_dice=std_dice, |
|
|
min_dice=min_dice, |
|
|
max_dice=max_dice, |
|
|
mean_elapsed_seconds=mean_elapsed, |
|
|
) |
|
|
|