Yoon-gu Hwang
์›”๋ณ„ ๋ฐ์ดํ„ฐ ๋“œ๋ฆฌํ”„ํŠธ ๊ฐ์ง€ ๋ฐ ๋ถ„์„ ๋Œ€์‹œ๋ณด๋“œ ์ถ”๊ฐ€
1ee2788
import sqlite3
import gradio as gr
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime, timedelta
import os
import subprocess
import numpy as np
from frouros.detectors.data_drift import KSTest
from scipy.stats import wasserstein_distance
# Initialize database if it doesn't exist
if not os.path.exists('drift_detection.db'):
print("Database not found. Creating new database...")
subprocess.run(['python', 'create_database.py'], check=True)
print("Database created successfully!")
def get_korean_holidays_2025():
"""Get Korean holidays for Jan-Aug 2025 (including substitute holidays)"""
holidays = [
(datetime(2025, 1, 1), datetime(2025, 1, 1), "์‹ ์ •"),
(datetime(2025, 1, 28), datetime(2025, 1, 30), "์„ค๋‚  ์—ฐํœด"),
(datetime(2025, 3, 1), datetime(2025, 3, 1), "์‚ผ์ผ์ ˆ"),
(datetime(2025, 3, 3), datetime(2025, 3, 3), "์‚ผ์ผ์ ˆ ๋Œ€์ฒด๊ณตํœด์ผ"),
(datetime(2025, 5, 5), datetime(2025, 5, 5), "์–ด๋ฆฐ์ด๋‚ ยท๋ถ€์ฒ˜๋‹˜์˜ค์‹ ๋‚ "),
(datetime(2025, 5, 6), datetime(2025, 5, 6), "๋Œ€์ฒด๊ณตํœด์ผ"),
(datetime(2025, 6, 6), datetime(2025, 6, 6), "ํ˜„์ถฉ์ผ"),
(datetime(2025, 8, 15), datetime(2025, 8, 15), "๊ด‘๋ณต์ ˆ"),
]
return holidays
def get_weekends(start_date, end_date):
"""Get all weekends (Saturday-Sunday) between start and end date"""
weekends = []
current = start_date
while current <= end_date:
if current.weekday() == 5: # Saturday
weekend_start = current
weekend_end = current + timedelta(days=1) # Sunday
weekends.append((weekend_start, weekend_end))
current += timedelta(days=1)
return weekends
def load_drift_data():
"""Load all drift_record data"""
conn = sqlite3.connect('drift_detection.db')
query = """
SELECT
dr.model_id,
mi.model_name,
dr.precision,
dr.recall,
dr.sample_numbers,
dr.js_value,
dr.wd_value,
dr.prediction_date
FROM drift_record dr
LEFT JOIN model_info mi ON dr.model_id = mi.model_id
ORDER BY dr.prediction_date, dr.model_id
"""
df = pd.read_sql_query(query, conn)
conn.close()
# Round numeric columns for better display
numeric_cols = ['precision', 'recall', 'js_value', 'wd_value']
for col in numeric_cols:
if col in df.columns:
df[col] = df[col].round(4)
return df
def load_model_info():
"""Load model_info data"""
conn = sqlite3.connect('drift_detection.db')
query = """
SELECT
model_id,
model_name,
release_date,
prediction_period
FROM model_info
ORDER BY model_id
"""
df = pd.read_sql_query(query, conn)
conn.close()
return df
def split_data_by_month(df):
"""Split dataframe by month"""
df = df.copy()
df['prediction_date'] = pd.to_datetime(df['prediction_date'])
df['month'] = df['prediction_date'].dt.to_period('M')
return df
def detect_drift_ks_test(reference_data, current_data):
"""Detect drift using Kolmogorov-Smirnov test"""
detector = KSTest()
detector.fit(X=reference_data)
result, _ = detector.compare(X=current_data)
return {
'p_value': result.p_value,
'statistic': result.statistic,
'drift_detected': result.p_value < 0.05
}
def calculate_monthly_drift(df, metric='precision'):
"""Calculate drift for each month compared to January (baseline)"""
df_with_month = split_data_by_month(df)
months = sorted(df_with_month['month'].unique())
if len(months) < 2:
return pd.DataFrame()
# Use January as baseline
baseline_month = months[0]
baseline_data = df_with_month[df_with_month['month'] == baseline_month][metric].values
drift_results = []
for month in months[1:]:
current_data = df_with_month[df_with_month['month'] == month][metric].values
if len(current_data) > 0 and len(baseline_data) > 0:
# KS Test
ks_result = detect_drift_ks_test(baseline_data, current_data)
# Wasserstein Distance
wd = wasserstein_distance(baseline_data, current_data)
drift_results.append({
'month': str(month),
'month_name': month.strftime('%Y-%m'),
'ks_statistic': ks_result['statistic'],
'p_value': ks_result['p_value'],
'drift_detected': ks_result['drift_detected'],
'wasserstein_distance': wd,
'sample_size': len(current_data)
})
return pd.DataFrame(drift_results)
def create_metric_chart(df, metric='precision'):
"""Create Plotly line chart for selected metric over time by model"""
if df.empty:
return px.line(title="No data available")
# Convert prediction_date to datetime
df = df.copy()
df['prediction_date'] = pd.to_datetime(df['prediction_date'])
# Metric display names
metric_names = {
'precision': 'Precision',
'recall': 'Recall',
'js_value': 'JS Divergence',
'wd_value': 'Wasserstein Distance'
}
metric_display = metric_names.get(metric, metric.capitalize())
# Create line chart
fig = px.line(
df,
x='prediction_date',
y=metric,
color='model_name',
labels={
'prediction_date': 'Date',
metric: metric_display,
'model_name': 'Model'
},
markers=True
)
# Get date range from data
start_date = df['prediction_date'].min()
end_date = df['prediction_date'].max()
# Add weekend shading (light gray)
weekends = get_weekends(start_date, end_date)
for weekend_start, weekend_end in weekends:
fig.add_vrect(
x0=weekend_start,
x1=weekend_end,
fillcolor="gray",
opacity=0.2,
layer="below",
line_width=0,
)
# Add Korean holiday shading (light red)
holidays = get_korean_holidays_2025()
for holiday_start, holiday_end, holiday_name in holidays:
if holiday_start <= end_date and holiday_end >= start_date:
fig.add_vrect(
x0=holiday_start,
x1=holiday_end + timedelta(days=1), # Add 1 day to include the end date
fillcolor="red",
opacity=0.25,
layer="below",
line_width=0,
annotation_text=holiday_name,
annotation_position="top left",
annotation=dict(font_size=10, font_color="darkred")
)
fig.update_layout(
hovermode='x unified',
xaxis_title='Date',
yaxis_title=metric_display,
legend_title='Model',
height=450,
margin=dict(t=20, b=50, l=50, r=20)
)
return fig
def create_drift_markers_chart(df, metric='precision'):
"""Create time series chart with drift markers"""
df_with_month = split_data_by_month(df)
drift_df = calculate_monthly_drift(df, metric)
# Create base chart
fig = create_metric_chart(df, metric)
# Add drift markers for each month with drift
if not drift_df.empty:
for _, row in drift_df[drift_df['drift_detected']].iterrows():
month_str = row['month']
# Add vertical line at month boundary
month_date = pd.Period(month_str).to_timestamp()
fig.add_vline(
x=month_date,
line_dash="dash",
line_color="red",
line_width=2,
annotation_text=f"Drift Detected<br>{row['month_name']}",
annotation_position="top",
annotation=dict(font_size=9, font_color="red")
)
return fig
def create_monthly_drift_chart(df, metric='precision'):
"""Create bar chart of monthly drift scores"""
drift_df = calculate_monthly_drift(df, metric)
if drift_df.empty:
return go.Figure().add_annotation(
text="Not enough data for drift detection",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False
)
fig = go.Figure()
# KS Statistic bars
fig.add_trace(go.Bar(
x=drift_df['month_name'],
y=drift_df['ks_statistic'],
name='KS Statistic',
marker_color=['red' if d else 'blue' for d in drift_df['drift_detected']],
text=[f"p={p:.4f}" for p in drift_df['p_value']],
textposition='outside'
))
# Wasserstein Distance (secondary y-axis)
fig.add_trace(go.Scatter(
x=drift_df['month_name'],
y=drift_df['wasserstein_distance'],
name='Wasserstein Distance',
yaxis='y2',
mode='lines+markers',
line=dict(color='orange', width=2),
marker=dict(size=8)
))
fig.update_layout(
title=f'Monthly Drift Detection for {metric.capitalize()}',
xaxis_title='Month',
yaxis_title='KS Statistic',
yaxis2=dict(
title='Wasserstein Distance',
overlaying='y',
side='right'
),
height=500,
hovermode='x unified',
showlegend=True
)
return fig
def create_drift_heatmap(df):
"""Create heatmap showing drift across all metrics and months"""
metrics = ['precision', 'recall', 'js_value', 'wd_value']
metric_names = ['Precision', 'Recall', 'JS Divergence', 'WD Value']
all_drift_data = {}
all_months = set()
for metric in metrics:
drift_df = calculate_monthly_drift(df, metric)
if not drift_df.empty:
all_drift_data[metric] = drift_df
all_months.update(drift_df['month_name'].values)
if not all_drift_data:
return go.Figure().add_annotation(
text="Not enough data for drift heatmap",
xref="paper", yref="paper",
x=0.5, y=0.5, showarrow=False
)
months = sorted(list(all_months))
z_data = []
hover_text = []
for metric in metrics:
if metric in all_drift_data:
drift_df = all_drift_data[metric]
row_z = []
row_hover = []
for month in months:
month_data = drift_df[drift_df['month_name'] == month]
if not month_data.empty:
row = month_data.iloc[0]
# Use p-value as color intensity (lower p-value = more drift = darker color)
row_z.append(1 - row['p_value']) # Invert so drift shows as high value
row_hover.append(
f"KS: {row['ks_statistic']:.4f}<br>" +
f"p-value: {row['p_value']:.4f}<br>" +
f"WD: {row['wasserstein_distance']:.4f}<br>" +
f"Drift: {'Yes' if row['drift_detected'] else 'No'}"
)
else:
row_z.append(0)
row_hover.append("No data")
z_data.append(row_z)
hover_text.append(row_hover)
else:
z_data.append([0] * len(months))
hover_text.append(["No data"] * len(months))
fig = go.Figure(data=go.Heatmap(
z=z_data,
x=months,
y=metric_names,
colorscale='RdYlGn_r', # Red for drift, Green for no drift
text=hover_text,
hovertemplate='%{y}<br>%{x}<br>%{text}<extra></extra>',
colorbar=dict(title="Drift<br>Intensity")
))
fig.update_layout(
title='Drift Detection Heatmap (All Metrics)',
xaxis_title='Month',
yaxis_title='Metric',
height=400
)
return fig
def update_chart(metric):
"""Update chart based on selected metric"""
df = load_drift_data()
chart = create_metric_chart(df, metric)
return chart
def update_all_drift_visualizations(metric):
"""Update all drift-related visualizations"""
df = load_drift_data()
drift_markers_chart = create_drift_markers_chart(df, metric)
monthly_drift_chart = create_monthly_drift_chart(df, metric)
drift_heatmap = create_drift_heatmap(df)
return drift_markers_chart, monthly_drift_chart, drift_heatmap
# Create Gradio interface
with gr.Blocks(title="Drift Detection Dashboard", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Drift Detection Dashboard")
gr.Markdown("๋ชจ๋ธ๋ณ„ ๋ฉ”ํŠธ๋ฆญ ์‹œ๊ณ„์—ด ๋ฐ ์›”๋ณ„ ๋ฐ์ดํ„ฐ ๋“œ๋ฆฌํ”„ํŠธ ๋ถ„์„")
with gr.Row():
metric_dropdown = gr.Dropdown(
choices=[
("Precision", "precision"),
("Recall", "recall"),
("JS Divergence", "js_value"),
("Wasserstein Distance", "wd_value")
],
value="precision",
label="Metric to Analyze",
scale=1
)
with gr.Tabs():
with gr.Tab("๐Ÿ“ˆ Time Series + Drift Markers"):
gr.Markdown("### ์‹œ๊ณ„์—ด ์ฐจํŠธ (๋“œ๋ฆฌํ”„ํŠธ ๋ฐœ์ƒ ์ง€์  ํ‘œ์‹œ)")
drift_markers_plot = gr.Plot()
with gr.Tab("๐Ÿ“Š Monthly Drift Scores"):
gr.Markdown("### ์›”๋ณ„ ๋“œ๋ฆฌํ”„ํŠธ ์ ์ˆ˜ (1์›” ๋Œ€๋น„)")
monthly_drift_plot = gr.Plot()
with gr.Tab("๐Ÿ”ฅ Drift Heatmap"):
gr.Markdown("### ์ „์ฒด ๋ฉ”ํŠธ๋ฆญ ๋“œ๋ฆฌํ”„ํŠธ ํžˆํŠธ๋งต")
heatmap_plot = gr.Plot()
with gr.Tab("๐Ÿ“‹ Data Tables"):
gr.Markdown("### ์›๋ณธ ๋ฐ์ดํ„ฐ")
with gr.Row():
with gr.Column(scale=2):
dataframe_output = gr.Dataframe(
value=load_drift_data(),
interactive=False,
wrap=True,
label="Drift Records"
)
with gr.Column(scale=1):
model_info_output = gr.Dataframe(
value=load_model_info(),
interactive=False,
wrap=True,
label="Model Info"
)
# Event handlers
metric_dropdown.change(
fn=update_all_drift_visualizations,
inputs=[metric_dropdown],
outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot]
)
# Load initial data
def load_initial_data():
df = load_drift_data()
drift_markers = create_drift_markers_chart(df, 'precision')
monthly_drift = create_monthly_drift_chart(df, 'precision')
heatmap = create_drift_heatmap(df)
return drift_markers, monthly_drift, heatmap
demo.load(
fn=load_initial_data,
outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot]
)
if __name__ == "__main__":
demo.launch()