Spaces:
Sleeping
Sleeping
File size: 5,996 Bytes
e7416bc 3b97d72 221bf32 3b97d72 4ba8e3d 3b97d72 4ba8e3d 3b97d72 e7416bc 3b97d72 e7416bc 4ba8e3d e7416bc 4ba8e3d 3b97d72 4ba8e3d 3b97d72 4ba8e3d 3b97d72 e7416bc 3b97d72 4ba8e3d 3b97d72 4ba8e3d a8a75f6 4ba8e3d 3b97d72 4ba8e3d 3b97d72 e7416bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | from fastapi import APIRouter, Depends, HTTPException, Body, status
from config.db import get_db
from models.ml import MLModel
from models.ml_inputs import MLInput
from models.ml_output import MLOutput
import pandas as pd
from model_loader import load_model
from features import compute_features
from schemas.PredictItemResult import PredictItemResult
from schemas.PredictResponse import PredictResponse
from schemas.PredictRequest import PredictRequest
from sqlalchemy.orm import Session
router = APIRouter(prefix="/predict", tags=["Prédiction"])
LABELS = {
"0": "reste_dans_l_entreprise",
"1": "parti_de_l_entreprise",
}
@router.post(
"/",
response_model=PredictResponse,
status_code=status.HTTP_200_OK,
summary="Prédire l’attrition d’un employé",
description=(
"Calcule la probabilité d’attrition pour chaque entrée fournie.\n\n"
"**Notes**\n"
"- `model_name` doit référencer un modèle *actif* en base (`MLModel`).\n"
"- Les données d’entrée sont persistées (`MLInput`) puis les sorties (`MLOutput`) sont enregistrées.\n"
"- En cas d’erreur de features ou de prédiction, la requête retourne **400**.\n"
),
responses={
200: {"description": "Prédictions calculées avec succès."},
400: {"description": "Erreur pendant la préparation des features ou la prédiction."},
404: {"description": "Modèle introuvable ou inactif."},
500: {"description": "Impossible de charger le modèle/erreur serveur."},
},
)
def batch_predict(
payload: PredictRequest = Body(
...,
examples={
"cas-minimal": {
"summary": "Exemple minimal",
"value": {
"model_name": "best_model",
"inputs": [
{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200
}
],
},
},
"cas-complet": {
"summary": "Exemple complet",
"value": {
"model_name": "best_model",
"inputs": [
{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200,
"statut_marital": "Célibataire",
"departement": "Ventes",
"poste": "Commercial",
"nombre_experiences_precedentes": 2,
"nombre_heures_travailless": 40,
"annee_experience_totale": 5,
"annees_dans_l_entreprise": 2,
"annees_dans_le_poste_actuel": 1,
"nombre_participation_pee": 1,
"nb_formations_suivies": 3,
"nombre_employee_sous_responsabilite": 0,
"code_sondage": 7,
"distance_domicile_travail": 12,
"niveau_education": 3,
"domaine_etude": "Marketing",
"ayant_enfants": "Non",
"frequence_deplacement": "Rarement",
"annees_depuis_la_derniere_promotion": 0,
"annes_sous_responsable_actuel": 1,
"satisfaction_employee_environnement": 3,
"note_evaluation_precedente": 4,
"niveau_hierarchique_poste": 2,
"satisfaction_employee_nature_travail": 3,
"satisfaction_employee_equipe": 4,
"satisfaction_employee_equilibre_pro_perso": 3,
"eval_number": "E2",
"note_evaluation_actuelle": 4,
"heure_supplementaires": "Non",
"augementation_salaire_precedente": 11
}
],
},
},
},
),
db: Session = Depends(get_db),
):
row = (
db.query(MLModel)
.filter(MLModel.name == payload.model_name)
.first()
)
objs = [MLInput(**x.model_dump()) for x in payload.inputs]
db.add_all(objs)
db.commit()
if not row or getattr(row, "is_active", True) is False:
raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif")
try:
m = load_model(payload.model_name)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Chargement du modèle '{payload.model_name}' impossible: {e}",
)
try:
df = pd.DataFrame([x.model_dump() for x in payload.inputs])
X = compute_features(df)
results: list[PredictItemResult] = []
probas = m.predict_proba(X)
classes = getattr(m, "classes_", None)
for idx, p in enumerate(probas):
i = int(p.argmax())
key = str(classes[i]) if classes is not None else str(i)
label = LABELS.get(key, key)
pred = PredictItemResult(label=label, proba=float(p[i]))
results.append(pred)
db.add(
MLOutput(
input_id=objs[idx].id,
prediction=label,
prob=float(p[i]),
)
)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=f"Erreur pendant la prédiction: {e}")
return PredictResponse(
model_name=payload.model_name,
results=results,
) |