Yoon-gu Hwang Claude commited on
Commit
9c2cd7f
ยท
1 Parent(s): bd359c3

Add 2-row layout with drift classification bar graph

Browse files

- Split visualization into 2 rows using subplots
- Top row: Time series data with background segment colors
- Bottom row: Drift detection bar graph (20-sample windows)
- Red bars indicate drift detected, green bars indicate no drift
- Improves understanding of drift occurrence patterns

๐Ÿค– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. visualizer.py +55 -18
visualizer.py CHANGED
@@ -1,11 +1,18 @@
1
  import plotly.graph_objects as go
2
  import numpy as np
3
  from typing import Tuple
 
4
 
5
  def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.ndarray, drift_type: str) -> go.Figure:
6
  """๋“œ๋ฆฌํ”„ํŠธ ๋ฐ์ดํ„ฐ๋ฅผ Plotly๋กœ ์‹œ๊ฐํ™”"""
7
 
8
- fig = go.Figure()
 
 
 
 
 
 
9
 
10
  # Y์ถ• ๋ฒ”์œ„ ๊ณ„์‚ฐ
11
  y_min, y_max = y.min(), y.max()
@@ -13,6 +20,7 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
13
  y_plot_min = y_min - y_range * 0.1
14
  y_plot_max = y_max + y_range * 0.1
15
 
 
16
  # ๋ฐฐ๊ฒฝ: drift ๊ตฌ๊ฐ„์„ bar graph๋กœ ํ‘œ์‹œ
17
  segment_boundaries = [0] + drift_points.tolist() + [len(X)]
18
  colors = ['rgba(100, 150, 255, 0.15)', 'rgba(255, 150, 100, 0.15)', 'rgba(150, 255, 150, 0.15)',
@@ -34,7 +42,7 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
34
  name=f'Segment {i+1}',
35
  showlegend=False,
36
  hoverinfo='skip'
37
- ))
38
 
39
  # ๋ฉ”์ธ ๋ผ์ธ ๊ทธ๋ž˜ํ”„
40
  fig.add_trace(go.Scatter(
@@ -45,7 +53,7 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
45
  line=dict(color='rgb(50, 100, 180)', width=2),
46
  marker=dict(size=4, color='rgb(50, 100, 180)'),
47
  hovertemplate='Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
48
- ))
49
 
50
  # ๋“œ๋ฆฌํ”„ํŠธ ๋ฐœ์ƒ ์ง€์  ํ‘œ์‹œ
51
  for i, drift_point in enumerate(drift_points):
@@ -55,9 +63,41 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
55
  line_color="red",
56
  line_width=2,
57
  annotation_text=f"Drift {i+1}",
58
- annotation_position="top"
 
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # ๋ ˆ์ด์•„์›ƒ ์„ค์ •
62
  title_map = {
63
  "sudden": "Sudden Drift",
@@ -73,11 +113,9 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
73
  xanchor='center',
74
  font=dict(size=20)
75
  ),
76
- xaxis_title="Time",
77
- yaxis_title="Value",
78
  hovermode='closest',
79
  template='plotly_white',
80
- height=500,
81
  showlegend=True,
82
  legend=dict(
83
  yanchor="top",
@@ -85,22 +123,21 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
85
  xanchor="right",
86
  x=0.99
87
  ),
88
- xaxis=dict(
89
- showgrid=True,
90
- gridwidth=1,
91
- gridcolor='LightGray'
92
- ),
93
- yaxis=dict(
94
- showgrid=True,
95
- gridwidth=1,
96
- gridcolor='LightGray',
97
- range=[y_plot_min, y_plot_max]
98
- ),
99
  plot_bgcolor='white',
100
  barmode='overlay',
101
  bargap=0
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
104
  return fig
105
 
106
 
 
1
  import plotly.graph_objects as go
2
  import numpy as np
3
  from typing import Tuple
4
+ from plotly.subplots import make_subplots
5
 
6
  def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.ndarray, drift_type: str) -> go.Figure:
7
  """๋“œ๋ฆฌํ”„ํŠธ ๋ฐ์ดํ„ฐ๋ฅผ Plotly๋กœ ์‹œ๊ฐํ™”"""
8
 
9
+ # 2๊ฐœ์˜ row๋กœ subplot ์ƒ์„ฑ
10
+ fig = make_subplots(
11
+ rows=2, cols=1,
12
+ row_heights=[0.7, 0.3],
13
+ vertical_spacing=0.1,
14
+ subplot_titles=("Time Series Data", "Drift Detection (20-sample windows)")
15
+ )
16
 
17
  # Y์ถ• ๋ฒ”์œ„ ๊ณ„์‚ฐ
18
  y_min, y_max = y.min(), y.max()
 
20
  y_plot_min = y_min - y_range * 0.1
21
  y_plot_max = y_max + y_range * 0.1
22
 
23
+ # === ์ฒซ ๋ฒˆ์งธ row: ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ===
24
  # ๋ฐฐ๊ฒฝ: drift ๊ตฌ๊ฐ„์„ bar graph๋กœ ํ‘œ์‹œ
25
  segment_boundaries = [0] + drift_points.tolist() + [len(X)]
26
  colors = ['rgba(100, 150, 255, 0.15)', 'rgba(255, 150, 100, 0.15)', 'rgba(150, 255, 150, 0.15)',
 
42
  name=f'Segment {i+1}',
43
  showlegend=False,
44
  hoverinfo='skip'
45
+ ), row=1, col=1)
46
 
47
  # ๋ฉ”์ธ ๋ผ์ธ ๊ทธ๋ž˜ํ”„
48
  fig.add_trace(go.Scatter(
 
53
  line=dict(color='rgb(50, 100, 180)', width=2),
54
  marker=dict(size=4, color='rgb(50, 100, 180)'),
55
  hovertemplate='Time: %{x}<br>Value: %{y:.2f}<extra></extra>'
56
+ ), row=1, col=1)
57
 
58
  # ๋“œ๋ฆฌํ”„ํŠธ ๋ฐœ์ƒ ์ง€์  ํ‘œ์‹œ
59
  for i, drift_point in enumerate(drift_points):
 
63
  line_color="red",
64
  line_width=2,
65
  annotation_text=f"Drift {i+1}",
66
+ annotation_position="top",
67
+ row=1, col=1
68
  )
69
 
70
+ # === ๋‘ ๋ฒˆ์งธ row: Drift Classification ===
71
+ window_size = 20
72
+ n_windows = len(X) // window_size
73
+ window_centers = []
74
+ drift_detected = []
75
+
76
+ for i in range(n_windows):
77
+ start_idx = i * window_size
78
+ end_idx = (i + 1) * window_size
79
+ window_center = (start_idx + end_idx) / 2
80
+
81
+ # ์ด window์— drift point๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
82
+ has_drift = any(start_idx <= dp < end_idx for dp in drift_points)
83
+
84
+ window_centers.append(window_center)
85
+ drift_detected.append(1 if has_drift else 0)
86
+
87
+ # Drift detection bar graph
88
+ bar_colors = ['rgba(255, 80, 80, 0.7)' if d == 1 else 'rgba(100, 200, 100, 0.7)'
89
+ for d in drift_detected]
90
+
91
+ fig.add_trace(go.Bar(
92
+ x=window_centers,
93
+ y=drift_detected,
94
+ marker=dict(color=bar_colors, line=dict(width=0)),
95
+ name='Drift Detected',
96
+ showlegend=False,
97
+ hovertemplate='Window: %{x:.0f}<br>Drift: %{y}<extra></extra>',
98
+ width=window_size * 0.9
99
+ ), row=2, col=1)
100
+
101
  # ๋ ˆ์ด์•„์›ƒ ์„ค์ •
102
  title_map = {
103
  "sudden": "Sudden Drift",
 
113
  xanchor='center',
114
  font=dict(size=20)
115
  ),
 
 
116
  hovermode='closest',
117
  template='plotly_white',
118
+ height=700,
119
  showlegend=True,
120
  legend=dict(
121
  yanchor="top",
 
123
  xanchor="right",
124
  x=0.99
125
  ),
 
 
 
 
 
 
 
 
 
 
 
126
  plot_bgcolor='white',
127
  barmode='overlay',
128
  bargap=0
129
  )
130
 
131
+ # ์ฒซ ๋ฒˆ์งธ subplot ์ถ• ์„ค์ •
132
+ fig.update_xaxes(title_text="Time", showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=1)
133
+ fig.update_yaxes(title_text="Value", showgrid=True, gridwidth=1, gridcolor='LightGray',
134
+ range=[y_plot_min, y_plot_max], row=1, col=1)
135
+
136
+ # ๋‘ ๋ฒˆ์งธ subplot ์ถ• ์„ค์ •
137
+ fig.update_xaxes(title_text="Time", showgrid=True, gridwidth=1, gridcolor='LightGray', row=2, col=1)
138
+ fig.update_yaxes(title_text="Drift", showgrid=True, gridwidth=1, gridcolor='LightGray',
139
+ range=[-0.2, 1.5], tickvals=[0, 1], ticktext=['No Drift', 'Drift'], row=2, col=1)
140
+
141
  return fig
142
 
143