File size: 1,957 Bytes
e348dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from pathlib import Path
import joblib
import pandas as pd
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download

from .schemas import PredictRequest, EmployeeInput

THRESHOLD = 0.33

LOCAL_MODEL = Path("models/model.joblib")
HF_REPO_ID = "veranoscience/attrition-model"
HF_FILENAME = "model.joblib"

def load_pipeline():
    if LOCAL_MODEL.exists():
        return joblib.load(LOCAL_MODEL)
    downloaded = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
    return joblib.load(downloaded)

pipe = load_pipeline()

app = FastAPI(
    title="Attrition API",
    description="Prédiction de probabilité de démission (attrition) via un pipeline scikit-learn.",
    version="0.1.0",
)

@app.get("/health")
def health():
    src = str(LOCAL_MODEL) if LOCAL_MODEL.exists() else f"hub:{HF_REPO_ID}/{HF_FILENAME}"
    return {"status": "ok", "model_source": src, "threshold": THRESHOLD}

@app.post("/predict_proba")
def predict_proba(req: PredictRequest):
    try:
        rows = [item.model_dump() for item in req.inputs]
        X = pd.DataFrame(rows)
        probas = pipe.predict_proba(X)[:, 1]
        preds = (probas >= THRESHOLD).astype(int)
        return {
            "threshold": THRESHOLD,
            "probas": [float(p) for p in probas],
            "preds": preds.tolist(),
        }
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Erreur de prédiction: {e}")

@app.post("/predict_one")
def predict_one(emp: EmployeeInput):
    try:
        X = pd.DataFrame([emp.model_dump()])
        proba = float(pipe.predict_proba(X)[:, 1][0])
        pred = int(proba >= THRESHOLD)
        return {"threshold": THRESHOLD, "proba": proba, "pred": pred}
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Erreur de prédiction: {e}")
    
from fastapi.responses import RedirectResponse

@app.get("/")
def root():
    return RedirectResponse(url="/docs")