File size: 2,853 Bytes
73401e0
 
 
 
9b860d1
3a105f9
f489a18
 
 
 
 
15524ff
 
f489a18
 
 
15524ff
 
 
f489a18
3a105f9
b6fad86
73401e0
9b860d1
 
 
 
73401e0
 
9b860d1
 
 
8344437
9b860d1
8344437
 
 
9b860d1
 
 
8344437
 
 
9b860d1
 
 
 
73401e0
9b860d1
 
 
73401e0
 
 
 
 
 
 
f489a18
 
 
 
23d4613
2df3397
 
 
 
 
 
 
23d4613
2df3397
 
 
 
 
 
f489a18
2df3397
23d4613
2df3397
 
23d4613
 
73401e0
9b860d1
23d4613
426e9e4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import joblib
import pandas as pd
from App.schemas import EmployeeFeatures
import json
from pathlib import Path
from huggingface_hub import hf_hub_download


# Import SQLAlchemy uniquement si disponible 
try: 
    from sqlalchemy.orm import Session 
    from App.database import SessionLocal 
    from App.model import Input, Predictions
    SQLALCHEMY_AVAILABLE = True 
except ModuleNotFoundError: 
    SQLALCHEMY_AVAILABLE = False 
    essionLocal = None 
    Input = None 
    Predictions = None
    

MODEL_REPO = "Diaure/xgb_model"

# Variables chargées
model = None
classes_mapping = None
Features = list(EmployeeFeatures.model_fields.keys())


# Chargement des fichiers: fonction pour charger le modèle, le mapping afin de permettre à l'API de démarrer m^me si les éléments ne sont pas présents
def files_load():
    global model, classes_mapping

    if model is None:
        chemin_model = Path(hf_hub_download(repo_id=MODEL_REPO, filename="modele_final_xgb.joblib"))
        # if not chemin_model.exists():
        #     raise RuntimeError("Eléments du modèle introuvable.")
        model =joblib.load(chemin_model)

    if classes_mapping is None:
        chemin_mapping = Path(hf_hub_download(repo_id=MODEL_REPO, filename="mapping_classes.json"))
        # if not chemin_mapping.exists():
        #     raise RuntimeError("Mapping des classes introuvable.")
        with open(chemin_mapping) as f:
            classes_mapping = json.load(f)

# Fonction prédiction
def predict_employee(data: dict):
    files_load()

    df = pd.DataFrame([data])[Features]

    print("Colonnes API :", df.columns.tolist()) 
    print("Nombre colonnes API :", len(df.columns))
    
    pred = model.predict(df)[0]
    proba = model.predict_proba(df)[0][1]

    # DB désactivée si SQLAlchemy indisponible ou SessionLocal = None 
    if SQLALCHEMY_AVAILABLE and SessionLocal is not None:
        db: Session = SessionLocal() 
    else: db = None

    if db is not None:
        try:
            # enregistrer les inputs: à chaque appel de POST/predict, on stocke d'abord les entrées de l'utilisateur
            input_row = Input(**data) 
            db.add(input_row) 
            db.commit() 
            db.refresh(input_row)

            # puis on récupère les ids générés automatiquement et enregistre les prédictions liés aux ids
            pred_row = Predictions(input_id = input_row.id, prediction_label = classes_mapping[str(pred)], prediction_proba = float(proba), model_version = "v1") 
            db.add(pred_row) 
            db.commit()
        
        except Exception as e: 
            print("ERREUR DB:", e) 
            raise e

        finally:
            db.close()

    # puis on renvoie la réponse API
    return {
        "Prediction": classes_mapping[str(pred)],
        "Probabilite_depart": float(proba)}