rkonan's picture
nouvelle version
1bc4593
#main.py
from fastapi import FastAPI, UploadFile, File, HTTPException, Request,Query
from pydantic import BaseModel
from typing import Union
import base64
from app.model import load_model, predict_with_model,get_heatmap,cache
import os
import threading
import time
from app.utils import heartbeat,register_forever
from app.log import logger
from app.config import MODEL_NAME, ENV,MODEL_TYPE
from typing import Optional
import psutil
logger.info(f"ENV :{ENV}")
app = FastAPI()
start_time = time.time()
request_count = 0
def load_models_once():
_ = load_model()
@app.on_event("startup")
def startup():
load_models_once()
threading.Thread(target=register_forever, daemon=True).start()
threading.Thread(target=heartbeat, daemon=True).start()
class ImagePayload(BaseModel):
image: str
predicted_class_index: int
@app.middleware("http")
async def count_requests(request, call_next):
global request_count
if request.url.path not in ["/admin/clear-predictions", "/admin/clear-heatmaps", "/admin/stats", "/admin/logs", "/admin/reset-state"]:
request_count += 1
return await call_next(request)
@app.post("/admin/clear-predictions")
def clear_predictions():
count = 0
for key in list(cache.iterkeys()):
if key.endswith("_pred"):
del cache[key]
count += 1
return {"message": f"✅ {count} prédictions supprimées"}
@app.post("/admin/clear-heatmaps")
def clear_heatmaps():
count = 0
for key in list(cache.iterkeys()):
if key.endswith("_heatmap"):
del cache[key]
count += 1
return {"message": f"🔥 {count} heatmaps supprimées"}
@app.get("/admin/stats")
def get_stats():
uptime_sec = int(time.time() - start_time)
process = psutil.Process()
mem_mb = process.memory_info().rss / 1024**2
return {
"uptime_seconds": uptime_sec,
"uptime_human": time.strftime("%H:%M:%S", time.gmtime(uptime_sec)),
"request_count": request_count,
"cache_items": len(cache),
"memory_usage_mb": f"{mem_mb:.2f}",
}
@app.get("/admin/logs")
def get_logs(lines: Optional[int] = Query(50, ge=1, le=500)):
log_path = "app.log" # ou le chemin réel de ton fichier log
if not os.path.exists(log_path):
return {"logs": ["Aucun fichier log disponible."]}
with open(log_path, "r") as f:
all_lines = f.readlines()
return {"logs": all_lines[-lines:]}
@app.post("/predict")
async def predict(request: Request,
file: UploadFile = File(None),
payload: Union[ImagePayload, None] = None,
#show_heatmap: bool = Query(False, description="Afficher la heatmap"),
):
logger.info("🔁 Requête reçue")
#logger.info(f"✅ Show heatmap : {show_heatmap}")
try:
# Cas 1 : multipart avec fichier
if file is not None:
image_bytes = await file.read()
logger.debug(f"✅ Image reçue via multipart : {file.filename}{len(image_bytes)} octets")
# Cas 2 : JSON base64
elif payload is not None:
image_bytes = base64.b64decode(payload.image)
logger.debug(f"✅ Image décodée depuis base64 : {len(image_bytes)} octets)")
else:
logger.info("⚠️ Aucune image reçue")
raise HTTPException(status_code=400, detail="Format de requête non supporté.")
# Appel de ta logique de prédiction
models = load_model()
if not models:
raise HTTPException(status_code=500, detail="Aucun modèle chargé.")
model_config = models[0]
prediction = predict_with_model(model_config, image_bytes)
# Pour l’instant : réponse simulée
return prediction
except Exception as e:
logger.error("❌ Une erreur s'est produite", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/heatmap")
async def predict_heatmap(
request: Request,
payload: Union[ImagePayload, None] = None,
file: UploadFile = File(None),
predicted_class_index: int = Query(None)
):
logger.info("🔁 Requête reçue pour heatmap")
try:
if file is not None:
image_bytes = await file.read()
logger.debug(f"✅ Image reçue via multipart : {file.filename}{len(image_bytes)} octets")
if predicted_class_index is None:
raise HTTPException(status_code=400, detail="predicted_class_index requis en query avec fichier multipart")
elif payload is not None:
image_bytes = base64.b64decode(payload.image)
predicted_class_index = payload.predicted_class_index
logger.debug(f"✅ Image reçue en JSON base64 : {len(image_bytes)} octets, class={predicted_class_index}")
else:
raise HTTPException(status_code=400, detail="Aucune image reçue")
models = load_model()
if not models:
raise HTTPException(status_code=500, detail="Aucun modèle chargé.")
model_config = models[0]
heatmap = get_heatmap(model_config, image_bytes, predicted_class_index)
return {"heatmap": heatmap}
except Exception as e:
logger.error("❌ Erreur heatmap", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
return {
"status": "ok",
"model_name": MODEL_NAME,
"model_type":MODEL_TYPE,
"timestamp": time.time()
}