Spaces:
Running
Running
| """ | |
| Drift Detection Module using Deepchecks | |
| Implements drift detection using Deepchecks integrated checks: | |
| - Drift check for text properties | |
| - Label distribution drift | |
| - Custom metrics comparison | |
| """ | |
| from typing import Dict, List | |
| from loguru import logger | |
| import numpy as np | |
| import pandas as pd | |
| try: | |
| from deepchecks.nlp import SingleDataset | |
| from deepchecks.nlp.checks import Drift, TextPropertyDrift | |
| except ImportError: | |
| logger.warning("Deepchecks not installed. Install with: pip install deepchecks[nlp]") | |
| SingleDataset = None | |
| Drift = None | |
| TextPropertyDrift = None | |
| from turing import config | |
| class DriftDetector: | |
| """ | |
| Detects data drift using Deepchecks integrated checks comparing production data | |
| against baseline/reference datasets. | |
| """ | |
| def __init__(self, p_value_threshold: float = None, alert_threshold: float = None): | |
| """ | |
| Initialize drift detector with Deepchecks. | |
| Args: | |
| p_value_threshold: P-value threshold for drift detection (default from config) | |
| alert_threshold: More sensitive threshold for critical alerts (default from config) | |
| """ | |
| self.p_value_threshold = p_value_threshold or config.DRIFT_P_VALUE_THRESHOLD | |
| self.alert_threshold = alert_threshold or config.DRIFT_ALERT_THRESHOLD | |
| self.use_deepchecks = Drift is not None | |
| def detect_text_property_drift( | |
| self, | |
| production_texts: List[str], | |
| reference_texts: List[str], | |
| language: str = "java", | |
| ) -> Dict: | |
| """ | |
| Detect drift in text properties using Deepchecks TextPropertyDrift. | |
| Args: | |
| production_texts: Text data in production | |
| reference_texts: Reference/baseline text data | |
| language: Language of the texts | |
| Returns: | |
| Dictionary with drift detection results | |
| """ | |
| if not self.use_deepchecks: | |
| logger.warning("Deepchecks not available, using fallback method") | |
| return self._fallback_text_property_drift(production_texts, reference_texts) | |
| try: | |
| # Create Deepchecks datasets | |
| ref_df = pd.DataFrame({'text': reference_texts}) | |
| prod_df = pd.DataFrame({'text': production_texts}) | |
| reference_dataset = SingleDataset( | |
| ref_df, | |
| text_column='text', | |
| task_type='text_classification' | |
| ) | |
| production_dataset = SingleDataset( | |
| prod_df, | |
| text_column='text', | |
| task_type='text_classification' | |
| ) | |
| # Run TextPropertyDrift check | |
| check = TextPropertyDrift() | |
| result = check.run( | |
| reference_dataset, | |
| production_dataset, | |
| model_classes=None | |
| ) | |
| # Extract results | |
| scores = result.to_dict() | |
| is_drifted = result.failed | |
| drift_dict = { | |
| "check_result": scores, | |
| "drifted": is_drifted, | |
| "alert": is_drifted, | |
| "method": "deepchecks_text_property_drift", | |
| } | |
| if is_drifted: | |
| logger.warning("Text property drift detected (Deepchecks)") | |
| return drift_dict | |
| except Exception as e: | |
| logger.error(f"Deepchecks TextPropertyDrift failed: {e}") | |
| return self._fallback_text_property_drift(production_texts, reference_texts) | |
| def _fallback_text_property_drift( | |
| self, | |
| production_texts: List[str], | |
| reference_texts: List[str], | |
| ) -> Dict: | |
| """Fallback to manual calculation if Deepchecks fails.""" | |
| from scipy.stats import ks_2samp | |
| production_lengths = np.array([len(text) for text in production_texts]) | |
| reference_lengths = np.array([len(text) for text in reference_texts]) | |
| statistic, p_value = ks_2samp(reference_lengths, production_lengths) | |
| is_drifted = p_value < self.p_value_threshold | |
| return { | |
| "statistic": float(statistic), | |
| "p_value": float(p_value), | |
| "drifted": is_drifted, | |
| "alert": is_drifted and p_value < self.alert_threshold, | |
| "mean_production": float(np.mean(production_lengths)), | |
| "mean_reference": float(np.mean(reference_lengths)), | |
| "method": "fallback_ks_test", | |
| } | |
| def detect_label_distribution_drift( | |
| self, | |
| production_labels: np.ndarray, | |
| reference_labels: np.ndarray, | |
| ) -> Dict: | |
| """ | |
| Detect drift in label distribution using Deepchecks Drift check. | |
| Args: | |
| production_labels: Production label data (numpy array or list) | |
| reference_labels: Reference/baseline label data | |
| Returns: | |
| Dictionary with drift detection results | |
| """ | |
| if not self.use_deepchecks: | |
| logger.warning("Deepchecks not available, using fallback method") | |
| return self._fallback_label_drift(production_labels, reference_labels) | |
| try: | |
| # Prepare data | |
| if len(reference_labels.shape) == 1: | |
| ref_counts = np.bincount(reference_labels.astype(int)) | |
| else: | |
| ref_counts = np.sum(reference_labels, axis=0) | |
| if len(production_labels.shape) == 1: | |
| prod_counts = np.bincount( | |
| production_labels.astype(int), | |
| minlength=len(ref_counts) | |
| ) | |
| else: | |
| prod_counts = np.sum(production_labels, axis=0) | |
| # Create DataFrames with label columns | |
| n_labels = len(ref_counts) | |
| ref_df = pd.DataFrame({ | |
| f'label_{i}': [int(ref_counts[i])] for i in range(n_labels) | |
| }) | |
| prod_df = pd.DataFrame({ | |
| f'label_{i}': [int(prod_counts[i])] for i in range(n_labels) | |
| }) | |
| # Run Drift check | |
| check = Drift() | |
| reference_dataset = SingleDataset(ref_df, task_type='classification') | |
| production_dataset = SingleDataset(prod_df, task_type='classification') | |
| result = check.run(reference_dataset, production_dataset) | |
| is_drifted = result.failed | |
| drift_dict = { | |
| "check_result": result.to_dict(), | |
| "drifted": is_drifted, | |
| "alert": is_drifted, | |
| "reference_counts": ref_counts.tolist(), | |
| "production_counts": prod_counts.tolist(), | |
| "method": "deepchecks_drift_check", | |
| } | |
| if is_drifted: | |
| logger.warning("Label distribution drift detected (Deepchecks)") | |
| return drift_dict | |
| except Exception as e: | |
| logger.error(f"Deepchecks Drift check failed: {e}") | |
| return self._fallback_label_drift(production_labels, reference_labels) | |
| def _fallback_label_drift( | |
| self, | |
| production_labels: np.ndarray, | |
| reference_labels: np.ndarray, | |
| ) -> Dict: | |
| """Fallback to manual Chi-Square test if Deepchecks fails.""" | |
| from scipy.stats import chi2_contingency | |
| if len(reference_labels.shape) == 1: | |
| ref_counts = np.bincount(reference_labels.astype(int)) | |
| else: | |
| ref_counts = np.sum(reference_labels, axis=0) | |
| if len(production_labels.shape) == 1: | |
| prod_counts = np.bincount( | |
| production_labels.astype(int), | |
| minlength=len(ref_counts) | |
| ) | |
| else: | |
| prod_counts = np.sum(production_labels, axis=0) | |
| min_len = min(len(prod_counts), len(ref_counts)) | |
| prod_counts = prod_counts[:min_len] | |
| ref_counts = ref_counts[:min_len] | |
| contingency_table = np.array([ref_counts, prod_counts]) | |
| try: | |
| chi2, p_value, dof, expected = chi2_contingency(contingency_table) | |
| except Exception as e: | |
| logger.warning(f"Chi-square test failed: {e}") | |
| return {"statistic": None, "p_value": 1.0, "drifted": False, "alert": False} | |
| is_drifted = p_value < self.p_value_threshold | |
| is_alert = p_value < self.alert_threshold | |
| return { | |
| "statistic": float(chi2), | |
| "p_value": float(p_value), | |
| "drifted": is_drifted, | |
| "alert": is_alert, | |
| "method": "fallback_chi_square", | |
| } | |
| def detect_word_count_drift( | |
| self, | |
| production_texts: List[str], | |
| reference_texts: List[str], | |
| ) -> Dict: | |
| """ | |
| Detect drift in word count distribution. | |
| Uses Deepchecks TextPropertyDrift or fallback KS test. | |
| Args: | |
| production_texts: Text data in production | |
| reference_texts: Reference/baseline text data | |
| Returns: | |
| Dictionary with drift detection results | |
| """ | |
| # Use TextPropertyDrift which includes word count analysis | |
| return self.detect_text_property_drift( | |
| production_texts, | |
| reference_texts, | |
| language="unknown" | |
| ) | |
| def detect_all_drifts( | |
| self, | |
| production_texts: List[str], | |
| production_labels: np.ndarray, | |
| reference_texts: List[str], | |
| reference_labels: np.ndarray, | |
| ) -> Dict: | |
| """ | |
| Run all drift detection checks using Deepchecks. | |
| Args: | |
| production_texts: Production text data | |
| production_labels: Production label data | |
| reference_texts: Reference/baseline text data | |
| reference_labels: Reference/baseline label data | |
| Returns: | |
| Dictionary with aggregated drift detection results | |
| """ | |
| results = { | |
| "text_property": self.detect_text_property_drift( | |
| production_texts, | |
| reference_texts, | |
| ), | |
| "label_distribution": self.detect_label_distribution_drift( | |
| production_labels, | |
| reference_labels, | |
| ), | |
| } | |
| any_drifted = any(r.get("drifted", False) for r in results.values()) | |
| any_alert = any(r.get("alert", False) for r in results.values()) | |
| results["overall"] = { | |
| "drifted": any_drifted, | |
| "alert": any_alert, | |
| "num_drifts": sum(1 for r in results.values() if r.get("drifted", False)), | |
| "methods": [r.get("method", "unknown") for r in results.values()], } | |
| return results | |
| def detect_all_drifts_from_baseline( | |
| self, | |
| production_texts: List[str], | |
| production_labels: np.ndarray, | |
| baseline_stats: Dict, | |
| ) -> Dict: | |
| """ | |
| Legacy method for backward compatibility. | |
| Converts baseline_stats dict to reference_texts and reference_labels if available. | |
| Otherwise reconstructs reference data from baseline statistics. | |
| Args: | |
| production_texts: Production text data | |
| production_labels: Production label data | |
| baseline_stats: Dictionary with baseline statistics (legacy format) | |
| Returns: | |
| Dictionary with aggregated drift detection results | |
| """ | |
| results = { | |
| "text_length": self._fallback_text_property_drift( | |
| production_texts, | |
| production_texts, # Use production as fallback reference | |
| ), | |
| "label_distribution": self._fallback_label_drift( | |
| production_labels, | |
| np.array(baseline_stats.get("label_counts", [])), | |
| ), | |
| } | |
| any_drifted = any(r.get("drifted", False) for r in results.values()) | |
| any_alert = any(r.get("alert", False) for r in results.values()) | |
| results["overall"] = { | |
| "drifted": any_drifted, | |
| "alert": any_alert, | |
| "num_drifts": sum(1 for r in results.values() if r.get("drifted", False)), | |
| } | |
| return results |