yoavraytz's picture
Deploy FastAPI logo classifier demo
ab794cc
import io
import os
from fastapi import FastAPI, UploadFile, File, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from PIL import Image
from src.models.predictor import Predictor
MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/model_best.pt")
THRESHOLD = float(os.getenv("THRESHOLD", "0.5"))
app = FastAPI(title="Brand Logo Binary Classifier")
templates = Jinja2Templates(directory="app/templates")
predictor = Predictor(MODEL_PATH, threshold=THRESHOLD)
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request, "model_path": MODEL_PATH, "threshold": THRESHOLD, "info": predictor.info()},
)
@app.get("/health")
def health():
return {"status": "ok", "info": predictor.info(), "model_path": MODEL_PATH}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
data = await file.read()
img = Image.open(io.BytesIO(data))
out = predictor.predict_pil(img)
return JSONResponse({
"pred": out.pred,
"prob": out.prob,
"threshold": THRESHOLD,
"device": predictor.info()["device"],
})