Diane.Aurélie commited on
Commit
9e27b61
·
2 Parent(s): 797be61 9b860d1

Merge pull request #19 from Diaure/feature/cd-huggingface

Browse files

docs: predict.py modified pour stabiliser l'api sur huggingface

Files changed (1) hide show
  1. App/predict.py +27 -6
App/predict.py CHANGED
@@ -2,16 +2,37 @@ import joblib
2
  import pandas as pd
3
  from App.schemas import EmployeeFeatures
4
  import json
 
5
 
 
 
 
6
 
7
- model = joblib.load("App/model/modele_final_xgb.joblib")
 
 
 
8
 
9
- FEATURES = list(EmployeeFeatures.model_fields.keys())
10
- with open("App/model/mapping_classes.json") as f:
11
- CLASS_MAPPING = json.load(f)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def predict_employee(data: dict):
14
- df = pd.DataFrame([data])[FEATURES]
 
 
15
 
16
  print("Colonnes API :", df.columns.tolist())
17
  print("Nombre colonnes API :", len(df.columns))
@@ -20,6 +41,6 @@ def predict_employee(data: dict):
20
  proba = model.predict_proba(df)[0][1]
21
 
22
  return {
23
- "Prediction": CLASS_MAPPING[str(pred)],
24
  "Probabilite_depart": float(proba)
25
  }
 
2
  import pandas as pd
3
  from App.schemas import EmployeeFeatures
4
  import json
5
+ from pathlib import Path
6
 
7
+ # Chemin des fichiers
8
+ chemin_model = Path("App/model/modele_final_xgb.joblib")
9
+ chemin_mapping = Path("App/model/mapping_classes.json")
10
 
11
+ # Variables chargées
12
+ model = None
13
+ classes_mapping = None
14
+ Features = list(EmployeeFeatures.model_fields.keys())
15
 
 
 
 
16
 
17
+ # Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
18
+ def files_load():
19
+ global model, classes_mapping
20
+ if model is None:
21
+ if not chemin_model.exists():
22
+ raise RuntimeError("Eléments du modèle introuvable.")
23
+ model =joblib.load(chemin_model)
24
+
25
+ if classes_mapping is None:
26
+ if not chemin_mapping.exists():
27
+ raise RuntimeError("Mapping des classes introuvable.")
28
+ with open(chemin_mapping) as f:
29
+ classes_mapping = json.load(f)
30
+
31
+ # Fonction prédiction
32
  def predict_employee(data: dict):
33
+ files_load()
34
+
35
+ df = pd.DataFrame([data])[Features]
36
 
37
  print("Colonnes API :", df.columns.tolist())
38
  print("Nombre colonnes API :", len(df.columns))
 
41
  proba = model.predict_proba(df)[0][1]
42
 
43
  return {
44
+ "Prediction": classes_mapping[str(pred)],
45
  "Probabilite_depart": float(proba)
46
  }