Ariyan-Pro's picture
🚀 Initial upload: Self-Healing ML Pipelines
11898c7 verified
"""Data drift detection module."""
import numpy as np
from scipy import stats
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import logging
logger = logging.getLogger(__name__)
@dataclass
class DriftResult:
"""Data class for drift detection results."""
drift_score: float
p_value: Optional[float]
is_drift: bool
method: str
threshold: float
feature_name: Optional[str] = None
statistic: Optional[float] = None
sample_sizes: Optional[Dict[str, int]] = None
class DataDriftDetector:
"""Detects data drift between reference and current distributions."""
def __init__(self, method: str = "ks", threshold: float = 0.05):
"""
Initialize drift detector.
Args:
method: Detection method ('ks' for Kolmogorov-Smirnov,
'psi' for Population Stability Index)
threshold: Significance threshold
"""
self.method = method
self.threshold = threshold
self.supported_methods = ["ks", "psi", "wasserstein"]
if method not in self.supported_methods:
raise ValueError(f"Method {method} not supported. "
f"Choose from {self.supported_methods}")
def detect_drift(
self,
reference: Union[np.ndarray, List[float]],
current: Union[np.ndarray, List[float]],
feature_name: Optional[str] = None
) -> DriftResult:
"""
Detect drift between reference and current data.
Args:
reference: Reference distribution data
current: Current distribution data
feature_name: Optional name of the feature being analyzed
Returns:
DriftResult containing detection results
"""
# Convert to numpy arrays
ref_array = np.array(reference).flatten()
curr_array = np.array(current).flatten()
# Record sample sizes
sample_sizes = {
"reference": len(ref_array),
"current": len(curr_array)
}
if self.method == "ks":
return self._ks_drift(ref_array, curr_array, feature_name, sample_sizes)
elif self.method == "psi":
return self._psi_drift(ref_array, curr_array, feature_name, sample_sizes)
elif self.method == "wasserstein":
return self._wasserstein_drift(ref_array, curr_array, feature_name, sample_sizes)
def _ks_drift(
self,
reference: np.ndarray,
current: np.ndarray,
feature_name: Optional[str],
sample_sizes: Dict[str, int]
) -> DriftResult:
"""Kolmogorov-Smirnov test for drift detection."""
stat, p_value = stats.ks_2samp(reference, current)
drift_score = 1 - p_value
is_drift = p_value < self.threshold
logger.debug(
f"KS test: statistic={stat:.4f}, p={p_value:.4f}, "
f"drift_score={drift_score:.4f}, drift={is_drift}"
)
return DriftResult(
drift_score=drift_score,
p_value=p_value,
is_drift=is_drift,
method="kolmogorov_smirnov",
threshold=self.threshold,
feature_name=feature_name,
statistic=stat,
sample_sizes=sample_sizes
)
def _psi_drift(
self,
reference: np.ndarray,
current: np.ndarray,
feature_name: Optional[str],
sample_sizes: Dict[str, int]
) -> DriftResult:
"""Population Stability Index for drift detection."""
# Create bins based on reference data (ensure at least 2 bins)
n_bins = max(2, min(10, len(reference) // 20))
bins = np.histogram_bin_edges(reference, bins=n_bins)
# Calculate histograms
ref_hist, _ = np.histogram(reference, bins=bins)
curr_hist, _ = np.histogram(current, bins=bins)
# Add small epsilon to avoid division by zero
epsilon = 1e-10
# Normalize to get proportions
ref_prop = (ref_hist + epsilon) / (len(reference) + epsilon * n_bins)
curr_prop = (curr_hist + epsilon) / (len(current) + epsilon * n_bins)
# Calculate PSI
psi_values = (curr_prop - ref_prop) * np.log((curr_prop + epsilon) / (ref_prop + epsilon))
psi_total = np.sum(psi_values)
# Convert PSI to drift score (0-1)
drift_score = min(1.0, psi_total / 0.5) # PSI > 0.25 indicates drift
is_drift = psi_total > 0.25
return DriftResult(
drift_score=drift_score,
p_value=None,
is_drift=is_drift,
method="population_stability_index",
threshold=0.25,
feature_name=feature_name,
statistic=psi_total,
sample_sizes=sample_sizes
)
def _wasserstein_drift(
self,
reference: np.ndarray,
current: np.ndarray,
feature_name: Optional[str],
sample_sizes: Dict[str, int]
) -> DriftResult:
"""Wasserstein distance for drift detection."""
from scipy.stats import wasserstein_distance
distance = wasserstein_distance(reference, current)
# Normalize distance to 0-1 range (empirical threshold)
drift_score = min(1.0, distance / 10.0)
is_drift = distance > 1.0 # Empirical threshold
return DriftResult(
drift_score=drift_score,
p_value=None,
is_drift=is_drift,
method="wasserstein_distance",
threshold=1.0,
feature_name=feature_name,
statistic=distance,
sample_sizes=sample_sizes
)
def detect_batch_drift(
self,
reference_data: Dict[str, np.ndarray],
current_data: Dict[str, np.ndarray]
) -> Dict[str, DriftResult]:
"""
Detect drift for multiple features.
Args:
reference_data: Dictionary of reference features
current_data: Dictionary of current features
Returns:
Dictionary mapping feature names to drift results
"""
results = {}
for feature_name in reference_data.keys():
if feature_name in current_data:
result = self.detect_drift(
reference_data[feature_name],
current_data[feature_name],
feature_name
)
results[feature_name] = result
return results
# Legacy function for backward compatibility
def detect_drift(reference, current, threshold=0.05):
"""Legacy function for simple drift detection."""
detector = DataDriftDetector(method="ks", threshold=threshold)
result = detector.detect_drift(reference, current)
return {
"drift_score": result.drift_score,
"drift_detected": result.is_drift,
"p_value": result.p_value,
"statistic": result.statistic
}