Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| API FastAPI pour le modèle Employee Turnover. | |
| Cette API expose le modèle de prédiction de départ des employés avec : | |
| - Validation stricte des inputs via Pydantic | |
| - Preprocessing automatique | |
| - Health check pour monitoring | |
| - Documentation OpenAPI/Swagger automatique | |
| - Interface Gradio pour utilisation interactive | |
| """ | |
| import time | |
| from contextlib import asynccontextmanager | |
| import gradio as gr | |
| from fastapi import Depends, FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from slowapi import _rate_limit_exceeded_handler | |
| from slowapi.errors import RateLimitExceeded | |
| from src.auth import verify_api_key | |
| from src.config import get_settings | |
| from src.gradio_ui import create_gradio_interface | |
| from src.logger import logger, log_model_load, log_request | |
| from src.models import get_model_info, load_model | |
| from src.preprocessing import preprocess_for_prediction | |
| from src.rate_limit import limiter | |
| from src.schemas import EmployeeInput, HealthCheck, PredictionOutput | |
| # Charger la configuration | |
| settings = get_settings() | |
| API_VERSION = settings.API_VERSION | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Gestion du cycle de vie de l'application. | |
| Charge le modèle au démarrage et le garde en cache. | |
| """ | |
| logger.info( | |
| "🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION} | |
| ) | |
| start_time = time.time() | |
| try: | |
| # Pré-charger le modèle au démarrage | |
| model = load_model() | |
| duration_ms = (time.time() - start_time) * 1000 | |
| model_type = type(model).__name__ | |
| log_model_load(model_type, duration_ms, True) | |
| logger.info("✅ Modèle chargé avec succès") | |
| except Exception as e: | |
| duration_ms = (time.time() - start_time) * 1000 | |
| log_model_load("Unknown", duration_ms, False) | |
| logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)}) | |
| yield # L'application tourne | |
| logger.info("🛑 Arrêt de l'API") | |
| # Créer l'application FastAPI | |
| app = FastAPI( | |
| title="Employee Turnover Prediction API", | |
| description="API de prédiction du turnover des employés avec XGBoost + SMOTE", | |
| version=API_VERSION, | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| # Ajouter rate limiting | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| # Configurer CORS (autoriser tous les domaines en dev) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Middleware de logging des requêtes | |
| async def log_requests(request: Request, call_next): | |
| """ | |
| Middleware pour logger toutes les requêtes HTTP. | |
| """ | |
| start_time = time.time() | |
| # Traiter la requête | |
| response = await call_next(request) | |
| # Calculer la durée | |
| duration_ms = (time.time() - start_time) * 1000 | |
| # Logger | |
| log_request( | |
| method=request.method, | |
| path=request.url.path, | |
| status_code=response.status_code, | |
| duration_ms=duration_ms, | |
| client_host=request.client.host if request.client else None, | |
| ) | |
| return response | |
| async def root(): | |
| """ | |
| Endpoint racine avec informations sur l'API. | |
| """ | |
| return { | |
| "message": "Employee Turnover Prediction API", | |
| "version": API_VERSION, | |
| "docs": "/docs", | |
| "health": "/health", | |
| "predict": "/predict (POST)", | |
| } | |
| async def health_check(): | |
| """ | |
| Health check endpoint pour monitoring. | |
| Vérifie que l'API est opérationnelle et que le modèle est chargé. | |
| Returns: | |
| HealthCheck: Status de l'API et du modèle. | |
| Raises: | |
| HTTPException: 503 si le modèle n'est pas disponible. | |
| """ | |
| try: | |
| model_info = get_model_info() | |
| return HealthCheck( | |
| status="healthy", | |
| model_loaded=model_info.get("cached", False), | |
| model_type=model_info.get("model_type", "Unknown"), | |
| version=API_VERSION, | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=503, | |
| detail={ | |
| "status": "unhealthy", | |
| "error": "Model not available", | |
| "message": str(e), | |
| }, | |
| ) | |
| async def predict(request: Request, employee: EmployeeInput): | |
| """ | |
| Endpoint de prédiction du turnover d'un employé. | |
| **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production. | |
| Prend en entrée les données d'un employé, applique le preprocessing | |
| et retourne la prédiction avec les probabilités. | |
| Args: | |
| employee: Données de l'employé validées par Pydantic. | |
| Returns: | |
| PredictionOutput: Prédiction et probabilités. | |
| Raises: | |
| HTTPException: 401 si API key invalide ou manquante. | |
| HTTPException: 500 si erreur lors de la prédiction. | |
| Examples: | |
| ```bash | |
| # Avec authentification | |
| curl -X POST http://localhost:8000/predict \\ | |
| -H "X-API-Key: your-secret-key" \\ | |
| -H "Content-Type: application/json" \\ | |
| -d '{...}' | |
| ``` | |
| """ | |
| try: | |
| # 1. Charger le modèle | |
| model = load_model() | |
| # 2. Préprocessing | |
| X = preprocess_for_prediction(employee) | |
| # 3. Prédiction | |
| prediction = int(model.predict(X)[0]) | |
| # 4. Probabilités (si le modèle supporte predict_proba) | |
| try: | |
| probabilities = model.predict_proba(X)[0] | |
| prob_0 = float(probabilities[0]) | |
| prob_1 = float(probabilities[1]) | |
| except AttributeError: | |
| # Si le modèle ne supporte pas predict_proba | |
| prob_0 = 1.0 if prediction == 0 else 0.0 | |
| prob_1 = 1.0 if prediction == 1 else 0.0 | |
| # 5. Niveau de risque | |
| if prob_1 < 0.3: | |
| risk_level = "Low" | |
| elif prob_1 < 0.7: | |
| risk_level = "Medium" | |
| else: | |
| risk_level = "High" | |
| return PredictionOutput( | |
| prediction=prediction, | |
| probability_0=prob_0, | |
| probability_1=prob_1, | |
| risk_level=risk_level, | |
| ) | |
| except Exception: | |
| logger.exception("Unexpected error during prediction") | |
| raise HTTPException( | |
| status_code=500, | |
| detail={ | |
| "error": "Prediction failed", | |
| "message": "An unexpected error occurred. Please contact support.", | |
| }, | |
| ) | |
| # Monter l'interface Gradio sur /ui | |
| gradio_app = create_gradio_interface() | |
| app = gr.mount_gradio_app(app, gradio_app, path="/ui") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("🚀 Lancement de l'API en mode développement...") | |
| print("📖 Documentation : http://localhost:8000/docs") | |
| print("🎨 Interface Gradio : http://localhost:8000/ui") | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| log_level="info", | |
| ) | |