Claude
fix: remove silent failures and redundant checks
be12b50 unverified
"""End-to-end pipeline orchestration."""
from __future__ import annotations
import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder
if TYPE_CHECKING:
from collections.abc import Sequence
from stroke_deepisles_demo.core.types import CaseFiles
logger = get_logger(__name__)
@dataclass(frozen=True)
class PipelineResult:
"""Complete result of running the pipeline on a case.
All file paths in this result point to valid, accessible files in results_dir.
Callers are responsible for cleaning up results_dir when done (if desired).
"""
case_id: str
input_files: CaseFiles # Copied to results_dir; always valid paths
results_dir: Path # Directory containing all result files (for cleanup)
prediction_mask: Path
ground_truth: Path | None
dice_score: float | None # None if ground truth unavailable or not computed
elapsed_seconds: float
@dataclass(frozen=True)
class PipelineSummary:
"""Summary statistics from multiple pipeline runs."""
num_cases: int
num_successful: int
num_failed: int
mean_dice: float | None
std_dice: float | None
min_dice: float | None
max_dice: float | None
mean_elapsed_seconds: float
def run_pipeline_on_case(
case_id: str | int,
*,
dataset_id: str | None = None,
output_dir: Path | None = None,
fast: bool = True,
gpu: bool = True,
compute_dice: bool = True,
cleanup_staging: bool = True,
) -> PipelineResult:
"""
Run the complete segmentation pipeline on a single case.
Args:
case_id: Case identifier (string) or index (int)
dataset_id: HF dataset ID (default from settings - currently ignored/local)
output_dir: Directory for results (default: temp dir)
fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed)
gpu: Use GPU acceleration
compute_dice: Compute Dice score if ground truth available
cleanup_staging: Remove staging directory after inference
Returns:
PipelineResult with all paths and optional metrics
"""
# Note: dataset_id is currently unused as we default to local loading.
# It's kept for interface compatibility with future cloud mode.
_ = dataset_id
start_time = time.time()
# Use context manager to ensure HuggingFace temp files are cleaned up
# This prevents unbounded disk usage from accumulating temp NIfTI files
with load_isles_dataset() as dataset:
# Resolve ID if integer
if isinstance(case_id, int):
all_ids = dataset.list_case_ids()
if case_id < 0 or case_id >= len(all_ids):
raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
resolved_case_id = all_ids[case_id]
else:
resolved_case_id = case_id
# Set up output directories (now that we have resolved_case_id)
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
staging_root = output_dir / "staging" / resolved_case_id
results_dir = output_dir / resolved_case_id
else:
base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
staging_root = base_temp / "staging"
results_dir = base_temp / "results"
# Get case files
case_files = dataset.get_case(resolved_case_id)
# Stage files (copies DWI/ADC to staging directory)
staged = stage_case_for_deepisles(case_files, staging_root)
# Copy input files to results_dir before dataset cleanup
# (HuggingFace mode stores files in temp dirs that get cleaned up)
# This ensures all paths in PipelineResult remain valid after function returns
results_dir.mkdir(parents=True, exist_ok=True)
# Copy DWI (required for UI visualization)
dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz"
shutil.copy2(case_files["dwi"], dwi_dest)
# Copy ADC
adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz"
shutil.copy2(case_files["adc"], adc_dest)
# Copy ground truth if available
ground_truth: Path | None = None
original_ground_truth = case_files.get("ground_truth")
if original_ground_truth and original_ground_truth.exists():
ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz"
shutil.copy2(original_ground_truth, ground_truth)
# Build input_files with copied paths (always valid after function returns)
preserved_input_files: CaseFiles = {
"dwi": dwi_dest,
"adc": adc_dest,
}
if ground_truth:
preserved_input_files["ground_truth"] = ground_truth
# Dataset temp files cleaned up here (context manager __exit__)
# 3. Run Inference
inference_result = run_deepisles_on_folder(
staged.input_dir,
output_dir=results_dir,
fast=fast,
gpu=gpu,
)
# 4. Compute Metrics (using copied ground truth)
dice_score: float | None = None
if compute_dice and ground_truth and ground_truth.exists():
try:
dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
except Exception:
logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)
# 5. Cleanup (Optional)
if cleanup_staging:
try:
shutil.rmtree(staging_root)
except OSError as e:
logger.warning("Failed to cleanup staging directory %s: %s", staging_root, e)
elapsed = time.time() - start_time
return PipelineResult(
case_id=resolved_case_id,
input_files=preserved_input_files,
results_dir=results_dir,
prediction_mask=inference_result.prediction_path,
ground_truth=ground_truth,
dice_score=dice_score,
elapsed_seconds=elapsed,
)
def run_pipeline_on_batch(
case_ids: Sequence[str | int],
*,
max_workers: int = 1,
**kwargs: object,
) -> list[PipelineResult]:
"""
Run pipeline on multiple cases.
Note: Parallel execution requires multiple GPUs or sequential mode.
Currently only sequential execution is implemented (max_workers is ignored).
Args:
case_ids: List of case identifiers or indices
max_workers: Number of parallel workers (default 1 for sequential).
Currently ignored - reserved for future parallel support.
**kwargs: Passed to run_pipeline_on_case
Returns:
List of PipelineResult, one per case
"""
# Currently only sequential execution is supported.
# max_workers is accepted for API compatibility but ignored.
_ = max_workers
results: list[PipelineResult] = []
for case_id in case_ids:
result = run_pipeline_on_case(case_id, **kwargs) # type: ignore[arg-type]
results.append(result)
return results
def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
"""
Compute summary statistics from multiple pipeline results.
Returns:
Summary with mean Dice, success rate, etc.
"""
# Filter results with valid dice scores
dice_scores = [r.dice_score for r in results if r.dice_score is not None]
elapsed_times = [r.elapsed_seconds for r in results]
num_cases = len(results)
# We assume all passed results are "successful" runs (failed runs raise exceptions)
num_successful = num_cases
num_failed = 0
if dice_scores:
mean_dice = statistics.mean(dice_scores)
std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
min_dice = min(dice_scores)
max_dice = max(dice_scores)
else:
mean_dice = None
std_dice = None
min_dice = None
max_dice = None
mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0
return PipelineSummary(
num_cases=num_cases,
num_successful=num_successful,
num_failed=num_failed,
mean_dice=mean_dice,
std_dice=std_dice,
min_dice=min_dice,
max_dice=max_dice,
mean_elapsed_seconds=mean_elapsed,
)