Spaces:
Sleeping
Sleeping
| import io | |
| import ipaddress | |
| import os | |
| import socket | |
| from urllib.parse import urlparse | |
| import requests as http_requests | |
| from fastapi import Depends, FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import APIKeyHeader | |
| from PIL import Image, UnidentifiedImageError | |
| from app.model import MammogramModel | |
| from app.schemas import AnalyzeRequest, PredictResponse | |
| # Comma-separated list of allowed URL hostnames (e.g. your Supabase storage host) | |
| _ALLOWED_HOSTS_ENV = os.getenv("ALLOWED_IMAGE_HOSTS", "") | |
| ALLOWED_HOSTS: set[str] = { | |
| h.strip().lower() for h in _ALLOWED_HOSTS_ENV.split(",") if h.strip() | |
| } | |
| _API_KEY = os.getenv("API_KEY", "") | |
| _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) | |
| def _require_api_key(key: str | None = Depends(_api_key_header)) -> None: | |
| if not _API_KEY: | |
| return # API_KEY not configured — open in dev/mock mode | |
| if key != _API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid or missing API key") | |
| app = FastAPI(title="Mammogram Inference API", version="0.1.0") | |
| _CORS_ORIGINS = [ | |
| o.strip() for o in os.getenv("ALLOWED_ORIGINS", "").split(",") if o.strip() | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=_CORS_ORIGINS, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["Content-Type", "X-API-Key"], | |
| ) | |
| model = MammogramModel() | |
| def health() -> dict: | |
| return {"status": "ok", "model_mode": model.mode, "model_version": model.version} | |
| async def predict(file: UploadFile = File(...)) -> PredictResponse: | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Upload must be an image file") | |
| data = await file.read() | |
| if not data: | |
| raise HTTPException(status_code=400, detail="Empty file") | |
| try: | |
| image = Image.open(io.BytesIO(data)) | |
| except UnidentifiedImageError as exc: | |
| raise HTTPException(status_code=400, detail="Invalid image format") from exc | |
| result = model.predict(image) | |
| return PredictResponse(**result) | |
| def _validate_url(url: str) -> str: | |
| """Validate image URL to prevent SSRF attacks.""" | |
| parsed = urlparse(url) | |
| if parsed.scheme not in ("https",): | |
| raise HTTPException(status_code=400, detail="Only HTTPS URLs are allowed") | |
| hostname = (parsed.hostname or "").lower() | |
| if not hostname: | |
| raise HTTPException(status_code=400, detail="Invalid URL") | |
| if ALLOWED_HOSTS and hostname not in ALLOWED_HOSTS: | |
| raise HTTPException(status_code=400, detail="Image host not in allowlist") | |
| # Block private/loopback IPs to prevent SSRF | |
| try: | |
| for info in socket.getaddrinfo(hostname, None): | |
| addr = info[4][0] | |
| try: | |
| ip = ipaddress.ip_address(addr) | |
| if ip.is_private or ip.is_loopback or ip.is_link_local: | |
| raise HTTPException(status_code=400, detail="URL resolves to a private address") | |
| except ValueError: | |
| pass # skip unparseable addresses (e.g. scoped IPv6) | |
| except HTTPException: | |
| raise | |
| except OSError as exc: | |
| raise HTTPException(status_code=400, detail="Cannot resolve hostname") from exc | |
| return url | |
| def analyze(body: AnalyzeRequest) -> PredictResponse: | |
| """Accept a public image URL, download it, and run inference.""" | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| _validate_url(body.image_url) | |
| try: | |
| resp = http_requests.get(body.image_url, timeout=30) | |
| resp.raise_for_status() | |
| except http_requests.RequestException as exc: | |
| raise HTTPException(status_code=400, detail=f"Failed to fetch image: {exc}") from exc | |
| try: | |
| image = Image.open(io.BytesIO(resp.content)) | |
| except UnidentifiedImageError as exc: | |
| raise HTTPException(status_code=400, detail="URL did not return a valid image") from exc | |
| try: | |
| result = model.predict(image) | |
| return PredictResponse(**result) | |
| except Exception as exc: | |
| logger.exception("Model inference failed") | |
| raise HTTPException(status_code=500, detail=f"Model inference error: {type(exc).__name__}: {exc}") from exc | |