File size: 5,048 Bytes
722753e 66404dc 722753e 66404dc f58bb83 66404dc f58bb83 66404dc 722753e 66404dc 722753e f58bb83 722753e 66404dc 722753e 66404dc f58bb83 66404dc f58bb83 66404dc f58bb83 66404dc 722753e 66404dc 722753e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""FastAPI application for stroke segmentation API.
This API provides async ML inference for stroke lesion segmentation using DeepISLES.
It implements a job queue pattern to handle long-running inference without timeouts:
1. POST /api/segment - Creates job, returns immediately (202)
2. GET /api/jobs/{id} - Poll for status/progress/results
3. GET /files/{job_id}/... - Download result NIfTI files
Architecture designed to work within HuggingFace Spaces constraints:
- ~60s gateway timeout (avoided via async job pattern)
- Single worker (in-memory job store is sufficient)
- /tmp writable only (results stored there)
"""
import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from stroke_deepisles_demo.api.job_store import init_job_store
from stroke_deepisles_demo.api.routes import router
from stroke_deepisles_demo.core.logging import get_logger
logger = get_logger(__name__)
# Results directory (must be in /tmp for HF Spaces)
RESULTS_DIR = Path("/tmp/stroke-results")
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan handler for startup/shutdown tasks.
Startup:
- Initialize job store with cleanup scheduler
- Create results directory
Shutdown:
- Stop cleanup scheduler
"""
# Startup
logger.info("Starting stroke segmentation API...")
# Check for GPU availability (DeepISLES requires GPU)
try:
import torch
if not torch.cuda.is_available():
logger.warning(
"GPU not available! DeepISLES requires GPU for inference. "
"This Space should be configured with t4-small or better hardware."
)
except ImportError:
pass # torch may not be available in all environments
# Create results directory
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
# Initialize job store with cleanup scheduler
job_store = init_job_store(results_dir=RESULTS_DIR)
logger.info("Job store initialized with %d jobs", len(job_store))
yield
# Shutdown
logger.info("Shutting down stroke segmentation API...")
job_store.stop_cleanup_scheduler()
app = FastAPI(
title="Stroke Segmentation API",
description="DeepISLES stroke lesion segmentation with async job queue",
version="2.0.0",
lifespan=lifespan,
)
# Cross-Origin Resource Policy middleware (required for COEP)
# This must be added BEFORE CORSMiddleware for proper header ordering
class CORPMiddleware(BaseHTTPMiddleware):
"""Add Cross-Origin-Resource-Policy header to all responses.
Required when frontend uses COEP (Cross-Origin-Embedder-Policy: require-corp)
to enable SharedArrayBuffer for WebGL performance optimizations.
"""
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
response.headers["Cross-Origin-Resource-Policy"] = "cross-origin"
return response
# CORS configuration
FRONTEND_ORIGIN = os.environ.get("FRONTEND_ORIGIN", "")
CORS_ORIGINS = [
"http://localhost:5173", # Vite dev server
"http://localhost:3000", # Alternative local port
]
if FRONTEND_ORIGIN:
CORS_ORIGINS.append(FRONTEND_ORIGIN)
# Add CORP middleware first (for COEP compatibility)
app.add_middleware(CORPMiddleware)
# Add CORS middleware with strict security settings
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
# Anchored regex: only allow our specific HF Space (security fix for BUG-002)
allow_origin_regex=r"https://vibecodermcswaggins-stroke-viewer-frontend\.hf\.space",
allow_credentials=False, # Not needed - no cookies/auth
allow_methods=["GET", "POST"], # Only methods we use
allow_headers=["Content-Type"], # Only headers we need
)
# API routes
app.include_router(router, prefix="/api")
# Static files for NIfTI results
# Note: Mount happens at import time; ensure directory exists here as well.
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
app.mount("/files", StaticFiles(directory=str(RESULTS_DIR)), name="files")
@app.get("/")
async def root() -> dict[str, Any]:
"""Health check endpoint."""
return {
"status": "healthy",
"service": "stroke-segmentation-api",
"version": "2.0.0",
"features": ["async-jobs", "progress-tracking"],
}
@app.get("/health")
async def health() -> dict[str, Any]:
"""Detailed health check endpoint."""
from stroke_deepisles_demo.api.job_store import get_job_store
store = get_job_store()
return {
"status": "healthy",
"jobs_in_memory": len(store),
"results_dir": str(RESULTS_DIR),
"results_dir_exists": RESULTS_DIR.exists(),
}
|