File size: 7,370 Bytes
5ff57b0
 
bffe28b
572fd8e
bffe28b
 
 
 
 
5709d57
5ff57b0
bffe28b
 
572fd8e
5709d57
bffe28b
 
 
 
5ff57b0
bffe28b
 
5709d57
bffe28b
 
 
 
 
5ff57b0
bffe28b
 
 
 
 
 
 
04e665f
bffe28b
04e665f
bffe28b
04e665f
bffe28b
 
 
 
 
5ff57b0
bffe28b
 
 
3cabed9
bffe28b
 
 
5ff57b0
bffe28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ff57b0
bffe28b
 
04e665f
bffe28b
 
5ff57b0
bffe28b
 
 
 
 
 
 
 
5ff57b0
bffe28b
5ff57b0
 
bffe28b
 
 
 
 
 
 
 
 
 
 
 
5ff57b0
 
bffe28b
 
5ff57b0
bffe28b
 
 
 
 
 
5ff57b0
bffe28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ff57b0
 
bffe28b
 
 
 
 
 
 
 
 
 
5ff57b0
bffe28b
5ff57b0
bffe28b
 
5ff57b0
bffe28b
 
5ff57b0
bffe28b
 
5ff57b0
bffe28b
 
 
5ff57b0
bffe28b
 
 
 
 
 
 
 
5ff57b0
bffe28b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ff57b0
 
5709d57
 
 
 
 
5ff57b0
bffe28b
 
 
 
5709d57
bffe28b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#!/usr/bin/env python3
"""
API FastAPI pour le modèle Employee Turnover.

Cette API expose le modèle de prédiction de départ des employés avec :
- Validation stricte des inputs via Pydantic
- Preprocessing automatique
- Health check pour monitoring
- Documentation OpenAPI/Swagger automatique
- Interface Gradio pour utilisation interactive
"""
import time
from contextlib import asynccontextmanager

import gradio as gr
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

from src.auth import verify_api_key
from src.config import get_settings
from src.gradio_ui import create_gradio_interface
from src.logger import logger, log_model_load, log_request
from src.models import get_model_info, load_model
from src.preprocessing import preprocess_for_prediction
from src.rate_limit import limiter
from src.schemas import EmployeeInput, HealthCheck, PredictionOutput

# Charger la configuration
settings = get_settings()
API_VERSION = settings.API_VERSION


@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Gestion du cycle de vie de l'application.

    Charge le modèle au démarrage et le garde en cache.
    """
    logger.info(
        "🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION}
    )

    start_time = time.time()
    try:
        # Pré-charger le modèle au démarrage
        model = load_model()
        duration_ms = (time.time() - start_time) * 1000

        model_type = type(model).__name__
        log_model_load(model_type, duration_ms, True)
        logger.info("✅ Modèle chargé avec succès")
    except Exception as e:
        duration_ms = (time.time() - start_time) * 1000
        log_model_load("Unknown", duration_ms, False)
        logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)})

    yield  # L'application tourne

    logger.info("🛑 Arrêt de l'API")


# Créer l'application FastAPI
app = FastAPI(
    title="Employee Turnover Prediction API",
    description="API de prédiction du turnover des employés avec XGBoost + SMOTE",
    version=API_VERSION,
    lifespan=lifespan,
    docs_url="/docs",
    redoc_url="/redoc",
)

# Ajouter rate limiting
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Configurer CORS (autoriser tous les domaines en dev)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Middleware de logging des requêtes
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """
    Middleware pour logger toutes les requêtes HTTP.
    """
    start_time = time.time()

    # Traiter la requête
    response = await call_next(request)

    # Calculer la durée
    duration_ms = (time.time() - start_time) * 1000

    # Logger
    log_request(
        method=request.method,
        path=request.url.path,
        status_code=response.status_code,
        duration_ms=duration_ms,
        client_host=request.client.host if request.client else None,
    )

    return response


@app.get("/", tags=["Root"])
async def root():
    """
    Endpoint racine avec informations sur l'API.
    """
    return {
        "message": "Employee Turnover Prediction API",
        "version": API_VERSION,
        "docs": "/docs",
        "health": "/health",
        "predict": "/predict (POST)",
    }


@app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
async def health_check():
    """
    Health check endpoint pour monitoring.

    Vérifie que l'API est opérationnelle et que le modèle est chargé.

    Returns:
        HealthCheck: Status de l'API et du modèle.

    Raises:
        HTTPException: 503 si le modèle n'est pas disponible.
    """
    try:
        model_info = get_model_info()

        return HealthCheck(
            status="healthy",
            model_loaded=model_info.get("cached", False),
            model_type=model_info.get("model_type", "Unknown"),
            version=API_VERSION,
        )
    except Exception as e:
        raise HTTPException(
            status_code=503,
            detail={
                "status": "unhealthy",
                "error": "Model not available",
                "message": str(e),
            },
        )


@app.post(
    "/predict",
    response_model=PredictionOutput,
    tags=["Prediction"],
    dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
)
@limiter.limit("20/minute")
async def predict(request: Request, employee: EmployeeInput):
    """
    Endpoint de prédiction du turnover d'un employé.

    **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.

    Prend en entrée les données d'un employé, applique le preprocessing
    et retourne la prédiction avec les probabilités.

    Args:
        employee: Données de l'employé validées par Pydantic.

    Returns:
        PredictionOutput: Prédiction et probabilités.

    Raises:
        HTTPException: 401 si API key invalide ou manquante.
        HTTPException: 500 si erreur lors de la prédiction.

    Examples:
        ```bash
        # Avec authentification
        curl -X POST http://localhost:8000/predict \\
          -H "X-API-Key: your-secret-key" \\
          -H "Content-Type: application/json" \\
          -d '{...}'
        ```
    """
    try:
        # 1. Charger le modèle
        model = load_model()

        # 2. Préprocessing
        X = preprocess_for_prediction(employee)

        # 3. Prédiction
        prediction = int(model.predict(X)[0])

        # 4. Probabilités (si le modèle supporte predict_proba)
        try:
            probabilities = model.predict_proba(X)[0]
            prob_0 = float(probabilities[0])
            prob_1 = float(probabilities[1])
        except AttributeError:
            # Si le modèle ne supporte pas predict_proba
            prob_0 = 1.0 if prediction == 0 else 0.0
            prob_1 = 1.0 if prediction == 1 else 0.0

        # 5. Niveau de risque
        if prob_1 < 0.3:
            risk_level = "Low"
        elif prob_1 < 0.7:
            risk_level = "Medium"
        else:
            risk_level = "High"

        return PredictionOutput(
            prediction=prediction,
            probability_0=prob_0,
            probability_1=prob_1,
            risk_level=risk_level,
        )

    except Exception:
        logger.exception("Unexpected error during prediction")
        raise HTTPException(
            status_code=500,
            detail={
                "error": "Prediction failed",
                "message": "An unexpected error occurred. Please contact support.",
            },
        )


# Monter l'interface Gradio sur /ui
gradio_app = create_gradio_interface()
app = gr.mount_gradio_app(app, gradio_app, path="/ui")


if __name__ == "__main__":
    import uvicorn

    print("🚀 Lancement de l'API en mode développement...")
    print("📖 Documentation : http://localhost:8000/docs")
    print("🎨 Interface Gradio : http://localhost:8000/ui")

    uvicorn.run(
        "app:app",
        host="0.0.0.0",
        port=8000,
        reload=True,
        log_level="info",
    )