"""End-to-end pipeline orchestration.""" from __future__ import annotations import shutil import statistics import tempfile import time from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING from stroke_deepisles_demo import metrics from stroke_deepisles_demo.core.logging import get_logger from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles from stroke_deepisles_demo.inference import run_deepisles_on_folder if TYPE_CHECKING: from collections.abc import Sequence from stroke_deepisles_demo.core.types import CaseFiles logger = get_logger(__name__) @dataclass(frozen=True) class PipelineResult: """Complete result of running the pipeline on a case. All file paths in this result point to valid, accessible files in results_dir. Callers are responsible for cleaning up results_dir when done (if desired). """ case_id: str input_files: CaseFiles # Copied to results_dir; always valid paths results_dir: Path # Directory containing all result files (for cleanup) prediction_mask: Path ground_truth: Path | None dice_score: float | None # None if ground truth unavailable or not computed elapsed_seconds: float @dataclass(frozen=True) class PipelineSummary: """Summary statistics from multiple pipeline runs.""" num_cases: int num_successful: int num_failed: int mean_dice: float | None std_dice: float | None min_dice: float | None max_dice: float | None mean_elapsed_seconds: float def run_pipeline_on_case( case_id: str | int, *, dataset_id: str | None = None, output_dir: Path | None = None, fast: bool = True, gpu: bool = True, compute_dice: bool = True, cleanup_staging: bool = True, ) -> PipelineResult: """ Run the complete segmentation pipeline on a single case. Args: case_id: Case identifier (string) or index (int) dataset_id: HF dataset ID (default from settings - currently ignored/local) output_dir: Directory for results (default: temp dir) fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed) gpu: Use GPU acceleration compute_dice: Compute Dice score if ground truth available cleanup_staging: Remove staging directory after inference Returns: PipelineResult with all paths and optional metrics """ # Note: dataset_id is currently unused as we default to local loading. # It's kept for interface compatibility with future cloud mode. _ = dataset_id start_time = time.time() # Use context manager to ensure HuggingFace temp files are cleaned up # This prevents unbounded disk usage from accumulating temp NIfTI files with load_isles_dataset() as dataset: # Resolve ID if integer if isinstance(case_id, int): all_ids = dataset.list_case_ids() if case_id < 0 or case_id >= len(all_ids): raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})") resolved_case_id = all_ids[case_id] else: resolved_case_id = case_id # Set up output directories (now that we have resolved_case_id) if output_dir: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) staging_root = output_dir / "staging" / resolved_case_id results_dir = output_dir / resolved_case_id else: base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_")) staging_root = base_temp / "staging" results_dir = base_temp / "results" # Get case files case_files = dataset.get_case(resolved_case_id) # Stage files (copies DWI/ADC to staging directory) staged = stage_case_for_deepisles(case_files, staging_root) # Copy input files to results_dir before dataset cleanup # (HuggingFace mode stores files in temp dirs that get cleaned up) # This ensures all paths in PipelineResult remain valid after function returns results_dir.mkdir(parents=True, exist_ok=True) # Copy DWI (required for UI visualization) dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz" shutil.copy2(case_files["dwi"], dwi_dest) # Copy ADC adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz" shutil.copy2(case_files["adc"], adc_dest) # Copy ground truth if available ground_truth: Path | None = None original_ground_truth = case_files.get("ground_truth") if original_ground_truth and original_ground_truth.exists(): ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz" shutil.copy2(original_ground_truth, ground_truth) # Build input_files with copied paths (always valid after function returns) preserved_input_files: CaseFiles = { "dwi": dwi_dest, "adc": adc_dest, } if ground_truth: preserved_input_files["ground_truth"] = ground_truth # Dataset temp files cleaned up here (context manager __exit__) # 3. Run Inference inference_result = run_deepisles_on_folder( staged.input_dir, output_dir=results_dir, fast=fast, gpu=gpu, ) # 4. Compute Metrics (using copied ground truth) dice_score: float | None = None if compute_dice and ground_truth and ground_truth.exists(): try: dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth) except Exception: logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True) # 5. Cleanup (Optional) if cleanup_staging: try: shutil.rmtree(staging_root) except OSError as e: logger.warning("Failed to cleanup staging directory %s: %s", staging_root, e) elapsed = time.time() - start_time return PipelineResult( case_id=resolved_case_id, input_files=preserved_input_files, results_dir=results_dir, prediction_mask=inference_result.prediction_path, ground_truth=ground_truth, dice_score=dice_score, elapsed_seconds=elapsed, ) def run_pipeline_on_batch( case_ids: Sequence[str | int], *, max_workers: int = 1, **kwargs: object, ) -> list[PipelineResult]: """ Run pipeline on multiple cases. Note: Parallel execution requires multiple GPUs or sequential mode. Currently only sequential execution is implemented (max_workers is ignored). Args: case_ids: List of case identifiers or indices max_workers: Number of parallel workers (default 1 for sequential). Currently ignored - reserved for future parallel support. **kwargs: Passed to run_pipeline_on_case Returns: List of PipelineResult, one per case """ # Currently only sequential execution is supported. # max_workers is accepted for API compatibility but ignored. _ = max_workers results: list[PipelineResult] = [] for case_id in case_ids: result = run_pipeline_on_case(case_id, **kwargs) # type: ignore[arg-type] results.append(result) return results def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary: """ Compute summary statistics from multiple pipeline results. Returns: Summary with mean Dice, success rate, etc. """ # Filter results with valid dice scores dice_scores = [r.dice_score for r in results if r.dice_score is not None] elapsed_times = [r.elapsed_seconds for r in results] num_cases = len(results) # We assume all passed results are "successful" runs (failed runs raise exceptions) num_successful = num_cases num_failed = 0 if dice_scores: mean_dice = statistics.mean(dice_scores) std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0 min_dice = min(dice_scores) max_dice = max(dice_scores) else: mean_dice = None std_dice = None min_dice = None max_dice = None mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0 return PipelineSummary( num_cases=num_cases, num_successful=num_successful, num_failed=num_failed, mean_dice=mean_dice, std_dice=std_dice, min_dice=min_dice, max_dice=max_dice, mean_elapsed_seconds=mean_elapsed, )