Spaces:
Sleeping
Sleeping
| """Drift detection using PSI and Kolmogorov-Smirnov test. | |
| PSI interpretation: | |
| PSI < 0.1 -> no significant drift | |
| PSI < 0.2 -> moderate drift | |
| PSI >= 0.2 -> significant drift, action required | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pandas as pd | |
| from scipy import stats | |
| from src.utils.config import settings, resolve | |
| from src.utils.logging_config import get_logger | |
| log = get_logger(__name__) | |
| _N_BINS = 10 | |
| _PSI_THRESHOLD = settings.monitoring.drift.psi_threshold | |
| _KS_PVALUE = settings.monitoring.drift.ks_pvalue_threshold | |
| class DriftDetector: | |
| """Detect feature and performance drift between reference and live distributions.""" | |
| def __init__(self, reference_df: Optional[pd.DataFrame] = None) -> None: | |
| self._reference: Optional[pd.DataFrame] = None | |
| self._report_path = resolve(settings.monitoring.drift_report_path) | |
| if reference_df is not None: | |
| self.set_reference(reference_df) | |
| def set_reference(self, df: pd.DataFrame) -> None: | |
| """Set the reference (training) distribution.""" | |
| self._reference = df.copy() | |
| log.info("Reference distribution set (%d samples, %d features)", *df.shape) | |
| def has_reference(self) -> bool: | |
| return self._reference is not None | |
| def detect_feature_drift( | |
| self, | |
| live_df: pd.DataFrame, | |
| features: Optional[list[str]] = None, | |
| ) -> dict: | |
| """Compute PSI and KS test for each feature. | |
| Returns a report dict with drift_detected, feature_results, | |
| drifted_features, n_live_samples, and timestamp. | |
| """ | |
| if self._reference is None: | |
| raise RuntimeError("Call set_reference() before detect_feature_drift().") | |
| features = features or [c for c in self._reference.columns if c != "timestamp"] | |
| results: dict = {} | |
| for feat in features: | |
| if feat not in self._reference.columns or feat not in live_df.columns: | |
| continue | |
| ref_vals = self._reference[feat].dropna().to_numpy(dtype=float) | |
| live_vals = live_df[feat].dropna().to_numpy(dtype=float) | |
| psi = self._psi(ref_vals, live_vals) | |
| ks_stat, ks_pvalue = stats.ks_2samp(ref_vals, live_vals) | |
| drifted = bool((psi >= _PSI_THRESHOLD) or (ks_pvalue < _KS_PVALUE)) | |
| results[feat] = { | |
| "psi": round(float(psi), 4), | |
| "ks_stat": round(float(ks_stat), 4), | |
| "ks_pvalue": round(float(ks_pvalue), 4), | |
| "drifted": drifted, | |
| } | |
| drifted_features = [f for f, v in results.items() if v["drifted"]] | |
| drift_detected = len(drifted_features) > 0 | |
| report = { | |
| "drift_detected": drift_detected, | |
| "feature_results": results, | |
| "drifted_features": drifted_features, | |
| "n_live_samples": len(live_df), | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| } | |
| self._save_report(report, "feature") | |
| if drift_detected: | |
| log.warning("FEATURE DRIFT detected! Drifted features: %s", drifted_features) | |
| else: | |
| log.info("No significant feature drift detected.") | |
| return report | |
| def detect_performance_drift( | |
| self, | |
| recent_rmse: float, | |
| baseline_rmse: float, | |
| ) -> dict: | |
| """Check if recent RMSE exceeds the baseline by the configured threshold.""" | |
| threshold = settings.monitoring.performance.degradation_threshold | |
| pct_change = (recent_rmse - baseline_rmse) / max(baseline_rmse, 1e-9) | |
| degraded = pct_change > threshold | |
| report = { | |
| "drift_detected": degraded, | |
| "recent_rmse": round(recent_rmse, 4), | |
| "baseline_rmse": round(baseline_rmse, 4), | |
| "pct_change": round(pct_change * 100, 2), | |
| "threshold_pct": threshold * 100, | |
| "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| } | |
| self._save_report(report, "performance") | |
| if degraded: | |
| log.warning( | |
| "PERFORMANCE DRIFT detected! RMSE increased by %.1f%% " | |
| "(%.4f -> %.4f, threshold=%.0f%%)", | |
| pct_change * 100, baseline_rmse, recent_rmse, threshold * 100, | |
| ) | |
| return report | |
| def _psi(reference: np.ndarray, current: np.ndarray, n_bins: int = _N_BINS) -> float: | |
| """Population Stability Index.""" | |
| bins = np.percentile(reference, np.linspace(0, 100, n_bins + 1)) | |
| bins[0] -= 1e-9 | |
| bins[-1] += 1e-9 | |
| ref_counts = np.histogram(reference, bins=bins)[0] | |
| cur_counts = np.histogram(current, bins=bins)[0] | |
| ref_pct = np.where(ref_counts == 0, 1e-4, ref_counts / len(reference)) | |
| cur_pct = np.where(cur_counts == 0, 1e-4, cur_counts / len(current)) | |
| psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct)) | |
| return float(psi) | |
| def _save_report(self, report: dict, report_type: str) -> None: | |
| report["report_type"] = report_type | |
| with open(self._report_path, "a", encoding="utf-8") as fh: | |
| fh.write(json.dumps(report, default=_json_default) + "\n") | |
| def _json_default(obj): | |
| import numpy as np | |
| if isinstance(obj, np.bool_): | |
| return bool(obj) | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| if isinstance(obj, np.floating): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| raise TypeError(f"Object of type {type(obj).__name__} is not JSON serialisable") | |