File size: 18,949 Bytes
cea1951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
from __future__ import annotations
 
import time
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
 
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import wandb
 
from river import (
    drift as river_drift,
    preprocessing as river_pp,
    tree as river_tree,
    metrics as river_metrics,
    stream as river_stream,
)
 
warnings.filterwarnings("ignore")
 
 
# ─── Drift Event dataclass ────────────────────────────────────────────────────
 
@dataclass
class DriftEvent:
    sample_index: int
    detector:     str          # "ADWIN" | "KSWIN"
    running_auc:  float
    action:       str          # "retrain" | "alert"
    timestamp:    float = field(default_factory=time.time)
 
 
# ─── DriftMonitor ─────────────────────────────────────────────────────────────
 
class DriftMonitor:
    """

    Wraps River's ADWIN and KSWIN detectors.

    Emits DriftEvent objects when drift is detected.

 

    Parameters

    ----------

    delta_adwin : ADWIN confidence (lower = more sensitive)

    alpha_kswin : KSWIN significance level

    window_size : KSWIN sliding window size

    use_adwin   : enable ADWIN detector

    use_kswin   : enable KSWIN detector

    """
 
    def __init__(

        self,

        delta_adwin:  float = 0.002,

        alpha_kswin:  float = 0.005,

        window_size:  int   = 100,

        use_adwin:    bool  = True,

        use_kswin:    bool  = True,

    ):
        self.use_adwin  = use_adwin
        self.use_kswin  = use_kswin
        self.adwin      = river_drift.ADWIN(delta=delta_adwin)  if use_adwin  else None
        self.kswin      = river_drift.KSWIN(alpha=alpha_kswin, window_size=window_size) if use_kswin else None
        self.events:    List[DriftEvent] = []
        self._n_adwin_resets = 0
        self._n_kswin_resets = 0
 
    def update(self, error: float, sample_idx: int, running_auc: float) -> Optional[DriftEvent]:
        """

        Feed one prediction error. Returns a DriftEvent if drift detected, else None.

        ADWIN takes priority; KSWIN fires if ADWIN didn't.

        """
        evt = None
 
        if self.use_adwin:
            self.adwin.update(error)
            if self.adwin.drift_detected:
                self._n_adwin_resets += 1
                evt = DriftEvent(
                    sample_index=sample_idx,
                    detector="ADWIN",
                    running_auc=running_auc,
                    action="retrain",
                )
                self.events.append(evt)
                return evt
 
        if self.use_kswin:
            self.kswin.update(error)
            if self.kswin.drift_detected:
                self._n_kswin_resets += 1
                evt = DriftEvent(
                    sample_index=sample_idx,
                    detector="KSWIN",
                    running_auc=running_auc,
                    action="alert",
                )
                self.events.append(evt)
                return evt
 
        return None
 
    @property
    def total_detections(self) -> int:
        return len(self.events)
 
    @property
    def adwin_detections(self) -> int:
        return self._n_adwin_resets
 
    @property
    def kswin_detections(self) -> int:
        return self._n_kswin_resets
 
    def summary(self) -> Dict:
        return {
            "total_detections": self.total_detections,
            "adwin_detections": self.adwin_detections,
            "kswin_detections": self.kswin_detections,
            "drift_sample_indices": [e.sample_index for e in self.events],
        }
 
 
# ─── OnlineLearner ────────────────────────────────────────────────────────────
 
class OnlineLearner:
    """

    Online learning wrapper around Hoeffding Adaptive Tree (HATR).

    Auto-retrains when DriftMonitor fires.

 

    Parameters

    ----------

    monitor : DriftMonitor instance

    grace_period_normal   : HATR grace period under normal stream

    grace_period_post_drift: HATR grace period right after drift (faster adapt)

    log_wandb : log metrics to W&B if True

    """
 
    def __init__(

        self,

        monitor:                DriftMonitor,

        grace_period_normal:    int  = 200,

        grace_period_post_drift:int  = 50,

        log_wandb:              bool = False,

        seed:                   int  = 42,

    ):
        self.monitor                  = monitor
        self.grace_period_normal      = grace_period_normal
        self.grace_period_post_drift  = grace_period_post_drift
        self.log_wandb                = log_wandb
        self.seed                     = seed
        self._build_model(grace_period_normal)
 
        self.auc_metric      = river_metrics.ROCAUC()
        self.errors:         List[float] = []
        self.running_aucs:   List[float] = []
        self.retrain_count:  int         = 0
 
    def _build_model(self, grace_period: int):
        self.pipeline = (
            river_pp.StandardScaler()
            | river_tree.HoeffdingAdaptiveTreeClassifier(
                grace_period=grace_period,
                delta=1e-5,
                seed=self.seed,
            )
        )
 
    def run_stream(

        self,

        X: pd.DataFrame,

        y: pd.Series,

        drift_inject_at: Optional[int] = None,

        drift_income_mult: float = 0.4,

        drift_label_noise: float = 0.12,

        drift_duration: int = 5000,

        verbose_every: int = 5000,

    ) -> Dict:
        """

        Stream all rows through the online learner.

 

        Parameters

        ----------

        drift_inject_at   : sample index to start injecting synthetic drift (None = no injection)

        drift_income_mult : income multiplier during drift window

        drift_label_noise : fraction of labels to flip during drift

        drift_duration    : how many samples the drift lasts

        verbose_every     : print progress every N samples

 

        Returns

        -------

        results dict with all tracked metrics

        """
        print(f"🌊 Streaming {len(X):,} samples through online learner...")
        if drift_inject_at:
            print(f"   Synthetic drift will be injected at sample {drift_inject_at:,} "
                  f"for {drift_duration:,} samples (incomeΓ—{drift_income_mult})")
 
        start = time.time()
        income_col = "AMT_INCOME_TOTAL" if "AMT_INCOME_TOTAL" in X.columns else None
 
        for i, (xi, yi) in enumerate(river_stream.iter_pandas(X, y)):
 
            # ── Optional: synthetic drift injection ─────────────────────
            if drift_inject_at and drift_inject_at <= i < drift_inject_at + drift_duration:
                xi = dict(xi)
                if income_col:
                    xi[income_col] = xi[income_col] * drift_income_mult
                if np.random.random() < drift_label_noise:
                    yi = 1 - yi
 
            # ── Predict ─────────────────────────────────────────────────
            y_prob = self.pipeline.predict_proba_one(xi)
            p1     = y_prob.get(1, 0.5)
 
            # ── Update metric ────────────────────────────────────────────
            self.auc_metric.update(yi, p1)
            current_auc = self.auc_metric.get()
            self.running_aucs.append(current_auc)
 
            error = abs(yi - p1)
            self.errors.append(error)
 
            # ── Drift detection ──────────────────────────────────────────
            evt = self.monitor.update(error, i, current_auc)
            if evt is not None:
                self.retrain_count += 1
                gp = self.grace_period_post_drift if evt.action == "retrain" else self.grace_period_normal
                self._build_model(gp)
 
                if self.retrain_count <= 8:
                    print(f"  🚨 [{evt.detector}] Drift @ sample {i:,} | "
                          f"AUC: {current_auc:.4f} | Retrain #{self.retrain_count}")
 
                if self.log_wandb:
                    wandb.log({
                        "online/drift_detected_at": i,
                        "online/detector": evt.detector,
                        "online/auc_at_drift": current_auc,
                        "online/retrain_count": self.retrain_count,
                    })
 
            # ── Learn ────────────────────────────────────────────────────
            self.pipeline.learn_one(xi, yi)
 
            # ── Periodic logging ─────────────────────────────────────────
            if (i + 1) % verbose_every == 0:
                elapsed = time.time() - start
                print(f"  [{i+1:>7,}] AUC={current_auc:.4f} | "
                      f"Drifts={self.monitor.total_detections} | "
                      f"Elapsed={elapsed:.0f}s")
                if self.log_wandb:
                    wandb.log({"online/auc": current_auc, "online/sample": i + 1})
 
        elapsed = time.time() - start
        results = {
            "final_auc":         current_auc,
            "total_samples":     len(X),
            "elapsed_seconds":   elapsed,
            "throughput":        len(X) / elapsed,
            "total_retrains":    self.retrain_count,
            **self.monitor.summary(),
        }
        print(f"\nβœ… Stream complete | Final AUC: {current_auc:.5f} | "
              f"Drift events: {self.monitor.total_detections} | "
              f"Time: {elapsed:.1f}s | Throughput: {results['throughput']:.0f} samples/s")
        return results
 
 
# ─── DriftSimulator ───────────────────────────────────────────────────────────
 
class DriftSimulator:
    """

    Simulates various economic shock scenarios on batch data.

    Useful for evaluating model degradation before deploying drift detection.

 

    Parameters

    ----------

    model_predict_fn : callable β€” takes a pd.DataFrame, returns probability array

    feature_cols     : list of feature column names

    """
 
    SCENARIOS = {
        "Baseline":                {"income_mult": 1.0,  "emp_mask": 0.00, "label_noise": 0.00},
        "Mild Income Shock -30%":  {"income_mult": 0.70, "emp_mask": 0.05, "label_noise": 0.02},
        "Severe Income Shock -60%":{"income_mult": 0.40, "emp_mask": 0.15, "label_noise": 0.05},
        "Mass Job Loss 20%":       {"income_mult": 0.50, "emp_mask": 0.20, "label_noise": 0.08},
        "Full Economic Collapse":  {"income_mult": 0.25, "emp_mask": 0.40, "label_noise": 0.15},
    }
 
    def __init__(self, model_predict_fn, feature_cols: List[str]):
        self.predict     = model_predict_fn
        self.feature_cols = feature_cols
 
    def _apply_shock(

        self,

        X: pd.DataFrame,

        y: np.ndarray,

        income_mult: float,

        emp_mask: float,

        label_noise: float,

        seed: int = 42,

    ) -> Tuple[pd.DataFrame, np.ndarray]:
        from sklearn.metrics import roc_auc_score
 
        rng      = np.random.RandomState(seed)
        X_shock  = X.copy()
        y_shock  = y.copy()
 
        # Income shock
        for col in [c for c in X_shock.columns if "INCOME" in c]:
            X_shock[col] *= income_mult
 
        # Employment shock β€” zero out employment columns for `emp_mask` fraction
        emp_cols = [c for c in X_shock.columns if "EMPLOY" in c or "DAYS_EMPLOYED" in c]
        mask     = rng.random(len(X_shock)) < emp_mask
        for col in emp_cols:
            X_shock.loc[mask, col] = 0
 
        # Label noise
        noise_idx = rng.choice(len(y_shock), int(label_noise * len(y_shock)), replace=False)
        y_shock[noise_idx] = 1 - y_shock[noise_idx]
 
        return X_shock, y_shock
 
    def run_all_scenarios(

        self,

        X: pd.DataFrame,

        y: np.ndarray,

        log_wandb: bool = False,

    ) -> pd.DataFrame:
        from sklearn.metrics import roc_auc_score
 
        results = []
        for name, params in self.SCENARIOS.items():
            X_s, y_s = self._apply_shock(X, y, **params)
            # Align columns
            for col in self.feature_cols:
                if col not in X_s.columns:
                    X_s[col] = 0.0
            X_s = X_s[self.feature_cols]
 
            preds = self.predict(X_s)
            auc   = roc_auc_score(y_s, preds)
 
            results.append({
                "scenario":    name,
                "auc":         auc,
                "income_mult": params["income_mult"],
                "emp_mask":    params["emp_mask"],
                "label_noise": params["label_noise"],
            })
 
            if log_wandb:
                wandb.log({"drift_sim/scenario": name, "drift_sim/auc": auc})
 
            print(f"  {name:<35s} | AUC: {auc:.5f}")
 
        df = pd.DataFrame(results)
        baseline_auc = df.loc[df["scenario"] == "Baseline", "auc"].values[0]
        df["auc_drop"]    = baseline_auc - df["auc"]
        df["pct_drop"]    = (df["auc_drop"] / baseline_auc * 100).round(2)
        return df
 
 
# ─── DriftDashboard ───────────────────────────────────────────────────────────
 
class DriftDashboard:
    """

    Generates publication-quality drift analysis plots.

    """
 
    @staticmethod
    def plot_error_stream(

        errors: List[float],

        drift_events: List[DriftEvent],

        drift_inject_at: Optional[int] = None,

        window: int = 500,

        save_path: Optional[str] = None,

    ) -> plt.Figure:
        """Smoothed error stream with drift markers."""
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 9), sharex=True)
        smoothed = pd.Series(errors).rolling(window).mean()
 
        ax1.plot(smoothed, color="#1565C0", linewidth=1.2, label=f"Error (rolling {window})")
        ax1.fill_between(range(len(smoothed)), smoothed, alpha=0.15, color="#1565C0")
 
        colors = {"ADWIN": "#F44336", "KSWIN": "#FF9800"}
        for evt in drift_events:
            c = colors.get(evt.detector, "#9C27B0")
            ax1.axvline(evt.sample_index, color=c, linewidth=0.7, alpha=0.8)
 
        if drift_inject_at:
            ax1.axvline(drift_inject_at, color="orange", linewidth=2.5,
                        linestyle="--", label="Synthetic Drift Injected")
 
        ax1.set_ylabel("Prediction Error")
        ax1.set_title("ADWIN + KSWIN Drift Detection β€” Error Stream", fontsize=13, fontweight="bold")
        ax1.legend(loc="upper right")
 
        # Cumulative detections
        if drift_events:
            indices = [e.sample_index for e in drift_events]
            ax2.step(indices, range(1, len(indices)+1), color="#F44336", linewidth=2)
            if drift_inject_at:
                ax2.axvline(drift_inject_at, color="orange", linewidth=2.5, linestyle="--")
        ax2.set_ylabel("Cumulative Detections")
        ax2.set_xlabel("Sample Index")
        ax2.set_title("Cumulative Drift Events", fontsize=11)
 
        plt.tight_layout()
        if save_path:
            fig.savefig(save_path, dpi=150, bbox_inches="tight")
            print(f"βœ… Saved drift stream plot β†’ {save_path}")
        return fig
 
    @staticmethod
    def plot_scenario_degradation(

        drift_df: pd.DataFrame,

        save_path: Optional[str] = None,

    ) -> go.Figure:
        """Plotly bar chart of AUC across drift scenarios."""
        PALETTE = ["#4CAF50", "#8BC34A", "#FF9800", "#F44336", "#B71C1C"]
 
        fig = go.Figure()
        fig.add_trace(go.Bar(
            x=drift_df["scenario"],
            y=drift_df["auc"],
            marker_color=PALETTE[:len(drift_df)],
            text=[f"{a:.4f}<br>({d:+.4f})" for a, d in zip(drift_df["auc"], -drift_df["auc_drop"])],
            textposition="outside",
        ))
        fig.add_hline(y=0.70, line_dash="dash", line_color="#F44336",
                      annotation_text="Min Acceptable AUC (0.70)")
        fig.update_layout(
            title="Model AUC Under Concept Drift Scenarios",
            xaxis_title="Scenario",
            yaxis_title="ROC-AUC",
            yaxis_range=[0.5, max(drift_df["auc"]) + 0.05],
            height=480,
            template="plotly_white",
        )
 
        if save_path:
            fig.write_image(save_path)
            print(f"βœ… Saved scenario degradation plot β†’ {save_path}")
        return fig
 
    @staticmethod
    def plot_income_sensitivity(

        drift_df: pd.DataFrame,

        save_path: Optional[str] = None,

    ) -> go.Figure:
        """Plotly line: AUC vs income multiplier."""
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=drift_df["income_mult"],
            y=drift_df["auc"],
            mode="lines+markers",
            name="Ensemble AUC",
            line=dict(color="#F44336", width=3),
            marker=dict(size=10, color=[
                "#4CAF50" if a > 0.75 else "#FF9800" if a > 0.65 else "#F44336"
                for a in drift_df["auc"]
            ]),
        ))
        fig.add_hline(y=0.70, line_dash="dash", line_color="#666",
                      annotation_text="Min Acceptable AUC")
        fig.update_layout(
            title="AUC Degradation vs Income Shock Severity",
            xaxis_title="Remaining Income Fraction (1.0 = no shock)",
            yaxis_title="ROC-AUC",
            height=420,
            template="plotly_white",
        )
 
        if save_path:
            fig.write_image(save_path)
            print(f"βœ… Saved income sensitivity plot β†’ {save_path}")
        return fig