import io import ipaddress import os import socket from urllib.parse import urlparse import requests as http_requests from fastapi import Depends, FastAPI, File, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader from PIL import Image, UnidentifiedImageError from app.model import MammogramModel from app.schemas import AnalyzeRequest, PredictResponse # Comma-separated list of allowed URL hostnames (e.g. your Supabase storage host) _ALLOWED_HOSTS_ENV = os.getenv("ALLOWED_IMAGE_HOSTS", "") ALLOWED_HOSTS: set[str] = { h.strip().lower() for h in _ALLOWED_HOSTS_ENV.split(",") if h.strip() } _API_KEY = os.getenv("API_KEY", "") _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) def _require_api_key(key: str | None = Depends(_api_key_header)) -> None: if not _API_KEY: return # API_KEY not configured — open in dev/mock mode if key != _API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key") app = FastAPI(title="Mammogram Inference API", version="0.1.0") _CORS_ORIGINS = [ o.strip() for o in os.getenv("ALLOWED_ORIGINS", "").split(",") if o.strip() ] app.add_middleware( CORSMiddleware, allow_origins=_CORS_ORIGINS, allow_methods=["POST", "GET"], allow_headers=["Content-Type", "X-API-Key"], ) model = MammogramModel() @app.get("/health") def health() -> dict: return {"status": "ok", "model_mode": model.mode, "model_version": model.version} @app.post("/predict", response_model=PredictResponse, dependencies=[Depends(_require_api_key)]) async def predict(file: UploadFile = File(...)) -> PredictResponse: if not file.content_type or not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Upload must be an image file") data = await file.read() if not data: raise HTTPException(status_code=400, detail="Empty file") try: image = Image.open(io.BytesIO(data)) except UnidentifiedImageError as exc: raise HTTPException(status_code=400, detail="Invalid image format") from exc result = model.predict(image) return PredictResponse(**result) def _validate_url(url: str) -> str: """Validate image URL to prevent SSRF attacks.""" parsed = urlparse(url) if parsed.scheme not in ("https",): raise HTTPException(status_code=400, detail="Only HTTPS URLs are allowed") hostname = (parsed.hostname or "").lower() if not hostname: raise HTTPException(status_code=400, detail="Invalid URL") if ALLOWED_HOSTS and hostname not in ALLOWED_HOSTS: raise HTTPException(status_code=400, detail="Image host not in allowlist") # Block private/loopback IPs to prevent SSRF try: for info in socket.getaddrinfo(hostname, None): addr = info[4][0] try: ip = ipaddress.ip_address(addr) if ip.is_private or ip.is_loopback or ip.is_link_local: raise HTTPException(status_code=400, detail="URL resolves to a private address") except ValueError: pass # skip unparseable addresses (e.g. scoped IPv6) except HTTPException: raise except OSError as exc: raise HTTPException(status_code=400, detail="Cannot resolve hostname") from exc return url @app.post("/analyze", response_model=PredictResponse, dependencies=[Depends(_require_api_key)]) def analyze(body: AnalyzeRequest) -> PredictResponse: """Accept a public image URL, download it, and run inference.""" import logging logger = logging.getLogger(__name__) _validate_url(body.image_url) try: resp = http_requests.get(body.image_url, timeout=30) resp.raise_for_status() except http_requests.RequestException as exc: raise HTTPException(status_code=400, detail=f"Failed to fetch image: {exc}") from exc try: image = Image.open(io.BytesIO(resp.content)) except UnidentifiedImageError as exc: raise HTTPException(status_code=400, detail="URL did not return a valid image") from exc try: result = model.predict(image) return PredictResponse(**result) except Exception as exc: logger.exception("Model inference failed") raise HTTPException(status_code=500, detail=f"Model inference error: {type(exc).__name__}: {exc}") from exc