Spaces:
Paused
Paused
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request,Query | |
| from app.voting import soft_voting | |
| import io | |
| from PIL import Image | |
| from io import BytesIO | |
| from pydantic import BaseModel | |
| from typing import Union | |
| from io import BytesIO | |
| import base64 | |
| import logging | |
| import logging | |
| from app.model import load_models, predict_with_model | |
| # Configuration de base du logger | |
| logging.basicConfig( | |
| level=logging.DEBUG, # DEBUG pour voir tous les logs (INFO, WARNING, ERROR, etc.) | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| model_configs = load_models() | |
| app = FastAPI() | |
| # @app.post("/predict") | |
| # async def predict(file: UploadFile = File(...)): | |
| # image_bytes = await file.read() | |
| # prediction = await soft_voting(image_bytes) | |
| # return {"prediction": int(prediction)} | |
| class ImagePayload(BaseModel): | |
| image: str # chaîne encodée en base64 | |
| async def predict(request: Request, | |
| file: UploadFile = File(None), | |
| payload: Union[ImagePayload, None] = None, | |
| mode: str = Query("single", enum=["single", "voting","automatic"], description="Mode de prédiction"), | |
| show_heatmap: bool = Query(False, description="Afficher la heatmap"), | |
| default_model: str = Query("efficientnetv2m", enum=["efficientnetv2m", "resnet50"], description="Model par défaut") | |
| ): | |
| logger.info("🔁 Requête reçue") | |
| logger.info(f"✅ Mode : {mode}") | |
| logger.info(f"✅ Default model : {default_model}") | |
| logger.info(f"✅ Show heatmap : {show_heatmap}") | |
| try: | |
| # Cas 1 : multipart avec fichier | |
| if file is not None: | |
| image_bytes = await file.read() | |
| logger.debug("✅ Image reçue via multipart :", file.filename, len(image_bytes), "octets") | |
| # Cas 2 : JSON base64 | |
| elif await request.json(): | |
| body = await request.json() | |
| if "image" not in body: | |
| raise HTTPException(status_code=422, detail="Champ 'image' manquant.") | |
| image_base64 = body["image"] | |
| image_bytes = base64.b64decode(image_base64) | |
| logger.debug("✅ Image décodée depuis base64 :", len(image_bytes), "octets") | |
| else: | |
| logger.info("⚠️ Aucune image reçue") | |
| raise HTTPException(status_code=400, detail="Format de requête non supporté.") | |
| # Appel de ta logique de prédiction | |
| logger.debug("🔍 Appel du vote multi-modèles...") | |
| prediction = await soft_voting(model_configs,image_bytes,mode,show_heatmap,default_model) | |
| # Pour l’instant : réponse simulée | |
| return prediction | |
| except Exception as e: | |
| logger.error("❌ Une erreur s'est produite", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def health_check(): | |
| return {"status": "ok"} | |