ericjedha commited on
Commit
4e0eb62
·
verified ·
1 Parent(s): 2776de6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -38
app.py CHANGED
@@ -4,24 +4,63 @@ import mlflow
4
  import mlflow.pyfunc
5
  import logging
6
  import os
7
- from typing import Literal, List, Union
8
- from fastapi import FastAPI, HTTPException, File, UploadFile
9
- from pydantic import BaseModel
10
  from contextlib import asynccontextmanager
11
- import joblib
12
- import traceback
13
-
14
- # --- Configuration ---
15
- app = FastAPI()
16
-
17
 
18
- # Configuration des logs
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
22
 
23
- # --- Modèle de données Pydantic pour la requête ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
  class Item(BaseModel):
26
  model_key: str
27
  mileage: int
@@ -37,36 +76,29 @@ class Item(BaseModel):
37
  has_speed_regulator: int
38
  winter_tires: int
39
 
40
- # --- Endpoint de prédiction ---
 
 
 
41
 
42
  @app.post("/predict/")
43
  async def predict(item: Item):
44
-
 
 
 
 
 
 
45
  try:
46
  # Créer un DataFrame à partir des données de la requête
47
- # La méthode `model_dump()` de Pydantic est plus sûre que de reconstruire le dict à la main
48
  car_df = pd.DataFrame([item.model_dump()])
49
-
50
- os.environ["APP_URI"] = "https://ericjedha-getaroundml.hf.space"
51
- EXPERIMENT_NAME = "08_GETAROUND"
52
- # Set experiment's info
53
- mlflow.set_experiment(EXPERIMENT_NAME)
54
-
55
- # Get our experiment info
56
- experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
57
-
58
- # Charger le modèle depuis mlflow
59
- logged_model = 'runs:/8d6657ebb69943f298f1124df0db622f/xgboost_ridge_pipeline'
60
-
61
  logger.info(f"Données reçues pour la prédiction : \n{car_df.to_string()}")
62
 
63
- loaded_model = mlflow.pyfunc.load_model(logged_model)
64
-
65
- # Utiliser le modèle déjà en mémoire pour faire la prédiction
66
- prediction = loaded_model.predict(car_df)
67
 
68
  # Formater la réponse
69
- # `.tolist()[0]` est une bonne pratique pour extraire la première valeur d'un array numpy
70
  response = {"prediction": prediction.tolist()[0]}
71
 
72
  logger.info(f"Prédiction effectuée : {response}")
@@ -74,11 +106,5 @@ async def predict(item: Item):
74
 
75
  except Exception as e:
76
  logger.error(f"Erreur lors de la prédiction : {e}")
77
- # Il est utile de logguer l'erreur complète pour le débogage
78
- import traceback
79
  logger.error(traceback.format_exc())
80
- raise HTTPException(status_code=500, detail=f"Erreur serveur lors de la prédiction : {str(e)}")
81
-
82
- @app.get("/")
83
- def read_root():
84
- return {"message": "Bienvenue sur l'API de prédiction GetAround"}
 
4
  import mlflow.pyfunc
5
  import logging
6
  import os
7
+ import traceback
 
 
8
  from contextlib import asynccontextmanager
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel
 
 
 
 
11
 
12
+ # --- Configuration des logs ---
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # --- Dictionnaire pour stocker les modèles chargés ---
17
+ # On le remplit au démarrage de l'application
18
+ ml_models = {}
19
 
20
+ # --- Configuration du Lifespan de l'application ---
21
+ @asynccontextmanager
22
+ async def lifespan(app: FastAPI):
23
+ # Code exécuté au démarrage de l'application
24
+ logger.info("Démarrage de l'application: chargement du modèle...")
25
+
26
+ # 1. Configurer l'URI du serveur MLflow (LA PARTIE LA PLUS IMPORTANTE)
27
+ # Cette variable doit être définie dans les "Secrets" de votre Space FastAPI
28
+ MLFLOW_TRACKING_URI = os.getenv("MLFLOW_TRACKING_URI")
29
+ if not MLFLOW_TRACKING_URI:
30
+ raise ValueError("La variable d'environnement MLFLOW_TRACKING_URI n'est pas définie !")
31
+ mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
32
+ logger.info(f"MLflow tracking URI configuré sur: {MLFLOW_TRACKING_URI}")
33
+
34
+ # 2. Configurer l'authentification si votre Space MLflow est privé
35
+ # Le token doit aussi être dans les "Secrets" du Space FastAPI
36
+ HF_TOKEN = os.getenv("HF_TOKEN")
37
+ if HF_TOKEN:
38
+ os.environ['MLFLOW_TRACKING_USERNAME'] = "ericjedha" # ou tout autre nom d'utilisateur
39
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = HF_TOKEN
40
+ logger.info("Authentification MLflow configurée avec un token.")
41
+
42
+ # 3. Charger le modèle
43
+ try:
44
+ logged_model_uri = 'runs:/8d6657ebb69943f298f1124df0db622f/xgboost_ridge_pipeline'
45
+ # Charger le modèle et le stocker dans notre dictionnaire
46
+ ml_models["getaround_model"] = mlflow.pyfunc.load_model(logged_model_uri)
47
+ logger.info("Modèle chargé avec succès et prêt à être utilisé.")
48
+ except Exception as e:
49
+ logger.error(f"Erreur critique lors du chargement du modèle: {e}")
50
+ logger.error(traceback.format_exc())
51
+ # Si le modèle ne se charge pas, l'application ne peut pas fonctionner.
52
+ # On pourrait choisir d'arrêter l'application ici, mais pour l'instant on logue l'erreur.
53
+
54
+ yield
55
+
56
+ # Code exécuté à l'arrêt de l'application (cleanup)
57
+ logger.info("Arrêt de l'application: nettoyage...")
58
+ ml_models.clear()
59
+
60
+ # --- Initialisation de l'application FastAPI avec le lifespan ---
61
+ app = FastAPI(lifespan=lifespan)
62
 
63
+ # --- Modèle de données Pydantic pour la requête ---
64
  class Item(BaseModel):
65
  model_key: str
66
  mileage: int
 
76
  has_speed_regulator: int
77
  winter_tires: int
78
 
79
+ # --- Endpoints ---
80
+ @app.get("/")
81
+ def read_root():
82
+ return {"message": "Bienvenue sur l'API de prédiction GetAround"}
83
 
84
  @app.post("/predict/")
85
  async def predict(item: Item):
86
+ # Vérifier si le modèle est bien chargé
87
+ if "getaround_model" not in ml_models:
88
+ raise HTTPException(
89
+ status_code=503,
90
+ detail="Le modèle n'est pas disponible. L'application n'a pas pu le charger au démarrage."
91
+ )
92
+
93
  try:
94
  # Créer un DataFrame à partir des données de la requête
 
95
  car_df = pd.DataFrame([item.model_dump()])
 
 
 
 
 
 
 
 
 
 
 
 
96
  logger.info(f"Données reçues pour la prédiction : \n{car_df.to_string()}")
97
 
98
+ # Utiliser le modèle DÉJÀ en mémoire pour faire la prédiction
99
+ prediction = ml_models["getaround_model"].predict(car_df)
 
 
100
 
101
  # Formater la réponse
 
102
  response = {"prediction": prediction.tolist()[0]}
103
 
104
  logger.info(f"Prédiction effectuée : {response}")
 
106
 
107
  except Exception as e:
108
  logger.error(f"Erreur lors de la prédiction : {e}")
 
 
109
  logger.error(traceback.format_exc())
110
+ raise HTTPException(status_code=500, detail=f"Erreur serveur lors de la prédiction : {str(e)}")