tampee's picture
fix: catch all inference exceptions and OSError in SSRF check to surface real 500 cause
cfd4ead
import io
import ipaddress
import os
import socket
from urllib.parse import urlparse
import requests as http_requests
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from PIL import Image, UnidentifiedImageError
from app.model import MammogramModel
from app.schemas import AnalyzeRequest, PredictResponse
# Comma-separated list of allowed URL hostnames (e.g. your Supabase storage host)
_ALLOWED_HOSTS_ENV = os.getenv("ALLOWED_IMAGE_HOSTS", "")
ALLOWED_HOSTS: set[str] = {
h.strip().lower() for h in _ALLOWED_HOSTS_ENV.split(",") if h.strip()
}
_API_KEY = os.getenv("API_KEY", "")
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def _require_api_key(key: str | None = Depends(_api_key_header)) -> None:
if not _API_KEY:
return # API_KEY not configured — open in dev/mock mode
if key != _API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
app = FastAPI(title="Mammogram Inference API", version="0.1.0")
_CORS_ORIGINS = [
o.strip() for o in os.getenv("ALLOWED_ORIGINS", "").split(",") if o.strip()
]
app.add_middleware(
CORSMiddleware,
allow_origins=_CORS_ORIGINS,
allow_methods=["POST", "GET"],
allow_headers=["Content-Type", "X-API-Key"],
)
model = MammogramModel()
@app.get("/health")
def health() -> dict:
return {"status": "ok", "model_mode": model.mode, "model_version": model.version}
@app.post("/predict", response_model=PredictResponse, dependencies=[Depends(_require_api_key)])
async def predict(file: UploadFile = File(...)) -> PredictResponse:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Upload must be an image file")
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="Empty file")
try:
image = Image.open(io.BytesIO(data))
except UnidentifiedImageError as exc:
raise HTTPException(status_code=400, detail="Invalid image format") from exc
result = model.predict(image)
return PredictResponse(**result)
def _validate_url(url: str) -> str:
"""Validate image URL to prevent SSRF attacks."""
parsed = urlparse(url)
if parsed.scheme not in ("https",):
raise HTTPException(status_code=400, detail="Only HTTPS URLs are allowed")
hostname = (parsed.hostname or "").lower()
if not hostname:
raise HTTPException(status_code=400, detail="Invalid URL")
if ALLOWED_HOSTS and hostname not in ALLOWED_HOSTS:
raise HTTPException(status_code=400, detail="Image host not in allowlist")
# Block private/loopback IPs to prevent SSRF
try:
for info in socket.getaddrinfo(hostname, None):
addr = info[4][0]
try:
ip = ipaddress.ip_address(addr)
if ip.is_private or ip.is_loopback or ip.is_link_local:
raise HTTPException(status_code=400, detail="URL resolves to a private address")
except ValueError:
pass # skip unparseable addresses (e.g. scoped IPv6)
except HTTPException:
raise
except OSError as exc:
raise HTTPException(status_code=400, detail="Cannot resolve hostname") from exc
return url
@app.post("/analyze", response_model=PredictResponse, dependencies=[Depends(_require_api_key)])
def analyze(body: AnalyzeRequest) -> PredictResponse:
"""Accept a public image URL, download it, and run inference."""
import logging
logger = logging.getLogger(__name__)
_validate_url(body.image_url)
try:
resp = http_requests.get(body.image_url, timeout=30)
resp.raise_for_status()
except http_requests.RequestException as exc:
raise HTTPException(status_code=400, detail=f"Failed to fetch image: {exc}") from exc
try:
image = Image.open(io.BytesIO(resp.content))
except UnidentifiedImageError as exc:
raise HTTPException(status_code=400, detail="URL did not return a valid image") from exc
try:
result = model.predict(image)
return PredictResponse(**result)
except Exception as exc:
logger.exception("Model inference failed")
raise HTTPException(status_code=500, detail=f"Model inference error: {type(exc).__name__}: {exc}") from exc