Spaces:
Runtime error
Runtime error
| #main.py | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request,Query | |
| from pydantic import BaseModel | |
| from typing import Union | |
| import base64 | |
| from app.model import load_model, predict_with_model,get_heatmap,cache | |
| import os | |
| import threading | |
| import time | |
| from app.utils import heartbeat,register_forever | |
| from app.log import logger | |
| from app.config import MODEL_NAME, ENV,MODEL_TYPE | |
| from typing import Optional | |
| import psutil | |
| logger.info(f"ENV :{ENV}") | |
| app = FastAPI() | |
| start_time = time.time() | |
| request_count = 0 | |
| def load_models_once(): | |
| _ = load_model() | |
| def startup(): | |
| load_models_once() | |
| threading.Thread(target=register_forever, daemon=True).start() | |
| threading.Thread(target=heartbeat, daemon=True).start() | |
| class ImagePayload(BaseModel): | |
| image: str | |
| predicted_class_index: int | |
| async def count_requests(request, call_next): | |
| global request_count | |
| if request.url.path not in ["/admin/clear-predictions", "/admin/clear-heatmaps", "/admin/stats", "/admin/logs", "/admin/reset-state"]: | |
| request_count += 1 | |
| return await call_next(request) | |
| def clear_predictions(): | |
| count = 0 | |
| for key in list(cache.iterkeys()): | |
| if key.endswith("_pred"): | |
| del cache[key] | |
| count += 1 | |
| return {"message": f"✅ {count} prédictions supprimées"} | |
| def clear_heatmaps(): | |
| count = 0 | |
| for key in list(cache.iterkeys()): | |
| if key.endswith("_heatmap"): | |
| del cache[key] | |
| count += 1 | |
| return {"message": f"🔥 {count} heatmaps supprimées"} | |
| def get_stats(): | |
| uptime_sec = int(time.time() - start_time) | |
| process = psutil.Process() | |
| mem_mb = process.memory_info().rss / 1024**2 | |
| return { | |
| "uptime_seconds": uptime_sec, | |
| "uptime_human": time.strftime("%H:%M:%S", time.gmtime(uptime_sec)), | |
| "request_count": request_count, | |
| "cache_items": len(cache), | |
| "memory_usage_mb": f"{mem_mb:.2f}", | |
| } | |
| def get_logs(lines: Optional[int] = Query(50, ge=1, le=500)): | |
| log_path = "app.log" # ou le chemin réel de ton fichier log | |
| if not os.path.exists(log_path): | |
| return {"logs": ["Aucun fichier log disponible."]} | |
| with open(log_path, "r") as f: | |
| all_lines = f.readlines() | |
| return {"logs": all_lines[-lines:]} | |
| async def predict(request: Request, | |
| file: UploadFile = File(None), | |
| payload: Union[ImagePayload, None] = None, | |
| #show_heatmap: bool = Query(False, description="Afficher la heatmap"), | |
| ): | |
| logger.info("🔁 Requête reçue") | |
| #logger.info(f"✅ Show heatmap : {show_heatmap}") | |
| try: | |
| # Cas 1 : multipart avec fichier | |
| if file is not None: | |
| image_bytes = await file.read() | |
| logger.debug(f"✅ Image reçue via multipart : {file.filename} — {len(image_bytes)} octets") | |
| # Cas 2 : JSON base64 | |
| elif payload is not None: | |
| image_bytes = base64.b64decode(payload.image) | |
| logger.debug(f"✅ Image décodée depuis base64 : {len(image_bytes)} octets)") | |
| else: | |
| logger.info("⚠️ Aucune image reçue") | |
| raise HTTPException(status_code=400, detail="Format de requête non supporté.") | |
| # Appel de ta logique de prédiction | |
| models = load_model() | |
| if not models: | |
| raise HTTPException(status_code=500, detail="Aucun modèle chargé.") | |
| model_config = models[0] | |
| prediction = predict_with_model(model_config, image_bytes) | |
| # Pour l’instant : réponse simulée | |
| return prediction | |
| except Exception as e: | |
| logger.error("❌ Une erreur s'est produite", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict_heatmap( | |
| request: Request, | |
| payload: Union[ImagePayload, None] = None, | |
| file: UploadFile = File(None), | |
| predicted_class_index: int = Query(None) | |
| ): | |
| logger.info("🔁 Requête reçue pour heatmap") | |
| try: | |
| if file is not None: | |
| image_bytes = await file.read() | |
| logger.debug(f"✅ Image reçue via multipart : {file.filename} — {len(image_bytes)} octets") | |
| if predicted_class_index is None: | |
| raise HTTPException(status_code=400, detail="predicted_class_index requis en query avec fichier multipart") | |
| elif payload is not None: | |
| image_bytes = base64.b64decode(payload.image) | |
| predicted_class_index = payload.predicted_class_index | |
| logger.debug(f"✅ Image reçue en JSON base64 : {len(image_bytes)} octets, class={predicted_class_index}") | |
| else: | |
| raise HTTPException(status_code=400, detail="Aucune image reçue") | |
| models = load_model() | |
| if not models: | |
| raise HTTPException(status_code=500, detail="Aucun modèle chargé.") | |
| model_config = models[0] | |
| heatmap = get_heatmap(model_config, image_bytes, predicted_class_index) | |
| return {"heatmap": heatmap} | |
| except Exception as e: | |
| logger.error("❌ Erreur heatmap", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def health_check(): | |
| return { | |
| "status": "ok", | |
| "model_name": MODEL_NAME, | |
| "model_type":MODEL_TYPE, | |
| "timestamp": time.time() | |
| } | |