File size: 8,122 Bytes
722753e
 
 
 
 
 
 
 
 
 
 
66404dc
 
 
 
 
 
722753e
66404dc
722753e
 
 
 
 
 
 
 
 
66404dc
 
 
 
722753e
 
66404dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66404dc
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66404dc
722753e
66404dc
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66404dc
722753e
 
 
 
 
 
 
66404dc
722753e
 
 
 
 
 
 
 
 
 
 
66404dc
 
722753e
66404dc
722753e
66404dc
 
 
 
722753e
 
66404dc
 
 
 
 
722753e
 
 
66404dc
 
722753e
 
 
 
 
 
 
 
 
 
66404dc
722753e
 
66404dc
722753e
 
 
 
 
 
66404dc
722753e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""API route handlers for stroke segmentation.

This module implements an async job queue pattern to handle long-running ML inference:
1. POST /api/segment creates a job and returns immediately (202 Accepted)
2. Background task runs the inference
3. Frontend polls GET /api/jobs/{job_id} for status/results

This pattern avoids HuggingFace Spaces' ~60s gateway timeout.
"""

from __future__ import annotations

import contextlib
import os
import uuid
from pathlib import Path

from fastapi import APIRouter, BackgroundTasks, HTTPException, Request

from stroke_deepisles_demo.api.job_store import JobStatus, get_job_store
from stroke_deepisles_demo.api.schemas import (
    CasesResponse,
    CreateJobResponse,
    JobStatusResponse,
    SegmentRequest,
    SegmentResponse,
)
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import list_case_ids
from stroke_deepisles_demo.metrics import compute_volume_ml
from stroke_deepisles_demo.pipeline import run_pipeline_on_case

logger = get_logger(__name__)

router = APIRouter()

# Base directory for results
RESULTS_BASE = Path("/tmp/stroke-results")


def get_backend_base_url(request: Request) -> str:
    """Get the backend's public URL for building absolute file URLs.

    Priority:
    1. BACKEND_PUBLIC_URL env var (for production HF Spaces)
    2. Request's base URL (for local development)
    """
    env_url = os.environ.get("BACKEND_PUBLIC_URL", "").rstrip("/")
    if env_url:
        return env_url
    return str(request.base_url).rstrip("/")


@router.get("/cases", response_model=CasesResponse)
def get_cases() -> CasesResponse:
    """List available cases from dataset.

    Note: This is a sync def (not async) because list_case_ids() is synchronous.
    FastAPI automatically runs sync endpoints in a threadpool to avoid blocking.
    """
    try:
        cases = list_case_ids()
        return CasesResponse(cases=cases)
    except HTTPException:
        raise
    except Exception:
        logger.exception("Failed to list cases")
        raise HTTPException(status_code=500, detail="Failed to retrieve cases") from None


@router.post(
    "/segment",
    response_model=CreateJobResponse,
    status_code=202,
    responses={
        202: {"description": "Job created successfully"},
        400: {"description": "Invalid request"},
        500: {"description": "Internal server error"},
    },
)
def create_segment_job(
    request: Request,
    body: SegmentRequest,
    background_tasks: BackgroundTasks,
) -> CreateJobResponse:
    """Create an async segmentation job.

    Returns immediately with a job ID. The actual ML inference runs in the background.
    Poll GET /api/jobs/{jobId} for status updates and results.

    This async pattern is required because:
    - DeepISLES inference takes 30-60 seconds
    - HuggingFace Spaces has a ~60s gateway timeout
    - Returning immediately avoids timeout errors
    """
    try:
        # Use full UUID hex for uniqueness (no truncation)
        job_id = uuid.uuid4().hex
        store = get_job_store()
        backend_url = get_backend_base_url(request)

        # Create job record
        store.create_job(job_id, body.case_id, body.fast_mode)

        # Queue background task
        background_tasks.add_task(
            run_segmentation_job,
            job_id=job_id,
            case_id=body.case_id,
            fast_mode=body.fast_mode,
            backend_url=backend_url,
        )

        # Note: Don't log case_id as it may be sensitive (medical domain)
        logger.info("Created segmentation job %s", job_id)

        return CreateJobResponse(
            jobId=job_id,
            status="pending",
            message=f"Segmentation job queued for {body.case_id}",
        )

    except Exception:
        logger.exception("Failed to create segmentation job")
        raise HTTPException(status_code=500, detail="Failed to create segmentation job") from None


@router.get(
    "/jobs/{job_id}",
    response_model=JobStatusResponse,
    responses={
        200: {"description": "Job status retrieved"},
        404: {"description": "Job not found"},
    },
)
def get_job_status(job_id: str) -> JobStatusResponse:
    """Get the status of a segmentation job.

    Poll this endpoint to track job progress and retrieve results.

    Returns:
        Job status including progress percentage and results when completed.

    Raises:
        404: Job not found (may have expired or never existed)
    """
    store = get_job_store()
    job = store.get_job(job_id)

    if job is None:
        raise HTTPException(
            status_code=404,
            detail=f"Job not found: {job_id}. Jobs expire after 1 hour.",
        )

    # Build response from job data
    response = JobStatusResponse(
        jobId=job.id,
        status=job.status.value,
        progress=job.progress,
        progressMessage=job.progress_message,
        elapsedSeconds=round(job.elapsed_seconds, 2) if job.started_at else None,
        result=None,
        error=None,
    )

    # Include result if completed
    if job.status == JobStatus.COMPLETED and job.result:
        response.result = SegmentResponse(**job.result)

    # Include error if failed
    if job.status == JobStatus.FAILED and job.error:
        response.error = job.error

    return response


def run_segmentation_job(
    job_id: str,
    case_id: str,
    fast_mode: bool,
    backend_url: str,
) -> None:
    """Execute segmentation in background thread.

    This function runs in a threadpool (not the main event loop) because
    the ML inference is CPU/GPU-bound and blocking.

    Updates job status and progress throughout execution, allowing the
    frontend to show meaningful progress updates.

    Args:
        job_id: Unique job identifier
        case_id: Case to process
        fast_mode: Whether to use fast inference mode
        backend_url: Base URL for constructing result file URLs
    """
    store = get_job_store()
    job = store.get_job(job_id)

    if job is None:
        logger.error("Job %s not found when starting execution", job_id)
        return

    try:
        # Mark as running
        store.start_job(job_id)
        store.update_progress(job_id, 10, "Loading case data...")

        # Set up output directory
        output_dir = RESULTS_BASE / job_id

        store.update_progress(job_id, 20, "Staging files for DeepISLES...")

        # Run the pipeline
        store.update_progress(job_id, 30, "Running DeepISLES inference...")

        result = run_pipeline_on_case(
            case_id,
            output_dir=output_dir,
            fast=fast_mode,
            compute_dice=True,
            cleanup_staging=True,
        )

        store.update_progress(job_id, 85, "Computing metrics...")

        # Compute volume (may fail for edge cases)
        volume_ml = None
        with contextlib.suppress(Exception):
            volume_ml = round(compute_volume_ml(result.prediction_mask, threshold=0.5), 2)

        store.update_progress(job_id, 95, "Preparing results...")

        # Build result data
        dwi_filename = result.input_files["dwi"].name
        pred_filename = result.prediction_mask.name
        file_path_prefix = f"/files/{job_id}/{result.case_id}"

        result_data = {
            "caseId": result.case_id,
            "diceScore": result.dice_score,
            "volumeMl": volume_ml,
            "elapsedSeconds": round(result.elapsed_seconds, 2),
            "dwiUrl": f"{backend_url}{file_path_prefix}/{dwi_filename}",
            "predictionUrl": f"{backend_url}{file_path_prefix}/{pred_filename}",
        }

        # Mark as completed
        store.complete_job(job_id, result_data)

        logger.info(
            "Job %s completed: case=%s, dice=%.3f, time=%.1fs",
            job_id,
            case_id,
            result.dice_score or 0,
            result.elapsed_seconds,
        )

    except Exception:
        logger.exception("Job %s failed", job_id)
        # Sanitize error message - don't expose internal details to clients
        store.fail_job(job_id, "Segmentation failed")