Spaces:
Runtime error
Runtime error
| """End-to-end pipeline orchestration.""" | |
| from __future__ import annotations | |
| import shutil | |
| import statistics | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| from stroke_deepisles_demo import metrics | |
| from stroke_deepisles_demo.core.logging import get_logger | |
| from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles | |
| from stroke_deepisles_demo.inference import run_deepisles_on_folder | |
| if TYPE_CHECKING: | |
| from collections.abc import Sequence | |
| from stroke_deepisles_demo.core.types import CaseFiles | |
| logger = get_logger(__name__) | |
| class PipelineResult: | |
| """Complete result of running the pipeline on a case. | |
| All file paths in this result point to valid, accessible files in results_dir. | |
| Callers are responsible for cleaning up results_dir when done (if desired). | |
| """ | |
| case_id: str | |
| input_files: CaseFiles # Copied to results_dir; always valid paths | |
| results_dir: Path # Directory containing all result files (for cleanup) | |
| prediction_mask: Path | |
| ground_truth: Path | None | |
| dice_score: float | None # None if ground truth unavailable or not computed | |
| elapsed_seconds: float | |
| class PipelineSummary: | |
| """Summary statistics from multiple pipeline runs.""" | |
| num_cases: int | |
| num_successful: int | |
| num_failed: int | |
| mean_dice: float | None | |
| std_dice: float | None | |
| min_dice: float | None | |
| max_dice: float | None | |
| mean_elapsed_seconds: float | |
| def run_pipeline_on_case( | |
| case_id: str | int, | |
| *, | |
| dataset_id: str | None = None, | |
| output_dir: Path | None = None, | |
| fast: bool = True, | |
| gpu: bool = True, | |
| compute_dice: bool = True, | |
| cleanup_staging: bool = True, | |
| ) -> PipelineResult: | |
| """ | |
| Run the complete segmentation pipeline on a single case. | |
| Args: | |
| case_id: Case identifier (string) or index (int) | |
| dataset_id: HF dataset ID (default from settings - currently ignored/local) | |
| output_dir: Directory for results (default: temp dir) | |
| fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed) | |
| gpu: Use GPU acceleration | |
| compute_dice: Compute Dice score if ground truth available | |
| cleanup_staging: Remove staging directory after inference | |
| Returns: | |
| PipelineResult with all paths and optional metrics | |
| """ | |
| # Note: dataset_id is currently unused as we default to local loading. | |
| # It's kept for interface compatibility with future cloud mode. | |
| _ = dataset_id | |
| start_time = time.time() | |
| # Use context manager to ensure HuggingFace temp files are cleaned up | |
| # This prevents unbounded disk usage from accumulating temp NIfTI files | |
| with load_isles_dataset() as dataset: | |
| # Resolve ID if integer | |
| if isinstance(case_id, int): | |
| all_ids = dataset.list_case_ids() | |
| if case_id < 0 or case_id >= len(all_ids): | |
| raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})") | |
| resolved_case_id = all_ids[case_id] | |
| else: | |
| resolved_case_id = case_id | |
| # Set up output directories (now that we have resolved_case_id) | |
| if output_dir: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| staging_root = output_dir / "staging" / resolved_case_id | |
| results_dir = output_dir / resolved_case_id | |
| else: | |
| base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_")) | |
| staging_root = base_temp / "staging" | |
| results_dir = base_temp / "results" | |
| # Get case files | |
| case_files = dataset.get_case(resolved_case_id) | |
| # Stage files (copies DWI/ADC to staging directory) | |
| staged = stage_case_for_deepisles(case_files, staging_root) | |
| # Copy input files to results_dir before dataset cleanup | |
| # (HuggingFace mode stores files in temp dirs that get cleaned up) | |
| # This ensures all paths in PipelineResult remain valid after function returns | |
| results_dir.mkdir(parents=True, exist_ok=True) | |
| # Copy DWI (required for UI visualization) | |
| dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz" | |
| shutil.copy2(case_files["dwi"], dwi_dest) | |
| # Copy ADC | |
| adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz" | |
| shutil.copy2(case_files["adc"], adc_dest) | |
| # Copy ground truth if available | |
| ground_truth: Path | None = None | |
| original_ground_truth = case_files.get("ground_truth") | |
| if original_ground_truth and original_ground_truth.exists(): | |
| ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz" | |
| shutil.copy2(original_ground_truth, ground_truth) | |
| # Build input_files with copied paths (always valid after function returns) | |
| preserved_input_files: CaseFiles = { | |
| "dwi": dwi_dest, | |
| "adc": adc_dest, | |
| } | |
| if ground_truth: | |
| preserved_input_files["ground_truth"] = ground_truth | |
| # Dataset temp files cleaned up here (context manager __exit__) | |
| # 3. Run Inference | |
| inference_result = run_deepisles_on_folder( | |
| staged.input_dir, | |
| output_dir=results_dir, | |
| fast=fast, | |
| gpu=gpu, | |
| ) | |
| # 4. Compute Metrics (using copied ground truth) | |
| dice_score: float | None = None | |
| if compute_dice and ground_truth and ground_truth.exists(): | |
| try: | |
| dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth) | |
| except Exception: | |
| logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True) | |
| # 5. Cleanup (Optional) | |
| if cleanup_staging: | |
| try: | |
| shutil.rmtree(staging_root) | |
| except OSError as e: | |
| logger.warning("Failed to cleanup staging directory %s: %s", staging_root, e) | |
| elapsed = time.time() - start_time | |
| return PipelineResult( | |
| case_id=resolved_case_id, | |
| input_files=preserved_input_files, | |
| results_dir=results_dir, | |
| prediction_mask=inference_result.prediction_path, | |
| ground_truth=ground_truth, | |
| dice_score=dice_score, | |
| elapsed_seconds=elapsed, | |
| ) | |
| def run_pipeline_on_batch( | |
| case_ids: Sequence[str | int], | |
| *, | |
| max_workers: int = 1, | |
| **kwargs: object, | |
| ) -> list[PipelineResult]: | |
| """ | |
| Run pipeline on multiple cases. | |
| Note: Parallel execution requires multiple GPUs or sequential mode. | |
| Currently only sequential execution is implemented (max_workers is ignored). | |
| Args: | |
| case_ids: List of case identifiers or indices | |
| max_workers: Number of parallel workers (default 1 for sequential). | |
| Currently ignored - reserved for future parallel support. | |
| **kwargs: Passed to run_pipeline_on_case | |
| Returns: | |
| List of PipelineResult, one per case | |
| """ | |
| # Currently only sequential execution is supported. | |
| # max_workers is accepted for API compatibility but ignored. | |
| _ = max_workers | |
| results: list[PipelineResult] = [] | |
| for case_id in case_ids: | |
| result = run_pipeline_on_case(case_id, **kwargs) # type: ignore[arg-type] | |
| results.append(result) | |
| return results | |
| def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary: | |
| """ | |
| Compute summary statistics from multiple pipeline results. | |
| Returns: | |
| Summary with mean Dice, success rate, etc. | |
| """ | |
| # Filter results with valid dice scores | |
| dice_scores = [r.dice_score for r in results if r.dice_score is not None] | |
| elapsed_times = [r.elapsed_seconds for r in results] | |
| num_cases = len(results) | |
| # We assume all passed results are "successful" runs (failed runs raise exceptions) | |
| num_successful = num_cases | |
| num_failed = 0 | |
| if dice_scores: | |
| mean_dice = statistics.mean(dice_scores) | |
| std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0 | |
| min_dice = min(dice_scores) | |
| max_dice = max(dice_scores) | |
| else: | |
| mean_dice = None | |
| std_dice = None | |
| min_dice = None | |
| max_dice = None | |
| mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0 | |
| return PipelineSummary( | |
| num_cases=num_cases, | |
| num_successful=num_successful, | |
| num_failed=num_failed, | |
| mean_dice=mean_dice, | |
| std_dice=std_dice, | |
| min_dice=min_dice, | |
| max_dice=max_dice, | |
| mean_elapsed_seconds=mean_elapsed, | |
| ) | |