Spaces:
Sleeping
Sleeping
| import base64 | |
| from contextlib import asynccontextmanager | |
| from io import BytesIO | |
| from pathlib import Path | |
| from urllib.parse import unquote_to_bytes | |
| from fastapi import FastAPI, HTTPException, Request, UploadFile | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from PIL import Image, UnidentifiedImageError | |
| from starlette.datastructures import UploadFile as StarletteUploadFile | |
| from model_service import MODEL_CONFIGS, get_model_config, get_model_service | |
| BASE_DIR = Path(__file__).resolve().parent | |
| MAX_UPLOAD_SIZE = 10 * 1024 * 1024 | |
| templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) | |
| ACTIVE_MODEL_CONFIG = get_model_config() | |
| async def lifespan(_: FastAPI): | |
| # Warm up model on startup so the first request is not slow. | |
| get_model_service() | |
| yield | |
| app = FastAPI( | |
| title="Presence Detection API", | |
| description="Detect whether an image contains a person.", | |
| version="0.1.0", | |
| lifespan=lifespan, | |
| ) | |
| def _build_demo_context(**overrides): | |
| context = { | |
| "image_data_url": None, | |
| "result_label": "Normal", | |
| "result_label_zh": "預測結果", | |
| "class_label": "-", | |
| "confidence": "-", | |
| "acc": "-", | |
| "error": None, | |
| "selected_model": ACTIVE_MODEL_CONFIG.name, | |
| "model_options": [ | |
| { | |
| "name": config.name, | |
| "backend": config.backend, | |
| "path": config.model_path.name, | |
| } | |
| for config in MODEL_CONFIGS.values() | |
| ], | |
| } | |
| context.update(overrides) | |
| return context | |
| def _parse_data_url(data_url: str) -> tuple[bytes, str]: | |
| if not data_url.startswith("data:") or "," not in data_url: | |
| raise HTTPException(status_code=400, detail="Invalid preview image data.") | |
| header, encoded = data_url.split(",", 1) | |
| content_type = header[5:].split(";")[0] or "image/png" | |
| if not content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Preview data must be an image.") | |
| if ";base64" in header: | |
| try: | |
| return base64.b64decode(encoded), content_type | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail="Invalid preview image data.") from exc | |
| return unquote_to_bytes(encoded), content_type | |
| async def _read_image_data( | |
| file: UploadFile | None, | |
| existing_image_data_url: str | None = None, | |
| ) -> tuple[bytes, str, str | None]: | |
| if file and file.filename: | |
| if not file.content_type or not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Uploaded file must be an image.") | |
| data = await file.read() | |
| if not data: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty.") | |
| return data, file.content_type, file.filename | |
| if existing_image_data_url: | |
| data, content_type = _parse_data_url(existing_image_data_url) | |
| if not data: | |
| raise HTTPException(status_code=400, detail="Preview image is empty.") | |
| return data, content_type, None | |
| raise HTTPException(status_code=400, detail="Please upload an image first.") | |
| async def _predict_upload( | |
| file: UploadFile | None, | |
| model_name: str | None = None, | |
| existing_image_data_url: str | None = None, | |
| ) -> tuple[dict, bytes]: | |
| data, content_type, filename = await _read_image_data(file, existing_image_data_url) | |
| try: | |
| image = Image.open(BytesIO(data)).convert("RGB") | |
| except (UnidentifiedImageError, OSError) as exc: | |
| raise HTTPException(status_code=400, detail="Invalid image file.") from exc | |
| try: | |
| result = get_model_service(model_name).predict_image(image) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| result["filename"] = filename | |
| result["content_type"] = content_type | |
| return result, data | |
| def _coerce_upload_file(value: object) -> UploadFile | None: | |
| if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| return value | |
| return None | |
| async def _parse_demo_form(request: Request) -> tuple[UploadFile | None, str, str | None]: | |
| try: | |
| form = await request.form(max_part_size=MAX_UPLOAD_SIZE) | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| file = _coerce_upload_file(form.get("file")) | |
| model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name) | |
| existing_image_data_url = form.get("existing_image_data_url") | |
| if existing_image_data_url is not None: | |
| existing_image_data_url = str(existing_image_data_url) | |
| return file, model_name, existing_image_data_url | |
| async def _parse_predict_form(request: Request) -> tuple[UploadFile | None, str]: | |
| try: | |
| form = await request.form(max_part_size=MAX_UPLOAD_SIZE) | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| file = _coerce_upload_file(form.get("file")) | |
| model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name) | |
| return file, model_name | |
| def root(): | |
| return { | |
| "message": "Presence Detection API", | |
| "docs": "/docs", | |
| "model_name": ACTIVE_MODEL_CONFIG.name, | |
| "model_backend": ACTIVE_MODEL_CONFIG.backend, | |
| "model_path": str(ACTIVE_MODEL_CONFIG.model_path.name), | |
| } | |
| def health(): | |
| return {"status": "ok", "model_loaded": True} | |
| def demo_page(request: Request): | |
| return templates.TemplateResponse( | |
| request, | |
| "demo.html", | |
| _build_demo_context(), | |
| ) | |
| async def demo_predict( | |
| request: Request, | |
| ): | |
| file, model_name, existing_image_data_url = await _parse_demo_form(request) | |
| try: | |
| result, data = await _predict_upload(file, model_name, existing_image_data_url) | |
| except HTTPException as exc: | |
| return templates.TemplateResponse( | |
| request, | |
| "demo.html", | |
| _build_demo_context(error=exc.detail, selected_model=model_name), | |
| status_code=exc.status_code, | |
| ) | |
| pred_label = result["label"] | |
| pred_conf = result["probabilities"][pred_label] | |
| image_data_url = ( | |
| f"data:{result['content_type']};base64," | |
| f"{base64.b64encode(data).decode('ascii')}" | |
| ) | |
| return templates.TemplateResponse( | |
| request, | |
| "demo.html", | |
| _build_demo_context( | |
| image_data_url=image_data_url, | |
| result_label=pred_label, | |
| result_label_zh="有人" if pred_label == "person" else "沒人", | |
| class_label=pred_label, | |
| confidence=f"{pred_conf * 100:.2f}%", | |
| acc=f"{pred_conf * 100:.2f}%", | |
| selected_model=result["model_name"], | |
| ), | |
| ) | |
| async def predict(request: Request): | |
| file, model_name = await _parse_predict_form(request) | |
| result, _ = await _predict_upload(file, model_name) | |
| return result | |