Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- .flake8 +8 -2
- .gitignore +9 -0
- docs/mlflow_guide.md +412 -0
- examples/01_find_best_model.py +106 -0
- examples/02_compare_models.py +165 -0
- examples/03_model_registry.py +205 -0
- examples/README.md +111 -0
- ml_model/preprocess.py +12 -14
- ml_model/train_model.py +42 -15
- pyproject.toml +1 -1
- requirements.txt +9 -8
- scripts/fix_lint.py +68 -0
- tests/conftest.py +8 -0
- tests/test_mlflow_local.py +54 -0
- tests/test_mlflow_quick.py +130 -0
.flake8
CHANGED
|
@@ -10,6 +10,12 @@ exclude =
|
|
| 10 |
.cache,
|
| 11 |
.eggs,
|
| 12 |
build,
|
| 13 |
-
dist
|
|
|
|
| 14 |
# Max line pour compat Black (default 88 vs PEP8 79)
|
| 15 |
-
max-line-length = 88
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
.cache,
|
| 11 |
.eggs,
|
| 12 |
build,
|
| 13 |
+
dist,
|
| 14 |
+
mlruns
|
| 15 |
# Max line pour compat Black (default 88 vs PEP8 79)
|
| 16 |
+
max-line-length = 88
|
| 17 |
+
|
| 18 |
+
# Ignorer certains warnings pour les scripts d'exemple (non-critique)
|
| 19 |
+
per-file-ignores =
|
| 20 |
+
examples/*.py:F541,E722,F841
|
| 21 |
+
tests/test_mlflow_*.py:F401,E402,F811,F541
|
.gitignore
CHANGED
|
@@ -35,3 +35,12 @@ Thumbs.db
|
|
| 35 |
secrets.json
|
| 36 |
data/raw/ # Pour datasets volumineux en data science (OC_P5)
|
| 37 |
notebooks/*.ipynb_checkpoints/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
secrets.json
|
| 36 |
data/raw/ # Pour datasets volumineux en data science (OC_P5)
|
| 37 |
notebooks/*.ipynb_checkpoints/
|
| 38 |
+
|
| 39 |
+
# MLflow
|
| 40 |
+
mlflow.db
|
| 41 |
+
mlflow.db-shm
|
| 42 |
+
mlflow.db-wal
|
| 43 |
+
mlruns/
|
| 44 |
+
mlflow_ui.log
|
| 45 |
+
mlflow_comparison.png
|
| 46 |
+
nohup.out
|
docs/mlflow_guide.md
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Guide MLflow - Projet Employee Turnover
|
| 2 |
+
|
| 3 |
+
## 📋 Table des matières
|
| 4 |
+
1. [Workflow complet MLflow](#workflow-complet)
|
| 5 |
+
2. [Comparer plusieurs runs](#comparer-runs)
|
| 6 |
+
3. [Trouver le meilleur modèle](#meilleur-modèle)
|
| 7 |
+
4. [Charger un modèle pour l'API](#api-integration)
|
| 8 |
+
5. [Best Practices](#best-practices)
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## 1. Workflow complet MLflow
|
| 13 |
+
|
| 14 |
+
### 🎯 Concept clé
|
| 15 |
+
MLflow suit ce workflow :
|
| 16 |
+
```
|
| 17 |
+
Entraînement → Tracking → Registry → Déploiement → API
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### Architecture actuelle du projet
|
| 21 |
+
```
|
| 22 |
+
train_model.py
|
| 23 |
+
↓ (log params/metrics/model)
|
| 24 |
+
mlflow.db (SQLite)
|
| 25 |
+
↓ (query)
|
| 26 |
+
MLflow UI (http://localhost:5000)
|
| 27 |
+
↓ (select best model)
|
| 28 |
+
Model Registry (XGBoost_Employee_Turnover)
|
| 29 |
+
↓ (load)
|
| 30 |
+
API FastAPI/Flask
|
| 31 |
+
↓ (serve)
|
| 32 |
+
Prédictions
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## 2. Comparer plusieurs runs
|
| 38 |
+
|
| 39 |
+
### Scénario : Tester différents hyperparamètres
|
| 40 |
+
|
| 41 |
+
**Exemple : Tester 3 configurations différentes**
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
# tests/test_multiple_runs.py
|
| 45 |
+
import mlflow
|
| 46 |
+
from ml_model.preprocess import preprocess_data
|
| 47 |
+
from ml_model.train_model import train_model
|
| 48 |
+
|
| 49 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 50 |
+
mlflow.set_experiment("Hyperparameter_Tuning") # Créer une expérience dédiée
|
| 51 |
+
|
| 52 |
+
# Chemins des données
|
| 53 |
+
data_paths = {
|
| 54 |
+
"sondage_path": "data/extrait_sondage.csv",
|
| 55 |
+
"eval_path": "data/extrait_eval.csv",
|
| 56 |
+
"sirh_path": "data/extrait_sirh.csv",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# Préparer les données une seule fois
|
| 60 |
+
X, y, scaler, encoders = preprocess_data(data_paths)
|
| 61 |
+
|
| 62 |
+
# Tester 3 configurations
|
| 63 |
+
configs = [
|
| 64 |
+
{"name": "Baseline", "n_iter": 100, "cv": 3},
|
| 65 |
+
{"name": "Intensive", "n_iter": 500, "cv": 5},
|
| 66 |
+
{"name": "Quick", "n_iter": 50, "cv": 3},
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
for config in configs:
|
| 70 |
+
with mlflow.start_run(run_name=config["name"]):
|
| 71 |
+
# Log la configuration testée
|
| 72 |
+
mlflow.log_param("config_name", config["name"])
|
| 73 |
+
mlflow.log_param("n_iter", config["n_iter"])
|
| 74 |
+
mlflow.log_param("cv", config["cv"])
|
| 75 |
+
|
| 76 |
+
# Entraîner (modifier train_model pour accepter n_iter/cv)
|
| 77 |
+
model, params, cv_f1 = train_model(X, y)
|
| 78 |
+
|
| 79 |
+
print(f"✅ {config['name']}: F1={cv_f1:.4f}")
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
**Résultat dans MLflow UI** :
|
| 83 |
+
- Va sur **Experiments** → **Hyperparameter_Tuning**
|
| 84 |
+
- Tu verras 3 runs avec leurs métriques côte à côte
|
| 85 |
+
- Clique sur **"Compare"** pour voir un tableau comparatif
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## 3. Trouver le meilleur modèle
|
| 90 |
+
|
| 91 |
+
### Option A : Via l'API MLflow (recommandé pour l'API)
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
# api/get_best_model.py
|
| 95 |
+
import mlflow
|
| 96 |
+
from mlflow.tracking import MlflowClient
|
| 97 |
+
|
| 98 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 99 |
+
client = MlflowClient()
|
| 100 |
+
|
| 101 |
+
def get_best_model_from_experiment(experiment_name="Default", metric="cv_f1"):
|
| 102 |
+
"""
|
| 103 |
+
Trouve le meilleur modèle d'une expérience basé sur une métrique.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
experiment_name: Nom de l'expérience MLflow
|
| 107 |
+
metric: Métrique à optimiser (cv_f1, test_f1, etc.)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
run_id du meilleur modèle
|
| 111 |
+
"""
|
| 112 |
+
# Récupérer l'expérience
|
| 113 |
+
experiment = client.get_experiment_by_name(experiment_name)
|
| 114 |
+
if not experiment:
|
| 115 |
+
raise ValueError(f"Expérience '{experiment_name}' introuvable")
|
| 116 |
+
|
| 117 |
+
# Rechercher tous les runs de l'expérience
|
| 118 |
+
runs = client.search_runs(
|
| 119 |
+
experiment_ids=[experiment.experiment_id],
|
| 120 |
+
order_by=[f"metrics.{metric} DESC"], # Trier par métrique décroissante
|
| 121 |
+
max_results=1 # Prendre seulement le meilleur
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if not runs:
|
| 125 |
+
raise ValueError(f"Aucun run trouvé dans l'expérience '{experiment_name}'")
|
| 126 |
+
|
| 127 |
+
best_run = runs[0]
|
| 128 |
+
print(f"🏆 Meilleur modèle trouvé:")
|
| 129 |
+
print(f" Run ID: {best_run.info.run_id}")
|
| 130 |
+
print(f" {metric}: {best_run.data.metrics.get(metric, 'N/A')}")
|
| 131 |
+
print(f" Date: {best_run.info.start_time}")
|
| 132 |
+
|
| 133 |
+
return best_run.info.run_id
|
| 134 |
+
|
| 135 |
+
# Exemple d'utilisation
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
best_run_id = get_best_model_from_experiment("Default", "cv_f1")
|
| 138 |
+
|
| 139 |
+
# Charger le modèle
|
| 140 |
+
model_uri = f"runs:/{best_run_id}/model"
|
| 141 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 142 |
+
print(f"✅ Modèle chargé : {type(model)}")
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Option B : Via le Model Registry (pour production)
|
| 146 |
+
|
| 147 |
+
```python
|
| 148 |
+
# api/load_production_model.py
|
| 149 |
+
import mlflow
|
| 150 |
+
|
| 151 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 152 |
+
|
| 153 |
+
def load_production_model(model_name="XGBoost_Employee_Turnover", stage="Production"):
|
| 154 |
+
"""
|
| 155 |
+
Charge le modèle en production depuis le Model Registry.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
model_name: Nom du modèle dans le Registry
|
| 159 |
+
stage: Stage du modèle ("Staging", "Production", "Archived")
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Modèle chargé
|
| 163 |
+
"""
|
| 164 |
+
model_uri = f"models:/{model_name}/{stage}"
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 168 |
+
print(f"✅ Modèle '{model_name}' ({stage}) chargé")
|
| 169 |
+
return model
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"⚠️ Erreur : {e}")
|
| 172 |
+
print(f"💡 Astuce : Promouvoir une version en '{stage}' dans MLflow UI")
|
| 173 |
+
|
| 174 |
+
# Fallback : Charger la dernière version
|
| 175 |
+
model_uri = f"models:/{model_name}/latest"
|
| 176 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 177 |
+
print(f"✅ Fallback : Dernière version chargée")
|
| 178 |
+
return model
|
| 179 |
+
|
| 180 |
+
# Utilisation
|
| 181 |
+
if __name__ == "__main__":
|
| 182 |
+
model = load_production_model()
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
## 4. API Integration - Exemple complet
|
| 188 |
+
|
| 189 |
+
### Créer une API Flask/FastAPI avec MLflow
|
| 190 |
+
|
| 191 |
+
```python
|
| 192 |
+
# api/app.py
|
| 193 |
+
from fastapi import FastAPI, HTTPException
|
| 194 |
+
from pydantic import BaseModel
|
| 195 |
+
import mlflow
|
| 196 |
+
import pandas as pd
|
| 197 |
+
import numpy as np
|
| 198 |
+
|
| 199 |
+
# Configuration
|
| 200 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 201 |
+
app = FastAPI(title="Employee Turnover Prediction API")
|
| 202 |
+
|
| 203 |
+
# Charger le modèle au démarrage
|
| 204 |
+
MODEL_NAME = "XGBoost_Employee_Turnover"
|
| 205 |
+
model = None
|
| 206 |
+
|
| 207 |
+
@app.on_event("startup")
|
| 208 |
+
def load_model():
|
| 209 |
+
global model
|
| 210 |
+
try:
|
| 211 |
+
# Charger le dernier modèle du Registry
|
| 212 |
+
model_uri = f"models:/{MODEL_NAME}/latest"
|
| 213 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 214 |
+
print(f"✅ Modèle chargé : {MODEL_NAME}")
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"❌ Erreur chargement modèle : {e}")
|
| 217 |
+
raise
|
| 218 |
+
|
| 219 |
+
# Schéma de requête
|
| 220 |
+
class PredictionRequest(BaseModel):
|
| 221 |
+
features: list[float] # Liste de 50 features (après prétraitement)
|
| 222 |
+
|
| 223 |
+
class Config:
|
| 224 |
+
json_schema_extra = {
|
| 225 |
+
"example": {
|
| 226 |
+
"features": [0.5, 1.2, -0.3, 0.8] + [0.0] * 46 # 50 features
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
class PredictionResponse(BaseModel):
|
| 231 |
+
prediction: int # 0 ou 1
|
| 232 |
+
probability: float # Probabilité de départ (classe 1)
|
| 233 |
+
model_version: str
|
| 234 |
+
|
| 235 |
+
# Endpoint de prédiction
|
| 236 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 237 |
+
def predict(request: PredictionRequest):
|
| 238 |
+
"""
|
| 239 |
+
Prédit si un employé va quitter l'entreprise.
|
| 240 |
+
|
| 241 |
+
- **features**: Liste de 50 features numériques (après prétraitement)
|
| 242 |
+
- Retourne la prédiction (0=reste, 1=part) et la probabilité
|
| 243 |
+
"""
|
| 244 |
+
if model is None:
|
| 245 |
+
raise HTTPException(status_code=503, detail="Modèle non chargé")
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
# Convertir en DataFrame
|
| 249 |
+
X = pd.DataFrame([request.features])
|
| 250 |
+
|
| 251 |
+
# Prédiction
|
| 252 |
+
prediction = int(model.predict(X)[0])
|
| 253 |
+
probability = float(model.predict_proba(X)[0][1])
|
| 254 |
+
|
| 255 |
+
return PredictionResponse(
|
| 256 |
+
prediction=prediction,
|
| 257 |
+
probability=round(probability, 4),
|
| 258 |
+
model_version=MODEL_NAME
|
| 259 |
+
)
|
| 260 |
+
except Exception as e:
|
| 261 |
+
raise HTTPException(status_code=400, detail=f"Erreur prédiction : {str(e)}")
|
| 262 |
+
|
| 263 |
+
# Endpoint de santé
|
| 264 |
+
@app.get("/health")
|
| 265 |
+
def health():
|
| 266 |
+
return {
|
| 267 |
+
"status": "ok",
|
| 268 |
+
"model_loaded": model is not None,
|
| 269 |
+
"model_name": MODEL_NAME
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# Endpoint pour lister les modèles disponibles
|
| 273 |
+
@app.get("/models")
|
| 274 |
+
def list_models():
|
| 275 |
+
from mlflow.tracking import MlflowClient
|
| 276 |
+
client = MlflowClient()
|
| 277 |
+
|
| 278 |
+
models = []
|
| 279 |
+
for rm in client.search_registered_models():
|
| 280 |
+
latest_versions = rm.latest_versions
|
| 281 |
+
models.append({
|
| 282 |
+
"name": rm.name,
|
| 283 |
+
"versions": len(latest_versions),
|
| 284 |
+
"latest_version": latest_versions[0].version if latest_versions else None
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
return {"models": models}
|
| 288 |
+
|
| 289 |
+
# Lancer avec : uvicorn api.app:app --reload
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
**Tester l'API** :
|
| 293 |
+
```bash
|
| 294 |
+
# Installer FastAPI
|
| 295 |
+
pip install fastapi uvicorn
|
| 296 |
+
|
| 297 |
+
# Lancer le serveur
|
| 298 |
+
uvicorn api.app:app --reload --port 8000
|
| 299 |
+
|
| 300 |
+
# Tester
|
| 301 |
+
curl -X POST http://localhost:8000/predict \
|
| 302 |
+
-H "Content-Type: application/json" \
|
| 303 |
+
-d '{"features": [0.5, 1.2, -0.3] + [0.0] * 47}'
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
---
|
| 307 |
+
|
| 308 |
+
## 5. Best Practices
|
| 309 |
+
|
| 310 |
+
### ✅ Stratégie de versioning des modèles
|
| 311 |
+
|
| 312 |
+
```python
|
| 313 |
+
# Workflow recommandé
|
| 314 |
+
# 1. Entraîner plusieurs modèles → Experiment "Development"
|
| 315 |
+
# 2. Sélectionner le meilleur → Promouvoir en "Staging"
|
| 316 |
+
# 3. Valider en staging → Promouvoir en "Production"
|
| 317 |
+
# 4. API charge toujours "Production"
|
| 318 |
+
|
| 319 |
+
from mlflow.tracking import MlflowClient
|
| 320 |
+
|
| 321 |
+
client = MlflowClient()
|
| 322 |
+
model_name = "XGBoost_Employee_Turnover"
|
| 323 |
+
|
| 324 |
+
# Promouvoir version 2 en Production
|
| 325 |
+
client.transition_model_version_stage(
|
| 326 |
+
name=model_name,
|
| 327 |
+
version=2,
|
| 328 |
+
stage="Production"
|
| 329 |
+
)
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
### 📊 Logging avancé
|
| 333 |
+
|
| 334 |
+
```python
|
| 335 |
+
# Dans train_model.py, ajouter plus de contexte
|
| 336 |
+
with mlflow.start_run():
|
| 337 |
+
# Log dataset info
|
| 338 |
+
mlflow.log_param("n_samples", len(X))
|
| 339 |
+
mlflow.log_param("n_features", X.shape[1])
|
| 340 |
+
mlflow.log_param("class_imbalance_ratio", sum(y==0)/sum(y==1))
|
| 341 |
+
|
| 342 |
+
# Log artifacts (graphiques, etc.)
|
| 343 |
+
import matplotlib.pyplot as plt
|
| 344 |
+
|
| 345 |
+
# Confusion matrix plot
|
| 346 |
+
plt.figure()
|
| 347 |
+
# ... plot code ...
|
| 348 |
+
plt.savefig("confusion_matrix.png")
|
| 349 |
+
mlflow.log_artifact("confusion_matrix.png")
|
| 350 |
+
|
| 351 |
+
# Log code version
|
| 352 |
+
import subprocess
|
| 353 |
+
git_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode()
|
| 354 |
+
mlflow.set_tag("git_commit", git_commit)
|
| 355 |
+
```
|
| 356 |
+
|
| 357 |
+
### 🔄 Retraining workflow
|
| 358 |
+
|
| 359 |
+
```python
|
| 360 |
+
# scripts/retrain_model.py
|
| 361 |
+
import mlflow
|
| 362 |
+
from datetime import datetime
|
| 363 |
+
|
| 364 |
+
def retrain_and_compare():
|
| 365 |
+
"""Entraîne un nouveau modèle et le compare à la production."""
|
| 366 |
+
|
| 367 |
+
# 1. Charger le modèle en production
|
| 368 |
+
prod_model = mlflow.sklearn.load_model("models:/XGBoost_Employee_Turnover/Production")
|
| 369 |
+
|
| 370 |
+
# 2. Entraîner nouveau modèle
|
| 371 |
+
X, y, _, _ = preprocess_data(data_paths)
|
| 372 |
+
new_model, params, new_f1 = train_model(X, y)
|
| 373 |
+
|
| 374 |
+
# 3. Comparer les performances
|
| 375 |
+
from sklearn.model_selection import cross_val_score
|
| 376 |
+
prod_f1 = cross_val_score(prod_model, X, y, cv=5, scoring='f1').mean()
|
| 377 |
+
|
| 378 |
+
print(f"Production F1: {prod_f1:.4f}")
|
| 379 |
+
print(f"New model F1: {new_f1:.4f}")
|
| 380 |
+
|
| 381 |
+
# 4. Si meilleur, promouvoir automatiquement
|
| 382 |
+
if new_f1 > prod_f1:
|
| 383 |
+
print("✅ Nouveau modèle meilleur ! Promotion en Staging...")
|
| 384 |
+
# Enregistrer dans Registry
|
| 385 |
+
# ... code de promotion ...
|
| 386 |
+
else:
|
| 387 |
+
print("⚠️ Nouveau modèle moins bon, conservation du modèle actuel")
|
| 388 |
+
```
|
| 389 |
+
|
| 390 |
+
---
|
| 391 |
+
|
| 392 |
+
## 📚 Ressources
|
| 393 |
+
|
| 394 |
+
- **MLflow Docs**: https://mlflow.org/docs/latest/index.html
|
| 395 |
+
- **Model Registry**: https://mlflow.org/docs/latest/model-registry.html
|
| 396 |
+
- **Python API**: https://mlflow.org/docs/latest/python_api/index.html
|
| 397 |
+
|
| 398 |
+
---
|
| 399 |
+
|
| 400 |
+
## 🎯 Prochaines étapes pour ton projet
|
| 401 |
+
|
| 402 |
+
1. ✅ **MLflow configuré** - Tracking local avec SQLite
|
| 403 |
+
2. ✅ **Modèle enregistré** - XGBoost_Employee_Turnover v1
|
| 404 |
+
3. 🔄 **TODO: Créer l'API** - FastAPI avec chargement du modèle
|
| 405 |
+
4. 🔄 **TODO: Tester comparaison** - Multiple runs avec hyperparams différents
|
| 406 |
+
5. 🔄 **TODO: CI/CD** - Auto-retraining et déploiement
|
| 407 |
+
|
| 408 |
+
**Commande pour démarrer l'API** :
|
| 409 |
+
```bash
|
| 410 |
+
# Créer api/app.py avec le code ci-dessus
|
| 411 |
+
uvicorn api.app:app --reload --port 8000
|
| 412 |
+
```
|
examples/01_find_best_model.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Exemple 1 : Trouver le meilleur modèle dans MLflow
|
| 4 |
+
Usage: python examples/01_find_best_model.py
|
| 5 |
+
"""
|
| 6 |
+
import mlflow
|
| 7 |
+
from mlflow.tracking import MlflowClient
|
| 8 |
+
|
| 9 |
+
# Configuration
|
| 10 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 11 |
+
client = MlflowClient()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def find_best_model(experiment_name="Default", metric="cv_f1"):
|
| 15 |
+
"""Trouve le meilleur modèle basé sur une métrique."""
|
| 16 |
+
|
| 17 |
+
print(f"🔍 Recherche du meilleur modèle dans '{experiment_name}'...")
|
| 18 |
+
print(f"📊 Métrique d'optimisation : {metric}\n")
|
| 19 |
+
|
| 20 |
+
# Récupérer l'expérience
|
| 21 |
+
experiment = client.get_experiment_by_name(experiment_name)
|
| 22 |
+
if not experiment:
|
| 23 |
+
print(f"❌ Expérience '{experiment_name}' introuvable")
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
# Rechercher tous les runs
|
| 27 |
+
runs = client.search_runs(
|
| 28 |
+
experiment_ids=[experiment.experiment_id],
|
| 29 |
+
filter_string="", # Pas de filtre
|
| 30 |
+
order_by=[f"metrics.{metric} DESC"],
|
| 31 |
+
max_results=5, # Top 5
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
if not runs:
|
| 35 |
+
print(f"❌ Aucun run trouvé")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
print(f"📈 Top 5 des modèles :\n")
|
| 39 |
+
print(f"{'Rank':<6} {'Run ID':<35} {metric:<12} {'Date':<20}")
|
| 40 |
+
print("-" * 75)
|
| 41 |
+
|
| 42 |
+
for i, run in enumerate(runs, 1):
|
| 43 |
+
metric_value = run.data.metrics.get(metric, 0.0)
|
| 44 |
+
timestamp = run.info.start_time
|
| 45 |
+
from datetime import datetime
|
| 46 |
+
|
| 47 |
+
date_str = datetime.fromtimestamp(timestamp / 1000).strftime("%Y-%m-%d %H:%M")
|
| 48 |
+
|
| 49 |
+
print(f"{i:<6} {run.info.run_id:<35} {metric_value:<12.4f} {date_str:<20}")
|
| 50 |
+
|
| 51 |
+
# Meilleur modèle
|
| 52 |
+
best_run = runs[0]
|
| 53 |
+
best_metric = best_run.data.metrics.get(metric, 0.0)
|
| 54 |
+
|
| 55 |
+
print(f"\n🏆 Meilleur modèle :")
|
| 56 |
+
print(f" Run ID : {best_run.info.run_id}")
|
| 57 |
+
print(f" {metric:<10}: {best_metric:.4f}")
|
| 58 |
+
print(f" Status : {best_run.info.status}")
|
| 59 |
+
|
| 60 |
+
# Afficher les hyperparamètres
|
| 61 |
+
print(f"\n⚙️ Hyperparamètres :")
|
| 62 |
+
for key, value in best_run.data.params.items():
|
| 63 |
+
print(f" {key:<25} : {value}")
|
| 64 |
+
|
| 65 |
+
# Afficher toutes les métriques
|
| 66 |
+
print(f"\n📊 Métriques :")
|
| 67 |
+
for key, value in best_run.data.metrics.items():
|
| 68 |
+
print(f" {key:<25} : {value:.4f}")
|
| 69 |
+
|
| 70 |
+
return best_run.info.run_id
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_best_model(run_id):
|
| 74 |
+
"""Charge le modèle à partir d'un run_id."""
|
| 75 |
+
print(f"\n📦 Chargement du modèle...")
|
| 76 |
+
|
| 77 |
+
model_uri = f"runs:/{run_id}/model"
|
| 78 |
+
try:
|
| 79 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 80 |
+
print(f"✅ Modèle chargé avec succès")
|
| 81 |
+
print(f" Type : {type(model).__name__}")
|
| 82 |
+
|
| 83 |
+
# Afficher la pipeline si c'est une Pipeline
|
| 84 |
+
if hasattr(model, "steps"):
|
| 85 |
+
print(f" Pipeline steps :")
|
| 86 |
+
for name, step in model.steps:
|
| 87 |
+
print(f" - {name}: {type(step).__name__}")
|
| 88 |
+
|
| 89 |
+
return model
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"❌ Erreur lors du chargement : {e}")
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
# Trouver le meilleur modèle
|
| 97 |
+
best_run_id = find_best_model("Default", "cv_f1")
|
| 98 |
+
|
| 99 |
+
if best_run_id:
|
| 100 |
+
# Charger le modèle
|
| 101 |
+
model = load_best_model(best_run_id)
|
| 102 |
+
|
| 103 |
+
if model:
|
| 104 |
+
print(f"\n💡 Pour utiliser ce modèle dans ton API :")
|
| 105 |
+
print(f" model_uri = 'runs:/{best_run_id}/model'")
|
| 106 |
+
print(f" model = mlflow.sklearn.load_model(model_uri)")
|
examples/02_compare_models.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Exemple 2 : Comparer plusieurs modèles avec différents hyperparamètres
|
| 4 |
+
Usage: python examples/02_compare_models.py
|
| 5 |
+
"""
|
| 6 |
+
import mlflow
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from mlflow.tracking import MlflowClient
|
| 9 |
+
|
| 10 |
+
# Configuration
|
| 11 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 12 |
+
client = MlflowClient()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def compare_all_runs(experiment_name="Default"):
|
| 16 |
+
"""Compare tous les runs d'une expérience."""
|
| 17 |
+
|
| 18 |
+
print(f"📊 Comparaison de tous les runs dans '{experiment_name}'\n")
|
| 19 |
+
|
| 20 |
+
# Récupérer l'expérience
|
| 21 |
+
experiment = client.get_experiment_by_name(experiment_name)
|
| 22 |
+
if not experiment:
|
| 23 |
+
print(f"❌ Expérience '{experiment_name}' introuvable")
|
| 24 |
+
return None
|
| 25 |
+
|
| 26 |
+
# Récupérer tous les runs
|
| 27 |
+
runs = client.search_runs(
|
| 28 |
+
experiment_ids=[experiment.experiment_id], order_by=["start_time DESC"]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if not runs:
|
| 32 |
+
print(f"❌ Aucun run trouvé")
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
print(f"✅ {len(runs)} run(s) trouvé(s)\n")
|
| 36 |
+
|
| 37 |
+
# Créer un DataFrame pour comparaison
|
| 38 |
+
data = []
|
| 39 |
+
for run in runs:
|
| 40 |
+
from datetime import datetime
|
| 41 |
+
|
| 42 |
+
row = {
|
| 43 |
+
"run_id": run.info.run_id[:8], # 8 premiers caractères
|
| 44 |
+
"status": run.info.status,
|
| 45 |
+
"start_time": datetime.fromtimestamp(run.info.start_time / 1000).strftime(
|
| 46 |
+
"%Y-%m-%d %H:%M"
|
| 47 |
+
),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
# Ajouter les métriques
|
| 51 |
+
for metric_name in ["cv_f1", "test_precision", "test_recall", "test_f1"]:
|
| 52 |
+
row[metric_name] = run.data.metrics.get(metric_name, None)
|
| 53 |
+
|
| 54 |
+
# Ajouter quelques hyperparamètres importants
|
| 55 |
+
for param_name in ["clf__n_estimators", "clf__max_depth", "clf__learning_rate"]:
|
| 56 |
+
param_value = run.data.params.get(param_name, None)
|
| 57 |
+
if param_value:
|
| 58 |
+
try:
|
| 59 |
+
row[param_name] = (
|
| 60 |
+
float(param_value)
|
| 61 |
+
if "." in str(param_value)
|
| 62 |
+
else int(param_value)
|
| 63 |
+
)
|
| 64 |
+
except:
|
| 65 |
+
row[param_name] = param_value
|
| 66 |
+
|
| 67 |
+
data.append(row)
|
| 68 |
+
|
| 69 |
+
# Créer DataFrame
|
| 70 |
+
df = pd.DataFrame(data)
|
| 71 |
+
|
| 72 |
+
# Afficher le tableau
|
| 73 |
+
print("📈 Comparaison des modèles :")
|
| 74 |
+
print("=" * 120)
|
| 75 |
+
pd.set_option("display.max_columns", None)
|
| 76 |
+
pd.set_option("display.width", 120)
|
| 77 |
+
print(df.to_string(index=False))
|
| 78 |
+
print("=" * 120)
|
| 79 |
+
|
| 80 |
+
# Statistiques
|
| 81 |
+
print(f"\n📊 Statistiques :")
|
| 82 |
+
if "cv_f1" in df.columns:
|
| 83 |
+
print(f" CV F1 moyen : {df['cv_f1'].mean():.4f}")
|
| 84 |
+
print(f" CV F1 max : {df['cv_f1'].max():.4f}")
|
| 85 |
+
print(f" CV F1 min : {df['cv_f1'].min():.4f}")
|
| 86 |
+
print(f" Écart-type : {df['cv_f1'].std():.4f}")
|
| 87 |
+
|
| 88 |
+
# Meilleur run
|
| 89 |
+
if "cv_f1" in df.columns:
|
| 90 |
+
best_idx = df["cv_f1"].idxmax()
|
| 91 |
+
best_run = df.iloc[best_idx]
|
| 92 |
+
print(f"\n🏆 Meilleur run : {best_run['run_id']}")
|
| 93 |
+
print(f" CV F1 : {best_run['cv_f1']:.4f}")
|
| 94 |
+
|
| 95 |
+
return df
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def plot_metrics_comparison(experiment_name="Default"):
|
| 99 |
+
"""Génère un graphique de comparaison (nécessite matplotlib)."""
|
| 100 |
+
try:
|
| 101 |
+
import matplotlib.pyplot as plt
|
| 102 |
+
|
| 103 |
+
experiment = client.get_experiment_by_name(experiment_name)
|
| 104 |
+
if not experiment:
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
runs = client.search_runs(
|
| 108 |
+
experiment_ids=[experiment.experiment_id], order_by=["start_time ASC"]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Extraire les données
|
| 112 |
+
run_names = [f"Run {i + 1}" for i in range(len(runs))]
|
| 113 |
+
cv_f1_scores = [run.data.metrics.get("cv_f1", 0) for run in runs]
|
| 114 |
+
test_f1_scores = [run.data.metrics.get("test_f1", 0) for run in runs]
|
| 115 |
+
|
| 116 |
+
# Créer le graphique
|
| 117 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 118 |
+
|
| 119 |
+
x = range(len(runs))
|
| 120 |
+
width = 0.35
|
| 121 |
+
|
| 122 |
+
ax.bar(
|
| 123 |
+
[i - width / 2 for i in x], cv_f1_scores, width, label="CV F1", alpha=0.8
|
| 124 |
+
)
|
| 125 |
+
ax.bar(
|
| 126 |
+
[i + width / 2 for i in x],
|
| 127 |
+
test_f1_scores,
|
| 128 |
+
width,
|
| 129 |
+
label="Test F1",
|
| 130 |
+
alpha=0.8,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
ax.set_xlabel("Runs")
|
| 134 |
+
ax.set_ylabel("F1 Score")
|
| 135 |
+
ax.set_title(f"Comparaison des F1 scores - Expérience: {experiment_name}")
|
| 136 |
+
ax.set_xticks(x)
|
| 137 |
+
ax.set_xticklabels(run_names, rotation=45)
|
| 138 |
+
ax.legend()
|
| 139 |
+
ax.grid(axis="y", alpha=0.3)
|
| 140 |
+
|
| 141 |
+
plt.tight_layout()
|
| 142 |
+
plt.savefig("mlflow_comparison.png", dpi=300, bbox_inches="tight")
|
| 143 |
+
print(f"\n📊 Graphique sauvegardé : mlflow_comparison.png")
|
| 144 |
+
|
| 145 |
+
except ImportError:
|
| 146 |
+
print("\n⚠️ matplotlib non installé, graphique non généré")
|
| 147 |
+
print(" Installation : pip install matplotlib")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
# Comparer tous les runs
|
| 152 |
+
df = compare_all_runs("Default")
|
| 153 |
+
|
| 154 |
+
if df is not None:
|
| 155 |
+
# Générer un graphique
|
| 156 |
+
plot_metrics_comparison("Default")
|
| 157 |
+
|
| 158 |
+
print(f"\n💡 Conseils :")
|
| 159 |
+
print(f" - Les runs avec CV F1 élevé sont de meilleurs candidats")
|
| 160 |
+
print(
|
| 161 |
+
f" - Vérifier que test_f1 est cohérent avec cv_f1 (pas de surapprentissage)"
|
| 162 |
+
)
|
| 163 |
+
print(
|
| 164 |
+
f" - Favoriser les modèles avec moins de paramètres si performances similaires"
|
| 165 |
+
)
|
examples/03_model_registry.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Exemple 3 : Gérer le Model Registry (versions, stages, promotion)
|
| 4 |
+
Usage: python examples/03_model_registry.py
|
| 5 |
+
"""
|
| 6 |
+
import mlflow
|
| 7 |
+
from mlflow.tracking import MlflowClient
|
| 8 |
+
|
| 9 |
+
# Configuration
|
| 10 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 11 |
+
client = MlflowClient()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def list_registered_models():
|
| 15 |
+
"""Liste tous les modèles enregistrés dans le Registry."""
|
| 16 |
+
|
| 17 |
+
print("📦 Modèles enregistrés dans le Model Registry :\n")
|
| 18 |
+
|
| 19 |
+
models = client.search_registered_models()
|
| 20 |
+
|
| 21 |
+
if not models:
|
| 22 |
+
print("❌ Aucun modèle enregistré")
|
| 23 |
+
return []
|
| 24 |
+
|
| 25 |
+
for rm in models:
|
| 26 |
+
print(f"🔹 {rm.name}")
|
| 27 |
+
print(f" Description : {rm.description or 'N/A'}")
|
| 28 |
+
print(f" Création : {rm.creation_timestamp}")
|
| 29 |
+
print(f" Versions : {len(rm.latest_versions)}")
|
| 30 |
+
|
| 31 |
+
# Lister les versions
|
| 32 |
+
versions = client.search_model_versions(f"name='{rm.name}'")
|
| 33 |
+
for mv in versions:
|
| 34 |
+
stage = mv.current_stage
|
| 35 |
+
emoji = (
|
| 36 |
+
"🚀" if stage == "Production" else "🧪" if stage == "Staging" else "📝"
|
| 37 |
+
)
|
| 38 |
+
print(f" {emoji} Version {mv.version} - {stage}")
|
| 39 |
+
print(f" Run ID: {mv.run_id}")
|
| 40 |
+
print(f" Source: {mv.source}")
|
| 41 |
+
print()
|
| 42 |
+
|
| 43 |
+
return models
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_model_details(model_name="XGBoost_Employee_Turnover"):
|
| 47 |
+
"""Affiche les détails d'un modèle spécifique."""
|
| 48 |
+
|
| 49 |
+
print(f"🔍 Détails du modèle '{model_name}' :\n")
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
# Récupérer les infos du modèle
|
| 53 |
+
rm = client.get_registered_model(model_name)
|
| 54 |
+
|
| 55 |
+
from datetime import datetime
|
| 56 |
+
|
| 57 |
+
print(f"📦 Informations générales :")
|
| 58 |
+
print(f" Nom : {rm.name}")
|
| 59 |
+
print(f" Description : {rm.description or 'N/A'}")
|
| 60 |
+
print(
|
| 61 |
+
f" Création : {datetime.fromtimestamp(rm.creation_timestamp / 1000).strftime('%Y-%m-%d %H:%M')}"
|
| 62 |
+
)
|
| 63 |
+
print(
|
| 64 |
+
f" Dernière MAJ: {datetime.fromtimestamp(rm.last_updated_timestamp / 1000).strftime('%Y-%m-%d %H:%M')}"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Lister toutes les versions
|
| 68 |
+
versions = client.search_model_versions(f"name='{model_name}'")
|
| 69 |
+
|
| 70 |
+
print(f"\n📊 Versions ({len(versions)}) :")
|
| 71 |
+
print(f"{'Version':<10} {'Stage':<15} {'Run ID':<35} {'Date':<20}")
|
| 72 |
+
print("-" * 85)
|
| 73 |
+
|
| 74 |
+
for mv in sorted(versions, key=lambda v: int(v.version), reverse=True):
|
| 75 |
+
date_str = datetime.fromtimestamp(mv.creation_timestamp / 1000).strftime(
|
| 76 |
+
"%Y-%m-%d %H:%M"
|
| 77 |
+
)
|
| 78 |
+
print(
|
| 79 |
+
f"{mv.version:<10} {mv.current_stage:<15} {mv.run_id:<35} {date_str:<20}"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Afficher la version en production
|
| 83 |
+
prod_versions = [v for v in versions if v.current_stage == "Production"]
|
| 84 |
+
if prod_versions:
|
| 85 |
+
print(f"\n🚀 Version en production : {prod_versions[0].version}")
|
| 86 |
+
else:
|
| 87 |
+
print(f"\n⚠️ Aucune version en production")
|
| 88 |
+
|
| 89 |
+
return rm
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"❌ Erreur : {e}")
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def promote_model(model_name, version, stage="Staging"):
|
| 97 |
+
"""
|
| 98 |
+
Promouvoir une version de modèle vers un stage.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
model_name: Nom du modèle
|
| 102 |
+
version: Numéro de version
|
| 103 |
+
stage: "Staging", "Production", ou "Archived"
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
print(f"🔄 Promotion du modèle '{model_name}' v{version} → {stage}...")
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# Transition vers le nouveau stage
|
| 110 |
+
client.transition_model_version_stage(
|
| 111 |
+
name=model_name,
|
| 112 |
+
version=version,
|
| 113 |
+
stage=stage,
|
| 114 |
+
archive_existing_versions=True, # Archive les anciennes versions du même stage
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(f"✅ Modèle promu avec succès !")
|
| 118 |
+
print(f" {model_name} v{version} est maintenant en {stage}")
|
| 119 |
+
|
| 120 |
+
# Afficher l'état mis à jour
|
| 121 |
+
mv = client.get_model_version(model_name, version)
|
| 122 |
+
print(f" Status : {mv.status}")
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"❌ Erreur lors de la promotion : {e}")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def load_model_from_registry(
|
| 129 |
+
model_name="XGBoost_Employee_Turnover", stage="Production"
|
| 130 |
+
):
|
| 131 |
+
"""Charge un modèle depuis le Registry."""
|
| 132 |
+
|
| 133 |
+
print(f"📦 Chargement du modèle '{model_name}' ({stage})...\n")
|
| 134 |
+
|
| 135 |
+
model_uri = f"models:/{model_name}/{stage}"
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 139 |
+
print(f"✅ Modèle chargé avec succès")
|
| 140 |
+
print(f" URI : {model_uri}")
|
| 141 |
+
print(f" Type : {type(model).__name__}")
|
| 142 |
+
|
| 143 |
+
return model
|
| 144 |
+
|
| 145 |
+
except mlflow.exceptions.MlflowException as e:
|
| 146 |
+
print(f"⚠️ Aucun modèle en {stage}")
|
| 147 |
+
print(f" Essai avec 'latest'...")
|
| 148 |
+
|
| 149 |
+
# Fallback sur latest
|
| 150 |
+
model_uri = f"models:/{model_name}/latest"
|
| 151 |
+
model = mlflow.sklearn.load_model(model_uri)
|
| 152 |
+
print(f"✅ Dernière version chargée")
|
| 153 |
+
|
| 154 |
+
return model
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def demo_workflow():
|
| 158 |
+
"""Démo du workflow complet de gestion des modèles."""
|
| 159 |
+
|
| 160 |
+
print("=" * 80)
|
| 161 |
+
print("🎯 DEMO - Workflow Model Registry")
|
| 162 |
+
print("=" * 80 + "\n")
|
| 163 |
+
|
| 164 |
+
# 1. Lister les modèles
|
| 165 |
+
print("1️⃣ Liste des modèles\n")
|
| 166 |
+
models = list_registered_models()
|
| 167 |
+
|
| 168 |
+
if not models:
|
| 169 |
+
print("⚠️ Aucun modèle trouvé. Exécute d'abord un training avec MLflow.")
|
| 170 |
+
return
|
| 171 |
+
|
| 172 |
+
# 2. Détails du premier modèle
|
| 173 |
+
model_name = models[0].name
|
| 174 |
+
print("\n" + "=" * 80)
|
| 175 |
+
print(f"2️⃣ Détails du modèle '{model_name}'\n")
|
| 176 |
+
get_model_details(model_name)
|
| 177 |
+
|
| 178 |
+
# 3. Exemple de promotion (commenté pour ne pas modifier l'état)
|
| 179 |
+
print("\n" + "=" * 80)
|
| 180 |
+
print("3️⃣ Promotion d'un modèle\n")
|
| 181 |
+
print("💡 Pour promouvoir la version 1 en Production :")
|
| 182 |
+
print(f" promote_model('{model_name}', version=1, stage='Production')")
|
| 183 |
+
print(" (Décommente dans le code pour exécuter)")
|
| 184 |
+
|
| 185 |
+
# Décommente cette ligne pour promouvoir réellement :
|
| 186 |
+
# promote_model(model_name, version=1, stage="Production")
|
| 187 |
+
|
| 188 |
+
# 4. Charger un modèle
|
| 189 |
+
print("\n" + "=" * 80)
|
| 190 |
+
print("4️⃣ Chargement d'un modèle\n")
|
| 191 |
+
|
| 192 |
+
# Essayer de charger depuis Production
|
| 193 |
+
try:
|
| 194 |
+
model = load_model_from_registry(model_name, "Production")
|
| 195 |
+
except:
|
| 196 |
+
print("⚠️ Aucun modèle en Production, chargement de 'latest'")
|
| 197 |
+
model = load_model_from_registry(model_name, "None")
|
| 198 |
+
|
| 199 |
+
print("\n" + "=" * 80)
|
| 200 |
+
print("✅ Demo terminée !")
|
| 201 |
+
print("=" * 80)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
demo_workflow()
|
examples/README.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📚 Exemples MLflow
|
| 2 |
+
|
| 3 |
+
Ce dossier contient des exemples pratiques pour utiliser MLflow dans le projet.
|
| 4 |
+
|
| 5 |
+
## 🚀 Exemples disponibles
|
| 6 |
+
|
| 7 |
+
### 1. Trouver le meilleur modèle
|
| 8 |
+
```bash
|
| 9 |
+
python examples/01_find_best_model.py
|
| 10 |
+
```
|
| 11 |
+
**Ce qu'il fait** :
|
| 12 |
+
- Liste les 5 meilleurs runs selon une métrique (cv_f1)
|
| 13 |
+
- Affiche les hyperparamètres et métriques du meilleur
|
| 14 |
+
- Charge le modèle pour vérifier qu'il fonctionne
|
| 15 |
+
|
| 16 |
+
**Utilisation** : Parfait pour identifier quel modèle utiliser dans ton API
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
### 2. Comparer plusieurs modèles
|
| 21 |
+
```bash
|
| 22 |
+
python examples/02_compare_models.py
|
| 23 |
+
```
|
| 24 |
+
**Ce qu'il fait** :
|
| 25 |
+
- Compare tous les runs d'une expérience
|
| 26 |
+
- Affiche un tableau avec métriques et hyperparamètres
|
| 27 |
+
- Génère un graphique de comparaison (si matplotlib installé)
|
| 28 |
+
- Calcule des statistiques (moyenne, max, min, écart-type)
|
| 29 |
+
|
| 30 |
+
**Utilisation** : Pour analyser tes expériences d'hyperparameter tuning
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
|
| 34 |
+
### 3. Gérer le Model Registry
|
| 35 |
+
```bash
|
| 36 |
+
python examples/03_model_registry.py
|
| 37 |
+
```
|
| 38 |
+
**Ce qu'il fait** :
|
| 39 |
+
- Liste tous les modèles enregistrés
|
| 40 |
+
- Affiche les versions et leurs stages
|
| 41 |
+
- Démontre comment promouvoir un modèle
|
| 42 |
+
- Charge un modèle depuis le Registry
|
| 43 |
+
|
| 44 |
+
**Utilisation** : Workflow de versioning pour la production
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 📖 Guide complet
|
| 49 |
+
|
| 50 |
+
Consulte `docs/mlflow_guide.md` pour :
|
| 51 |
+
- Architecture MLflow complète
|
| 52 |
+
- Intégration API FastAPI/Flask
|
| 53 |
+
- Best practices
|
| 54 |
+
- Workflow de retraining
|
| 55 |
+
|
| 56 |
+
## 🎯 Workflow recommandé
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# 1. Entraîner plusieurs modèles
|
| 60 |
+
MLFLOW_TRACKING_URI=sqlite:///mlflow.db python tests/test_mlflow_quick.py
|
| 61 |
+
|
| 62 |
+
# 2. Trouver le meilleur
|
| 63 |
+
python examples/01_find_best_model.py
|
| 64 |
+
|
| 65 |
+
# 3. Comparer tous les runs
|
| 66 |
+
python examples/02_compare_models.py
|
| 67 |
+
|
| 68 |
+
# 4. Gérer le Registry
|
| 69 |
+
python examples/03_model_registry.py
|
| 70 |
+
|
| 71 |
+
# 5. Promouvoir en production (dans le code Python)
|
| 72 |
+
from mlflow.tracking import MlflowClient
|
| 73 |
+
client = MlflowClient()
|
| 74 |
+
client.transition_model_version_stage(
|
| 75 |
+
name="XGBoost_Employee_Turnover",
|
| 76 |
+
version=1,
|
| 77 |
+
stage="Production"
|
| 78 |
+
)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## 🔗 Intégration API
|
| 82 |
+
|
| 83 |
+
Une fois le meilleur modèle identifié :
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
import mlflow
|
| 87 |
+
|
| 88 |
+
# Option A : Charger par run_id
|
| 89 |
+
model = mlflow.sklearn.load_model("runs:/RUN_ID/model")
|
| 90 |
+
|
| 91 |
+
# Option B : Charger depuis le Registry
|
| 92 |
+
model = mlflow.sklearn.load_model("models:/XGBoost_Employee_Turnover/Production")
|
| 93 |
+
|
| 94 |
+
# Prédiction
|
| 95 |
+
predictions = model.predict(X_new)
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## 💡 Tips
|
| 99 |
+
|
| 100 |
+
- **Métrique principale** : `cv_f1` (F1-score en cross-validation)
|
| 101 |
+
- **Métriques secondaires** : `test_precision`, `test_recall`, `test_f1`
|
| 102 |
+
- **Vérifier** : Que test_f1 ≈ cv_f1 (pas de surapprentissage)
|
| 103 |
+
- **Favoriser** : Modèles simples si performances similaires
|
| 104 |
+
|
| 105 |
+
## 🌐 MLflow UI
|
| 106 |
+
|
| 107 |
+
Pour visualiser graphiquement :
|
| 108 |
+
```bash
|
| 109 |
+
mlflow ui --backend-store-uri sqlite:///mlflow.db --port 5000
|
| 110 |
+
```
|
| 111 |
+
Puis ouvre http://localhost:5000
|
ml_model/preprocess.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
import numpy as np
|
| 3 |
-
|
| 4 |
-
from scipy.stats.mstats import winsorize
|
| 5 |
from scipy import stats
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def load_raw_data(
|
|
@@ -97,9 +97,16 @@ def preprocess_data(raw_data_paths=None):
|
|
| 97 |
)
|
| 98 |
|
| 99 |
# Assemblage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
df_engineered = pd.concat(
|
| 101 |
[
|
| 102 |
central_df[quantitative_cols],
|
|
|
|
| 103 |
encoded_non_ord,
|
| 104 |
encoded_ord,
|
| 105 |
central_df["a_quitte_l_entreprise"],
|
|
@@ -107,17 +114,8 @@ def preprocess_data(raw_data_paths=None):
|
|
| 107 |
axis=1,
|
| 108 |
) # Inclut cible
|
| 109 |
|
| 110 |
-
# Scaling (quantitatives + ordinal)
|
| 111 |
-
cols_to_scale = (
|
| 112 |
-
quantitative_cols.tolist()
|
| 113 |
-
+ cat_ord
|
| 114 |
-
+ [
|
| 115 |
-
"revenu_par_anciennete",
|
| 116 |
-
"experience_par_anciennete",
|
| 117 |
-
"satisfaction_moyenne",
|
| 118 |
-
"promo_par_anciennete",
|
| 119 |
-
]
|
| 120 |
-
)
|
| 121 |
scaler = StandardScaler()
|
| 122 |
df_engineered[cols_to_scale] = scaler.fit_transform(df_engineered[cols_to_scale])
|
| 123 |
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
import pandas as pd
|
|
|
|
| 3 |
from scipy import stats
|
| 4 |
+
from scipy.stats.mstats import winsorize
|
| 5 |
+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
|
| 6 |
|
| 7 |
|
| 8 |
def load_raw_data(
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
# Assemblage
|
| 100 |
+
engineered_cols = [
|
| 101 |
+
"revenu_par_anciennete",
|
| 102 |
+
"experience_par_anciennete",
|
| 103 |
+
"satisfaction_moyenne",
|
| 104 |
+
"promo_par_anciennete",
|
| 105 |
+
]
|
| 106 |
df_engineered = pd.concat(
|
| 107 |
[
|
| 108 |
central_df[quantitative_cols],
|
| 109 |
+
central_df[engineered_cols],
|
| 110 |
encoded_non_ord,
|
| 111 |
encoded_ord,
|
| 112 |
central_df["a_quitte_l_entreprise"],
|
|
|
|
| 114 |
axis=1,
|
| 115 |
) # Inclut cible
|
| 116 |
|
| 117 |
+
# Scaling (quantitatives + ordinal + engineered)
|
| 118 |
+
cols_to_scale = quantitative_cols.tolist() + engineered_cols + cat_ord
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
scaler = StandardScaler()
|
| 120 |
df_engineered[cols_to_scale] = scaler.fit_transform(df_engineered[cols_to_scale])
|
| 121 |
|
ml_model/train_model.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
from imblearn.over_sampling import SMOTE
|
| 4 |
from imblearn.pipeline import Pipeline as ImbPipeline
|
|
|
|
|
|
|
|
|
|
| 5 |
from xgboost import XGBClassifier
|
| 6 |
-
from scipy.stats import uniform, randint
|
| 7 |
|
| 8 |
|
| 9 |
def train_model(X, y):
|
|
@@ -42,17 +44,42 @@ def train_model(X, y):
|
|
| 42 |
n_jobs=-1,
|
| 43 |
random_state=42,
|
| 44 |
)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
return best_model, best_params, cv_f1
|
|
|
|
| 1 |
+
import mlflow
|
| 2 |
+
import mlflow.sklearn
|
| 3 |
from imblearn.over_sampling import SMOTE
|
| 4 |
from imblearn.pipeline import Pipeline as ImbPipeline
|
| 5 |
+
from scipy.stats import randint, uniform
|
| 6 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 7 |
+
from sklearn.model_selection import RandomizedSearchCV, train_test_split
|
| 8 |
from xgboost import XGBClassifier
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def train_model(X, y):
|
|
|
|
| 44 |
n_jobs=-1,
|
| 45 |
random_state=42,
|
| 46 |
)
|
| 47 |
+
|
| 48 |
+
# Ajout MLflow : Encapsule training pour tracking auto (./mlruns)
|
| 49 |
+
with mlflow.start_run(run_name="XGBoost_Tuning"):
|
| 50 |
+
random.fit(X_train, y_train)
|
| 51 |
+
|
| 52 |
+
best_model = random.best_estimator_ # type: ignore[assignment]
|
| 53 |
+
best_params = random.best_params_
|
| 54 |
+
cv_f1 = random.best_score_
|
| 55 |
+
|
| 56 |
+
mlflow.log_params(
|
| 57 |
+
best_params
|
| 58 |
+
) # Choix : Log tous hyperparams pour reproductibilité.
|
| 59 |
+
mlflow.log_metric(
|
| 60 |
+
"cv_f1", cv_f1
|
| 61 |
+
) # Choix : Métrique clé (F1 CV pour déséquilibre).
|
| 62 |
+
|
| 63 |
+
y_pred = best_model.predict(X_test) # type: ignore[attr-defined]
|
| 64 |
+
report = classification_report(y_test, y_pred, output_dict=True) # type: ignore[arg-type]
|
| 65 |
+
|
| 66 |
+
# Type ignore car classification_report avec output_dict=True retourne dict, pas str
|
| 67 |
+
mlflow.log_metric("test_precision", float(report["1"]["precision"])) # type: ignore[index]
|
| 68 |
+
mlflow.log_metric("test_recall", float(report["1"]["recall"])) # type: ignore[index]
|
| 69 |
+
mlflow.log_metric("test_f1", float(report["1"]["f1-score"])) # type: ignore[index]
|
| 70 |
+
|
| 71 |
+
# Log model et récupère URI pour l'enregistrement
|
| 72 |
+
model_info = mlflow.sklearn.log_model(best_model, "model") # type: ignore[attr-defined]
|
| 73 |
+
|
| 74 |
+
# Enregistre dans Model Registry pour apparaître dans la page "Models"
|
| 75 |
+
mlflow.register_model(
|
| 76 |
+
model_uri=model_info.model_uri, name="XGBoost_Employee_Turnover"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Éval test (pédagogique)
|
| 80 |
+
print("Meilleurs params:", best_params)
|
| 81 |
+
print("Meilleur CV F1:", cv_f1)
|
| 82 |
+
print(classification_report(y_test, y_pred))
|
| 83 |
+
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
|
| 84 |
|
| 85 |
return best_model, best_params, cv_f1
|
pyproject.toml
CHANGED
|
@@ -4,7 +4,7 @@ version = "0.1.0"
|
|
| 4 |
description = "Projet OpenClassRoom mise en API d'un modèle ML"
|
| 5 |
authors = ["chaton59 <v.trouillez@gmail.com>"]
|
| 6 |
readme = "README.md"
|
| 7 |
-
packages = [{include = "src"}]
|
| 8 |
|
| 9 |
[tool.poetry.dependencies]
|
| 10 |
python = "^3.12"
|
|
|
|
| 4 |
description = "Projet OpenClassRoom mise en API d'un modèle ML"
|
| 5 |
authors = ["chaton59 <v.trouillez@gmail.com>"]
|
| 6 |
readme = "README.md"
|
| 7 |
+
packages = [{include = "src"}, {include = "ml_model"}]
|
| 8 |
|
| 9 |
[tool.poetry.dependencies]
|
| 10 |
python = "^3.12"
|
requirements.txt
CHANGED
|
@@ -36,11 +36,12 @@ uvicorn==0.38.0 ; python_version >= "3.12"
|
|
| 36 |
uvloop==0.22.1 ; sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy" and python_version >= "3.12"
|
| 37 |
watchfiles==1.1.1 ; python_version >= "3.12"
|
| 38 |
websockets==15.0.1 ; python_version >= "3.12"
|
| 39 |
-
scikit-learn==1.
|
| 40 |
-
xgboost==2.
|
| 41 |
-
imbalanced-learn==0.
|
| 42 |
-
scipy==1.
|
| 43 |
-
numpy==
|
| 44 |
-
pandas==2.
|
| 45 |
-
joblib==1.
|
| 46 |
-
|
|
|
|
|
|
| 36 |
uvloop==0.22.1 ; sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy" and python_version >= "3.12"
|
| 37 |
watchfiles==1.1.1 ; python_version >= "3.12"
|
| 38 |
websockets==15.0.1 ; python_version >= "3.12"
|
| 39 |
+
scikit-learn==1.6.1
|
| 40 |
+
xgboost==2.1.4
|
| 41 |
+
imbalanced-learn==0.13.0
|
| 42 |
+
scipy==1.14.1
|
| 43 |
+
numpy==2.0.2
|
| 44 |
+
pandas==2.2.3
|
| 45 |
+
joblib==1.4.2
|
| 46 |
+
mlflow==3.8.0
|
| 47 |
+
huggingface-hub==0.26.5
|
scripts/fix_lint.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script pour corriger automatiquement les problèmes de lint du projet.
|
| 4 |
+
Usage: python scripts/fix_lint.py
|
| 5 |
+
"""
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def run_command(cmd, description):
|
| 12 |
+
"""Exécute une commande et affiche le résultat."""
|
| 13 |
+
print(f"\n{'=' * 60}")
|
| 14 |
+
print(f"🔧 {description}")
|
| 15 |
+
print(f"{'=' * 60}")
|
| 16 |
+
|
| 17 |
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
| 18 |
+
|
| 19 |
+
if result.stdout:
|
| 20 |
+
print(result.stdout)
|
| 21 |
+
if result.stderr:
|
| 22 |
+
print(result.stderr, file=sys.stderr)
|
| 23 |
+
|
| 24 |
+
return result.returncode
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
"""Corrige tous les problèmes de lint."""
|
| 29 |
+
project_root = Path(__file__).parent.parent
|
| 30 |
+
|
| 31 |
+
print(f"📁 Projet : {project_root}")
|
| 32 |
+
|
| 33 |
+
# 1. Formater avec Black
|
| 34 |
+
returncode = run_command(
|
| 35 |
+
f"cd {project_root} && .venv/bin/black ml_model/ tests/ examples/ --line-length 88",
|
| 36 |
+
"Formatage avec Black",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# 2. Trier les imports avec isort
|
| 40 |
+
returncode += run_command(
|
| 41 |
+
f"cd {project_root} && .venv/bin/python -m isort ml_model/ tests/ examples/ --profile black",
|
| 42 |
+
"Tri des imports avec isort",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# 3. Vérifier avec Flake8
|
| 46 |
+
returncode += run_command(
|
| 47 |
+
f"cd {project_root} && .venv/bin/python -m flake8 ml_model/ tests/ examples/ --max-line-length=88 --extend-ignore=E203,W503",
|
| 48 |
+
"Vérification avec Flake8",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# 4. Lancer les tests
|
| 52 |
+
returncode += run_command(
|
| 53 |
+
f"cd {project_root} && .venv/bin/python -m pytest tests/test_basic.py -v",
|
| 54 |
+
"Exécution des tests",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
print(f"\n{'=' * 60}")
|
| 58 |
+
if returncode == 0:
|
| 59 |
+
print("✅ Tous les checks passent !")
|
| 60 |
+
else:
|
| 61 |
+
print("⚠️ Certains problèmes subsistent. Vérifiez les logs ci-dessus.")
|
| 62 |
+
print(f"{'=' * 60}\n")
|
| 63 |
+
|
| 64 |
+
return returncode
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
sys.exit(main())
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration pytest pour ajouter le dossier racine au PYTHONPATH."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Ajouter le dossier racine du projet au PYTHONPATH
|
| 7 |
+
project_root = Path(__file__).parent.parent
|
| 8 |
+
sys.path.insert(0, str(project_root))
|
tests/test_mlflow_local.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script de test local pour train_model.py avec MLflow.
|
| 4 |
+
Utilise preprocess_data pour charger et préparer les données.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import mlflow
|
| 9 |
+
|
| 10 |
+
from ml_model.preprocess import preprocess_data
|
| 11 |
+
from ml_model.train_model import train_model
|
| 12 |
+
|
| 13 |
+
# Configure MLflow pour utiliser SQLite (nécessaire pour Model Registry)
|
| 14 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
print("🔄 Chargement et prétraitement des données...")
|
| 18 |
+
# Chemins des fichiers de données
|
| 19 |
+
data_paths = {
|
| 20 |
+
"sondage_path": "data/extrait_sondage.csv",
|
| 21 |
+
"eval_path": "data/extrait_eval.csv",
|
| 22 |
+
"sirh_path": "data/extrait_sirh.csv",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
X, y, scaler, onehot, ordinal = preprocess_data(raw_data_paths=data_paths)
|
| 26 |
+
print(f"✅ Données prétraitées :")
|
| 27 |
+
print(f" X shape: {X.shape}, y shape: {y.shape}")
|
| 28 |
+
print(f" Distribution y: {y.value_counts().to_dict()}")
|
| 29 |
+
|
| 30 |
+
print("\n🚀 Lancement de l'entraînement avec MLflow tracking...")
|
| 31 |
+
print(" (Cela peut prendre quelques minutes avec n_iter=1000...)\n")
|
| 32 |
+
|
| 33 |
+
best_model, best_params, cv_f1 = train_model(X, y)
|
| 34 |
+
|
| 35 |
+
print(f"\n✅ Entraînement terminé !")
|
| 36 |
+
print(f" CV F1-score: {cv_f1:.4f}")
|
| 37 |
+
|
| 38 |
+
print("\n📁 Vérification des artifacts MLflow dans ./mlruns :")
|
| 39 |
+
if os.path.exists("./mlruns"):
|
| 40 |
+
for root, dirs, files in os.walk("./mlruns"):
|
| 41 |
+
level = root.replace("./mlruns", "").count(os.sep)
|
| 42 |
+
indent = " " * 2 * level
|
| 43 |
+
print(f"{indent}{os.path.basename(root)}/")
|
| 44 |
+
subindent = " " * 2 * (level + 1)
|
| 45 |
+
for file in files[:5]: # Limite à 5 fichiers par dossier
|
| 46 |
+
print(f"{subindent}{file}")
|
| 47 |
+
if len(files) > 5:
|
| 48 |
+
print(f"{subindent}... ({len(files) - 5} autres fichiers)")
|
| 49 |
+
else:
|
| 50 |
+
print(" ⚠️ Dossier ./mlruns non trouvé")
|
| 51 |
+
|
| 52 |
+
print("\n💡 Pour visualiser les runs MLflow, exécutez :")
|
| 53 |
+
print(" mlflow ui")
|
| 54 |
+
print(" Puis ouvrez http://localhost:5000 dans votre navigateur")
|
tests/test_mlflow_quick.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test rapide MLflow avec n_iter=10 au lieu de 1000.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import mlflow
|
| 8 |
+
from sklearn.model_selection import RandomizedSearchCV
|
| 9 |
+
|
| 10 |
+
from ml_model.preprocess import preprocess_data
|
| 11 |
+
from ml_model.train_model import train_model
|
| 12 |
+
|
| 13 |
+
# Configure MLflow pour utiliser SQLite (nécessaire pour Model Registry)
|
| 14 |
+
mlflow.set_tracking_uri("sqlite:///mlflow.db")
|
| 15 |
+
|
| 16 |
+
# Patch temporaire pour test rapide
|
| 17 |
+
import ml_model.train_model as train_module
|
| 18 |
+
|
| 19 |
+
original_train = train_module.train_model
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def quick_train(X, y):
|
| 23 |
+
"""Version rapide avec n_iter=10"""
|
| 24 |
+
import mlflow
|
| 25 |
+
import mlflow.sklearn
|
| 26 |
+
from imblearn.over_sampling import SMOTE
|
| 27 |
+
from imblearn.pipeline import Pipeline as ImbPipeline
|
| 28 |
+
from scipy.stats import randint, uniform
|
| 29 |
+
from sklearn.metrics import classification_report, confusion_matrix
|
| 30 |
+
from sklearn.model_selection import RandomizedSearchCV, train_test_split
|
| 31 |
+
from xgboost import XGBClassifier
|
| 32 |
+
|
| 33 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 34 |
+
X, y, test_size=0.2, random_state=42, stratify=y
|
| 35 |
+
)
|
| 36 |
+
ratio = sum(y == 0) / sum(y == 1)
|
| 37 |
+
|
| 38 |
+
pipeline = ImbPipeline(
|
| 39 |
+
[("sampler", SMOTE(random_state=42)), ("clf", XGBClassifier(random_state=42))]
|
| 40 |
+
)
|
| 41 |
+
param_dist = {
|
| 42 |
+
"clf__max_depth": randint(3, 15),
|
| 43 |
+
"clf__n_estimators": randint(100, 1000),
|
| 44 |
+
"clf__learning_rate": uniform(0.001, 0.5),
|
| 45 |
+
"clf__subsample": uniform(0.4, 0.6),
|
| 46 |
+
"clf__reg_alpha": uniform(0, 3),
|
| 47 |
+
"clf__gamma": uniform(0, 10),
|
| 48 |
+
"clf__colsample_bytree": uniform(0.5, 0.5),
|
| 49 |
+
"clf__min_child_weight": randint(1, 15),
|
| 50 |
+
"clf__scale_pos_weight": uniform(1, ratio),
|
| 51 |
+
"clf__tree_method": ["auto", "hist"],
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
random = RandomizedSearchCV(
|
| 55 |
+
pipeline,
|
| 56 |
+
param_dist,
|
| 57 |
+
n_iter=10, # 🚀 Test rapide !
|
| 58 |
+
cv=3, # Réduit aussi le CV
|
| 59 |
+
scoring="f1",
|
| 60 |
+
n_jobs=-1,
|
| 61 |
+
random_state=42,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
with mlflow.start_run(run_name="XGBoost_Quick_Test"):
|
| 65 |
+
random.fit(X_train, y_train)
|
| 66 |
+
|
| 67 |
+
best_model = random.best_estimator_ # type: ignore[assignment]
|
| 68 |
+
best_params = random.best_params_
|
| 69 |
+
cv_f1 = random.best_score_
|
| 70 |
+
|
| 71 |
+
mlflow.log_params(best_params)
|
| 72 |
+
mlflow.log_metric("cv_f1", cv_f1)
|
| 73 |
+
|
| 74 |
+
y_pred = best_model.predict(X_test) # type: ignore[attr-defined]
|
| 75 |
+
report = classification_report(y_test, y_pred, output_dict=True) # type: ignore[arg-type]
|
| 76 |
+
|
| 77 |
+
mlflow.log_metric("test_precision", float(report["1"]["precision"])) # type: ignore[index]
|
| 78 |
+
mlflow.log_metric("test_recall", float(report["1"]["recall"])) # type: ignore[index]
|
| 79 |
+
mlflow.log_metric("test_f1", float(report["1"]["f1-score"])) # type: ignore[index]
|
| 80 |
+
|
| 81 |
+
# Log model et enregistre dans Model Registry
|
| 82 |
+
model_info = mlflow.sklearn.log_model(best_model, "model") # type: ignore[attr-defined]
|
| 83 |
+
mlflow.register_model(
|
| 84 |
+
model_uri=model_info.model_uri, name="XGBoost_Employee_Turnover"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
print("Meilleurs params:", best_params)
|
| 88 |
+
print("Meilleur CV F1:", cv_f1)
|
| 89 |
+
print(classification_report(y_test, y_pred))
|
| 90 |
+
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
|
| 91 |
+
|
| 92 |
+
return best_model, best_params, cv_f1
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
print("🔄 Chargement et prétraitement des données...")
|
| 97 |
+
data_paths = {
|
| 98 |
+
"sondage_path": "data/extrait_sondage.csv",
|
| 99 |
+
"eval_path": "data/extrait_eval.csv",
|
| 100 |
+
"sirh_path": "data/extrait_sirh.csv",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
X, y, scaler, onehot, ordinal = preprocess_data(raw_data_paths=data_paths)
|
| 104 |
+
print(f"✅ Données prétraitées :")
|
| 105 |
+
print(f" X shape: {X.shape}, y shape: {y.shape}")
|
| 106 |
+
print(f" Distribution y: {y.value_counts().to_dict()}")
|
| 107 |
+
|
| 108 |
+
print("\n🚀 Lancement du test rapide (n_iter=10, cv=3)...\n")
|
| 109 |
+
|
| 110 |
+
best_model, best_params, cv_f1 = quick_train(X, y)
|
| 111 |
+
|
| 112 |
+
print(f"\n✅ Test terminé ! CV F1-score: {cv_f1:.4f}")
|
| 113 |
+
|
| 114 |
+
print("\n📁 Vérification des artifacts MLflow dans ./mlruns :")
|
| 115 |
+
if os.path.exists("./mlruns"):
|
| 116 |
+
for root, dirs, files in os.walk("./mlruns"):
|
| 117 |
+
level = root.replace("./mlruns", "").count(os.sep)
|
| 118 |
+
if level < 3: # Limite la profondeur
|
| 119 |
+
indent = " " * 2 * level
|
| 120 |
+
print(f"{indent}{os.path.basename(root)}/")
|
| 121 |
+
if level == 2: # Affiche fichiers dans les runs
|
| 122 |
+
subindent = " " * 2 * (level + 1)
|
| 123 |
+
for file in files[:3]:
|
| 124 |
+
print(f"{subindent}{file}")
|
| 125 |
+
if len(files) > 3:
|
| 126 |
+
print(f"{subindent}... (+{len(files) - 3} fichiers)")
|
| 127 |
+
|
| 128 |
+
print("\n💡 Pour visualiser les runs MLflow, exécutez :")
|
| 129 |
+
print(" mlflow ui")
|
| 130 |
+
print(" Puis ouvrez http://localhost:5000")
|