File size: 5,996 Bytes
e7416bc
3b97d72
221bf32
3b97d72
 
 
4ba8e3d
3b97d72
 
 
 
 
 
 
4ba8e3d
3b97d72
e7416bc
3b97d72
 
 
 
 
 
e7416bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ba8e3d
 
e7416bc
 
4ba8e3d
3b97d72
4ba8e3d
 
 
3b97d72
4ba8e3d
 
3b97d72
 
 
 
e7416bc
 
 
 
3b97d72
 
 
 
 
 
 
 
 
4ba8e3d
 
3b97d72
 
 
4ba8e3d
 
 
 
 
 
a8a75f6
4ba8e3d
 
 
 
 
 
3b97d72
 
4ba8e3d
3b97d72
 
 
 
 
e7416bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from fastapi import APIRouter, Depends, HTTPException, Body, status

from config.db import get_db
from models.ml import MLModel 

from models.ml_inputs import MLInput
from models.ml_output import MLOutput

import pandas as pd
from model_loader import load_model           
from features import compute_features
from schemas.PredictItemResult import PredictItemResult
from schemas.PredictResponse import PredictResponse
from schemas.PredictRequest import PredictRequest
from sqlalchemy.orm import Session

router = APIRouter(prefix="/predict", tags=["Prédiction"])

LABELS = {
    "0": "reste_dans_l_entreprise",
    "1": "parti_de_l_entreprise",
}

@router.post(
    "/",
    response_model=PredictResponse,
    status_code=status.HTTP_200_OK,
    summary="Prédire l’attrition d’un employé",
    description=(
        "Calcule la probabilité d’attrition pour chaque entrée fournie.\n\n"
        "**Notes**\n"
        "- `model_name` doit référencer un modèle *actif* en base (`MLModel`).\n"
        "- Les données d’entrée sont persistées (`MLInput`) puis les sorties (`MLOutput`) sont enregistrées.\n"
        "- En cas d’erreur de features ou de prédiction, la requête retourne **400**.\n"
    ),
    responses={
        200: {"description": "Prédictions calculées avec succès."},
        400: {"description": "Erreur pendant la préparation des features ou la prédiction."},
        404: {"description": "Modèle introuvable ou inactif."},
        500: {"description": "Impossible de charger le modèle/erreur serveur."},
    },
)
def batch_predict(
    payload: PredictRequest = Body(
        ...,
        examples={
            "cas-minimal": {
                "summary": "Exemple minimal",
                "value": {
                    "model_name": "best_model",
                    "inputs": [
                        {
                            "id_employee": 123,
                            "age": 35,
                            "genre": "Homme",
                            "revenu_mensuel": 4200
                        }
                    ],
                },
            },
            "cas-complet": {
                "summary": "Exemple complet",
                "value": {
                    "model_name": "best_model",
                    "inputs": [
                        {
                            "id_employee": 123,
                            "age": 35,
                            "genre": "Homme",
                            "revenu_mensuel": 4200,
                            "statut_marital": "Célibataire",
                            "departement": "Ventes",
                            "poste": "Commercial",
                            "nombre_experiences_precedentes": 2,
                            "nombre_heures_travailless": 40,
                            "annee_experience_totale": 5,
                            "annees_dans_l_entreprise": 2,
                            "annees_dans_le_poste_actuel": 1,
                            "nombre_participation_pee": 1,
                            "nb_formations_suivies": 3,
                            "nombre_employee_sous_responsabilite": 0,
                            "code_sondage": 7,
                            "distance_domicile_travail": 12,
                            "niveau_education": 3,
                            "domaine_etude": "Marketing",
                            "ayant_enfants": "Non",
                            "frequence_deplacement": "Rarement",
                            "annees_depuis_la_derniere_promotion": 0,
                            "annes_sous_responsable_actuel": 1,
                            "satisfaction_employee_environnement": 3,
                            "note_evaluation_precedente": 4,
                            "niveau_hierarchique_poste": 2,
                            "satisfaction_employee_nature_travail": 3,
                            "satisfaction_employee_equipe": 4,
                            "satisfaction_employee_equilibre_pro_perso": 3,
                            "eval_number": "E2",
                            "note_evaluation_actuelle": 4,
                            "heure_supplementaires": "Non",
                            "augementation_salaire_precedente": 11
                        }
                    ],
                },
            },
        },
    ),
    db: Session = Depends(get_db),
):
    row = (
        db.query(MLModel)
        .filter(MLModel.name == payload.model_name)
        .first()
    )

    objs = [MLInput(**x.model_dump()) for x in payload.inputs]
    db.add_all(objs)
    db.commit()

    if not row or getattr(row, "is_active", True) is False:
        raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif")

    try:
        m = load_model(payload.model_name)
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Chargement du modèle '{payload.model_name}' impossible: {e}",
        )

    try:
        df = pd.DataFrame([x.model_dump() for x in payload.inputs])
        X = compute_features(df)

        results: list[PredictItemResult] = []

        probas = m.predict_proba(X)
        classes = getattr(m, "classes_", None)

        for idx, p in enumerate(probas):
            i = int(p.argmax())
            key = str(classes[i]) if classes is not None else str(i)
            label = LABELS.get(key, key)

            pred = PredictItemResult(label=label, proba=float(p[i]))
            results.append(pred)

            db.add(
                MLOutput(
                    input_id=objs[idx].id,
                    prediction=label,
                    prob=float(p[i]),
                )
            )

        db.commit()

    except Exception as e:
        db.rollback()
        raise HTTPException(status_code=400, detail=f"Erreur pendant la prédiction: {e}")

    return PredictResponse(
        model_name=payload.model_name,
        results=results,
    )