File size: 7,168 Bytes
429d013
dcd4485
 
429d013
0066f5e
3f86103
0066f5e
429d013
 
dcd4485
6296cfb
3f86103
0066f5e
3f86103
429d013
0066f5e
429d013
896740b
429d013
dcd4485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429d013
 
 
 
8ffefcd
429d013
 
 
 
0066f5e
 
 
 
 
 
 
 
 
dcd4485
429d013
 
dcd4485
 
0066f5e
 
 
dcd4485
0066f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcd4485
 
 
0066f5e
dcd4485
 
0066f5e
 
 
 
 
 
429d013
 
 
6296cfb
 
 
 
 
 
0066f5e
 
 
 
 
 
 
 
6296cfb
0066f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6296cfb
0066f5e
 
 
 
429d013
 
 
 
 
896740b
 
 
429d013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0066f5e
 
 
 
429d013
0066f5e
429d013
 
 
 
0066f5e
429d013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ffefcd
429d013
 
 
0066f5e
429d013
 
 
 
 
0066f5e
 
 
dcd4485
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import base64
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
from urllib.parse import unquote_to_bytes

from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from PIL import Image, UnidentifiedImageError
from starlette.datastructures import UploadFile as StarletteUploadFile

from model_service import MODEL_CONFIGS, get_model_config, get_model_service

BASE_DIR = Path(__file__).resolve().parent
MAX_UPLOAD_SIZE = 10 * 1024 * 1024
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
ACTIVE_MODEL_CONFIG = get_model_config()


@asynccontextmanager
async def lifespan(_: FastAPI):
    # Warm up model on startup so the first request is not slow.
    get_model_service()
    yield


app = FastAPI(
    title="Presence Detection API",
    description="Detect whether an image contains a person.",
    version="0.1.0",
    lifespan=lifespan,
)


def _build_demo_context(**overrides):
    context = {
        "image_data_url": None,
        "result_label": "Normal",
        "result_label_zh": "預測結果",
        "class_label": "-",
        "confidence": "-",
        "acc": "-",
        "error": None,
        "selected_model": ACTIVE_MODEL_CONFIG.name,
        "model_options": [
            {
                "name": config.name,
                "backend": config.backend,
                "path": config.model_path.name,
            }
            for config in MODEL_CONFIGS.values()
        ],
    }
    context.update(overrides)
    return context


def _parse_data_url(data_url: str) -> tuple[bytes, str]:
    if not data_url.startswith("data:") or "," not in data_url:
        raise HTTPException(status_code=400, detail="Invalid preview image data.")

    header, encoded = data_url.split(",", 1)
    content_type = header[5:].split(";")[0] or "image/png"
    if not content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="Preview data must be an image.")

    if ";base64" in header:
        try:
            return base64.b64decode(encoded), content_type
        except ValueError as exc:
            raise HTTPException(status_code=400, detail="Invalid preview image data.") from exc

    return unquote_to_bytes(encoded), content_type


async def _read_image_data(
    file: UploadFile | None,
    existing_image_data_url: str | None = None,
) -> tuple[bytes, str, str | None]:
    if file and file.filename:
        if not file.content_type or not file.content_type.startswith("image/"):
            raise HTTPException(status_code=400, detail="Uploaded file must be an image.")

        data = await file.read()
        if not data:
            raise HTTPException(status_code=400, detail="Uploaded file is empty.")
        return data, file.content_type, file.filename

    if existing_image_data_url:
        data, content_type = _parse_data_url(existing_image_data_url)
        if not data:
            raise HTTPException(status_code=400, detail="Preview image is empty.")
        return data, content_type, None

    raise HTTPException(status_code=400, detail="Please upload an image first.")


async def _predict_upload(
    file: UploadFile | None,
    model_name: str | None = None,
    existing_image_data_url: str | None = None,
) -> tuple[dict, bytes]:
    data, content_type, filename = await _read_image_data(file, existing_image_data_url)

    try:
        image = Image.open(BytesIO(data)).convert("RGB")
    except (UnidentifiedImageError, OSError) as exc:
        raise HTTPException(status_code=400, detail="Invalid image file.") from exc

    try:
        result = get_model_service(model_name).predict_image(image)
    except ValueError as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc
    result["filename"] = filename
    result["content_type"] = content_type
    return result, data


def _coerce_upload_file(value: object) -> UploadFile | None:
    if isinstance(value, (UploadFile, StarletteUploadFile)):
        return value
    return None


async def _parse_demo_form(request: Request) -> tuple[UploadFile | None, str, str | None]:
    try:
        form = await request.form(max_part_size=MAX_UPLOAD_SIZE)
    except HTTPException:
        raise
    except Exception as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc

    file = _coerce_upload_file(form.get("file"))
    model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name)
    existing_image_data_url = form.get("existing_image_data_url")
    if existing_image_data_url is not None:
        existing_image_data_url = str(existing_image_data_url)
    return file, model_name, existing_image_data_url


async def _parse_predict_form(request: Request) -> tuple[UploadFile | None, str]:
    try:
        form = await request.form(max_part_size=MAX_UPLOAD_SIZE)
    except HTTPException:
        raise
    except Exception as exc:
        raise HTTPException(status_code=400, detail=str(exc)) from exc

    file = _coerce_upload_file(form.get("file"))
    model_name = str(form.get("model_name") or ACTIVE_MODEL_CONFIG.name)
    return file, model_name


@app.get("/")
def root():
    return {
        "message": "Presence Detection API",
        "docs": "/docs",
        "model_name": ACTIVE_MODEL_CONFIG.name,
        "model_backend": ACTIVE_MODEL_CONFIG.backend,
        "model_path": str(ACTIVE_MODEL_CONFIG.model_path.name),
    }


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": True}


@app.get("/demo", response_class=HTMLResponse)
def demo_page(request: Request):
    return templates.TemplateResponse(
        request,
        "demo.html",
        _build_demo_context(),
    )


@app.post("/demo", response_class=HTMLResponse)
async def demo_predict(
    request: Request,
):
    file, model_name, existing_image_data_url = await _parse_demo_form(request)
    try:
        result, data = await _predict_upload(file, model_name, existing_image_data_url)
    except HTTPException as exc:
        return templates.TemplateResponse(
            request,
            "demo.html",
            _build_demo_context(error=exc.detail, selected_model=model_name),
            status_code=exc.status_code,
        )

    pred_label = result["label"]
    pred_conf = result["probabilities"][pred_label]
    image_data_url = (
        f"data:{result['content_type']};base64,"
        f"{base64.b64encode(data).decode('ascii')}"
    )
    return templates.TemplateResponse(
        request,
        "demo.html",
        _build_demo_context(
            image_data_url=image_data_url,
            result_label=pred_label,
            result_label_zh="有人" if pred_label == "person" else "沒人",
            class_label=pred_label,
            confidence=f"{pred_conf * 100:.2f}%",
            acc=f"{pred_conf * 100:.2f}%",
            selected_model=result["model_name"],
        ),
    )


@app.post("/predict")
async def predict(request: Request):
    file, model_name = await _parse_predict_form(request)
    result, _ = await _predict_upload(file, model_name)
    return result