Futurisys_API_ML / App /predict.py
Diaure's picture
CD: update from GitHub main
426e9e4 verified
import joblib
import pandas as pd
from App.schemas import EmployeeFeatures
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
# Import SQLAlchemy uniquement si disponible
try:
from sqlalchemy.orm import Session
from App.database import SessionLocal
from App.model import Input, Predictions
SQLALCHEMY_AVAILABLE = True
except ModuleNotFoundError:
SQLALCHEMY_AVAILABLE = False
essionLocal = None
Input = None
Predictions = None
MODEL_REPO = "Diaure/xgb_model"
# Variables chargées
model = None
classes_mapping = None
Features = list(EmployeeFeatures.model_fields.keys())
# Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
def files_load():
global model, classes_mapping
if model is None:
chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib"))
# if not chemin_model.exists():
# raise RuntimeError("Eléments du modèle introuvable.")
model =joblib.load(chemin_model)
if classes_mapping is None:
chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json"))
# if not chemin_mapping.exists():
# raise RuntimeError("Mapping des classes introuvable.")
with open(chemin_mapping) as f:
classes_mapping = json.load(f)
# Fonction prédiction
def predict_employee(data: dict):
files_load()
df = pd.DataFrame([data])[Features]
print("Colonnes API :", df.columns.tolist())
print("Nombre colonnes API :", len(df.columns))
pred = model.predict(df)[0]
proba = model.predict_proba(df)[0][1]
# DB désactivée si SQLAlchemy indisponible ou SessionLocal = None
if SQLALCHEMY_AVAILABLE and SessionLocal is not None:
db: Session = SessionLocal()
else: db = None
if db is not None:
try:
# enregistrer les inputs: à chaque appel de POST/predict, on stocke d'abord les entrées de l'utilisateur
input_row = Input(**data)
db.add(input_row)
db.commit()
db.refresh(input_row)
# puis on récupère les ids générés automatiquement et enregistre les prédictions liés aux ids
pred_row = Predictions(input_id = input_row.id, prediction_label = classes_mapping[str(pred)], prediction_proba = float(proba), model_version = "v1")
db.add(pred_row)
db.commit()
except Exception as e:
print("ERREUR DB:", e)
raise e
finally:
db.close()
# puis on renvoie la réponse API
return {
"Prediction": classes_mapping[str(pred)],
"Probabilite_depart": float(proba)}