Spaces:
Sleeping
Sleeping
| import plotly.graph_objects as go | |
| import numpy as np | |
| from typing import Tuple | |
| from plotly.subplots import make_subplots | |
| def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.ndarray, drift_type: str) -> go.Figure: | |
| """드리프트 데이터를 Plotly로 시각화""" | |
| # 2개의 row로 subplot 생성 | |
| fig = make_subplots( | |
| rows=2, cols=1, | |
| row_heights=[0.7, 0.3], | |
| vertical_spacing=0.1, | |
| subplot_titles=("Time Series Data", "Drift Detection (20-sample windows)") | |
| ) | |
| # Y축 범위 계산 | |
| y_min, y_max = y.min(), y.max() | |
| y_range = y_max - y_min | |
| y_plot_min = y_min - y_range * 0.1 | |
| y_plot_max = y_max + y_range * 0.1 | |
| # === 첫 번째 row: 시계열 데이터 === | |
| # 배경: drift 구간을 bar graph로 표시 | |
| segment_boundaries = [0] + drift_points.tolist() + [len(X)] | |
| colors = ['rgba(100, 150, 255, 0.15)', 'rgba(255, 150, 100, 0.15)', 'rgba(150, 255, 150, 0.15)', | |
| 'rgba(255, 200, 100, 0.15)', 'rgba(200, 150, 255, 0.15)', 'rgba(150, 255, 200, 0.15)'] | |
| for i in range(len(segment_boundaries) - 1): | |
| start_idx = segment_boundaries[i] | |
| end_idx = segment_boundaries[i + 1] | |
| # 각 segment를 bar로 표시 | |
| fig.add_trace(go.Bar( | |
| x=X[start_idx:end_idx], | |
| y=[y_plot_max - y_plot_min] * (end_idx - start_idx), | |
| base=y_plot_min, | |
| marker=dict( | |
| color=colors[i % len(colors)], | |
| line=dict(width=0) | |
| ), | |
| name=f'Segment {i+1}', | |
| showlegend=False, | |
| hoverinfo='skip' | |
| ), row=1, col=1) | |
| # 메인 라인 그래프 | |
| fig.add_trace(go.Scatter( | |
| x=X, | |
| y=y, | |
| mode='lines+markers', | |
| name='Data', | |
| line=dict(color='rgb(50, 100, 180)', width=2), | |
| marker=dict(size=4, color='rgb(50, 100, 180)'), | |
| hovertemplate='Time: %{x}<br>Value: %{y:.2f}<extra></extra>' | |
| ), row=1, col=1) | |
| # 드리프트 발생 지점 표시 | |
| for i, drift_point in enumerate(drift_points): | |
| fig.add_vline( | |
| x=X[drift_point], | |
| line_dash="dash", | |
| line_color="red", | |
| line_width=2, | |
| annotation_text=f"Drift {i+1}", | |
| annotation_position="top", | |
| row=1, col=1 | |
| ) | |
| # === 두 번째 row: Drift Classification (using Frouros) === | |
| from frouros.detectors.data_drift import KSTest | |
| from scipy import stats | |
| window_size = 20 | |
| n_windows = len(X) // window_size | |
| window_centers = [] | |
| drift_detected = [] | |
| # 첫 번째 window를 reference로 사용 | |
| reference_window = y[:window_size] | |
| for i in range(n_windows): | |
| start_idx = i * window_size | |
| end_idx = (i + 1) * window_size | |
| window_center = (start_idx + end_idx) / 2 | |
| window_data = y[start_idx:end_idx] | |
| # Frouros KSTest를 사용하여 drift 감지 | |
| if i == 0: | |
| # 첫 번째 window는 reference이므로 drift 없음 | |
| has_drift = False | |
| else: | |
| try: | |
| # Kolmogorov-Smirnov test 수행 | |
| detector = KSTest() | |
| detector.fit(X=reference_window.reshape(-1, 1)) | |
| result = detector.compare(X=window_data.reshape(-1, 1)) | |
| # p-value가 0.05보다 작으면 drift 감지 | |
| has_drift = result.p_value < 0.05 | |
| # drift가 감지되면 reference window 업데이트 | |
| if not has_drift: | |
| reference_window = window_data | |
| except: | |
| # 에러 발생 시 이전 방식 사용 | |
| has_drift = any(start_idx <= dp < end_idx for dp in drift_points) | |
| window_centers.append(window_center) | |
| drift_detected.append(1 if has_drift else 0) | |
| # Drift detection bar graph - 주황색(drift) / 초록색(no drift) | |
| bar_colors = ['rgba(255, 140, 0, 0.7)' if d == 1 else 'rgba(100, 200, 100, 0.7)' | |
| for d in drift_detected] | |
| fig.add_trace(go.Bar( | |
| x=window_centers, | |
| y=[1 for d in drift_detected], | |
| marker=dict(color=bar_colors, line=dict(width=0)), | |
| name='Drift Detected', | |
| showlegend=False, | |
| hovertemplate='Window: %{x:.0f}<br>Drift: %{y}<extra></extra>', | |
| width=window_size * 0.9 | |
| ), row=2, col=1) | |
| # 레이아웃 설정 | |
| title_map = { | |
| "sudden": "Sudden Drift", | |
| "gradual": "Gradual Drift", | |
| "incremental": "Incremental Drift", | |
| "recurring": "Reoccurring Concepts" | |
| } | |
| fig.update_layout( | |
| title=dict( | |
| text=title_map.get(drift_type, "Concept Drift"), | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=20) | |
| ), | |
| hovermode='closest', | |
| template='plotly_white', | |
| height=700, | |
| showlegend=True, | |
| legend=dict( | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="right", | |
| x=0.99 | |
| ), | |
| plot_bgcolor='white', | |
| barmode='overlay', | |
| bargap=0 | |
| ) | |
| # 첫 번째 subplot 축 설정 | |
| fig.update_xaxes(title_text="Time", showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1) | |
| fig.update_yaxes(title_text="Value", showgrid=True, gridwidth=1, gridcolor='LightGray', | |
| range=[y_plot_min, y_plot_max], row=1, col=1) | |
| # 두 번째 subplot 축 설정 | |
| fig.update_xaxes(title_text="Time", showgrid=True, gridwidth=1, gridcolor='LightGray', row=2, col=1) | |
| fig.update_yaxes(title_text="Drift", showgrid=True, gridwidth=1, gridcolor='LightGray', | |
| range=[-0.2, 1.5], tickvals=[0, 1], ticktext=['No Drift', 'Drift'], row=2, col=1) | |
| return fig | |
| def create_comparison_visualization(drift_data_dict: dict) -> go.Figure: | |
| """여러 드리프트 유형을 한 번에 비교""" | |
| from plotly.subplots import make_subplots | |
| fig = make_subplots( | |
| rows=2, cols=2, | |
| subplot_titles=("Sudden Drift", "Gradual Drift", "Incremental Drift", "Reoccurring Concepts"), | |
| vertical_spacing=0.15, | |
| horizontal_spacing=0.1 | |
| ) | |
| positions = [(1, 1), (1, 2), (2, 1), (2, 2)] | |
| drift_types = ["sudden", "gradual", "incremental", "recurring"] | |
| colors = ['rgba(100, 150, 255, 0.15)', 'rgba(255, 150, 100, 0.15)', 'rgba(150, 255, 150, 0.15)', | |
| 'rgba(255, 200, 100, 0.15)', 'rgba(200, 150, 255, 0.15)', 'rgba(150, 255, 200, 0.15)'] | |
| for (row, col), drift_type in zip(positions, drift_types): | |
| if drift_type in drift_data_dict: | |
| X, y, drift_points = drift_data_dict[drift_type] | |
| # Y축 범위 계산 | |
| y_min, y_max = y.min(), y.max() | |
| y_range = y_max - y_min | |
| y_plot_min = y_min - y_range * 0.1 | |
| y_plot_max = y_max + y_range * 0.1 | |
| # 배경: drift 구간을 bar로 표시 | |
| segment_boundaries = [0] + drift_points.tolist() + [len(X)] | |
| for i in range(len(segment_boundaries) - 1): | |
| start_idx = segment_boundaries[i] | |
| end_idx = segment_boundaries[i + 1] | |
| fig.add_trace( | |
| go.Bar( | |
| x=X[start_idx:end_idx], | |
| y=[y_plot_max - y_plot_min] * (end_idx - start_idx), | |
| base=y_plot_min, | |
| marker=dict( | |
| color=colors[i % len(colors)], | |
| line=dict(width=0) | |
| ), | |
| showlegend=False, | |
| hoverinfo='skip' | |
| ), | |
| row=row, col=col | |
| ) | |
| # 라인 그래프 추가 | |
| fig.add_trace( | |
| go.Scatter( | |
| x=X, | |
| y=y, | |
| mode='lines', | |
| line=dict(color='rgb(50, 100, 180)', width=1.5), | |
| showlegend=False | |
| ), | |
| row=row, col=col | |
| ) | |
| # 드리프트 지점 표시 | |
| for drift_point in drift_points: | |
| fig.add_vline( | |
| x=X[drift_point], | |
| line_dash="dash", | |
| line_color="red", | |
| line_width=1, | |
| row=row, col=col | |
| ) | |
| # 레이아웃 설정 | |
| fig.update_xaxes(title_text="Time", showgrid=True, gridcolor='LightGray') | |
| fig.update_yaxes(title_text="Value", showgrid=True, gridcolor='LightGray') | |
| fig.update_layout( | |
| height=800, | |
| title_text="Concept Drift Types Comparison", | |
| showlegend=False, | |
| template='plotly_white', | |
| barmode='overlay', | |
| bargap=0 | |
| ) | |
| return fig | |