Spaces:
Sleeping
Sleeping
File size: 5,409 Bytes
e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 e348dc0 9f0dbb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from __future__ import annotations
import os
import json
import uuid
from pathlib import Path
from typing import Any, Dict, List
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from huggingface_hub import hf_hub_download
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from .schemas import PredictRequest, EmployeeInput # pydantic v2
# -----------------------
# Config & chargement ML
# -----------------------
THRESHOLD = 0.33
LOCAL_MODEL = Path("models/model.joblib")
HF_REPO_ID = "veranoscience/attrition-model"
HF_FILENAME = "model.joblib"
def load_pipeline():
"""Charge le pipeline sklearn (préproc + RF) depuis /models sinon depuis le Hub."""
if LOCAL_MODEL.exists():
return joblib.load(LOCAL_MODEL)
downloaded = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
return joblib.load(downloaded)
pipe = load_pipeline()
# -----------------------
# DB (optionnelle)
# -----------------------
DATABASE_URL = os.getenv("DATABASE_URL", "").strip()
db_engine = create_engine(DATABASE_URL, future=True) if DATABASE_URL else None
def log_prediction(source: str, inputs: dict | list, proba: float | None, pred: int | None,
threshold: float, model_version: str = "rf_reg@hub",
status: str = "ok", error_message: str | None = None) -> None:
"""Écrit une ligne dans ml.predictions_log si db_engine est dispo. Nève pas d'exception."""
if not db_engine:
# Pas de DB → on sort silencieusement
return
try:
payload = json.dumps(inputs, ensure_ascii=False)
with db_engine.begin() as conn:
conn.execute(
text("""
INSERT INTO ml.predictions_log
(request_id, source, input_payload, proba, pred, threshold, model_version, status, error_message)
VALUES
(:rid, :src, CAST(:payload AS JSONB), :proba, :pred, :thr, :version, :status, :err)
"""),
{
"rid": str(uuid.uuid4()),
"src": source,
"payload": payload,
"proba": proba,
"pred": pred,
"thr": threshold,
"version": model_version,
"status": status,
"err": error_message
}
)
except Exception as e:
# On évite de casser l'API pour un souci DB; mais on log en console pour debug
print(f"[WARN] Échec log_prediction: {e}")
# -----------------------
# FastAPI
# -----------------------
app = FastAPI(
title="Attrition API",
description="Prédiction de probabilité de démission (attrition) via un pipeline scikit-learn.",
version="0.2.0",
)
@app.get("/")
def root():
# redirige vers la doc OpenAPI
return RedirectResponse(url="/docs")
@app.get("/health")
def health():
src = str(LOCAL_MODEL) if LOCAL_MODEL.exists() else f"hub:{HF_REPO_ID}/{HF_FILENAME}"
return {
"status": "ok",
"model_source": src,
"threshold": THRESHOLD,
"db_connected": bool(db_engine)
}
@app.post("/predict_one")
def predict_one(emp: EmployeeInput):
"""
Reçoit un objet EmployeeInput et retourne {threshold, proba, pred}.
"""
try:
X = pd.DataFrame([emp.model_dump()])
proba = float(pipe.predict_proba(X)[:, 1][0])
pred = int(proba >= THRESHOLD)
# logging best-effort
log_prediction(
source="api",
inputs=emp.model_dump(),
proba=proba,
pred=pred,
threshold=THRESHOLD,
status="ok",
)
return {"threshold": THRESHOLD, "proba": proba, "pred": pred}
except Exception as e:
# log l’erreur côté DB si possible
log_prediction(
source="api",
inputs=emp.model_dump(),
proba=None, pred=None, threshold=THRESHOLD,
status="error", error_message=str(e)
)
raise HTTPException(status_code=400, detail=f"Erreur de prédiction: {e}")
@app.post("/predict_proba")
def predict_proba(req: PredictRequest):
"""
Reçoit {"inputs": [EmployeeInput, ...]} et retourne {threshold, probas[], preds[]}.
"""
try:
rows = [item.model_dump() for item in req.inputs]
X = pd.DataFrame(rows)
probas = pipe.predict_proba(X)[:, 1]
preds = (probas >= THRESHOLD).astype(int)
# logging succès
log_prediction(
source="api",
inputs=rows,
proba=float(probas.mean()) if len(probas) else None, # exemple: proba moyenne
pred=int(preds.mean() >= 0.5) if len(preds) else None, # petite synthèse
threshold=THRESHOLD
)
return {
"threshold": THRESHOLD,
"probas": [float(p) for p in probas],
"preds": preds.tolist(),
}
except Exception as e:
log_prediction(
source="api",
inputs=[item.model_dump() for item in req.inputs],
proba=None, pred=None, threshold=THRESHOLD,
status="error", error_message=str(e)
)
raise HTTPException(status_code=400, detail=f"Erreur de prédiction: {e}")
|