Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from datetime import datetime, timedelta | |
| import os | |
| import subprocess | |
| import numpy as np | |
| from frouros.detectors.data_drift import KSTest | |
| from scipy.stats import wasserstein_distance | |
| # Initialize database if it doesn't exist | |
| if not os.path.exists('drift_detection.db'): | |
| print("Database not found. Creating new database...") | |
| subprocess.run(['python', 'create_database.py'], check=True) | |
| print("Database created successfully!") | |
| def get_korean_holidays_2025(): | |
| """Get Korean holidays for Jan-Aug 2025 (including substitute holidays)""" | |
| holidays = [ | |
| (datetime(2025, 1, 1), datetime(2025, 1, 1), "์ ์ "), | |
| (datetime(2025, 1, 28), datetime(2025, 1, 30), "์ค๋ ์ฐํด"), | |
| (datetime(2025, 3, 1), datetime(2025, 3, 1), "์ผ์ผ์ "), | |
| (datetime(2025, 3, 3), datetime(2025, 3, 3), "์ผ์ผ์ ๋์ฒด๊ณตํด์ผ"), | |
| (datetime(2025, 5, 5), datetime(2025, 5, 5), "์ด๋ฆฐ์ด๋ ยท๋ถ์ฒ๋์ค์ ๋ "), | |
| (datetime(2025, 5, 6), datetime(2025, 5, 6), "๋์ฒด๊ณตํด์ผ"), | |
| (datetime(2025, 6, 6), datetime(2025, 6, 6), "ํ์ถฉ์ผ"), | |
| (datetime(2025, 8, 15), datetime(2025, 8, 15), "๊ด๋ณต์ "), | |
| ] | |
| return holidays | |
| def get_weekends(start_date, end_date): | |
| """Get all weekends (Saturday-Sunday) between start and end date""" | |
| weekends = [] | |
| current = start_date | |
| while current <= end_date: | |
| if current.weekday() == 5: # Saturday | |
| weekend_start = current | |
| weekend_end = current + timedelta(days=1) # Sunday | |
| weekends.append((weekend_start, weekend_end)) | |
| current += timedelta(days=1) | |
| return weekends | |
| def load_drift_data(): | |
| """Load all drift_record data""" | |
| conn = sqlite3.connect('drift_detection.db') | |
| query = """ | |
| SELECT | |
| dr.model_id, | |
| mi.model_name, | |
| dr.precision, | |
| dr.recall, | |
| dr.sample_numbers, | |
| dr.js_value, | |
| dr.wd_value, | |
| dr.prediction_date | |
| FROM drift_record dr | |
| LEFT JOIN model_info mi ON dr.model_id = mi.model_id | |
| ORDER BY dr.prediction_date, dr.model_id | |
| """ | |
| df = pd.read_sql_query(query, conn) | |
| conn.close() | |
| # Round numeric columns for better display | |
| numeric_cols = ['precision', 'recall', 'js_value', 'wd_value'] | |
| for col in numeric_cols: | |
| if col in df.columns: | |
| df[col] = df[col].round(4) | |
| return df | |
| def load_model_info(): | |
| """Load model_info data""" | |
| conn = sqlite3.connect('drift_detection.db') | |
| query = """ | |
| SELECT | |
| model_id, | |
| model_name, | |
| release_date, | |
| prediction_period | |
| FROM model_info | |
| ORDER BY model_id | |
| """ | |
| df = pd.read_sql_query(query, conn) | |
| conn.close() | |
| return df | |
| def split_data_by_month(df): | |
| """Split dataframe by month""" | |
| df = df.copy() | |
| df['prediction_date'] = pd.to_datetime(df['prediction_date']) | |
| df['month'] = df['prediction_date'].dt.to_period('M') | |
| return df | |
| def detect_drift_ks_test(reference_data, current_data): | |
| """Detect drift using Kolmogorov-Smirnov test""" | |
| detector = KSTest() | |
| detector.fit(X=reference_data) | |
| result, _ = detector.compare(X=current_data) | |
| return { | |
| 'p_value': result.p_value, | |
| 'statistic': result.statistic, | |
| 'drift_detected': result.p_value < 0.05 | |
| } | |
| def calculate_monthly_drift(df, metric='precision'): | |
| """Calculate drift for each month compared to January (baseline)""" | |
| df_with_month = split_data_by_month(df) | |
| months = sorted(df_with_month['month'].unique()) | |
| if len(months) < 2: | |
| return pd.DataFrame() | |
| # Use January as baseline | |
| baseline_month = months[0] | |
| baseline_data = df_with_month[df_with_month['month'] == baseline_month][metric].values | |
| drift_results = [] | |
| for month in months[1:]: | |
| current_data = df_with_month[df_with_month['month'] == month][metric].values | |
| if len(current_data) > 0 and len(baseline_data) > 0: | |
| # KS Test | |
| ks_result = detect_drift_ks_test(baseline_data, current_data) | |
| # Wasserstein Distance | |
| wd = wasserstein_distance(baseline_data, current_data) | |
| drift_results.append({ | |
| 'month': str(month), | |
| 'month_name': month.strftime('%Y-%m'), | |
| 'ks_statistic': ks_result['statistic'], | |
| 'p_value': ks_result['p_value'], | |
| 'drift_detected': ks_result['drift_detected'], | |
| 'wasserstein_distance': wd, | |
| 'sample_size': len(current_data) | |
| }) | |
| return pd.DataFrame(drift_results) | |
| def create_metric_chart(df, metric='precision'): | |
| """Create Plotly line chart for selected metric over time by model""" | |
| if df.empty: | |
| return px.line(title="No data available") | |
| # Convert prediction_date to datetime | |
| df = df.copy() | |
| df['prediction_date'] = pd.to_datetime(df['prediction_date']) | |
| # Metric display names | |
| metric_names = { | |
| 'precision': 'Precision', | |
| 'recall': 'Recall', | |
| 'js_value': 'JS Divergence', | |
| 'wd_value': 'Wasserstein Distance' | |
| } | |
| metric_display = metric_names.get(metric, metric.capitalize()) | |
| # Create line chart | |
| fig = px.line( | |
| df, | |
| x='prediction_date', | |
| y=metric, | |
| color='model_name', | |
| labels={ | |
| 'prediction_date': 'Date', | |
| metric: metric_display, | |
| 'model_name': 'Model' | |
| }, | |
| markers=True | |
| ) | |
| # Get date range from data | |
| start_date = df['prediction_date'].min() | |
| end_date = df['prediction_date'].max() | |
| # Add weekend shading (light gray) | |
| weekends = get_weekends(start_date, end_date) | |
| for weekend_start, weekend_end in weekends: | |
| fig.add_vrect( | |
| x0=weekend_start, | |
| x1=weekend_end, | |
| fillcolor="gray", | |
| opacity=0.2, | |
| layer="below", | |
| line_width=0, | |
| ) | |
| # Add Korean holiday shading (light red) | |
| holidays = get_korean_holidays_2025() | |
| for holiday_start, holiday_end, holiday_name in holidays: | |
| if holiday_start <= end_date and holiday_end >= start_date: | |
| fig.add_vrect( | |
| x0=holiday_start, | |
| x1=holiday_end + timedelta(days=1), # Add 1 day to include the end date | |
| fillcolor="red", | |
| opacity=0.25, | |
| layer="below", | |
| line_width=0, | |
| annotation_text=holiday_name, | |
| annotation_position="top left", | |
| annotation=dict(font_size=10, font_color="darkred") | |
| ) | |
| fig.update_layout( | |
| hovermode='x unified', | |
| xaxis_title='Date', | |
| yaxis_title=metric_display, | |
| legend_title='Model', | |
| height=450, | |
| margin=dict(t=20, b=50, l=50, r=20) | |
| ) | |
| return fig | |
| def create_drift_markers_chart(df, metric='precision'): | |
| """Create time series chart with drift markers""" | |
| df_with_month = split_data_by_month(df) | |
| drift_df = calculate_monthly_drift(df, metric) | |
| # Create base chart | |
| fig = create_metric_chart(df, metric) | |
| # Add drift markers for each month with drift | |
| if not drift_df.empty: | |
| for _, row in drift_df[drift_df['drift_detected']].iterrows(): | |
| month_str = row['month'] | |
| # Add vertical line at month boundary | |
| month_date = pd.Period(month_str).to_timestamp() | |
| fig.add_vline( | |
| x=month_date, | |
| line_dash="dash", | |
| line_color="red", | |
| line_width=2, | |
| annotation_text=f"Drift Detected<br>{row['month_name']}", | |
| annotation_position="top", | |
| annotation=dict(font_size=9, font_color="red") | |
| ) | |
| return fig | |
| def create_monthly_drift_chart(df, metric='precision'): | |
| """Create bar chart of monthly drift scores""" | |
| drift_df = calculate_monthly_drift(df, metric) | |
| if drift_df.empty: | |
| return go.Figure().add_annotation( | |
| text="Not enough data for drift detection", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False | |
| ) | |
| fig = go.Figure() | |
| # KS Statistic bars | |
| fig.add_trace(go.Bar( | |
| x=drift_df['month_name'], | |
| y=drift_df['ks_statistic'], | |
| name='KS Statistic', | |
| marker_color=['red' if d else 'blue' for d in drift_df['drift_detected']], | |
| text=[f"p={p:.4f}" for p in drift_df['p_value']], | |
| textposition='outside' | |
| )) | |
| # Wasserstein Distance (secondary y-axis) | |
| fig.add_trace(go.Scatter( | |
| x=drift_df['month_name'], | |
| y=drift_df['wasserstein_distance'], | |
| name='Wasserstein Distance', | |
| yaxis='y2', | |
| mode='lines+markers', | |
| line=dict(color='orange', width=2), | |
| marker=dict(size=8) | |
| )) | |
| fig.update_layout( | |
| title=f'Monthly Drift Detection for {metric.capitalize()}', | |
| xaxis_title='Month', | |
| yaxis_title='KS Statistic', | |
| yaxis2=dict( | |
| title='Wasserstein Distance', | |
| overlaying='y', | |
| side='right' | |
| ), | |
| height=500, | |
| hovermode='x unified', | |
| showlegend=True | |
| ) | |
| return fig | |
| def create_drift_heatmap(df): | |
| """Create heatmap showing drift across all metrics and months""" | |
| metrics = ['precision', 'recall', 'js_value', 'wd_value'] | |
| metric_names = ['Precision', 'Recall', 'JS Divergence', 'WD Value'] | |
| all_drift_data = {} | |
| all_months = set() | |
| for metric in metrics: | |
| drift_df = calculate_monthly_drift(df, metric) | |
| if not drift_df.empty: | |
| all_drift_data[metric] = drift_df | |
| all_months.update(drift_df['month_name'].values) | |
| if not all_drift_data: | |
| return go.Figure().add_annotation( | |
| text="Not enough data for drift heatmap", | |
| xref="paper", yref="paper", | |
| x=0.5, y=0.5, showarrow=False | |
| ) | |
| months = sorted(list(all_months)) | |
| z_data = [] | |
| hover_text = [] | |
| for metric in metrics: | |
| if metric in all_drift_data: | |
| drift_df = all_drift_data[metric] | |
| row_z = [] | |
| row_hover = [] | |
| for month in months: | |
| month_data = drift_df[drift_df['month_name'] == month] | |
| if not month_data.empty: | |
| row = month_data.iloc[0] | |
| # Use p-value as color intensity (lower p-value = more drift = darker color) | |
| row_z.append(1 - row['p_value']) # Invert so drift shows as high value | |
| row_hover.append( | |
| f"KS: {row['ks_statistic']:.4f}<br>" + | |
| f"p-value: {row['p_value']:.4f}<br>" + | |
| f"WD: {row['wasserstein_distance']:.4f}<br>" + | |
| f"Drift: {'Yes' if row['drift_detected'] else 'No'}" | |
| ) | |
| else: | |
| row_z.append(0) | |
| row_hover.append("No data") | |
| z_data.append(row_z) | |
| hover_text.append(row_hover) | |
| else: | |
| z_data.append([0] * len(months)) | |
| hover_text.append(["No data"] * len(months)) | |
| fig = go.Figure(data=go.Heatmap( | |
| z=z_data, | |
| x=months, | |
| y=metric_names, | |
| colorscale='RdYlGn_r', # Red for drift, Green for no drift | |
| text=hover_text, | |
| hovertemplate='%{y}<br>%{x}<br>%{text}<extra></extra>', | |
| colorbar=dict(title="Drift<br>Intensity") | |
| )) | |
| fig.update_layout( | |
| title='Drift Detection Heatmap (All Metrics)', | |
| xaxis_title='Month', | |
| yaxis_title='Metric', | |
| height=400 | |
| ) | |
| return fig | |
| def update_chart(metric): | |
| """Update chart based on selected metric""" | |
| df = load_drift_data() | |
| chart = create_metric_chart(df, metric) | |
| return chart | |
| def update_all_drift_visualizations(metric): | |
| """Update all drift-related visualizations""" | |
| df = load_drift_data() | |
| drift_markers_chart = create_drift_markers_chart(df, metric) | |
| monthly_drift_chart = create_monthly_drift_chart(df, metric) | |
| drift_heatmap = create_drift_heatmap(df) | |
| return drift_markers_chart, monthly_drift_chart, drift_heatmap | |
| # Create Gradio interface | |
| with gr.Blocks(title="Drift Detection Dashboard", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Drift Detection Dashboard") | |
| gr.Markdown("๋ชจ๋ธ๋ณ ๋ฉํธ๋ฆญ ์๊ณ์ด ๋ฐ ์๋ณ ๋ฐ์ดํฐ ๋๋ฆฌํํธ ๋ถ์") | |
| with gr.Row(): | |
| metric_dropdown = gr.Dropdown( | |
| choices=[ | |
| ("Precision", "precision"), | |
| ("Recall", "recall"), | |
| ("JS Divergence", "js_value"), | |
| ("Wasserstein Distance", "wd_value") | |
| ], | |
| value="precision", | |
| label="Metric to Analyze", | |
| scale=1 | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("๐ Time Series + Drift Markers"): | |
| gr.Markdown("### ์๊ณ์ด ์ฐจํธ (๋๋ฆฌํํธ ๋ฐ์ ์ง์ ํ์)") | |
| drift_markers_plot = gr.Plot() | |
| with gr.Tab("๐ Monthly Drift Scores"): | |
| gr.Markdown("### ์๋ณ ๋๋ฆฌํํธ ์ ์ (1์ ๋๋น)") | |
| monthly_drift_plot = gr.Plot() | |
| with gr.Tab("๐ฅ Drift Heatmap"): | |
| gr.Markdown("### ์ ์ฒด ๋ฉํธ๋ฆญ ๋๋ฆฌํํธ ํํธ๋งต") | |
| heatmap_plot = gr.Plot() | |
| with gr.Tab("๐ Data Tables"): | |
| gr.Markdown("### ์๋ณธ ๋ฐ์ดํฐ") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| dataframe_output = gr.Dataframe( | |
| value=load_drift_data(), | |
| interactive=False, | |
| wrap=True, | |
| label="Drift Records" | |
| ) | |
| with gr.Column(scale=1): | |
| model_info_output = gr.Dataframe( | |
| value=load_model_info(), | |
| interactive=False, | |
| wrap=True, | |
| label="Model Info" | |
| ) | |
| # Event handlers | |
| metric_dropdown.change( | |
| fn=update_all_drift_visualizations, | |
| inputs=[metric_dropdown], | |
| outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot] | |
| ) | |
| # Load initial data | |
| def load_initial_data(): | |
| df = load_drift_data() | |
| drift_markers = create_drift_markers_chart(df, 'precision') | |
| monthly_drift = create_monthly_drift_chart(df, 'precision') | |
| heatmap = create_drift_heatmap(df) | |
| return drift_markers, monthly_drift, heatmap | |
| demo.load( | |
| fn=load_initial_data, | |
| outputs=[drift_markers_plot, monthly_drift_plot, heatmap_plot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |