ericjedha commited on
Commit
ad42521
·
verified ·
1 Parent(s): c12fff2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import mlflow
4
+ import mlflow.pyfunc
5
+ import logging
6
+ import os
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+ from contextlib import asynccontextmanager
10
+
11
+ # --- Configuration ---
12
+
13
+ # Configuration des logs
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Configuration de MLflow (à faire une seule fois)
18
+ # Assurez-vous que cette variable d'environnement est bien définie où l'API tourne
19
+ # Pour Hugging Face, il peut être nécessaire de s'authentifier avec un token.
20
+ # os.environ['HF_TOKEN'] = "votre_token_hugging_face"
21
+ os.environ["MLFLOW_TRACKING_URI"] = "https://ericjedha-getaround.hf.space"
22
+ # Alternativement, vous pouvez le mettre en dur ici si c'est plus simple pour le debug
23
+ # mlflow.set_tracking_uri("https://ericjedha-getaround.hf.space")
24
+
25
+ # ID du modèle à charger
26
+ LOGGED_MODEL_URI = 'runs:/3b12ae0ccd3648d3888571e25e22f500/model'
27
+
28
+ # Dictionnaire pour contenir notre modèle chargé
29
+ ml_models = {}
30
+
31
+ # --- Événements de démarrage et d'arrêt de l'API (Lifespan) ---
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ # Code exécuté au démarrage de l'application
36
+ logger.info("Démarrage de l'application API...")
37
+ try:
38
+ logger.info(f"Chargement du modèle depuis: {LOGGED_MODEL_URI}")
39
+ # Charge le modèle et le stocke dans le dictionnaire
40
+ ml_models['getaround_predictor'] = mlflow.pyfunc.load_model(LOGGED_MODEL_URI)
41
+ logger.info("Modèle chargé avec succès.")
42
+ except Exception as e:
43
+ logger.error(f"Erreur lors du chargement du modèle au démarrage : {e}")
44
+ # Si le modèle ne charge pas, l'API ne devrait pas démarrer correctement
45
+ # ou au moins signaler qu'elle est dans un état dégradé.
46
+ ml_models['getaround_predictor'] = None
47
+
48
+ yield
49
+
50
+ # Code exécuté à l'arrêt de l'application
51
+ logger.info("Arrêt de l'application API.")
52
+ ml_models.clear()
53
+
54
+
55
+ # --- Initialisation de l'API avec le lifespan ---
56
+
57
+ app = FastAPI(lifespan=lifespan)
58
+
59
+ # --- Modèle de données Pydantic pour la requête ---
60
+
61
+ class Item(BaseModel):
62
+ model_key: str
63
+ mileage: int
64
+ engine_power: int
65
+ fuel: str
66
+ paint_color: str
67
+ car_type: str
68
+ private_parking_available: int
69
+ has_gps: int
70
+ has_air_conditioning: int
71
+ automatic_car: int
72
+ has_getaround_connect: int
73
+ has_speed_regulator: int
74
+ winter_tires: int
75
+
76
+ # --- Endpoint de prédiction ---
77
+
78
+ @app.post("/predict/")
79
+ async def predict(item: Item):
80
+ # Vérifier si le modèle a bien été chargé au démarrage
81
+ if 'getaround_predictor' not in ml_models or ml_models['getaround_predictor'] is None:
82
+ logger.error("Le modèle n'a pas pu être chargé. L'endpoint de prédiction n'est pas disponible.")
83
+ raise HTTPException(status_code=503, detail="Service indisponible: le modèle de prédiction n'est pas chargé.")
84
+
85
+ try:
86
+ # Créer un DataFrame à partir des données de la requête
87
+ # La méthode `model_dump()` de Pydantic est plus sûre que de reconstruire le dict à la main
88
+ car_df = pd.DataFrame([item.model_dump()])
89
+
90
+ logger.info(f"Données reçues pour la prédiction : \n{car_df.to_string()}")
91
+
92
+ # Utiliser le modèle déjà en mémoire pour faire la prédiction
93
+ prediction = ml_models['getaround_predictor'].predict(car_df)
94
+
95
+ # Formater la réponse
96
+ # `.tolist()[0]` est une bonne pratique pour extraire la première valeur d'un array numpy
97
+ response = {"prediction": prediction.tolist()[0]}
98
+
99
+ logger.info(f"Prédiction effectuée : {response}")
100
+ return response
101
+
102
+ except Exception as e:
103
+ logger.error(f"Erreur lors de la prédiction : {e}")
104
+ # Il est utile de logguer l'erreur complète pour le débogage
105
+ import traceback
106
+ logger.error(traceback.format_exc())
107
+ raise HTTPException(status_code=500, detail=f"Erreur serveur lors de la prédiction : {str(e)}")
108
+
109
+ @app.get("/")
110
+ def read_root():
111
+ return {"message": "Bienvenue sur l'API de prédiction GetAround"}