Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| from fastapi import FastAPI, UploadFile, File, Request | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from PIL import Image | |
| from src.models.predictor import Predictor | |
| MODEL_PATH = os.getenv("MODEL_PATH", "artifacts/model_best.pt") | |
| THRESHOLD = float(os.getenv("THRESHOLD", "0.5")) | |
| app = FastAPI(title="Brand Logo Binary Classifier") | |
| templates = Jinja2Templates(directory="app/templates") | |
| predictor = Predictor(MODEL_PATH, threshold=THRESHOLD) | |
| def home(request: Request): | |
| return templates.TemplateResponse( | |
| "index.html", | |
| {"request": request, "model_path": MODEL_PATH, "threshold": THRESHOLD, "info": predictor.info()}, | |
| ) | |
| def health(): | |
| return {"status": "ok", "info": predictor.info(), "model_path": MODEL_PATH} | |
| async def predict(file: UploadFile = File(...)): | |
| data = await file.read() | |
| img = Image.open(io.BytesIO(data)) | |
| out = predictor.predict_pil(img) | |
| return JSONResponse({ | |
| "pred": out.pred, | |
| "prob": out.prob, | |
| "threshold": THRESHOLD, | |
| "device": predictor.info()["device"], | |
| }) | |