File size: 7,042 Bytes
3f8bf9c bfe80c5 3f8bf9c bfe80c5 3f8bf9c 26f14be 3f8bf9c a544a50 3f8bf9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
"""End-to-end pipeline orchestration."""
from __future__ import annotations
import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder
if TYPE_CHECKING:
from collections.abc import Sequence
from stroke_deepisles_demo.core.types import CaseFiles
logger = get_logger(__name__)
@dataclass(frozen=True)
class PipelineResult:
"""Complete result of running the pipeline on a case."""
case_id: str
input_files: CaseFiles
staged_dir: Path
prediction_mask: Path
ground_truth: Path | None
dice_score: float | None # None if ground truth unavailable or not computed
elapsed_seconds: float
@dataclass(frozen=True)
class PipelineSummary:
"""Summary statistics from multiple pipeline runs."""
num_cases: int
num_successful: int
num_failed: int
mean_dice: float | None
std_dice: float | None
min_dice: float | None
max_dice: float | None
mean_elapsed_seconds: float
def run_pipeline_on_case(
case_id: str | int,
*,
dataset_id: str | None = None,
output_dir: Path | None = None,
fast: bool = True,
gpu: bool = True,
compute_dice: bool = True,
cleanup_staging: bool = True,
) -> PipelineResult:
"""
Run the complete segmentation pipeline on a single case.
Args:
case_id: Case identifier (string) or index (int)
dataset_id: HF dataset ID (default from settings - currently ignored/local)
output_dir: Directory for results (default: temp dir)
fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed)
gpu: Use GPU acceleration
compute_dice: Compute Dice score if ground truth available
cleanup_staging: Remove staging directory after inference
Returns:
PipelineResult with all paths and optional metrics
"""
# Note: dataset_id is currently unused as we default to local loading.
# It's kept for interface compatibility with future cloud mode.
_ = dataset_id
start_time = time.time()
# 1. Load Dataset
dataset = load_isles_dataset() # Uses default local path for now
# Resolve ID if integer
if isinstance(case_id, int):
all_ids = dataset.list_case_ids()
if case_id < 0 or case_id >= len(all_ids):
raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
resolved_case_id = all_ids[case_id]
else:
resolved_case_id = case_id
# Get case files
case_files = dataset.get_case(resolved_case_id)
# 2. Stage Files
# Use a temp dir for staging if output_dir not provided, or a subdir of output_dir
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
staging_root = output_dir / "staging" / resolved_case_id
results_dir = output_dir / resolved_case_id
else:
# If no output dir, we create a temp dir that persists (unless cleanup requested)
# But wait, the user wants paths. If we use tempfile.TemporaryDirectory context,
# it disappears. We should use mkdtemp or let stage_case handle it.
# Let's use a temp dir for staging.
base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
staging_root = base_temp / "staging"
results_dir = base_temp / "results"
staged = stage_case_for_deepisles(case_files, staging_root)
# 3. Run Inference
inference_result = run_deepisles_on_folder(
staged.input_dir,
output_dir=results_dir,
fast=fast,
gpu=gpu,
)
# 4. Compute Metrics
dice_score: float | None = None
ground_truth = case_files.get("ground_truth")
if compute_dice and ground_truth and ground_truth.exists():
try:
dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
except Exception:
logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)
# 5. Cleanup (Optional)
if cleanup_staging:
shutil.rmtree(staging_root, ignore_errors=True)
elapsed = time.time() - start_time
return PipelineResult(
case_id=resolved_case_id,
input_files=case_files,
staged_dir=staged.input_dir,
prediction_mask=inference_result.prediction_path,
ground_truth=ground_truth,
dice_score=dice_score,
elapsed_seconds=elapsed,
)
def run_pipeline_on_batch(
case_ids: Sequence[str | int],
*,
max_workers: int = 1,
**kwargs: object,
) -> list[PipelineResult]:
"""
Run pipeline on multiple cases.
Note: Parallel execution requires multiple GPUs or sequential mode.
Currently only sequential execution is implemented (max_workers is ignored).
Args:
case_ids: List of case identifiers or indices
max_workers: Number of parallel workers (default 1 for sequential).
Currently ignored - reserved for future parallel support.
**kwargs: Passed to run_pipeline_on_case
Returns:
List of PipelineResult, one per case
"""
# Currently only sequential execution is supported.
# max_workers is accepted for API compatibility but ignored.
_ = max_workers
results: list[PipelineResult] = []
for case_id in case_ids:
result = run_pipeline_on_case(case_id, **kwargs) # type: ignore[arg-type]
results.append(result)
return results
def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
"""
Compute summary statistics from multiple pipeline results.
Returns:
Summary with mean Dice, success rate, etc.
"""
# Filter results with valid dice scores
dice_scores = [r.dice_score for r in results if r.dice_score is not None]
elapsed_times = [r.elapsed_seconds for r in results]
num_cases = len(results)
# We assume all passed results are "successful" runs (failed runs raise exceptions)
num_successful = num_cases
num_failed = 0
if dice_scores:
mean_dice = statistics.mean(dice_scores)
std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
min_dice = min(dice_scores)
max_dice = max(dice_scores)
else:
mean_dice = None
std_dice = None
min_dice = None
max_dice = None
mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0
return PipelineSummary(
num_cases=num_cases,
num_successful=num_successful,
num_failed=num_failed,
mean_dice=mean_dice,
std_dice=std_dice,
min_dice=min_dice,
max_dice=max_dice,
mean_elapsed_seconds=mean_elapsed,
)
|