import base64 from contextlib import asynccontextmanager from io import BytesIO from pathlib import Path from urllib.parse import unquote_to_bytes from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from PIL import Image, UnidentifiedImageError from starlette.datastructures import UploadFile as StarletteUploadFile from model_service import MODEL_CONFIGS, get_model_config, get_model_service BASE_DIR = Path(__file__).resolve().parent MAX_UPLOAD_SIZE = 10 * 1024 * 1024 templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) ACTIVE_MODEL_CONFIG = get_model_config() @asynccontextmanager async def lifespan(_: FastAPI): # Warm up model on startup so the first request is not slow. get_model_service() yield app = FastAPI( title="Presence Detection API", description="Detect whether an image contains a person.", version="0.1.0", lifespan=lifespan, ) def _build_demo_context(**overrides): context = { "image_data_url": None, "result_label": "Normal", "result_label_zh": "預測結果", "class_label": "-", "confidence": "-", "acc": "-", "error": None, "selected_model": ACTIVE_MODEL_CONFIG.name, "model_options": [ { "name": config.name, "backend": config.backend, "path": config.model_path.name, } for config in MODEL_CONFIGS.values() ], } context.update(overrides) return context def _parse_data_url(data_url: str) -> tuple[bytes, str]: if not data_url.startswith("data:") or "," not in data_url: raise HTTPException(status_code=400, detail="Invalid preview image data.") header, encoded = data_url.split(",", 1) content_type = header[5:].split(";")[0] or "image/png" if not content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Preview data must be an image.") if ";base64" in header: try: return base64.b64decode(encoded), content_type except ValueError as exc: raise HTTPException(status_code=400, detail="Invalid preview image data.") from exc return unquote_to_bytes(encoded), content_type async def _read_image_data( file: UploadFile | None, existing_image_data_url: str | None = None, ) -> tuple[bytes, str, str | None]: if file and file.filename: if not file.content_type or not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Uploaded file must be an image.") data = await file.read() if not data: raise HTTPException(status_code=400, detail="Uploaded file is empty.") return data, file.content_type, file.filename if existing_image_data_url: data, content_type = _parse_data_url(existing_image_data_url) if not data: raise HTTPException(status_code=400, detail="Preview image is empty.") return data, content_type, None raise HTTPException(status_code=400, detail="Please upload an image first.") async def _predict_upload( file: UploadFile | None, model_name: str | None = None, existing_image_data_url: str | None = None, ) -> tuple[dict, bytes]: data, content_type, filename = await _read_image_data(file, existing_image_data_url) try: image = Image.open(BytesIO(data)).convert("RGB") except (UnidentifiedImageError, OSError) as exc: raise HTTPException(status_code=400, detail="Invalid image file.") from exc try: result = get_model_service(model_name).predict_image(image) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc result["filename"] = filename result["content_type"] = content_type return result, data def _coerce_upload_file(value: object) -> UploadFile | None: if isinstance(value, (UploadFile, StarletteUploadFile)): return value return None async def _parse_demo_form(request: Request) -> tuple[UploadFile | None, str, str | None]: try: form = await request.form(max_part_size=MAX_UPLOAD_SIZE) except HTTPException: raise except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc file = _coerce_upload_file(form.get("file")) model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name) existing_image_data_url = form.get("existing_image_data_url") if existing_image_data_url is not None: existing_image_data_url = str(existing_image_data_url) return file, model_name, existing_image_data_url async def _parse_predict_form(request: Request) -> tuple[UploadFile | None, str]: try: form = await request.form(max_part_size=MAX_UPLOAD_SIZE) except HTTPException: raise except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc file = _coerce_upload_file(form.get("file")) model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name) return file, model_name @app.get("/") def root(): return { "message": "Presence Detection API", "docs": "/docs", "model_name": ACTIVE_MODEL_CONFIG.name, "model_backend": ACTIVE_MODEL_CONFIG.backend, "model_path": str(ACTIVE_MODEL_CONFIG.model_path.name), } @app.get("/health") def health(): return {"status": "ok", "model_loaded": True} @app.get("/demo", response_class=HTMLResponse) def demo_page(request: Request): return templates.TemplateResponse( request, "demo.html", _build_demo_context(), ) @app.post("/demo", response_class=HTMLResponse) async def demo_predict( request: Request, ): file, model_name, existing_image_data_url = await _parse_demo_form(request) try: result, data = await _predict_upload(file, model_name, existing_image_data_url) except HTTPException as exc: return templates.TemplateResponse( request, "demo.html", _build_demo_context(error=exc.detail, selected_model=model_name), status_code=exc.status_code, ) pred_label = result["label"] pred_conf = result["probabilities"][pred_label] image_data_url = ( f"data:{result['content_type']};base64," f"{base64.b64encode(data).decode('ascii')}" ) return templates.TemplateResponse( request, "demo.html", _build_demo_context( image_data_url=image_data_url, result_label=pred_label, result_label_zh="有人" if pred_label == "person" else "沒人", class_label=pred_label, confidence=f"{pred_conf * 100:.2f}%", acc=f"{pred_conf * 100:.2f}%", selected_model=result["model_name"], ), ) @app.post("/predict") async def predict(request: Request): file, model_name = await _parse_predict_form(request) result, _ = await _predict_upload(file, model_name) return result