File size: 10,070 Bytes
722753e
 
 
 
 
 
 
 
 
 
 
66404dc
 
 
722753e
66404dc
722753e
 
 
 
 
 
 
 
ba32591
722753e
66404dc
 
 
 
722753e
 
66404dc
 
 
 
 
 
 
ba32591
66404dc
 
ba32591
 
 
66404dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52703e6
fa1717e
52703e6
8290bc9
 
 
 
 
 
 
 
 
633a315
 
 
 
 
 
 
722753e
 
 
 
fa1717e
8290bc9
fa1717e
 
 
 
 
 
 
 
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52703e6
 
 
 
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66404dc
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66404dc
722753e
66404dc
722753e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66404dc
722753e
 
 
 
 
 
 
66404dc
722753e
 
 
 
 
ba32591
722753e
 
 
 
 
66404dc
fa1717e
66404dc
722753e
66404dc
722753e
fa1717e
66404dc
 
 
 
722753e
 
633a315
66404dc
633a315
66404dc
633a315
 
 
 
 
 
66404dc
722753e
 
 
66404dc
 
722753e
 
 
 
 
 
 
 
 
a3c22e6
722753e
66404dc
722753e
 
66404dc
52703e6
722753e
52703e6
722753e
 
 
66404dc
722753e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""API route handlers for stroke segmentation.

This module implements an async job queue pattern to handle long-running ML inference:
1. POST /api/segment creates a job and returns immediately (202 Accepted)
2. Background task runs the inference
3. Frontend polls GET /api/jobs/{job_id} for status/results

This pattern avoids HuggingFace Spaces' ~60s gateway timeout.
"""

from __future__ import annotations

import uuid

from fastapi import APIRouter, BackgroundTasks, HTTPException, Request

from stroke_deepisles_demo.api.job_store import JobStatus, get_job_store
from stroke_deepisles_demo.api.schemas import (
    CasesResponse,
    CreateJobResponse,
    JobStatusResponse,
    SegmentRequest,
    SegmentResponse,
)
from stroke_deepisles_demo.core.config import get_settings
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import list_case_ids
from stroke_deepisles_demo.metrics import compute_volume_ml
from stroke_deepisles_demo.pipeline import run_pipeline_on_case

logger = get_logger(__name__)

router = APIRouter()


def get_backend_base_url(request: Request) -> str:
    """Get the backend's public URL for building absolute file URLs.

    Priority:
    1. BACKEND_PUBLIC_URL setting (from env var or config)
    2. Request's base URL (for local development)
    """
    settings_url = get_settings().backend_public_url
    if settings_url:
        return settings_url.rstrip("/")
    return str(request.base_url).rstrip("/")


@router.get("/cases", response_model=CasesResponse)
def get_cases() -> CasesResponse:
    """List available cases from dataset.

    Note: This is a sync def (not async) because list_case_ids() is synchronous.
    FastAPI automatically runs sync endpoints in a threadpool to avoid blocking.
    """
    try:
        cases = list_case_ids()
        return CasesResponse(cases=cases)
    except HTTPException:
        raise
    except Exception:
        logger.exception("Failed to list cases")
        raise HTTPException(status_code=500, detail="Failed to retrieve cases") from None


@router.post(
    "/segment",
    response_model=CreateJobResponse,
    status_code=202,
    responses={
        202: {"description": "Job created successfully"},
        400: {"description": "Invalid request"},
        500: {"description": "Internal server error"},
    },
)
def create_segment_job(
    request: Request,
    body: SegmentRequest,
    background_tasks: BackgroundTasks,
) -> CreateJobResponse:
    """Create an async segmentation job.

    Returns immediately with a job ID. The actual ML inference runs in the background.
    Poll GET /api/jobs/{jobId} for status updates and results.

    This async pattern is required because:
    - DeepISLES inference takes 30-60 seconds
    - HuggingFace Spaces has a ~60s gateway timeout
    - Returning immediately avoids timeout errors
    """
    try:
        store = get_job_store()
        settings = get_settings()

        # Pre-check concurrency limit before expensive validation
        # This is a cheap check; the actual limit is enforced atomically below
        if store.get_active_job_count() >= settings.max_concurrent_jobs:
            raise HTTPException(
                status_code=503,
                detail="Server busy: too many active jobs. Please try again later.",
            )

        # Validate case_id exists (only after passing concurrency pre-check)
        valid_cases = list_case_ids()
        if body.case_id not in valid_cases:
            raise HTTPException(
                status_code=400,
                detail=f"Invalid case ID: '{body.case_id}'. Use GET /api/cases for available cases.",
            )

        # Use full UUID hex for uniqueness (no truncation)
        job_id = uuid.uuid4().hex
        backend_url = get_backend_base_url(request)

        # Atomic concurrency limit + job creation (prevents TOCTOU race)
        # The pre-check above is just an optimization; this is the authoritative check
        job = store.create_job_if_under_limit(
            job_id, body.case_id, body.fast_mode, settings.max_concurrent_jobs
        )
        if job is None:
            raise HTTPException(
                status_code=503,
                detail="Server busy: too many active jobs. Please try again later.",
            )

        # Queue background task
        background_tasks.add_task(
            run_segmentation_job,
            job_id=job_id,
            case_id=body.case_id,
            fast_mode=body.fast_mode,
            backend_url=backend_url,
        )

        # Note: Don't log case_id as it may be sensitive (medical domain)
        logger.info("Created segmentation job %s", job_id)

        return CreateJobResponse(
            jobId=job_id,
            status="pending",
            message=f"Segmentation job queued for {body.case_id}",
        )

    except HTTPException:
        # Re-raise HTTP exceptions (400, 404, 503, etc.) as-is
        # Without this, they'd be caught by `except Exception` and converted to 500
        raise
    except Exception:
        logger.exception("Failed to create segmentation job")
        raise HTTPException(status_code=500, detail="Failed to create segmentation job") from None


@router.get(
    "/jobs/{job_id}",
    response_model=JobStatusResponse,
    responses={
        200: {"description": "Job status retrieved"},
        404: {"description": "Job not found"},
    },
)
def get_job_status(job_id: str) -> JobStatusResponse:
    """Get the status of a segmentation job.

    Poll this endpoint to track job progress and retrieve results.

    Returns:
        Job status including progress percentage and results when completed.

    Raises:
        404: Job not found (may have expired or never existed)
    """
    store = get_job_store()
    job = store.get_job(job_id)

    if job is None:
        raise HTTPException(
            status_code=404,
            detail=f"Job not found: {job_id}. Jobs expire after 1 hour.",
        )

    # Build response from job data
    response = JobStatusResponse(
        jobId=job.id,
        status=job.status.value,
        progress=job.progress,
        progressMessage=job.progress_message,
        elapsedSeconds=round(job.elapsed_seconds, 2) if job.started_at else None,
        result=None,
        error=None,
    )

    # Include result if completed
    if job.status == JobStatus.COMPLETED and job.result:
        response.result = SegmentResponse(**job.result)

    # Include error if failed
    if job.status == JobStatus.FAILED and job.error:
        response.error = job.error

    return response


def run_segmentation_job(
    job_id: str,
    case_id: str,
    fast_mode: bool,
    backend_url: str,
) -> None:
    """Execute segmentation in background thread.

    This function runs in a threadpool (not the main event loop) because
    the ML inference is CPU/GPU-bound and blocking.

    Updates job status and progress throughout execution, allowing the
    frontend to show meaningful progress updates.

    Args:
        job_id: Unique job identifier
        case_id: Case to process
        fast_mode: Whether to use fast inference mode
        backend_url: Base URL for constructing result file URLs
    """
    store = get_job_store()
    job = store.get_job(job_id)

    if job is None:
        logger.error("Job %s not found when starting execution", job_id)
        return

    try:
        # Mark as running
        store.start_job(job_id)
        store.update_progress(job_id, 10, "Loading case data...")

        # Set up output directory
        output_dir = get_settings().results_dir / job_id

        store.update_progress(job_id, 20, "Staging files for DeepISLES...")

        # Run the pipeline
        store.update_progress(job_id, 30, "Running DeepISLES inference...")

        # Note: gpu and timeout default to Settings values via pipeline
        result = run_pipeline_on_case(
            case_id,
            output_dir=output_dir,
            fast=fast_mode,
            # gpu, timeout use Settings defaults
            compute_dice=True,
            cleanup_staging=True,
        )

        store.update_progress(job_id, 85, "Computing metrics...")

        # Compute volume - log failures but don't crash the job (BUG-011 fix)
        volume_ml = None
        try:
            volume_ml = round(compute_volume_ml(result.prediction_mask, threshold=0.5), 2)
        except (FileNotFoundError, ValueError) as e:
            # Expected failures: missing mask file or invalid threshold
            logger.warning("Could not compute volume for job %s: %s", job_id, e)
        except Exception:
            # Unexpected failures - log full traceback for debugging
            logger.exception("Unexpected error computing volume for job %s", job_id)

        store.update_progress(job_id, 95, "Preparing results...")

        # Build result data
        dwi_filename = result.input_files["dwi"].name
        pred_filename = result.prediction_mask.name
        file_path_prefix = f"/files/{job_id}/{result.case_id}"

        result_data = {
            "caseId": result.case_id,
            "diceScore": result.dice_score,
            "volumeMl": volume_ml,
            "elapsedSeconds": round(result.elapsed_seconds, 2),
            "dwiUrl": f"{backend_url}{file_path_prefix}/{dwi_filename}",
            "predictionUrl": f"{backend_url}{file_path_prefix}/{pred_filename}",
            "warning": "Results are temporary and will be lost if the Space restarts. Download promptly.",
        }

        # Mark as completed
        store.complete_job(job_id, result_data)

        # Note: Don't log case_id as it may be sensitive (medical domain)
        logger.info(
            "Job %s completed: dice=%.3f, time=%.1fs",
            job_id,
            result.dice_score or 0,
            result.elapsed_seconds,
        )

    except Exception:
        logger.exception("Job %s failed", job_id)
        # Sanitize error message - don't expose internal details to clients
        store.fail_job(job_id, "Segmentation failed")