File size: 12,367 Bytes
38593e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5abc469
38593e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
Drift Detection Module using Deepchecks

Implements drift detection using Deepchecks integrated checks:
- Drift check for text properties
- Label distribution drift
- Custom metrics comparison
"""

from typing import Dict, List

from loguru import logger
import numpy as np
import pandas as pd

try:
    from deepchecks.nlp import SingleDataset
    from deepchecks.nlp.checks import Drift, TextPropertyDrift
except ImportError:
    logger.warning("Deepchecks not installed. Install with: pip install deepchecks[nlp]")
    SingleDataset = None
    Drift = None
    TextPropertyDrift = None

from turing import config


class DriftDetector:
    """
    Detects data drift using Deepchecks integrated checks comparing production data
    against baseline/reference datasets.
    """

    def __init__(self, p_value_threshold: float = None, alert_threshold: float = None):
        """
        Initialize drift detector with Deepchecks.

        Args:
            p_value_threshold: P-value threshold for drift detection (default from config)
            alert_threshold: More sensitive threshold for critical alerts (default from config)
        """
        self.p_value_threshold = p_value_threshold or config.DRIFT_P_VALUE_THRESHOLD
        self.alert_threshold = alert_threshold or config.DRIFT_ALERT_THRESHOLD
        self.use_deepchecks = Drift is not None

    def detect_text_property_drift(
        self,
        production_texts: List[str],
        reference_texts: List[str],
        language: str = "java",
    ) -> Dict:
        """
        Detect drift in text properties using Deepchecks TextPropertyDrift.
        
        Args:
            production_texts: Text data in production
            reference_texts: Reference/baseline text data
            language: Language of the texts
            
        Returns:
            Dictionary with drift detection results
        """
        if not self.use_deepchecks:
            logger.warning("Deepchecks not available, using fallback method")
            return self._fallback_text_property_drift(production_texts, reference_texts)

        try:
            # Create Deepchecks datasets
            ref_df = pd.DataFrame({'text': reference_texts})
            prod_df = pd.DataFrame({'text': production_texts})
            
            reference_dataset = SingleDataset(
                ref_df,
                text_column='text',
                task_type='text_classification'
            )
            
            production_dataset = SingleDataset(
                prod_df,
                text_column='text',
                task_type='text_classification'
            )
            
            # Run TextPropertyDrift check
            check = TextPropertyDrift()
            result = check.run(
                reference_dataset,
                production_dataset,
                model_classes=None
            )
            
            # Extract results
            scores = result.to_dict()
            is_drifted = result.failed
            
            drift_dict = {
                "check_result": scores,
                "drifted": is_drifted,
                "alert": is_drifted,  
                "method": "deepchecks_text_property_drift",
            }
            
            if is_drifted:
                logger.warning("Text property drift detected (Deepchecks)")
            
            return drift_dict
            
        except Exception as e:
            logger.error(f"Deepchecks TextPropertyDrift failed: {e}")
            return self._fallback_text_property_drift(production_texts, reference_texts)

    def _fallback_text_property_drift(
        self,
        production_texts: List[str],
        reference_texts: List[str],
    ) -> Dict:
        """Fallback to manual calculation if Deepchecks fails."""
        from scipy.stats import ks_2samp
        
        production_lengths = np.array([len(text) for text in production_texts])
        reference_lengths = np.array([len(text) for text in reference_texts])
        statistic, p_value = ks_2samp(reference_lengths, production_lengths)

        is_drifted = p_value < self.p_value_threshold
        
        return {
            "statistic": float(statistic),
            "p_value": float(p_value),
            "drifted": is_drifted,
            "alert": is_drifted and p_value < self.alert_threshold,
            "mean_production": float(np.mean(production_lengths)),
            "mean_reference": float(np.mean(reference_lengths)),
            "method": "fallback_ks_test",
        }

    def detect_label_distribution_drift(
        self,
        production_labels: np.ndarray,
        reference_labels: np.ndarray,
    ) -> Dict:
        """
        Detect drift in label distribution using Deepchecks Drift check.
        
        Args:
            production_labels: Production label data (numpy array or list)
            reference_labels: Reference/baseline label data
            
        Returns:
            Dictionary with drift detection results
        """
        if not self.use_deepchecks:
            logger.warning("Deepchecks not available, using fallback method")
            return self._fallback_label_drift(production_labels, reference_labels)

        try:
            # Prepare data
            if len(reference_labels.shape) == 1:
                ref_counts = np.bincount(reference_labels.astype(int))
            else:
                ref_counts = np.sum(reference_labels, axis=0)
            
            if len(production_labels.shape) == 1:
                prod_counts = np.bincount(
                    production_labels.astype(int),
                    minlength=len(ref_counts)
                )
            else:
                prod_counts = np.sum(production_labels, axis=0)
            
            # Create DataFrames with label columns
            n_labels = len(ref_counts)
            ref_df = pd.DataFrame({
                f'label_{i}': [int(ref_counts[i])] for i in range(n_labels)
            })
            prod_df = pd.DataFrame({
                f'label_{i}': [int(prod_counts[i])] for i in range(n_labels)
            })
            
            # Run Drift check
            check = Drift()
            reference_dataset = SingleDataset(ref_df, task_type='classification')
            production_dataset = SingleDataset(prod_df, task_type='classification')
            
            result = check.run(reference_dataset, production_dataset)
            
            is_drifted = result.failed
            
            drift_dict = {
                "check_result": result.to_dict(),
                "drifted": is_drifted,
                "alert": is_drifted,
                "reference_counts": ref_counts.tolist(),
                "production_counts": prod_counts.tolist(),
                "method": "deepchecks_drift_check",
            }
            
            if is_drifted:
                logger.warning("Label distribution drift detected (Deepchecks)")
            
            return drift_dict
            
        except Exception as e:
            logger.error(f"Deepchecks Drift check failed: {e}")
            return self._fallback_label_drift(production_labels, reference_labels)

    def _fallback_label_drift(
        self,
        production_labels: np.ndarray,
        reference_labels: np.ndarray,
    ) -> Dict:
        """Fallback to manual Chi-Square test if Deepchecks fails."""
        from scipy.stats import chi2_contingency
        
        if len(reference_labels.shape) == 1:
            ref_counts = np.bincount(reference_labels.astype(int))
        else:
            ref_counts = np.sum(reference_labels, axis=0)
        
        if len(production_labels.shape) == 1:
            prod_counts = np.bincount(
                production_labels.astype(int),
                minlength=len(ref_counts)
            )
        else:
            prod_counts = np.sum(production_labels, axis=0)
        
        min_len = min(len(prod_counts), len(ref_counts))
        prod_counts = prod_counts[:min_len]
        ref_counts = ref_counts[:min_len]
        
        contingency_table = np.array([ref_counts, prod_counts])
        
        try:
            chi2, p_value, dof, expected = chi2_contingency(contingency_table)
        except Exception as e:
            logger.warning(f"Chi-square test failed: {e}")
            return {"statistic": None, "p_value": 1.0, "drifted": False, "alert": False}
        
        is_drifted = p_value < self.p_value_threshold
        is_alert = p_value < self.alert_threshold
        
        return {
            "statistic": float(chi2),
            "p_value": float(p_value),
            "drifted": is_drifted,
            "alert": is_alert,
            "method": "fallback_chi_square",
        }

    def detect_word_count_drift(
        self,
        production_texts: List[str],
        reference_texts: List[str],
    ) -> Dict:
        """
        Detect drift in word count distribution.
        Uses Deepchecks TextPropertyDrift or fallback KS test.
        
        Args:
            production_texts: Text data in production
            reference_texts: Reference/baseline text data
            
        Returns:
            Dictionary with drift detection results
        """
        # Use TextPropertyDrift which includes word count analysis
        return self.detect_text_property_drift(
            production_texts,
            reference_texts,
            language="unknown"
        )

    def detect_all_drifts(
        self,
        production_texts: List[str],
        production_labels: np.ndarray,
        reference_texts: List[str],
        reference_labels: np.ndarray,
    ) -> Dict:
        """
        Run all drift detection checks using Deepchecks.
        
        Args:
            production_texts: Production text data
            production_labels: Production label data
            reference_texts: Reference/baseline text data
            reference_labels: Reference/baseline label data
            
        Returns:
            Dictionary with aggregated drift detection results
        """
        results = {
            "text_property": self.detect_text_property_drift(
                production_texts,
                reference_texts,
            ),
            "label_distribution": self.detect_label_distribution_drift(
                production_labels,
                reference_labels,
            ),
        }

        any_drifted = any(r.get("drifted", False) for r in results.values())
        any_alert = any(r.get("alert", False) for r in results.values())

        results["overall"] = {
            "drifted": any_drifted,
            "alert": any_alert,
            "num_drifts": sum(1 for r in results.values() if r.get("drifted", False)),
            "methods": [r.get("method", "unknown") for r in results.values()],        }

        return results

    def detect_all_drifts_from_baseline(
        self,
        production_texts: List[str],
        production_labels: np.ndarray,
        baseline_stats: Dict,
    ) -> Dict:
        """
        Legacy method for backward compatibility.
        Converts baseline_stats dict to reference_texts and reference_labels if available.
        Otherwise reconstructs reference data from baseline statistics.
        
        Args:
            production_texts: Production text data
            production_labels: Production label data
            baseline_stats: Dictionary with baseline statistics (legacy format)
            
        Returns:
            Dictionary with aggregated drift detection results
        """
        
        results = {
            "text_length": self._fallback_text_property_drift(
                production_texts,
                production_texts,  # Use production as fallback reference
            ),
            "label_distribution": self._fallback_label_drift(
                production_labels,
                np.array(baseline_stats.get("label_counts", [])),
            ),
        }

        any_drifted = any(r.get("drifted", False) for r in results.values())
        any_alert = any(r.get("alert", False) for r in results.values())

        results["overall"] = {
            "drifted": any_drifted,
            "alert": any_alert,
            "num_drifts": sum(1 for r in results.values() if r.get("drifted", False)),
        }

        return results