deploy-machine-learning / src /controllers /predict_controller.py
marintosti12
final rework
92a1fd3
from fastapi import APIRouter, Depends, HTTPException, Body, status
from config.db import get_db
from models.ml import MLModel
from models.ml_inputs import MLInput
from models.ml_output import MLOutput
import pandas as pd
from model_loader import load_model
from features import compute_features
from schemas.PredictItemResult import PredictItemResult
from schemas.PredictResponse import PredictResponse
from schemas.PredictRequest import PredictRequest
from sqlalchemy.orm import Session
router = APIRouter(prefix="/predict", tags=["Prédiction"])
LABELS = {
"0": "reste_dans_l_entreprise",
"1": "parti_de_l_entreprise",
}
@router.post(
"/",
response_model=PredictResponse,
status_code=status.HTTP_200_OK,
summary="Prédire l’attrition d’un employé",
description=(
"Calcule la probabilité d’attrition pour chaque entrée fournie.\n\n"
"**Notes**\n"
"- `model_name` doit référencer un modèle *actif* en base (`MLModel`).\n"
"- Les données d’entrée sont persistées (`MLInput`) puis les sorties (`MLOutput`) sont enregistrées.\n"
"- En cas d’erreur de features ou de prédiction, la requête retourne **400**.\n"
),
responses={
200: {"description": "Prédictions calculées avec succès."},
400: {"description": "Erreur pendant la préparation des features ou la prédiction."},
404: {"description": "Modèle introuvable ou inactif."},
500: {"description": "Impossible de charger le modèle/erreur serveur."},
},
)
def batch_predict(
payload: PredictRequest = Body(
...,
examples={
"cas-minimal": {
"summary": "Exemple minimal",
"value": {
"model_name": "best_model",
"inputs": [
{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200
}
],
},
},
"cas-complet": {
"summary": "Exemple complet",
"value": {
"model_name": "best_model",
"inputs": [
{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200,
"statut_marital": "Célibataire",
"departement": "Ventes",
"poste": "Commercial",
"nombre_experiences_precedentes": 2,
"nombre_heures_travailless": 40,
"annee_experience_totale": 5,
"annees_dans_l_entreprise": 2,
"annees_dans_le_poste_actuel": 1,
"nombre_participation_pee": 1,
"nb_formations_suivies": 3,
"nombre_employee_sous_responsabilite": 0,
"code_sondage": 7,
"distance_domicile_travail": 12,
"niveau_education": 3,
"domaine_etude": "Marketing",
"ayant_enfants": "Non",
"frequence_deplacement": "Rarement",
"annees_depuis_la_derniere_promotion": 0,
"annes_sous_responsable_actuel": 1,
"satisfaction_employee_environnement": 3,
"note_evaluation_precedente": 4,
"niveau_hierarchique_poste": 2,
"satisfaction_employee_nature_travail": 3,
"satisfaction_employee_equipe": 4,
"satisfaction_employee_equilibre_pro_perso": 3,
"eval_number": "E2",
"note_evaluation_actuelle": 4,
"heure_supplementaires": "Non",
"augementation_salaire_precedente": 11
}
],
},
},
},
),
db: Session = Depends(get_db),
):
row = (
db.query(MLModel)
.filter(MLModel.name == payload.model_name)
.first()
)
objs = [MLInput(**x.model_dump()) for x in payload.inputs]
db.add_all(objs)
db.commit()
if not row or getattr(row, "is_active", True) is False:
raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif")
try:
m = load_model(payload.model_name)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Chargement du modèle '{payload.model_name}' impossible: {e}",
)
try:
df = pd.DataFrame([x.model_dump() for x in payload.inputs])
X = compute_features(df)
results: list[PredictItemResult] = []
probas = m.predict_proba(X)
classes = getattr(m, "classes_", None)
for idx, p in enumerate(probas):
i = int(p.argmax())
key = str(classes[i]) if classes is not None else str(i)
label = LABELS.get(key, key)
pred = PredictItemResult(label=label, proba=float(p[i]))
results.append(pred)
db.add(
MLOutput(
input_id=objs[idx].id,
prediction=label,
prob=float(p[i]),
)
)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=f"Erreur pendant la prédiction: {e}")
return PredictResponse(
model_name=payload.model_name,
results=results,
)