Spaces:
Sleeping
Sleeping
| from datetime import datetime, timezone | |
| from uuid import uuid4 | |
| import math | |
| from fastapi import APIRouter, Depends, HTTPException, Body, status | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy import insert | |
| import pandas as pd | |
| import numpy as np | |
| from src.config.db import get_db | |
| from src.models.ml import MLModel | |
| from src.models.ml_inputs import MLInput | |
| from src.models.ml_output import MLOutput | |
| from src.model_loader import load_model | |
| from src.features import compute_features | |
| from src.schemas.PredictItemResult import PredictItemResult | |
| from src.schemas.PredictResponse import PredictResponse | |
| from src.schemas.PredictRequest import PredictRequest | |
| from time import perf_counter | |
| router = APIRouter(prefix="/predict", tags=["Solvabilité"]) | |
| LABELS = { | |
| "0": "non_solvable", | |
| "1": "solvable", | |
| } | |
| def series_to_jsonable(s: pd.Series) -> dict: | |
| cleaned: dict = {} | |
| for k, v in s.items(): | |
| if v is pd.NaT: | |
| cleaned[k] = None | |
| continue | |
| if isinstance(v, np.floating): | |
| if np.isnan(v): | |
| cleaned[k] = None | |
| else: | |
| cleaned[k] = float(v) | |
| continue | |
| if isinstance(v, np.integer): | |
| cleaned[k] = int(v) | |
| continue | |
| if isinstance(v, float): | |
| if math.isnan(v) or math.isinf(v): | |
| cleaned[k] = None | |
| else: | |
| cleaned[k] = v | |
| continue | |
| cleaned[k] = v | |
| return cleaned | |
| def batch_predict( | |
| payload: PredictRequest = Body(...), | |
| db: Session = Depends(get_db), | |
| ): | |
| start_time = perf_counter() | |
| request_id = str(uuid4()) | |
| now = datetime.now(timezone.utc) | |
| row = db.query(MLModel).filter(MLModel.name == payload.model_name).first() | |
| if not row or getattr(row, "is_active", True) is False: | |
| raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif") | |
| try: | |
| model = load_model(payload.model_name) | |
| classes = getattr(model, "classes_", [0, 1]) | |
| classes = [int(c) for c in classes] | |
| except Exception as e: | |
| print(f"[ERROR] Chargement modèle: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Chargement du modèle '{payload.model_name}' impossible: {e}", | |
| ) | |
| try: | |
| df_raw = pd.DataFrame([x.model_dump() for x in payload.inputs]) | |
| try: | |
| X = compute_features(df_raw.copy()) | |
| except Exception: | |
| X = df_raw.copy() | |
| X = X.reset_index(drop=True) | |
| df_raw = df_raw.reset_index(drop=True) | |
| except Exception as e: | |
| print(f"[ERROR] Préparation features: {e}") | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Erreur de préparation des features: {e}", | |
| ) | |
| try: | |
| input_dicts = [] | |
| for i in range(len(df_raw)): | |
| raw_dict = series_to_jsonable(df_raw.iloc[i]) | |
| feat_dict = series_to_jsonable(X.iloc[i]) | |
| input_dicts.append({ | |
| "created_at": now, | |
| "model_name": payload.model_name, | |
| "raw_data": raw_dict, | |
| "features": feat_dict, | |
| }) | |
| stmt = insert(MLInput).returning(MLInput.id) | |
| result = db.execute(stmt, input_dicts) | |
| input_ids = [row[0] for row in result.fetchall()] | |
| except Exception as e: | |
| print(f"[ERROR] Bulk insert MLInput: {e}") | |
| db.rollback() | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Erreur lors de l'enregistrement des entrées: {e}", | |
| ) | |
| try: | |
| probas = model.predict_proba(X) | |
| i_def = classes.index(1) | |
| i_sol = classes.index(0) | |
| THRESH = 0.5 | |
| results: list[PredictItemResult] = [] | |
| output_dicts = [] | |
| elapsed_ms = int((perf_counter() - start_time) * 1000) | |
| for i, p in enumerate(probas): | |
| p_def = float(p[i_def]) | |
| p_sol = float(p[i_sol]) | |
| if p_def >= THRESH: | |
| label = "non_solvable" | |
| proba_retour = p_def | |
| else: | |
| label = "solvable" | |
| proba_retour = p_sol | |
| results.append( | |
| PredictItemResult(label=label, proba=proba_retour) | |
| ) | |
| output_dicts.append({ | |
| "input_id": input_ids[i], | |
| "model_name": payload.model_name, | |
| "model_version": getattr(row, "version", None), | |
| "prediction": label, | |
| "prob": proba_retour, | |
| "proba_defaut": p_def, | |
| "proba_solvable": p_sol, | |
| "threshold": THRESH, | |
| "classes": classes, | |
| "latency_ms": elapsed_ms, | |
| "meta": { | |
| "request_id": request_id, | |
| "elapsed_ms": elapsed_ms, | |
| }, | |
| "created_at": now, | |
| }) | |
| db.execute(insert(MLOutput), output_dicts) | |
| db.commit() | |
| except Exception as e: | |
| print(f"[ERROR] Prédiction/bulk insert: {e}") | |
| db.rollback() | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Erreur pendant la prédiction: {e}", | |
| ) | |
| return PredictResponse( | |
| model_name=payload.model_name, | |
| results=results, | |
| ) |