import joblib import pandas as pd from App.schemas import EmployeeFeatures import json from pathlib import Path from huggingface_hub import hf_hub_download # Import SQLAlchemy uniquement si disponible try: from sqlalchemy.orm import Session from App.database import SessionLocal from App.model import Input, Predictions SQLALCHEMY_AVAILABLE = True except ModuleNotFoundError: SQLALCHEMY_AVAILABLE = False essionLocal = None Input = None Predictions = None MODEL_REPO = "Diaure/xgb_model" # Variables chargées model = None classes_mapping = None Features = list(EmployeeFeatures.model_fields.keys()) # Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents def files_load(): global model, classes_mapping if model is None: chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib")) # if not chemin_model.exists(): # raise RuntimeError("Eléments du modèle introuvable.") model =joblib.load(chemin_model) if classes_mapping is None: chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json")) # if not chemin_mapping.exists(): # raise RuntimeError("Mapping des classes introuvable.") with open(chemin_mapping) as f: classes_mapping = json.load(f) # Fonction prédiction def predict_employee(data: dict): files_load() df = pd.DataFrame([data])[Features] print("Colonnes API :", df.columns.tolist()) print("Nombre colonnes API :", len(df.columns)) pred = model.predict(df)[0] proba = model.predict_proba(df)[0][1] # DB désactivée si SQLAlchemy indisponible ou SessionLocal = None if SQLALCHEMY_AVAILABLE and SessionLocal is not None: db: Session = SessionLocal() else: db = None if db is not None: try: # enregistrer les inputs: à chaque appel de POST/predict, on stocke d'abord les entrées de l'utilisateur input_row = Input(**data) db.add(input_row) db.commit() db.refresh(input_row) # puis on récupère les ids générés automatiquement et enregistre les prédictions liés aux ids pred_row = Predictions(input_id = input_row.id, prediction_label = classes_mapping[str(pred)], prediction_proba = float(proba), model_version = "v1") db.add(pred_row) db.commit() except Exception as e: print("ERREUR DB:", e) raise e finally: db.close() # puis on renvoie la réponse API return { "Prediction": classes_mapping[str(pred)], "Probabilite_depart": float(proba)}