ASI-Engineer commited on
Commit
0977180
·
verified ·
1 Parent(s): b11e918

Upload folder using huggingface_hub

Browse files
Files changed (15) hide show
  1. .gitignore +55 -20
  2. Dockerfile +37 -0
  3. README.md +207 -74
  4. README_HF.md +49 -0
  5. app.py +224 -101
  6. requirements.txt +121 -9
  7. src/__init__.py +1 -0
  8. src/auth.py +99 -0
  9. src/config.py +64 -0
  10. src/gradio_ui.py +513 -0
  11. src/logger.py +223 -0
  12. src/models.py +153 -0
  13. src/preprocessing.py +243 -0
  14. src/rate_limit.py +40 -0
  15. src/schemas.py +250 -0
.gitignore CHANGED
@@ -1,44 +1,79 @@
1
- # Environnements virtuels
 
 
2
  __pycache__/
3
  *.py[cod]
4
- *.class
5
- *.so
6
  .Python
 
 
 
 
 
 
 
 
 
 
7
  .venv/
8
  env/
 
9
  ENV/
10
 
11
- # Poetry
12
- .venv/
13
-
14
- # Logs et caches
15
- *.log
16
- local/
17
  .pytest_cache/
18
  .coverage
19
  htmlcov/
20
  .tox/
21
- .cache
22
 
23
- # IDE et éditeurs
24
- .vscode/settings.json # Optionnel : ignorez les paramètres sensibles VSCode
 
 
25
  .idea/
26
  *.swp
27
  *~
 
28
 
29
- # Système d'exploitation
 
 
30
  .DS_Store
31
  Thumbs.db
 
32
 
33
- # Secrets et données
 
 
34
  .env
 
 
35
  secrets.json
36
- data/raw/ # Pour datasets volumineux en data science (OC_P5)
37
- notebooks/*.ipynb_checkpoints/
38
 
39
- # MLflow (logs seulement, garder DB et runs pour déploiement HF)
 
 
 
 
 
 
 
 
 
 
 
40
  mlflow.db-shm
41
  mlflow.db-wal
42
- mlflow_ui.log
43
- mlflow_comparison.png
44
- nohup.out
 
 
 
 
 
1
+ # =====================
2
+ # Python
3
+ # =====================
4
  __pycache__/
5
  *.py[cod]
6
+ *.pyo
7
+ *.pyd
8
  .Python
9
+ *.so
10
+ *.egg
11
+ *.egg-info/
12
+ dist/
13
+ build/
14
+ eggs/
15
+
16
+ # =====================
17
+ # Environnements virtuels
18
+ # =====================
19
  .venv/
20
  env/
21
+ venv/
22
  ENV/
23
 
24
+ # =====================
25
+ # Tests & Coverage
26
+ # =====================
 
 
 
27
  .pytest_cache/
28
  .coverage
29
  htmlcov/
30
  .tox/
31
+ coverage.xml
32
 
33
+ # =====================
34
+ # IDE & Editeurs
35
+ # =====================
36
+ .vscode/settings.json
37
  .idea/
38
  *.swp
39
  *~
40
+ *.sublime-*
41
 
42
+ # =====================
43
+ # OS
44
+ # =====================
45
  .DS_Store
46
  Thumbs.db
47
+ *.tmp
48
 
49
+ # =====================
50
+ # Secrets & Données
51
+ # =====================
52
  .env
53
+ .env.*
54
+ !.env.example
55
  secrets.json
56
+ *.key
57
+ *.pem
58
 
59
+ # =====================
60
+ # Logs & Cache
61
+ # =====================
62
+ *.log
63
+ .cache/
64
+ local/
65
+ nohup.out
66
+
67
+ # =====================
68
+ # MLflow
69
+ # =====================
70
+ mlflow.db
71
  mlflow.db-shm
72
  mlflow.db-wal
73
+ mlruns/
74
+
75
+ # =====================
76
+ # Jupyter
77
+ # =====================
78
+ .ipynb_checkpoints/
79
+ *.ipynb_checkpoints/
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Installer les dépendances système
6
+ RUN apt-get update && apt-get install -y \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copier les fichiers de dépendances
11
+ COPY requirements.txt .
12
+
13
+ # Installer les dépendances Python
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copier le code de l'application
17
+ COPY app.py .
18
+ COPY src/ ./src/
19
+ COPY .env.example .env
20
+
21
+ # Créer le dossier logs
22
+ RUN mkdir -p logs
23
+
24
+ # Exposer le port
25
+ EXPOSE 8000
26
+
27
+ # Variables d'environnement par défaut
28
+ ENV DEBUG=false
29
+ ENV LOG_LEVEL=INFO
30
+ ENV API_KEY=change-me-in-production
31
+
32
+ # Healthcheck
33
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
34
+ CMD curl -f http://localhost:8000/health || exit 1
35
+
36
+ # Commande de démarrage
37
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
README.md CHANGED
@@ -1,106 +1,239 @@
1
- ---
2
- title: OC P5 - API ML Déployée
3
- emoji: 🎯
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- # 🎯 Employee Turnover Prediction - DEV Environment
14
-
15
- Interface Gradio pour tester le modèle de prédiction de départ des employés (turnover).
16
-
17
- ## 🚀 Modèle ML
18
-
19
- - **Algorithme**: XGBoost optimisé avec RandomizedSearchCV
20
- - **Équilibrage**: SMOTE pour gérer le déséquilibre de classes (ratio 5:1)
21
- - **Tracking**: MLflow pour versioning et reproductibilité
22
- - **Métriques**: F1-Score optimisé (0.51), Accuracy 79%
23
- - **Stockage**: [Hugging Face Hub](https://huggingface.co/ASI-Engineer/employee-turnover-model)
24
-
25
- ## 📊 Fonctionnalités
26
-
27
- - **Status Checker**: Vérifier l'état du modèle et les métriques
28
- - **API Simple**: Interface Gradio pour tests rapides
29
- - **Chargement automatique**: Modèle téléchargé depuis HF Hub au démarrage
30
-
31
- ## 🔧 Architecture
32
-
33
- ```python
34
- # Chargement du modèle depuis HF Hub
35
- model_path = hf_hub_download(
36
- repo_id="ASI-Engineer/employee-turnover-model",
37
- filename="model/model.pkl"
38
- )
39
- model = mlflow.sklearn.load_model(str(Path(model_path).parent))
40
  ```
41
 
42
- ## 🛠️ Installation & Développement
43
 
44
  ### Prérequis
45
  - Python 3.12+
46
- - Poetry (gestionnaire de dépendances)
 
47
 
48
- ### Installation avec Poetry
49
 
50
  ```bash
51
- # Installer Poetry (si pas déjà fait)
52
- curl -sSL https://install.python-poetry.org | python3 -
 
53
 
54
- # Installer les dépendances
55
  poetry install
56
 
57
- # Activer l'environnement virtuel
58
- poetry shell
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Lancer le pipeline d'entraînement
61
- poetry run python main.py
62
 
63
- # Lancer l'interface Gradio
64
- poetry run python app.py
 
 
 
 
65
  ```
66
 
67
- ### Requirements.txt pour HF Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- Le fichier `requirements.txt` est **minimal et optimisé** pour HF Spaces (seulement gradio, huggingface-hub, joblib).
70
 
71
- Il est **généré automatiquement** par le CI/CD lors des déploiements.
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- Pour le générer manuellement :
74
  ```bash
75
- ./scripts/export_requirements.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  ```
77
 
78
- ### Tests et Linting
79
 
80
  ```bash
81
- # Formater le code
82
- poetry run black .
83
 
84
- # Linter
85
- poetry run flake8 .
86
 
87
- # Tests
88
- poetry run pytest --cov=ml_model tests/
89
  ```
90
 
91
- ## 📈 Métriques
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- - **F1-Score**: 0.5136
94
- - **Accuracy**: 79%
95
- - **Données**: 1470 échantillons, 50 features
96
- - **Classes**: {0: 1233, 1: 237} - Ratio 5.20:1
97
 
98
- ## 🔗 Liens
 
 
 
 
99
 
100
- - **Modèle**: [employee-turnover-model](https://huggingface.co/ASI-Engineer/employee-turnover-model)
101
- - **GitHub**: [OC_P5](https://github.com/chaton59/OC_P5)
102
- - **CI/CD**: GitHub Actions avec déploiement automatique
 
103
 
104
- Ce Space est synchronisé automatiquement via CI/CD depuis la branche `dev` du repository GitHub.
105
 
106
- **Repository**: [chaton59/OC_P5](https://github.com/chaton59/OC_P5)
 
 
1
+ # 🚀 Employee Turnover Prediction API - v2.1.0
2
+
3
+ ## 📊 Vue d'ensemble
4
+
5
+ API REST de prédiction du turnover des employés basée sur un modèle XGBoost avec SMOTE.
6
+
7
+ **✨ Nouveautés v2.1.0** :
8
+ - 📝 Logging structuré JSON
9
+ - 🛡️ Rate limiting (20 req/min par IP)
10
+ - ⚡ Gestion d'erreurs améliorée
11
+ - 📊 Monitoring des performances
12
+ - 🔐 Authentification API Key
13
+
14
+ ## 🏗️ Architecture
15
+
16
+ ```
17
+ OC_P5/
18
+ ├── app.py # Point d'entrée FastAPI
19
+ ├── src/
20
+ │ ├── auth.py # Authentification API Key
21
+ │ ├── config.py # Configuration centralisée
22
+ │ ├── logger.py # Logging structuré (NOUVEAU)
23
+ │ ├── models.py # Chargement modèle HF Hub
24
+ │ ├── preprocessing.py # Pipeline preprocessing
25
+ │ ├── rate_limit.py # Rate limiting (NOUVEAU)
26
+ │ └── schemas.py # Validation Pydantic
27
+ ├── tests/ # Suite pytest (33 tests, 88% couverture)
28
+ ├── logs/ # Logs JSON (NOUVEAU)
29
+ │ ├── api.log # Tous les logs
30
+ │ └── error.log # Erreurs uniquement
31
+ ├── docs/ # Documentation
32
+ ├── ml_model/ # Scripts training
33
+ └── data/ # Données sources
 
 
 
 
 
 
34
  ```
35
 
36
+ ## 🚀 Installation
37
 
38
  ### Prérequis
39
  - Python 3.12+
40
+ - Poetry 1.7+
41
+ - Git
42
 
43
+ ### Setup rapide
44
 
45
  ```bash
46
+ # 1. Cloner le repo
47
+ git clone https://github.com/chaton59/OC_P5.git
48
+ cd OC_P5
49
 
50
+ # 2. Installer les dépendances
51
  poetry install
52
 
53
+ # 3. Configurer l'environnement
54
+ cp .env.example .env
55
+ # Éditer .env avec vos valeurs
56
+
57
+ # 4. Lancer l'API
58
+ poetry run uvicorn app:app --reload
59
+
60
+ # 5. Accéder à la documentation
61
+ # http://localhost:8000/docs
62
+ ```
63
+
64
+ ## 📝 Configuration (.env)
65
+
66
+ ```bash
67
+ # Mode développement (désactive auth + active logs détaillés)
68
+ DEBUG=true
69
 
70
+ # API Key (requis en production)
71
+ API_KEY=your-secret-key-here
72
 
73
+ # Logging (DEBUG, INFO, WARNING, ERROR, CRITICAL)
74
+ LOG_LEVEL=INFO
75
+
76
+ # HuggingFace Model
77
+ HF_MODEL_REPO=ASI-Engineer/employee-turnover-model
78
+ MODEL_FILENAME=model/model.pkl
79
  ```
80
 
81
+ ## 🔒 Authentification
82
+
83
+ ### Mode DEBUG (développement)
84
+ ```bash
85
+ # L'API Key n'est PAS requise
86
+ curl http://localhost:8000/predict -H "Content-Type: application/json" -d '{...}'
87
+ ```
88
+
89
+ ### Mode PRODUCTION
90
+ ```bash
91
+ # L'API Key est REQUISE
92
+ curl http://localhost:8000/predict \
93
+ -H "X-API-Key: your-secret-key" \
94
+ -H "Content-Type: application/json" \
95
+ -d '{...}'
96
+ ```
97
 
98
+ ## 📡 Endpoints
99
 
100
+ ### 🏥 Health Check
101
+ ```bash
102
+ GET /health
103
+
104
+ # Réponse
105
+ {
106
+ "status": "healthy",
107
+ "model_loaded": true,
108
+ "model_type": "Pipeline",
109
+ "version": "2.1.0"
110
+ }
111
+ ```
112
 
113
+ ### 🔮 Prédiction
114
  ```bash
115
+ POST /predict
116
+ Content-Type: application/json
117
+ X-API-Key: your-key (en production)
118
+
119
+ # Exemple payload (voir docs/API_GUIDE.md pour tous les champs)
120
+ {
121
+ "satisfaction_employee_environnement": 3,
122
+ "satisfaction_employee_nature_travail": 4,
123
+ "satisfaction_employee_equipe": 5,
124
+ "satisfaction_employee_equilibre_pro_perso": 3,
125
+ "note_evaluation_actuelle": 85,
126
+ "annees_depuis_la_derniere_promotion": 2,
127
+ "nombre_formations_realisees": 3,
128
+ ...
129
+ }
130
+
131
+ # Réponse
132
+ {
133
+ "prediction": 0, # 0 = reste, 1 = part
134
+ "probability_0": 0.85, # Probabilité de rester
135
+ "probability_1": 0.15, # Probabilité de partir
136
+ "risk_level": "Low" # Low, Medium, High
137
+ }
138
+ ```
139
+
140
+ ## 📊 Logging
141
+
142
+ ### Logs structurés JSON
143
+
144
+ **Fichiers** :
145
+ - `logs/api.log` : Tous les logs
146
+ - `logs/error.log` : Erreurs uniquement
147
+
148
+ **Format** :
149
+ ```json
150
+ {
151
+ "timestamp": "2025-12-26T10:30:45",
152
+ "level": "INFO",
153
+ "logger": "employee_turnover_api",
154
+ "message": "Request POST /predict",
155
+ "method": "POST",
156
+ "path": "/predict",
157
+ "status_code": 200,
158
+ "duration_ms": 23.45,
159
+ "client_host": "127.0.0.1"
160
+ }
161
+ ```
162
+
163
+ ## 🛡️ Rate Limiting
164
+
165
+ **Configuration** :
166
+ - **Développement** : Désactivé (DEBUG=true)
167
+ - **Production** : 20 requêtes/minute par IP ou API Key
168
+
169
+ **En cas de dépassement** :
170
+ ```json
171
+ {
172
+ "error": "Rate limit exceeded",
173
+ "message": "20 per 1 minute"
174
+ }
175
  ```
176
 
177
+ ## Tests
178
 
179
  ```bash
180
+ # Tous les tests
181
+ poetry run pytest tests/ -v
182
 
183
+ # Avec couverture
184
+ poetry run pytest tests/ --cov --cov-report=html
185
 
186
+ # Voir rapport HTML
187
+ open htmlcov/index.html
188
  ```
189
 
190
+ **Résultats** :
191
+ - ✅ 33 tests passés
192
+ - 📊 88% de couverture globale
193
+
194
+ ## 🚀 Déploiement
195
+
196
+ ### Variables d'environnement requises
197
+ ```bash
198
+ DEBUG=false
199
+ API_KEY=<votre-clé-sécurisée>
200
+ LOG_LEVEL=INFO
201
+ ```
202
+
203
+ ### HuggingFace Spaces
204
+ Prêt pour déploiement avec `app.py` et `requirements.txt`
205
+
206
+ ## 📚 Documentation
207
+
208
+ - **API Interactive** : http://localhost:8000/docs
209
+ - **ReDoc** : http://localhost:8000/redoc
210
+ - **Guide complet** : [docs/API_GUIDE.md](docs/API_GUIDE.md)
211
+ - **Standards** : [docs/standards.md](docs/standards.md)
212
+ - **Couverture tests** : [docs/TEST_COVERAGE.md](docs/TEST_COVERAGE.md)
213
+
214
+ ## 📦 Dépendances principales
215
+
216
+ - **FastAPI** 0.115.14 : Framework web
217
+ - **Pydantic** 2.12.5 : Validation données
218
+ - **XGBoost** 2.1.3 : Modèle ML
219
+ - **SlowAPI** 0.1.9 : Rate limiting
220
+ - **python-json-logger** 4.0.0 : Logs structurés
221
+ - **pytest** 9.0.2 : Tests
222
 
223
+ ## 🔄 Changelog
 
 
 
224
 
225
+ ### v2.1.0 (26 décembre 2025)
226
+ - ✨ Système de logging structuré JSON
227
+ - 🛡️ Rate limiting avec SlowAPI
228
+ - ⚡ Amélioration gestion d'erreurs
229
+ - 📊 Monitoring des performances
230
 
231
+ ### v2.0.0 (26 décembre 2025)
232
+ - Suite de tests complète (33 tests)
233
+ - 🔐 Authentification API Key
234
+ - 📊 88% de couverture de code
235
 
236
+ ## 👥 Auteurs
237
 
238
+ - **Projet** : OpenClassrooms P5
239
+ - **Repo** : [github.com/chaton59/OC_P5](https://github.com/chaton59/OC_P5)
README_HF.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Employee Turnover Prediction API
3
+ emoji: 👔
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: true
8
+ license: mit
9
+ app_port: 8000
10
+ ---
11
+
12
+ # Employee Turnover Prediction API 🚀
13
+
14
+ API de prédiction du turnover des employés avec XGBoost + SMOTE.
15
+
16
+ ## 🎯 Fonctionnalités
17
+
18
+ - ✅ Prédiction de turnover (0 = reste, 1 = part)
19
+ - 📊 Probabilités et niveau de risque (Low/Medium/High)
20
+ - 🔐 Authentification API Key
21
+ - 📝 Logs structurés JSON
22
+ - 🛡️ Rate limiting (20 req/min)
23
+ - 📚 Documentation OpenAPI/Swagger
24
+
25
+ ## 🔗 Endpoints
26
+
27
+ - **Docs** : `/docs` - Documentation interactive
28
+ - **Health** : `/health` - Status de l'API
29
+ - **Predict** : `/predict` - Prédiction de turnover
30
+
31
+ ## 🚀 Utilisation
32
+
33
+ ```bash
34
+ # Health check
35
+ curl https://asi-engineer-employee-turnover-api.hf.space/health
36
+
37
+ # Prédiction
38
+ curl -X POST https://asi-engineer-employee-turnover-api.hf.space/predict \
39
+ -H "Content-Type: application/json" \
40
+ -d '{
41
+ "satisfaction_employee_environnement": 3,
42
+ "satisfaction_employee_nature_travail": 4,
43
+ ...
44
+ }'
45
+ ```
46
+
47
+ ## 📚 Documentation complète
48
+
49
+ Voir [GitHub Repository](https://github.com/chaton59/OC_P5) pour la documentation complète.
app.py CHANGED
@@ -1,138 +1,261 @@
1
  #!/usr/bin/env python3
2
  """
3
- Interface Gradio pour tester le modèle Employee Turnover en production.
4
 
5
- Déploiement sur Hugging Face Spaces pour tests rapides.
6
- Version de démonstration - Interface complète en développement.
 
 
 
 
7
  """
 
 
 
8
  import gradio as gr
9
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # Configuration
12
- HF_MODEL_REPO = "ASI-Engineer/employee-turnover-model"
 
13
 
14
 
15
- def load_model():
 
16
  """
17
- Charge le modèle depuis Hugging Face Hub.
18
 
19
- En production (HF Spaces), charge uniquement depuis HF Hub.
20
- Le fallback MLflow local n'est disponible qu'en développement local.
21
  """
 
 
 
 
 
22
  try:
23
- import joblib
 
 
24
 
25
- # Download model pickle from HF Hub
26
- model_path = hf_hub_download(
27
- repo_id=HF_MODEL_REPO, filename="model/model.pkl", repo_type="model"
28
- )
29
- model = joblib.load(model_path)
30
- print(f"✅ Modèle chargé depuis HF Hub: {HF_MODEL_REPO}")
31
- return model, "HF Hub"
32
  except Exception as e:
33
- print(f"❌ Erreur chargement depuis HF Hub: {e}")
34
- return None, "Error"
35
-
36
-
37
- # Charger le modèle au démarrage
38
- try:
39
- model, model_source = load_model()
40
- MODEL_LOADED = model is not None
41
- except Exception as e:
42
- print(f"❌ Erreur lors du chargement du modèle: {e}")
43
- MODEL_LOADED = False
44
- model = None
45
- model_source = "Error"
46
-
47
-
48
- def get_model_info():
49
- """Retourne les informations sur le modèle."""
50
- if not MODEL_LOADED:
51
- return {
52
- "status": "❌ Modèle non disponible",
53
- "error": "Le modèle n'a pas pu être chargé",
54
- "solution": "Vérifiez que le modèle est bien enregistré sur HF Hub ou entraîné localement",
55
- }
56
 
57
- try:
58
- info = {
59
- "status": "✅ Modèle chargé avec succès",
60
- "source": model_source,
61
- "model_type": type(model).__name__,
62
- "features": "~50 features (après preprocessing)",
63
- "algorithme": "XGBoost + SMOTE",
64
- "hf_hub_repo": HF_MODEL_REPO,
65
- }
66
-
67
- info["info"] = "Interface de prédiction en développement - API FastAPI à venir"
68
- return info
69
 
70
- except Exception as e:
71
- return {"status": "✅ Modèle chargé (info limitées)", "error": str(e)}
72
 
73
 
74
- # Interface Gradio
75
- with gr.Blocks( # type: ignore[attr-defined]
76
- title="Employee Turnover Prediction - DEV", theme=gr.themes.Soft() # type: ignore[attr-defined]
77
- ) as demo:
78
- gr.Markdown("# 🎯 Prédiction du Turnover - Employee Attrition") # type: ignore[attr-defined]
79
- gr.Markdown("## Environment DEV - Test de déploiement CI/CD") # type: ignore[attr-defined]
 
 
 
80
 
81
- gr.Markdown( # type: ignore[attr-defined]
82
- """
83
- ### 📊 Statut du projet
84
 
85
- Ce Space est synchronisé automatiquement depuis GitHub (branche `dev`).
 
 
 
 
 
 
 
86
 
87
- **Actuellement disponible :**
88
- - ✅ Pipeline d'entraînement MLflow complet (`main.py`)
89
- - ✅ Déploiement automatique CI/CD (GitHub Actions → HF Spaces)
90
- - ✅ Tests unitaires et linting automatisés
91
 
92
- **En développement :**
93
- - 🚧 Interface de prédiction interactive
94
- - 🚧 API FastAPI avec endpoints de prédiction
95
- - 🚧 Intégration PostgreSQL pour tracking des prédictions
 
96
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
- with gr.Row(): # type: ignore[attr-defined]
100
- with gr.Column(): # type: ignore[attr-defined]
101
- gr.Markdown("### 🔍 Informations sur le modèle") # type: ignore[attr-defined]
102
- check_btn = gr.Button("📊 Vérifier le statut du modèle", variant="primary") # type: ignore[attr-defined]
103
 
104
- with gr.Column(): # type: ignore[attr-defined]
105
- model_output = gr.JSON(label="Statut") # type: ignore[attr-defined]
106
 
107
- check_btn.click(fn=get_model_info, inputs=[], outputs=model_output)
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- gr.Markdown("---") # type: ignore[attr-defined]
110
 
111
- gr.Markdown( # type: ignore[attr-defined]
112
- """
113
- ### 🛠️ Prochaines étapes (selon etapes.txt)
 
114
 
115
- 1. **Étape 3** : Développement API FastAPI
116
- - Endpoints de prédiction avec validation Pydantic
117
- - Chargement dynamique des preprocessing artifacts (scaler, encoders)
118
- - Documentation Swagger/OpenAPI automatique
119
 
120
- 2. **Étape 4** : Intégration PostgreSQL
121
- - Stockage des inputs/outputs des prédictions
122
- - Traçabilité complète des requêtes
123
 
124
- 3. **Étape 5** : Tests unitaires et fonctionnels
125
- - Tests des endpoints API
126
- - Tests de charge et performance
127
- - Couverture de code avec pytest-cov
 
128
 
129
- ### 📚 Documentation
130
- - **Repository GitHub** : [chaton59/OC_P5](https://github.com/chaton59/OC_P5)
131
- - **MLflow Tracking** : Disponible en local (`./scripts/start_mlflow.sh`)
132
- - **Métriques** : F1-Score optimisé, gestion classes déséquilibrées (SMOTE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  """
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  if __name__ == "__main__":
138
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ API FastAPI pour le modèle Employee Turnover.
4
 
5
+ Cette API expose le modèle de prédiction de départ des employés avec :
6
+ - Validation stricte des inputs via Pydantic
7
+ - Preprocessing automatique
8
+ - Health check pour monitoring
9
+ - Documentation OpenAPI/Swagger automatique
10
+ - Interface Gradio pour utilisation interactive
11
  """
12
+ import time
13
+ from contextlib import asynccontextmanager
14
+
15
  import gradio as gr
16
+ from fastapi import Depends, FastAPI, HTTPException, Request
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from slowapi import _rate_limit_exceeded_handler
19
+ from slowapi.errors import RateLimitExceeded
20
+
21
+ from src.auth import verify_api_key
22
+ from src.config import get_settings
23
+ from src.gradio_ui import create_gradio_interface
24
+ from src.logger import logger, log_model_load, log_request
25
+ from src.models import get_model_info, load_model
26
+ from src.preprocessing import preprocess_for_prediction
27
+ from src.rate_limit import limiter
28
+ from src.schemas import EmployeeInput, HealthCheck, PredictionOutput
29
 
30
+ # Charger la configuration
31
+ settings = get_settings()
32
+ API_VERSION = settings.API_VERSION
33
 
34
 
35
+ @asynccontextmanager
36
+ async def lifespan(app: FastAPI):
37
  """
38
+ Gestion du cycle de vie de l'application.
39
 
40
+ Charge le modèle au démarrage et le garde en cache.
 
41
  """
42
+ logger.info(
43
+ "🚀 Démarrage de l'API Employee Turnover...", extra={"version": API_VERSION}
44
+ )
45
+
46
+ start_time = time.time()
47
  try:
48
+ # Pré-charger le modèle au démarrage
49
+ model = load_model()
50
+ duration_ms = (time.time() - start_time) * 1000
51
 
52
+ model_type = type(model).__name__
53
+ log_model_load(model_type, duration_ms, True)
54
+ logger.info(" Modèle chargé avec succès")
 
 
 
 
55
  except Exception as e:
56
+ duration_ms = (time.time() - start_time) * 1000
57
+ log_model_load("Unknown", duration_ms, False)
58
+ logger.error("Le modèle n'a pas pu être chargé", extra={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ yield # L'application tourne
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ logger.info("🛑 Arrêt de l'API")
 
63
 
64
 
65
+ # Créer l'application FastAPI
66
+ app = FastAPI(
67
+ title="Employee Turnover Prediction API",
68
+ description="API de prédiction du turnover des employés avec XGBoost + SMOTE",
69
+ version=API_VERSION,
70
+ lifespan=lifespan,
71
+ docs_url="/docs",
72
+ redoc_url="/redoc",
73
+ )
74
 
75
+ # Ajouter rate limiting
76
+ app.state.limiter = limiter
77
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
78
 
79
+ # Configurer CORS (autoriser tous les domaines en dev)
80
+ app.add_middleware(
81
+ CORSMiddleware,
82
+ allow_origins=["*"],
83
+ allow_credentials=True,
84
+ allow_methods=["*"],
85
+ allow_headers=["*"],
86
+ )
87
 
 
 
 
 
88
 
89
+ # Middleware de logging des requêtes
90
+ @app.middleware("http")
91
+ async def log_requests(request: Request, call_next):
92
+ """
93
+ Middleware pour logger toutes les requêtes HTTP.
94
  """
95
+ start_time = time.time()
96
+
97
+ # Traiter la requête
98
+ response = await call_next(request)
99
+
100
+ # Calculer la durée
101
+ duration_ms = (time.time() - start_time) * 1000
102
+
103
+ # Logger
104
+ log_request(
105
+ method=request.method,
106
+ path=request.url.path,
107
+ status_code=response.status_code,
108
+ duration_ms=duration_ms,
109
+ client_host=request.client.host if request.client else None,
110
  )
111
 
112
+ return response
 
 
 
113
 
 
 
114
 
115
+ @app.get("/", tags=["Root"])
116
+ async def root():
117
+ """
118
+ Endpoint racine avec informations sur l'API.
119
+ """
120
+ return {
121
+ "message": "Employee Turnover Prediction API",
122
+ "version": API_VERSION,
123
+ "docs": "/docs",
124
+ "health": "/health",
125
+ "predict": "/predict (POST)",
126
+ }
127
 
 
128
 
129
+ @app.get("/health", response_model=HealthCheck, tags=["Monitoring"])
130
+ async def health_check():
131
+ """
132
+ Health check endpoint pour monitoring.
133
 
134
+ Vérifie que l'API est opérationnelle et que le modèle est chargé.
 
 
 
135
 
136
+ Returns:
137
+ HealthCheck: Status de l'API et du modèle.
 
138
 
139
+ Raises:
140
+ HTTPException: 503 si le modèle n'est pas disponible.
141
+ """
142
+ try:
143
+ model_info = get_model_info()
144
 
145
+ return HealthCheck(
146
+ status="healthy",
147
+ model_loaded=model_info.get("cached", False),
148
+ model_type=model_info.get("model_type", "Unknown"),
149
+ version=API_VERSION,
150
+ )
151
+ except Exception as e:
152
+ raise HTTPException(
153
+ status_code=503,
154
+ detail={
155
+ "status": "unhealthy",
156
+ "error": "Model not available",
157
+ "message": str(e),
158
+ },
159
+ )
160
+
161
+
162
+ @app.post(
163
+ "/predict",
164
+ response_model=PredictionOutput,
165
+ tags=["Prediction"],
166
+ dependencies=[Depends(verify_api_key)] if settings.is_api_key_required else [],
167
+ )
168
+ @limiter.limit("20/minute")
169
+ async def predict(request: Request, employee: EmployeeInput):
170
  """
171
+ Endpoint de prédiction du turnover d'un employé.
172
+
173
+ **PROTÉGÉ PAR API KEY** : Requiert le header `X-API-Key` en production.
174
+
175
+ Prend en entrée les données d'un employé, applique le preprocessing
176
+ et retourne la prédiction avec les probabilités.
177
+
178
+ Args:
179
+ employee: Données de l'employé validées par Pydantic.
180
+
181
+ Returns:
182
+ PredictionOutput: Prédiction et probabilités.
183
+
184
+ Raises:
185
+ HTTPException: 401 si API key invalide ou manquante.
186
+ HTTPException: 500 si erreur lors de la prédiction.
187
+
188
+ Examples:
189
+ ```bash
190
+ # Avec authentification
191
+ curl -X POST http://localhost:8000/predict \\
192
+ -H "X-API-Key: your-secret-key" \\
193
+ -H "Content-Type: application/json" \\
194
+ -d '{...}'
195
+ ```
196
+ """
197
+ try:
198
+ # 1. Charger le modèle
199
+ model = load_model()
200
+
201
+ # 2. Préprocessing
202
+ X = preprocess_for_prediction(employee)
203
+
204
+ # 3. Prédiction
205
+ prediction = int(model.predict(X)[0])
206
+
207
+ # 4. Probabilités (si le modèle supporte predict_proba)
208
+ try:
209
+ probabilities = model.predict_proba(X)[0]
210
+ prob_0 = float(probabilities[0])
211
+ prob_1 = float(probabilities[1])
212
+ except AttributeError:
213
+ # Si le modèle ne supporte pas predict_proba
214
+ prob_0 = 1.0 if prediction == 0 else 0.0
215
+ prob_1 = 1.0 if prediction == 1 else 0.0
216
+
217
+ # 5. Niveau de risque
218
+ if prob_1 < 0.3:
219
+ risk_level = "Low"
220
+ elif prob_1 < 0.7:
221
+ risk_level = "Medium"
222
+ else:
223
+ risk_level = "High"
224
+
225
+ return PredictionOutput(
226
+ prediction=prediction,
227
+ probability_0=prob_0,
228
+ probability_1=prob_1,
229
+ risk_level=risk_level,
230
+ )
231
+
232
+ except Exception:
233
+ logger.exception("Unexpected error during prediction")
234
+ raise HTTPException(
235
+ status_code=500,
236
+ detail={
237
+ "error": "Prediction failed",
238
+ "message": "An unexpected error occurred. Please contact support.",
239
+ },
240
+ )
241
+
242
+
243
+ # Monter l'interface Gradio sur /ui
244
+ gradio_app = create_gradio_interface()
245
+ app = gr.mount_gradio_app(app, gradio_app, path="/ui")
246
 
247
 
248
  if __name__ == "__main__":
249
+ import uvicorn
250
+
251
+ print("🚀 Lancement de l'API en mode développement...")
252
+ print("📖 Documentation : http://localhost:8000/docs")
253
+ print("🎨 Interface Gradio : http://localhost:8000/ui")
254
+
255
+ uvicorn.run(
256
+ "app:app",
257
+ host="0.0.0.0",
258
+ port=8000,
259
+ reload=True,
260
+ log_level="info",
261
+ )
requirements.txt CHANGED
@@ -1,9 +1,121 @@
1
- # Minimal requirements for HF Spaces deployment
2
- # Only the dependencies needed for app.py and model loading
3
- gradio>=5.9.0
4
- huggingface-hub>=0.27.0
5
- joblib>=1.4.0
6
- scikit-learn>=1.6.0
7
- imbalanced-learn>=0.13.0
8
- xgboost>=2.1.0
9
- numpy>=2.0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0 ; python_version >= "3.12" and python_version < "4.0"
2
+ alembic==1.17.2 ; python_version >= "3.12" and python_version < "4.0"
3
+ annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
4
+ anyio==4.12.0 ; python_version >= "3.12" and python_version < "4.0"
5
+ audioop-lts==0.2.2 ; python_version >= "3.13" and python_version < "4.0"
6
+ blinker==1.9.0 ; python_version >= "3.12" and python_version < "4.0"
7
+ brotli==1.2.0 ; python_version >= "3.12" and python_version < "4.0"
8
+ cachetools==6.2.4 ; python_version >= "3.12" and python_version < "4.0"
9
+ certifi==2025.11.12 ; python_version >= "3.12" and python_version < "4.0"
10
+ cffi==2.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy"
11
+ charset-normalizer==3.4.4 ; python_version >= "3.12" and python_version < "4.0"
12
+ click==8.3.1 ; python_version >= "3.12" and python_version < "4.0"
13
+ cloudpickle==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
14
+ colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and (platform_system == "Windows" or sys_platform == "win32")
15
+ contourpy==1.3.3 ; python_version >= "3.12" and python_version < "4.0"
16
+ cryptography==46.0.3 ; python_version >= "3.12" and python_version < "4.0"
17
+ cycler==0.12.1 ; python_version >= "3.12" and python_version < "4.0"
18
+ databricks-sdk==0.76.0 ; python_version >= "3.12" and python_version < "4.0"
19
+ deprecated==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
20
+ docker==7.1.0 ; python_version >= "3.12" and python_version < "4.0"
21
+ fastapi==0.115.14 ; python_version >= "3.12" and python_version < "4.0"
22
+ ffmpy==1.0.0 ; python_version >= "3.12" and python_version < "4.0"
23
+ filelock==3.20.1 ; python_version >= "3.12" and python_version < "4.0"
24
+ flask-cors==6.0.2 ; python_version >= "3.12" and python_version < "4.0"
25
+ flask==3.1.2 ; python_version >= "3.12" and python_version < "4.0"
26
+ fonttools==4.61.1 ; python_version >= "3.12" and python_version < "4.0"
27
+ fsspec==2025.12.0 ; python_version >= "3.12" and python_version < "4.0"
28
+ gitdb==4.0.12 ; python_version >= "3.12" and python_version < "4.0"
29
+ gitpython==3.1.45 ; python_version >= "3.12" and python_version < "4.0"
30
+ google-auth==2.45.0 ; python_version >= "3.12" and python_version < "4.0"
31
+ gradio-client==2.0.2 ; python_version >= "3.12" and python_version < "4.0"
32
+ gradio==6.2.0 ; python_version >= "3.12" and python_version < "4.0"
33
+ graphene==3.4.3 ; python_version >= "3.12" and python_version < "4.0"
34
+ graphql-core==3.2.7 ; python_version >= "3.12" and python_version < "4.0"
35
+ graphql-relay==3.2.0 ; python_version >= "3.12" and python_version < "4.0"
36
+ greenlet==3.3.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32")
37
+ groovy==0.1.2 ; python_version >= "3.12" and python_version < "4.0"
38
+ gunicorn==23.0.0 ; python_version >= "3.12" and python_version < "4.0" and platform_system != "Windows"
39
+ h11==0.16.0 ; python_version >= "3.12" and python_version < "4.0"
40
+ hf-xet==1.2.0 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "arm64" or platform_machine == "aarch64")
41
+ httpcore==1.0.9 ; python_version >= "3.12" and python_version < "4.0"
42
+ httptools==0.7.1 ; python_version >= "3.12" and python_version < "4.0"
43
+ httpx==0.28.1 ; python_version >= "3.12" and python_version < "4.0"
44
+ huey==2.5.5 ; python_version >= "3.12" and python_version < "4.0"
45
+ huggingface-hub==1.2.3 ; python_version >= "3.12" and python_version < "4.0"
46
+ idna==3.11 ; python_version >= "3.12" and python_version < "4.0"
47
+ imbalanced-learn==0.13.0 ; python_version >= "3.12" and python_version < "4.0"
48
+ importlib-metadata==8.7.1 ; python_version >= "3.12" and python_version < "4.0"
49
+ itsdangerous==2.2.0 ; python_version >= "3.12" and python_version < "4.0"
50
+ jinja2==3.1.6 ; python_version >= "3.12" and python_version < "4.0"
51
+ joblib==1.5.3 ; python_version >= "3.12" and python_version < "4.0"
52
+ kiwisolver==1.4.9 ; python_version >= "3.12" and python_version < "4.0"
53
+ limits==5.6.0 ; python_version >= "3.12" and python_version < "4.0"
54
+ mako==1.3.10 ; python_version >= "3.12" and python_version < "4.0"
55
+ markdown-it-py==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
56
+ markupsafe==3.0.3 ; python_version >= "3.12" and python_version < "4.0"
57
+ matplotlib==3.10.8 ; python_version >= "3.12" and python_version < "4.0"
58
+ mdurl==0.1.2 ; python_version >= "3.12" and python_version < "4.0"
59
+ mlflow-skinny==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
60
+ mlflow-tracing==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
61
+ mlflow==3.8.1 ; python_version >= "3.12" and python_version < "4.0"
62
+ numpy==2.4.0 ; python_version >= "3.12" and python_version < "4.0"
63
+ nvidia-nccl-cu12==2.28.9 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine != "aarch64"
64
+ opentelemetry-api==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
65
+ opentelemetry-proto==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
66
+ opentelemetry-sdk==1.39.1 ; python_version >= "3.12" and python_version < "4.0"
67
+ opentelemetry-semantic-conventions==0.60b1 ; python_version >= "3.12" and python_version < "4.0"
68
+ orjson==3.11.5 ; python_version >= "3.12" and python_version < "4.0"
69
+ packaging==25.0 ; python_version >= "3.12" and python_version < "4.0"
70
+ pandas==2.3.3 ; python_version >= "3.12" and python_version < "4.0"
71
+ pillow==12.0.0 ; python_version >= "3.12" and python_version < "4.0"
72
+ protobuf==6.33.2 ; python_version >= "3.12" and python_version < "4.0"
73
+ pyarrow==22.0.0 ; python_version >= "3.12" and python_version < "4.0"
74
+ pyasn1-modules==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
75
+ pyasn1==0.6.1 ; python_version >= "3.12" and python_version < "4.0"
76
+ pycparser==2.23 ; python_version >= "3.12" and python_version < "4.0" and platform_python_implementation != "PyPy" and implementation_name != "PyPy"
77
+ pydantic-core==2.41.5 ; python_version >= "3.12" and python_version < "4.0"
78
+ pydantic==2.12.5 ; python_version >= "3.12" and python_version < "4.0"
79
+ pydub==0.25.1 ; python_version >= "3.12" and python_version < "4.0"
80
+ pygments==2.19.2 ; python_version >= "3.12" and python_version < "4.0"
81
+ pyparsing==3.3.1 ; python_version >= "3.12" and python_version < "4.0"
82
+ python-dateutil==2.9.0.post0 ; python_version >= "3.12" and python_version < "4.0"
83
+ python-dotenv==1.2.1 ; python_version >= "3.12" and python_version < "4.0"
84
+ python-json-logger==4.0.0 ; python_version >= "3.12" and python_version < "4.0"
85
+ python-multipart==0.0.21 ; python_version >= "3.12" and python_version < "4.0"
86
+ pytz==2025.2 ; python_version >= "3.12" and python_version < "4.0"
87
+ pywin32==311 ; python_version >= "3.12" and python_version < "4.0" and sys_platform == "win32"
88
+ pyyaml==6.0.3 ; python_version >= "3.12" and python_version < "4.0"
89
+ requests==2.32.5 ; python_version >= "3.12" and python_version < "4.0"
90
+ rich==14.2.0 ; python_version >= "3.12" and python_version < "4.0"
91
+ rsa==4.9.1 ; python_version >= "3.12" and python_version < "4.0"
92
+ safehttpx==0.1.7 ; python_version >= "3.12" and python_version < "4.0"
93
+ scikit-learn==1.6.1 ; python_version >= "3.12" and python_version < "4.0"
94
+ scipy==1.16.3 ; python_version >= "3.12" and python_version < "4.0"
95
+ semantic-version==2.10.0 ; python_version >= "3.12" and python_version < "4.0"
96
+ shellingham==1.5.4 ; python_version >= "3.12" and python_version < "4.0"
97
+ six==1.17.0 ; python_version >= "3.12" and python_version < "4.0"
98
+ sklearn-compat==0.1.5 ; python_version >= "3.12" and python_version < "4.0"
99
+ slowapi==0.1.9 ; python_version >= "3.12" and python_version < "4.0"
100
+ smmap==5.0.2 ; python_version >= "3.12" and python_version < "4.0"
101
+ sqlalchemy==2.0.45 ; python_version >= "3.12" and python_version < "4.0"
102
+ sqlparse==0.5.5 ; python_version >= "3.12" and python_version < "4.0"
103
+ starlette==0.46.2 ; python_version >= "3.12" and python_version < "4.0"
104
+ threadpoolctl==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
105
+ tomlkit==0.13.3 ; python_version >= "3.12" and python_version < "4.0"
106
+ tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
107
+ typer-slim==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
108
+ typer==0.21.0 ; python_version >= "3.12" and python_version < "4.0"
109
+ typing-extensions==4.15.0 ; python_version >= "3.12" and python_version < "4.0"
110
+ typing-inspection==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
111
+ tzdata==2025.3 ; python_version >= "3.12" and python_version < "4.0"
112
+ urllib3==2.6.2 ; python_version >= "3.12" and python_version < "4.0"
113
+ uvicorn==0.32.1 ; python_version >= "3.12" and python_version < "4.0"
114
+ uvloop==0.22.1 ; python_version >= "3.12" and python_version < "4.0" and sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy"
115
+ waitress==3.0.2 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
116
+ watchfiles==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
117
+ websockets==15.0.1 ; python_version >= "3.12" and python_version < "4.0"
118
+ werkzeug==3.1.4 ; python_version >= "3.12" and python_version < "4.0"
119
+ wrapt==2.0.1 ; python_version >= "3.12" and python_version < "4.0"
120
+ xgboost==2.1.4 ; python_version >= "3.12" and python_version < "4.0"
121
+ zipp==3.23.0 ; python_version >= "3.12" and python_version < "4.0"
src/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """Module src pour l'API FastAPI."""
src/auth.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module d'authentification pour l'API.
4
+
5
+ Fournit un système de vérification de clé API via header HTTP.
6
+ """
7
+ from fastapi import Header, HTTPException, status
8
+ from fastapi.security import APIKeyHeader
9
+
10
+ from src.config import get_settings
11
+
12
+ # Schéma pour la documentation Swagger
13
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
14
+
15
+
16
+ async def verify_api_key(x_api_key: str = Header(None)) -> str:
17
+ """
18
+ Vérifie que la clé API fournie est valide.
19
+
20
+ Cette fonction est utilisée comme dépendance FastAPI (Depends).
21
+ Elle vérifie le header HTTP "X-API-Key" et compare avec la clé configurée.
22
+
23
+ Args:
24
+ x_api_key: Clé API fournie dans le header HTTP.
25
+
26
+ Returns:
27
+ str: La clé API validée.
28
+
29
+ Raises:
30
+ HTTPException: 401 si la clé est manquante ou invalide.
31
+
32
+ Comment ça marche :
33
+ 1. FastAPI extrait automatiquement le header "X-API-Key"
34
+ 2. La fonction compare avec la clé configurée dans .env
35
+ 3. Si valide → continue, sinon → erreur 401
36
+
37
+ Exemple d'utilisation :
38
+ ```python
39
+ @app.post("/predict", dependencies=[Depends(verify_api_key)])
40
+ async def predict(...):
41
+ # Cette route est protégée !
42
+ ```
43
+
44
+ Exemple de requête curl :
45
+ ```bash
46
+ curl -X POST http://localhost:8000/predict \\
47
+ -H "X-API-Key: your-secret-key" \\
48
+ -H "Content-Type: application/json" \\
49
+ -d '{...}'
50
+ ```
51
+ """
52
+ settings = get_settings()
53
+
54
+ # En mode DEBUG, on peut désactiver l'auth
55
+ if settings.DEBUG:
56
+ return "debug-mode-no-auth-required"
57
+
58
+ # Vérifier que la clé est fournie
59
+ if not x_api_key:
60
+ raise HTTPException(
61
+ status_code=status.HTTP_401_UNAUTHORIZED,
62
+ detail={
63
+ "error": "API Key missing",
64
+ "message": "Le header 'X-API-Key' est requis pour accéder à cette ressource",
65
+ "solution": "Ajoutez le header: -H 'X-API-Key: votre-cle-api'",
66
+ },
67
+ headers={"WWW-Authenticate": "ApiKey"},
68
+ )
69
+
70
+ # Vérifier que la clé est correcte
71
+ if x_api_key != settings.API_KEY:
72
+ raise HTTPException(
73
+ status_code=status.HTTP_401_UNAUTHORIZED,
74
+ detail={
75
+ "error": "Invalid API Key",
76
+ "message": "La clé API fournie est invalide",
77
+ "solution": "Vérifiez votre clé API ou contactez l'administrateur",
78
+ },
79
+ headers={"WWW-Authenticate": "ApiKey"},
80
+ )
81
+
82
+ return x_api_key
83
+
84
+
85
+ def get_api_key_dependency():
86
+ """
87
+ Retourne la dépendance d'authentification si nécessaire.
88
+
89
+ Permet de conditionner l'authentification selon la config.
90
+
91
+ Returns:
92
+ Depends(verify_api_key) si auth requise, None sinon.
93
+ """
94
+ settings = get_settings()
95
+ if settings.is_api_key_required:
96
+ from fastapi import Depends
97
+
98
+ return Depends(verify_api_key)
99
+ return None
src/config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de configuration de l'application.
4
+
5
+ Charge les variables d'environnement depuis .env et fournit
6
+ une interface pour accéder à la configuration de manière sécurisée.
7
+ """
8
+ import os
9
+ from functools import lru_cache
10
+
11
+ from dotenv import load_dotenv
12
+
13
+ # Charger .env au démarrage du module
14
+ load_dotenv()
15
+
16
+
17
+ class Settings:
18
+ """
19
+ Configuration de l'application.
20
+
21
+ Toutes les valeurs sensibles (API keys, etc.) sont chargées depuis
22
+ les variables d'environnement ou le fichier .env.
23
+ """
24
+
25
+ # ===== SÉCURITÉ =====
26
+ API_KEY: str = os.getenv("API_KEY", "dev-key-change-me-in-production")
27
+
28
+ # ===== API =====
29
+ API_VERSION: str = os.getenv("API_VERSION", "1.0.0")
30
+ API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
31
+ API_PORT: int = int(os.getenv("API_PORT", "8000"))
32
+
33
+ # ===== MODÈLE =====
34
+ HF_MODEL_REPO: str = os.getenv(
35
+ "HF_MODEL_REPO", "ASI-Engineer/employee-turnover-model"
36
+ )
37
+ MODEL_FILENAME: str = os.getenv("MODEL_FILENAME", "model/model.pkl")
38
+
39
+ # ===== ENVIRONNEMENT =====
40
+ DEBUG: bool = os.getenv("DEBUG", "False").lower() == "true"
41
+ LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
42
+
43
+ @property
44
+ def is_api_key_required(self) -> bool:
45
+ """
46
+ Vérifie si l'API key est requise.
47
+
48
+ Returns:
49
+ False en mode DEBUG, True en production.
50
+ """
51
+ return not self.DEBUG
52
+
53
+
54
+ @lru_cache()
55
+ def get_settings() -> Settings:
56
+ """
57
+ Retourne l'instance singleton des settings.
58
+
59
+ Le décorateur @lru_cache() assure qu'on ne crée qu'une seule instance.
60
+
61
+ Returns:
62
+ Settings: Configuration de l'application.
63
+ """
64
+ return Settings()
src/gradio_ui.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Interface Gradio pour l'API Employee Turnover Prediction.
4
+
5
+ Cette interface permet de:
6
+ - Tester les prédictions de manière interactive
7
+ - Visualiser la documentation de l'API
8
+ - Comprendre les champs requis
9
+ """
10
+ import gradio as gr
11
+
12
+ from src.models import load_model, get_model_info
13
+ from src.preprocessing import preprocess_for_prediction
14
+ from src.schemas import EmployeeInput
15
+
16
+
17
+ def predict_turnover(
18
+ # SONDAGE
19
+ nombre_participation_pee: int,
20
+ nb_formations_suivies: int,
21
+ nombre_employee_sous_responsabilite: int,
22
+ distance_domicile_travail: int,
23
+ niveau_education: int,
24
+ domaine_etude: str,
25
+ ayant_enfants: str,
26
+ frequence_deplacement: str,
27
+ annees_depuis_la_derniere_promotion: int,
28
+ annes_sous_responsable_actuel: int,
29
+ # EVALUATION
30
+ satisfaction_employee_environnement: int,
31
+ note_evaluation_precedente: int,
32
+ niveau_hierarchique_poste: int,
33
+ satisfaction_employee_nature_travail: int,
34
+ satisfaction_employee_equipe: int,
35
+ satisfaction_employee_equilibre_pro_perso: int,
36
+ note_evaluation_actuelle: int,
37
+ heure_supplementaires: str,
38
+ augementation_salaire_precedente: float,
39
+ # SIRH
40
+ age: int,
41
+ genre: str,
42
+ revenu_mensuel: float,
43
+ statut_marital: str,
44
+ departement: str,
45
+ poste: str,
46
+ nombre_experiences_precedentes: int,
47
+ nombre_heures_travailless: int,
48
+ annee_experience_totale: int,
49
+ annees_dans_l_entreprise: int,
50
+ annees_dans_le_poste_actuel: int,
51
+ ) -> str:
52
+ """Effectue une prédiction de turnover."""
53
+ try:
54
+ # Créer l'objet EmployeeInput
55
+ employee = EmployeeInput(
56
+ nombre_participation_pee=nombre_participation_pee,
57
+ nb_formations_suivies=nb_formations_suivies,
58
+ nombre_employee_sous_responsabilite=nombre_employee_sous_responsabilite,
59
+ distance_domicile_travail=distance_domicile_travail,
60
+ niveau_education=niveau_education,
61
+ domaine_etude=domaine_etude,
62
+ ayant_enfants=ayant_enfants,
63
+ frequence_deplacement=frequence_deplacement,
64
+ annees_depuis_la_derniere_promotion=annees_depuis_la_derniere_promotion,
65
+ annes_sous_responsable_actuel=annes_sous_responsable_actuel,
66
+ satisfaction_employee_environnement=satisfaction_employee_environnement,
67
+ note_evaluation_precedente=note_evaluation_precedente,
68
+ niveau_hierarchique_poste=niveau_hierarchique_poste,
69
+ satisfaction_employee_nature_travail=satisfaction_employee_nature_travail,
70
+ satisfaction_employee_equipe=satisfaction_employee_equipe,
71
+ satisfaction_employee_equilibre_pro_perso=satisfaction_employee_equilibre_pro_perso,
72
+ note_evaluation_actuelle=note_evaluation_actuelle,
73
+ heure_supplementaires=heure_supplementaires,
74
+ augementation_salaire_precedente=augementation_salaire_precedente,
75
+ age=age,
76
+ genre=genre,
77
+ revenu_mensuel=revenu_mensuel,
78
+ statut_marital=statut_marital,
79
+ departement=departement,
80
+ poste=poste,
81
+ nombre_experiences_precedentes=nombre_experiences_precedentes,
82
+ nombre_heures_travailless=nombre_heures_travailless,
83
+ annee_experience_totale=annee_experience_totale,
84
+ annees_dans_l_entreprise=annees_dans_l_entreprise,
85
+ annees_dans_le_poste_actuel=annees_dans_le_poste_actuel,
86
+ )
87
+
88
+ # Préprocessing
89
+ features = preprocess_for_prediction(employee)
90
+
91
+ # Prédiction
92
+ model = load_model()
93
+ prediction = model.predict(features)[0]
94
+ proba = model.predict_proba(features)[0]
95
+
96
+ # Résultat
97
+ risk_level = "🔴 RISQUE ÉLEVÉ" if prediction == 1 else "🟢 RISQUE FAIBLE"
98
+ confidence = max(proba) * 100
99
+
100
+ result = f"""
101
+ ## {risk_level}
102
+
103
+ ### Résultat de la prédiction
104
+ - **Prédiction**: {"Départ probable" if prediction == 1 else "Maintien probable"}
105
+ - **Confiance**: {confidence:.1f}%
106
+ - **Probabilité de départ**: {proba[1] * 100:.1f}%
107
+ - **Probabilité de maintien**: {proba[0] * 100:.1f}%
108
+
109
+ ### Interprétation
110
+ {"⚠️ Cet employé présente des facteurs de risque de départ. Il est recommandé d'engager un dialogue pour comprendre ses attentes." if prediction == 1 else "✅ Cet employé semble stable. Continuez à maintenir un environnement de travail positif."}
111
+ """
112
+ return result
113
+
114
+ except Exception as e:
115
+ return f"❌ **Erreur**: {str(e)}"
116
+
117
+
118
+ # Documentation de l'API
119
+ API_DOCS = """
120
+ # 🚀 Employee Turnover Prediction API
121
+
122
+ ## Description
123
+ Cette API prédit le risque de départ (turnover) d'un employé en utilisant un modèle
124
+ de Machine Learning entraîné sur des données RH.
125
+
126
+ ## Endpoints disponibles
127
+
128
+ ### `GET /`
129
+ Page d'accueil avec informations sur l'API.
130
+
131
+ ### `GET /health`
132
+ Vérification de l'état de l'API.
133
+ ```bash
134
+ curl https://asi-engineer-oc-p5-dev.hf.space/health
135
+ ```
136
+
137
+ ### `GET /docs`
138
+ Documentation Swagger interactive.
139
+
140
+ ### `POST /predict`
141
+ Effectue une prédiction de turnover.
142
+
143
+ ## Exemple d'utilisation avec curl
144
+
145
+ ```bash
146
+ curl -X POST https://asi-engineer-oc-p5-dev.hf.space/predict \\
147
+ -H "Content-Type: application/json" \\
148
+ -d '{
149
+ "nombre_participation_pee": 0,
150
+ "nb_formations_suivies": 2,
151
+ "nombre_employee_sous_responsabilite": 1,
152
+ "distance_domicile_travail": 15,
153
+ "niveau_education": 3,
154
+ "domaine_etude": "Infra & Cloud",
155
+ "ayant_enfants": "Y",
156
+ "frequence_deplacement": "Occasionnel",
157
+ "annees_depuis_la_derniere_promotion": 2,
158
+ "annes_sous_responsable_actuel": 5,
159
+ "satisfaction_employee_environnement": 3,
160
+ "note_evaluation_precedente": 4,
161
+ "niveau_hierarchique_poste": 2,
162
+ "satisfaction_employee_nature_travail": 3,
163
+ "satisfaction_employee_equipe": 3,
164
+ "satisfaction_employee_equilibre_pro_perso": 2,
165
+ "note_evaluation_actuelle": 4,
166
+ "heure_supplementaires": "Non",
167
+ "augementation_salaire_precedente": 5.5,
168
+ "age": 35,
169
+ "genre": "M",
170
+ "revenu_mensuel": 4500.0,
171
+ "statut_marital": "Marié(e)",
172
+ "departement": "Commercial",
173
+ "poste": "Manager",
174
+ "nombre_experiences_precedentes": 3,
175
+ "nombre_heures_travailless": 45,
176
+ "annee_experience_totale": 10,
177
+ "annees_dans_l_entreprise": 5,
178
+ "annees_dans_le_poste_actuel": 2
179
+ }'
180
+ ```
181
+
182
+ ## Exemple avec Python
183
+
184
+ ```python
185
+ import requests
186
+
187
+ url = "https://asi-engineer-oc-p5-dev.hf.space/predict"
188
+
189
+ data = {
190
+ "nombre_participation_pee": 0,
191
+ "nb_formations_suivies": 2,
192
+ "nombre_employee_sous_responsabilite": 1,
193
+ "distance_domicile_travail": 15,
194
+ "niveau_education": 3,
195
+ "domaine_etude": "Infra & Cloud",
196
+ "ayant_enfants": "Y",
197
+ "frequence_deplacement": "Occasionnel",
198
+ "annees_depuis_la_derniere_promotion": 2,
199
+ "annes_sous_responsable_actuel": 5,
200
+ "satisfaction_employee_environnement": 3,
201
+ "note_evaluation_precedente": 4,
202
+ "niveau_hierarchique_poste": 2,
203
+ "satisfaction_employee_nature_travail": 3,
204
+ "satisfaction_employee_equipe": 3,
205
+ "satisfaction_employee_equilibre_pro_perso": 2,
206
+ "note_evaluation_actuelle": 4,
207
+ "heure_supplementaires": "Non",
208
+ "augementation_salaire_precedente": 5.5,
209
+ "age": 35,
210
+ "genre": "M",
211
+ "revenu_mensuel": 4500.0,
212
+ "statut_marital": "Marié(e)",
213
+ "departement": "Commercial",
214
+ "poste": "Manager",
215
+ "nombre_experiences_precedentes": 3,
216
+ "nombre_heures_travailless": 45,
217
+ "annee_experience_totale": 10,
218
+ "annees_dans_l_entreprise": 5,
219
+ "annees_dans_le_poste_actuel": 2
220
+ }
221
+
222
+ response = requests.post(url, json=data)
223
+ print(response.json())
224
+ ```
225
+
226
+ ## Réponse attendue
227
+
228
+ ```json
229
+ {
230
+ "prediction": 0,
231
+ "probability": {
232
+ "stay": 0.85,
233
+ "leave": 0.15
234
+ },
235
+ "risk_level": "low",
236
+ "model_version": "1.0.0"
237
+ }
238
+ ```
239
+
240
+ ## Codes d'erreur
241
+
242
+ | Code | Description |
243
+ |------|-------------|
244
+ | 200 | Succès |
245
+ | 422 | Données invalides (validation Pydantic) |
246
+ | 429 | Trop de requêtes (rate limit: 20/min) |
247
+ | 500 | Erreur serveur |
248
+
249
+ ## Modèle utilisé
250
+
251
+ - **Type**: XGBoost Pipeline
252
+ - **Source**: HuggingFace Hub (`ASI-Engineer/employee-turnover-model`)
253
+ - **Features**: 25 variables RH (sondage, évaluation, SIRH)
254
+ """
255
+
256
+
257
+ def create_gradio_interface():
258
+ """Crée l'interface Gradio complète."""
259
+
260
+ # Obtenir les infos du modèle
261
+ try:
262
+ model_info = get_model_info()
263
+ model_status = f"✅ Modèle chargé: {model_info.get('type', 'Unknown')}"
264
+ except Exception:
265
+ model_status = "⏳ Modèle en cours de chargement..."
266
+
267
+ with gr.Blocks(
268
+ title="Employee Turnover Prediction",
269
+ ) as demo:
270
+ gr.Markdown(
271
+ """
272
+ # 🏢 Employee Turnover Prediction
273
+
274
+ Prédisez le risque de départ d'un employé grâce au Machine Learning.
275
+
276
+ **Naviguez entre les onglets** pour utiliser l'interface de prédiction
277
+ ou consulter la documentation de l'API.
278
+ """
279
+ )
280
+
281
+ gr.Markdown(f"**Statut**: {model_status}")
282
+
283
+ with gr.Tabs():
284
+ # Onglet Prédiction
285
+ with gr.TabItem("🎯 Prédiction"):
286
+ gr.Markdown("### Remplissez les informations de l'employé")
287
+
288
+ with gr.Row():
289
+ # Colonne SONDAGE
290
+ with gr.Column():
291
+ gr.Markdown("#### 📋 Données Sondage")
292
+ nombre_participation_pee = gr.Slider(
293
+ 0, 10, value=0, step=1, label="Participations PEE"
294
+ )
295
+ nb_formations_suivies = gr.Slider(
296
+ 0, 10, value=2, step=1, label="Formations suivies"
297
+ )
298
+ nombre_employee_sous_responsabilite = gr.Slider(
299
+ 0, 20, value=0, step=1, label="Employés sous responsabilité"
300
+ )
301
+ distance_domicile_travail = gr.Slider(
302
+ 0, 50, value=15, step=1, label="Distance domicile (km)"
303
+ )
304
+ niveau_education = gr.Slider(
305
+ 1, 5, value=3, step=1, label="Niveau éducation (1-5)"
306
+ )
307
+ domaine_etude = gr.Dropdown(
308
+ [
309
+ "Infra & Cloud",
310
+ "Transformation Digitale",
311
+ "Marketing",
312
+ "Entrepreunariat",
313
+ "Ressources Humaines",
314
+ "Autre",
315
+ ],
316
+ value="Infra & Cloud",
317
+ label="Domaine d'études",
318
+ )
319
+ ayant_enfants = gr.Radio(
320
+ ["Y", "N"], value="N", label="A des enfants"
321
+ )
322
+ frequence_deplacement = gr.Dropdown(
323
+ ["Aucun", "Occasionnel", "Frequent"],
324
+ value="Occasionnel",
325
+ label="Fréquence déplacements",
326
+ )
327
+ annees_depuis_la_derniere_promotion = gr.Slider(
328
+ 0, 15, value=2, step=1, label="Années depuis promotion"
329
+ )
330
+ annes_sous_responsable_actuel = gr.Slider(
331
+ 0, 20, value=3, step=1, label="Années sous responsable"
332
+ )
333
+
334
+ # Colonne EVALUATION
335
+ with gr.Column():
336
+ gr.Markdown("#### 📊 Données Évaluation")
337
+ satisfaction_employee_environnement = gr.Slider(
338
+ 1, 5, value=3, step=1, label="Satisfaction environnement"
339
+ )
340
+ note_evaluation_precedente = gr.Slider(
341
+ 1, 5, value=3, step=1, label="Évaluation précédente"
342
+ )
343
+ niveau_hierarchique_poste = gr.Slider(
344
+ 1, 5, value=2, step=1, label="Niveau hiérarchique"
345
+ )
346
+ satisfaction_employee_nature_travail = gr.Slider(
347
+ 1, 5, value=3, step=1, label="Satisfaction nature travail"
348
+ )
349
+ satisfaction_employee_equipe = gr.Slider(
350
+ 1, 5, value=3, step=1, label="Satisfaction équipe"
351
+ )
352
+ satisfaction_employee_equilibre_pro_perso = gr.Slider(
353
+ 1, 5, value=3, step=1, label="Équilibre pro/perso"
354
+ )
355
+ note_evaluation_actuelle = gr.Slider(
356
+ 1, 5, value=3, step=1, label="Évaluation actuelle"
357
+ )
358
+ heure_supplementaires = gr.Radio(
359
+ ["Oui", "Non"], value="Non", label="Heures supplémentaires"
360
+ )
361
+ augementation_salaire_precedente = gr.Slider(
362
+ 0,
363
+ 25,
364
+ value=5.0,
365
+ step=0.5,
366
+ label="Augmentation précédente (%)",
367
+ )
368
+
369
+ # Colonne SIRH
370
+ with gr.Column():
371
+ gr.Markdown("#### 👤 Données SIRH")
372
+ age = gr.Slider(18, 65, value=35, step=1, label="Âge")
373
+ genre = gr.Radio(["M", "F"], value="M", label="Genre")
374
+ revenu_mensuel = gr.Slider(
375
+ 1500,
376
+ 15000,
377
+ value=4500,
378
+ step=100,
379
+ label="Revenu mensuel (€)",
380
+ )
381
+ statut_marital = gr.Dropdown(
382
+ ["Célibataire", "Marié(e)", "Divorcé(e)"],
383
+ value="Célibataire",
384
+ label="Statut marital",
385
+ )
386
+ departement = gr.Dropdown(
387
+ ["Commercial", "Consulting", "Ressources Humaines"],
388
+ value="Commercial",
389
+ label="Département",
390
+ )
391
+ poste = gr.Dropdown(
392
+ [
393
+ "Cadre Commercial",
394
+ "Assistant de Direction",
395
+ "Consultant",
396
+ "Tech Lead",
397
+ "Manager",
398
+ "Senior Manager",
399
+ "Représentant Commercial",
400
+ "Directeur Technique",
401
+ "Ressources Humaines",
402
+ ],
403
+ value="Consultant",
404
+ label="Poste",
405
+ )
406
+ nombre_experiences_precedentes = gr.Slider(
407
+ 0, 10, value=2, step=1, label="Expériences précédentes"
408
+ )
409
+ nombre_heures_travailless = gr.Slider(
410
+ 35, 80, value=40, step=1, label="Heures travaillées/sem"
411
+ )
412
+ annee_experience_totale = gr.Slider(
413
+ 0, 40, value=10, step=1, label="Années d'expérience totale"
414
+ )
415
+ annees_dans_l_entreprise = gr.Slider(
416
+ 0, 30, value=5, step=1, label="Années dans l'entreprise"
417
+ )
418
+ annees_dans_le_poste_actuel = gr.Slider(
419
+ 0, 20, value=2, step=1, label="Années dans le poste"
420
+ )
421
+
422
+ # Bouton et résultat
423
+ predict_btn = gr.Button(
424
+ "🔮 Prédire le risque de départ", variant="primary"
425
+ )
426
+ result = gr.Markdown(label="Résultat")
427
+
428
+ predict_btn.click(
429
+ fn=predict_turnover,
430
+ inputs=[
431
+ nombre_participation_pee,
432
+ nb_formations_suivies,
433
+ nombre_employee_sous_responsabilite,
434
+ distance_domicile_travail,
435
+ niveau_education,
436
+ domaine_etude,
437
+ ayant_enfants,
438
+ frequence_deplacement,
439
+ annees_depuis_la_derniere_promotion,
440
+ annes_sous_responsable_actuel,
441
+ satisfaction_employee_environnement,
442
+ note_evaluation_precedente,
443
+ niveau_hierarchique_poste,
444
+ satisfaction_employee_nature_travail,
445
+ satisfaction_employee_equipe,
446
+ satisfaction_employee_equilibre_pro_perso,
447
+ note_evaluation_actuelle,
448
+ heure_supplementaires,
449
+ augementation_salaire_precedente,
450
+ age,
451
+ genre,
452
+ revenu_mensuel,
453
+ statut_marital,
454
+ departement,
455
+ poste,
456
+ nombre_experiences_precedentes,
457
+ nombre_heures_travailless,
458
+ annee_experience_totale,
459
+ annees_dans_l_entreprise,
460
+ annees_dans_le_poste_actuel,
461
+ ],
462
+ outputs=result,
463
+ )
464
+
465
+ # Onglet Documentation
466
+ with gr.TabItem("📚 Documentation API"):
467
+ gr.Markdown(API_DOCS)
468
+
469
+ # Onglet À propos
470
+ with gr.TabItem("ℹ️ À propos"):
471
+ gr.Markdown(
472
+ """
473
+ ## À propos de ce projet
474
+
475
+ ### 🎓 Contexte
476
+ Ce projet a été réalisé dans le cadre du **Projet 5 OpenClassrooms** :
477
+ "Déployez votre modèle de Machine Learning".
478
+
479
+ ### 🎯 Objectif
480
+ Développer une API de prédiction du turnover (départ) des employés,
481
+ permettant aux équipes RH d'anticiper et de prévenir les départs.
482
+
483
+ ### 🛠️ Technologies utilisées
484
+ - **FastAPI** : Framework API REST performant
485
+ - **XGBoost** : Modèle de Machine Learning
486
+ - **Gradio** : Interface utilisateur
487
+ - **HuggingFace Hub** : Hébergement du modèle
488
+ - **HuggingFace Spaces** : Déploiement de l'application
489
+ - **GitHub Actions** : CI/CD automatisé
490
+
491
+ ### 📊 Le modèle
492
+ Le modèle a été entraîné sur des données RH comprenant :
493
+ - Données de sondage de satisfaction
494
+ - Données d'évaluation de performance
495
+ - Données administratives SIRH
496
+
497
+ ### 🔗 Liens utiles
498
+ - [GitHub Repository](https://github.com/chaton59/OC_P5)
499
+ - [API Documentation (Swagger)](/docs)
500
+ - [HuggingFace Model](https://huggingface.co/ASI-Engineer/employee-turnover-model)
501
+
502
+ ### 👤 Auteur
503
+ Projet OpenClassrooms - Formation Data Scientist
504
+ """
505
+ )
506
+
507
+ return demo
508
+
509
+
510
+ # Pour lancer en standalone
511
+ if __name__ == "__main__":
512
+ demo = create_gradio_interface()
513
+ demo.launch(server_name="0.0.0.0", server_port=7860)
src/logger.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de logging structuré pour l'API Employee Turnover.
4
+
5
+ Fournit un système de logging centralisé avec :
6
+ - Logs structurés en JSON
7
+ - Rotation automatique des fichiers
8
+ - Niveaux de log configurables
9
+ - Intégration FastAPI
10
+ """
11
+ import logging
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Any, Dict
15
+
16
+ from pythonjsonlogger import jsonlogger
17
+
18
+ from src.config import get_settings
19
+
20
+ settings = get_settings()
21
+
22
+ # Créer le dossier logs s'il n'existe pas
23
+ LOG_DIR = Path("logs")
24
+ LOG_DIR.mkdir(exist_ok=True)
25
+
26
+ # Fichiers de logs
27
+ LOG_FILE = LOG_DIR / "api.log"
28
+ ERROR_LOG_FILE = LOG_DIR / "error.log"
29
+
30
+
31
+ class CustomJsonFormatter(jsonlogger.JsonFormatter):
32
+ """
33
+ Formatter JSON personnalisé avec champs supplémentaires.
34
+ """
35
+
36
+ def add_fields(
37
+ self,
38
+ log_record: Dict[str, Any],
39
+ record: logging.LogRecord,
40
+ message_dict: Dict[str, Any],
41
+ ) -> None:
42
+ """
43
+ Ajoute des champs personnalisés aux logs JSON.
44
+ """
45
+ super().add_fields(log_record, record, message_dict)
46
+
47
+ # Ajouter des métadonnées
48
+ log_record["level"] = record.levelname
49
+ log_record["logger"] = record.name
50
+ log_record["module"] = record.module
51
+ log_record["function"] = record.funcName
52
+ log_record["line"] = record.lineno
53
+
54
+ # Timestamp ISO 8601
55
+ if not log_record.get("timestamp"):
56
+ log_record["timestamp"] = self.formatTime(record, self.datefmt)
57
+
58
+
59
+ def setup_logger(name: str = "employee_turnover_api") -> logging.Logger:
60
+ """
61
+ Configure et retourne un logger structuré.
62
+
63
+ Args:
64
+ name: Nom du logger.
65
+
66
+ Returns:
67
+ Logger configuré avec handlers console et fichiers.
68
+
69
+ Examples:
70
+ >>> logger = setup_logger()
71
+ >>> logger.info("API démarrée", extra={"version": "2.0.0"})
72
+ """
73
+ logger = logging.getLogger(name)
74
+
75
+ # Éviter duplication si déjà configuré
76
+ if logger.handlers:
77
+ return logger
78
+
79
+ # Niveau de log depuis configuration
80
+ log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
81
+ logger.setLevel(log_level)
82
+
83
+ # === HANDLER CONSOLE (stdout) ===
84
+ console_handler = logging.StreamHandler(sys.stdout)
85
+ console_handler.setLevel(log_level)
86
+
87
+ # Format simple pour la console en dev, JSON en prod
88
+ if settings.DEBUG:
89
+ console_format = logging.Formatter(
90
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
91
+ datefmt="%Y-%m-%d %H:%M:%S",
92
+ )
93
+ else:
94
+ console_format = CustomJsonFormatter(
95
+ "%(timestamp)s %(level)s %(name)s %(message)s"
96
+ )
97
+
98
+ console_handler.setFormatter(console_format)
99
+ logger.addHandler(console_handler)
100
+
101
+ # === HANDLER FICHIER (tous les logs) ===
102
+ file_handler = logging.FileHandler(LOG_FILE, encoding="utf-8")
103
+ file_handler.setLevel(log_level)
104
+ file_handler.setFormatter(
105
+ CustomJsonFormatter("%(timestamp)s %(level)s %(name)s %(message)s")
106
+ )
107
+ logger.addHandler(file_handler)
108
+
109
+ # === HANDLER ERREURS UNIQUEMENT ===
110
+ error_handler = logging.FileHandler(ERROR_LOG_FILE, encoding="utf-8")
111
+ error_handler.setLevel(logging.ERROR)
112
+ error_handler.setFormatter(
113
+ CustomJsonFormatter("%(timestamp)s %(level)s %(name)s %(message)s")
114
+ )
115
+ logger.addHandler(error_handler)
116
+
117
+ # Éviter propagation au root logger
118
+ logger.propagate = False
119
+
120
+ return logger
121
+
122
+
123
+ def log_request(
124
+ method: str,
125
+ path: str,
126
+ status_code: int,
127
+ duration_ms: float,
128
+ **kwargs: Any,
129
+ ) -> None:
130
+ """
131
+ Log une requête HTTP avec métadonnées.
132
+
133
+ Args:
134
+ method: Méthode HTTP (GET, POST...).
135
+ path: Chemin de l'endpoint.
136
+ status_code: Code de statut HTTP.
137
+ duration_ms: Durée de la requête en millisecondes.
138
+ **kwargs: Métadonnées additionnelles.
139
+
140
+ Examples:
141
+ >>> log_request("POST", "/predict", 200, 45.3, user_id="123")
142
+ """
143
+ logger = logging.getLogger("employee_turnover_api")
144
+
145
+ log_data = {
146
+ "method": method,
147
+ "path": path,
148
+ "status_code": status_code,
149
+ "duration_ms": round(duration_ms, 2),
150
+ **kwargs,
151
+ }
152
+
153
+ # Niveau selon status code
154
+ if status_code >= 500:
155
+ logger.error(f"Request {method} {path}", extra=log_data)
156
+ elif status_code >= 400:
157
+ logger.warning(f"Request {method} {path}", extra=log_data)
158
+ else:
159
+ logger.info(f"Request {method} {path}", extra=log_data)
160
+
161
+
162
+ def log_prediction(
163
+ employee_id: str | None,
164
+ prediction: int,
165
+ probability: float,
166
+ risk_level: str,
167
+ duration_ms: float,
168
+ ) -> None:
169
+ """
170
+ Log une prédiction effectuée.
171
+
172
+ Args:
173
+ employee_id: ID de l'employé (optionnel).
174
+ prediction: Prédiction (0 ou 1).
175
+ probability: Probabilité de turnover.
176
+ risk_level: Niveau de risque ("low", "medium", "high").
177
+ duration_ms: Durée du preprocessing + pr��diction.
178
+
179
+ Examples:
180
+ >>> log_prediction("EMP123", 1, 0.87, "high", 23.4)
181
+ """
182
+ logger = logging.getLogger("employee_turnover_api")
183
+
184
+ logger.info(
185
+ "Prediction made",
186
+ extra={
187
+ "employee_id": employee_id,
188
+ "prediction": prediction,
189
+ "probability": round(probability, 4),
190
+ "risk_level": risk_level,
191
+ "duration_ms": round(duration_ms, 2),
192
+ },
193
+ )
194
+
195
+
196
+ def log_model_load(model_type: str, duration_ms: float, success: bool) -> None:
197
+ """
198
+ Log le chargement du modèle.
199
+
200
+ Args:
201
+ model_type: Type de modèle chargé.
202
+ duration_ms: Durée du chargement.
203
+ success: Si le chargement a réussi.
204
+
205
+ Examples:
206
+ >>> log_model_load("XGBoost Pipeline", 1234.5, True)
207
+ """
208
+ logger = logging.getLogger("employee_turnover_api")
209
+
210
+ log_data = {
211
+ "model_type": model_type,
212
+ "duration_ms": round(duration_ms, 2),
213
+ "success": success,
214
+ }
215
+
216
+ if success:
217
+ logger.info("Model loaded successfully", extra=log_data)
218
+ else:
219
+ logger.error("Model loading failed", extra=log_data)
220
+
221
+
222
+ # Créer le logger global
223
+ logger = setup_logger()
src/models.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de chargement et gestion du modèle MLflow.
4
+
5
+ Ce module encapsule la logique de chargement du modèle depuis Hugging Face Hub
6
+ via MLflow, avec gestion des erreurs et versioning.
7
+ """
8
+ from typing import Any, Optional
9
+
10
+ from fastapi import HTTPException
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # Configuration
14
+ HF_MODEL_REPO = "ASI-Engineer/employee-turnover-model"
15
+ MODEL_FILENAME = "model/model.pkl"
16
+
17
+ # Cache global du modèle
18
+ _model_cache: Optional[Any] = None
19
+
20
+
21
+ def load_model(force_reload: bool = False) -> Any:
22
+ """
23
+ Charge le modèle depuis Hugging Face Hub via MLflow.
24
+
25
+ Cette fonction implémente un système de cache pour éviter de recharger
26
+ le modèle à chaque appel. Le modèle est chargé une seule fois au démarrage
27
+ de l'application et mis en cache.
28
+
29
+ Args:
30
+ force_reload: Si True, force le rechargement du modèle même s'il est en cache.
31
+
32
+ Returns:
33
+ Le modèle MLflow chargé et prêt pour l'inférence.
34
+
35
+ Raises:
36
+ HTTPException: 500 si le modèle ne peut pas être chargé.
37
+
38
+ Examples:
39
+ >>> model = load_model()
40
+ >>> # Utiliser le modèle pour prédiction
41
+ >>> predictions = model.predict(X)
42
+ """
43
+ global _model_cache
44
+
45
+ # Retourner le modèle en cache si disponible
46
+ if _model_cache is not None and not force_reload:
47
+ return _model_cache
48
+
49
+ try:
50
+ import joblib
51
+
52
+ print(f"🔄 Chargement du modèle depuis HF Hub: {HF_MODEL_REPO}")
53
+
54
+ # Télécharger le modèle depuis Hugging Face Hub
55
+ model_path = hf_hub_download(
56
+ repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME, repo_type="model"
57
+ )
58
+
59
+ print(f"📦 Modèle téléchargé: {model_path}")
60
+
61
+ # Charger le modèle avec joblib
62
+ model = joblib.load(model_path)
63
+
64
+ # Mettre en cache
65
+ _model_cache = model
66
+
67
+ print(f"✅ Modèle chargé avec succès: {type(model).__name__}")
68
+ return model
69
+
70
+ except Exception as e:
71
+ error_msg = f"❌ Erreur lors du chargement du modèle: {str(e)}"
72
+ print(error_msg)
73
+ raise HTTPException(
74
+ status_code=500,
75
+ detail={
76
+ "error": "Model loading failed",
77
+ "message": str(e),
78
+ "model_repo": HF_MODEL_REPO,
79
+ "solution": "Vérifiez que le modèle est disponible sur HF Hub et correctement entraîné",
80
+ },
81
+ )
82
+
83
+
84
+ def get_model_info() -> dict:
85
+ """
86
+ Retourne les informations sur le modèle chargé.
87
+
88
+ Returns:
89
+ Dict contenant les métadonnées du modèle.
90
+
91
+ Raises:
92
+ HTTPException: 500 si le modèle n'est pas chargé.
93
+ """
94
+ try:
95
+ model = load_model()
96
+
97
+ return {
98
+ "status": "✅ Modèle chargé",
99
+ "model_type": type(model).__name__,
100
+ "hf_hub_repo": HF_MODEL_REPO,
101
+ "model_file": MODEL_FILENAME,
102
+ "cached": _model_cache is not None,
103
+ }
104
+
105
+ except Exception as e:
106
+ raise HTTPException(
107
+ status_code=500,
108
+ detail={"error": "Model info unavailable", "message": str(e)},
109
+ )
110
+
111
+
112
+ def load_preprocessing_artifacts(run_id: str) -> dict:
113
+ """
114
+ Charge les artifacts de preprocessing (scaler, encoders) depuis MLflow.
115
+
116
+ Args:
117
+ run_id: ID du run MLflow contenant les artifacts.
118
+
119
+ Returns:
120
+ Dict contenant les artifacts de preprocessing.
121
+
122
+ Raises:
123
+ HTTPException: 500 si les artifacts ne peuvent pas être chargés.
124
+
125
+ Note:
126
+ Cette fonction sera implémentée quand les preprocessing artifacts
127
+ seront disponibles dans le modèle HF Hub.
128
+ """
129
+ raise NotImplementedError(
130
+ "Le chargement des preprocessing artifacts sera implémenté "
131
+ "lors de l'intégration complète avec MLflow"
132
+ )
133
+
134
+
135
+ if __name__ == "__main__":
136
+ # Test de chargement du modèle
137
+ print("=" * 80)
138
+ print("TEST DE CHARGEMENT DU MODÈLE")
139
+ print("=" * 80)
140
+
141
+ try:
142
+ model = load_model()
143
+ print("\n✅ Test réussi!")
144
+ print(f"Type de modèle: {type(model).__name__}")
145
+
146
+ # Afficher les infos
147
+ info = get_model_info()
148
+ print("\nInformations du modèle:")
149
+ for key, value in info.items():
150
+ print(f" {key}: {value}")
151
+
152
+ except Exception as e:
153
+ print(f"\n❌ Test échoué: {e}")
src/preprocessing.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de preprocessing pour transformer les données d'entrée avant prédiction.
4
+
5
+ Ce module applique les mêmes transformations que le pipeline d'entraînement :
6
+ - Feature engineering (ratios, moyennes)
7
+ - Encoding (OneHot, Ordinal)
8
+ - Scaling (StandardScaler)
9
+ """
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
13
+
14
+ from src.schemas import EmployeeInput
15
+
16
+
17
+ def create_input_dataframe(employee: EmployeeInput) -> pd.DataFrame:
18
+ """
19
+ Convertit un objet EmployeeInput Pydantic en DataFrame pandas.
20
+
21
+ Args:
22
+ employee: Données validées d'un employé.
23
+
24
+ Returns:
25
+ DataFrame avec une seule ligne contenant toutes les features.
26
+ """
27
+ data = {
28
+ # SONDAGE
29
+ "nombre_participation_pee": [employee.nombre_participation_pee],
30
+ "nb_formations_suivies": [employee.nb_formations_suivies],
31
+ "nombre_employee_sous_responsabilite": [
32
+ employee.nombre_employee_sous_responsabilite
33
+ ],
34
+ "distance_domicile_travail": [employee.distance_domicile_travail],
35
+ "niveau_education": [employee.niveau_education],
36
+ "domaine_etude": [employee.domaine_etude],
37
+ "ayant_enfants": [employee.ayant_enfants],
38
+ "frequence_deplacement": [employee.frequence_deplacement],
39
+ "annees_depuis_la_derniere_promotion": [
40
+ employee.annees_depuis_la_derniere_promotion
41
+ ],
42
+ "annes_sous_responsable_actuel": [employee.annes_sous_responsable_actuel],
43
+ # EVALUATION
44
+ "satisfaction_employee_environnement": [
45
+ employee.satisfaction_employee_environnement
46
+ ],
47
+ "note_evaluation_precedente": [employee.note_evaluation_precedente],
48
+ "niveau_hierarchique_poste": [employee.niveau_hierarchique_poste],
49
+ "satisfaction_employee_nature_travail": [
50
+ employee.satisfaction_employee_nature_travail
51
+ ],
52
+ "satisfaction_employee_equipe": [employee.satisfaction_employee_equipe],
53
+ "satisfaction_employee_equilibre_pro_perso": [
54
+ employee.satisfaction_employee_equilibre_pro_perso
55
+ ],
56
+ "note_evaluation_actuelle": [employee.note_evaluation_actuelle],
57
+ "heure_supplementaires": [employee.heure_supplementaires],
58
+ "augementation_salaire_precedente": [employee.augementation_salaire_precedente],
59
+ # SIRH
60
+ "age": [employee.age],
61
+ "genre": [employee.genre],
62
+ "revenu_mensuel": [employee.revenu_mensuel],
63
+ "statut_marital": [employee.statut_marital],
64
+ "departement": [employee.departement],
65
+ "poste": [employee.poste],
66
+ "nombre_experiences_precedentes": [employee.nombre_experiences_precedentes],
67
+ "nombre_heures_travailless": [employee.nombre_heures_travailless],
68
+ "annee_experience_totale": [employee.annee_experience_totale],
69
+ "annees_dans_l_entreprise": [employee.annees_dans_l_entreprise],
70
+ "annees_dans_le_poste_actuel": [employee.annees_dans_le_poste_actuel],
71
+ }
72
+
73
+ return pd.DataFrame(data)
74
+
75
+
76
+ def engineer_features(df: pd.DataFrame) -> pd.DataFrame:
77
+ """
78
+ Applique le feature engineering (mêmes transformations que l'entraînement).
79
+
80
+ Args:
81
+ df: DataFrame avec les colonnes brutes.
82
+
83
+ Returns:
84
+ DataFrame avec les features engineered ajoutées.
85
+ """
86
+ df = df.copy()
87
+
88
+ # Ratios (+ 1 pour éviter division par zéro)
89
+ df["revenu_par_anciennete"] = df["revenu_mensuel"] / (
90
+ df["annees_dans_l_entreprise"] + 1
91
+ )
92
+ df["experience_par_anciennete"] = df["annee_experience_totale"] / (
93
+ df["annees_dans_l_entreprise"] + 1
94
+ )
95
+ df["promo_par_anciennete"] = df["annees_depuis_la_derniere_promotion"] / (
96
+ df["annees_dans_l_entreprise"] + 1
97
+ )
98
+
99
+ # Moyenne de satisfaction
100
+ df["satisfaction_moyenne"] = df[
101
+ [
102
+ "satisfaction_employee_environnement",
103
+ "satisfaction_employee_nature_travail",
104
+ "satisfaction_employee_equipe",
105
+ "satisfaction_employee_equilibre_pro_perso",
106
+ ]
107
+ ].mean(axis=1)
108
+
109
+ return df
110
+
111
+
112
+ def encode_and_scale(df: pd.DataFrame) -> pd.DataFrame:
113
+ """
114
+ Encode les variables catégorielles et scale les numériques.
115
+ IMPORTANT: Doit correspondre EXACTEMENT au pipeline d'entraînement.
116
+
117
+ Args:
118
+ df: DataFrame avec features engineered.
119
+
120
+ Returns:
121
+ DataFrame transformé avec 50 colonnes (comme training).
122
+ """
123
+ df = df.copy()
124
+
125
+ # === ENCODING ===
126
+
127
+ # NOTE: ayant_enfants et heure_supplementaires sont SUPPRIMÉS
128
+ # (ne font pas partie des features du modèle d'entraînement)
129
+ cols_to_drop = ["ayant_enfants", "heure_supplementaires"]
130
+ df = df.drop(columns=[col for col in cols_to_drop if col in df.columns])
131
+
132
+ # OneHot pour variables catégorielles non-ordonnées
133
+ # IMPORTANT: Utiliser les mêmes catégories que lors de l'entraînement
134
+ cat_non_ord = ["genre", "statut_marital", "departement", "poste", "domaine_etude"]
135
+
136
+ # Définir toutes les catégories possibles (depuis training data)
137
+ categories_dict = {
138
+ "genre": ["F", "M"],
139
+ "statut_marital": ["Célibataire", "Divorcé(e)", "Marié(e)"],
140
+ "departement": ["Commercial", "Consulting", "Ressources Humaines"],
141
+ "poste": [
142
+ "Assistant de Direction",
143
+ "Cadre Commercial",
144
+ "Consultant",
145
+ "Directeur Technique",
146
+ "Manager",
147
+ "Représentant Commercial",
148
+ "Ressources Humaines",
149
+ "Senior Manager",
150
+ "Tech Lead",
151
+ ],
152
+ "domaine_etude": [
153
+ "Autre",
154
+ "Entrepreunariat",
155
+ "Infra & Cloud",
156
+ "Marketing",
157
+ "Ressources Humaines",
158
+ "Transformation Digitale",
159
+ ],
160
+ }
161
+
162
+ onehot = OneHotEncoder(
163
+ sparse_output=False,
164
+ handle_unknown="ignore",
165
+ categories=[categories_dict[col] for col in cat_non_ord],
166
+ )
167
+
168
+ encoded_non_ord = pd.DataFrame(
169
+ onehot.fit_transform(df[cat_non_ord]),
170
+ columns=onehot.get_feature_names_out(cat_non_ord),
171
+ index=df.index,
172
+ )
173
+
174
+ # Ordinal pour fréquence déplacement
175
+ ordinal = OrdinalEncoder(categories=[["Aucun", "Occasionnel", "Frequent"]])
176
+ df["frequence_deplacement"] = ordinal.fit_transform(
177
+ df[["frequence_deplacement"]]
178
+ ).flatten()
179
+
180
+ # Supprimer les colonnes catégorielles originales
181
+ df = df.drop(columns=cat_non_ord)
182
+
183
+ # Concaténer les encodages OneHot
184
+ df = pd.concat([df, encoded_non_ord], axis=1)
185
+
186
+ # === SCALING ===
187
+
188
+ # Colonnes numériques à scaler
189
+ quantitative_cols = df.select_dtypes(include=[np.number]).columns.tolist()
190
+
191
+ # Retirer les colonnes OneHot du scaling (elles sont déjà 0/1)
192
+ cols_to_scale = [
193
+ col
194
+ for col in quantitative_cols
195
+ if df[col].nunique() > 2 # Exclut colonnes binaires (0/1)
196
+ ]
197
+
198
+ # Appliquer le scaling uniquement s'il y a des colonnes
199
+ if cols_to_scale:
200
+ scaler = StandardScaler()
201
+ df[cols_to_scale] = scaler.fit_transform(df[cols_to_scale])
202
+
203
+ return df
204
+
205
+
206
+ def preprocess_for_prediction(employee: EmployeeInput) -> np.ndarray:
207
+ """
208
+ Pipeline complet de preprocessing pour une prédiction.
209
+
210
+ Args:
211
+ employee: Données validées d'un employé.
212
+
213
+ Returns:
214
+ Array numpy transformé prêt pour model.predict().
215
+
216
+ Examples:
217
+ >>> from src.schemas import EmployeeInput
218
+ >>> employee = EmployeeInput(...)
219
+ >>> X = preprocess_for_prediction(employee)
220
+ >>> prediction = model.predict(X)
221
+ """
222
+ # 1. Créer DataFrame
223
+ df = create_input_dataframe(employee)
224
+
225
+ # 2. Feature engineering
226
+ df = engineer_features(df)
227
+
228
+ # 3. Encoding et scaling
229
+ df = encode_and_scale(df)
230
+
231
+ # 4. Convertir en numpy array (le modèle attend un array)
232
+ return df.values
233
+
234
+
235
+ # TODO: Implémenter le chargement des artifacts sauvegardés
236
+ # def load_preprocessing_artifacts(run_id: str) -> dict:
237
+ # """
238
+ # Charge les encoders et scaler depuis MLflow.
239
+ #
240
+ # Returns:
241
+ # dict avec keys: 'onehot_encoder', 'ordinal_encoder', 'scaler'
242
+ # """
243
+ # pass
src/rate_limit.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module de rate limiting pour protéger l'API contre les abus.
4
+
5
+ Utilise SlowAPI pour limiter le nombre de requêtes par IP/utilisateur.
6
+ """
7
+ from slowapi import Limiter
8
+ from slowapi.util import get_remote_address
9
+
10
+ from src.config import get_settings
11
+
12
+ settings = get_settings()
13
+
14
+ # Créer le limiter avec stratégie par IP
15
+ limiter = Limiter(
16
+ key_func=get_remote_address,
17
+ default_limits=["100/minute"] if not settings.DEBUG else [],
18
+ storage_uri="memory://", # En production: utiliser Redis
19
+ strategy="fixed-window",
20
+ )
21
+
22
+
23
+ def get_rate_limit_key(request):
24
+ """
25
+ Fonction pour obtenir la clé de rate limiting.
26
+
27
+ En production, on pourrait utiliser l'API Key au lieu de l'IP.
28
+
29
+ Args:
30
+ request: Requête FastAPI.
31
+
32
+ Returns:
33
+ Clé unique pour identifier l'utilisateur.
34
+ """
35
+ # Priorité: API Key > IP
36
+ api_key = request.headers.get("X-API-Key")
37
+ if api_key:
38
+ return f"api_key:{api_key}"
39
+
40
+ return get_remote_address(request)
src/schemas.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Schémas Pydantic pour validation des données d'entrée de l'API.
4
+
5
+ Ces schémas correspondent aux colonnes brutes du dataset avant preprocessing,
6
+ permettant une validation stricte des inputs avec messages d'erreur clairs.
7
+ """
8
+ from enum import Enum
9
+ from typing import Literal
10
+
11
+ from pydantic import BaseModel, Field, field_validator
12
+
13
+
14
+ # Enums pour les valeurs catégorielles
15
+ class GenreEnum(str, Enum):
16
+ """Genre de l'employé."""
17
+
18
+ M = "M"
19
+ F = "F"
20
+
21
+
22
+ class StatutMaritalEnum(str, Enum):
23
+ """Statut marital de l'employé."""
24
+
25
+ CELIBATAIRE = "Célibataire"
26
+ MARIE = "Marié(e)"
27
+ DIVORCE = "Divorcé(e)"
28
+
29
+
30
+ class DepartementEnum(str, Enum):
31
+ """Département de l'employé."""
32
+
33
+ COMMERCIAL = "Commercial"
34
+ CONSULTING = "Consulting"
35
+ RESSOURCES_HUMAINES = "Ressources Humaines"
36
+
37
+
38
+ class DomaineEtudeEnum(str, Enum):
39
+ """Domaine d'études de l'employé."""
40
+
41
+ INFRA_CLOUD = "Infra & Cloud"
42
+ TRANSFORMATION_DIGITALE = "Transformation Digitale"
43
+ MARKETING = "Marketing"
44
+ ENTREPREUNARIAT = "Entrepreunariat"
45
+ RESSOURCES_HUMAINES = "Ressources Humaines"
46
+ AUTRE = "Autre"
47
+
48
+
49
+ class PosteEnum(str, Enum):
50
+ """Poste de l'employé."""
51
+
52
+ CADRE_COMMERCIAL = "Cadre Commercial"
53
+ ASSISTANT_DIRECTION = "Assistant de Direction"
54
+ CONSULTANT = "Consultant"
55
+ TECH_LEAD = "Tech Lead"
56
+ MANAGER = "Manager"
57
+ SENIOR_MANAGER = "Senior Manager"
58
+ REPRESENTANT_COMMERCIAL = "Représentant Commercial"
59
+ DIRECTEUR_TECHNIQUE = "Directeur Technique"
60
+ RESSOURCES_HUMAINES = "Ressources Humaines"
61
+
62
+
63
+ class FrequenceDeplacementEnum(str, Enum):
64
+ """Fréquence des déplacements professionnels."""
65
+
66
+ AUCUN = "Aucun"
67
+ OCCASIONNEL = "Occasionnel"
68
+ FREQUENT = "Frequent"
69
+
70
+
71
+ class EmployeeInput(BaseModel):
72
+ """
73
+ Schéma de validation pour les données d'entrée d'un employé.
74
+
75
+ Tous les champs correspondent aux colonnes brutes des 3 fichiers CSV
76
+ (sondage, eval, sirh) avant preprocessing.
77
+ """
78
+
79
+ # === Données SONDAGE ===
80
+ nombre_participation_pee: int = Field(
81
+ ..., ge=0, description="Nombre de participations au PEE"
82
+ )
83
+ nb_formations_suivies: int = Field(
84
+ ..., ge=0, le=10, description="Nombre de formations suivies"
85
+ )
86
+ nombre_employee_sous_responsabilite: int = Field(
87
+ ..., ge=0, description="Nombre d'employés sous responsabilité"
88
+ )
89
+ distance_domicile_travail: int = Field(
90
+ ..., ge=0, le=50, description="Distance domicile-travail en km"
91
+ )
92
+ niveau_education: int = Field(
93
+ ..., ge=1, le=5, description="Niveau d'éducation (1-5)"
94
+ )
95
+ domaine_etude: DomaineEtudeEnum = Field(..., description="Domaine d'études")
96
+ ayant_enfants: Literal["Y", "N"] = Field(..., description="A des enfants (Y/N)")
97
+ frequence_deplacement: FrequenceDeplacementEnum = Field(
98
+ ..., description="Fréquence des déplacements"
99
+ )
100
+ annees_depuis_la_derniere_promotion: int = Field(
101
+ ..., ge=0, description="Années depuis la dernière promotion"
102
+ )
103
+ annes_sous_responsable_actuel: int = Field(
104
+ ..., ge=0, description="Années sous le responsable actuel"
105
+ )
106
+
107
+ # === Données EVALUATION ===
108
+ satisfaction_employee_environnement: int = Field(
109
+ ..., ge=1, le=4, description="Satisfaction environnement (1-4)"
110
+ )
111
+ note_evaluation_precedente: int = Field(
112
+ ..., ge=1, le=5, description="Note évaluation précédente (1-5)"
113
+ )
114
+ niveau_hierarchique_poste: int = Field(
115
+ ..., ge=1, le=5, description="Niveau hiérarchique (1-5)"
116
+ )
117
+ satisfaction_employee_nature_travail: int = Field(
118
+ ..., ge=1, le=4, description="Satisfaction nature du travail (1-4)"
119
+ )
120
+ satisfaction_employee_equipe: int = Field(
121
+ ..., ge=1, le=4, description="Satisfaction équipe (1-4)"
122
+ )
123
+ satisfaction_employee_equilibre_pro_perso: int = Field(
124
+ ..., ge=1, le=4, description="Satisfaction équilibre pro/perso (1-4)"
125
+ )
126
+ note_evaluation_actuelle: int = Field(
127
+ ..., ge=1, le=5, description="Note évaluation actuelle (1-5)"
128
+ )
129
+ heure_supplementaires: Literal["Oui", "Non"] = Field(
130
+ ..., description="Fait des heures supplémentaires"
131
+ )
132
+ augementation_salaire_precedente: float = Field(
133
+ ..., ge=0, le=100, description="Augmentation salaire précédente (%)"
134
+ )
135
+
136
+ # === Données SIRH ===
137
+ age: int = Field(..., ge=18, le=70, description="Âge de l'employé")
138
+ genre: GenreEnum = Field(..., description="Genre")
139
+ revenu_mensuel: float = Field(..., ge=1000, description="Revenu mensuel (€)")
140
+ statut_marital: StatutMaritalEnum = Field(..., description="Statut marital")
141
+ departement: DepartementEnum = Field(..., description="Département")
142
+ poste: PosteEnum = Field(..., description="Intitulé du poste")
143
+ nombre_experiences_precedentes: int = Field(
144
+ ..., ge=0, description="Nombre d'expériences précédentes"
145
+ )
146
+ nombre_heures_travailless: int = Field(
147
+ ..., ge=35, le=80, description="Nombre d'heures travaillées par semaine"
148
+ )
149
+ annee_experience_totale: int = Field(
150
+ ..., ge=0, description="Années d'expérience totale"
151
+ )
152
+ annees_dans_l_entreprise: int = Field(
153
+ ..., ge=0, description="Années dans l'entreprise"
154
+ )
155
+ annees_dans_le_poste_actuel: int = Field(
156
+ ..., ge=0, description="Années dans le poste actuel"
157
+ )
158
+
159
+ @field_validator("augementation_salaire_precedente")
160
+ @classmethod
161
+ def validate_augmentation(cls, v: float) -> float:
162
+ """Nettoie le format de l'augmentation (enlève % si présent)."""
163
+ if isinstance(v, str):
164
+ v = float(v.replace(" %", "").replace("%", ""))
165
+ return v
166
+
167
+ class Config:
168
+ """Configuration Pydantic."""
169
+
170
+ json_schema_extra = {
171
+ "example": {
172
+ # Exemple basé sur la première ligne des CSV
173
+ "nombre_participation_pee": 0,
174
+ "nb_formations_suivies": 0,
175
+ "nombre_employee_sous_responsabilite": 1,
176
+ "distance_domicile_travail": 1,
177
+ "niveau_education": 2,
178
+ "domaine_etude": "Infra & Cloud",
179
+ "ayant_enfants": "Y",
180
+ "frequence_deplacement": "Occasionnel",
181
+ "annees_depuis_la_derniere_promotion": 0,
182
+ "annes_sous_responsable_actuel": 5,
183
+ "satisfaction_employee_environnement": 2,
184
+ "note_evaluation_precedente": 3,
185
+ "niveau_hierarchique_poste": 2,
186
+ "satisfaction_employee_nature_travail": 4,
187
+ "satisfaction_employee_equipe": 1,
188
+ "satisfaction_employee_equilibre_pro_perso": 1,
189
+ "note_evaluation_actuelle": 3,
190
+ "heure_supplementaires": "Oui",
191
+ "augementation_salaire_precedente": 11.0,
192
+ "age": 41,
193
+ "genre": "F",
194
+ "revenu_mensuel": 5993.0,
195
+ "statut_marital": "Célibataire",
196
+ "departement": "Commercial",
197
+ "poste": "Cadre Commercial",
198
+ "nombre_experiences_precedentes": 8,
199
+ "nombre_heures_travailless": 80,
200
+ "annee_experience_totale": 8,
201
+ "annees_dans_l_entreprise": 6,
202
+ "annees_dans_le_poste_actuel": 4,
203
+ }
204
+ }
205
+
206
+
207
+ class PredictionOutput(BaseModel):
208
+ """Schéma de sortie pour les prédictions."""
209
+
210
+ prediction: int = Field(..., description="Classe prédite (0=reste, 1=part)")
211
+ probability_0: float = Field(
212
+ ..., ge=0, le=1, description="Probabilité de rester (classe 0)"
213
+ )
214
+ probability_1: float = Field(
215
+ ..., ge=0, le=1, description="Probabilité de partir (classe 1)"
216
+ )
217
+ risk_level: str = Field(..., description="Niveau de risque (Low/Medium/High)")
218
+
219
+ class Config:
220
+ """Configuration Pydantic."""
221
+
222
+ json_schema_extra = {
223
+ "example": {
224
+ "prediction": 1,
225
+ "probability_0": 0.35,
226
+ "probability_1": 0.65,
227
+ "risk_level": "High",
228
+ }
229
+ }
230
+
231
+
232
+ class HealthCheck(BaseModel):
233
+ """Schéma pour le endpoint health check."""
234
+
235
+ status: str = Field(..., description="Status de l'API")
236
+ model_loaded: bool = Field(..., description="Modèle chargé ou non")
237
+ model_type: str = Field(..., description="Type du modèle")
238
+ version: str = Field(..., description="Version de l'API")
239
+
240
+ class Config:
241
+ """Configuration Pydantic."""
242
+
243
+ json_schema_extra = {
244
+ "example": {
245
+ "status": "healthy",
246
+ "model_loaded": True,
247
+ "model_type": "Pipeline",
248
+ "version": "1.0.0",
249
+ }
250
+ }