SecureMLAPI / app.py
yenslife's picture
Update predict result label
8ffefcd
raw
history blame
3.44 kB
import base64
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
from fastapi import FastAPI, File, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from PIL import Image, UnidentifiedImageError
from model_service import MODEL_PATH, get_model_service
BASE_DIR = Path(__file__).resolve().parent
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
@asynccontextmanager
async def lifespan(_: FastAPI):
# Warm up model on startup so the first request is not slow.
get_model_service()
yield
app = FastAPI(
title="Presence Detection API",
description="Detect whether an image contains a person.",
version="0.1.0",
lifespan=lifespan,
)
def _build_demo_context(**overrides):
context = {
"image_data_url": None,
"result_label": "Normal",
"result_label_zh": "預測結果",
"class_label": "-",
"confidence": "-",
"acc": "-",
"error": None,
}
context.update(overrides)
return context
async def _predict_upload(file: UploadFile) -> tuple[dict, bytes]:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="Uploaded file is empty.")
try:
image = Image.open(BytesIO(data)).convert("RGB")
except UnidentifiedImageError as exc:
raise HTTPException(status_code=400, detail="Invalid image file.") from exc
result = get_model_service().predict_image(image)
result["filename"] = file.filename
result["content_type"] = file.content_type
return result, data
@app.get("/")
def root():
return {
"message": "Presence Detection API",
"docs": "/docs",
"model_path": str(MODEL_PATH.name),
}
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": True}
@app.get("/demo", response_class=HTMLResponse)
def demo_page(request: Request):
return templates.TemplateResponse(
request,
"demo.html",
_build_demo_context(),
)
@app.post("/demo", response_class=HTMLResponse)
async def demo_predict(request: Request, file: UploadFile = File(...)):
try:
result, data = await _predict_upload(file)
except HTTPException as exc:
return templates.TemplateResponse(
request,
"demo.html",
_build_demo_context(error=exc.detail),
status_code=exc.status_code,
)
pred_label = result["label"]
pred_conf = result["probabilities"][pred_label]
image_data_url = (
f"data:{result['content_type']};base64,"
f"{base64.b64encode(data).decode('ascii')}"
)
return templates.TemplateResponse(
request,
"demo.html",
_build_demo_context(
image_data_url=image_data_url,
result_label=pred_label,
result_label_zh="有人" if pred_label == "person" else "沒人",
class_label=pred_label,
confidence=f"{pred_conf * 100:.2f}%",
acc=f"{pred_conf * 100:.2f}%",
),
)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
result, _ = await _predict_upload(file)
return result