| """Data drift detection module.""" |
| import numpy as np |
| from scipy import stats |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Union |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class DriftResult: |
| """Data class for drift detection results.""" |
| drift_score: float |
| p_value: Optional[float] |
| is_drift: bool |
| method: str |
| threshold: float |
| feature_name: Optional[str] = None |
| statistic: Optional[float] = None |
| sample_sizes: Optional[Dict[str, int]] = None |
| |
|
|
| class DataDriftDetector: |
| """Detects data drift between reference and current distributions.""" |
| |
| def __init__(self, method: str = "ks", threshold: float = 0.05): |
| """ |
| Initialize drift detector. |
| |
| Args: |
| method: Detection method ('ks' for Kolmogorov-Smirnov, |
| 'psi' for Population Stability Index) |
| threshold: Significance threshold |
| """ |
| self.method = method |
| self.threshold = threshold |
| self.supported_methods = ["ks", "psi", "wasserstein"] |
| |
| if method not in self.supported_methods: |
| raise ValueError(f"Method {method} not supported. " |
| f"Choose from {self.supported_methods}") |
| |
| def detect_drift( |
| self, |
| reference: Union[np.ndarray, List[float]], |
| current: Union[np.ndarray, List[float]], |
| feature_name: Optional[str] = None |
| ) -> DriftResult: |
| """ |
| Detect drift between reference and current data. |
| |
| Args: |
| reference: Reference distribution data |
| current: Current distribution data |
| feature_name: Optional name of the feature being analyzed |
| |
| Returns: |
| DriftResult containing detection results |
| """ |
| |
| ref_array = np.array(reference).flatten() |
| curr_array = np.array(current).flatten() |
| |
| |
| sample_sizes = { |
| "reference": len(ref_array), |
| "current": len(curr_array) |
| } |
| |
| if self.method == "ks": |
| return self._ks_drift(ref_array, curr_array, feature_name, sample_sizes) |
| elif self.method == "psi": |
| return self._psi_drift(ref_array, curr_array, feature_name, sample_sizes) |
| elif self.method == "wasserstein": |
| return self._wasserstein_drift(ref_array, curr_array, feature_name, sample_sizes) |
| |
| def _ks_drift( |
| self, |
| reference: np.ndarray, |
| current: np.ndarray, |
| feature_name: Optional[str], |
| sample_sizes: Dict[str, int] |
| ) -> DriftResult: |
| """Kolmogorov-Smirnov test for drift detection.""" |
| stat, p_value = stats.ks_2samp(reference, current) |
| drift_score = 1 - p_value |
| is_drift = p_value < self.threshold |
| |
| logger.debug( |
| f"KS test: statistic={stat:.4f}, p={p_value:.4f}, " |
| f"drift_score={drift_score:.4f}, drift={is_drift}" |
| ) |
| |
| return DriftResult( |
| drift_score=drift_score, |
| p_value=p_value, |
| is_drift=is_drift, |
| method="kolmogorov_smirnov", |
| threshold=self.threshold, |
| feature_name=feature_name, |
| statistic=stat, |
| sample_sizes=sample_sizes |
| ) |
| |
| def _psi_drift( |
| self, |
| reference: np.ndarray, |
| current: np.ndarray, |
| feature_name: Optional[str], |
| sample_sizes: Dict[str, int] |
| ) -> DriftResult: |
| """Population Stability Index for drift detection.""" |
| |
| n_bins = max(2, min(10, len(reference) // 20)) |
| bins = np.histogram_bin_edges(reference, bins=n_bins) |
| |
| |
| ref_hist, _ = np.histogram(reference, bins=bins) |
| curr_hist, _ = np.histogram(current, bins=bins) |
| |
| |
| epsilon = 1e-10 |
| |
| |
| ref_prop = (ref_hist + epsilon) / (len(reference) + epsilon * n_bins) |
| curr_prop = (curr_hist + epsilon) / (len(current) + epsilon * n_bins) |
| |
| |
| psi_values = (curr_prop - ref_prop) * np.log((curr_prop + epsilon) / (ref_prop + epsilon)) |
| |
| psi_total = np.sum(psi_values) |
| |
| |
| drift_score = min(1.0, psi_total / 0.5) |
| is_drift = psi_total > 0.25 |
| |
| return DriftResult( |
| drift_score=drift_score, |
| p_value=None, |
| is_drift=is_drift, |
| method="population_stability_index", |
| threshold=0.25, |
| feature_name=feature_name, |
| statistic=psi_total, |
| sample_sizes=sample_sizes |
| ) |
| |
| def _wasserstein_drift( |
| self, |
| reference: np.ndarray, |
| current: np.ndarray, |
| feature_name: Optional[str], |
| sample_sizes: Dict[str, int] |
| ) -> DriftResult: |
| """Wasserstein distance for drift detection.""" |
| from scipy.stats import wasserstein_distance |
| |
| distance = wasserstein_distance(reference, current) |
| |
| |
| drift_score = min(1.0, distance / 10.0) |
| is_drift = distance > 1.0 |
| |
| return DriftResult( |
| drift_score=drift_score, |
| p_value=None, |
| is_drift=is_drift, |
| method="wasserstein_distance", |
| threshold=1.0, |
| feature_name=feature_name, |
| statistic=distance, |
| sample_sizes=sample_sizes |
| ) |
| |
| def detect_batch_drift( |
| self, |
| reference_data: Dict[str, np.ndarray], |
| current_data: Dict[str, np.ndarray] |
| ) -> Dict[str, DriftResult]: |
| """ |
| Detect drift for multiple features. |
| |
| Args: |
| reference_data: Dictionary of reference features |
| current_data: Dictionary of current features |
| |
| Returns: |
| Dictionary mapping feature names to drift results |
| """ |
| results = {} |
| |
| for feature_name in reference_data.keys(): |
| if feature_name in current_data: |
| result = self.detect_drift( |
| reference_data[feature_name], |
| current_data[feature_name], |
| feature_name |
| ) |
| results[feature_name] = result |
| |
| return results |
|
|
|
|
| |
| def detect_drift(reference, current, threshold=0.05): |
| """Legacy function for simple drift detection.""" |
| detector = DataDriftDetector(method="ks", threshold=threshold) |
| result = detector.detect_drift(reference, current) |
| |
| return { |
| "drift_score": result.drift_score, |
| "drift_detected": result.is_drift, |
| "p_value": result.p_value, |
| "statistic": result.statistic |
| }
|
|
|