File size: 1,228 Bytes
ab794cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import io
import os
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from PIL import Image

from src.models.predictor import Predictor


MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/model_best.pt")
THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))

app = FastAPI(title="Brand Logo Binary Classifier")
templates = Jinja2Templates(directory="app/templates")

predictor = Predictor(MODEL_PATH, threshold=THRESHOLD)


@app.get("/", response_class=HTMLResponse)
def home(request: Request):
    return templates.TemplateResponse(
        "index.html",
        {"request": request, "model_path": MODEL_PATH, "threshold": THRESHOLD, "info": predictor.info()},
    )


@app.get("/health")
def health():
    return {"status": "ok", "info": predictor.info(), "model_path": MODEL_PATH}


@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    data = await file.read()
    img = Image.open(io.BytesIO(data))
    out = predictor.predict_pil(img)

    return JSONResponse({
        "pred": out.pred,
        "prob": out.prob,
        "threshold": THRESHOLD,
        "device": predictor.info()["device"],
    })