Spaces:
Build error
Build error
File size: 7,821 Bytes
4c808ab 1d06355 4c808ab 1d06355 4c808ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
Data Drift Simulator β AI for Product Managers
Watch model performance degrade as data distribution changes over time.
"""
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
def generate_base_data(n=500, seed=42):
"""Generate synthetic fraud detection dataset."""
rng = np.random.RandomState(seed)
amount = rng.exponential(200, n)
hour = rng.randint(0, 24, n)
distance = rng.exponential(50, n)
txn_count = rng.poisson(5, n)
is_online = rng.choice([0, 1], n, p=[0.6, 0.4])
logits = (-4 + 0.003 * amount + 0.1 * (hour < 5).astype(float) +
0.01 * distance + 0.15 * txn_count + 0.5 * is_online +
rng.normal(0, 0.5, n))
labels = (1 / (1 + np.exp(-logits)) > 0.5).astype(int)
X = np.column_stack([amount, hour, distance, txn_count, is_online])
return X, labels
def apply_drift(X_base, month, drift_type, intensity, rng):
"""Apply drift to data for a given month."""
X = X_base.copy()
t = intensity / 100.0
if drift_type == "Gradual":
# Features slowly shift
shift = t * month / 24.0
X[:, 0] *= (1 + shift * 0.5) # amounts increase
X[:, 2] *= (1 + shift * 0.3) # distances increase
X[:, 3] = np.clip(X[:, 3] + shift * 2, 0, 20) # more transactions
elif drift_type == "Sudden":
# Sharp change at month 6
if month >= 6:
X[:, 0] *= (1 + t * 0.8) # amounts jump
X[:, 2] *= (1 + t * 0.6)
X[:, 4] = rng.choice([0, 1], len(X), p=[0.3, 0.7]) # more online
elif drift_type == "Seasonal":
# Cyclical pattern (holiday fraud spikes)
seasonal_factor = t * 0.5 * np.sin(2 * np.pi * month / 12)
X[:, 0] *= (1 + seasonal_factor)
X[:, 3] = np.clip(X[:, 3] * (1 + seasonal_factor * 0.5), 0, 20)
return X
def simulate_drift(drift_type, intensity, months):
months = int(months)
rng = np.random.RandomState(42)
# Train baseline model on month 0
X_train, y_train = generate_base_data(500, seed=42)
model = RandomForestClassifier(n_estimators=50, random_state=42, n_jobs=-1)
model.fit(X_train, y_train)
# Simulate months
accuracies, f1_scores, month_labels = [], [], []
drift_amounts = []
for m in range(months + 1):
X_test, y_test = generate_base_data(200, seed=100 + m)
X_drifted = apply_drift(X_test, m, drift_type, intensity, rng)
preds = model.predict(X_drifted)
acc = accuracy_score(y_test, preds)
f1 = f1_score(y_test, preds, zero_division=0)
accuracies.append(acc)
f1_scores.append(f1)
month_labels.append(m)
# Measure drift magnitude
mean_diff = np.mean(np.abs(X_drifted - X_test)) / (np.mean(np.abs(X_test)) + 1e-6)
drift_amounts.append(mean_diff)
# Find degradation point (first month where F1 drops > 10%)
baseline_f1 = f1_scores[0]
degradation_month = None
for i, f in enumerate(f1_scores):
if f < baseline_f1 * 0.9:
degradation_month = i
break
# Performance chart
fig = make_subplots(
rows=2, cols=1,
subplot_titles=("Model Performance Over Time", "Data Drift Magnitude"),
vertical_spacing=0.15
)
fig.add_trace(go.Scatter(
x=month_labels, y=accuracies, name="Accuracy",
line=dict(color="#3b82f6", width=2), mode="lines+markers"
), row=1, col=1)
fig.add_trace(go.Scatter(
x=month_labels, y=f1_scores, name="F1 Score",
line=dict(color="#10b981", width=2), mode="lines+markers"
), row=1, col=1)
# Threshold line
fig.add_hline(y=baseline_f1 * 0.9, line_dash="dash", line_color="red",
annotation_text="10% degradation threshold", row=1, col=1)
if degradation_month is not None:
fig.add_vline(x=degradation_month, line_dash="dot", line_color="red", row=1, col=1)
fig.add_trace(go.Bar(
x=month_labels, y=drift_amounts, name="Drift Magnitude",
marker_color="#f59e0b", opacity=0.7
), row=2, col=1)
fig.update_layout(height=600, margin=dict(l=20, r=20, t=40, b=20))
fig.update_yaxes(title_text="Score", range=[0, 1.05], row=1, col=1)
fig.update_yaxes(title_text="Drift", row=2, col=1)
fig.update_xaxes(title_text="Month", row=2, col=1)
# Summary
final_acc = accuracies[-1]
final_f1 = f1_scores[-1]
acc_drop = (accuracies[0] - final_acc) / accuracies[0] * 100
summary = f"""## Drift Analysis
| Metric | Month 0 | Month {months} | Change |
|--------|---------|----------|--------|
| Accuracy | {accuracies[0]:.1%} | {final_acc:.1%} | {'-' if acc_drop > 0 else '+'}{abs(acc_drop):.1f}% |
| F1 Score | {f1_scores[0]:.1%} | {final_f1:.1%} | {'-' if f1_scores[0] > final_f1 else '+'}{abs(f1_scores[0] - final_f1)*100:.1f}pp |
"""
if degradation_month is not None:
summary += f"**Alert:** Performance degraded past 10% threshold at **month {degradation_month}**.\n\n"
else:
summary += "**Status:** No significant degradation detected in this timeframe.\n\n"
# Recommendations by drift type
if drift_type == "Gradual":
rec_interval = max(3, degradation_month - 1) if degradation_month else 6
summary += f"**Recommendation:** For gradual drift, retrain every **{rec_interval} months**. Set up automated performance monitoring with alerts at 5% degradation."
elif drift_type == "Sudden":
summary += "**Recommendation:** For sudden drift, you need **real-time monitoring** and the ability to retrain within days. Set up alerts for sharp accuracy drops and have a retraining pipeline ready."
else:
summary += "**Recommendation:** For seasonal drift, retrain **before each peak season** using recent data. Consider maintaining separate models for peak vs off-peak periods."
return fig, summary
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(title="Data Drift Simulator", theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown(
"# Data Drift Simulator\n\n"
"**PM Decision:** ML models degrade over time as real-world data changes. Use this "
"to understand why you must require monitoring in every ML project and budget for "
"retraining. A model that's 95% accurate at launch might drop to 70% in 6 months.\n\n"
"Watch a fraud detection model's performance degrade as data distribution changes. "
"**ML models aren't like software β they don't stay accurate forever.**"
)
with gr.Row():
drift_type = gr.Dropdown(
choices=["Gradual", "Sudden", "Seasonal"],
value="Gradual",
label="Drift Type"
)
intensity = gr.Slider(10, 100, value=50, step=5, label="Drift Intensity (%)")
months = gr.Slider(6, 24, value=18, step=1, label="Simulation Length (months)")
run_btn = gr.Button("Simulate Drift", variant="primary")
chart = gr.Plot(label="Performance Over Time")
analysis = gr.Markdown()
run_btn.click(simulate_drift, [drift_type, intensity, months], [chart, analysis])
demo.load(simulate_drift, [drift_type, intensity, months], [chart, analysis])
gr.Markdown(
"---\n"
"**PM Takeaway:** Always ask: 'What happens when the data changes? How will we know, "
"and how often will we retrain?' Budget for monitoring from day one.\n\n"
"*AI for Product Managers*"
)
if __name__ == "__main__":
demo.launch()
|