| import argparse |
| import os |
| import sqlite3 |
| from datetime import datetime, timedelta |
| from pathlib import Path |
|
|
| import mlflow |
| import numpy as np |
| import pandas as pd |
| from dotenv import load_dotenv |
|
|
| |
| from evidently import Report |
| from evidently.presets import DataDriftPreset |
|
|
| load_dotenv(dotenv_path=".env") |
|
|
| |
| |
| REFERENCE_COUNTS = {"World": 1900, "Sports": 1900, "Business": 1900, "Sci/Tech": 1900} |
|
|
| |
| REFERENCE_CONFIDENCE_MEAN = 0.97 |
| REFERENCE_CONFIDENCE_STD = 0.05 |
|
|
| MIN_SAMPLES = 30 |
|
|
|
|
| def build_reference_df() -> pd.DataFrame: |
| rng = np.random.default_rng(seed=42) |
| rows = [] |
| for label, count in REFERENCE_COUNTS.items(): |
| confidences = np.clip( |
| rng.normal(REFERENCE_CONFIDENCE_MEAN, REFERENCE_CONFIDENCE_STD, count), |
| 0.5, |
| 1.0, |
| ) |
| for conf in confidences: |
| rows.append({"label": label, "confidence": round(float(conf), 4)}) |
| return pd.DataFrame(rows) |
|
|
|
|
| def load_current_df(db_path: str, hours: int) -> pd.DataFrame: |
| cutoff = (datetime.utcnow() - timedelta(hours=hours)).isoformat() |
| conn = sqlite3.connect(db_path) |
| df = pd.read_sql_query( |
| """ |
| SELECT label, confidence, created_at |
| FROM classifications |
| WHERE created_at >= ? AND cached = 0 |
| ORDER BY created_at |
| """, |
| conn, |
| params=(cutoff,), |
| ) |
| conn.close() |
| return df |
|
|
|
|
| def run_evidently(reference_df: pd.DataFrame, current_df: pd.DataFrame, html_path: str): |
| |
| |
| snapshot = Report([DataDriftPreset()]).run( |
| reference_data=reference_df[["label", "confidence"]], |
| current_data=current_df[["label", "confidence"]], |
| ) |
| snapshot.save_html(html_path) |
| return snapshot.dict() |
|
|
|
|
| def extract_drift_result(report_dict: dict) -> tuple[bool, float, list[str]]: |
| """ |
| Parse the Evidently 0.7 snapshot dict. |
| DriftedColumnsCount gives the overall share of drifted columns. |
| ValueDrift per column gives p-values — drift when p_value < threshold (0.05). |
| """ |
| drifted_share = 0.0 |
| drifted_columns = [] |
|
|
| for metric in report_dict.get("metrics", []): |
| name = metric.get("metric_name", "") |
| value = metric.get("value") |
|
|
| if "DriftedColumnsCount" in name and isinstance(value, dict): |
| drifted_share = value.get("share", 0.0) |
|
|
| if "ValueDrift" in name and isinstance(value, (int, float)): |
| |
| config = metric.get("config", {}) |
| column = config.get("column", "") |
| threshold = config.get("threshold", 0.05) |
| if value < threshold: |
| drifted_columns.append(column) |
|
|
| |
| drift_detected = drifted_share >= 0.5 |
| return drift_detected, drifted_share, drifted_columns |
|
|
|
|
| def log_to_mlflow( |
| drift_detected: bool, |
| drift_share: float, |
| drifted_columns: list[str], |
| current_df: pd.DataFrame, |
| html_path: str, |
| ): |
| mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"]) |
| mlflow.set_experiment("drift-monitoring") |
|
|
| run_name = f"drift-{datetime.utcnow().strftime('%Y%m%d-%H%M')}" |
| with mlflow.start_run(run_name=run_name): |
| mlflow.log_metric("drift_detected", int(drift_detected)) |
| mlflow.log_metric("drift_share_of_columns", round(drift_share, 4)) |
| mlflow.log_metric("sample_count", len(current_df)) |
| mlflow.log_metric("mean_confidence", round(current_df["confidence"].mean(), 4)) |
| mlflow.log_metric("p10_confidence", round(current_df["confidence"].quantile(0.10), 4)) |
|
|
| label_dist = current_df["label"].value_counts(normalize=True) |
| for label, share in label_dist.items(): |
| key = label.lower().replace("/", "_") |
| mlflow.log_metric(f"label_share_{key}", round(share, 4)) |
|
|
| if drifted_columns: |
| mlflow.set_tag("drifted_columns", ", ".join(drifted_columns)) |
|
|
| mlflow.log_artifact(html_path) |
|
|
| print(f"[drift] metrics logged to MLflow run '{run_name}'") |
|
|
|
|
| def run(db_path: str = "classifications.db", hours: int = 24, output_dir: str = "reports"): |
| Path(output_dir).mkdir(exist_ok=True) |
| timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M") |
| html_path = str(Path(output_dir) / f"drift_{timestamp}.html") |
|
|
| print("[drift] building reference distribution (AG News test set — 25% per label)...") |
| reference_df = build_reference_df() |
|
|
| print(f"[drift] loading current data from '{db_path}' (last {hours}h, model-only rows)...") |
| current_df = load_current_df(db_path, hours) |
|
|
| if len(current_df) < MIN_SAMPLES: |
| print( |
| f"[drift] only {len(current_df)} rows in window " |
| f"(need ≥ {MIN_SAMPLES}). Run the pipeline longer and retry." |
| ) |
| return |
|
|
| ref_dist = reference_df["label"].value_counts(normalize=True).round(3) |
| cur_dist = current_df["label"].value_counts(normalize=True).round(3) |
| print("\n Label distribution comparison:") |
| print(f" {'Label':<12} {'Reference':>10} {'Current':>10}") |
| for label in REFERENCE_COUNTS: |
| print(f" {label:<12} {ref_dist.get(label, 0):>10.1%} {cur_dist.get(label, 0):>10.1%}") |
| print( |
| f"\n Current confidence — mean: {current_df['confidence'].mean():.3f} " |
| f"p10: {current_df['confidence'].quantile(0.1):.3f}\n" |
| ) |
|
|
| print("[drift] running Evidently drift report...") |
| report_dict = run_evidently(reference_df, current_df, html_path) |
| print(f"[drift] report saved → {html_path}") |
|
|
| drift_detected, drift_share, drifted_columns = extract_drift_result(report_dict) |
|
|
| print("[drift] logging metrics to MLflow (DagsHub)...") |
| log_to_mlflow(drift_detected, drift_share, drifted_columns, current_df, html_path) |
|
|
| status = "DRIFT DETECTED" if drift_detected else "no drift detected" |
| col_info = f" drifted columns: {drifted_columns}" if drifted_columns else "" |
| print(f"\n[drift] result: {status}{col_info}") |
| print(f"[drift] open {html_path} in a browser for the full visual report") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Run drift detection on classifications.db") |
| parser.add_argument("--db", default="classifications.db", help="Path to SQLite DB") |
| parser.add_argument("--hours", type=int, default=24, help="Lookback window in hours") |
| parser.add_argument("--output-dir", default="reports", help="Where to save HTML reports") |
| args = parser.parse_args() |
|
|
| run(db_path=args.db, hours=args.hours, output_dir=args.output_dir) |
|
|