File size: 7,042 Bytes
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
bfe80c5
3f8bf9c
 
 
 
 
 
 
 
bfe80c5
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f14be
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a544a50
 
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""End-to-end pipeline orchestration."""

from __future__ import annotations

import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder

if TYPE_CHECKING:
    from collections.abc import Sequence

    from stroke_deepisles_demo.core.types import CaseFiles

logger = get_logger(__name__)


@dataclass(frozen=True)
class PipelineResult:
    """Complete result of running the pipeline on a case."""

    case_id: str
    input_files: CaseFiles
    staged_dir: Path
    prediction_mask: Path
    ground_truth: Path | None
    dice_score: float | None  # None if ground truth unavailable or not computed
    elapsed_seconds: float


@dataclass(frozen=True)
class PipelineSummary:
    """Summary statistics from multiple pipeline runs."""

    num_cases: int
    num_successful: int
    num_failed: int
    mean_dice: float | None
    std_dice: float | None
    min_dice: float | None
    max_dice: float | None
    mean_elapsed_seconds: float


def run_pipeline_on_case(
    case_id: str | int,
    *,
    dataset_id: str | None = None,
    output_dir: Path | None = None,
    fast: bool = True,
    gpu: bool = True,
    compute_dice: bool = True,
    cleanup_staging: bool = True,
) -> PipelineResult:
    """
    Run the complete segmentation pipeline on a single case.

    Args:
        case_id: Case identifier (string) or index (int)
        dataset_id: HF dataset ID (default from settings - currently ignored/local)
        output_dir: Directory for results (default: temp dir)
        fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed)
        gpu: Use GPU acceleration
        compute_dice: Compute Dice score if ground truth available
        cleanup_staging: Remove staging directory after inference

    Returns:
        PipelineResult with all paths and optional metrics
    """
    # Note: dataset_id is currently unused as we default to local loading.
    # It's kept for interface compatibility with future cloud mode.
    _ = dataset_id

    start_time = time.time()

    # 1. Load Dataset
    dataset = load_isles_dataset()  # Uses default local path for now

    # Resolve ID if integer
    if isinstance(case_id, int):
        all_ids = dataset.list_case_ids()
        if case_id < 0 or case_id >= len(all_ids):
            raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
        resolved_case_id = all_ids[case_id]
    else:
        resolved_case_id = case_id

    # Get case files
    case_files = dataset.get_case(resolved_case_id)

    # 2. Stage Files
    # Use a temp dir for staging if output_dir not provided, or a subdir of output_dir
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        staging_root = output_dir / "staging" / resolved_case_id
        results_dir = output_dir / resolved_case_id
    else:
        # If no output dir, we create a temp dir that persists (unless cleanup requested)
        # But wait, the user wants paths. If we use tempfile.TemporaryDirectory context,
        # it disappears. We should use mkdtemp or let stage_case handle it.
        # Let's use a temp dir for staging.
        base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
        staging_root = base_temp / "staging"
        results_dir = base_temp / "results"

    staged = stage_case_for_deepisles(case_files, staging_root)

    # 3. Run Inference
    inference_result = run_deepisles_on_folder(
        staged.input_dir,
        output_dir=results_dir,
        fast=fast,
        gpu=gpu,
    )

    # 4. Compute Metrics
    dice_score: float | None = None
    ground_truth = case_files.get("ground_truth")

    if compute_dice and ground_truth and ground_truth.exists():
        try:
            dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
        except Exception:
            logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)

    # 5. Cleanup (Optional)
    if cleanup_staging:
        shutil.rmtree(staging_root, ignore_errors=True)

    elapsed = time.time() - start_time

    return PipelineResult(
        case_id=resolved_case_id,
        input_files=case_files,
        staged_dir=staged.input_dir,
        prediction_mask=inference_result.prediction_path,
        ground_truth=ground_truth,
        dice_score=dice_score,
        elapsed_seconds=elapsed,
    )


def run_pipeline_on_batch(
    case_ids: Sequence[str | int],
    *,
    max_workers: int = 1,
    **kwargs: object,
) -> list[PipelineResult]:
    """
    Run pipeline on multiple cases.

    Note: Parallel execution requires multiple GPUs or sequential mode.
    Currently only sequential execution is implemented (max_workers is ignored).

    Args:
        case_ids: List of case identifiers or indices
        max_workers: Number of parallel workers (default 1 for sequential).
                     Currently ignored - reserved for future parallel support.
        **kwargs: Passed to run_pipeline_on_case

    Returns:
        List of PipelineResult, one per case
    """
    # Currently only sequential execution is supported.
    # max_workers is accepted for API compatibility but ignored.
    _ = max_workers

    results: list[PipelineResult] = []
    for case_id in case_ids:
        result = run_pipeline_on_case(case_id, **kwargs)  # type: ignore[arg-type]
        results.append(result)

    return results


def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
    """
    Compute summary statistics from multiple pipeline results.

    Returns:
        Summary with mean Dice, success rate, etc.
    """
    # Filter results with valid dice scores
    dice_scores = [r.dice_score for r in results if r.dice_score is not None]
    elapsed_times = [r.elapsed_seconds for r in results]

    num_cases = len(results)
    # We assume all passed results are "successful" runs (failed runs raise exceptions)
    num_successful = num_cases
    num_failed = 0

    if dice_scores:
        mean_dice = statistics.mean(dice_scores)
        std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
        min_dice = min(dice_scores)
        max_dice = max(dice_scores)
    else:
        mean_dice = None
        std_dice = None
        min_dice = None
        max_dice = None

    mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0

    return PipelineSummary(
        num_cases=num_cases,
        num_successful=num_successful,
        num_failed=num_failed,
        mean_dice=mean_dice,
        std_dice=std_dice,
        min_dice=min_dice,
        max_dice=max_dice,
        mean_elapsed_seconds=mean_elapsed,
    )