File size: 5,675 Bytes
1aa566a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Drift detection using PSI and Kolmogorov-Smirnov test.

PSI interpretation:
  PSI < 0.1  -> no significant drift
  PSI < 0.2  -> moderate drift
  PSI >= 0.2 -> significant drift, action required
"""
from __future__ import annotations

import json
import time
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
from scipy import stats

from src.utils.config import settings, resolve
from src.utils.logging_config import get_logger

log = get_logger(__name__)

_N_BINS = 10
_PSI_THRESHOLD = settings.monitoring.drift.psi_threshold
_KS_PVALUE = settings.monitoring.drift.ks_pvalue_threshold


class DriftDetector:
    """Detect feature and performance drift between reference and live distributions."""

    def __init__(self, reference_df: Optional[pd.DataFrame] = None) -> None:
        self._reference: Optional[pd.DataFrame] = None
        self._report_path = resolve(settings.monitoring.drift_report_path)

        if reference_df is not None:
            self.set_reference(reference_df)

    def set_reference(self, df: pd.DataFrame) -> None:
        """Set the reference (training) distribution."""
        self._reference = df.copy()
        log.info("Reference distribution set (%d samples, %d features)", *df.shape)

    def has_reference(self) -> bool:
        return self._reference is not None

    def detect_feature_drift(
        self,
        live_df: pd.DataFrame,
        features: Optional[list[str]] = None,
    ) -> dict:
        """Compute PSI and KS test for each feature.

        Returns a report dict with drift_detected, feature_results,
        drifted_features, n_live_samples, and timestamp.
        """
        if self._reference is None:
            raise RuntimeError("Call set_reference() before detect_feature_drift().")

        features = features or [c for c in self._reference.columns if c != "timestamp"]
        results: dict = {}

        for feat in features:
            if feat not in self._reference.columns or feat not in live_df.columns:
                continue
            ref_vals = self._reference[feat].dropna().to_numpy(dtype=float)
            live_vals = live_df[feat].dropna().to_numpy(dtype=float)
            psi = self._psi(ref_vals, live_vals)
            ks_stat, ks_pvalue = stats.ks_2samp(ref_vals, live_vals)
            drifted = bool((psi >= _PSI_THRESHOLD) or (ks_pvalue < _KS_PVALUE))

            results[feat] = {
                "psi": round(float(psi), 4),
                "ks_stat": round(float(ks_stat), 4),
                "ks_pvalue": round(float(ks_pvalue), 4),
                "drifted": drifted,
            }

        drifted_features = [f for f, v in results.items() if v["drifted"]]
        drift_detected = len(drifted_features) > 0

        report = {
            "drift_detected": drift_detected,
            "feature_results": results,
            "drifted_features": drifted_features,
            "n_live_samples": len(live_df),
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        }

        self._save_report(report, "feature")

        if drift_detected:
            log.warning("FEATURE DRIFT detected! Drifted features: %s", drifted_features)
        else:
            log.info("No significant feature drift detected.")

        return report

    def detect_performance_drift(
        self,
        recent_rmse: float,
        baseline_rmse: float,
    ) -> dict:
        """Check if recent RMSE exceeds the baseline by the configured threshold."""
        threshold = settings.monitoring.performance.degradation_threshold
        pct_change = (recent_rmse - baseline_rmse) / max(baseline_rmse, 1e-9)
        degraded = pct_change > threshold

        report = {
            "drift_detected": degraded,
            "recent_rmse": round(recent_rmse, 4),
            "baseline_rmse": round(baseline_rmse, 4),
            "pct_change": round(pct_change * 100, 2),
            "threshold_pct": threshold * 100,
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        }

        self._save_report(report, "performance")

        if degraded:
            log.warning(
                "PERFORMANCE DRIFT detected! RMSE increased by %.1f%% "
                "(%.4f -> %.4f, threshold=%.0f%%)",
                pct_change * 100, baseline_rmse, recent_rmse, threshold * 100,
            )
        return report

    @staticmethod
    def _psi(reference: np.ndarray, current: np.ndarray, n_bins: int = _N_BINS) -> float:
        """Population Stability Index."""
        bins = np.percentile(reference, np.linspace(0, 100, n_bins + 1))
        bins[0] -= 1e-9
        bins[-1] += 1e-9

        ref_counts = np.histogram(reference, bins=bins)[0]
        cur_counts = np.histogram(current, bins=bins)[0]

        ref_pct = np.where(ref_counts == 0, 1e-4, ref_counts / len(reference))
        cur_pct = np.where(cur_counts == 0, 1e-4, cur_counts / len(current))

        psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
        return float(psi)

    def _save_report(self, report: dict, report_type: str) -> None:
        report["report_type"] = report_type
        with open(self._report_path, "a", encoding="utf-8") as fh:
            fh.write(json.dumps(report, default=_json_default) + "\n")


def _json_default(obj):
    import numpy as np
    if isinstance(obj, np.bool_):
        return bool(obj)
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serialisable")