Spaces:
Sleeping
Sleeping
File size: 2,853 Bytes
73401e0 9b860d1 3a105f9 f489a18 15524ff f489a18 15524ff f489a18 3a105f9 b6fad86 73401e0 9b860d1 73401e0 9b860d1 8344437 9b860d1 8344437 9b860d1 8344437 9b860d1 73401e0 9b860d1 73401e0 f489a18 23d4613 2df3397 23d4613 2df3397 f489a18 2df3397 23d4613 2df3397 23d4613 73401e0 9b860d1 23d4613 426e9e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import joblib
import pandas as pd
from App.schemas import EmployeeFeatures
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
# Import SQLAlchemy uniquement si disponible
try:
from sqlalchemy.orm import Session
from App.database import SessionLocal
from App.model import Input, Predictions
SQLALCHEMY_AVAILABLE = True
except ModuleNotFoundError:
SQLALCHEMY_AVAILABLE = False
essionLocal = None
Input = None
Predictions = None
MODEL_REPO = "Diaure/xgb_model"
# Variables chargées
model = None
classes_mapping = None
Features = list(EmployeeFeatures.model_fields.keys())
# Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
def files_load():
global model, classes_mapping
if model is None:
chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib"))
# if not chemin_model.exists():
# raise RuntimeError("Eléments du modèle introuvable.")
model =joblib.load(chemin_model)
if classes_mapping is None:
chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json"))
# if not chemin_mapping.exists():
# raise RuntimeError("Mapping des classes introuvable.")
with open(chemin_mapping) as f:
classes_mapping = json.load(f)
# Fonction prédiction
def predict_employee(data: dict):
files_load()
df = pd.DataFrame([data])[Features]
print("Colonnes API :", df.columns.tolist())
print("Nombre colonnes API :", len(df.columns))
pred = model.predict(df)[0]
proba = model.predict_proba(df)[0][1]
# DB désactivée si SQLAlchemy indisponible ou SessionLocal = None
if SQLALCHEMY_AVAILABLE and SessionLocal is not None:
db: Session = SessionLocal()
else: db = None
if db is not None:
try:
# enregistrer les inputs: à chaque appel de POST/predict, on stocke d'abord les entrées de l'utilisateur
input_row = Input(**data)
db.add(input_row)
db.commit()
db.refresh(input_row)
# puis on récupère les ids générés automatiquement et enregistre les prédictions liés aux ids
pred_row = Predictions(input_id = input_row.id, prediction_label = classes_mapping[str(pred)], prediction_proba = float(proba), model_version = "v1")
db.add(pred_row)
db.commit()
except Exception as e:
print("ERREUR DB:", e)
raise e
finally:
db.close()
# puis on renvoie la réponse API
return {
"Prediction": classes_mapping[str(pred)],
"Probabilite_depart": float(proba)}
|