VibecoderMcSwaggins's picture
fix: resolve technical debt (P2/P3) with TDD validation (#9)
26f14be unverified
raw
history blame
7.04 kB
"""End-to-end pipeline orchestration."""
from __future__ import annotations
import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder
if TYPE_CHECKING:
from collections.abc import Sequence
from stroke_deepisles_demo.core.types import CaseFiles
logger = get_logger(__name__)
@dataclass(frozen=True)
class PipelineResult:
"""Complete result of running the pipeline on a case."""
case_id: str
input_files: CaseFiles
staged_dir: Path
prediction_mask: Path
ground_truth: Path | None
dice_score: float | None # None if ground truth unavailable or not computed
elapsed_seconds: float
@dataclass(frozen=True)
class PipelineSummary:
"""Summary statistics from multiple pipeline runs."""
num_cases: int
num_successful: int
num_failed: int
mean_dice: float | None
std_dice: float | None
min_dice: float | None
max_dice: float | None
mean_elapsed_seconds: float
def run_pipeline_on_case(
case_id: str | int,
*,
dataset_id: str | None = None,
output_dir: Path | None = None,
fast: bool = True,
gpu: bool = True,
compute_dice: bool = True,
cleanup_staging: bool = True,
) -> PipelineResult:
"""
Run the complete segmentation pipeline on a single case.
Args:
case_id: Case identifier (string) or index (int)
dataset_id: HF dataset ID (default from settings - currently ignored/local)
output_dir: Directory for results (default: temp dir)
fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed)
gpu: Use GPU acceleration
compute_dice: Compute Dice score if ground truth available
cleanup_staging: Remove staging directory after inference
Returns:
PipelineResult with all paths and optional metrics
"""
# Note: dataset_id is currently unused as we default to local loading.
# It's kept for interface compatibility with future cloud mode.
_ = dataset_id
start_time = time.time()
# 1. Load Dataset
dataset = load_isles_dataset() # Uses default local path for now
# Resolve ID if integer
if isinstance(case_id, int):
all_ids = dataset.list_case_ids()
if case_id < 0 or case_id >= len(all_ids):
raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
resolved_case_id = all_ids[case_id]
else:
resolved_case_id = case_id
# Get case files
case_files = dataset.get_case(resolved_case_id)
# 2. Stage Files
# Use a temp dir for staging if output_dir not provided, or a subdir of output_dir
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
staging_root = output_dir / "staging" / resolved_case_id
results_dir = output_dir / resolved_case_id
else:
# If no output dir, we create a temp dir that persists (unless cleanup requested)
# But wait, the user wants paths. If we use tempfile.TemporaryDirectory context,
# it disappears. We should use mkdtemp or let stage_case handle it.
# Let's use a temp dir for staging.
base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
staging_root = base_temp / "staging"
results_dir = base_temp / "results"
staged = stage_case_for_deepisles(case_files, staging_root)
# 3. Run Inference
inference_result = run_deepisles_on_folder(
staged.input_dir,
output_dir=results_dir,
fast=fast,
gpu=gpu,
)
# 4. Compute Metrics
dice_score: float | None = None
ground_truth = case_files.get("ground_truth")
if compute_dice and ground_truth and ground_truth.exists():
try:
dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
except Exception:
logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)
# 5. Cleanup (Optional)
if cleanup_staging:
shutil.rmtree(staging_root, ignore_errors=True)
elapsed = time.time() - start_time
return PipelineResult(
case_id=resolved_case_id,
input_files=case_files,
staged_dir=staged.input_dir,
prediction_mask=inference_result.prediction_path,
ground_truth=ground_truth,
dice_score=dice_score,
elapsed_seconds=elapsed,
)
def run_pipeline_on_batch(
case_ids: Sequence[str | int],
*,
max_workers: int = 1,
**kwargs: object,
) -> list[PipelineResult]:
"""
Run pipeline on multiple cases.
Note: Parallel execution requires multiple GPUs or sequential mode.
Currently only sequential execution is implemented (max_workers is ignored).
Args:
case_ids: List of case identifiers or indices
max_workers: Number of parallel workers (default 1 for sequential).
Currently ignored - reserved for future parallel support.
**kwargs: Passed to run_pipeline_on_case
Returns:
List of PipelineResult, one per case
"""
# Currently only sequential execution is supported.
# max_workers is accepted for API compatibility but ignored.
_ = max_workers
results: list[PipelineResult] = []
for case_id in case_ids:
result = run_pipeline_on_case(case_id, **kwargs) # type: ignore[arg-type]
results.append(result)
return results
def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
"""
Compute summary statistics from multiple pipeline results.
Returns:
Summary with mean Dice, success rate, etc.
"""
# Filter results with valid dice scores
dice_scores = [r.dice_score for r in results if r.dice_score is not None]
elapsed_times = [r.elapsed_seconds for r in results]
num_cases = len(results)
# We assume all passed results are "successful" runs (failed runs raise exceptions)
num_successful = num_cases
num_failed = 0
if dice_scores:
mean_dice = statistics.mean(dice_scores)
std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
min_dice = min(dice_scores)
max_dice = max(dice_scores)
else:
mean_dice = None
std_dice = None
min_dice = None
max_dice = None
mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0
return PipelineSummary(
num_cases=num_cases,
num_successful=num_successful,
num_failed=num_failed,
mean_dice=mean_dice,
std_dice=std_dice,
min_dice=min_dice,
max_dice=max_dice,
mean_elapsed_seconds=mean_elapsed,
)