import io import os from fastapi import FastAPI, UploadFile, File, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates from PIL import Image from src.models.predictor import Predictor MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/model_best.pt") THRESHOLD = float(os.getenv("THRESHOLD", "0.5")) app = FastAPI(title="Brand Logo Binary Classifier") templates = Jinja2Templates(directory="app/templates") predictor = Predictor(MODEL_PATH, threshold=THRESHOLD) @app.get("/", response_class=HTMLResponse) def home(request: Request): return templates.TemplateResponse( "index.html", {"request": request, "model_path": MODEL_PATH, "threshold": THRESHOLD, "info": predictor.info()}, ) @app.get("/health") def health(): return {"status": "ok", "info": predictor.info(), "model_path": MODEL_PATH} @app.post("/predict") async def predict(file: UploadFile = File(...)): data = await file.read() img = Image.open(io.BytesIO(data)) out = predictor.predict_pil(img) return JSONResponse({ "pred": out.pred, "prob": out.prob, "threshold": THRESHOLD, "device": predictor.info()["device"], })