oc_p5-dev / app.py
ASI-Engineer's picture
Upload folder using huggingface_hub
5709d57 verified
raw
history blame
7.37 kB
#!/usr/bin/env python3
"""
API FastAPI pour le modèle Employee Turnover.
Cette API expose le modèle de prédiction de départ des employés avec :
- Validation stricte des inputs via Pydantic
- Preprocessing automatique
- Health check pour monitoring
- Documentation OpenAPI/Swagger automatique
- Interface Gradio pour utilisation interactive
"""
import time
from contextlib import asynccontextmanager
import gradio as gr
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from src.auth import verify_api_key
from src.config import get_settings
from src.gradio_ui import create_gradio_interface
from src.logger import logger, log_model_load, log_request
from src.models import get_model_info, load_model
from src.preprocessing import preprocess_for_prediction
from src.rate_limit import limiter
from src.schemas import EmployeeInput, HealthCheck, PredictionOutput
# Charger la configuration
settings = get_settings()
API_VERSION = settings.API_VERSION
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Gestion du cycle de vie de l'application.
Charge le modèle au démarrage et le garde en cache.
"""
logger.info(
"🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION}
)
start_time = time.time()
try:
# Pré-charger le modèle au démarrage
model = load_model()
duration_ms = (time.time() - start_time) * 1000
model_type = type(model).__name__
log_model_load(model_type, duration_ms, True)
logger.info("✅ Modèle chargé avec succès")
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
log_model_load("Unknown", duration_ms, False)
logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)})
yield # L'application tourne
logger.info("🛑 Arrêt de l'API")
# Créer l'application FastAPI
app = FastAPI(
title="Employee Turnover Prediction API",
description="API de prédiction du turnover des employés avec XGBoost + SMOTE",
version=API_VERSION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
)
# Ajouter rate limiting
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Configurer CORS (autoriser tous les domaines en dev)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Middleware de logging des requêtes
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""
Middleware pour logger toutes les requêtes HTTP.
"""
start_time = time.time()
# Traiter la requête
response = await call_next(request)
# Calculer la durée
duration_ms = (time.time() - start_time) * 1000
# Logger
log_request(
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=duration_ms,
client_host=request.client.host if request.client else None,
)
return response
@app.get("/", tags=["Root"])
async def root():
"""
Endpoint racine avec informations sur l'API.
"""
return {
"message": "Employee Turnover Prediction API",
"version": API_VERSION,
"docs": "/docs",
"health": "/health",
"predict": "/predict (POST)",
}
@app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
async def health_check():
"""
Health check endpoint pour monitoring.
Vérifie que l'API est opérationnelle et que le modèle est chargé.
Returns:
HealthCheck: Status de l'API et du modèle.
Raises:
HTTPException: 503 si le modèle n'est pas disponible.
"""
try:
model_info = get_model_info()
return HealthCheck(
status="healthy",
model_loaded=model_info.get("cached", False),
model_type=model_info.get("model_type", "Unknown"),
version=API_VERSION,
)
except Exception as e:
raise HTTPException(
status_code=503,
detail={
"status": "unhealthy",
"error": "Model not available",
"message": str(e),
},
)
@app.post(
"/predict",
response_model=PredictionOutput,
tags=["Prediction"],
dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
)
@limiter.limit("20/minute")
async def predict(request: Request, employee: EmployeeInput):
"""
Endpoint de prédiction du turnover d'un employé.
**PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
Prend en entrée les données d'un employé, applique le preprocessing
et retourne la prédiction avec les probabilités.
Args:
employee: Données de l'employé validées par Pydantic.
Returns:
PredictionOutput: Prédiction et probabilités.
Raises:
HTTPException: 401 si API key invalide ou manquante.
HTTPException: 500 si erreur lors de la prédiction.
Examples:
```bash
# Avec authentification
curl -X POST http://localhost:8000/predict \\
-H "X-API-Key: your-secret-key" \\
-H "Content-Type: application/json" \\
-d '{...}'
```
"""
try:
# 1. Charger le modèle
model = load_model()
# 2. Préprocessing
X = preprocess_for_prediction(employee)
# 3. Prédiction
prediction = int(model.predict(X)[0])
# 4. Probabilités (si le modèle supporte predict_proba)
try:
probabilities = model.predict_proba(X)[0]
prob_0 = float(probabilities[0])
prob_1 = float(probabilities[1])
except AttributeError:
# Si le modèle ne supporte pas predict_proba
prob_0 = 1.0 if prediction == 0 else 0.0
prob_1 = 1.0 if prediction == 1 else 0.0
# 5. Niveau de risque
if prob_1 < 0.3:
risk_level = "Low"
elif prob_1 < 0.7:
risk_level = "Medium"
else:
risk_level = "High"
return PredictionOutput(
prediction=prediction,
probability_0=prob_0,
probability_1=prob_1,
risk_level=risk_level,
)
except Exception:
logger.exception("Unexpected error during prediction")
raise HTTPException(
status_code=500,
detail={
"error": "Prediction failed",
"message": "An unexpected error occurred. Please contact support.",
},
)
# Monter l'interface Gradio sur /ui
gradio_app = create_gradio_interface()
app = gr.mount_gradio_app(app, gradio_app, path="/ui")
if __name__ == "__main__":
import uvicorn
print("🚀 Lancement de l'API en mode développement...")
print("📖 Documentation : http://localhost:8000/docs")
print("🎨 Interface Gradio : http://localhost:8000/ui")
uvicorn.run(
"app:app",
host="0.0.0.0",
port=8000,
reload=True,
log_level="info",
)