File size: 8,504 Bytes
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
bfe80c5
3f8bf9c
 
 
 
 
 
 
 
bfe80c5
3f8bf9c
 
 
 
4a455a4
 
 
 
 
3f8bf9c
 
4a455a4
 
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26f14be
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878d2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a455a4
 
 
 
 
 
 
 
 
 
 
 
 
 
878d2e7
 
 
 
 
 
4a455a4
 
 
 
 
 
 
 
878d2e7
3f8bf9c
 
 
 
 
 
 
 
 
878d2e7
3f8bf9c
 
 
 
a544a50
 
3f8bf9c
 
 
 
 
 
 
 
 
4a455a4
 
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""End-to-end pipeline orchestration."""

from __future__ import annotations

import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder

if TYPE_CHECKING:
    from collections.abc import Sequence

    from stroke_deepisles_demo.core.types import CaseFiles

logger = get_logger(__name__)


@dataclass(frozen=True)
class PipelineResult:
    """Complete result of running the pipeline on a case.

    All file paths in this result point to valid, accessible files in results_dir.
    Callers are responsible for cleaning up results_dir when done (if desired).
    """

    case_id: str
    input_files: CaseFiles  # Copied to results_dir; always valid paths
    results_dir: Path  # Directory containing all result files (for cleanup)
    prediction_mask: Path
    ground_truth: Path | None
    dice_score: float | None  # None if ground truth unavailable or not computed
    elapsed_seconds: float


@dataclass(frozen=True)
class PipelineSummary:
    """Summary statistics from multiple pipeline runs."""

    num_cases: int
    num_successful: int
    num_failed: int
    mean_dice: float | None
    std_dice: float | None
    min_dice: float | None
    max_dice: float | None
    mean_elapsed_seconds: float


def run_pipeline_on_case(
    case_id: str | int,
    *,
    dataset_id: str | None = None,
    output_dir: Path | None = None,
    fast: bool = True,
    gpu: bool = True,
    compute_dice: bool = True,
    cleanup_staging: bool = True,
) -> PipelineResult:
    """
    Run the complete segmentation pipeline on a single case.

    Args:
        case_id: Case identifier (string) or index (int)
        dataset_id: HF dataset ID (default from settings - currently ignored/local)
        output_dir: Directory for results (default: temp dir)
        fast: Use SEALS-only mode (ISLES'22 winner, DWI+ADC only, no FLAIR needed)
        gpu: Use GPU acceleration
        compute_dice: Compute Dice score if ground truth available
        cleanup_staging: Remove staging directory after inference

    Returns:
        PipelineResult with all paths and optional metrics
    """
    # Note: dataset_id is currently unused as we default to local loading.
    # It's kept for interface compatibility with future cloud mode.
    _ = dataset_id

    start_time = time.time()

    # Use context manager to ensure HuggingFace temp files are cleaned up
    # This prevents unbounded disk usage from accumulating temp NIfTI files
    with load_isles_dataset() as dataset:
        # Resolve ID if integer
        if isinstance(case_id, int):
            all_ids = dataset.list_case_ids()
            if case_id < 0 or case_id >= len(all_ids):
                raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
            resolved_case_id = all_ids[case_id]
        else:
            resolved_case_id = case_id

        # Set up output directories (now that we have resolved_case_id)
        if output_dir:
            output_dir = Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
            staging_root = output_dir / "staging" / resolved_case_id
            results_dir = output_dir / resolved_case_id
        else:
            base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
            staging_root = base_temp / "staging"
            results_dir = base_temp / "results"

        # Get case files
        case_files = dataset.get_case(resolved_case_id)

        # Stage files (copies DWI/ADC to staging directory)
        staged = stage_case_for_deepisles(case_files, staging_root)

        # Copy input files to results_dir before dataset cleanup
        # (HuggingFace mode stores files in temp dirs that get cleaned up)
        # This ensures all paths in PipelineResult remain valid after function returns
        results_dir.mkdir(parents=True, exist_ok=True)

        # Copy DWI (required for UI visualization)
        dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz"
        shutil.copy2(case_files["dwi"], dwi_dest)

        # Copy ADC
        adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz"
        shutil.copy2(case_files["adc"], adc_dest)

        # Copy ground truth if available
        ground_truth: Path | None = None
        original_ground_truth = case_files.get("ground_truth")
        if original_ground_truth and original_ground_truth.exists():
            ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz"
            shutil.copy2(original_ground_truth, ground_truth)

        # Build input_files with copied paths (always valid after function returns)
        preserved_input_files: CaseFiles = {
            "dwi": dwi_dest,
            "adc": adc_dest,
        }
        if ground_truth:
            preserved_input_files["ground_truth"] = ground_truth

    # Dataset temp files cleaned up here (context manager __exit__)

    # 3. Run Inference
    inference_result = run_deepisles_on_folder(
        staged.input_dir,
        output_dir=results_dir,
        fast=fast,
        gpu=gpu,
    )

    # 4. Compute Metrics (using copied ground truth)
    dice_score: float | None = None
    if compute_dice and ground_truth and ground_truth.exists():
        try:
            dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
        except Exception:
            logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)

    # 5. Cleanup (Optional)
    if cleanup_staging:
        shutil.rmtree(staging_root, ignore_errors=True)

    elapsed = time.time() - start_time

    return PipelineResult(
        case_id=resolved_case_id,
        input_files=preserved_input_files,
        results_dir=results_dir,
        prediction_mask=inference_result.prediction_path,
        ground_truth=ground_truth,
        dice_score=dice_score,
        elapsed_seconds=elapsed,
    )


def run_pipeline_on_batch(
    case_ids: Sequence[str | int],
    *,
    max_workers: int = 1,
    **kwargs: object,
) -> list[PipelineResult]:
    """
    Run pipeline on multiple cases.

    Note: Parallel execution requires multiple GPUs or sequential mode.
    Currently only sequential execution is implemented (max_workers is ignored).

    Args:
        case_ids: List of case identifiers or indices
        max_workers: Number of parallel workers (default 1 for sequential).
                     Currently ignored - reserved for future parallel support.
        **kwargs: Passed to run_pipeline_on_case

    Returns:
        List of PipelineResult, one per case
    """
    # Currently only sequential execution is supported.
    # max_workers is accepted for API compatibility but ignored.
    _ = max_workers

    results: list[PipelineResult] = []
    for case_id in case_ids:
        result = run_pipeline_on_case(case_id, **kwargs)  # type: ignore[arg-type]
        results.append(result)

    return results


def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
    """
    Compute summary statistics from multiple pipeline results.

    Returns:
        Summary with mean Dice, success rate, etc.
    """
    # Filter results with valid dice scores
    dice_scores = [r.dice_score for r in results if r.dice_score is not None]
    elapsed_times = [r.elapsed_seconds for r in results]

    num_cases = len(results)
    # We assume all passed results are "successful" runs (failed runs raise exceptions)
    num_successful = num_cases
    num_failed = 0

    if dice_scores:
        mean_dice = statistics.mean(dice_scores)
        std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
        min_dice = min(dice_scores)
        max_dice = max(dice_scores)
    else:
        mean_dice = None
        std_dice = None
        min_dice = None
        max_dice = None

    mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0

    return PipelineSummary(
        num_cases=num_cases,
        num_successful=num_successful,
        num_failed=num_failed,
        mean_dice=mean_dice,
        std_dice=std_dice,
        min_dice=min_dice,
        max_dice=max_dice,
        mean_elapsed_seconds=mean_elapsed,
    )