Yoon-gu Hwang Claude commited on
Commit
1ee2788
·
1 Parent(s): 403ebf2

월별 데이터 드리프트 감지 및 분석 대시보드 추가

Browse files

- Frouros 라이브러리를 이용한 데이터 드리프트 감지 (KS Test, Wasserstein Distance)
- 4개의 탭으로 구성된 종합 드리프트 분석 대시보드:
1. Time Series + Drift Markers: 드리프트 발생 지점 표시
2. Monthly Drift Scores: 월별 드리프트 점수 (KS Statistic, WD)
3. Drift Heatmap: 전체 메트릭 드리프트 히트맵
4. Data Tables: 원본 데이터 및 모델 정보

- 1월을 기준(baseline)으로 2-8월 드리프트 비교
- p-value < 0.05 기준으로 드리프트 자동 감지
- 시각화: 바 차트, 라인 차트, 히트맵 등 다양한 방식 제공

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +250 -23
  2. pyproject.toml +3 -0
  3. requirements.txt +3 -0
app.py CHANGED
@@ -2,9 +2,13 @@ import sqlite3
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
 
5
  from datetime import datetime, timedelta
6
  import os
7
  import subprocess
 
 
 
8
 
9
  # Initialize database if it doesn't exist
10
  if not os.path.exists('drift_detection.db'):
@@ -87,6 +91,59 @@ def load_model_info():
87
 
88
  return df
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def create_metric_chart(df, metric='precision'):
91
  """Create Plotly line chart for selected metric over time by model"""
92
  if df.empty:
@@ -163,15 +220,170 @@ def create_metric_chart(df, metric='precision'):
163
 
164
  return fig
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def update_chart(metric):
167
  """Update chart based on selected metric"""
168
  df = load_drift_data()
169
  chart = create_metric_chart(df, metric)
170
  return chart
171
 
 
 
 
 
 
 
 
 
172
  # Create Gradio interface
173
  with gr.Blocks(title="Drift Detection Dashboard", theme=gr.themes.Soft()) as demo:
174
  gr.Markdown("# Drift Detection Dashboard")
 
175
 
176
  with gr.Row():
177
  metric_dropdown = gr.Dropdown(
@@ -182,44 +394,59 @@ with gr.Blocks(title="Drift Detection Dashboard", theme=gr.themes.Soft()) as dem
182
  ("Wasserstein Distance", "wd_value")
183
  ],
184
  value="precision",
185
- label="Metric",
186
  scale=1
187
  )
188
 
189
- with gr.Row():
190
- plot_output = gr.Plot()
 
 
191
 
192
- with gr.Row():
193
- with gr.Column(scale=2):
194
- dataframe_output = gr.Dataframe(
195
- value=load_drift_data(),
196
- interactive=False,
197
- wrap=True,
198
- label="Drift Records"
199
- )
200
- with gr.Column(scale=1):
201
- model_info_output = gr.Dataframe(
202
- value=load_model_info(),
203
- interactive=False,
204
- wrap=True,
205
- label="Model Info"
206
- )
 
 
 
 
 
 
 
 
 
 
207
 
208
  # Event handlers
209
  metric_dropdown.change(
210
- fn=update_chart,
211
  inputs=[metric_dropdown],
212
- outputs=[plot_output]
213
  )
214
 
215
- # Load initial data and chart
216
  def load_initial_data():
217
  df = load_drift_data()
218
- return df, create_metric_chart(df, 'precision')
 
 
 
219
 
220
  demo.load(
221
  fn=load_initial_data,
222
- outputs=[dataframe_output, plot_output]
223
  )
224
 
225
  if __name__ == "__main__":
 
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
5
+ import plotly.graph_objects as go
6
  from datetime import datetime, timedelta
7
  import os
8
  import subprocess
9
+ import numpy as np
10
+ from frouros.detectors.data_drift import KSTest
11
+ from scipy.stats import wasserstein_distance
12
 
13
  # Initialize database if it doesn't exist
14
  if not os.path.exists('drift_detection.db'):
 
91
 
92
  return df
93
 
94
+ def split_data_by_month(df):
95
+ """Split dataframe by month"""
96
+ df = df.copy()
97
+ df['prediction_date'] = pd.to_datetime(df['prediction_date'])
98
+ df['month'] = df['prediction_date'].dt.to_period('M')
99
+ return df
100
+
101
+ def detect_drift_ks_test(reference_data, current_data):
102
+ """Detect drift using Kolmogorov-Smirnov test"""
103
+ detector = KSTest()
104
+ detector.fit(X=reference_data)
105
+ result, _ = detector.compare(X=current_data)
106
+ return {
107
+ 'p_value': result.p_value,
108
+ 'statistic': result.statistic,
109
+ 'drift_detected': result.p_value < 0.05
110
+ }
111
+
112
+ def calculate_monthly_drift(df, metric='precision'):
113
+ """Calculate drift for each month compared to January (baseline)"""
114
+ df_with_month = split_data_by_month(df)
115
+
116
+ months = sorted(df_with_month['month'].unique())
117
+ if len(months) < 2:
118
+ return pd.DataFrame()
119
+
120
+ # Use January as baseline
121
+ baseline_month = months[0]
122
+ baseline_data = df_with_month[df_with_month['month'] == baseline_month][metric].values
123
+
124
+ drift_results = []
125
+ for month in months[1:]:
126
+ current_data = df_with_month[df_with_month['month'] == month][metric].values
127
+
128
+ if len(current_data) > 0 and len(baseline_data) > 0:
129
+ # KS Test
130
+ ks_result = detect_drift_ks_test(baseline_data, current_data)
131
+
132
+ # Wasserstein Distance
133
+ wd = wasserstein_distance(baseline_data, current_data)
134
+
135
+ drift_results.append({
136
+ 'month': str(month),
137
+ 'month_name': month.strftime('%Y-%m'),
138
+ 'ks_statistic': ks_result['statistic'],
139
+ 'p_value': ks_result['p_value'],
140
+ 'drift_detected': ks_result['drift_detected'],
141
+ 'wasserstein_distance': wd,
142
+ 'sample_size': len(current_data)
143
+ })
144
+
145
+ return pd.DataFrame(drift_results)
146
+
147
  def create_metric_chart(df, metric='precision'):
148
  """Create Plotly line chart for selected metric over time by model"""
149
  if df.empty:
 
220
 
221
  return fig
222
 
223
+ def create_drift_markers_chart(df, metric='precision'):
224
+ """Create time series chart with drift markers"""
225
+ df_with_month = split_data_by_month(df)
226
+ drift_df = calculate_monthly_drift(df, metric)
227
+
228
+ # Create base chart
229
+ fig = create_metric_chart(df, metric)
230
+
231
+ # Add drift markers for each month with drift
232
+ if not drift_df.empty:
233
+ for _, row in drift_df[drift_df['drift_detected']].iterrows():
234
+ month_str = row['month']
235
+ # Add vertical line at month boundary
236
+ month_date = pd.Period(month_str).to_timestamp()
237
+ fig.add_vline(
238
+ x=month_date,
239
+ line_dash="dash",
240
+ line_color="red",
241
+ line_width=2,
242
+ annotation_text=f"Drift Detected<br>{row['month_name']}",
243
+ annotation_position="top",
244
+ annotation=dict(font_size=9, font_color="red")
245
+ )
246
+
247
+ return fig
248
+
249
+ def create_monthly_drift_chart(df, metric='precision'):
250
+ """Create bar chart of monthly drift scores"""
251
+ drift_df = calculate_monthly_drift(df, metric)
252
+
253
+ if drift_df.empty:
254
+ return go.Figure().add_annotation(
255
+ text="Not enough data for drift detection",
256
+ xref="paper", yref="paper",
257
+ x=0.5, y=0.5, showarrow=False
258
+ )
259
+
260
+ fig = go.Figure()
261
+
262
+ # KS Statistic bars
263
+ fig.add_trace(go.Bar(
264
+ x=drift_df['month_name'],
265
+ y=drift_df['ks_statistic'],
266
+ name='KS Statistic',
267
+ marker_color=['red' if d else 'blue' for d in drift_df['drift_detected']],
268
+ text=[f"p={p:.4f}" for p in drift_df['p_value']],
269
+ textposition='outside'
270
+ ))
271
+
272
+ # Wasserstein Distance (secondary y-axis)
273
+ fig.add_trace(go.Scatter(
274
+ x=drift_df['month_name'],
275
+ y=drift_df['wasserstein_distance'],
276
+ name='Wasserstein Distance',
277
+ yaxis='y2',
278
+ mode='lines+markers',
279
+ line=dict(color='orange', width=2),
280
+ marker=dict(size=8)
281
+ ))
282
+
283
+ fig.update_layout(
284
+ title=f'Monthly Drift Detection for {metric.capitalize()}',
285
+ xaxis_title='Month',
286
+ yaxis_title='KS Statistic',
287
+ yaxis2=dict(
288
+ title='Wasserstein Distance',
289
+ overlaying='y',
290
+ side='right'
291
+ ),
292
+ height=500,
293
+ hovermode='x unified',
294
+ showlegend=True
295
+ )
296
+
297
+ return fig
298
+
299
+ def create_drift_heatmap(df):
300
+ """Create heatmap showing drift across all metrics and months"""
301
+ metrics = ['precision', 'recall', 'js_value', 'wd_value']
302
+ metric_names = ['Precision', 'Recall', 'JS Divergence', 'WD Value']
303
+
304
+ all_drift_data = {}
305
+ all_months = set()
306
+
307
+ for metric in metrics:
308
+ drift_df = calculate_monthly_drift(df, metric)
309
+ if not drift_df.empty:
310
+ all_drift_data[metric] = drift_df
311
+ all_months.update(drift_df['month_name'].values)
312
+
313
+ if not all_drift_data:
314
+ return go.Figure().add_annotation(
315
+ text="Not enough data for drift heatmap",
316
+ xref="paper", yref="paper",
317
+ x=0.5, y=0.5, showarrow=False
318
+ )
319
+
320
+ months = sorted(list(all_months))
321
+ z_data = []
322
+ hover_text = []
323
+
324
+ for metric in metrics:
325
+ if metric in all_drift_data:
326
+ drift_df = all_drift_data[metric]
327
+ row_z = []
328
+ row_hover = []
329
+ for month in months:
330
+ month_data = drift_df[drift_df['month_name'] == month]
331
+ if not month_data.empty:
332
+ row = month_data.iloc[0]
333
+ # Use p-value as color intensity (lower p-value = more drift = darker color)
334
+ row_z.append(1 - row['p_value']) # Invert so drift shows as high value
335
+ row_hover.append(
336
+ f"KS: {row['ks_statistic']:.4f}<br>" +
337
+ f"p-value: {row['p_value']:.4f}<br>" +
338
+ f"WD: {row['wasserstein_distance']:.4f}<br>" +
339
+ f"Drift: {'Yes' if row['drift_detected'] else 'No'}"
340
+ )
341
+ else:
342
+ row_z.append(0)
343
+ row_hover.append("No data")
344
+ z_data.append(row_z)
345
+ hover_text.append(row_hover)
346
+ else:
347
+ z_data.append([0] * len(months))
348
+ hover_text.append(["No data"] * len(months))
349
+
350
+ fig = go.Figure(data=go.Heatmap(
351
+ z=z_data,
352
+ x=months,
353
+ y=metric_names,
354
+ colorscale='RdYlGn_r', # Red for drift, Green for no drift
355
+ text=hover_text,
356
+ hovertemplate='%{y}<br>%{x}<br>%{text}<extra></extra>',
357
+ colorbar=dict(title="Drift<br>Intensity")
358
+ ))
359
+
360
+ fig.update_layout(
361
+ title='Drift Detection Heatmap (All Metrics)',
362
+ xaxis_title='Month',
363
+ yaxis_title='Metric',
364
+ height=400
365
+ )
366
+
367
+ return fig
368
+
369
  def update_chart(metric):
370
  """Update chart based on selected metric"""
371
  df = load_drift_data()
372
  chart = create_metric_chart(df, metric)
373
  return chart
374
 
375
+ def update_all_drift_visualizations(metric):
376
+ """Update all drift-related visualizations"""
377
+ df = load_drift_data()
378
+ drift_markers_chart = create_drift_markers_chart(df, metric)
379
+ monthly_drift_chart = create_monthly_drift_chart(df, metric)
380
+ drift_heatmap = create_drift_heatmap(df)
381
+ return drift_markers_chart, monthly_drift_chart, drift_heatmap
382
+
383
  # Create Gradio interface
384
  with gr.Blocks(title="Drift Detection Dashboard", theme=gr.themes.Soft()) as demo:
385
  gr.Markdown("# Drift Detection Dashboard")
386
+ gr.Markdown("모델별 메트릭 시계열 및 월별 데이터 드리프트 분석")
387
 
388
  with gr.Row():
389
  metric_dropdown = gr.Dropdown(
 
394
  ("Wasserstein Distance", "wd_value")
395
  ],
396
  value="precision",
397
+ label="Metric to Analyze",
398
  scale=1
399
  )
400
 
401
+ with gr.Tabs():
402
+ with gr.Tab("📈 Time Series + Drift Markers"):
403
+ gr.Markdown("### 시계열 차트 (드리프트 발생 지점 표시)")
404
+ drift_markers_plot = gr.Plot()
405
 
406
+ with gr.Tab("📊 Monthly Drift Scores"):
407
+ gr.Markdown("### 월별 드리프트 점수 (1월 대비)")
408
+ monthly_drift_plot = gr.Plot()
409
+
410
+ with gr.Tab("🔥 Drift Heatmap"):
411
+ gr.Markdown("### 전체 메트릭 드리프트 히트맵")
412
+ heatmap_plot = gr.Plot()
413
+
414
+ with gr.Tab("📋 Data Tables"):
415
+ gr.Markdown("### 원본 데이터")
416
+ with gr.Row():
417
+ with gr.Column(scale=2):
418
+ dataframe_output = gr.Dataframe(
419
+ value=load_drift_data(),
420
+ interactive=False,
421
+ wrap=True,
422
+ label="Drift Records"
423
+ )
424
+ with gr.Column(scale=1):
425
+ model_info_output = gr.Dataframe(
426
+ value=load_model_info(),
427
+ interactive=False,
428
+ wrap=True,
429
+ label="Model Info"
430
+ )
431
 
432
  # Event handlers
433
  metric_dropdown.change(
434
+ fn=update_all_drift_visualizations,
435
  inputs=[metric_dropdown],
436
+ outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot]
437
  )
438
 
439
+ # Load initial data
440
  def load_initial_data():
441
  df = load_drift_data()
442
+ drift_markers = create_drift_markers_chart(df, 'precision')
443
+ monthly_drift = create_monthly_drift_chart(df, 'precision')
444
+ heatmap = create_drift_heatmap(df)
445
+ return drift_markers, monthly_drift, heatmap
446
 
447
  demo.load(
448
  fn=load_initial_data,
449
+ outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot]
450
  )
451
 
452
  if __name__ == "__main__":
pyproject.toml CHANGED
@@ -5,7 +5,10 @@ description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.13"
7
  dependencies = [
 
8
  "gradio>=5.49.1",
 
9
  "pandas>=2.3.3",
10
  "plotly>=6.3.1",
 
11
  ]
 
5
  readme = "README.md"
6
  requires-python = ">=3.13"
7
  dependencies = [
8
+ "frouros>=0.9.0",
9
  "gradio>=5.49.1",
10
+ "numpy>=2.1.3",
11
  "pandas>=2.3.3",
12
  "plotly>=6.3.1",
13
+ "scipy>=1.14.1",
14
  ]
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  pandas
2
  plotly
 
 
 
 
1
  pandas
2
  plotly
3
+ frouros
4
+ scipy
5
+ numpy