import base64 from contextlib import asynccontextmanager from io import BytesIO from pathlib import Path from fastapi import FastAPI, File, HTTPException, Request, UploadFile from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from PIL import Image, UnidentifiedImageError from model_service import MODEL_PATH, get_model_service BASE_DIR = Path(__file__).resolve().parent templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) @asynccontextmanager async def lifespan(_: FastAPI): # Warm up model on startup so the first request is not slow. get_model_service() yield app = FastAPI( title="Presence Detection API", description="Detect whether an image contains a person.", version="0.1.0", lifespan=lifespan, ) def _build_demo_context(**overrides): context = { "image_data_url": None, "result_label": "Normal", "result_label_zh": "預測結果", "class_label": "-", "confidence": "-", "acc": "-", "error": None, } context.update(overrides) return context async def _predict_upload(file: UploadFile) -> tuple[dict, bytes]: if not file.content_type or not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Uploaded file must be an image.") data = await file.read() if not data: raise HTTPException(status_code=400, detail="Uploaded file is empty.") try: image = Image.open(BytesIO(data)).convert("RGB") except UnidentifiedImageError as exc: raise HTTPException(status_code=400, detail="Invalid image file.") from exc result = get_model_service().predict_image(image) result["filename"] = file.filename result["content_type"] = file.content_type return result, data @app.get("/") def root(): return { "message": "Presence Detection API", "docs": "/docs", "model_path": str(MODEL_PATH.name), } @app.get("/health") def health(): return {"status": "ok", "model_loaded": True} @app.get("/demo", response_class=HTMLResponse) def demo_page(request: Request): return templates.TemplateResponse( request, "demo.html", _build_demo_context(), ) @app.post("/demo", response_class=HTMLResponse) async def demo_predict(request: Request, file: UploadFile = File(...)): try: result, data = await _predict_upload(file) except HTTPException as exc: return templates.TemplateResponse( request, "demo.html", _build_demo_context(error=exc.detail), status_code=exc.status_code, ) pred_label = result["label"] pred_conf = result["probabilities"][pred_label] image_data_url = ( f"data:{result['content_type']};base64," f"{base64.b64encode(data).decode('ascii')}" ) return templates.TemplateResponse( request, "demo.html", _build_demo_context( image_data_url=image_data_url, result_label=pred_label, result_label_zh="有人" if pred_label == "person" else "沒人", class_label=pred_label, confidence=f"{pred_conf * 100:.2f}%", acc=f"{pred_conf * 100:.2f}%", ), ) @app.post("/predict") async def predict(file: UploadFile = File(...)): result, _ = await _predict_upload(file) return result