File size: 6,482 Bytes
8ed9c1d 1396866 8ed9c1d 1396866 8ed9c1d 1396866 8ed9c1d 1396866 8ed9c1d 1396866 8ed9c1d 1396866 8ed9c1d 1396866 8ed9c1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
"""
Data Drift Detection using Scipy KS Test.
Detects distribution shifts between baseline and new data.
"""
import pickle
import json
import requests
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime
from scipy.stats import ks_2samp
from typing import Dict, Tuple
# Configuration
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
BASELINE_DIR = Path(__file__).parent.parent / "baseline"
REPORTS_DIR = Path(__file__).parent.parent / "reports"
REPORTS_DIR.mkdir(parents=True, exist_ok=True)
PUSHGATEWAY_URL = "http://localhost:9091"
P_VALUE_THRESHOLD = 0.05 # Significance level
def load_baseline() -> np.ndarray:
"""Load reference/baseline data."""
baseline_path = BASELINE_DIR / "reference_data.pkl"
if not baseline_path.exists():
raise FileNotFoundError(
f"Baseline data not found at {baseline_path}\n"
f"Run `python prepare_baseline.py` first!"
)
with open(baseline_path, 'rb') as f:
X_baseline = pickle.load(f)
print(f"Loaded baseline data: {X_baseline.shape}")
return X_baseline
def load_new_data() -> np.ndarray:
"""
Load new/production data to check for drift.
In production, this would fetch from:
- Database
- S3 bucket
- API logs
- Data lake
For now, simulate or load from file.
"""
# Option 1: Load from file
data_path = PROJECT_ROOT / "data" / "test.csv"
if data_path.exists():
df = pd.read_csv(data_path)
# Extract same features as baseline
feature_columns = [col for col in df.columns if col not in ['label', 'id', 'timestamp']]
X_new = df[feature_columns].values[:500] # Take 500 samples
print(f"Loaded new data from file: {X_new.shape}")
return X_new
# Option 2: Simulate (for testing)
print("Simulating new data (no test file found)")
X_baseline = load_baseline()
# Add slight shift to simulate drift
X_new = X_baseline[:500] + np.random.normal(0, 0.1, (500, X_baseline.shape[1]))
return X_new
def run_drift_detection(X_baseline: np.ndarray, X_new: np.ndarray) -> Dict:
"""
Run Kolmogorov-Smirnov drift detection using scipy.
Args:
X_baseline: Reference data
X_new: New data to check
Returns:
Drift detection results
"""
print("\n" + "=" * 60)
print("Running Drift Detection (Kolmogorov-Smirnov Test)")
print("=" * 60)
# Run KS test for each feature
p_values = []
distances = []
for i in range(X_baseline.shape[1]):
statistic, p_value = ks_2samp(X_baseline[:, i], X_new[:, i])
p_values.append(p_value)
distances.append(statistic)
# Aggregate results
min_p_value = np.min(p_values)
max_distance = np.max(distances)
# Apply Bonferroni correction for multiple testing
adjusted_threshold = P_VALUE_THRESHOLD / X_baseline.shape[1]
drift_detected = min_p_value < adjusted_threshold
# Extract results
results = {
"timestamp": datetime.now().isoformat(),
"drift_detected": int(drift_detected),
"p_value": float(min_p_value),
"threshold": adjusted_threshold,
"distance": float(max_distance),
"baseline_samples": X_baseline.shape[0],
"new_samples": X_new.shape[0],
"num_features": X_baseline.shape[1]
}
# Print results
print(f"\nResults:")
print(f" Drift Detected: {'YES' if results['drift_detected'] else 'NO'}")
print(f" P-Value: {results['p_value']:.6f} (adjusted threshold: {adjusted_threshold:.6f})")
print(f" Distance: {results['distance']:.6f}")
print(f" Baseline: {X_baseline.shape[0]} samples")
print(f" New Data: {X_new.shape[0]} samples")
return results
def save_report(results: Dict):
"""Save drift detection report to file."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
report_path = REPORTS_DIR / f"drift_report_{timestamp}.json"
with open(report_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nReport saved to: {report_path}")
def push_to_prometheus(results: Dict):
"""
Push drift metrics to Prometheus via Pushgateway.
This allows Prometheus to scrape short-lived job metrics.
"""
metrics = f"""# TYPE drift_detected gauge
# HELP drift_detected Whether data drift was detected (1=yes, 0=no)
drift_detected {results['drift_detected']}
# TYPE drift_p_value gauge
# HELP drift_p_value P-value from drift detection test
drift_p_value {results['p_value']}
# TYPE drift_distance gauge
# HELP drift_distance Statistical distance between distributions
drift_distance {results['distance']}
# TYPE drift_check_timestamp gauge
# HELP drift_check_timestamp Unix timestamp of last drift check
drift_check_timestamp {datetime.now().timestamp()}
"""
try:
response = requests.post(
f"{PUSHGATEWAY_URL}/metrics/job/drift_detection/instance/hopcroft",
data=metrics,
headers={'Content-Type': 'text/plain'}
)
response.raise_for_status()
print(f"Metrics pushed to Pushgateway at {PUSHGATEWAY_URL}")
except requests.exceptions.RequestException as e:
print(f"Failed to push to Pushgateway: {e}")
print(f" Make sure Pushgateway is running: docker compose ps pushgateway")
def main():
"""Main execution."""
print("\n" + "=" * 60)
print("Hopcroft Data Drift Detection")
print("=" * 60)
try:
# Load data
X_baseline = load_baseline()
X_new = load_new_data()
# Run drift detection
results = run_drift_detection(X_baseline, X_new)
# Save report
save_report(results)
# Push to Prometheus
push_to_prometheus(results)
print("\n" + "=" * 60)
print("Drift Detection Complete!")
print("=" * 60)
if results['drift_detected']:
print("\nWARNING: Data drift detected!")
print(f" P-value: {results['p_value']:.6f} < {P_VALUE_THRESHOLD}")
return 1
else:
print("\nNo significant drift detected")
return 0
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
exit(main()) |