SecureMLAPI / app.py
yenslife's picture
feat: improve demo upload flow and model selection
6296cfb
import base64
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
from urllib.parse import unquote_to_bytes
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from PIL import Image, UnidentifiedImageError
from starlette.datastructures import UploadFile as StarletteUploadFile
from model_service import MODEL_CONFIGS, get_model_config, get_model_service
BASE_DIR = Path(__file__).resolve().parent
MAX_UPLOAD_SIZE = 10 * 1024 * 1024
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
ACTIVE_MODEL_CONFIG = get_model_config()
@asynccontextmanager
async def lifespan(_: FastAPI):
# Warm up model on startup so the first request is not slow.
get_model_service()
yield
app = FastAPI(
title="Presence Detection API",
description="Detect whether an image contains a person.",
version="0.1.0",
lifespan=lifespan,
)
def _build_demo_context(**overrides):
context = {
"image_data_url": None,
"result_label": "Normal",
"result_label_zh": "預測結果",
"class_label": "-",
"confidence": "-",
"acc": "-",
"error": None,
"selected_model": ACTIVE_MODEL_CONFIG.name,
"model_options": [
{
"name": config.name,
"backend": config.backend,
"path": config.model_path.name,
}
for config in MODEL_CONFIGS.values()
],
}
context.update(overrides)
return context
def _parse_data_url(data_url: str) -> tuple[bytes, str]:
if not data_url.startswith("data:") or "," not in data_url:
raise HTTPException(status_code=400, detail="Invalid preview image data.")
header, encoded = data_url.split(",", 1)
content_type = header[5:].split(";")[0] or "image/png"
if not content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Preview data must be an image.")
if ";base64" in header:
try:
return base64.b64decode(encoded), content_type
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid preview image data.") from exc
return unquote_to_bytes(encoded), content_type
async def _read_image_data(
file: UploadFile | None,
existing_image_data_url: str | None = None,
) -> tuple[bytes, str, str | None]:
if file and file.filename:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="Uploaded file is empty.")
return data, file.content_type, file.filename
if existing_image_data_url:
data, content_type = _parse_data_url(existing_image_data_url)
if not data:
raise HTTPException(status_code=400, detail="Preview image is empty.")
return data, content_type, None
raise HTTPException(status_code=400, detail="Please upload an image first.")
async def _predict_upload(
file: UploadFile | None,
model_name: str | None = None,
existing_image_data_url: str | None = None,
) -> tuple[dict, bytes]:
data, content_type, filename = await _read_image_data(file, existing_image_data_url)
try:
image = Image.open(BytesIO(data)).convert("RGB")
except (UnidentifiedImageError, OSError) as exc:
raise HTTPException(status_code=400, detail="Invalid image file.") from exc
try:
result = get_model_service(model_name).predict_image(image)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
result["filename"] = filename
result["content_type"] = content_type
return result, data
def _coerce_upload_file(value: object) -> UploadFile | None:
if isinstance(value, (UploadFile, StarletteUploadFile)):
return value
return None
async def _parse_demo_form(request: Request) -> tuple[UploadFile | None, str, str | None]:
try:
form = await request.form(max_part_size=MAX_UPLOAD_SIZE)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
file = _coerce_upload_file(form.get("file"))
model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name)
existing_image_data_url = form.get("existing_image_data_url")
if existing_image_data_url is not None:
existing_image_data_url = str(existing_image_data_url)
return file, model_name, existing_image_data_url
async def _parse_predict_form(request: Request) -> tuple[UploadFile | None, str]:
try:
form = await request.form(max_part_size=MAX_UPLOAD_SIZE)
except HTTPException:
raise
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
file = _coerce_upload_file(form.get("file"))
model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name)
return file, model_name
@app.get("/")
def root():
return {
"message": "Presence Detection API",
"docs": "/docs",
"model_name": ACTIVE_MODEL_CONFIG.name,
"model_backend": ACTIVE_MODEL_CONFIG.backend,
"model_path": str(ACTIVE_MODEL_CONFIG.model_path.name),
}
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": True}
@app.get("/demo", response_class=HTMLResponse)
def demo_page(request: Request):
return templates.TemplateResponse(
request,
"demo.html",
_build_demo_context(),
)
@app.post("/demo", response_class=HTMLResponse)
async def demo_predict(
request: Request,
):
file, model_name, existing_image_data_url = await _parse_demo_form(request)
try:
result, data = await _predict_upload(file, model_name, existing_image_data_url)
except HTTPException as exc:
return templates.TemplateResponse(
request,
"demo.html",
_build_demo_context(error=exc.detail, selected_model=model_name),
status_code=exc.status_code,
)
pred_label = result["label"]
pred_conf = result["probabilities"][pred_label]
image_data_url = (
f"data:{result['content_type']};base64,"
f"{base64.b64encode(data).decode('ascii')}"
)
return templates.TemplateResponse(
request,
"demo.html",
_build_demo_context(
image_data_url=image_data_url,
result_label=pred_label,
result_label_zh="有人" if pred_label == "person" else "沒人",
class_label=pred_label,
confidence=f"{pred_conf * 100:.2f}%",
acc=f"{pred_conf * 100:.2f}%",
selected_model=result["model_name"],
),
)
@app.post("/predict")
async def predict(request: Request):
file, model_name = await _parse_predict_form(request)
result, _ = await _predict_upload(file, model_name)
return result