Spaces:
Sleeping
Sleeping
File size: 1,957 Bytes
e348dc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from pathlib import Path
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download
from .schemas import PredictRequest, EmployeeInput
THRESHOLD = 0.33
LOCAL_MODEL = Path("models/model.joblib")
HF_REPO_ID = "veranoscience/attrition-model"
HF_FILENAME = "model.joblib"
def load_pipeline():
if LOCAL_MODEL.exists():
return joblib.load(LOCAL_MODEL)
downloaded = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
return joblib.load(downloaded)
pipe = load_pipeline()
app = FastAPI(
title="Attrition API",
description="Prédiction de probabilité de démission (attrition) via un pipeline scikit-learn.",
version="0.1.0",
)
@app.get("/health")
def health():
src = str(LOCAL_MODEL) if LOCAL_MODEL.exists() else f"hub:{HF_REPO_ID}/{HF_FILENAME}"
return {"status": "ok", "model_source": src, "threshold": THRESHOLD}
@app.post("/predict_proba")
def predict_proba(req: PredictRequest):
try:
rows = [item.model_dump() for item in req.inputs]
X = pd.DataFrame(rows)
probas = pipe.predict_proba(X)[:, 1]
preds = (probas >= THRESHOLD).astype(int)
return {
"threshold": THRESHOLD,
"probas": [float(p) for p in probas],
"preds": preds.tolist(),
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Erreur de prédiction: {e}")
@app.post("/predict_one")
def predict_one(emp: EmployeeInput):
try:
X = pd.DataFrame([emp.model_dump()])
proba = float(pipe.predict_proba(X)[:, 1][0])
pred = int(proba >= THRESHOLD)
return {"threshold": THRESHOLD, "proba": proba, "pred": pred}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Erreur de prédiction: {e}")
from fastapi.responses import RedirectResponse
@app.get("/")
def root():
return RedirectResponse(url="/docs") |