VibecoderMcSwaggins's picture
feat(api): async job queue with comprehensive test coverage (#36)
722753e unverified
raw
history blame
8.12 kB
"""API route handlers for stroke segmentation.
This module implements an async job queue pattern to handle long-running ML inference:
1. POST /api/segment creates a job and returns immediately (202 Accepted)
2. Background task runs the inference
3. Frontend polls GET /api/jobs/{job_id} for status/results
This pattern avoids HuggingFace Spaces' ~60s gateway timeout.
"""
from __future__ import annotations
import contextlib
import os
import uuid
from pathlib import Path
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from stroke_deepisles_demo.api.job_store import JobStatus, get_job_store
from stroke_deepisles_demo.api.schemas import (
CasesResponse,
CreateJobResponse,
JobStatusResponse,
SegmentRequest,
SegmentResponse,
)
from stroke_deepisles_demo.core.logging import get_logger
from stroke_deepisles_demo.data import list_case_ids
from stroke_deepisles_demo.metrics import compute_volume_ml
from stroke_deepisles_demo.pipeline import run_pipeline_on_case
logger = get_logger(__name__)
router = APIRouter()
# Base directory for results
RESULTS_BASE = Path("/tmp/stroke-results")
def get_backend_base_url(request: Request) -> str:
"""Get the backend's public URL for building absolute file URLs.
Priority:
1. BACKEND_PUBLIC_URL env var (for production HF Spaces)
2. Request's base URL (for local development)
"""
env_url = os.environ.get("BACKEND_PUBLIC_URL", "").rstrip("/")
if env_url:
return env_url
return str(request.base_url).rstrip("/")
@router.get("/cases", response_model=CasesResponse)
def get_cases() -> CasesResponse:
"""List available cases from dataset.
Note: This is a sync def (not async) because list_case_ids() is synchronous.
FastAPI automatically runs sync endpoints in a threadpool to avoid blocking.
"""
try:
cases = list_case_ids()
return CasesResponse(cases=cases)
except HTTPException:
raise
except Exception:
logger.exception("Failed to list cases")
raise HTTPException(status_code=500, detail="Failed to retrieve cases") from None
@router.post(
"/segment",
response_model=CreateJobResponse,
status_code=202,
responses={
202: {"description": "Job created successfully"},
400: {"description": "Invalid request"},
500: {"description": "Internal server error"},
},
)
def create_segment_job(
request: Request,
body: SegmentRequest,
background_tasks: BackgroundTasks,
) -> CreateJobResponse:
"""Create an async segmentation job.
Returns immediately with a job ID. The actual ML inference runs in the background.
Poll GET /api/jobs/{jobId} for status updates and results.
This async pattern is required because:
- DeepISLES inference takes 30-60 seconds
- HuggingFace Spaces has a ~60s gateway timeout
- Returning immediately avoids timeout errors
"""
try:
# Use full UUID hex for uniqueness (no truncation)
job_id = uuid.uuid4().hex
store = get_job_store()
backend_url = get_backend_base_url(request)
# Create job record
store.create_job(job_id, body.case_id, body.fast_mode)
# Queue background task
background_tasks.add_task(
run_segmentation_job,
job_id=job_id,
case_id=body.case_id,
fast_mode=body.fast_mode,
backend_url=backend_url,
)
# Note: Don't log case_id as it may be sensitive (medical domain)
logger.info("Created segmentation job %s", job_id)
return CreateJobResponse(
jobId=job_id,
status="pending",
message=f"Segmentation job queued for {body.case_id}",
)
except Exception:
logger.exception("Failed to create segmentation job")
raise HTTPException(status_code=500, detail="Failed to create segmentation job") from None
@router.get(
"/jobs/{job_id}",
response_model=JobStatusResponse,
responses={
200: {"description": "Job status retrieved"},
404: {"description": "Job not found"},
},
)
def get_job_status(job_id: str) -> JobStatusResponse:
"""Get the status of a segmentation job.
Poll this endpoint to track job progress and retrieve results.
Returns:
Job status including progress percentage and results when completed.
Raises:
404: Job not found (may have expired or never existed)
"""
store = get_job_store()
job = store.get_job(job_id)
if job is None:
raise HTTPException(
status_code=404,
detail=f"Job not found: {job_id}. Jobs expire after 1 hour.",
)
# Build response from job data
response = JobStatusResponse(
jobId=job.id,
status=job.status.value,
progress=job.progress,
progressMessage=job.progress_message,
elapsedSeconds=round(job.elapsed_seconds, 2) if job.started_at else None,
result=None,
error=None,
)
# Include result if completed
if job.status == JobStatus.COMPLETED and job.result:
response.result = SegmentResponse(**job.result)
# Include error if failed
if job.status == JobStatus.FAILED and job.error:
response.error = job.error
return response
def run_segmentation_job(
job_id: str,
case_id: str,
fast_mode: bool,
backend_url: str,
) -> None:
"""Execute segmentation in background thread.
This function runs in a threadpool (not the main event loop) because
the ML inference is CPU/GPU-bound and blocking.
Updates job status and progress throughout execution, allowing the
frontend to show meaningful progress updates.
Args:
job_id: Unique job identifier
case_id: Case to process
fast_mode: Whether to use fast inference mode
backend_url: Base URL for constructing result file URLs
"""
store = get_job_store()
job = store.get_job(job_id)
if job is None:
logger.error("Job %s not found when starting execution", job_id)
return
try:
# Mark as running
store.start_job(job_id)
store.update_progress(job_id, 10, "Loading case data...")
# Set up output directory
output_dir = RESULTS_BASE / job_id
store.update_progress(job_id, 20, "Staging files for DeepISLES...")
# Run the pipeline
store.update_progress(job_id, 30, "Running DeepISLES inference...")
result = run_pipeline_on_case(
case_id,
output_dir=output_dir,
fast=fast_mode,
compute_dice=True,
cleanup_staging=True,
)
store.update_progress(job_id, 85, "Computing metrics...")
# Compute volume (may fail for edge cases)
volume_ml = None
with contextlib.suppress(Exception):
volume_ml = round(compute_volume_ml(result.prediction_mask, threshold=0.5), 2)
store.update_progress(job_id, 95, "Preparing results...")
# Build result data
dwi_filename = result.input_files["dwi"].name
pred_filename = result.prediction_mask.name
file_path_prefix = f"/files/{job_id}/{result.case_id}"
result_data = {
"caseId": result.case_id,
"diceScore": result.dice_score,
"volumeMl": volume_ml,
"elapsedSeconds": round(result.elapsed_seconds, 2),
"dwiUrl": f"{backend_url}{file_path_prefix}/{dwi_filename}",
"predictionUrl": f"{backend_url}{file_path_prefix}/{pred_filename}",
}
# Mark as completed
store.complete_job(job_id, result_data)
logger.info(
"Job %s completed: case=%s, dice=%.3f, time=%.1fs",
job_id,
case_id,
result.dice_score or 0,
result.elapsed_seconds,
)
except Exception:
logger.exception("Job %s failed", job_id)
# Sanitize error message - don't expose internal details to clients
store.fail_job(job_id, "Segmentation failed")