File size: 8,504 Bytes
3f8bf9c bfe80c5 3f8bf9c bfe80c5 3f8bf9c 4a455a4 3f8bf9c 4a455a4 3f8bf9c 26f14be 3f8bf9c 878d2e7 4a455a4 878d2e7 4a455a4 878d2e7 3f8bf9c 878d2e7 3f8bf9c a544a50 3f8bf9c 4a455a4 3f8bf9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
"""End-to-end pipeline orchestration."""
from __future__ import annotations
import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder
if TYPE_CHECKING:
from collections.abc import Sequence
from stroke_deepisles_demo.core.types import CaseFiles
logger = get_logger(__name__)
@dataclass(frozen=True)
class PipelineResult:
"""Complete result of running the pipeline on a case.
All file paths in this result point to valid, accessible files in results_dir.
Callers are responsible for cleaning up results_dir when done (if desired).
"""
case_id: str
input_files: CaseFiles # Copied to results_dir; always valid paths
results_dir: Path # Directory containing all result files (for cleanup)
prediction_mask: Path
ground_truth: Path | None
dice_score: float | None # None if ground truth unavailable or not computed
elapsed_seconds: float
@dataclass(frozen=True)
class PipelineSummary:
"""Summary statistics from multiple pipeline runs."""
num_cases: int
num_successful: int
num_failed: int
mean_dice: float | None
std_dice: float | None
min_dice: float | None
max_dice: float | None
mean_elapsed_seconds: float
def run_pipeline_on_case(
case_id: str | int,
*,
dataset_id: str | None = None,
output_dir: Path | None = None,
fast: bool = True,
gpu: bool = True,
compute_dice: bool = True,
cleanup_staging: bool = True,
) -> PipelineResult:
"""
Run the complete segmentation pipeline on a single case.
Args:
case_id: Case identifier (string) or index (int)
dataset_id: HF dataset ID (default from settings - currently ignored/local)
output_dir: Directory for results (default: temp dir)
fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed)
gpu: Use GPU acceleration
compute_dice: Compute Dice score if ground truth available
cleanup_staging: Remove staging directory after inference
Returns:
PipelineResult with all paths and optional metrics
"""
# Note: dataset_id is currently unused as we default to local loading.
# It's kept for interface compatibility with future cloud mode.
_ = dataset_id
start_time = time.time()
# Use context manager to ensure HuggingFace temp files are cleaned up
# This prevents unbounded disk usage from accumulating temp NIfTI files
with load_isles_dataset() as dataset:
# Resolve ID if integer
if isinstance(case_id, int):
all_ids = dataset.list_case_ids()
if case_id < 0 or case_id >= len(all_ids):
raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
resolved_case_id = all_ids[case_id]
else:
resolved_case_id = case_id
# Set up output directories (now that we have resolved_case_id)
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
staging_root = output_dir / "staging" / resolved_case_id
results_dir = output_dir / resolved_case_id
else:
base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
staging_root = base_temp / "staging"
results_dir = base_temp / "results"
# Get case files
case_files = dataset.get_case(resolved_case_id)
# Stage files (copies DWI/ADC to staging directory)
staged = stage_case_for_deepisles(case_files, staging_root)
# Copy input files to results_dir before dataset cleanup
# (HuggingFace mode stores files in temp dirs that get cleaned up)
# This ensures all paths in PipelineResult remain valid after function returns
results_dir.mkdir(parents=True, exist_ok=True)
# Copy DWI (required for UI visualization)
dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz"
shutil.copy2(case_files["dwi"], dwi_dest)
# Copy ADC
adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz"
shutil.copy2(case_files["adc"], adc_dest)
# Copy ground truth if available
ground_truth: Path | None = None
original_ground_truth = case_files.get("ground_truth")
if original_ground_truth and original_ground_truth.exists():
ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz"
shutil.copy2(original_ground_truth, ground_truth)
# Build input_files with copied paths (always valid after function returns)
preserved_input_files: CaseFiles = {
"dwi": dwi_dest,
"adc": adc_dest,
}
if ground_truth:
preserved_input_files["ground_truth"] = ground_truth
# Dataset temp files cleaned up here (context manager __exit__)
# 3. Run Inference
inference_result = run_deepisles_on_folder(
staged.input_dir,
output_dir=results_dir,
fast=fast,
gpu=gpu,
)
# 4. Compute Metrics (using copied ground truth)
dice_score: float | None = None
if compute_dice and ground_truth and ground_truth.exists():
try:
dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
except Exception:
logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)
# 5. Cleanup (Optional)
if cleanup_staging:
shutil.rmtree(staging_root, ignore_errors=True)
elapsed = time.time() - start_time
return PipelineResult(
case_id=resolved_case_id,
input_files=preserved_input_files,
results_dir=results_dir,
prediction_mask=inference_result.prediction_path,
ground_truth=ground_truth,
dice_score=dice_score,
elapsed_seconds=elapsed,
)
def run_pipeline_on_batch(
case_ids: Sequence[str | int],
*,
max_workers: int = 1,
**kwargs: object,
) -> list[PipelineResult]:
"""
Run pipeline on multiple cases.
Note: Parallel execution requires multiple GPUs or sequential mode.
Currently only sequential execution is implemented (max_workers is ignored).
Args:
case_ids: List of case identifiers or indices
max_workers: Number of parallel workers (default 1 for sequential).
Currently ignored - reserved for future parallel support.
**kwargs: Passed to run_pipeline_on_case
Returns:
List of PipelineResult, one per case
"""
# Currently only sequential execution is supported.
# max_workers is accepted for API compatibility but ignored.
_ = max_workers
results: list[PipelineResult] = []
for case_id in case_ids:
result = run_pipeline_on_case(case_id, **kwargs) # type: ignore[arg-type]
results.append(result)
return results
def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
"""
Compute summary statistics from multiple pipeline results.
Returns:
Summary with mean Dice, success rate, etc.
"""
# Filter results with valid dice scores
dice_scores = [r.dice_score for r in results if r.dice_score is not None]
elapsed_times = [r.elapsed_seconds for r in results]
num_cases = len(results)
# We assume all passed results are "successful" runs (failed runs raise exceptions)
num_successful = num_cases
num_failed = 0
if dice_scores:
mean_dice = statistics.mean(dice_scores)
std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
min_dice = min(dice_scores)
max_dice = max(dice_scores)
else:
mean_dice = None
std_dice = None
min_dice = None
max_dice = None
mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0
return PipelineSummary(
num_cases=num_cases,
num_successful=num_successful,
num_failed=num_failed,
mean_dice=mean_dice,
std_dice=std_dice,
min_dice=min_dice,
max_dice=max_dice,
mean_elapsed_seconds=mean_elapsed,
)
|