File size: 4,043 Bytes
992aa4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac96642
 
 
 
 
 
 
992aa4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Inference module β€” load artifacts and score new observations.

Includes lightweight covariate drift detection that compares incoming
feature distributions against training-time reference statistics.
"""
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import joblib
import numpy as np
import pandas as pd

from src.config import (
    ARTIFACTS_DIR,
    DRIFT_REF_FILE,
    FEATURE_NAMES_FILE,
    MODEL_FINAL_FILE,
    PREPROCESSOR_FILE,
)
from src.features import engineer_features, _resolve_columns

logger = logging.getLogger(__name__)

# Drift thresholds (z-score of column mean vs reference)
DRIFT_WARN_THRESHOLD = 2.0
DRIFT_ALERT_THRESHOLD = 3.5


class GridRiskPredictor:
    """Stateless predictor wrapping saved artifacts."""

    def __init__(self, artifacts_dir: Path = ARTIFACTS_DIR) -> None:
        self.model = joblib.load(artifacts_dir / MODEL_FINAL_FILE)
        self.preprocessor = joblib.load(artifacts_dir / PREPROCESSOR_FILE)
        with open(artifacts_dir / FEATURE_NAMES_FILE) as f:
            self.feature_names: List[str] = json.load(f)

        drift_path = artifacts_dir / DRIFT_REF_FILE
        self.drift_ref: Optional[Dict[str, Dict[str, float]]] = (
            joblib.load(drift_path) if drift_path.exists() else None
        )

    def predict(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """
        Score a DataFrame of raw outage records.

        Returns
        -------
        probabilities : np.ndarray  – P(high_impact)
        labels        : np.ndarray  – binary prediction at 0.5 threshold
        """
        df = engineer_features(df)
        
        # Ensure that all feature names exist in df
        expected_cols = getattr(self.preprocessor, "feature_names_in_", [])
        for col in expected_cols:
            if col not in df.columns:
                df[col] = np.nan
                
        X = self.preprocessor.transform(df)
        probs = self.model.predict_proba(X)[:, 1]
        labels = (probs >= 0.5).astype(int)
        return probs, labels

    def predict_single(self, record: Dict[str, Any]) -> Dict[str, Any]:
        """Convenience wrapper for a single observation (used by UI)."""
        df = pd.DataFrame([record])
        probs, labels = self.predict(df)
        return {
            "probability": float(probs[0]),
            "prediction": int(labels[0]),
            "risk_tier": _risk_tier(probs[0]),
        }

    # ------------------------------------------------------------------
    # Drift detection
    # ------------------------------------------------------------------
    def check_drift(self, df: pd.DataFrame) -> Dict[str, str]:
        """
        Compare incoming batch column means against training reference.

        Returns a dict of {feature: status} where status ∈ {ok, warn, alert}.
        """
        if self.drift_ref is None:
            logger.warning("No drift reference found β€” skipping check.")
            return {}

        df = engineer_features(df)
        results: Dict[str, str] = {}
        for col, ref in self.drift_ref.items():
            if col not in df.columns:
                continue
            col_mean = df[col].dropna().mean()
            ref_mean, ref_std = ref["mean"], ref["std"]
            if ref_std == 0:
                continue
            z = abs(col_mean - ref_mean) / ref_std
            if z >= DRIFT_ALERT_THRESHOLD:
                status = "alert"
            elif z >= DRIFT_WARN_THRESHOLD:
                status = "warn"
            else:
                status = "ok"
            results[col] = status

        drifted = {k: v for k, v in results.items() if v != "ok"}
        if drifted:
            logger.warning("Drift detected: %s", drifted)
        return results


def _risk_tier(prob: float) -> str:
    if prob >= 0.75:
        return "CRITICAL"
    if prob >= 0.50:
        return "HIGH"
    if prob >= 0.25:
        return "MODERATE"
    return "LOW"