File size: 9,093 Bytes
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
bfe80c5
3f8bf9c
 
 
 
 
 
 
 
bfe80c5
3f8bf9c
 
 
 
4a455a4
 
 
 
 
3f8bf9c
 
4a455a4
 
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa1717e
 
 
3f8bf9c
26f14be
3f8bf9c
 
 
 
 
 
fa1717e
3f8bf9c
fa1717e
 
 
3f8bf9c
 
 
 
 
 
fa1717e
 
 
 
 
 
 
 
 
 
 
3f8bf9c
 
 
878d2e7
 
fa1717e
 
878d2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a455a4
 
 
 
 
 
 
 
 
 
 
 
 
 
878d2e7
 
 
 
 
 
4a455a4
 
 
 
 
 
 
 
878d2e7
3f8bf9c
 
 
 
 
 
 
fa1717e
3f8bf9c
 
878d2e7
3f8bf9c
 
 
 
a544a50
 
3f8bf9c
 
 
be12b50
 
 
 
3f8bf9c
 
 
 
 
4a455a4
 
3f8bf9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""End-to-end pipeline orchestration."""

from __future__ import annotations

import shutil
import statistics
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

from stroke_deepisles_demo import metrics
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import load_isles_dataset, stage_case_for_deepisles
from stroke_deepisles_demo.inference import run_deepisles_on_folder

if TYPE_CHECKING:
    from collections.abc import Sequence

    from stroke_deepisles_demo.core.types import CaseFiles

logger = get_logger(__name__)


@dataclass(frozen=True)
class PipelineResult:
    """Complete result of running the pipeline on a case.

    All file paths in this result point to valid, accessible files in results_dir.
    Callers are responsible for cleaning up results_dir when done (if desired).
    """

    case_id: str
    input_files: CaseFiles  # Copied to results_dir; always valid paths
    results_dir: Path  # Directory containing all result files (for cleanup)
    prediction_mask: Path
    ground_truth: Path | None
    dice_score: float | None  # None if ground truth unavailable or not computed
    elapsed_seconds: float


@dataclass(frozen=True)
class PipelineSummary:
    """Summary statistics from multiple pipeline runs."""

    num_cases: int
    num_successful: int
    num_failed: int
    mean_dice: float | None
    std_dice: float | None
    min_dice: float | None
    max_dice: float | None
    mean_elapsed_seconds: float


def run_pipeline_on_case(
    case_id: str | int,
    *,
    dataset_id: str | None = None,
    output_dir: Path | None = None,
    fast: bool | None = None,
    gpu: bool | None = None,
    timeout: float | None = None,
    compute_dice: bool = True,
    cleanup_staging: bool = True,
) -> PipelineResult:
    """
    Run the complete segmentation pipeline on a single case.

    Args:
        case_id: Case identifier (string) or index (int)
        dataset_id: HF dataset ID (default from Settings.hf_dataset_id)
        output_dir: Directory for results (default: temp dir)
        fast: Use SEALS-only mode (default from Settings.deepisles_fast_mode)
        gpu: Use GPU acceleration (default from Settings.deepisles_use_gpu)
        timeout: Maximum inference time in seconds (default from Settings.deepisles_timeout_seconds)
        compute_dice: Compute Dice score if ground truth available
        cleanup_staging: Remove staging directory after inference

    Returns:
        PipelineResult with all paths and optional metrics
    """
    from stroke_deepisles_demo.core.config import get_settings

    settings = get_settings()

    # Apply settings defaults if not specified
    if fast is None:
        fast = settings.deepisles_fast_mode
    if gpu is None:
        gpu = settings.deepisles_use_gpu
    if timeout is None:
        timeout = settings.deepisles_timeout_seconds

    start_time = time.time()

    # Use context manager to ensure HuggingFace temp files are cleaned up
    # This prevents unbounded disk usage from accumulating temp NIfTI files
    # dataset_id is wired through to loader (defaults to Settings.hf_dataset_id)
    with load_isles_dataset(dataset_id) as dataset:
        # Resolve ID if integer
        if isinstance(case_id, int):
            all_ids = dataset.list_case_ids()
            if case_id < 0 or case_id >= len(all_ids):
                raise IndexError(f"Case index {case_id} out of range (0-{len(all_ids) - 1})")
            resolved_case_id = all_ids[case_id]
        else:
            resolved_case_id = case_id

        # Set up output directories (now that we have resolved_case_id)
        if output_dir:
            output_dir = Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
            staging_root = output_dir / "staging" / resolved_case_id
            results_dir = output_dir / resolved_case_id
        else:
            base_temp = Path(tempfile.mkdtemp(prefix="deepisles_pipeline_"))
            staging_root = base_temp / "staging"
            results_dir = base_temp / "results"

        # Get case files
        case_files = dataset.get_case(resolved_case_id)

        # Stage files (copies DWI/ADC to staging directory)
        staged = stage_case_for_deepisles(case_files, staging_root)

        # Copy input files to results_dir before dataset cleanup
        # (HuggingFace mode stores files in temp dirs that get cleaned up)
        # This ensures all paths in PipelineResult remain valid after function returns
        results_dir.mkdir(parents=True, exist_ok=True)

        # Copy DWI (required for UI visualization)
        dwi_dest = results_dir / f"{resolved_case_id}_dwi.nii.gz"
        shutil.copy2(case_files["dwi"], dwi_dest)

        # Copy ADC
        adc_dest = results_dir / f"{resolved_case_id}_adc.nii.gz"
        shutil.copy2(case_files["adc"], adc_dest)

        # Copy ground truth if available
        ground_truth: Path | None = None
        original_ground_truth = case_files.get("ground_truth")
        if original_ground_truth and original_ground_truth.exists():
            ground_truth = results_dir / f"{resolved_case_id}_ground_truth.nii.gz"
            shutil.copy2(original_ground_truth, ground_truth)

        # Build input_files with copied paths (always valid after function returns)
        preserved_input_files: CaseFiles = {
            "dwi": dwi_dest,
            "adc": adc_dest,
        }
        if ground_truth:
            preserved_input_files["ground_truth"] = ground_truth

    # Dataset temp files cleaned up here (context manager __exit__)

    # 3. Run Inference
    inference_result = run_deepisles_on_folder(
        staged.input_dir,
        output_dir=results_dir,
        fast=fast,
        gpu=gpu,
        timeout=timeout,
    )

    # 4. Compute Metrics (using copied ground truth)
    dice_score: float | None = None
    if compute_dice and ground_truth and ground_truth.exists():
        try:
            dice_score = metrics.compute_dice(inference_result.prediction_path, ground_truth)
        except Exception:
            logger.warning("Failed to compute Dice score for %s", resolved_case_id, exc_info=True)

    # 5. Cleanup (Optional)
    if cleanup_staging:
        try:
            shutil.rmtree(staging_root)
        except OSError as e:
            logger.warning("Failed to cleanup staging directory %s: %s", staging_root, e)

    elapsed = time.time() - start_time

    return PipelineResult(
        case_id=resolved_case_id,
        input_files=preserved_input_files,
        results_dir=results_dir,
        prediction_mask=inference_result.prediction_path,
        ground_truth=ground_truth,
        dice_score=dice_score,
        elapsed_seconds=elapsed,
    )


def run_pipeline_on_batch(
    case_ids: Sequence[str | int],
    *,
    max_workers: int = 1,
    **kwargs: object,
) -> list[PipelineResult]:
    """
    Run pipeline on multiple cases.

    Note: Parallel execution requires multiple GPUs or sequential mode.
    Currently only sequential execution is implemented (max_workers is ignored).

    Args:
        case_ids: List of case identifiers or indices
        max_workers: Number of parallel workers (default 1 for sequential).
                     Currently ignored - reserved for future parallel support.
        **kwargs: Passed to run_pipeline_on_case

    Returns:
        List of PipelineResult, one per case
    """
    # Currently only sequential execution is supported.
    # max_workers is accepted for API compatibility but ignored.
    _ = max_workers

    results: list[PipelineResult] = []
    for case_id in case_ids:
        result = run_pipeline_on_case(case_id, **kwargs)  # type: ignore[arg-type]
        results.append(result)

    return results


def get_pipeline_summary(results: Sequence[PipelineResult]) -> PipelineSummary:
    """
    Compute summary statistics from multiple pipeline results.

    Returns:
        Summary with mean Dice, success rate, etc.
    """
    # Filter results with valid dice scores
    dice_scores = [r.dice_score for r in results if r.dice_score is not None]
    elapsed_times = [r.elapsed_seconds for r in results]

    num_cases = len(results)
    # We assume all passed results are "successful" runs (failed runs raise exceptions)
    num_successful = num_cases
    num_failed = 0

    if dice_scores:
        mean_dice = statistics.mean(dice_scores)
        std_dice = statistics.stdev(dice_scores) if len(dice_scores) > 1 else 0.0
        min_dice = min(dice_scores)
        max_dice = max(dice_scores)
    else:
        mean_dice = None
        std_dice = None
        min_dice = None
        max_dice = None

    mean_elapsed = statistics.mean(elapsed_times) if elapsed_times else 0.0

    return PipelineSummary(
        num_cases=num_cases,
        num_successful=num_successful,
        num_failed=num_failed,
        mean_dice=mean_dice,
        std_dice=std_dice,
        min_dice=min_dice,
        max_dice=max_dice,
        mean_elapsed_seconds=mean_elapsed,
    )