ASI-Engineer commited on
Commit
d987b13
·
verified ·
1 Parent(s): 45f0e10

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. Dockerfile +7 -7
  2. README.md +259 -34
  3. README_HF.md +34 -10
  4. app.py +163 -22
  5. src/config.py +1 -1
  6. src/gradio_ui.py +57 -23
  7. src/preprocessing.py +227 -16
  8. src/schemas.py +70 -13
Dockerfile CHANGED
@@ -21,17 +21,17 @@ COPY .env.example .env
21
  # Créer le dossier logs
22
  RUN mkdir -p logs
23
 
24
- # Exposer le port
25
- EXPOSE 8000
26
 
27
  # Variables d'environnement par défaut
28
  ENV DEBUG=false
29
  ENV LOG_LEVEL=INFO
30
  ENV API_KEY=change-me-in-production
31
 
32
- # Healthcheck
33
- HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
34
- CMD curl -f http://localhost:8000/health || exit 1
35
 
36
- # Commande de démarrage
37
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
 
21
  # Créer le dossier logs
22
  RUN mkdir -p logs
23
 
24
+ # Exposer le port (7860 = Gradio par défaut sur HuggingFace Spaces)
25
+ EXPOSE 7860
26
 
27
  # Variables d'environnement par défaut
28
  ENV DEBUG=false
29
  ENV LOG_LEVEL=INFO
30
  ENV API_KEY=change-me-in-production
31
 
32
+ # Healthcheck - vérifier que FastAPI répond sur /health
33
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
34
+ CMD curl -f http://localhost:7860/health || exit 1
35
 
36
+ # Commande de démarrage - FastAPI avec Gradio monté sur /ui
37
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,49 +1,274 @@
1
- ---
2
- title: Employee Turnover Prediction API
3
- emoji: 👔
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- pinned: true
8
- license: mit
9
- app_port: 8000
10
- ---
11
 
12
- # Employee Turnover Prediction API 🚀
13
 
14
- API de prédiction du turnover des employés avec XGBoost + SMOTE.
15
 
16
- ## 🎯 Fonctionnalités
 
 
 
17
 
18
- - ✅ Prédiction de turnover (0 = reste, 1 = part)
19
- - 📊 Probabilités et niveau de risque (Low/Medium/High)
 
 
20
  - 🔐 Authentification API Key
21
- - 📝 Logs structurés JSON
22
- - 🛡️ Rate limiting (20 req/min)
23
- - 📚 Documentation OpenAPI/Swagger
24
 
25
- ## 🔗 Endpoints
26
 
27
- - **Docs** : `/docs` - Documentation interactive
28
- - **Health** : `/health` - Status de l'API
29
- - **Predict** : `/predict` - Prédiction de turnover
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- ## 🚀 Utilisation
 
 
 
 
 
32
 
33
  ```bash
34
- # Health check
35
- curl https://asi-engineer-employee-turnover-api.hf.space/health
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Prédiction
38
- curl -X POST https://asi-engineer-employee-turnover-api.hf.space/predict \
 
 
 
 
 
 
 
 
 
39
  -H "Content-Type: application/json" \
40
- -d '{
41
- "satisfaction_employee_environnement": 3,
42
- "satisfaction_employee_nature_travail": 4,
43
- ...
44
- }'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ```
46
 
47
- ## 📚 Documentation complète
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- Voir [GitHub Repository](https://github.com/chaton59/OC_P5) pour la documentation complète.
 
 
1
+ # 🚀 Employee Turnover Prediction API - v2.2.0
 
 
 
 
 
 
 
 
 
2
 
3
+ ## 📊 Vue d'ensemble
4
 
5
+ API REST de prédiction du turnover des employés basée sur un modèle XGBoost avec SMOTE.
6
 
7
+ **✨ Nouveautés v2.2.0** :
8
+ - 📦 **Endpoint batch CSV** : Envoyez directement vos 3 fichiers CSV bruts
9
+ - 🔧 Correction du preprocessing (scaling + ordre des colonnes)
10
+ - 📊 Prédictions plus précises (~90% accuracy)
11
 
12
+ **✨ v2.1.0** :
13
+ - 📝 Logging structuré JSON
14
+ - 🛡️ Rate limiting (20 req/min par IP)
15
+ - ⚡ Gestion d'erreurs améliorée
16
  - 🔐 Authentification API Key
 
 
 
17
 
18
+ ## 🏗️ Architecture
19
 
20
+ ```
21
+ OC_P5/
22
+ ├── app.py # Point d'entrée FastAPI
23
+ ├── src/
24
+ │ ├── auth.py # Authentification API Key
25
+ │ ├── config.py # Configuration centralisée
26
+ │ ├── logger.py # Logging structuré (NOUVEAU)
27
+ │ ├── models.py # Chargement modèle HF Hub
28
+ │ ├── preprocessing.py # Pipeline preprocessing
29
+ │ ├── rate_limit.py # Rate limiting (NOUVEAU)
30
+ │ └── schemas.py # Validation Pydantic
31
+ ├── tests/ # Suite pytest (33 tests, 88% couverture)
32
+ ├── logs/ # Logs JSON (NOUVEAU)
33
+ │ ├── api.log # Tous les logs
34
+ │ └── error.log # Erreurs uniquement
35
+ ├── docs/ # Documentation
36
+ ├── ml_model/ # Scripts training
37
+ └── data/ # Données sources
38
+ ```
39
+
40
+ ## 🚀 Installation
41
 
42
+ ### Prérequis
43
+ - Python 3.12+
44
+ - Poetry 1.7+
45
+ - Git
46
+
47
+ ### Setup rapide
48
 
49
  ```bash
50
+ # 1. Cloner le repo
51
+ git clone https://github.com/chaton59/OC_P5.git
52
+ cd OC_P5
53
+
54
+ # 2. Installer les dépendances
55
+ poetry install
56
+
57
+ # 3. Configurer l'environnement
58
+ cp .env.example .env
59
+ # Éditer .env avec vos valeurs
60
+
61
+ # 4. Lancer l'API
62
+ poetry run uvicorn app:app --reload
63
+
64
+ # 5. Accéder à la documentation
65
+ # http://localhost:8000/docs
66
+ ```
67
+
68
+ ## 📝 Configuration (.env)
69
+
70
+ ```bash
71
+ # Mode développement (désactive auth + active logs détaillés)
72
+ DEBUG=true
73
+
74
+ # API Key (requis en production)
75
+ API_KEY=your-secret-key-here
76
+
77
+ # Logging (DEBUG, INFO, WARNING, ERROR, CRITICAL)
78
+ LOG_LEVEL=INFO
79
+
80
+ # HuggingFace Model
81
+ HF_MODEL_REPO=ASI-Engineer/employee-turnover-model
82
+ MODEL_FILENAME=model/model.pkl
83
+ ```
84
+
85
+ ## 🔒 Authentification
86
 
87
+ ### Mode DEBUG (développement)
88
+ ```bash
89
+ # L'API Key n'est PAS requise
90
+ curl http://localhost:8000/predict -H "Content-Type: application/json" -d '{...}'
91
+ ```
92
+
93
+ ### Mode PRODUCTION
94
+ ```bash
95
+ # L'API Key est REQUISE
96
+ curl http://localhost:8000/predict \
97
+ -H "X-API-Key: your-secret-key" \
98
  -H "Content-Type: application/json" \
99
+ -d '{...}'
100
+ ```
101
+
102
+ ## 📡 Endpoints
103
+
104
+ ### 🏥 Health Check
105
+ ```bash
106
+ GET /health
107
+
108
+ # Réponse
109
+ {
110
+ "status": "healthy",
111
+ "model_loaded": true,
112
+ "model_type": "Pipeline",
113
+ "version": "2.2.0"
114
+ }
115
+ ```
116
+
117
+ ### 🔮 Prédiction unitaire
118
+ ```bash
119
+ POST /predict
120
+ Content-Type: application/json
121
+ X-API-Key: your-key (en production)
122
+
123
+ # Payload (tous les champs d'un employé)
124
+ {
125
+ "nombre_participation_pee": 0,
126
+ "nb_formations_suivies": 2,
127
+ "satisfaction_employee_environnement": 3,
128
+ ...
129
+ }
130
+
131
+ # Réponse
132
+ {
133
+ "prediction": 0, # 0 = reste, 1 = part
134
+ "probability_0": 0.85, # Probabilité de rester
135
+ "probability_1": 0.15, # Probabilité de partir
136
+ "risk_level": "Low" # Low, Medium, High
137
+ }
138
+ ```
139
+
140
+ ### 📦 Prédiction batch (NOUVEAU)
141
+ ```bash
142
+ POST /predict/batch
143
+ X-API-Key: your-key (en production)
144
+
145
+ # Envoi des 3 fichiers CSV bruts
146
+ curl -X POST "http://localhost:8000/predict/batch" \
147
+ -H "X-API-Key: your-key" \
148
+ -F "sondage_file=@data/extrait_sondage.csv" \
149
+ -F "eval_file=@data/extrait_eval.csv" \
150
+ -F "sirh_file=@data/extrait_sirh.csv"
151
+
152
+ # Réponse
153
+ {
154
+ "total_employees": 1470,
155
+ "predictions": [
156
+ {"employee_id": 1, "prediction": 1, "probability_leave": 0.84, "risk_level": "High"},
157
+ {"employee_id": 2, "prediction": 0, "probability_leave": 0.11, "risk_level": "Low"}
158
+ ],
159
+ "summary": {
160
+ "total_stay": 1169,
161
+ "total_leave": 301,
162
+ "high_risk_count": 222,
163
+ "medium_risk_count": 233,
164
+ "low_risk_count": 1015
165
+ }
166
+ }
167
  ```
168
 
169
+ ## 📊 Logging
170
+
171
+ ### Logs structurés JSON
172
+
173
+ **Fichiers** :
174
+ - `logs/api.log` : Tous les logs
175
+ - `logs/error.log` : Erreurs uniquement
176
+
177
+ **Format** :
178
+ ```json
179
+ {
180
+ "timestamp": "2025-12-26T10:30:45",
181
+ "level": "INFO",
182
+ "logger": "employee_turnover_api",
183
+ "message": "Request POST /predict",
184
+ "method": "POST",
185
+ "path": "/predict",
186
+ "status_code": 200,
187
+ "duration_ms": 23.45,
188
+ "client_host": "127.0.0.1"
189
+ }
190
+ ```
191
+
192
+ ## 🛡️ Rate Limiting
193
+
194
+ **Configuration** :
195
+ - **Développement** : Désactivé (DEBUG=true)
196
+ - **Production** : 20 requêtes/minute par IP ou API Key
197
+
198
+ **En cas de dépassement** :
199
+ ```json
200
+ {
201
+ "error": "Rate limit exceeded",
202
+ "message": "20 per 1 minute"
203
+ }
204
+ ```
205
+
206
+ ## ✅ Tests
207
+
208
+ ```bash
209
+ # Tous les tests
210
+ poetry run pytest tests/ -v
211
+
212
+ # Avec couverture
213
+ poetry run pytest tests/ --cov --cov-report=html
214
+
215
+ # Voir rapport HTML
216
+ open htmlcov/index.html
217
+ ```
218
+
219
+ **Résultats** :
220
+ - ✅ 33 tests passés
221
+ - 📊 88% de couverture globale
222
+
223
+ ## 🚀 Déploiement
224
+
225
+ ### Variables d'environnement requises
226
+ ```bash
227
+ DEBUG=false
228
+ API_KEY=<votre-clé-sécurisée>
229
+ LOG_LEVEL=INFO
230
+ ```
231
+
232
+ ### HuggingFace Spaces
233
+ Prêt pour déploiement avec `app.py` et `requirements.txt`
234
+
235
+ ## 📚 Documentation
236
+
237
+ - **API Interactive** : http://localhost:8000/docs
238
+ - **ReDoc** : http://localhost:8000/redoc
239
+ - **Guide complet** : [docs/API_GUIDE.md](docs/API_GUIDE.md)
240
+ - **Standards** : [docs/standards.md](docs/standards.md)
241
+ - **Couverture tests** : [docs/TEST_COVERAGE.md](docs/TEST_COVERAGE.md)
242
+
243
+ ## 📦 Dépendances principales
244
+
245
+ - **FastAPI** 0.115.14 : Framework web
246
+ - **Pydantic** 2.12.5 : Validation données
247
+ - **XGBoost** 2.1.3 : Modèle ML
248
+ - **SlowAPI** 0.1.9 : Rate limiting
249
+ - **python-json-logger** 4.0.0 : Logs structurés
250
+ - **pytest** 9.0.2 : Tests
251
+
252
+ ## 🔄 Changelog
253
+
254
+ ### v2.2.0 (27 décembre 2025)
255
+ - 📦 Nouvel endpoint `/predict/batch` pour traitement CSV direct
256
+ - 🔧 Fix preprocessing : ajout du scaling des features
257
+ - 🔧 Fix preprocessing : correction de l'ordre des colonnes
258
+ - 📊 Amélioration précision des prédictions (~90%)
259
+
260
+ ### v2.1.0 (26 décembre 2025)
261
+ - ✨ Système de logging structuré JSON
262
+ - 🛡️ Rate limiting avec SlowAPI
263
+ - ⚡ Amélioration gestion d'erreurs
264
+ - 📊 Monitoring des performances
265
+
266
+ ### v2.0.0 (26 décembre 2025)
267
+ - ✅ Suite de tests complète (36 tests)
268
+ - 🔐 Authentification API Key
269
+ - 📊 88% de couverture de code
270
+
271
+ ## 👥 Auteurs
272
 
273
+ - **Projet** : OpenClassrooms P5
274
+ - **Repo** : [github.com/chaton59/OC_P5](https://github.com/chaton59/OC_P5)
README_HF.md CHANGED
@@ -6,7 +6,7 @@ colorTo: purple
6
  sdk: docker
7
  pinned: true
8
  license: mit
9
- app_port: 8000
10
  ---
11
 
12
  # Employee Turnover Prediction API 🚀
@@ -16,6 +16,7 @@ API de prédiction du turnover des employés avec XGBoost + SMOTE.
16
  ## 🎯 Fonctionnalités
17
 
18
  - ✅ Prédiction de turnover (0 = reste, 1 = part)
 
19
  - 📊 Probabilités et niveau de risque (Low/Medium/High)
20
  - 🔐 Authentification API Key
21
  - 📝 Logs structurés JSON
@@ -24,26 +25,49 @@ API de prédiction du turnover des employés avec XGBoost + SMOTE.
24
 
25
  ## 🔗 Endpoints
26
 
27
- - **Docs** : `/docs` - Documentation interactive
28
- - **Health** : `/health` - Status de l'API
29
- - **Predict** : `/predict` - Prédiction de turnover
 
 
 
 
30
 
31
  ## 🚀 Utilisation
32
 
 
33
  ```bash
34
- # Health check
35
- curl https://asi-engineer-employee-turnover-api.hf.space/health
36
-
37
- # Prédiction
38
- curl -X POST https://asi-engineer-employee-turnover-api.hf.space/predict \
39
  -H "Content-Type: application/json" \
40
  -d '{
 
 
41
  "satisfaction_employee_environnement": 3,
42
- "satisfaction_employee_nature_travail": 4,
43
  ...
44
  }'
45
  ```
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ## 📚 Documentation complète
48
 
49
  Voir [GitHub Repository](https://github.com/chaton59/OC_P5) pour la documentation complète.
 
6
  sdk: docker
7
  pinned: true
8
  license: mit
9
+ app_port: 7860
10
  ---
11
 
12
  # Employee Turnover Prediction API 🚀
 
16
  ## 🎯 Fonctionnalités
17
 
18
  - ✅ Prédiction de turnover (0 = reste, 1 = part)
19
+ - 📦 **Nouveau** : Endpoint batch pour traiter vos fichiers CSV directement
20
  - 📊 Probabilités et niveau de risque (Low/Medium/High)
21
  - 🔐 Authentification API Key
22
  - 📝 Logs structurés JSON
 
25
 
26
  ## 🔗 Endpoints
27
 
28
+ | Endpoint | Description |
29
+ |----------|-------------|
30
+ | `/docs` | Documentation interactive Swagger |
31
+ | `/health` | Status de l'API |
32
+ | `/ui` | Interface Gradio interactive |
33
+ | `/predict` | Prédiction unitaire (JSON) |
34
+ | `/predict/batch` | Prédiction batch (3 fichiers CSV) |
35
 
36
  ## 🚀 Utilisation
37
 
38
+ ### Prédiction unitaire
39
  ```bash
40
+ curl -X POST https://asi-engineer-oc-p5-dev.hf.space/predict \
 
 
 
 
41
  -H "Content-Type: application/json" \
42
  -d '{
43
+ "nombre_participation_pee": 0,
44
+ "nb_formations_suivies": 2,
45
  "satisfaction_employee_environnement": 3,
 
46
  ...
47
  }'
48
  ```
49
 
50
+ ### Prédiction batch (fichiers CSV)
51
+ ```bash
52
+ curl -X POST https://asi-engineer-oc-p5-dev.hf.space/predict/batch \
53
+ -F "sondage_file=@extrait_sondage.csv" \
54
+ -F "eval_file=@extrait_eval.csv" \
55
+ -F "sirh_file=@extrait_sirh.csv"
56
+ ```
57
+
58
+ **Réponse :**
59
+ ```json
60
+ {
61
+ "total_employees": 1470,
62
+ "predictions": [...],
63
+ "summary": {
64
+ "total_stay": 1169,
65
+ "total_leave": 301,
66
+ "high_risk_count": 222
67
+ }
68
+ }
69
+ ```
70
+
71
  ## 📚 Documentation complète
72
 
73
  Voir [GitHub Repository](https://github.com/chaton59/OC_P5) pour la documentation complète.
app.py CHANGED
@@ -8,12 +8,15 @@ Cette API expose le modèle de prédiction de départ des employés avec :
8
  - Health check pour monitoring
9
  - Documentation OpenAPI/Swagger automatique
10
  - Interface Gradio pour utilisation interactive
 
11
  """
 
12
  import time
13
  from contextlib import asynccontextmanager
14
 
15
  import gradio as gr
16
- from fastapi import Depends, FastAPI, HTTPException, Request
 
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from slowapi import _rate_limit_exceeded_handler
19
  from slowapi.errors import RateLimitExceeded
@@ -23,9 +26,19 @@ from src.config import get_settings
23
  from src.gradio_ui import create_gradio_interface
24
  from src.logger import logger, log_model_load, log_request
25
  from src.models import get_model_info, load_model
26
- from src.preprocessing import preprocess_for_prediction
 
 
 
 
27
  from src.rate_limit import limiter
28
- from src.schemas import EmployeeInput, HealthCheck, PredictionOutput
 
 
 
 
 
 
29
 
30
  # Charger la configuration
31
  settings = get_settings()
@@ -112,20 +125,6 @@ async def log_requests(request: Request, call_next):
112
  return response
113
 
114
 
115
- @app.get("/", tags=["Root"])
116
- async def root():
117
- """
118
- Endpoint racine avec informations sur l'API.
119
- """
120
- return {
121
- "message": "Employee Turnover Prediction API",
122
- "version": API_VERSION,
123
- "docs": "/docs",
124
- "health": "/health",
125
- "predict": "/predict (POST)",
126
- }
127
-
128
-
129
  @app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
130
  async def health_check():
131
  """
@@ -240,17 +239,159 @@ async def predict(request: Request, employee: EmployeeInput):
240
  )
241
 
242
 
243
- # Monter l'interface Gradio sur /ui
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  gradio_app = create_gradio_interface()
245
- app = gr.mount_gradio_app(app, gradio_app, path="/ui")
246
 
247
 
248
  if __name__ == "__main__":
249
  import uvicorn
250
 
251
- print("🚀 Lancement de l'API en mode développement...")
252
- print("📖 Documentation : http://localhost:8000/docs")
253
- print("🎨 Interface Gradio : http://localhost:8000/ui")
254
 
255
  uvicorn.run(
256
  "app:app",
 
8
  - Health check pour monitoring
9
  - Documentation OpenAPI/Swagger automatique
10
  - Interface Gradio pour utilisation interactive
11
+ - Endpoint batch pour traitement de fichiers CSV
12
  """
13
+ import io
14
  import time
15
  from contextlib import asynccontextmanager
16
 
17
  import gradio as gr
18
+ import pandas as pd
19
+ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
20
  from fastapi.middleware.cors import CORSMiddleware
21
  from slowapi import _rate_limit_exceeded_handler
22
  from slowapi.errors import RateLimitExceeded
 
26
  from src.gradio_ui import create_gradio_interface
27
  from src.logger import logger, log_model_load, log_request
28
  from src.models import get_model_info, load_model
29
+ from src.preprocessing import (
30
+ merge_csv_dataframes,
31
+ preprocess_dataframe_for_prediction,
32
+ preprocess_for_prediction,
33
+ )
34
  from src.rate_limit import limiter
35
+ from src.schemas import (
36
+ BatchPredictionOutput,
37
+ EmployeeInput,
38
+ EmployeePrediction,
39
+ HealthCheck,
40
+ PredictionOutput,
41
+ )
42
 
43
  # Charger la configuration
44
  settings = get_settings()
 
125
  return response
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  @app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
129
  async def health_check():
130
  """
 
239
  )
240
 
241
 
242
+ @app.post(
243
+ "/predict/batch",
244
+ response_model=BatchPredictionOutput,
245
+ tags=["Prediction"],
246
+ dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
247
+ )
248
+ @limiter.limit("5/minute")
249
+ async def predict_batch(
250
+ request: Request,
251
+ sondage_file: UploadFile = File(..., description="Fichier CSV du sondage"),
252
+ eval_file: UploadFile = File(..., description="Fichier CSV des évaluations"),
253
+ sirh_file: UploadFile = File(..., description="Fichier CSV SIRH"),
254
+ ):
255
+ """
256
+ Endpoint de prédiction batch à partir de fichiers CSV.
257
+
258
+ **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
259
+
260
+ Prend en entrée les 3 fichiers CSV (sondage, évaluation, SIRH),
261
+ les fusionne, applique le preprocessing et retourne les prédictions
262
+ pour tous les employés.
263
+
264
+ Args:
265
+ sondage_file: Fichier CSV contenant les données de sondage.
266
+ eval_file: Fichier CSV contenant les données d'évaluation.
267
+ sirh_file: Fichier CSV contenant les données SIRH.
268
+
269
+ Returns:
270
+ BatchPredictionOutput: Prédictions pour tous les employés.
271
+
272
+ Raises:
273
+ HTTPException: 400 si les fichiers sont invalides.
274
+ HTTPException: 500 si erreur lors du traitement.
275
+ """
276
+ try:
277
+ # 1. Lire les fichiers CSV
278
+ sondage_content = await sondage_file.read()
279
+ eval_content = await eval_file.read()
280
+ sirh_content = await sirh_file.read()
281
+
282
+ sondage_df = pd.read_csv(io.BytesIO(sondage_content))
283
+ eval_df = pd.read_csv(io.BytesIO(eval_content))
284
+ sirh_df = pd.read_csv(io.BytesIO(sirh_content))
285
+
286
+ logger.info(
287
+ f"Fichiers CSV chargés: sondage={len(sondage_df)}, "
288
+ f"eval={len(eval_df)}, sirh={len(sirh_df)} lignes"
289
+ )
290
+
291
+ # 2. Fusionner les DataFrames
292
+ merged_df = merge_csv_dataframes(sondage_df, eval_df, sirh_df)
293
+ employee_ids = merged_df["original_employee_id"].tolist()
294
+ merged_df = merged_df.drop(columns=["original_employee_id"])
295
+
296
+ # Supprimer la colonne cible si présente
297
+ if "a_quitte_l_entreprise" in merged_df.columns:
298
+ merged_df = merged_df.drop(columns=["a_quitte_l_entreprise"])
299
+
300
+ logger.info(f"DataFrame fusionné: {len(merged_df)} employés")
301
+
302
+ # 3. Preprocessing
303
+ X = preprocess_dataframe_for_prediction(merged_df)
304
+
305
+ # 4. Charger le modèle et prédire
306
+ model = load_model()
307
+ predictions = model.predict(X.values)
308
+ probabilities = model.predict_proba(X.values)
309
+
310
+ # 5. Construire la réponse
311
+ results = []
312
+ risk_counts = {"Low": 0, "Medium": 0, "High": 0}
313
+ leave_count = 0
314
+
315
+ for i, emp_id in enumerate(employee_ids):
316
+ prob_stay = float(probabilities[i][0])
317
+ prob_leave = float(probabilities[i][1])
318
+ pred = int(predictions[i])
319
+
320
+ if prob_leave < 0.3:
321
+ risk = "Low"
322
+ elif prob_leave < 0.7:
323
+ risk = "Medium"
324
+ else:
325
+ risk = "High"
326
+
327
+ risk_counts[risk] += 1
328
+ if pred == 1:
329
+ leave_count += 1
330
+
331
+ results.append(
332
+ EmployeePrediction(
333
+ employee_id=int(emp_id),
334
+ prediction=pred,
335
+ probability_stay=prob_stay,
336
+ probability_leave=prob_leave,
337
+ risk_level=risk,
338
+ )
339
+ )
340
+
341
+ summary = {
342
+ "total_stay": len(results) - leave_count,
343
+ "total_leave": leave_count,
344
+ "high_risk_count": risk_counts["High"],
345
+ "medium_risk_count": risk_counts["Medium"],
346
+ "low_risk_count": risk_counts["Low"],
347
+ }
348
+
349
+ logger.info(f"Prédictions terminées: {summary}")
350
+
351
+ return BatchPredictionOutput(
352
+ total_employees=len(results),
353
+ predictions=results,
354
+ summary=summary,
355
+ )
356
+
357
+ except pd.errors.EmptyDataError:
358
+ raise HTTPException(
359
+ status_code=400,
360
+ detail={
361
+ "error": "Empty CSV file",
362
+ "message": "Un des fichiers CSV est vide.",
363
+ },
364
+ )
365
+ except KeyError as e:
366
+ raise HTTPException(
367
+ status_code=400,
368
+ detail={
369
+ "error": "Missing column",
370
+ "message": f"Colonne manquante dans les CSV: {e}",
371
+ },
372
+ )
373
+ except Exception as e:
374
+ logger.exception("Unexpected error during batch prediction")
375
+ raise HTTPException(
376
+ status_code=500,
377
+ detail={
378
+ "error": "Batch prediction failed",
379
+ "message": str(e),
380
+ },
381
+ )
382
+
383
+
384
+ # Monter l'interface Gradio sur / (racine pour HuggingFace Spaces)
385
  gradio_app = create_gradio_interface()
386
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
387
 
388
 
389
  if __name__ == "__main__":
390
  import uvicorn
391
 
392
+ print("\U0001f680 Lancement de l'API en mode d\u00e9veloppement...")
393
+ print("\U0001f4d6 Documentation : http://localhost:8000/docs")
394
+ print("\U0001f3a8 Interface Gradio : http://localhost:8000/")
395
 
396
  uvicorn.run(
397
  "app:app",
src/config.py CHANGED
@@ -26,7 +26,7 @@ class Settings:
26
  API_KEY: str = os.getenv("API_KEY", "dev-key-change-me-in-production")
27
 
28
  # ===== API =====
29
- API_VERSION: str = os.getenv("API_VERSION", "1.0.0")
30
  API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
31
  API_PORT: int = int(os.getenv("API_PORT", "8000"))
32
 
 
26
  API_KEY: str = os.getenv("API_KEY", "dev-key-change-me-in-production")
27
 
28
  # ===== API =====
29
+ API_VERSION: str = os.getenv("API_VERSION", "2.2.0")
30
  API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
31
  API_PORT: int = int(os.getenv("API_PORT", "8000"))
32
 
src/gradio_ui.py CHANGED
@@ -198,7 +198,7 @@ curl -X POST https://asi-engineer-oc-p5-dev.hf.space/predict \\
198
  "departement": "Commercial",
199
  "poste": "Manager",
200
  "nombre_experiences_precedentes": 3,
201
- "nombre_heures_travailless": 45,
202
  "annee_experience_totale": 10,
203
  "annees_dans_l_entreprise": 5,
204
  "annees_dans_le_poste_actuel": 2
@@ -239,7 +239,7 @@ data = {
239
  "departement": "Commercial",
240
  "poste": "Manager",
241
  "nombre_experiences_precedentes": 3,
242
- "nombre_heures_travailless": 45,
243
  "annee_experience_totale": 10,
244
  "annees_dans_l_entreprise": 5,
245
  "annees_dans_le_poste_actuel": 2
@@ -316,16 +316,18 @@ def create_gradio_interface():
316
  with gr.Column():
317
  gr.Markdown("#### 📋 Données Sondage")
318
  nombre_participation_pee = gr.Slider(
319
- 0, 10, value=0, step=1, label="Participations PEE"
320
  )
321
  nb_formations_suivies = gr.Slider(
322
- 0, 10, value=2, step=1, label="Formations suivies"
323
  )
324
- nombre_employee_sous_responsabilite = gr.Slider(
325
- 0, 20, value=0, step=1, label="Employés sous responsabilité"
 
 
326
  )
327
  distance_domicile_travail = gr.Slider(
328
- 0, 50, value=15, step=1, label="Distance domicile (km)"
329
  )
330
  niveau_education = gr.Slider(
331
  1, 5, value=3, step=1, label="Niveau éducation (1-5)"
@@ -354,7 +356,7 @@ def create_gradio_interface():
354
  0, 15, value=2, step=1, label="Années depuis promotion"
355
  )
356
  annes_sous_responsable_actuel = gr.Slider(
357
- 0, 20, value=3, step=1, label="Années sous responsable"
358
  )
359
 
360
  # Colonne EVALUATION
@@ -364,7 +366,7 @@ def create_gradio_interface():
364
  1, 4, value=3, step=1, label="Satisfaction environnement"
365
  )
366
  note_evaluation_precedente = gr.Slider(
367
- 1, 5, value=3, step=1, label="Évaluation précédente"
368
  )
369
  niveau_hierarchique_poste = gr.Slider(
370
  1, 5, value=2, step=1, label="Niveau hiérarchique"
@@ -379,7 +381,7 @@ def create_gradio_interface():
379
  1, 4, value=3, step=1, label="Équilibre pro/perso"
380
  )
381
  note_evaluation_actuelle = gr.Slider(
382
- 1, 5, value=3, step=1, label="Évaluation actuelle"
383
  )
384
  heure_supplementaires = gr.Radio(
385
  ["Oui", "Non"], value="Non", label="Heures supplémentaires"
@@ -395,11 +397,11 @@ def create_gradio_interface():
395
  # Colonne SIRH
396
  with gr.Column():
397
  gr.Markdown("#### 👤 Données SIRH")
398
- age = gr.Slider(18, 65, value=35, step=1, label="Âge")
399
  genre = gr.Radio(["M", "F"], value="M", label="Genre")
400
  revenu_mensuel = gr.Slider(
401
- 1500,
402
- 15000,
403
  value=4500,
404
  step=100,
405
  label="Revenu mensuel (€)",
@@ -430,19 +432,19 @@ def create_gradio_interface():
430
  label="Poste",
431
  )
432
  nombre_experiences_precedentes = gr.Slider(
433
- 0, 10, value=2, step=1, label="Expériences précédentes"
434
  )
435
- nombre_heures_travailless = gr.Slider(
436
- 35, 80, value=40, step=1, label="Heures travaillées/sem"
437
  )
438
  annee_experience_totale = gr.Slider(
439
  0, 40, value=10, step=1, label="Années d'expérience totale"
440
  )
441
  annees_dans_l_entreprise = gr.Slider(
442
- 0, 30, value=5, step=1, label="Années dans l'entreprise"
443
  )
444
  annees_dans_le_poste_actuel = gr.Slider(
445
- 0, 20, value=2, step=1, label="Années dans le poste"
446
  )
447
 
448
  # Bouton et résultat
@@ -531,13 +533,45 @@ def create_gradio_interface():
531
  """
532
  )
533
 
534
- # Note: Pas de queue() car monté dans FastAPI via mount_gradio_app
535
- # La queue SSE v3 ne fonctionne pas correctement avec le montage FastAPI
536
-
537
  return demo
538
 
539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  # Pour lancer en standalone
541
  if __name__ == "__main__":
542
- demo = create_gradio_interface()
543
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
198
  "departement": "Commercial",
199
  "poste": "Manager",
200
  "nombre_experiences_precedentes": 3,
201
+ "nombre_heures_travailless": 80,
202
  "annee_experience_totale": 10,
203
  "annees_dans_l_entreprise": 5,
204
  "annees_dans_le_poste_actuel": 2
 
239
  "departement": "Commercial",
240
  "poste": "Manager",
241
  "nombre_experiences_precedentes": 3,
242
+ "nombre_heures_travailless": 80,
243
  "annee_experience_totale": 10,
244
  "annees_dans_l_entreprise": 5,
245
  "annees_dans_le_poste_actuel": 2
 
316
  with gr.Column():
317
  gr.Markdown("#### 📋 Données Sondage")
318
  nombre_participation_pee = gr.Slider(
319
+ 0, 3, value=0, step=1, label="Participations PEE"
320
  )
321
  nb_formations_suivies = gr.Slider(
322
+ 0, 6, value=2, step=1, label="Formations suivies"
323
  )
324
+ nombre_employee_sous_responsabilite = gr.Number(
325
+ value=1,
326
+ label="Employés sous responsabilité",
327
+ interactive=False,
328
  )
329
  distance_domicile_travail = gr.Slider(
330
+ 1, 30, value=10, step=1, label="Distance domicile (km)"
331
  )
332
  niveau_education = gr.Slider(
333
  1, 5, value=3, step=1, label="Niveau éducation (1-5)"
 
356
  0, 15, value=2, step=1, label="Années depuis promotion"
357
  )
358
  annes_sous_responsable_actuel = gr.Slider(
359
+ 0, 17, value=3, step=1, label="Années sous responsable"
360
  )
361
 
362
  # Colonne EVALUATION
 
366
  1, 4, value=3, step=1, label="Satisfaction environnement"
367
  )
368
  note_evaluation_precedente = gr.Slider(
369
+ 1, 4, value=3, step=1, label="Évaluation précédente"
370
  )
371
  niveau_hierarchique_poste = gr.Slider(
372
  1, 5, value=2, step=1, label="Niveau hiérarchique"
 
381
  1, 4, value=3, step=1, label="Équilibre pro/perso"
382
  )
383
  note_evaluation_actuelle = gr.Slider(
384
+ 3, 4, value=3, step=1, label="Évaluation actuelle"
385
  )
386
  heure_supplementaires = gr.Radio(
387
  ["Oui", "Non"], value="Non", label="Heures supplémentaires"
 
397
  # Colonne SIRH
398
  with gr.Column():
399
  gr.Markdown("#### 👤 Données SIRH")
400
+ age = gr.Slider(18, 60, value=35, step=1, label="Âge")
401
  genre = gr.Radio(["M", "F"], value="M", label="Genre")
402
  revenu_mensuel = gr.Slider(
403
+ 1000,
404
+ 20000,
405
  value=4500,
406
  step=100,
407
  label="Revenu mensuel (€)",
 
432
  label="Poste",
433
  )
434
  nombre_experiences_precedentes = gr.Slider(
435
+ 0, 9, value=2, step=1, label="Expériences précédentes"
436
  )
437
+ nombre_heures_travailless = gr.Number(
438
+ value=80, label="Heures travaillées/sem", interactive=False
439
  )
440
  annee_experience_totale = gr.Slider(
441
  0, 40, value=10, step=1, label="Années d'expérience totale"
442
  )
443
  annees_dans_l_entreprise = gr.Slider(
444
+ 0, 40, value=5, step=1, label="Années dans l'entreprise"
445
  )
446
  annees_dans_le_poste_actuel = gr.Slider(
447
+ 0, 18, value=2, step=1, label="Années dans le poste"
448
  )
449
 
450
  # Bouton et résultat
 
533
  """
534
  )
535
 
 
 
 
536
  return demo
537
 
538
 
539
+ def launch_standalone():
540
+ """Lance Gradio en mode standalone (pour HuggingFace Spaces)."""
541
+ import sys
542
+
543
+ print("🚀 Démarrage de l'application Gradio...", flush=True)
544
+ print(f"Python version: {sys.version}", flush=True)
545
+
546
+ # Pré-charger le modèle pour éviter le timeout au premier appel
547
+ print("📦 Pré-chargement du modèle...", flush=True)
548
+ try:
549
+ from src.models import load_model
550
+
551
+ model = load_model()
552
+ print(f"✅ Modèle chargé: {type(model).__name__}", flush=True)
553
+ except Exception as e:
554
+ print(f"⚠️ Erreur chargement modèle: {e}", flush=True)
555
+
556
+ print("🎨 Création de l'interface Gradio...", flush=True)
557
+ demo = create_gradio_interface()
558
+
559
+ # Configuration pour HuggingFace Spaces
560
+ # Ne pas utiliser queue() qui peut causer des problèmes sur HF Spaces
561
+ # car il nécessite un serveur websocket supplémentaire
562
+
563
+ print("🌐 Lancement du serveur sur 0.0.0.0:7860...", flush=True)
564
+ sys.stdout.flush()
565
+ sys.stderr.flush()
566
+
567
+ demo.launch(
568
+ server_name="0.0.0.0",
569
+ server_port=7860,
570
+ share=False, # Pas de tunnel Gradio sur HF Spaces
571
+ show_error=True,
572
+ )
573
+
574
+
575
  # Pour lancer en standalone
576
  if __name__ == "__main__":
577
+ launch_standalone()
 
src/preprocessing.py CHANGED
@@ -5,8 +5,7 @@ Module de preprocessing pour transformer les données d'entrée avant prédictio
5
  Ce module applique les mêmes transformations que le pipeline d'entraînement :
6
  - Feature engineering (ratios, moyennes)
7
  - Encoding (OneHot, Ordinal)
8
-
9
- Note: Pas de scaling car XGBoost est insensible à l'échelle des features.
10
  """
11
  import numpy as np
12
  import pandas as pd
@@ -14,6 +13,98 @@ from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
14
 
15
  from src.schemas import EmployeeInput
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def create_input_dataframe(employee: EmployeeInput) -> pd.DataFrame:
19
  """
@@ -119,7 +210,7 @@ def encode_and_scale(df: pd.DataFrame) -> pd.DataFrame:
119
  df: DataFrame avec features engineered.
120
 
121
  Returns:
122
- DataFrame transformé avec 50 colonnes (comme training).
123
  """
124
  df = df.copy()
125
 
@@ -184,10 +275,71 @@ def encode_and_scale(df: pd.DataFrame) -> pd.DataFrame:
184
  # Concaténer les encodages OneHot
185
  df = pd.concat([df, encoded_non_ord], axis=1)
186
 
187
- # NOTE: PAS de scaling !
188
- # XGBoost est un modèle basé sur des arbres, insensible à l'échelle.
189
- # Le scaling sur une seule observation causait des valeurs constantes
190
- # car StandardScaler.fit_transform() sur 1 ligne donne toujours 0.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  return df
193
 
@@ -221,12 +373,71 @@ def preprocess_for_prediction(employee: EmployeeInput) -> np.ndarray:
221
  return df.values
222
 
223
 
224
- # TODO: Implémenter le chargement des artifacts sauvegardés
225
- # def load_preprocessing_artifacts(run_id: str) -> dict:
226
- # """
227
- # Charge les encoders et scaler depuis MLflow.
228
- #
229
- # Returns:
230
- # dict avec keys: 'onehot_encoder', 'ordinal_encoder', 'scaler'
231
- # """
232
- # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Ce module applique les mêmes transformations que le pipeline d'entraînement :
6
  - Feature engineering (ratios, moyennes)
7
  - Encoding (OneHot, Ordinal)
8
+ - Scaling (StandardScaler avec paramètres sauvegardés)
 
9
  """
10
  import numpy as np
11
  import pandas as pd
 
13
 
14
  from src.schemas import EmployeeInput
15
 
16
+ # Paramètres du scaler sauvegardés depuis l'entraînement
17
+ # Ces valeurs doivent correspondre exactement à celles utilisées lors du training
18
+ SCALER_PARAMS = {
19
+ "columns": [
20
+ "nombre_participation_pee",
21
+ "nb_formations_suivies",
22
+ "nombre_employee_sous_responsabilite",
23
+ "distance_domicile_travail",
24
+ "niveau_education",
25
+ "annees_depuis_la_derniere_promotion",
26
+ "annes_sous_responsable_actuel",
27
+ "satisfaction_employee_environnement",
28
+ "note_evaluation_precedente",
29
+ "niveau_hierarchique_poste",
30
+ "satisfaction_employee_nature_travail",
31
+ "satisfaction_employee_equipe",
32
+ "satisfaction_employee_equilibre_pro_perso",
33
+ "note_evaluation_actuelle",
34
+ "augementation_salaire_precedente",
35
+ "age",
36
+ "revenu_mensuel",
37
+ "nombre_experiences_precedentes",
38
+ "nombre_heures_travailless",
39
+ "annee_experience_totale",
40
+ "annees_dans_l_entreprise",
41
+ "annees_dans_le_poste_actuel",
42
+ "revenu_par_anciennete",
43
+ "experience_par_anciennete",
44
+ "satisfaction_moyenne",
45
+ "promo_par_anciennete",
46
+ "frequence_deplacement",
47
+ ],
48
+ "mean": [
49
+ 0.7938775510204081,
50
+ 2.7993197278911564,
51
+ 1.0,
52
+ 9.19251700680272,
53
+ 2.912925170068027,
54
+ 2.1789115646258503,
55
+ 4.102721088435374,
56
+ 2.721768707482993,
57
+ 2.7299319727891156,
58
+ 2.0639455782312925,
59
+ 2.7285714285714286,
60
+ 2.7122448979591836,
61
+ 2.7612244897959184,
62
+ 3.1537414965986397,
63
+ 15.209523809523809,
64
+ 36.923809523809524,
65
+ 6502.931292517007,
66
+ 2.6931972789115646,
67
+ 80.0,
68
+ 11.268707482993197,
69
+ 6.980272108843537,
70
+ 4.214965986394557,
71
+ 1170.0019803036198,
72
+ 1.9285635921785853,
73
+ 2.730952380952381,
74
+ 0.23624418065415922,
75
+ 1.0863945578231293,
76
+ ],
77
+ "scale": [
78
+ 0.8517867966287158,
79
+ 1.2888320187689346,
80
+ 1.0,
81
+ 8.104106529671768,
82
+ 1.0238165299102608,
83
+ 3.1873417003246085,
84
+ 3.502524756587405,
85
+ 1.0927103547111134,
86
+ 0.7113190741884202,
87
+ 1.1065633247112856,
88
+ 1.1024709415085499,
89
+ 1.0808410657505316,
90
+ 0.7062354909319911,
91
+ 0.3607007746349458,
92
+ 3.658692627979528,
93
+ 9.132265690615387,
94
+ 4706.355164823003,
95
+ 2.497159198593844,
96
+ 1.0,
97
+ 7.7078836108215345,
98
+ 6.0028580432875085,
99
+ 3.575242796407657,
100
+ 1353.331540788815,
101
+ 2.2050718706188372,
102
+ 0.5056427624070211,
103
+ 0.2687717006578023,
104
+ 0.5319888822661019,
105
+ ],
106
+ }
107
+
108
 
109
  def create_input_dataframe(employee: EmployeeInput) -> pd.DataFrame:
110
  """
 
210
  df: DataFrame avec features engineered.
211
 
212
  Returns:
213
+ DataFrame transformé avec 50 colonnes dans l'ordre exact du modèle.
214
  """
215
  df = df.copy()
216
 
 
275
  # Concaténer les encodages OneHot
276
  df = pd.concat([df, encoded_non_ord], axis=1)
277
 
278
+ # === RÉORDONNER LES COLONNES SELON L'ORDRE DU MODÈLE ===
279
+ # Ordre exact des features attendues par le modèle (50 colonnes)
280
+ expected_columns = [
281
+ "nombre_participation_pee",
282
+ "nb_formations_suivies",
283
+ "nombre_employee_sous_responsabilite",
284
+ "distance_domicile_travail",
285
+ "niveau_education",
286
+ "annees_depuis_la_derniere_promotion",
287
+ "annes_sous_responsable_actuel",
288
+ "satisfaction_employee_environnement",
289
+ "note_evaluation_precedente",
290
+ "niveau_hierarchique_poste",
291
+ "satisfaction_employee_nature_travail",
292
+ "satisfaction_employee_equipe",
293
+ "satisfaction_employee_equilibre_pro_perso",
294
+ "note_evaluation_actuelle",
295
+ "augementation_salaire_precedente",
296
+ "age",
297
+ "revenu_mensuel",
298
+ "nombre_experiences_precedentes",
299
+ "nombre_heures_travailless",
300
+ "annee_experience_totale",
301
+ "annees_dans_l_entreprise",
302
+ "annees_dans_le_poste_actuel",
303
+ "revenu_par_anciennete",
304
+ "experience_par_anciennete",
305
+ "satisfaction_moyenne",
306
+ "promo_par_anciennete",
307
+ "genre_F",
308
+ "genre_M",
309
+ "statut_marital_Célibataire",
310
+ "statut_marital_Divorcé(e)",
311
+ "statut_marital_Marié(e)",
312
+ "departement_Commercial",
313
+ "departement_Consulting",
314
+ "departement_Ressources Humaines",
315
+ "poste_Assistant de Direction",
316
+ "poste_Cadre Commercial",
317
+ "poste_Consultant",
318
+ "poste_Directeur Technique",
319
+ "poste_Manager",
320
+ "poste_Représentant Commercial",
321
+ "poste_Ressources Humaines",
322
+ "poste_Senior Manager",
323
+ "poste_Tech Lead",
324
+ "domaine_etude_Autre",
325
+ "domaine_etude_Entrepreunariat",
326
+ "domaine_etude_Infra & Cloud",
327
+ "domaine_etude_Marketing",
328
+ "domaine_etude_Ressources Humaines",
329
+ "domaine_etude_Transformation Digitale",
330
+ "frequence_deplacement",
331
+ ]
332
+
333
+ # Réordonner les colonnes
334
+ df = df[expected_columns]
335
+
336
+ # === SCALING ===
337
+ # Appliquer le StandardScaler avec les paramètres sauvegardés
338
+ for i, col in enumerate(SCALER_PARAMS["columns"]):
339
+ if col in df.columns:
340
+ mean = SCALER_PARAMS["mean"][i]
341
+ scale = SCALER_PARAMS["scale"][i]
342
+ df[col] = (df[col] - mean) / scale
343
 
344
  return df
345
 
 
373
  return df.values
374
 
375
 
376
+ def preprocess_dataframe_for_prediction(df: pd.DataFrame) -> pd.DataFrame:
377
+ """
378
+ Préprocess un DataFrame complet (issu de CSV fusionnés) pour prédiction batch.
379
+
380
+ Args:
381
+ df: DataFrame avec toutes les colonnes nécessaires.
382
+
383
+ Returns:
384
+ DataFrame transformé prêt pour model.predict().
385
+ """
386
+ # Feature engineering
387
+ df_processed = engineer_features(df)
388
+
389
+ # Encoding et scaling
390
+ df_processed = encode_and_scale(df_processed)
391
+
392
+ return df_processed
393
+
394
+
395
+ def merge_csv_dataframes(
396
+ sondage_df: pd.DataFrame,
397
+ eval_df: pd.DataFrame,
398
+ sirh_df: pd.DataFrame,
399
+ ) -> pd.DataFrame:
400
+ """
401
+ Fusionne les 3 DataFrames CSV comme lors de l'entraînement.
402
+
403
+ Args:
404
+ sondage_df: DataFrame du fichier sondage.
405
+ eval_df: DataFrame du fichier évaluation.
406
+ sirh_df: DataFrame du fichier SIRH.
407
+
408
+ Returns:
409
+ DataFrame fusionné avec toutes les colonnes.
410
+ """
411
+ # Nettoyage de l'évaluation
412
+ eval_df = eval_df.copy()
413
+ eval_df["augementation_salaire_precedente"] = eval_df[
414
+ "augementation_salaire_precedente"
415
+ ].apply(lambda x: float(str(x).replace(" %", "")) if isinstance(x, str) else x)
416
+ eval_df["employee_id"] = eval_df["eval_number"].apply(
417
+ lambda x: int(str(x).replace("E_", "")) if isinstance(x, str) else x
418
+ )
419
+
420
+ # Nettoyage du sondage
421
+ sondage_df = sondage_df.copy()
422
+ sondage_df["employee_id"] = sondage_df["code_sondage"].apply(
423
+ lambda x: int(x) if isinstance(x, (str, int)) else None
424
+ )
425
+
426
+ # Fusion
427
+ central_df = pd.merge(sondage_df, eval_df, on="employee_id", how="inner")
428
+ central_df = pd.merge(
429
+ central_df, sirh_df, left_on="employee_id", right_on="id_employee", how="inner"
430
+ )
431
+
432
+ # Conserver l'ID pour le retour
433
+ central_df["original_employee_id"] = central_df["employee_id"]
434
+
435
+ # Supprimer les colonnes de jointure
436
+ central_df.drop(
437
+ ["code_sondage", "eval_number", "id_employee", "employee_id"],
438
+ axis=1,
439
+ inplace=True,
440
+ errors="ignore",
441
+ )
442
+
443
+ return central_df
src/schemas.py CHANGED
@@ -78,16 +78,19 @@ class EmployeeInput(BaseModel):
78
 
79
  # === Données SONDAGE ===
80
  nombre_participation_pee: int = Field(
81
- ..., ge=0, description="Nombre de participations au PEE"
82
  )
83
  nb_formations_suivies: int = Field(
84
- ..., ge=0, le=10, description="Nombre de formations suivies"
85
  )
86
  nombre_employee_sous_responsabilite: int = Field(
87
- ..., ge=0, description="Nombre d'employés sous responsabilité"
 
 
 
88
  )
89
  distance_domicile_travail: int = Field(
90
- ..., ge=0, le=50, description="Distance domicile-travail en km"
91
  )
92
  niveau_education: int = Field(
93
  ..., ge=1, le=5, description="Niveau d'éducation (1-5)"
@@ -101,7 +104,7 @@ class EmployeeInput(BaseModel):
101
  ..., ge=0, description="Années depuis la dernière promotion"
102
  )
103
  annes_sous_responsable_actuel: int = Field(
104
- ..., ge=0, description="Années sous le responsable actuel"
105
  )
106
 
107
  # === Données EVALUATION ===
@@ -109,7 +112,7 @@ class EmployeeInput(BaseModel):
109
  ..., ge=1, le=4, description="Satisfaction environnement (1-4)"
110
  )
111
  note_evaluation_precedente: int = Field(
112
- ..., ge=1, le=5, description="Note évaluation précédente (1-5)"
113
  )
114
  niveau_hierarchique_poste: int = Field(
115
  ..., ge=1, le=5, description="Niveau hiérarchique (1-5)"
@@ -124,7 +127,7 @@ class EmployeeInput(BaseModel):
124
  ..., ge=1, le=4, description="Satisfaction équilibre pro/perso (1-4)"
125
  )
126
  note_evaluation_actuelle: int = Field(
127
- ..., ge=1, le=5, description="Note évaluation actuelle (1-5)"
128
  )
129
  heure_supplementaires: Literal["Oui", "Non"] = Field(
130
  ..., description="Fait des heures supplémentaires"
@@ -134,26 +137,31 @@ class EmployeeInput(BaseModel):
134
  )
135
 
136
  # === Données SIRH ===
137
- age: int = Field(..., ge=18, le=70, description="Âge de l'employé")
138
  genre: GenreEnum = Field(..., description="Genre")
139
- revenu_mensuel: float = Field(..., ge=1000, description="Revenu mensuel (€)")
 
 
140
  statut_marital: StatutMaritalEnum = Field(..., description="Statut marital")
141
  departement: DepartementEnum = Field(..., description="Département")
142
  poste: PosteEnum = Field(..., description="Intitulé du poste")
143
  nombre_experiences_precedentes: int = Field(
144
- ..., ge=0, description="Nombre d'expériences précédentes"
145
  )
146
  nombre_heures_travailless: int = Field(
147
- ..., ge=35, le=80, description="Nombre d'heures travaillées par semaine"
 
 
 
148
  )
149
  annee_experience_totale: int = Field(
150
  ..., ge=0, description="Années d'expérience totale"
151
  )
152
  annees_dans_l_entreprise: int = Field(
153
- ..., ge=0, description="Années dans l'entreprise"
154
  )
155
  annees_dans_le_poste_actuel: int = Field(
156
- ..., ge=0, description="Années dans le poste actuel"
157
  )
158
 
159
  @field_validator("augementation_salaire_precedente")
@@ -248,3 +256,52 @@ class HealthCheck(BaseModel):
248
  "version": "1.0.0",
249
  }
250
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # === Données SONDAGE ===
80
  nombre_participation_pee: int = Field(
81
+ ..., ge=0, le=3, description="Nombre de participations au PEE (0-3)"
82
  )
83
  nb_formations_suivies: int = Field(
84
+ ..., ge=0, le=6, description="Nombre de formations suivies (0-6)"
85
  )
86
  nombre_employee_sous_responsabilite: int = Field(
87
+ default=1,
88
+ ge=1,
89
+ le=1,
90
+ description="Nombre d'employés sous responsabilité (fixe: 1)",
91
  )
92
  distance_domicile_travail: int = Field(
93
+ ..., ge=1, le=30, description="Distance domicile-travail en km (1-30)"
94
  )
95
  niveau_education: int = Field(
96
  ..., ge=1, le=5, description="Niveau d'éducation (1-5)"
 
104
  ..., ge=0, description="Années depuis la dernière promotion"
105
  )
106
  annes_sous_responsable_actuel: int = Field(
107
+ ..., ge=0, le=17, description="Années sous le responsable actuel (0-17)"
108
  )
109
 
110
  # === Données EVALUATION ===
 
112
  ..., ge=1, le=4, description="Satisfaction environnement (1-4)"
113
  )
114
  note_evaluation_precedente: int = Field(
115
+ ..., ge=1, le=4, description="Note évaluation précédente (1-4)"
116
  )
117
  niveau_hierarchique_poste: int = Field(
118
  ..., ge=1, le=5, description="Niveau hiérarchique (1-5)"
 
127
  ..., ge=1, le=4, description="Satisfaction équilibre pro/perso (1-4)"
128
  )
129
  note_evaluation_actuelle: int = Field(
130
+ ..., ge=3, le=4, description="Note évaluation actuelle (3-4)"
131
  )
132
  heure_supplementaires: Literal["Oui", "Non"] = Field(
133
  ..., description="Fait des heures supplémentaires"
 
137
  )
138
 
139
  # === Données SIRH ===
140
+ age: int = Field(..., ge=18, le=60, description="Âge de l'employé (18-60)")
141
  genre: GenreEnum = Field(..., description="Genre")
142
+ revenu_mensuel: float = Field(
143
+ ..., ge=1000, le=20000, description="Revenu mensuel (€) (1000-20000)"
144
+ )
145
  statut_marital: StatutMaritalEnum = Field(..., description="Statut marital")
146
  departement: DepartementEnum = Field(..., description="Département")
147
  poste: PosteEnum = Field(..., description="Intitulé du poste")
148
  nombre_experiences_precedentes: int = Field(
149
+ ..., ge=0, le=9, description="Nombre d'expériences précédentes (0-9)"
150
  )
151
  nombre_heures_travailless: int = Field(
152
+ default=80,
153
+ ge=80,
154
+ le=80,
155
+ description="Nombre d'heures travaillées par semaine (fixe: 80)",
156
  )
157
  annee_experience_totale: int = Field(
158
  ..., ge=0, description="Années d'expérience totale"
159
  )
160
  annees_dans_l_entreprise: int = Field(
161
+ ..., ge=0, le=40, description="Années dans l'entreprise (0-40)"
162
  )
163
  annees_dans_le_poste_actuel: int = Field(
164
+ ..., ge=0, le=18, description="Années dans le poste actuel (0-18)"
165
  )
166
 
167
  @field_validator("augementation_salaire_precedente")
 
256
  "version": "1.0.0",
257
  }
258
  }
259
+
260
+
261
+ class EmployeePrediction(BaseModel):
262
+ """Prédiction pour un employé dans le batch."""
263
+
264
+ employee_id: int = Field(..., description="ID de l'employé")
265
+ prediction: int = Field(..., description="Classe prédite (0=reste, 1=part)")
266
+ probability_stay: float = Field(
267
+ ..., ge=0, le=1, description="Probabilité de rester"
268
+ )
269
+ probability_leave: float = Field(
270
+ ..., ge=0, le=1, description="Probabilité de partir"
271
+ )
272
+ risk_level: str = Field(..., description="Niveau de risque (Low/Medium/High)")
273
+
274
+
275
+ class BatchPredictionOutput(BaseModel):
276
+ """Schéma de sortie pour les prédictions par lots (CSV)."""
277
+
278
+ total_employees: int = Field(..., description="Nombre total d'employés traités")
279
+ predictions: list[EmployeePrediction] = Field(
280
+ ..., description="Liste des prédictions"
281
+ )
282
+ summary: dict = Field(..., description="Résumé des prédictions")
283
+
284
+ class Config:
285
+ """Configuration Pydantic."""
286
+
287
+ json_schema_extra = {
288
+ "example": {
289
+ "total_employees": 100,
290
+ "predictions": [
291
+ {
292
+ "employee_id": 1,
293
+ "prediction": 0,
294
+ "probability_stay": 0.85,
295
+ "probability_leave": 0.15,
296
+ "risk_level": "Low",
297
+ }
298
+ ],
299
+ "summary": {
300
+ "total_stay": 80,
301
+ "total_leave": 20,
302
+ "high_risk_count": 15,
303
+ "medium_risk_count": 10,
304
+ "low_risk_count": 75,
305
+ },
306
+ }
307
+ }