grid-risk-platform / src /train.py
Nashid-Noor
Initial commit for HF Spaces without binaries
992aa4f
"""
Training pipeline β€” baseline logistic regression + XGBoost final model.
Usage
-----
python -m src.train # default data path
python -m src.train --data path/to/outage.csv # custom path
Outputs (β†’ artifacts/)
logistic_baseline.joblib fitted LR model
xgb_final.joblib fitted XGBoost model
preprocessor.joblib sklearn ColumnTransformer
metrics.json evaluation results
drift_reference.joblib column-level stats for future drift checks
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
from typing import Any, Dict
import joblib
import numpy as np
from sklearn.calibration import calibration_curve
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
average_precision_score,
brier_score_loss,
classification_report,
roc_auc_score,
)
from sklearn.model_selection import StratifiedKFold, cross_val_score
from xgboost import XGBClassifier
from src.config import (
ARTIFACTS_DIR,
DRIFT_REF_FILE,
METRICS_FILE,
MODEL_BASELINE_FILE,
MODEL_FINAL_FILE,
RANDOM_STATE,
)
from src.data import get_dataset
from src.features import engineer_features, prepare_splits, _resolve_columns
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s β€” %(message)s")
logger = logging.getLogger(__name__)
CV_FOLDS = 5
def _evaluate(
model: Any,
X: np.ndarray,
y: np.ndarray,
label: str,
) -> Dict[str, float]:
"""Compute core classification metrics and return as dict."""
y_prob = model.predict_proba(X)[:, 1]
roc = roc_auc_score(y, y_prob)
pr = average_precision_score(y, y_prob)
brier = brier_score_loss(y, y_prob)
logger.info("[%s] ROC-AUC=%.4f PR-AUC=%.4f Brier=%.4f", label, roc, pr, brier)
return {"roc_auc": round(roc, 4), "pr_auc": round(pr, 4), "brier": round(brier, 4)}
def _cross_validate(model: Any, X: np.ndarray, y: np.ndarray, label: str) -> float:
"""Stratified k-fold cross-validation on ROC-AUC."""
cv = StratifiedKFold(n_splits=CV_FOLDS, shuffle=True, random_state=RANDOM_STATE)
scores = cross_val_score(model, X, y, cv=cv, scoring="roc_auc", n_jobs=-1)
mean, std = scores.mean(), scores.std()
logger.info("[%s CV] ROC-AUC = %.4f Β± %.4f", label, mean, std)
return round(mean, 4)
def train_baseline(X_train: np.ndarray, y_train: np.ndarray) -> LogisticRegression:
lr = LogisticRegression(
class_weight="balanced",
max_iter=1000,
solver="lbfgs",
random_state=RANDOM_STATE,
)
lr.fit(X_train, y_train)
joblib.dump(lr, ARTIFACTS_DIR / MODEL_BASELINE_FILE)
return lr
def train_xgb(X_train: np.ndarray, y_train: np.ndarray) -> XGBClassifier:
"""XGBoost with scale_pos_weight to handle class imbalance."""
neg, pos = np.bincount(y_train)
scale = neg / max(pos, 1)
xgb = XGBClassifier(
n_estimators=400,
max_depth=5,
learning_rate=0.05,
subsample=0.8,
colsample_bytree=0.8,
scale_pos_weight=scale,
eval_metric="aucpr",
use_label_encoder=False,
random_state=RANDOM_STATE,
n_jobs=-1,
)
xgb.fit(X_train, y_train)
joblib.dump(xgb, ARTIFACTS_DIR / MODEL_FINAL_FILE)
return xgb
def save_drift_reference(df_raw, path: Path = ARTIFACTS_DIR / DRIFT_REF_FILE) -> None:
"""
Persist column-level statistics from training data so we can detect
covariate drift at inference time.
"""
df = engineer_features(df_raw)
num_cols, _ = _resolve_columns(df)
stats = {}
for col in num_cols:
if col in df.columns:
s = df[col].dropna()
stats[col] = {"mean": float(s.mean()), "std": float(s.std()), "min": float(s.min()), "max": float(s.max())}
joblib.dump(stats, path)
logger.info("Drift reference saved β€” %d features tracked", len(stats))
def calibration_summary(model: Any, X: np.ndarray, y: np.ndarray) -> None:
"""Log calibration quality β€” in production this feeds a monitoring dashboard."""
y_prob = model.predict_proba(X)[:, 1]
fraction_pos, mean_pred = calibration_curve(y, y_prob, n_bins=8, strategy="quantile")
diffs = np.abs(fraction_pos - mean_pred)
max_gap = diffs.max()
logger.info(
"Calibration check β€” max bin gap: %.3f %s",
max_gap,
"(acceptable)" if max_gap < 0.10 else "(consider Platt scaling)",
)
def run(data_path: str | None = None) -> Dict[str, Any]:
logger.info("=" * 60)
logger.info("GRID RISK PLATFORM β€” Training pipeline")
logger.info("=" * 60)
df = get_dataset(data_path)
X_train, X_test, y_train, y_test, preprocessor, feature_names = prepare_splits(df)
# --- Baseline --------------------------------------------------------
lr = train_baseline(X_train, y_train)
lr_cv = _cross_validate(lr, X_train, y_train, "LR-baseline")
lr_test = _evaluate(lr, X_test, y_test, "LR-baseline test")
# --- XGBoost ---------------------------------------------------------
xgb = train_xgb(X_train, y_train)
xgb_cv = _cross_validate(xgb, X_train, y_train, "XGB-final")
xgb_test = _evaluate(xgb, X_test, y_test, "XGB-final test")
calibration_summary(xgb, X_test, y_test)
# --- Persist metrics -------------------------------------------------
metrics = {
"baseline": {**lr_test, "cv_roc_auc": lr_cv},
"xgboost": {**xgb_test, "cv_roc_auc": xgb_cv},
}
with open(ARTIFACTS_DIR / METRICS_FILE, "w") as f:
json.dump(metrics, f, indent=2)
logger.info("Metrics saved β†’ %s", ARTIFACTS_DIR / METRICS_FILE)
# --- Drift reference -------------------------------------------------
save_drift_reference(df)
logger.info("Training complete. Artifacts β†’ %s", ARTIFACTS_DIR)
return metrics
def main() -> None:
parser = argparse.ArgumentParser(description="Train Grid Risk models")
parser.add_argument("--data", type=str, default=None, help="Path to outage CSV/XLSX")
args = parser.parse_args()
run(data_path=args.data)
if __name__ == "__main__":
# Allow running from project root: python -m src.train
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
main()