Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException, Body, status | |
| from config.db import get_db | |
| from models.ml import MLModel | |
| from models.ml_inputs import MLInput | |
| from models.ml_output import MLOutput | |
| import pandas as pd | |
| from model_loader import load_model | |
| from features import compute_features | |
| from schemas.PredictItemResult import PredictItemResult | |
| from schemas.PredictResponse import PredictResponse | |
| from schemas.PredictRequest import PredictRequest | |
| from sqlalchemy.orm import Session | |
| router = APIRouter(prefix="/predict", tags=["Prédiction"]) | |
| LABELS = { | |
| "0": "reste_dans_l_entreprise", | |
| "1": "parti_de_l_entreprise", | |
| } | |
| def batch_predict( | |
| payload: PredictRequest = Body( | |
| ..., | |
| examples={ | |
| "cas-minimal": { | |
| "summary": "Exemple minimal", | |
| "value": { | |
| "model_name": "best_model", | |
| "inputs": [ | |
| { | |
| "id_employee": 123, | |
| "age": 35, | |
| "genre": "Homme", | |
| "revenu_mensuel": 4200 | |
| } | |
| ], | |
| }, | |
| }, | |
| "cas-complet": { | |
| "summary": "Exemple complet", | |
| "value": { | |
| "model_name": "best_model", | |
| "inputs": [ | |
| { | |
| "id_employee": 123, | |
| "age": 35, | |
| "genre": "Homme", | |
| "revenu_mensuel": 4200, | |
| "statut_marital": "Célibataire", | |
| "departement": "Ventes", | |
| "poste": "Commercial", | |
| "nombre_experiences_precedentes": 2, | |
| "nombre_heures_travailless": 40, | |
| "annee_experience_totale": 5, | |
| "annees_dans_l_entreprise": 2, | |
| "annees_dans_le_poste_actuel": 1, | |
| "nombre_participation_pee": 1, | |
| "nb_formations_suivies": 3, | |
| "nombre_employee_sous_responsabilite": 0, | |
| "code_sondage": 7, | |
| "distance_domicile_travail": 12, | |
| "niveau_education": 3, | |
| "domaine_etude": "Marketing", | |
| "ayant_enfants": "Non", | |
| "frequence_deplacement": "Rarement", | |
| "annees_depuis_la_derniere_promotion": 0, | |
| "annes_sous_responsable_actuel": 1, | |
| "satisfaction_employee_environnement": 3, | |
| "note_evaluation_precedente": 4, | |
| "niveau_hierarchique_poste": 2, | |
| "satisfaction_employee_nature_travail": 3, | |
| "satisfaction_employee_equipe": 4, | |
| "satisfaction_employee_equilibre_pro_perso": 3, | |
| "eval_number": "E2", | |
| "note_evaluation_actuelle": 4, | |
| "heure_supplementaires": "Non", | |
| "augementation_salaire_precedente": 11 | |
| } | |
| ], | |
| }, | |
| }, | |
| }, | |
| ), | |
| db: Session = Depends(get_db), | |
| ): | |
| row = ( | |
| db.query(MLModel) | |
| .filter(MLModel.name == payload.model_name) | |
| .first() | |
| ) | |
| objs = [MLInput(**x.model_dump()) for x in payload.inputs] | |
| db.add_all(objs) | |
| db.commit() | |
| if not row or getattr(row, "is_active", True) is False: | |
| raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif") | |
| try: | |
| m = load_model(payload.model_name) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Chargement du modèle '{payload.model_name}' impossible: {e}", | |
| ) | |
| try: | |
| df = pd.DataFrame([x.model_dump() for x in payload.inputs]) | |
| X = compute_features(df) | |
| results: list[PredictItemResult] = [] | |
| probas = m.predict_proba(X) | |
| classes = getattr(m, "classes_", None) | |
| for idx, p in enumerate(probas): | |
| i = int(p.argmax()) | |
| key = str(classes[i]) if classes is not None else str(i) | |
| label = LABELS.get(key, key) | |
| pred = PredictItemResult(label=label, proba=float(p[i])) | |
| results.append(pred) | |
| db.add( | |
| MLOutput( | |
| input_id=objs[idx].id, | |
| prediction=label, | |
| prob=float(p[i]), | |
| ) | |
| ) | |
| db.commit() | |
| except Exception as e: | |
| db.rollback() | |
| raise HTTPException(status_code=400, detail=f"Erreur pendant la prédiction: {e}") | |
| return PredictResponse( | |
| model_name=payload.model_name, | |
| results=results, | |
| ) |