turing-space / turing /monitoring /drift_detector.py
github-actions[bot]
Sync turing folder from GitHub
5abc469
"""
Drift Detection Module using Deepchecks
Implements drift detection using Deepchecks integrated checks:
- Drift check for text properties
- Label distribution drift
- Custom metrics comparison
"""
from typing import Dict, List
from loguru import logger
import numpy as np
import pandas as pd
try:
from deepchecks.nlp import SingleDataset
from deepchecks.nlp.checks import Drift, TextPropertyDrift
except ImportError:
logger.warning("Deepchecks not installed. Install with: pip install deepchecks[nlp]")
SingleDataset = None
Drift = None
TextPropertyDrift = None
from turing import config
class DriftDetector:
"""
Detects data drift using Deepchecks integrated checks comparing production data
against baseline/reference datasets.
"""
def __init__(self, p_value_threshold: float = None, alert_threshold: float = None):
"""
Initialize drift detector with Deepchecks.
Args:
p_value_threshold: P-value threshold for drift detection (default from config)
alert_threshold: More sensitive threshold for critical alerts (default from config)
"""
self.p_value_threshold = p_value_threshold or config.DRIFT_P_VALUE_THRESHOLD
self.alert_threshold = alert_threshold or config.DRIFT_ALERT_THRESHOLD
self.use_deepchecks = Drift is not None
def detect_text_property_drift(
self,
production_texts: List[str],
reference_texts: List[str],
language: str = "java",
) -> Dict:
"""
Detect drift in text properties using Deepchecks TextPropertyDrift.
Args:
production_texts: Text data in production
reference_texts: Reference/baseline text data
language: Language of the texts
Returns:
Dictionary with drift detection results
"""
if not self.use_deepchecks:
logger.warning("Deepchecks not available, using fallback method")
return self._fallback_text_property_drift(production_texts, reference_texts)
try:
# Create Deepchecks datasets
ref_df = pd.DataFrame({'text': reference_texts})
prod_df = pd.DataFrame({'text': production_texts})
reference_dataset = SingleDataset(
ref_df,
text_column='text',
task_type='text_classification'
)
production_dataset = SingleDataset(
prod_df,
text_column='text',
task_type='text_classification'
)
# Run TextPropertyDrift check
check = TextPropertyDrift()
result = check.run(
reference_dataset,
production_dataset,
model_classes=None
)
# Extract results
scores = result.to_dict()
is_drifted = result.failed
drift_dict = {
"check_result": scores,
"drifted": is_drifted,
"alert": is_drifted,
"method": "deepchecks_text_property_drift",
}
if is_drifted:
logger.warning("Text property drift detected (Deepchecks)")
return drift_dict
except Exception as e:
logger.error(f"Deepchecks TextPropertyDrift failed: {e}")
return self._fallback_text_property_drift(production_texts, reference_texts)
def _fallback_text_property_drift(
self,
production_texts: List[str],
reference_texts: List[str],
) -> Dict:
"""Fallback to manual calculation if Deepchecks fails."""
from scipy.stats import ks_2samp
production_lengths = np.array([len(text) for text in production_texts])
reference_lengths = np.array([len(text) for text in reference_texts])
statistic, p_value = ks_2samp(reference_lengths, production_lengths)
is_drifted = p_value < self.p_value_threshold
return {
"statistic": float(statistic),
"p_value": float(p_value),
"drifted": is_drifted,
"alert": is_drifted and p_value < self.alert_threshold,
"mean_production": float(np.mean(production_lengths)),
"mean_reference": float(np.mean(reference_lengths)),
"method": "fallback_ks_test",
}
def detect_label_distribution_drift(
self,
production_labels: np.ndarray,
reference_labels: np.ndarray,
) -> Dict:
"""
Detect drift in label distribution using Deepchecks Drift check.
Args:
production_labels: Production label data (numpy array or list)
reference_labels: Reference/baseline label data
Returns:
Dictionary with drift detection results
"""
if not self.use_deepchecks:
logger.warning("Deepchecks not available, using fallback method")
return self._fallback_label_drift(production_labels, reference_labels)
try:
# Prepare data
if len(reference_labels.shape) == 1:
ref_counts = np.bincount(reference_labels.astype(int))
else:
ref_counts = np.sum(reference_labels, axis=0)
if len(production_labels.shape) == 1:
prod_counts = np.bincount(
production_labels.astype(int),
minlength=len(ref_counts)
)
else:
prod_counts = np.sum(production_labels, axis=0)
# Create DataFrames with label columns
n_labels = len(ref_counts)
ref_df = pd.DataFrame({
f'label_{i}': [int(ref_counts[i])] for i in range(n_labels)
})
prod_df = pd.DataFrame({
f'label_{i}': [int(prod_counts[i])] for i in range(n_labels)
})
# Run Drift check
check = Drift()
reference_dataset = SingleDataset(ref_df, task_type='classification')
production_dataset = SingleDataset(prod_df, task_type='classification')
result = check.run(reference_dataset, production_dataset)
is_drifted = result.failed
drift_dict = {
"check_result": result.to_dict(),
"drifted": is_drifted,
"alert": is_drifted,
"reference_counts": ref_counts.tolist(),
"production_counts": prod_counts.tolist(),
"method": "deepchecks_drift_check",
}
if is_drifted:
logger.warning("Label distribution drift detected (Deepchecks)")
return drift_dict
except Exception as e:
logger.error(f"Deepchecks Drift check failed: {e}")
return self._fallback_label_drift(production_labels, reference_labels)
def _fallback_label_drift(
self,
production_labels: np.ndarray,
reference_labels: np.ndarray,
) -> Dict:
"""Fallback to manual Chi-Square test if Deepchecks fails."""
from scipy.stats import chi2_contingency
if len(reference_labels.shape) == 1:
ref_counts = np.bincount(reference_labels.astype(int))
else:
ref_counts = np.sum(reference_labels, axis=0)
if len(production_labels.shape) == 1:
prod_counts = np.bincount(
production_labels.astype(int),
minlength=len(ref_counts)
)
else:
prod_counts = np.sum(production_labels, axis=0)
min_len = min(len(prod_counts), len(ref_counts))
prod_counts = prod_counts[:min_len]
ref_counts = ref_counts[:min_len]
contingency_table = np.array([ref_counts, prod_counts])
try:
chi2, p_value, dof, expected = chi2_contingency(contingency_table)
except Exception as e:
logger.warning(f"Chi-square test failed: {e}")
return {"statistic": None, "p_value": 1.0, "drifted": False, "alert": False}
is_drifted = p_value < self.p_value_threshold
is_alert = p_value < self.alert_threshold
return {
"statistic": float(chi2),
"p_value": float(p_value),
"drifted": is_drifted,
"alert": is_alert,
"method": "fallback_chi_square",
}
def detect_word_count_drift(
self,
production_texts: List[str],
reference_texts: List[str],
) -> Dict:
"""
Detect drift in word count distribution.
Uses Deepchecks TextPropertyDrift or fallback KS test.
Args:
production_texts: Text data in production
reference_texts: Reference/baseline text data
Returns:
Dictionary with drift detection results
"""
# Use TextPropertyDrift which includes word count analysis
return self.detect_text_property_drift(
production_texts,
reference_texts,
language="unknown"
)
def detect_all_drifts(
self,
production_texts: List[str],
production_labels: np.ndarray,
reference_texts: List[str],
reference_labels: np.ndarray,
) -> Dict:
"""
Run all drift detection checks using Deepchecks.
Args:
production_texts: Production text data
production_labels: Production label data
reference_texts: Reference/baseline text data
reference_labels: Reference/baseline label data
Returns:
Dictionary with aggregated drift detection results
"""
results = {
"text_property": self.detect_text_property_drift(
production_texts,
reference_texts,
),
"label_distribution": self.detect_label_distribution_drift(
production_labels,
reference_labels,
),
}
any_drifted = any(r.get("drifted", False) for r in results.values())
any_alert = any(r.get("alert", False) for r in results.values())
results["overall"] = {
"drifted": any_drifted,
"alert": any_alert,
"num_drifts": sum(1 for r in results.values() if r.get("drifted", False)),
"methods": [r.get("method", "unknown") for r in results.values()], }
return results
def detect_all_drifts_from_baseline(
self,
production_texts: List[str],
production_labels: np.ndarray,
baseline_stats: Dict,
) -> Dict:
"""
Legacy method for backward compatibility.
Converts baseline_stats dict to reference_texts and reference_labels if available.
Otherwise reconstructs reference data from baseline statistics.
Args:
production_texts: Production text data
production_labels: Production label data
baseline_stats: Dictionary with baseline statistics (legacy format)
Returns:
Dictionary with aggregated drift detection results
"""
results = {
"text_length": self._fallback_text_property_drift(
production_texts,
production_texts, # Use production as fallback reference
),
"label_distribution": self._fallback_label_drift(
production_labels,
np.array(baseline_stats.get("label_counts", [])),
),
}
any_drifted = any(r.get("drifted", False) for r in results.values())
any_alert = any(r.get("alert", False) for r in results.values())
results["overall"] = {
"drifted": any_drifted,
"alert": any_alert,
"num_drifts": sum(1 for r in results.values() if r.get("drifted", False)),
}
return results