Spaces:
Sleeping
Sleeping
File size: 8,748 Bytes
3e3dc68 9c2cd7f 3e3dc68 9c2cd7f 3e3dc68 bd359c3 9c2cd7f bd359c3 9c2cd7f bd359c3 befeb85 bd359c3 befeb85 9c2cd7f befeb85 9c2cd7f befeb85 3e3dc68 b9120fa 9c2cd7f b9120fa 9c2cd7f b9120fa 9c2cd7f b9120fa 9c2cd7f 53d71a9 9c2cd7f 3e3dc68 7ab1194 3e3dc68 3ab49ae 93ec097 3e3dc68 9c2cd7f befeb85 3e3dc68 7ab1194 3ab49ae bd359c3 3e3dc68 9c2cd7f 3e3dc68 7ab1194 3e3dc68 bd359c3 3e3dc68 bd359c3 befeb85 bd359c3 befeb85 3e3dc68 7ab1194 befeb85 7ab1194 bd359c3 7ab1194 3e3dc68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import plotly.graph_objects as go
import numpy as np
from typing import Tuple
from plotly.subplots import make_subplots
def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.ndarray, drift_type: str) -> go.Figure:
"""드리프트 데이터를 Plotly로 시각화"""
# 2개의 row로 subplot 생성
fig = make_subplots(
rows=2, cols=1,
row_heights=[0.7, 0.3],
vertical_spacing=0.1,
subplot_titles=("Time Series Data", "Drift Detection (20-sample windows)")
)
# Y축 범위 계산
y_min, y_max = y.min(), y.max()
y_range = y_max - y_min
y_plot_min = y_min - y_range * 0.1
y_plot_max = y_max + y_range * 0.1
# === 첫 번째 row: 시계열 데이터 ===
# 배경: drift 구간을 bar graph로 표시
segment_boundaries = [0] + drift_points.tolist() + [len(X)]
colors = ['rgba(100, 150, 255, 0.15)', 'rgba(255, 150, 100, 0.15)', 'rgba(150, 255, 150, 0.15)',
'rgba(255, 200, 100, 0.15)', 'rgba(200, 150, 255, 0.15)', 'rgba(150, 255, 200, 0.15)']
for i in range(len(segment_boundaries) - 1):
start_idx = segment_boundaries[i]
end_idx = segment_boundaries[i + 1]
# 각 segment를 bar로 표시
fig.add_trace(go.Bar(
x=X[start_idx:end_idx],
y=[y_plot_max - y_plot_min] * (end_idx - start_idx),
base=y_plot_min,
marker=dict(
color=colors[i % len(colors)],
line=dict(width=0)
),
name=f'Segment {i+1}',
showlegend=False,
hoverinfo='skip'
), row=1, col=1)
# 메인 라인 그래프
fig.add_trace(go.Scatter(
x=X,
y=y,
mode='lines+markers',
name='Data',
line=dict(color='rgb(50, 100, 180)', width=2),
marker=dict(size=4, color='rgb(50, 100, 180)'),
hovertemplate='Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
), row=1, col=1)
# 드리프트 발생 지점 표시
for i, drift_point in enumerate(drift_points):
fig.add_vline(
x=X[drift_point],
line_dash="dash",
line_color="red",
line_width=2,
annotation_text=f"Drift {i+1}",
annotation_position="top",
row=1, col=1
)
# === 두 번째 row: Drift Classification (using Frouros) ===
from frouros.detectors.data_drift import KSTest
from scipy import stats
window_size = 20
n_windows = len(X) // window_size
window_centers = []
drift_detected = []
# 첫 번째 window를 reference로 사용
reference_window = y[:window_size]
for i in range(n_windows):
start_idx = i * window_size
end_idx = (i + 1) * window_size
window_center = (start_idx + end_idx) / 2
window_data = y[start_idx:end_idx]
# Frouros KSTest를 사용하여 drift 감지
if i == 0:
# 첫 번째 window는 reference이므로 drift 없음
has_drift = False
else:
try:
# Kolmogorov-Smirnov test 수행
detector = KSTest()
detector.fit(X=reference_window.reshape(-1, 1))
result = detector.compare(X=window_data.reshape(-1, 1))
# p-value가 0.05보다 작으면 drift 감지
has_drift = result.p_value < 0.05
# drift가 감지되면 reference window 업데이트
if not has_drift:
reference_window = window_data
except:
# 에러 발생 시 이전 방식 사용
has_drift = any(start_idx <= dp < end_idx for dp in drift_points)
window_centers.append(window_center)
drift_detected.append(1 if has_drift else 0)
# Drift detection bar graph - 주황색(drift) / 초록색(no drift)
bar_colors = ['rgba(255, 140, 0, 0.7)' if d == 1 else 'rgba(100, 200, 100, 0.7)'
for d in drift_detected]
fig.add_trace(go.Bar(
x=window_centers,
y=[1 for d in drift_detected],
marker=dict(color=bar_colors, line=dict(width=0)),
name='Drift Detected',
showlegend=False,
hovertemplate='Window: %{x:.0f}<br>Drift: %{y}<extra></extra>',
width=window_size * 0.9
), row=2, col=1)
# 레이아웃 설정
title_map = {
"sudden": "Sudden Drift",
"gradual": "Gradual Drift",
"incremental": "Incremental Drift",
"recurring": "Reoccurring Concepts"
}
fig.update_layout(
title=dict(
text=title_map.get(drift_type, "Concept Drift"),
x=0.5,
xanchor='center',
font=dict(size=20)
),
hovermode='closest',
template='plotly_white',
height=700,
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="right",
x=0.99
),
plot_bgcolor='white',
barmode='overlay',
bargap=0
)
# 첫 번째 subplot 축 설정
fig.update_xaxes(title_text="Time", showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1)
fig.update_yaxes(title_text="Value", showgrid=True, gridwidth=1, gridcolor='LightGray',
range=[y_plot_min, y_plot_max], row=1, col=1)
# 두 번째 subplot 축 설정
fig.update_xaxes(title_text="Time", showgrid=True, gridwidth=1, gridcolor='LightGray', row=2, col=1)
fig.update_yaxes(title_text="Drift", showgrid=True, gridwidth=1, gridcolor='LightGray',
range=[-0.2, 1.5], tickvals=[0, 1], ticktext=['No Drift', 'Drift'], row=2, col=1)
return fig
def create_comparison_visualization(drift_data_dict: dict) -> go.Figure:
"""여러 드리프트 유형을 한 번에 비교"""
from plotly.subplots import make_subplots
fig = make_subplots(
rows=2, cols=2,
subplot_titles=("Sudden Drift", "Gradual Drift", "Incremental Drift", "Reoccurring Concepts"),
vertical_spacing=0.15,
horizontal_spacing=0.1
)
positions = [(1, 1), (1, 2), (2, 1), (2, 2)]
drift_types = ["sudden", "gradual", "incremental", "recurring"]
colors = ['rgba(100, 150, 255, 0.15)', 'rgba(255, 150, 100, 0.15)', 'rgba(150, 255, 150, 0.15)',
'rgba(255, 200, 100, 0.15)', 'rgba(200, 150, 255, 0.15)', 'rgba(150, 255, 200, 0.15)']
for (row, col), drift_type in zip(positions, drift_types):
if drift_type in drift_data_dict:
X, y, drift_points = drift_data_dict[drift_type]
# Y축 범위 계산
y_min, y_max = y.min(), y.max()
y_range = y_max - y_min
y_plot_min = y_min - y_range * 0.1
y_plot_max = y_max + y_range * 0.1
# 배경: drift 구간을 bar로 표시
segment_boundaries = [0] + drift_points.tolist() + [len(X)]
for i in range(len(segment_boundaries) - 1):
start_idx = segment_boundaries[i]
end_idx = segment_boundaries[i + 1]
fig.add_trace(
go.Bar(
x=X[start_idx:end_idx],
y=[y_plot_max - y_plot_min] * (end_idx - start_idx),
base=y_plot_min,
marker=dict(
color=colors[i % len(colors)],
line=dict(width=0)
),
showlegend=False,
hoverinfo='skip'
),
row=row, col=col
)
# 라인 그래프 추가
fig.add_trace(
go.Scatter(
x=X,
y=y,
mode='lines',
line=dict(color='rgb(50, 100, 180)', width=1.5),
showlegend=False
),
row=row, col=col
)
# 드리프트 지점 표시
for drift_point in drift_points:
fig.add_vline(
x=X[drift_point],
line_dash="dash",
line_color="red",
line_width=1,
row=row, col=col
)
# 레이아웃 설정
fig.update_xaxes(title_text="Time", showgrid=True, gridcolor='LightGray')
fig.update_yaxes(title_text="Value", showgrid=True, gridcolor='LightGray')
fig.update_layout(
height=800,
title_text="Concept Drift Types Comparison",
showlegend=False,
template='plotly_white',
barmode='overlay',
bargap=0
)
return fig
|