ASI-Engineer commited on
Commit
6f61606
·
verified ·
1 Parent(s): 2d2a6e8

Upload folder using huggingface_hub

Browse files
.env.example CHANGED
@@ -1,6 +1,11 @@
1
  # Configuration de l'API Employee Turnover
2
  # Copiez ce fichier vers .env et remplissez les valeurs
3
 
 
 
 
 
 
4
  # ===== SÉCURITÉ =====
5
  # Clé API pour protéger l'endpoint /predict
6
  # Générez une clé forte : python -c "import secrets; print(secrets.token_urlsafe(32))"
 
1
  # Configuration de l'API Employee Turnover
2
  # Copiez ce fichier vers .env et remplissez les valeurs
3
 
4
+ # ===== BASE DE DONNÉES =====
5
+ # URL de connexion PostgreSQL (avec credentials)
6
+ # Format: postgresql://username:password@host:port/database
7
+ DATABASE_URL=postgresql://ml_user:your-password-here@localhost:5432/oc_p5_db
8
+
9
  # ===== SÉCURITÉ =====
10
  # Clé API pour protéger l'endpoint /predict
11
  # Générez une clé forte : python -c "import secrets; print(secrets.token_urlsafe(32))"
.huggingface/space.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ sdk: docker
2
+ dockerfile: src/Dockerfile
3
+ build_context: .
README.md CHANGED
@@ -1,106 +1,324 @@
1
- ---
2
- title: Employee Turnover Prediction API
3
- emoji: 👔
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- pinned: true
8
- license: mit
9
- app_port: 7860
10
- ---
11
 
 
12
 
13
- # Employee Turnover Prediction API 🚀 (v3.2.1)
14
 
15
- API de prédiction du turnover des employés (XGBoost + SMOTE) avec endpoints batch, validation stricte et documentation à jour.
16
 
17
- ## 🎯 Fonctionnalités
18
-
19
- - ✅ Prédiction de turnover (0 = reste, 1 = part)
20
  - 📦 Endpoint batch CSV (3 fichiers bruts)
21
- - 🎛️ Sliders Gradio et schémas Pydantic alignés sur les min/max réels
22
- - 📊 Probabilités et niveau de risque (Low/Medium/High)
23
- - 🔐 Authentification API Key (obligatoire)
24
- - 📝 Logs structurés JSON
25
- - 🛡️ Rate limiting (20 req/min)
26
- - 📚 Documentation OpenAPI/Swagger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
28
 
29
- ## 🔗 Endpoints
 
 
30
 
31
- | Endpoint | Description |
32
- |----------|-------------|
33
- | `/docs` | Documentation interactive Swagger |
34
- | `/health` | Status de l'API |
35
- | `/ui` | Interface Gradio interactive |
36
- | `/predict` | Prédiction unitaire (JSON, contraintes réelles) |
37
- | `/predict/batch` | Prédiction batch (3 fichiers CSV bruts) |
38
 
 
 
 
39
 
40
- ## 🚀 Utilisation
41
 
42
- ### Prédiction unitaire (toutes contraintes appliquées)
43
  ```bash
44
- curl -X POST https://asi-engineer-oc-p5-dev.hf.space/predict \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  -H "Content-Type: application/json" \
46
- -H "X-API-Key: your-key" \
47
- -d '{
48
- "nombre_participation_pee": 0,
49
- "nb_formations_suivies": 2,
50
- "nombre_employee_sous_responsabilite": 1,
51
- "distance_domicile_travail": 15,
52
- "niveau_education": 3,
53
- "domaine_etude": "Infra & Cloud",
54
- "ayant_enfants": "Y",
55
- "frequence_deplacement": "Occasionnel",
56
- "annees_depuis_la_derniere_promotion": 2,
57
- "annes_sous_responsable_actuel": 5,
58
- "satisfaction_employee_environnement": 3,
59
- "note_evaluation_precedente": 4,
60
- "niveau_hierarchique_poste": 2,
61
- "satisfaction_employee_nature_travail": 3,
62
- "satisfaction_employee_equipe": 3,
63
- "satisfaction_employee_equilibre_pro_perso": 2,
64
- "note_evaluation_actuelle": 4,
65
- "heure_supplementaires": "Non",
66
- "augementation_salaire_precedente": 5.5,
67
- "age": 35,
68
- "genre": "M",
69
- "revenu_mensuel": 4500.0,
70
- "statut_marital": "Marié(e)",
71
- "departement": "Commercial",
72
- "poste": "Manager",
73
- "nombre_experiences_precedentes": 3,
74
- "nombre_heures_travailless": 80,
75
- "annee_experience_totale": 10,
76
- "annees_dans_l_entreprise": 5,
77
- "annees_dans_le_poste_actuel": 2
78
- }'
79
  ```
80
 
81
- ### Prédiction batch (3 fichiers CSV bruts)
 
 
 
82
  ```bash
83
- curl -X POST https://asi-engineer-oc-p5-dev.hf.space/predict/batch \
84
- -H "X-API-Key: your-key" \
85
- -F "sondage_file=@extrait_sondage.csv" \
86
- -F "eval_file=@extrait_eval.csv" \
87
- -F "sirh_file=@extrait_sirh.csv"
 
 
 
 
88
  ```
89
 
90
- **Réponse :**
91
- ```json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  {
93
  "total_employees": 1470,
94
- "predictions": [...],
 
 
 
95
  "summary": {
96
  "total_stay": 1169,
97
  "total_leave": 301,
98
- "high_risk_count": 222
 
 
99
  }
100
  }
101
  ```
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- ## 📚 Documentation complète
105
 
106
- Voir [docs/API.md](docs/API.md) ou le [GitHub Repository](https://github.com/chaton59/OC_P5) pour la documentation complète et les contraintes détaillées (min/max, enums, etc).
 
 
1
+ # 🚀 Employee Turnover Prediction API - v3.2.1
 
 
 
 
 
 
 
 
 
2
 
3
+ ## 📊 Vue d'ensemble
4
 
5
+ API REST de prédiction du turnover des employés basée sur un modèle XGBoost avec SMOTE.
6
 
 
7
 
8
+ **✨ Nouveautés v3.2.1** :
9
+ - 🎛️ Sliders Gradio et schémas Pydantic alignés sur les min/max réels des données d'entraînement
 
10
  - 📦 Endpoint batch CSV (3 fichiers bruts)
11
+ - 🔑 Authentification API Key (prod)
12
+ - 🔧 Correction preprocessing (scaling, ordre des colonnes)
13
+ - 📝 Documentation et exemples mis à jour
14
+
15
+ ## 🏗️ Architecture
16
+
17
+ ```
18
+ OC_P5/
19
+ ├── app.py # Point d'entrée FastAPI
20
+ ├── src/
21
+ │ ├── auth.py # Authentification API Key
22
+ │ ├── config.py # Configuration centralisée
23
+ │ ├── logger.py # Logging structuré (NOUVEAU)
24
+ │ ├── models.py # Chargement modèle HF Hub
25
+ │ ├── preprocessing.py # Pipeline preprocessing
26
+ │ ├── rate_limit.py # Rate limiting (NOUVEAU)
27
+ │ └── schemas.py # Validation Pydantic
28
+ ├── tests/ # Suite pytest (33 tests, 88% couverture)
29
+ ├── logs/ # Logs JSON (NOUVEAU)
30
+ │ ├── api.log # Tous les logs
31
+ │ └── error.log # Erreurs uniquement
32
+ ├── docs/ # Documentation
33
+ ├── ml_model/ # Scripts training
34
+ └── data/ # Données sources
35
+ ## 🗄️ Schéma de la Base de Données (PostgreSQL)
36
+
37
+ Schéma UML pour traçabilité ML (basé sur P5 prédiction turnover employé) :
38
+ ![Schéma BDD](docs/schema.png)
39
+
40
+ - **dataset** : Dataset original (référence pour tests/retraining). Colonnes adaptées au modèle de prédiction turnover.
41
+ - **ml_logs** : Logs inputs/outputs (JSON pour flexibilité, timestamp pour audits).
42
+
43
+ Choix : Structure relationnelle pour efficacité volume data ; sécurité via user dédié (ml_user).
44
+ Instructions : Voir create_db.py pour création.
45
+
46
+ 📖 **Guide complet pour débutants** : [docs/database_guide.md](docs/database_guide.md)
47
+
48
+ ### 💾 Insertion du Dataset
49
+ ```bash
50
+ # Insérer le dataset complet (1470 employés)
51
+ poetry run python scripts/insert_dataset.py
52
+
53
+ # Vérifier l'insertion
54
+ psql -h localhost -U ml_user -d oc_p5_db -c "SELECT COUNT(*) FROM dataset;"
55
+ ```
56
+
57
+ ### Prérequis
58
+ - Python 3.12+
59
+ - Poetry 1.7+
60
+ - Git
61
+
62
+ ### Setup rapide
63
 
64
+ ```bash
65
+ # 1. Cloner le repo
66
+ git clone https://github.com/chaton59/OC_P5.git
67
+ cd OC_P5
68
+
69
+ # 2. Installer les dépendances
70
+ poetry install
71
 
72
+ # 3. Configurer l'environnement
73
+ cp .env.example .env
74
+ # Éditer .env avec vos valeurs
75
 
76
+ # 4. Lancer l'API
77
+ poetry run uvicorn app:app --reload
 
 
 
 
 
78
 
79
+ # 5. Accéder à la documentation
80
+ # http://localhost:8000/docs
81
+ ```
82
 
83
+ ## 📝 Configuration (.env)
84
 
 
85
  ```bash
86
+ # Mode développement (désactive auth + active logs détaillés)
87
+ DEBUG=true
88
+
89
+ # API Key (requis en production)
90
+ API_KEY=your-secret-key-here
91
+
92
+ # Logging (DEBUG, INFO, WARNING, ERROR, CRITICAL)
93
+ LOG_LEVEL=INFO
94
+
95
+ # HuggingFace Model
96
+ HF_MODEL_REPO=ASI-Engineer/employee-turnover-model
97
+ MODEL_FILENAME=model/model.pkl
98
+ ```
99
+
100
+ ## 🔒 Authentification
101
+
102
+ ### Mode DEBUG (développement)
103
+ ```bash
104
+ # L'API Key n'est PAS requise
105
+ curl http://localhost:8000/predict -H "Content-Type: application/json" -d '{...}'
106
+ ```
107
+
108
+ ### Mode PRODUCTION
109
+ ```bash
110
+ # L'API Key est REQUISE
111
+ curl http://localhost:8000/predict \
112
+ -H "X-API-Key: your-secret-key" \
113
  -H "Content-Type: application/json" \
114
+ -d '{...}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ```
116
 
117
+
118
+ ## 📡 Endpoints
119
+
120
+ ### 🏥 Health Check
121
  ```bash
122
+ GET /health
123
+
124
+ # Réponse
125
+ {
126
+ "status": "healthy",
127
+ "model_loaded": true,
128
+ "model_type": "Pipeline",
129
+ "version": "3.2.1"
130
+ }
131
  ```
132
 
133
+ ### 🔮 Prédiction unitaire
134
+ ```bash
135
+ POST /predict
136
+ Content-Type: application/json
137
+ X-API-Key: your-key (en production)
138
+
139
+ # Payload (exemple, contraintes réelles appliquées)
140
+ {
141
+ "nombre_participation_pee": 0,
142
+ "nb_formations_suivies": 2,
143
+ "nombre_employee_sous_responsabilite": 1,
144
+ "distance_domicile_travail": 15,
145
+ "niveau_education": 3,
146
+ "domaine_etude": "Infra & Cloud",
147
+ "ayant_enfants": "Y",
148
+ "frequence_deplacement": "Occasionnel",
149
+ "annees_depuis_la_derniere_promotion": 2,
150
+ "annes_sous_responsable_actuel": 5,
151
+ "satisfaction_employee_environnement": 3,
152
+ "note_evaluation_precedente": 4,
153
+ "niveau_hierarchique_poste": 2,
154
+ "satisfaction_employee_nature_travail": 3,
155
+ "satisfaction_employee_equipe": 3,
156
+ "satisfaction_employee_equilibre_pro_perso": 2,
157
+ "note_evaluation_actuelle": 4,
158
+ "heure_supplementaires": "Non",
159
+ "augementation_salaire_precedente": 5.5,
160
+ "age": 35,
161
+ "genre": "M",
162
+ "revenu_mensuel": 4500.0,
163
+ "statut_marital": "Marié(e)",
164
+ "departement": "Commercial",
165
+ "poste": "Manager",
166
+ "nombre_experiences_precedentes": 3,
167
+ "nombre_heures_travailless": 80,
168
+ "annee_experience_totale": 10,
169
+ "annees_dans_l_entreprise": 5,
170
+ "annees_dans_le_poste_actuel": 2
171
+ }
172
+
173
+ # Réponse
174
+ {
175
+ "prediction": 0, # 0 = reste, 1 = part
176
+ "probability_0": 0.85, # Probabilité de rester
177
+ "probability_1": 0.15, # Probabilité de partir
178
+ "risk_level": "Low" # Low, Medium, High
179
+ }
180
+ ```
181
+
182
+ ### 📦 Prédiction batch (CSV)
183
+ ```bash
184
+ POST /predict/batch
185
+ X-API-Key: your-key (en production)
186
+
187
+ # Envoi des 3 fichiers CSV bruts
188
+ curl -X POST "http://localhost:8000/predict/batch" \
189
+ -H "X-API-Key: your-key" \
190
+ -F "sondage_file=@data/extrait_sondage.csv" \
191
+ -F "eval_file=@data/extrait_eval.csv" \
192
+ -F "sirh_file=@data/extrait_sirh.csv"
193
+
194
+ # Réponse
195
  {
196
  "total_employees": 1470,
197
+ "predictions": [
198
+ {"employee_id": 1, "prediction": 1, "probability_leave": 0.84, "risk_level": "High"},
199
+ {"employee_id": 2, "prediction": 0, "probability_leave": 0.11, "risk_level": "Low"}
200
+ ],
201
  "summary": {
202
  "total_stay": 1169,
203
  "total_leave": 301,
204
+ "high_risk_count": 222,
205
+ "medium_risk_count": 233,
206
+ "low_risk_count": 1015
207
  }
208
  }
209
  ```
210
 
211
+ ## 📊 Logging
212
+
213
+ ### Logs structurés JSON
214
+
215
+ **Fichiers** :
216
+ - `logs/api.log` : Tous les logs
217
+ - `logs/error.log` : Erreurs uniquement
218
+
219
+ **Format** :
220
+ ```json
221
+ {
222
+ "timestamp": "2025-12-26T10:30:45",
223
+ "level": "INFO",
224
+ "logger": "employee_turnover_api",
225
+ "message": "Request POST /predict",
226
+ "method": "POST",
227
+ "path": "/predict",
228
+ "status_code": 200,
229
+ "duration_ms": 23.45,
230
+ "client_host": "127.0.0.1"
231
+ }
232
+ ```
233
+
234
+ ## 🛡️ Rate Limiting
235
+
236
+ **Configuration** :
237
+ - **Développement** : Désactivé (DEBUG=true)
238
+ - **Production** : 20 requêtes/minute par IP ou API Key
239
+
240
+ **En cas de dépassement** :
241
+ ```json
242
+ {
243
+ "error": "Rate limit exceeded",
244
+ "message": "20 per 1 minute"
245
+ }
246
+ ```
247
+
248
+ ## ✅ Tests
249
+
250
+ ```bash
251
+ # Tous les tests
252
+ poetry run pytest tests/ -v
253
+
254
+ # Avec couverture
255
+ poetry run pytest tests/ --cov --cov-report=html
256
+
257
+ # Voir rapport HTML
258
+ open htmlcov/index.html
259
+ ```
260
+
261
+ **Résultats** :
262
+ - ✅ 33 tests passés
263
+ - 📊 88% de couverture globale
264
+
265
+ ## 🚀 Déploiement
266
+
267
+ ### Variables d'environnement requises
268
+ ```bash
269
+ DEBUG=false
270
+ API_KEY=<votre-clé-sécurisée>
271
+ LOG_LEVEL=INFO
272
+ ```
273
+
274
+ ### HuggingFace Spaces
275
+ Prêt pour déploiement avec `app.py` et `requirements.txt`
276
+
277
+ ## 📚 Documentation
278
+
279
+ - **API Interactive** : http://localhost:8000/docs
280
+ - **ReDoc** : http://localhost:8000/redoc
281
+ - **Guide complet** : [docs/API_GUIDE.md](docs/API_GUIDE.md)
282
+ - **Standards** : [docs/standards.md](docs/standards.md)
283
+ - **Couverture tests** : [docs/TEST_COVERAGE.md](docs/TEST_COVERAGE.md)
284
+
285
+ ## 📦 Dépendances principales
286
+
287
+ - **FastAPI** 0.115.14 : Framework web
288
+ - **Pydantic** 2.12.5 : Validation données
289
+ - **XGBoost** 2.1.3 : Modèle ML
290
+ - **SlowAPI** 0.1.9 : Rate limiting
291
+ - **python-json-logger** 4.0.0 : Logs structurés
292
+ - **pytest** 9.0.2 : Tests
293
+
294
+
295
+ ## 🔄 Changelog
296
+
297
+ ### v3.2.1 (janvier 2026)
298
+ - 🎛️ Sliders Gradio et schémas Pydantic alignés sur les min/max réels des données d'entraînement
299
+ - 📦 Endpoint batch CSV (3 fichiers bruts)
300
+ - 🔑 Authentification API Key (prod)
301
+ - 🔧 Correction preprocessing (scaling, ordre des colonnes)
302
+ - 📝 Documentation et exemples mis à jour
303
+
304
+ ### v2.2.0 (27 décembre 2025)
305
+ - 📦 Nouvel endpoint `/predict/batch` pour traitement CSV direct
306
+ - 🔧 Fix preprocessing : ajout du scaling des features
307
+ - 🔧 Fix preprocessing : correction de l'ordre des colonnes
308
+ - 📊 Amélioration précision des prédictions (~90%)
309
+
310
+ ### v2.1.0 (26 décembre 2025)
311
+ - ✨ Système de logging structuré JSON
312
+ - 🛡️ Rate limiting avec SlowAPI
313
+ - ⚡ Amélioration gestion d'erreurs
314
+ - 📊 Monitoring des performances
315
+
316
+ ### v2.0.0 (26 décembre 2025)
317
+ - ✅ Suite de tests complète (36 tests)
318
+ - 🔐 Authentification API Key
319
+ - 📊 88% de couverture de code
320
 
321
+ ## 👥 Auteurs
322
 
323
+ - **Projet** : OpenClassrooms P5
324
+ - **Repo** : [github.com/chaton59/OC_P5](https://github.com/chaton59/OC_P5)
README_HF.md CHANGED
@@ -3,7 +3,7 @@ title: Employee Turnover Prediction API
3
  emoji: 👔
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: docker
7
  pinned: true
8
  license: mit
9
  app_port: 7860
 
3
  emoji: 👔
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: gradio
7
  pinned: true
8
  license: mit
9
  app_port: 7860
api.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ API FastAPI pour le modèle Employee Turnover.
4
+
5
+ Cette API expose le modèle de prédiction de départ des employés avec :
6
+ - Validation stricte des inputs via Pydantic
7
+ - Preprocessing automatique
8
+ - Health check pour monitoring
9
+ - Documentation OpenAPI/Swagger automatique
10
+ - Interface Gradio pour utilisation interactive
11
+ - Endpoint batch pour traitement de fichiers CSV
12
+ """
13
+ import io
14
+ import time
15
+ from contextlib import asynccontextmanager
16
+
17
+ import gradio as gr
18
+ import pandas as pd
19
+ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from slowapi import _rate_limit_exceeded_handler
22
+ from slowapi.errors import RateLimitExceeded
23
+
24
+ from src.auth import verify_api_key
25
+ from src.config import get_settings
26
+ from src.gradio_ui import create_gradio_interface
27
+ from src.logger import logger, log_model_load, log_request
28
+ from src.models import get_model_info, load_model
29
+ from src.preprocessing import (
30
+ merge_csv_dataframes,
31
+ preprocess_dataframe_for_prediction,
32
+ preprocess_for_prediction,
33
+ )
34
+ from src.rate_limit import limiter
35
+ from src.schemas import (
36
+ BatchPredictionOutput,
37
+ EmployeeInput,
38
+ EmployeePrediction,
39
+ HealthCheck,
40
+ PredictionOutput,
41
+ )
42
+
43
+ # Charger la configuration
44
+ settings = get_settings()
45
+ API_VERSION = settings.API_VERSION
46
+
47
+
48
+ @asynccontextmanager
49
+ async def lifespan(app: FastAPI):
50
+ """
51
+ Gestion du cycle de vie de l'application.
52
+
53
+ Charge le modèle au démarrage et le garde en cache.
54
+ """
55
+ logger.info(
56
+ "🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION}
57
+ )
58
+
59
+ start_time = time.time()
60
+ try:
61
+ # Pré-charger le modèle au démarrage
62
+ model = load_model()
63
+ duration_ms = (time.time() - start_time) * 1000
64
+
65
+ model_type = type(model).__name__
66
+ log_model_load(model_type, duration_ms, True)
67
+ logger.info("✅ Modèle chargé avec succès")
68
+ except Exception as e:
69
+ duration_ms = (time.time() - start_time) * 1000
70
+ log_model_load("Unknown", duration_ms, False)
71
+ logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)})
72
+
73
+ yield # L'application tourne
74
+
75
+ logger.info("🛑 Arrêt de l'API")
76
+
77
+
78
+ # Créer l'application FastAPI
79
+ app = FastAPI(
80
+ title="Employee Turnover Prediction API",
81
+ description="API de prédiction du turnover des employés avec XGBoost + SMOTE",
82
+ version=API_VERSION,
83
+ lifespan=lifespan,
84
+ docs_url="/docs",
85
+ redoc_url="/redoc",
86
+ )
87
+
88
+ # Ajouter rate limiting
89
+ app.state.limiter = limiter
90
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
91
+
92
+ # Configurer CORS (autoriser tous les domaines en dev)
93
+ app.add_middleware(
94
+ CORSMiddleware,
95
+ allow_origins=["*"],
96
+ allow_credentials=True,
97
+ allow_methods=["*"],
98
+ allow_headers=["*"],
99
+ )
100
+
101
+
102
+ # Middleware de logging des requêtes
103
+ @app.middleware("http")
104
+ async def log_requests(request: Request, call_next):
105
+ """
106
+ Middleware pour logger toutes les requêtes HTTP.
107
+ """
108
+ start_time = time.time()
109
+
110
+ # Traiter la requête
111
+ response = await call_next(request)
112
+
113
+ # Calculer la durée
114
+ duration_ms = (time.time() - start_time) * 1000
115
+
116
+ # Logger
117
+ log_request(
118
+ method=request.method,
119
+ path=request.url.path,
120
+ status_code=response.status_code,
121
+ duration_ms=duration_ms,
122
+ client_host=request.client.host if request.client else None,
123
+ )
124
+
125
+ return response
126
+
127
+
128
+ @app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
129
+ async def health_check():
130
+ """
131
+ Health check endpoint pour monitoring.
132
+
133
+ Vérifie que l'API est opérationnelle et que le modèle est chargé.
134
+
135
+ Returns:
136
+ HealthCheck: Status de l'API et du modèle.
137
+
138
+ Raises:
139
+ HTTPException: 503 si le modèle n'est pas disponible.
140
+ """
141
+ try:
142
+ model_info = get_model_info()
143
+
144
+ return HealthCheck(
145
+ status="healthy",
146
+ model_loaded=model_info.get("cached", False),
147
+ model_type=model_info.get("model_type", "Unknown"),
148
+ version=API_VERSION,
149
+ )
150
+ except Exception as e:
151
+ raise HTTPException(
152
+ status_code=503,
153
+ detail={
154
+ "status": "unhealthy",
155
+ "error": "Model not available",
156
+ "message": str(e),
157
+ },
158
+ )
159
+
160
+
161
+ @app.post(
162
+ "/predict",
163
+ response_model=PredictionOutput,
164
+ tags=["Prediction"],
165
+ dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
166
+ )
167
+ @limiter.limit("20/minute")
168
+ async def predict(request: Request, employee: EmployeeInput):
169
+ """
170
+ Endpoint de prédiction du turnover d'un employé.
171
+
172
+ **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
173
+
174
+ Prend en entrée les données d'un employé, applique le preprocessing
175
+ et retourne la prédiction avec les probabilités.
176
+
177
+ Args:
178
+ employee: Données de l'employé validées par Pydantic.
179
+
180
+ Returns:
181
+ PredictionOutput: Prédiction et probabilités.
182
+
183
+ Raises:
184
+ HTTPException: 401 si API key invalide ou manquante.
185
+ HTTPException: 500 si erreur lors de la prédiction.
186
+
187
+ Examples:
188
+ ```bash
189
+ # Avec authentification
190
+ curl -X POST http://localhost:8000/predict \\
191
+ -H "X-API-Key: your-secret-key" \\
192
+ -H "Content-Type: application/json" \\
193
+ -d '{...}'
194
+ ```
195
+ """
196
+ try:
197
+ # 1. Charger le modèle
198
+ model = load_model()
199
+
200
+ # 2. Préprocessing
201
+ X = preprocess_for_prediction(employee)
202
+
203
+ # 3. Prédiction
204
+ prediction = int(model.predict(X)[0])
205
+
206
+ # 4. Probabilités (si le modèle supporte predict_proba)
207
+ try:
208
+ probabilities = model.predict_proba(X)[0]
209
+ prob_0 = float(probabilities[0])
210
+ prob_1 = float(probabilities[1])
211
+ except AttributeError:
212
+ # Si le modèle ne supporte pas predict_proba
213
+ prob_0 = 1.0 if prediction == 0 else 0.0
214
+ prob_1 = 1.0 if prediction == 1 else 0.0
215
+
216
+ # 5. Niveau de risque
217
+ if prob_1 < 0.3:
218
+ risk_level = "Low"
219
+ elif prob_1 < 0.7:
220
+ risk_level = "Medium"
221
+ else:
222
+ risk_level = "High"
223
+
224
+ # 6. Enregistrer dans la base de données
225
+ try:
226
+ from sqlalchemy import create_engine
227
+ from sqlalchemy.orm import sessionmaker
228
+ from db_models import MLLog
229
+
230
+ engine = create_engine(settings.DATABASE_URL)
231
+ Session = sessionmaker(bind=engine)
232
+ session = Session()
233
+
234
+ log_entry = MLLog(
235
+ input_json=employee.model_dump(),
236
+ prediction="Oui" if prediction == 1 else "Non",
237
+ )
238
+ session.add(log_entry)
239
+ session.commit()
240
+ session.close()
241
+
242
+ logger.info(f"Prediction logged to database: {prediction}")
243
+ except Exception as db_error:
244
+ logger.warning(f"Failed to log prediction to database: {db_error}")
245
+
246
+ return PredictionOutput(
247
+ prediction=prediction,
248
+ probability_0=prob_0,
249
+ probability_1=prob_1,
250
+ risk_level=risk_level,
251
+ )
252
+
253
+ except Exception:
254
+ logger.exception("Unexpected error during prediction")
255
+ raise HTTPException(
256
+ status_code=500,
257
+ detail={
258
+ "error": "Prediction failed",
259
+ "message": "An unexpected error occurred. Please contact support.",
260
+ },
261
+ )
262
+
263
+
264
+ @app.post(
265
+ "/predict/batch",
266
+ response_model=BatchPredictionOutput,
267
+ tags=["Prediction"],
268
+ dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
269
+ )
270
+ @limiter.limit("5/minute")
271
+ async def predict_batch(
272
+ request: Request,
273
+ sondage_file: UploadFile = File(..., description="Fichier CSV du sondage"),
274
+ eval_file: UploadFile = File(..., description="Fichier CSV des évaluations"),
275
+ sirh_file: UploadFile = File(..., description="Fichier CSV SIRH"),
276
+ ):
277
+ """
278
+ Endpoint de prédiction batch à partir de fichiers CSV.
279
+
280
+ **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
281
+
282
+ Prend en entrée les 3 fichiers CSV (sondage, évaluation, SIRH),
283
+ les fusionne, applique le preprocessing et retourne les prédictions
284
+ pour tous les employés.
285
+
286
+ Args:
287
+ sondage_file: Fichier CSV contenant les données de sondage.
288
+ eval_file: Fichier CSV contenant les données d'évaluation.
289
+ sirh_file: Fichier CSV contenant les données SIRH.
290
+
291
+ Returns:
292
+ BatchPredictionOutput: Prédictions pour tous les employés.
293
+
294
+ Raises:
295
+ HTTPException: 400 si les fichiers sont invalides.
296
+ HTTPException: 500 si erreur lors du traitement.
297
+ """
298
+ try:
299
+ # 1. Lire les fichiers CSV
300
+ sondage_content = await sondage_file.read()
301
+ eval_content = await eval_file.read()
302
+ sirh_content = await sirh_file.read()
303
+
304
+ sondage_df = pd.read_csv(io.BytesIO(sondage_content))
305
+ eval_df = pd.read_csv(io.BytesIO(eval_content))
306
+ sirh_df = pd.read_csv(io.BytesIO(sirh_content))
307
+
308
+ logger.info(
309
+ f"Fichiers CSV chargés: sondage={len(sondage_df)}, "
310
+ f"eval={len(eval_df)}, sirh={len(sirh_df)} lignes"
311
+ )
312
+
313
+ # 2. Fusionner les DataFrames
314
+ merged_df = merge_csv_dataframes(sondage_df, eval_df, sirh_df)
315
+ employee_ids = merged_df["original_employee_id"].tolist()
316
+ merged_df = merged_df.drop(columns=["original_employee_id"])
317
+
318
+ # Supprimer la colonne cible si présente
319
+ if "a_quitte_l_entreprise" in merged_df.columns:
320
+ merged_df = merged_df.drop(columns=["a_quitte_l_entreprise"])
321
+
322
+ logger.info(f"DataFrame fusionné: {len(merged_df)} employés")
323
+
324
+ # 3. Preprocessing
325
+ X = preprocess_dataframe_for_prediction(merged_df)
326
+
327
+ # 4. Charger le modèle et prédire
328
+ model = load_model()
329
+ predictions = model.predict(X.values)
330
+ probabilities = model.predict_proba(X.values)
331
+
332
+ # 5. Construire la réponse
333
+ results = []
334
+ risk_counts = {"Low": 0, "Medium": 0, "High": 0}
335
+ leave_count = 0
336
+
337
+ for i, emp_id in enumerate(employee_ids):
338
+ prob_stay = float(probabilities[i][0])
339
+ prob_leave = float(probabilities[i][1])
340
+ pred = int(predictions[i])
341
+
342
+ if prob_leave < 0.3:
343
+ risk = "Low"
344
+ elif prob_leave < 0.7:
345
+ risk = "Medium"
346
+ else:
347
+ risk = "High"
348
+
349
+ risk_counts[risk] += 1
350
+ if pred == 1:
351
+ leave_count += 1
352
+
353
+ results.append(
354
+ EmployeePrediction(
355
+ employee_id=int(emp_id),
356
+ prediction=pred,
357
+ probability_stay=prob_stay,
358
+ probability_leave=prob_leave,
359
+ risk_level=risk,
360
+ )
361
+ )
362
+
363
+ summary = {
364
+ "total_stay": len(results) - leave_count,
365
+ "total_leave": leave_count,
366
+ "high_risk_count": risk_counts["High"],
367
+ "medium_risk_count": risk_counts["Medium"],
368
+ "low_risk_count": risk_counts["Low"],
369
+ }
370
+
371
+ logger.info(f"Prédictions terminées: {summary}")
372
+
373
+ return BatchPredictionOutput(
374
+ total_employees=len(results),
375
+ predictions=results,
376
+ summary=summary,
377
+ )
378
+
379
+ except pd.errors.EmptyDataError:
380
+ raise HTTPException(
381
+ status_code=400,
382
+ detail={
383
+ "error": "Empty CSV file",
384
+ "message": "Un des fichiers CSV est vide.",
385
+ },
386
+ )
387
+ except KeyError as e:
388
+ raise HTTPException(
389
+ status_code=400,
390
+ detail={
391
+ "error": "Missing column",
392
+ "message": f"Colonne manquante dans les CSV: {e}",
393
+ },
394
+ )
395
+ except Exception as e:
396
+ logger.exception("Unexpected error during batch prediction")
397
+ raise HTTPException(
398
+ status_code=500,
399
+ detail={
400
+ "error": "Batch prediction failed",
401
+ "message": str(e),
402
+ },
403
+ )
404
+
405
+
406
+ # Monter l'interface Gradio sur / (racine pour HuggingFace Spaces)
407
+ gradio_app = create_gradio_interface()
408
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
409
+
410
+
411
+ if __name__ == "__main__":
412
+ import uvicorn
413
+
414
+ print("\U0001f680 Lancement de l'API en mode d\u00e9veloppement...")
415
+ print("\U0001f4d6 Documentation : http://localhost:8000/docs")
416
+ print("\U0001f3a8 Interface Gradio : http://localhost:8000/")
417
+
418
+ uvicorn.run(
419
+ "app:app",
420
+ host="0.0.0.0",
421
+ port=8000,
422
+ reload=True,
423
+ log_level="info",
424
+ )
app.py CHANGED
@@ -1,402 +1,16 @@
1
  #!/usr/bin/env python3
2
  """
3
- API FastAPI pour le modèle Employee Turnover.
4
 
5
- Cette API expose le modèle de prédiction de départ des employés avec :
6
- - Validation stricte des inputs via Pydantic
7
- - Preprocessing automatique
8
- - Health check pour monitoring
9
- - Documentation OpenAPI/Swagger automatique
10
- - Interface Gradio pour utilisation interactive
11
- - Endpoint batch pour traitement de fichiers CSV
12
  """
13
- import io
14
- import time
15
- from contextlib import asynccontextmanager
16
 
17
- import gradio as gr
18
- import pandas as pd
19
- from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
20
- from fastapi.middleware.cors import CORSMiddleware
21
- from slowapi import _rate_limit_exceeded_handler
22
- from slowapi.errors import RateLimitExceeded
23
-
24
- from src.auth import verify_api_key
25
- from src.config import get_settings
26
- from src.gradio_ui import create_gradio_interface
27
- from src.logger import logger, log_model_load, log_request
28
- from src.models import get_model_info, load_model
29
- from src.preprocessing import (
30
- merge_csv_dataframes,
31
- preprocess_dataframe_for_prediction,
32
- preprocess_for_prediction,
33
- )
34
- from src.rate_limit import limiter
35
- from src.schemas import (
36
- BatchPredictionOutput,
37
- EmployeeInput,
38
- EmployeePrediction,
39
- HealthCheck,
40
- PredictionOutput,
41
- )
42
-
43
- # Charger la configuration
44
- settings = get_settings()
45
- API_VERSION = settings.API_VERSION
46
-
47
-
48
- @asynccontextmanager
49
- async def lifespan(app: FastAPI):
50
- """
51
- Gestion du cycle de vie de l'application.
52
-
53
- Charge le modèle au démarrage et le garde en cache.
54
- """
55
- logger.info(
56
- "🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION}
57
- )
58
-
59
- start_time = time.time()
60
- try:
61
- # Pré-charger le modèle au démarrage
62
- model = load_model()
63
- duration_ms = (time.time() - start_time) * 1000
64
-
65
- model_type = type(model).__name__
66
- log_model_load(model_type, duration_ms, True)
67
- logger.info("✅ Modèle chargé avec succès")
68
- except Exception as e:
69
- duration_ms = (time.time() - start_time) * 1000
70
- log_model_load("Unknown", duration_ms, False)
71
- logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)})
72
-
73
- yield # L'application tourne
74
-
75
- logger.info("🛑 Arrêt de l'API")
76
-
77
-
78
- # Créer l'application FastAPI
79
- app = FastAPI(
80
- title="Employee Turnover Prediction API",
81
- description="API de prédiction du turnover des employés avec XGBoost + SMOTE",
82
- version=API_VERSION,
83
- lifespan=lifespan,
84
- docs_url="/docs",
85
- redoc_url="/redoc",
86
- )
87
-
88
- # Ajouter rate limiting
89
- app.state.limiter = limiter
90
- app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
91
-
92
- # Configurer CORS (autoriser tous les domaines en dev)
93
- app.add_middleware(
94
- CORSMiddleware,
95
- allow_origins=["*"],
96
- allow_credentials=True,
97
- allow_methods=["*"],
98
- allow_headers=["*"],
99
- )
100
-
101
-
102
- # Middleware de logging des requêtes
103
- @app.middleware("http")
104
- async def log_requests(request: Request, call_next):
105
- """
106
- Middleware pour logger toutes les requêtes HTTP.
107
- """
108
- start_time = time.time()
109
-
110
- # Traiter la requête
111
- response = await call_next(request)
112
-
113
- # Calculer la durée
114
- duration_ms = (time.time() - start_time) * 1000
115
-
116
- # Logger
117
- log_request(
118
- method=request.method,
119
- path=request.url.path,
120
- status_code=response.status_code,
121
- duration_ms=duration_ms,
122
- client_host=request.client.host if request.client else None,
123
- )
124
-
125
- return response
126
-
127
-
128
- @app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
129
- async def health_check():
130
- """
131
- Health check endpoint pour monitoring.
132
-
133
- Vérifie que l'API est opérationnelle et que le modèle est chargé.
134
-
135
- Returns:
136
- HealthCheck: Status de l'API et du modèle.
137
-
138
- Raises:
139
- HTTPException: 503 si le modèle n'est pas disponible.
140
- """
141
- try:
142
- model_info = get_model_info()
143
-
144
- return HealthCheck(
145
- status="healthy",
146
- model_loaded=model_info.get("cached", False),
147
- model_type=model_info.get("model_type", "Unknown"),
148
- version=API_VERSION,
149
- )
150
- except Exception as e:
151
- raise HTTPException(
152
- status_code=503,
153
- detail={
154
- "status": "unhealthy",
155
- "error": "Model not available",
156
- "message": str(e),
157
- },
158
- )
159
-
160
-
161
- @app.post(
162
- "/predict",
163
- response_model=PredictionOutput,
164
- tags=["Prediction"],
165
- dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
166
- )
167
- @limiter.limit("20/minute")
168
- async def predict(request: Request, employee: EmployeeInput):
169
- """
170
- Endpoint de prédiction du turnover d'un employé.
171
-
172
- **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
173
-
174
- Prend en entrée les données d'un employé, applique le preprocessing
175
- et retourne la prédiction avec les probabilités.
176
-
177
- Args:
178
- employee: Données de l'employé validées par Pydantic.
179
-
180
- Returns:
181
- PredictionOutput: Prédiction et probabilités.
182
-
183
- Raises:
184
- HTTPException: 401 si API key invalide ou manquante.
185
- HTTPException: 500 si erreur lors de la prédiction.
186
-
187
- Examples:
188
- ```bash
189
- # Avec authentification
190
- curl -X POST http://localhost:8000/predict \\
191
- -H "X-API-Key: your-secret-key" \\
192
- -H "Content-Type: application/json" \\
193
- -d '{...}'
194
- ```
195
- """
196
- try:
197
- # 1. Charger le modèle
198
- model = load_model()
199
-
200
- # 2. Préprocessing
201
- X = preprocess_for_prediction(employee)
202
-
203
- # 3. Prédiction
204
- prediction = int(model.predict(X)[0])
205
-
206
- # 4. Probabilités (si le modèle supporte predict_proba)
207
- try:
208
- probabilities = model.predict_proba(X)[0]
209
- prob_0 = float(probabilities[0])
210
- prob_1 = float(probabilities[1])
211
- except AttributeError:
212
- # Si le modèle ne supporte pas predict_proba
213
- prob_0 = 1.0 if prediction == 0 else 0.0
214
- prob_1 = 1.0 if prediction == 1 else 0.0
215
-
216
- # 5. Niveau de risque
217
- if prob_1 < 0.3:
218
- risk_level = "Low"
219
- elif prob_1 < 0.7:
220
- risk_level = "Medium"
221
- else:
222
- risk_level = "High"
223
-
224
- return PredictionOutput(
225
- prediction=prediction,
226
- probability_0=prob_0,
227
- probability_1=prob_1,
228
- risk_level=risk_level,
229
- )
230
-
231
- except Exception:
232
- logger.exception("Unexpected error during prediction")
233
- raise HTTPException(
234
- status_code=500,
235
- detail={
236
- "error": "Prediction failed",
237
- "message": "An unexpected error occurred. Please contact support.",
238
- },
239
- )
240
-
241
-
242
- @app.post(
243
- "/predict/batch",
244
- response_model=BatchPredictionOutput,
245
- tags=["Prediction"],
246
- dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
247
- )
248
- @limiter.limit("5/minute")
249
- async def predict_batch(
250
- request: Request,
251
- sondage_file: UploadFile = File(..., description="Fichier CSV du sondage"),
252
- eval_file: UploadFile = File(..., description="Fichier CSV des évaluations"),
253
- sirh_file: UploadFile = File(..., description="Fichier CSV SIRH"),
254
- ):
255
- """
256
- Endpoint de prédiction batch à partir de fichiers CSV.
257
-
258
- **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
259
-
260
- Prend en entrée les 3 fichiers CSV (sondage, évaluation, SIRH),
261
- les fusionne, applique le preprocessing et retourne les prédictions
262
- pour tous les employés.
263
-
264
- Args:
265
- sondage_file: Fichier CSV contenant les données de sondage.
266
- eval_file: Fichier CSV contenant les données d'évaluation.
267
- sirh_file: Fichier CSV contenant les données SIRH.
268
-
269
- Returns:
270
- BatchPredictionOutput: Prédictions pour tous les employés.
271
-
272
- Raises:
273
- HTTPException: 400 si les fichiers sont invalides.
274
- HTTPException: 500 si erreur lors du traitement.
275
- """
276
- try:
277
- # 1. Lire les fichiers CSV
278
- sondage_content = await sondage_file.read()
279
- eval_content = await eval_file.read()
280
- sirh_content = await sirh_file.read()
281
-
282
- sondage_df = pd.read_csv(io.BytesIO(sondage_content))
283
- eval_df = pd.read_csv(io.BytesIO(eval_content))
284
- sirh_df = pd.read_csv(io.BytesIO(sirh_content))
285
-
286
- logger.info(
287
- f"Fichiers CSV chargés: sondage={len(sondage_df)}, "
288
- f"eval={len(eval_df)}, sirh={len(sirh_df)} lignes"
289
- )
290
-
291
- # 2. Fusionner les DataFrames
292
- merged_df = merge_csv_dataframes(sondage_df, eval_df, sirh_df)
293
- employee_ids = merged_df["original_employee_id"].tolist()
294
- merged_df = merged_df.drop(columns=["original_employee_id"])
295
-
296
- # Supprimer la colonne cible si présente
297
- if "a_quitte_l_entreprise" in merged_df.columns:
298
- merged_df = merged_df.drop(columns=["a_quitte_l_entreprise"])
299
-
300
- logger.info(f"DataFrame fusionné: {len(merged_df)} employés")
301
-
302
- # 3. Preprocessing
303
- X = preprocess_dataframe_for_prediction(merged_df)
304
-
305
- # 4. Charger le modèle et prédire
306
- model = load_model()
307
- predictions = model.predict(X.values)
308
- probabilities = model.predict_proba(X.values)
309
-
310
- # 5. Construire la réponse
311
- results = []
312
- risk_counts = {"Low": 0, "Medium": 0, "High": 0}
313
- leave_count = 0
314
-
315
- for i, emp_id in enumerate(employee_ids):
316
- prob_stay = float(probabilities[i][0])
317
- prob_leave = float(probabilities[i][1])
318
- pred = int(predictions[i])
319
-
320
- if prob_leave < 0.3:
321
- risk = "Low"
322
- elif prob_leave < 0.7:
323
- risk = "Medium"
324
- else:
325
- risk = "High"
326
-
327
- risk_counts[risk] += 1
328
- if pred == 1:
329
- leave_count += 1
330
-
331
- results.append(
332
- EmployeePrediction(
333
- employee_id=int(emp_id),
334
- prediction=pred,
335
- probability_stay=prob_stay,
336
- probability_leave=prob_leave,
337
- risk_level=risk,
338
- )
339
- )
340
-
341
- summary = {
342
- "total_stay": len(results) - leave_count,
343
- "total_leave": leave_count,
344
- "high_risk_count": risk_counts["High"],
345
- "medium_risk_count": risk_counts["Medium"],
346
- "low_risk_count": risk_counts["Low"],
347
- }
348
-
349
- logger.info(f"Prédictions terminées: {summary}")
350
-
351
- return BatchPredictionOutput(
352
- total_employees=len(results),
353
- predictions=results,
354
- summary=summary,
355
- )
356
-
357
- except pd.errors.EmptyDataError:
358
- raise HTTPException(
359
- status_code=400,
360
- detail={
361
- "error": "Empty CSV file",
362
- "message": "Un des fichiers CSV est vide.",
363
- },
364
- )
365
- except KeyError as e:
366
- raise HTTPException(
367
- status_code=400,
368
- detail={
369
- "error": "Missing column",
370
- "message": f"Colonne manquante dans les CSV: {e}",
371
- },
372
- )
373
- except Exception as e:
374
- logger.exception("Unexpected error during batch prediction")
375
- raise HTTPException(
376
- status_code=500,
377
- detail={
378
- "error": "Batch prediction failed",
379
- "message": str(e),
380
- },
381
- )
382
-
383
-
384
- # Monter l'interface Gradio sur / (racine pour HuggingFace Spaces)
385
- gradio_app = create_gradio_interface()
386
- app = gr.mount_gradio_app(app, gradio_app, path="/")
387
 
 
 
388
 
389
  if __name__ == "__main__":
390
- import uvicorn
391
-
392
- print("\U0001f680 Lancement de l'API en mode d\u00e9veloppement...")
393
- print("\U0001f4d6 Documentation : http://localhost:8000/docs")
394
- print("\U0001f3a8 Interface Gradio : http://localhost:8000/")
395
-
396
- uvicorn.run(
397
- "app:app",
398
- host="0.0.0.0",
399
- port=8000,
400
- reload=True,
401
- log_level="info",
402
- )
 
1
  #!/usr/bin/env python3
2
  """
3
+ App Gradio pour Hugging Face Spaces.
4
 
5
+ Lance l'interface Gradio pour la prédiction de turnover.
 
 
 
 
 
 
6
  """
7
+ import sys
8
+ import os
 
9
 
10
+ from src.gradio_ui import launch_standalone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Ajouter le répertoire src au path
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
14
 
15
  if __name__ == "__main__":
16
+ launch_standalone()
 
 
 
 
 
 
 
 
 
 
 
 
db_models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String, JSON, DateTime, func
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+
4
+ Base = declarative_base()
5
+
6
+
7
+ class Dataset(Base):
8
+ __tablename__ = "dataset"
9
+ id = Column(Integer, primary_key=True)
10
+ features_json = Column(JSON) # Features from sondage, eval, sirh data
11
+ target = Column(String) # Target: 'Oui' or 'Non' for turnover
12
+
13
+
14
+ class MLLog(Base):
15
+ __tablename__ = "ml_logs"
16
+ id = Column(Integer, primary_key=True)
17
+ input_json = Column(JSON) # Inputs flexibles (JSON for features variables)
18
+ prediction = Column(String) # Output ML ('Oui' or 'Non')
19
+ created_at = Column(DateTime, default=func.now()) # Timestamp auto pour traçabilité
requirements.txt CHANGED
@@ -1,122 +1,27 @@
1
- aiofiles==24.1.0 ; python_version >= "3.12" and python_version < "4.0"
2
- alembic==1.17.2 ; python_version >= "3.12" and python_version < "4.0"
3
- annotated-doc==0.0.4 ; python_version >= "3.12" and python_version < "4.0"
4
- annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
5
- anyio==4.12.0 ; python_version >= "3.12" and python_version < "4.0"
6
- audioop-lts==0.2.2 ; python_version >= "3.13" and python_version < "4.0"
7
- blinker==1.9.0 ; python_version >= "3.12" and python_version < "4.0"
8
- brotli==1.2.0 ; python_version >= "3.12" and python_version < "4.0"
9
- cachetools==6.2.4 ; python_version >= "3.12" and python_version < "4.0"
10
- certifi==2025.11.12 ; python_version >= "3.12" and python_version < "4.0"
11
- cffi==2.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy"
12
- charset-normalizer==3.4.4 ; python_version >= "3.12" and python_version < "4.0"
13
- click==8.3.1 ; python_version >= "3.12" and python_version < "4.0"
14
- cloudpickle==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
15
- colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
16
- contourpy==1.3.3 ; python_version >= "3.12" and python_version < "4.0"
17
- cryptography==46.0.3 ; python_version >= "3.12" and python_version < "4.0"
18
- cycler==0.12.1 ; python_version >= "3.12" and python_version < "4.0"
19
- databricks-sdk==0.76.0 ; python_version >= "3.12" and python_version < "4.0"
20
- deprecated==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
21
- docker==7.1.0 ; python_version >= "3.12" and python_version < "4.0"
22
- fastapi==0.127.1 ; python_version >= "3.12" and python_version < "4.0"
23
- ffmpy==1.0.0 ; python_version >= "3.12" and python_version < "4.0"
24
- filelock==3.20.1 ; python_version >= "3.12" and python_version < "4.0"
25
- flask-cors==6.0.2 ; python_version >= "3.12" and python_version < "4.0"
26
- flask==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
27
- fonttools==4.61.1 ; python_version >= "3.12" and python_version < "4.0"
28
- fsspec==2025.12.0 ; python_version >= "3.12" and python_version < "4.0"
29
- gitdb==4.0.12 ; python_version >= "3.12" and python_version < "4.0"
30
- gitpython==3.1.45 ; python_version >= "3.12" and python_version < "4.0"
31
- google-auth==2.45.0 ; python_version >= "3.12" and python_version < "4.0"
32
- gradio-client==2.0.2 ; python_version >= "3.12" and python_version < "4.0"
33
- gradio==6.2.0 ; python_version >= "3.12" and python_version < "4.0"
34
- graphene==3.4.3 ; python_version >= "3.12" and python_version < "4.0"
35
- graphql-core==3.2.7 ; python_version >= "3.12" and python_version < "4.0"
36
- graphql-relay==3.2.0 ; python_version >= "3.12" and python_version < "4.0"
37
- greenlet==3.3.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32")
38
- groovy==0.1.2 ; python_version >= "3.12" and python_version < "4.0"
39
- gunicorn==23.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_system != "Windows"
40
- h11==0.16.0 ; python_version >= "3.12" and python_version < "4.0"
41
- hf-xet==1.2.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "arm64" or platform_machine == "aarch64")
42
- httpcore==1.0.9 ; python_version >= "3.12" and python_version < "4.0"
43
- httptools==0.7.1 ; python_version >= "3.12" and python_version < "4.0"
44
- httpx==0.28.1 ; python_version >= "3.12" and python_version < "4.0"
45
- huey==2.5.5 ; python_version >= "3.12" and python_version < "4.0"
46
- huggingface-hub==1.2.3 ; python_version >= "3.12" and python_version < "4.0"
47
- idna==3.11 ; python_version >= "3.12" and python_version < "4.0"
48
- imbalanced-learn==0.13.0 ; python_version >= "3.12" and python_version < "4.0"
49
- importlib-metadata==8.7.1 ; python_version >= "3.12" and python_version < "4.0"
50
- itsdangerous==2.2.0 ; python_version >= "3.12" and python_version < "4.0"
51
- jinja2==3.1.6 ; python_version >= "3.12" and python_version < "4.0"
52
- joblib==1.5.3 ; python_version >= "3.12" and python_version < "4.0"
53
- kiwisolver==1.4.9 ; python_version >= "3.12" and python_version < "4.0"
54
- limits==5.6.0 ; python_version >= "3.12" and python_version < "4.0"
55
- mako==1.3.10 ; python_version >= "3.12" and python_version < "4.0"
56
- markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
57
- markupsafe==3.0.3 ; python_version >= "3.12" and python_version < "4.0"
58
- matplotlib==3.10.8 ; python_version >= "3.12" and python_version < "4.0"
59
- mdurl==0.1.2 ; python_version >= "3.12" and python_version < "4.0"
60
- mlflow-skinny==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
61
- mlflow-tracing==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
62
- mlflow==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
63
- numpy==2.4.0 ; python_version >= "3.12" and python_version < "4.0"
64
- nvidia-nccl-cu12==2.28.9 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine != "aarch64"
65
- opentelemetry-api==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
66
- opentelemetry-proto==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
67
- opentelemetry-sdk==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
68
- opentelemetry-semantic-conventions==0.60b1 ; python_version >= "3.12" and python_version < "4.0"
69
- orjson==3.11.5 ; python_version >= "3.12" and python_version < "4.0"
70
- packaging==25.0 ; python_version >= "3.12" and python_version < "4.0"
71
- pandas==2.3.3 ; python_version >= "3.12" and python_version < "4.0"
72
- pillow==12.0.0 ; python_version >= "3.12" and python_version < "4.0"
73
- protobuf==6.33.2 ; python_version >= "3.12" and python_version < "4.0"
74
- pyarrow==22.0.0 ; python_version >= "3.12" and python_version < "4.0"
75
- pyasn1-modules==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
76
- pyasn1==0.6.1 ; python_version >= "3.12" and python_version < "4.0"
77
- pycparser==2.23 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy" and implementation_name != "PyPy"
78
- pydantic-core==2.41.5 ; python_version >= "3.12" and python_version < "4.0"
79
- pydantic==2.12.5 ; python_version >= "3.12" and python_version < "4.0"
80
- pydub==0.25.1 ; python_version >= "3.12" and python_version < "4.0"
81
- pygments==2.19.2 ; python_version >= "3.12" and python_version < "4.0"
82
- pyparsing==3.3.1 ; python_version >= "3.12" and python_version < "4.0"
83
- python-dateutil==2.9.0.post0 ; python_version >= "3.12" and python_version < "4.0"
84
- python-dotenv==1.2.1 ; python_version >= "3.12" and python_version < "4.0"
85
- python-json-logger==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
86
- python-multipart==0.0.21 ; python_version >= "3.12" and python_version < "4.0"
87
- pytz==2025.2 ; python_version >= "3.12" and python_version < "4.0"
88
- pywin32==311 ; python_version >= "3.12" and python_version < "4.0" and sys_platform == "win32"
89
- pyyaml==6.0.3 ; python_version >= "3.12" and python_version < "4.0"
90
- requests==2.32.5 ; python_version >= "3.12" and python_version < "4.0"
91
- rich==14.2.0 ; python_version >= "3.12" and python_version < "4.0"
92
- rsa==4.9.1 ; python_version >= "3.12" and python_version < "4.0"
93
- safehttpx==0.1.7 ; python_version >= "3.12" and python_version < "4.0"
94
- scikit-learn==1.6.1 ; python_version >= "3.12" and python_version < "4.0"
95
- scipy==1.16.3 ; python_version >= "3.12" and python_version < "4.0"
96
- semantic-version==2.10.0 ; python_version >= "3.12" and python_version < "4.0"
97
- shellingham==1.5.4 ; python_version >= "3.12" and python_version < "4.0"
98
- six==1.17.0 ; python_version >= "3.12" and python_version < "4.0"
99
- sklearn-compat==0.1.5 ; python_version >= "3.12" and python_version < "4.0"
100
- slowapi==0.1.9 ; python_version >= "3.12" and python_version < "4.0"
101
- smmap==5.0.2 ; python_version >= "3.12" and python_version < "4.0"
102
- sqlalchemy==2.0.45 ; python_version >= "3.12" and python_version < "4.0"
103
- sqlparse==0.5.5 ; python_version >= "3.12" and python_version < "4.0"
104
- starlette==0.50.0 ; python_version >= "3.12" and python_version < "4.0"
105
- threadpoolctl==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
106
- tomlkit==0.13.3 ; python_version >= "3.12" and python_version < "4.0"
107
- tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
108
- typer-slim==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
109
- typer==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
110
- typing-extensions==4.15.0 ; python_version >= "3.12" and python_version < "4.0"
111
- typing-inspection==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
112
- tzdata==2025.3 ; python_version >= "3.12" and python_version < "4.0"
113
- urllib3==2.6.2 ; python_version >= "3.12" and python_version < "4.0"
114
- uvicorn==0.32.1 ; python_version >= "3.12" and python_version < "4.0"
115
- uvloop==0.22.1 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy"
116
- waitress==3.0.2 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
117
- watchfiles==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
118
- websockets==15.0.1 ; python_version >= "3.12" and python_version < "4.0"
119
- werkzeug==3.1.4 ; python_version >= "3.12" and python_version < "4.0"
120
- wrapt==2.0.1 ; python_version >= "3.12" and python_version < "4.0"
121
- xgboost==2.1.4 ; python_version >= "3.12" and python_version < "4.0"
122
- zipp==3.23.0 ; python_version >= "3.12" and python_version < "4.0"
 
1
+ # Requirements for Hugging Face Spaces (Gradio app)
2
+ # Minimal dependencies needed for the Gradio interface
3
+ # Generated from pyproject.toml with essential packages only
4
+
5
+ # Core ML libraries
6
+ scikit-learn>=1.6.0,<1.7.0
7
+ xgboost>=2.1.0,<3.0.0
8
+ numpy>=2.0.0,<3.0.0
9
+ pandas>=2.2.0,<3.0.0
10
+ joblib>=1.4.0,<2.0.0
11
+ scipy>=1.14.0,<2.0.0
12
+
13
+ # Gradio and web framework
14
+ gradio>=6.2.0,<7.0.0
15
+ fastapi>=0.127.0,<1.0.0
16
+ uvicorn[standard]>=0.32.0,<1.0.0
17
+ pydantic>=2.10.0,<3.0.0
18
+
19
+ # Data processing
20
+ imbalanced-learn>=0.13.0,<1.0.0
21
+
22
+ # Hugging Face
23
+ huggingface-hub>=1.2.0,<2.0.0
24
+
25
+ # Utilities
26
+ python-dotenv>=1.0.0,<2.0.0
27
+ python-json-logger>=4.0.0,<5.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements_full.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0 ; python_version >= "3.12" and python_version < "4.0"
2
+ alembic==1.17.2 ; python_version >= "3.12" and python_version < "4.0"
3
+ annotated-doc==0.0.4 ; python_version >= "3.12" and python_version < "4.0"
4
+ annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
5
+ anyio==4.12.0 ; python_version >= "3.12" and python_version < "4.0"
6
+ audioop-lts==0.2.2 ; python_version >= "3.13" and python_version < "4.0"
7
+ blinker==1.9.0 ; python_version >= "3.12" and python_version < "4.0"
8
+ brotli==1.2.0 ; python_version >= "3.12" and python_version < "4.0"
9
+ cachetools==6.2.4 ; python_version >= "3.12" and python_version < "4.0"
10
+ certifi==2025.11.12 ; python_version >= "3.12" and python_version < "4.0"
11
+ cffi==2.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy"
12
+ charset-normalizer==3.4.4 ; python_version >= "3.12" and python_version < "4.0"
13
+ click==8.3.1 ; python_version >= "3.12" and python_version < "4.0"
14
+ cloudpickle==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
16
+ contourpy==1.3.3 ; python_version >= "3.12" and python_version < "4.0"
17
+ cryptography==46.0.3 ; python_version >= "3.12" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.12" and python_version < "4.0"
19
+ databricks-sdk==0.76.0 ; python_version >= "3.12" and python_version < "4.0"
20
+ deprecated==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
21
+ docker==7.1.0 ; python_version >= "3.12" and python_version < "4.0"
22
+ fastapi==0.127.1 ; python_version >= "3.12" and python_version < "4.0"
23
+ ffmpy==1.0.0 ; python_version >= "3.12" and python_version < "4.0"
24
+ filelock==3.20.1 ; python_version >= "3.12" and python_version < "4.0"
25
+ flask-cors==6.0.2 ; python_version >= "3.12" and python_version < "4.0"
26
+ flask==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
27
+ fonttools==4.61.1 ; python_version >= "3.12" and python_version < "4.0"
28
+ fsspec==2025.12.0 ; python_version >= "3.12" and python_version < "4.0"
29
+ gitdb==4.0.12 ; python_version >= "3.12" and python_version < "4.0"
30
+ gitpython==3.1.45 ; python_version >= "3.12" and python_version < "4.0"
31
+ google-auth==2.45.0 ; python_version >= "3.12" and python_version < "4.0"
32
+ gradio-client==2.0.2 ; python_version >= "3.12" and python_version < "4.0"
33
+ gradio==6.2.0 ; python_version >= "3.12" and python_version < "4.0"
34
+ graphene==3.4.3 ; python_version >= "3.12" and python_version < "4.0"
35
+ graphql-core==3.2.7 ; python_version >= "3.12" and python_version < "4.0"
36
+ graphql-relay==3.2.0 ; python_version >= "3.12" and python_version < "4.0"
37
+ greenlet==3.3.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32")
38
+ groovy==0.1.2 ; python_version >= "3.12" and python_version < "4.0"
39
+ gunicorn==23.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_system != "Windows"
40
+ h11==0.16.0 ; python_version >= "3.12" and python_version < "4.0"
41
+ hf-xet==1.2.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "arm64" or platform_machine == "aarch64")
42
+ httpcore==1.0.9 ; python_version >= "3.12" and python_version < "4.0"
43
+ httptools==0.7.1 ; python_version >= "3.12" and python_version < "4.0"
44
+ httpx==0.28.1 ; python_version >= "3.12" and python_version < "4.0"
45
+ huey==2.5.5 ; python_version >= "3.12" and python_version < "4.0"
46
+ huggingface-hub==1.2.3 ; python_version >= "3.12" and python_version < "4.0"
47
+ idna==3.11 ; python_version >= "3.12" and python_version < "4.0"
48
+ imbalanced-learn==0.13.0 ; python_version >= "3.12" and python_version < "4.0"
49
+ importlib-metadata==8.7.1 ; python_version >= "3.12" and python_version < "4.0"
50
+ itsdangerous==2.2.0 ; python_version >= "3.12" and python_version < "4.0"
51
+ jinja2==3.1.6 ; python_version >= "3.12" and python_version < "4.0"
52
+ joblib==1.5.3 ; python_version >= "3.12" and python_version < "4.0"
53
+ kiwisolver==1.4.9 ; python_version >= "3.12" and python_version < "4.0"
54
+ limits==5.6.0 ; python_version >= "3.12" and python_version < "4.0"
55
+ mako==1.3.10 ; python_version >= "3.12" and python_version < "4.0"
56
+ markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
57
+ markupsafe==3.0.3 ; python_version >= "3.12" and python_version < "4.0"
58
+ matplotlib==3.10.8 ; python_version >= "3.12" and python_version < "4.0"
59
+ mdurl==0.1.2 ; python_version >= "3.12" and python_version < "4.0"
60
+ mlflow-skinny==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
61
+ mlflow-tracing==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
62
+ mlflow==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
63
+ numpy==2.4.0 ; python_version >= "3.12" and python_version < "4.0"
64
+ nvidia-nccl-cu12==2.28.9 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine != "aarch64"
65
+ opentelemetry-api==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
66
+ opentelemetry-proto==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
67
+ opentelemetry-sdk==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
68
+ opentelemetry-semantic-conventions==0.60b1 ; python_version >= "3.12" and python_version < "4.0"
69
+ orjson==3.11.5 ; python_version >= "3.12" and python_version < "4.0"
70
+ packaging==25.0 ; python_version >= "3.12" and python_version < "4.0"
71
+ pandas==2.3.3 ; python_version >= "3.12" and python_version < "4.0"
72
+ pillow==12.0.0 ; python_version >= "3.12" and python_version < "4.0"
73
+ protobuf==6.33.2 ; python_version >= "3.12" and python_version < "4.0"
74
+ psycopg2-binary==2.9.9 ; python_version >= "3.12" and python_version < "4.0"
75
+ pyarrow==22.0.0 ; python_version >= "3.12" and python_version < "4.0"
76
+ pyasn1-modules==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
77
+ pyasn1==0.6.1 ; python_version >= "3.12" and python_version < "4.0"
78
+ pycparser==2.23 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy" and implementation_name != "PyPy"
79
+ pydantic-core==2.41.5 ; python_version >= "3.12" and python_version < "4.0"
80
+ pydantic==2.12.5 ; python_version >= "3.12" and python_version < "4.0"
81
+ pydub==0.25.1 ; python_version >= "3.12" and python_version < "4.0"
82
+ pygments==2.19.2 ; python_version >= "3.12" and python_version < "4.0"
83
+ pyparsing==3.3.1 ; python_version >= "3.12" and python_version < "4.0"
84
+ python-dateutil==2.9.0.post0 ; python_version >= "3.12" and python_version < "4.0"
85
+ python-dotenv==1.0.0 ; python_version >= "3.12" and python_version < "4.0"
86
+ python-json-logger==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
87
+ python-multipart==0.0.21 ; python_version >= "3.12" and python_version < "4.0"
88
+ pytz==2025.2 ; python_version >= "3.12" and python_version < "4.0"
89
+ pywin32==311 ; python_version >= "3.12" and python_version < "4.0" and sys_platform == "win32"
90
+ pyyaml==6.0.3 ; python_version >= "3.12" and python_version < "4.0"
91
+ requests==2.32.5 ; python_version >= "3.12" and python_version < "4.0"
92
+ rich==14.2.0 ; python_version >= "3.12" and python_version < "4.0"
93
+ rsa==4.9.1 ; python_version >= "3.12" and python_version < "4.0"
94
+ safehttpx==0.1.7 ; python_version >= "3.12" and python_version < "4.0"
95
+ scikit-learn==1.6.1 ; python_version >= "3.12" and python_version < "4.0"
96
+ scipy==1.16.3 ; python_version >= "3.12" and python_version < "4.0"
97
+ semantic-version==2.10.0 ; python_version >= "3.12" and python_version < "4.0"
98
+ shellingham==1.5.4 ; python_version >= "3.12" and python_version < "4.0"
99
+ six==1.17.0 ; python_version >= "3.12" and python_version < "4.0"
100
+ sklearn-compat==0.1.5 ; python_version >= "3.12" and python_version < "4.0"
101
+ slowapi==0.1.9 ; python_version >= "3.12" and python_version < "4.0"
102
+ smmap==5.0.2 ; python_version >= "3.12" and python_version < "4.0"
103
+ sqlalchemy==2.0.23 ; python_version >= "3.12" and python_version < "4.0"
104
+ sqlparse==0.5.5 ; python_version >= "3.12" and python_version < "4.0"
105
+ starlette==0.50.0 ; python_version >= "3.12" and python_version < "4.0"
106
+ threadpoolctl==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
107
+ tomlkit==0.13.3 ; python_version >= "3.12" and python_version < "4.0"
108
+ tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
109
+ typer-slim==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
110
+ typer==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
111
+ typing-extensions==4.15.0 ; python_version >= "3.12" and python_version < "4.0"
112
+ typing-inspection==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
113
+ tzdata==2025.3 ; python_version >= "3.12" and python_version < "4.0"
114
+ urllib3==2.6.2 ; python_version >= "3.12" and python_version < "4.0"
115
+ uvicorn==0.32.1 ; python_version >= "3.12" and python_version < "4.0"
116
+ uvloop==0.22.1 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy"
117
+ waitress==3.0.2 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
118
+ watchfiles==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
119
+ websockets==15.0.1 ; python_version >= "3.12" and python_version < "4.0"
120
+ werkzeug==3.1.4 ; python_version >= "3.12" and python_version < "4.0"
121
+ wrapt==2.0.1 ; python_version >= "3.12" and python_version < "4.0"
122
+ xgboost==2.1.4 ; python_version >= "3.12" and python_version < "4.0"
123
+ zipp==3.23.0 ; python_version >= "3.12" and python_version < "4.0"
Dockerfile → src/Dockerfile RENAMED
@@ -2,19 +2,26 @@ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
 
 
 
5
  # Installer les dépendances système
6
  RUN apt-get update && apt-get install -y \
7
  curl \
8
  && rm -rf /var/lib/apt/lists/*
9
 
10
- # Copier les fichiers de dépendances
11
- COPY requirements.txt .
 
 
 
12
 
13
- # Installer les dépendances Python
14
- RUN pip install --no-cache-dir -r requirements.txt
15
 
16
  # Copier le code de l'application
17
  COPY app.py .
 
18
  COPY src/ ./src/
19
  COPY .env.example .env
20
 
 
2
 
3
  WORKDIR /app
4
 
5
+ # Installer Poetry
6
+ RUN pip install poetry
7
+
8
  # Installer les dépendances système
9
  RUN apt-get update && apt-get install -y \
10
  curl \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
+ # Copier les fichiers de dépendances Poetry
14
+ COPY pyproject.toml poetry.lock ./
15
+
16
+ # Configurer Poetry pour ne pas créer d'environnement virtuel
17
+ RUN poetry config virtualenvs.create false
18
 
19
+ # Installer les dépendances Python via Poetry
20
+ RUN poetry install --no-dev --no-interaction --no-ansi
21
 
22
  # Copier le code de l'application
23
  COPY app.py .
24
+ COPY db_models.py .
25
  COPY src/ ./src/
26
  COPY .env.example .env
27
 
src/config.py CHANGED
@@ -40,6 +40,11 @@ class Settings:
40
  DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
41
  LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
42
 
 
 
 
 
 
43
  @property
44
  def is_api_key_required(self) -> bool:
45
  """
 
40
  DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
41
  LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
42
 
43
+ # ===== BASE DE DONNÉES =====
44
+ DATABASE_URL: str = os.getenv(
45
+ "DATABASE_URL", "postgresql://ml_user:15975359320@localhost:5432/oc_p5_db"
46
+ )
47
+
48
  @property
49
  def is_api_key_required(self) -> bool:
50
  """
src/gradio_ui.py CHANGED
@@ -8,6 +8,7 @@ Cette interface permet de:
8
  - Comprendre les champs requis
9
  """
10
  import gradio as gr
 
11
 
12
  from src.models import get_model_info, load_model
13
  from src.preprocessing import preprocess_for_prediction
@@ -123,6 +124,36 @@ def predict_turnover(
123
 
124
  confidence = max(prob_0, prob_1) * 100
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  result = f"""
127
  ## {risk_emoji}
128
 
@@ -132,6 +163,9 @@ def predict_turnover(
132
  - **Probabilité de départ**: {prob_1 * 100:.1f}%
133
  - **Probabilité de maintien**: {prob_0 * 100:.1f}%
134
 
 
 
 
135
  ### Interprétation
136
  {"⚠️ Cet employé présente des facteurs de risque de départ. Il est recommandé d'engager un dialogue pour comprendre ses attentes." if prediction == 1 else "✅ Cet employé semble stable. Continuez à maintenir un environnement de travail positif."}
137
  """
@@ -567,7 +601,7 @@ def launch_standalone():
567
  demo.launch(
568
  server_name="0.0.0.0",
569
  server_port=7860,
570
- share=False, # Pas de tunnel Gradio sur HF Spaces
571
  show_error=True,
572
  )
573
 
 
8
  - Comprendre les champs requis
9
  """
10
  import gradio as gr
11
+ import os
12
 
13
  from src.models import get_model_info, load_model
14
  from src.preprocessing import preprocess_for_prediction
 
124
 
125
  confidence = max(prob_0, prob_1) * 100
126
 
127
+ # Enregistrer dans la base de données (uniquement en local)
128
+ db_status = "ℹ️ DB désactivée sur HF Spaces"
129
+ try:
130
+ # Vérifier si on est sur HF Spaces (variable d'environnement)
131
+ if os.getenv("SPACE_ID") is None: # Pas sur HF Spaces
132
+ from sqlalchemy import create_engine
133
+ from sqlalchemy.orm import sessionmaker
134
+ from src.config import get_settings
135
+
136
+ settings = get_settings()
137
+ engine = create_engine(settings.DATABASE_URL)
138
+ Session = sessionmaker(bind=engine)
139
+ session = Session()
140
+
141
+ # Importer le modèle MLLog
142
+ from db_models import MLLog
143
+
144
+ # Créer le log
145
+ log_entry = MLLog(
146
+ input_json=employee.dict(), # Convertir Pydantic en dict
147
+ prediction="Oui" if prediction == 1 else "Non",
148
+ )
149
+ session.add(log_entry)
150
+ session.commit()
151
+ session.close()
152
+
153
+ db_status = "✅ Enregistré en DB"
154
+ except Exception as db_error:
155
+ db_status = f"⚠️ Erreur DB: {str(db_error)}"
156
+
157
  result = f"""
158
  ## {risk_emoji}
159
 
 
163
  - **Probabilité de départ**: {prob_1 * 100:.1f}%
164
  - **Probabilité de maintien**: {prob_0 * 100:.1f}%
165
 
166
+ ### Base de données
167
+ {db_status}
168
+
169
  ### Interprétation
170
  {"⚠️ Cet employé présente des facteurs de risque de départ. Il est recommandé d'engager un dialogue pour comprendre ses attentes." if prediction == 1 else "✅ Cet employé semble stable. Continuez à maintenir un environnement de travail positif."}
171
  """
 
601
  demo.launch(
602
  server_name="0.0.0.0",
603
  server_port=7860,
604
+ share=False,
605
  show_error=True,
606
  )
607
 
src/schemas.py CHANGED
@@ -8,7 +8,7 @@ permettant une validation stricte des inputs avec messages d'erreur clairs.
8
  from enum import Enum
9
  from typing import Literal
10
 
11
- from pydantic import BaseModel, Field, field_validator
12
 
13
 
14
  # Enums pour les valeurs catégorielles
@@ -172,10 +172,8 @@ class EmployeeInput(BaseModel):
172
  v = float(v.replace(" %", "").replace("%", ""))
173
  return v
174
 
175
- class Config:
176
- """Configuration Pydantic."""
177
-
178
- json_schema_extra = {
179
  "example": {
180
  # Exemple basé sur la première ligne des CSV
181
  "nombre_participation_pee": 0,
@@ -210,6 +208,7 @@ class EmployeeInput(BaseModel):
210
  "annees_dans_le_poste_actuel": 4,
211
  }
212
  }
 
213
 
214
 
215
  class PredictionOutput(BaseModel):
@@ -224,10 +223,8 @@ class PredictionOutput(BaseModel):
224
  )
225
  risk_level: str = Field(..., description="Niveau de risque (Low/Medium/High)")
226
 
227
- class Config:
228
- """Configuration Pydantic."""
229
-
230
- json_schema_extra = {
231
  "example": {
232
  "prediction": 1,
233
  "probability_0": 0.35,
@@ -235,6 +232,7 @@ class PredictionOutput(BaseModel):
235
  "risk_level": "High",
236
  }
237
  }
 
238
 
239
 
240
  class HealthCheck(BaseModel):
@@ -245,10 +243,8 @@ class HealthCheck(BaseModel):
245
  model_type: str = Field(..., description="Type du modèle")
246
  version: str = Field(..., description="Version de l'API")
247
 
248
- class Config:
249
- """Configuration Pydantic."""
250
-
251
- json_schema_extra = {
252
  "example": {
253
  "status": "healthy",
254
  "model_loaded": True,
@@ -256,6 +252,7 @@ class HealthCheck(BaseModel):
256
  "version": "1.0.0",
257
  }
258
  }
 
259
 
260
 
261
  class EmployeePrediction(BaseModel):
@@ -281,10 +278,8 @@ class BatchPredictionOutput(BaseModel):
281
  )
282
  summary: dict = Field(..., description="Résumé des prédictions")
283
 
284
- class Config:
285
- """Configuration Pydantic."""
286
-
287
- json_schema_extra = {
288
  "example": {
289
  "total_employees": 100,
290
  "predictions": [
@@ -305,3 +300,4 @@ class BatchPredictionOutput(BaseModel):
305
  },
306
  }
307
  }
 
 
8
  from enum import Enum
9
  from typing import Literal
10
 
11
+ from pydantic import BaseModel, Field, field_validator, ConfigDict
12
 
13
 
14
  # Enums pour les valeurs catégorielles
 
172
  v = float(v.replace(" %", "").replace("%", ""))
173
  return v
174
 
175
+ model_config = ConfigDict(
176
+ json_schema_extra={
 
 
177
  "example": {
178
  # Exemple basé sur la première ligne des CSV
179
  "nombre_participation_pee": 0,
 
208
  "annees_dans_le_poste_actuel": 4,
209
  }
210
  }
211
+ )
212
 
213
 
214
  class PredictionOutput(BaseModel):
 
223
  )
224
  risk_level: str = Field(..., description="Niveau de risque (Low/Medium/High)")
225
 
226
+ model_config = ConfigDict(
227
+ json_schema_extra={
 
 
228
  "example": {
229
  "prediction": 1,
230
  "probability_0": 0.35,
 
232
  "risk_level": "High",
233
  }
234
  }
235
+ )
236
 
237
 
238
  class HealthCheck(BaseModel):
 
243
  model_type: str = Field(..., description="Type du modèle")
244
  version: str = Field(..., description="Version de l'API")
245
 
246
+ model_config = ConfigDict(
247
+ json_schema_extra={
 
 
248
  "example": {
249
  "status": "healthy",
250
  "model_loaded": True,
 
252
  "version": "1.0.0",
253
  }
254
  }
255
+ )
256
 
257
 
258
  class EmployeePrediction(BaseModel):
 
278
  )
279
  summary: dict = Field(..., description="Résumé des prédictions")
280
 
281
+ model_config = ConfigDict(
282
+ json_schema_extra={
 
 
283
  "example": {
284
  "total_employees": 100,
285
  "predictions": [
 
300
  },
301
  }
302
  }
303
+ )