import sqlite3 import gradio as gr import pandas as pd import plotly.express as px import plotly.graph_objects as go from datetime import datetime, timedelta import os import subprocess import numpy as np from frouros.detectors.data_drift import KSTest from scipy.stats import wasserstein_distance # Initialize database if it doesn't exist if not os.path.exists('drift_detection.db'): print("Database not found. Creating new database...") subprocess.run(['python', 'create_database.py'], check=True) print("Database created successfully!") def get_korean_holidays_2025(): """Get Korean holidays for Jan-Aug 2025 (including substitute holidays)""" holidays = [ (datetime(2025, 1, 1), datetime(2025, 1, 1), "신정"), (datetime(2025, 1, 28), datetime(2025, 1, 30), "설날 연휴"), (datetime(2025, 3, 1), datetime(2025, 3, 1), "삼일절"), (datetime(2025, 3, 3), datetime(2025, 3, 3), "삼일절 대체공휴일"), (datetime(2025, 5, 5), datetime(2025, 5, 5), "어린이날·부처님오신날"), (datetime(2025, 5, 6), datetime(2025, 5, 6), "대체공휴일"), (datetime(2025, 6, 6), datetime(2025, 6, 6), "현충일"), (datetime(2025, 8, 15), datetime(2025, 8, 15), "광복절"), ] return holidays def get_weekends(start_date, end_date): """Get all weekends (Saturday-Sunday) between start and end date""" weekends = [] current = start_date while current <= end_date: if current.weekday() == 5: # Saturday weekend_start = current weekend_end = current + timedelta(days=1) # Sunday weekends.append((weekend_start, weekend_end)) current += timedelta(days=1) return weekends def load_drift_data(): """Load all drift_record data""" conn = sqlite3.connect('drift_detection.db') query = """ SELECT dr.model_id, mi.model_name, dr.precision, dr.recall, dr.sample_numbers, dr.js_value, dr.wd_value, dr.prediction_date FROM drift_record dr LEFT JOIN model_info mi ON dr.model_id = mi.model_id ORDER BY dr.prediction_date, dr.model_id """ df = pd.read_sql_query(query, conn) conn.close() # Round numeric columns for better display numeric_cols = ['precision', 'recall', 'js_value', 'wd_value'] for col in numeric_cols: if col in df.columns: df[col] = df[col].round(4) return df def load_model_info(): """Load model_info data""" conn = sqlite3.connect('drift_detection.db') query = """ SELECT model_id, model_name, release_date, prediction_period FROM model_info ORDER BY model_id """ df = pd.read_sql_query(query, conn) conn.close() return df def split_data_by_month(df): """Split dataframe by month""" df = df.copy() df['prediction_date'] = pd.to_datetime(df['prediction_date']) df['month'] = df['prediction_date'].dt.to_period('M') return df def detect_drift_ks_test(reference_data, current_data): """Detect drift using Kolmogorov-Smirnov test""" detector = KSTest() detector.fit(X=reference_data) result, _ = detector.compare(X=current_data) return { 'p_value': result.p_value, 'statistic': result.statistic, 'drift_detected': result.p_value < 0.05 } def calculate_monthly_drift(df, metric='precision'): """Calculate drift for each month compared to January (baseline)""" df_with_month = split_data_by_month(df) months = sorted(df_with_month['month'].unique()) if len(months) < 2: return pd.DataFrame() # Use January as baseline baseline_month = months[0] baseline_data = df_with_month[df_with_month['month'] == baseline_month][metric].values drift_results = [] for month in months[1:]: current_data = df_with_month[df_with_month['month'] == month][metric].values if len(current_data) > 0 and len(baseline_data) > 0: # KS Test ks_result = detect_drift_ks_test(baseline_data, current_data) # Wasserstein Distance wd = wasserstein_distance(baseline_data, current_data) drift_results.append({ 'month': str(month), 'month_name': month.strftime('%Y-%m'), 'ks_statistic': ks_result['statistic'], 'p_value': ks_result['p_value'], 'drift_detected': ks_result['drift_detected'], 'wasserstein_distance': wd, 'sample_size': len(current_data) }) return pd.DataFrame(drift_results) def create_metric_chart(df, metric='precision'): """Create Plotly line chart for selected metric over time by model""" if df.empty: return px.line(title="No data available") # Convert prediction_date to datetime df = df.copy() df['prediction_date'] = pd.to_datetime(df['prediction_date']) # Metric display names metric_names = { 'precision': 'Precision', 'recall': 'Recall', 'js_value': 'JS Divergence', 'wd_value': 'Wasserstein Distance' } metric_display = metric_names.get(metric, metric.capitalize()) # Create line chart fig = px.line( df, x='prediction_date', y=metric, color='model_name', labels={ 'prediction_date': 'Date', metric: metric_display, 'model_name': 'Model' }, markers=True ) # Get date range from data start_date = df['prediction_date'].min() end_date = df['prediction_date'].max() # Add weekend shading (light gray) weekends = get_weekends(start_date, end_date) for weekend_start, weekend_end in weekends: fig.add_vrect( x0=weekend_start, x1=weekend_end, fillcolor="gray", opacity=0.2, layer="below", line_width=0, ) # Add Korean holiday shading (light red) holidays = get_korean_holidays_2025() for holiday_start, holiday_end, holiday_name in holidays: if holiday_start <= end_date and holiday_end >= start_date: fig.add_vrect( x0=holiday_start, x1=holiday_end + timedelta(days=1), # Add 1 day to include the end date fillcolor="red", opacity=0.25, layer="below", line_width=0, annotation_text=holiday_name, annotation_position="top left", annotation=dict(font_size=10, font_color="darkred") ) fig.update_layout( hovermode='x unified', xaxis_title='Date', yaxis_title=metric_display, legend_title='Model', height=450, margin=dict(t=20, b=50, l=50, r=20) ) return fig def create_drift_markers_chart(df, metric='precision'): """Create time series chart with drift markers""" df_with_month = split_data_by_month(df) drift_df = calculate_monthly_drift(df, metric) # Create base chart fig = create_metric_chart(df, metric) # Add drift markers for each month with drift if not drift_df.empty: for _, row in drift_df[drift_df['drift_detected']].iterrows(): month_str = row['month'] # Add vertical line at month boundary month_date = pd.Period(month_str).to_timestamp() fig.add_vline( x=month_date, line_dash="dash", line_color="red", line_width=2, annotation_text=f"Drift Detected
{row['month_name']}", annotation_position="top", annotation=dict(font_size=9, font_color="red") ) return fig def create_monthly_drift_chart(df, metric='precision'): """Create bar chart of monthly drift scores""" drift_df = calculate_monthly_drift(df, metric) if drift_df.empty: return go.Figure().add_annotation( text="Not enough data for drift detection", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False ) fig = go.Figure() # KS Statistic bars fig.add_trace(go.Bar( x=drift_df['month_name'], y=drift_df['ks_statistic'], name='KS Statistic', marker_color=['red' if d else 'blue' for d in drift_df['drift_detected']], text=[f"p={p:.4f}" for p in drift_df['p_value']], textposition='outside' )) # Wasserstein Distance (secondary y-axis) fig.add_trace(go.Scatter( x=drift_df['month_name'], y=drift_df['wasserstein_distance'], name='Wasserstein Distance', yaxis='y2', mode='lines+markers', line=dict(color='orange', width=2), marker=dict(size=8) )) fig.update_layout( title=f'Monthly Drift Detection for {metric.capitalize()}', xaxis_title='Month', yaxis_title='KS Statistic', yaxis2=dict( title='Wasserstein Distance', overlaying='y', side='right' ), height=500, hovermode='x unified', showlegend=True ) return fig def create_drift_heatmap(df): """Create heatmap showing drift across all metrics and months""" metrics = ['precision', 'recall', 'js_value', 'wd_value'] metric_names = ['Precision', 'Recall', 'JS Divergence', 'WD Value'] all_drift_data = {} all_months = set() for metric in metrics: drift_df = calculate_monthly_drift(df, metric) if not drift_df.empty: all_drift_data[metric] = drift_df all_months.update(drift_df['month_name'].values) if not all_drift_data: return go.Figure().add_annotation( text="Not enough data for drift heatmap", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False ) months = sorted(list(all_months)) z_data = [] hover_text = [] for metric in metrics: if metric in all_drift_data: drift_df = all_drift_data[metric] row_z = [] row_hover = [] for month in months: month_data = drift_df[drift_df['month_name'] == month] if not month_data.empty: row = month_data.iloc[0] # Use p-value as color intensity (lower p-value = more drift = darker color) row_z.append(1 - row['p_value']) # Invert so drift shows as high value row_hover.append( f"KS: {row['ks_statistic']:.4f}
" + f"p-value: {row['p_value']:.4f}
" + f"WD: {row['wasserstein_distance']:.4f}
" + f"Drift: {'Yes' if row['drift_detected'] else 'No'}" ) else: row_z.append(0) row_hover.append("No data") z_data.append(row_z) hover_text.append(row_hover) else: z_data.append([0] * len(months)) hover_text.append(["No data"] * len(months)) fig = go.Figure(data=go.Heatmap( z=z_data, x=months, y=metric_names, colorscale='RdYlGn_r', # Red for drift, Green for no drift text=hover_text, hovertemplate='%{y}
%{x}
%{text}', colorbar=dict(title="Drift
Intensity") )) fig.update_layout( title='Drift Detection Heatmap (All Metrics)', xaxis_title='Month', yaxis_title='Metric', height=400 ) return fig def update_chart(metric): """Update chart based on selected metric""" df = load_drift_data() chart = create_metric_chart(df, metric) return chart def update_all_drift_visualizations(metric): """Update all drift-related visualizations""" df = load_drift_data() drift_markers_chart = create_drift_markers_chart(df, metric) monthly_drift_chart = create_monthly_drift_chart(df, metric) drift_heatmap = create_drift_heatmap(df) return drift_markers_chart, monthly_drift_chart, drift_heatmap # Create Gradio interface with gr.Blocks(title="Drift Detection Dashboard", theme=gr.themes.Soft()) as demo: gr.Markdown("# Drift Detection Dashboard") gr.Markdown("모델별 메트릭 시계열 및 월별 데이터 드리프트 분석") with gr.Row(): metric_dropdown = gr.Dropdown( choices=[ ("Precision", "precision"), ("Recall", "recall"), ("JS Divergence", "js_value"), ("Wasserstein Distance", "wd_value") ], value="precision", label="Metric to Analyze", scale=1 ) with gr.Tabs(): with gr.Tab("📈 Time Series + Drift Markers"): gr.Markdown("### 시계열 차트 (드리프트 발생 지점 표시)") drift_markers_plot = gr.Plot() with gr.Tab("📊 Monthly Drift Scores"): gr.Markdown("### 월별 드리프트 점수 (1월 대비)") monthly_drift_plot = gr.Plot() with gr.Tab("🔥 Drift Heatmap"): gr.Markdown("### 전체 메트릭 드리프트 히트맵") heatmap_plot = gr.Plot() with gr.Tab("📋 Data Tables"): gr.Markdown("### 원본 데이터") with gr.Row(): with gr.Column(scale=2): dataframe_output = gr.Dataframe( value=load_drift_data(), interactive=False, wrap=True, label="Drift Records" ) with gr.Column(scale=1): model_info_output = gr.Dataframe( value=load_model_info(), interactive=False, wrap=True, label="Model Info" ) # Event handlers metric_dropdown.change( fn=update_all_drift_visualizations, inputs=[metric_dropdown], outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot] ) # Load initial data def load_initial_data(): df = load_drift_data() drift_markers = create_drift_markers_chart(df, 'precision') monthly_drift = create_monthly_drift_chart(df, 'precision') heatmap = create_drift_heatmap(df) return drift_markers, monthly_drift, heatmap demo.load( fn=load_initial_data, outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot] ) if __name__ == "__main__": demo.launch()