|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Request,Query |
|
|
|
|
|
from pydantic import BaseModel |
|
|
from typing import Union |
|
|
import base64 |
|
|
|
|
|
from app.model import load_model, predict_with_model |
|
|
import os |
|
|
import threading |
|
|
import time |
|
|
from app.utils import heartbeat,register_forever |
|
|
from app.log import logger |
|
|
from app.config import MODEL_NAME, ENV |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"ENV :{ENV}") |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
def load_models_once(): |
|
|
_ = load_model() |
|
|
|
|
|
@app.on_event("startup") |
|
|
def startup(): |
|
|
load_models_once() |
|
|
threading.Thread(target=register_forever, daemon=True).start() |
|
|
threading.Thread(target=heartbeat, daemon=True).start() |
|
|
|
|
|
|
|
|
class ImagePayload(BaseModel): |
|
|
image: str |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(request: Request, |
|
|
file: UploadFile = File(None), |
|
|
payload: Union[ImagePayload, None] = None, |
|
|
show_heatmap: bool = Query(False, description="Afficher la heatmap"), |
|
|
): |
|
|
|
|
|
logger.info("🔁 Requête reçue") |
|
|
logger.info(f"✅ Show heatmap : {show_heatmap}") |
|
|
|
|
|
try: |
|
|
|
|
|
if file is not None: |
|
|
image_bytes = await file.read() |
|
|
logger.debug(f"✅ Image reçue via multipart : {file.filename} — {len(image_bytes)} octets") |
|
|
|
|
|
|
|
|
elif payload is not None: |
|
|
image_bytes = base64.b64decode(payload.image) |
|
|
logger.debug(f"✅ Image décodée depuis base64 : {len(image_bytes)} octets)") |
|
|
|
|
|
else: |
|
|
logger.info("⚠️ Aucune image reçue") |
|
|
raise HTTPException(status_code=400, detail="Format de requête non supporté.") |
|
|
|
|
|
|
|
|
logger.debug("🔍 Appel du vote multi-modèles...") |
|
|
models = load_model() |
|
|
if not models: |
|
|
raise HTTPException(status_code=500, detail="Aucun modèle chargé.") |
|
|
model_config = models[0] |
|
|
prediction = predict_with_model(model_config, image_bytes, show_heatmap) |
|
|
|
|
|
|
|
|
return prediction |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("❌ Une erreur s'est produite", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return { |
|
|
"status": "ok", |
|
|
"model_name": MODEL_NAME, |
|
|
"timestamp": time.time() |
|
|
} |
|
|
|