""" Drift Detection Module using Deepchecks Implements drift detection using Deepchecks integrated checks: - Drift check for text properties - Label distribution drift - Custom metrics comparison """ from typing import Dict, List from loguru import logger import numpy as np import pandas as pd try: from deepchecks.nlp import SingleDataset from deepchecks.nlp.checks import Drift, TextPropertyDrift except ImportError: logger.warning("Deepchecks not installed. Install with: pip install deepchecks[nlp]") SingleDataset = None Drift = None TextPropertyDrift = None from turing import config class DriftDetector: """ Detects data drift using Deepchecks integrated checks comparing production data against baseline/reference datasets. """ def __init__(self, p_value_threshold: float = None, alert_threshold: float = None): """ Initialize drift detector with Deepchecks. Args: p_value_threshold: P-value threshold for drift detection (default from config) alert_threshold: More sensitive threshold for critical alerts (default from config) """ self.p_value_threshold = p_value_threshold or config.DRIFT_P_VALUE_THRESHOLD self.alert_threshold = alert_threshold or config.DRIFT_ALERT_THRESHOLD self.use_deepchecks = Drift is not None def detect_text_property_drift( self, production_texts: List[str], reference_texts: List[str], language: str = "java", ) -> Dict: """ Detect drift in text properties using Deepchecks TextPropertyDrift. Args: production_texts: Text data in production reference_texts: Reference/baseline text data language: Language of the texts Returns: Dictionary with drift detection results """ if not self.use_deepchecks: logger.warning("Deepchecks not available, using fallback method") return self._fallback_text_property_drift(production_texts, reference_texts) try: # Create Deepchecks datasets ref_df = pd.DataFrame({'text': reference_texts}) prod_df = pd.DataFrame({'text': production_texts}) reference_dataset = SingleDataset( ref_df, text_column='text', task_type='text_classification' ) production_dataset = SingleDataset( prod_df, text_column='text', task_type='text_classification' ) # Run TextPropertyDrift check check = TextPropertyDrift() result = check.run( reference_dataset, production_dataset, model_classes=None ) # Extract results scores = result.to_dict() is_drifted = result.failed drift_dict = { "check_result": scores, "drifted": is_drifted, "alert": is_drifted, "method": "deepchecks_text_property_drift", } if is_drifted: logger.warning("Text property drift detected (Deepchecks)") return drift_dict except Exception as e: logger.error(f"Deepchecks TextPropertyDrift failed: {e}") return self._fallback_text_property_drift(production_texts, reference_texts) def _fallback_text_property_drift( self, production_texts: List[str], reference_texts: List[str], ) -> Dict: """Fallback to manual calculation if Deepchecks fails.""" from scipy.stats import ks_2samp production_lengths = np.array([len(text) for text in production_texts]) reference_lengths = np.array([len(text) for text in reference_texts]) statistic, p_value = ks_2samp(reference_lengths, production_lengths) is_drifted = p_value < self.p_value_threshold return { "statistic": float(statistic), "p_value": float(p_value), "drifted": is_drifted, "alert": is_drifted and p_value < self.alert_threshold, "mean_production": float(np.mean(production_lengths)), "mean_reference": float(np.mean(reference_lengths)), "method": "fallback_ks_test", } def detect_label_distribution_drift( self, production_labels: np.ndarray, reference_labels: np.ndarray, ) -> Dict: """ Detect drift in label distribution using Deepchecks Drift check. Args: production_labels: Production label data (numpy array or list) reference_labels: Reference/baseline label data Returns: Dictionary with drift detection results """ if not self.use_deepchecks: logger.warning("Deepchecks not available, using fallback method") return self._fallback_label_drift(production_labels, reference_labels) try: # Prepare data if len(reference_labels.shape) == 1: ref_counts = np.bincount(reference_labels.astype(int)) else: ref_counts = np.sum(reference_labels, axis=0) if len(production_labels.shape) == 1: prod_counts = np.bincount( production_labels.astype(int), minlength=len(ref_counts) ) else: prod_counts = np.sum(production_labels, axis=0) # Create DataFrames with label columns n_labels = len(ref_counts) ref_df = pd.DataFrame({ f'label_{i}': [int(ref_counts[i])] for i in range(n_labels) }) prod_df = pd.DataFrame({ f'label_{i}': [int(prod_counts[i])] for i in range(n_labels) }) # Run Drift check check = Drift() reference_dataset = SingleDataset(ref_df, task_type='classification') production_dataset = SingleDataset(prod_df, task_type='classification') result = check.run(reference_dataset, production_dataset) is_drifted = result.failed drift_dict = { "check_result": result.to_dict(), "drifted": is_drifted, "alert": is_drifted, "reference_counts": ref_counts.tolist(), "production_counts": prod_counts.tolist(), "method": "deepchecks_drift_check", } if is_drifted: logger.warning("Label distribution drift detected (Deepchecks)") return drift_dict except Exception as e: logger.error(f"Deepchecks Drift check failed: {e}") return self._fallback_label_drift(production_labels, reference_labels) def _fallback_label_drift( self, production_labels: np.ndarray, reference_labels: np.ndarray, ) -> Dict: """Fallback to manual Chi-Square test if Deepchecks fails.""" from scipy.stats import chi2_contingency if len(reference_labels.shape) == 1: ref_counts = np.bincount(reference_labels.astype(int)) else: ref_counts = np.sum(reference_labels, axis=0) if len(production_labels.shape) == 1: prod_counts = np.bincount( production_labels.astype(int), minlength=len(ref_counts) ) else: prod_counts = np.sum(production_labels, axis=0) min_len = min(len(prod_counts), len(ref_counts)) prod_counts = prod_counts[:min_len] ref_counts = ref_counts[:min_len] contingency_table = np.array([ref_counts, prod_counts]) try: chi2, p_value, dof, expected = chi2_contingency(contingency_table) except Exception as e: logger.warning(f"Chi-square test failed: {e}") return {"statistic": None, "p_value": 1.0, "drifted": False, "alert": False} is_drifted = p_value < self.p_value_threshold is_alert = p_value < self.alert_threshold return { "statistic": float(chi2), "p_value": float(p_value), "drifted": is_drifted, "alert": is_alert, "method": "fallback_chi_square", } def detect_word_count_drift( self, production_texts: List[str], reference_texts: List[str], ) -> Dict: """ Detect drift in word count distribution. Uses Deepchecks TextPropertyDrift or fallback KS test. Args: production_texts: Text data in production reference_texts: Reference/baseline text data Returns: Dictionary with drift detection results """ # Use TextPropertyDrift which includes word count analysis return self.detect_text_property_drift( production_texts, reference_texts, language="unknown" ) def detect_all_drifts( self, production_texts: List[str], production_labels: np.ndarray, reference_texts: List[str], reference_labels: np.ndarray, ) -> Dict: """ Run all drift detection checks using Deepchecks. Args: production_texts: Production text data production_labels: Production label data reference_texts: Reference/baseline text data reference_labels: Reference/baseline label data Returns: Dictionary with aggregated drift detection results """ results = { "text_property": self.detect_text_property_drift( production_texts, reference_texts, ), "label_distribution": self.detect_label_distribution_drift( production_labels, reference_labels, ), } any_drifted = any(r.get("drifted", False) for r in results.values()) any_alert = any(r.get("alert", False) for r in results.values()) results["overall"] = { "drifted": any_drifted, "alert": any_alert, "num_drifts": sum(1 for r in results.values() if r.get("drifted", False)), "methods": [r.get("method", "unknown") for r in results.values()], } return results def detect_all_drifts_from_baseline( self, production_texts: List[str], production_labels: np.ndarray, baseline_stats: Dict, ) -> Dict: """ Legacy method for backward compatibility. Converts baseline_stats dict to reference_texts and reference_labels if available. Otherwise reconstructs reference data from baseline statistics. Args: production_texts: Production text data production_labels: Production label data baseline_stats: Dictionary with baseline statistics (legacy format) Returns: Dictionary with aggregated drift detection results """ results = { "text_length": self._fallback_text_property_drift( production_texts, production_texts, # Use production as fallback reference ), "label_distribution": self._fallback_label_drift( production_labels, np.array(baseline_stats.get("label_counts", [])), ), } any_drifted = any(r.get("drifted", False) for r in results.values()) any_alert = any(r.get("alert", False) for r in results.values()) results["overall"] = { "drifted": any_drifted, "alert": any_alert, "num_drifts": sum(1 for r in results.values() if r.get("drifted", False)), } return results