Spaces:
Sleeping
Sleeping
| """ | |
| Training pipeline β baseline logistic regression + XGBoost final model. | |
| Usage | |
| ----- | |
| python -m src.train # default data path | |
| python -m src.train --data path/to/outage.csv # custom path | |
| Outputs (β artifacts/) | |
| logistic_baseline.joblib fitted LR model | |
| xgb_final.joblib fitted XGBoost model | |
| preprocessor.joblib sklearn ColumnTransformer | |
| metrics.json evaluation results | |
| drift_reference.joblib column-level stats for future drift checks | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict | |
| import joblib | |
| import numpy as np | |
| from sklearn.calibration import calibration_curve | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import ( | |
| average_precision_score, | |
| brier_score_loss, | |
| classification_report, | |
| roc_auc_score, | |
| ) | |
| from sklearn.model_selection import StratifiedKFold, cross_val_score | |
| from xgboost import XGBClassifier | |
| from src.config import ( | |
| ARTIFACTS_DIR, | |
| DRIFT_REF_FILE, | |
| METRICS_FILE, | |
| MODEL_BASELINE_FILE, | |
| MODEL_FINAL_FILE, | |
| RANDOM_STATE, | |
| ) | |
| from src.data import get_dataset | |
| from src.features import engineer_features, prepare_splits, _resolve_columns | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s β %(message)s") | |
| logger = logging.getLogger(__name__) | |
| CV_FOLDS = 5 | |
| def _evaluate( | |
| model: Any, | |
| X: np.ndarray, | |
| y: np.ndarray, | |
| label: str, | |
| ) -> Dict[str, float]: | |
| """Compute core classification metrics and return as dict.""" | |
| y_prob = model.predict_proba(X)[:, 1] | |
| roc = roc_auc_score(y, y_prob) | |
| pr = average_precision_score(y, y_prob) | |
| brier = brier_score_loss(y, y_prob) | |
| logger.info("[%s] ROC-AUC=%.4f PR-AUC=%.4f Brier=%.4f", label, roc, pr, brier) | |
| return {"roc_auc": round(roc, 4), "pr_auc": round(pr, 4), "brier": round(brier, 4)} | |
| def _cross_validate(model: Any, X: np.ndarray, y: np.ndarray, label: str) -> float: | |
| """Stratified k-fold cross-validation on ROC-AUC.""" | |
| cv = StratifiedKFold(n_splits=CV_FOLDS, shuffle=True, random_state=RANDOM_STATE) | |
| scores = cross_val_score(model, X, y, cv=cv, scoring="roc_auc", n_jobs=-1) | |
| mean, std = scores.mean(), scores.std() | |
| logger.info("[%s CV] ROC-AUC = %.4f Β± %.4f", label, mean, std) | |
| return round(mean, 4) | |
| def train_baseline(X_train: np.ndarray, y_train: np.ndarray) -> LogisticRegression: | |
| lr = LogisticRegression( | |
| class_weight="balanced", | |
| max_iter=1000, | |
| solver="lbfgs", | |
| random_state=RANDOM_STATE, | |
| ) | |
| lr.fit(X_train, y_train) | |
| joblib.dump(lr, ARTIFACTS_DIR / MODEL_BASELINE_FILE) | |
| return lr | |
| def train_xgb(X_train: np.ndarray, y_train: np.ndarray) -> XGBClassifier: | |
| """XGBoost with scale_pos_weight to handle class imbalance.""" | |
| neg, pos = np.bincount(y_train) | |
| scale = neg / max(pos, 1) | |
| xgb = XGBClassifier( | |
| n_estimators=400, | |
| max_depth=5, | |
| learning_rate=0.05, | |
| subsample=0.8, | |
| colsample_bytree=0.8, | |
| scale_pos_weight=scale, | |
| eval_metric="aucpr", | |
| use_label_encoder=False, | |
| random_state=RANDOM_STATE, | |
| n_jobs=-1, | |
| ) | |
| xgb.fit(X_train, y_train) | |
| joblib.dump(xgb, ARTIFACTS_DIR / MODEL_FINAL_FILE) | |
| return xgb | |
| def save_drift_reference(df_raw, path: Path = ARTIFACTS_DIR / DRIFT_REF_FILE) -> None: | |
| """ | |
| Persist column-level statistics from training data so we can detect | |
| covariate drift at inference time. | |
| """ | |
| df = engineer_features(df_raw) | |
| num_cols, _ = _resolve_columns(df) | |
| stats = {} | |
| for col in num_cols: | |
| if col in df.columns: | |
| s = df[col].dropna() | |
| stats[col] = {"mean": float(s.mean()), "std": float(s.std()), "min": float(s.min()), "max": float(s.max())} | |
| joblib.dump(stats, path) | |
| logger.info("Drift reference saved β %d features tracked", len(stats)) | |
| def calibration_summary(model: Any, X: np.ndarray, y: np.ndarray) -> None: | |
| """Log calibration quality β in production this feeds a monitoring dashboard.""" | |
| y_prob = model.predict_proba(X)[:, 1] | |
| fraction_pos, mean_pred = calibration_curve(y, y_prob, n_bins=8, strategy="quantile") | |
| diffs = np.abs(fraction_pos - mean_pred) | |
| max_gap = diffs.max() | |
| logger.info( | |
| "Calibration check β max bin gap: %.3f %s", | |
| max_gap, | |
| "(acceptable)" if max_gap < 0.10 else "(consider Platt scaling)", | |
| ) | |
| def run(data_path: str | None = None) -> Dict[str, Any]: | |
| logger.info("=" * 60) | |
| logger.info("GRID RISK PLATFORM β Training pipeline") | |
| logger.info("=" * 60) | |
| df = get_dataset(data_path) | |
| X_train, X_test, y_train, y_test, preprocessor, feature_names = prepare_splits(df) | |
| # --- Baseline -------------------------------------------------------- | |
| lr = train_baseline(X_train, y_train) | |
| lr_cv = _cross_validate(lr, X_train, y_train, "LR-baseline") | |
| lr_test = _evaluate(lr, X_test, y_test, "LR-baseline test") | |
| # --- XGBoost --------------------------------------------------------- | |
| xgb = train_xgb(X_train, y_train) | |
| xgb_cv = _cross_validate(xgb, X_train, y_train, "XGB-final") | |
| xgb_test = _evaluate(xgb, X_test, y_test, "XGB-final test") | |
| calibration_summary(xgb, X_test, y_test) | |
| # --- Persist metrics ------------------------------------------------- | |
| metrics = { | |
| "baseline": {**lr_test, "cv_roc_auc": lr_cv}, | |
| "xgboost": {**xgb_test, "cv_roc_auc": xgb_cv}, | |
| } | |
| with open(ARTIFACTS_DIR / METRICS_FILE, "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| logger.info("Metrics saved β %s", ARTIFACTS_DIR / METRICS_FILE) | |
| # --- Drift reference ------------------------------------------------- | |
| save_drift_reference(df) | |
| logger.info("Training complete. Artifacts β %s", ARTIFACTS_DIR) | |
| return metrics | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train Grid Risk models") | |
| parser.add_argument("--data", type=str, default=None, help="Path to outage CSV/XLSX") | |
| args = parser.parse_args() | |
| run(data_path=args.data) | |
| if __name__ == "__main__": | |
| # Allow running from project root: python -m src.train | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| main() | |