File size: 2,506 Bytes
73401e0
 
 
 
9b860d1
3a105f9
23d4613
 
 
3a105f9
b6fad86
73401e0
9b860d1
 
 
 
73401e0
 
3a105f9
9b860d1
 
 
8344437
9b860d1
8344437
 
 
9b860d1
 
 
8344437
 
 
9b860d1
 
 
 
73401e0
9b860d1
 
 
73401e0
 
 
 
 
 
 
2df3397
23d4613
2df3397
 
 
 
 
 
 
23d4613
2df3397
 
 
 
 
 
 
 
23d4613
2df3397
 
23d4613
 
73401e0
9b860d1
23d4613
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import joblib
import pandas as pd
from App.schemas import EmployeeFeatures
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
from sqlalchemy.orm import Session
from App.database import SessionLocal
from App.model import Input, Predictions

MODEL_REPO = "Diaure/xgb_model"

# Variables chargées
model = None
classes_mapping = None
Features = list(EmployeeFeatures.model_fields.keys())



# Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
def files_load():
    global model, classes_mapping

    if model is None:
        chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib"))
        # if not chemin_model.exists():
        #     raise RuntimeError("Eléments du modèle introuvable.")
        model =joblib.load(chemin_model)

    if classes_mapping is None:
        chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json"))
        # if not chemin_mapping.exists():
        #     raise RuntimeError("Mapping des classes introuvable.")
        with open(chemin_mapping) as f:
            classes_mapping = json.load(f)

# Fonction prédiction
def predict_employee(data: dict):
    files_load()

    df = pd.DataFrame([data])[Features]

    print("Colonnes API :", df.columns.tolist()) 
    print("Nombre colonnes API :", len(df.columns))
    
    pred = model.predict(df)[0]
    proba = model.predict_proba(df)[0][1]

    db: Session = SessionLocal() if SessionLocal is not None else None

    if db is not None:
        try:
            # enregistrer les inputs: à chaque appel de POST/predict, on stocke d'abord les entrées de l'utilisateur
            input_row = Input(**data) 
            db.add(input_row) 
            db.commit() 
            db.refresh(input_row)

            # puis on récupère les ids générés automatiquement et enregistre les prédictions liés aux ids
            pred_row = Predictions(input_id = input_row.id, prediction_label = classes_mapping[str(pred)], prediction_proba = float(proba), model_version = "v1") 
            db.add(pred_row) 
            db.commit()
        
        except Exception as e: 
            print("🔥 ERREUR DB :", e) 
            raise e

        finally:
            db.close()

    # puis on renvoie la réponse API
    return {
        "Prediction": classes_mapping[str(pred)],
        "Probabilite_depart": float(proba)}