Yoon-gu Hwang Claude commited on
Commit
93ec097
Β·
1 Parent(s): 7ab1194

Change visualization from bar chart to scatter plot

Browse files

- Replace bar chart with scatter plot using square markers
- Simplify data distribution visualization
- Improve color scheme for better visibility

πŸ€– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. visualizer.py +50 -42
visualizer.py CHANGED
@@ -9,19 +9,24 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
9
 
10
  # incremental driftλŠ” 연속 κ°’μœΌλ‘œ 처리
11
  if drift_type == "incremental":
12
- # 0-1 사이 값을 μƒ‰μƒμœΌλ‘œ λ§€ν•‘
13
  colors = []
14
  for val in y:
15
  # νŒŒλž€μƒ‰(0)μ—μ„œ μ΄ˆλ‘μƒ‰(1)둜 점진적 λ³€ν™”
16
- blue = int(255 * (1 - val))
17
- green = int(255 * val)
18
- colors.append(f'rgb({blue}, {green}, 150)')
 
19
 
20
- fig.add_trace(go.Bar(
 
21
  x=X,
22
  y=np.ones(len(X)),
 
23
  marker=dict(
24
  color=colors,
 
 
25
  line=dict(width=0)
26
  ),
27
  showlegend=False,
@@ -30,26 +35,27 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
30
  ))
31
  else:
32
  # 이진 λΆ„λ₯˜ (0: νŒŒλž€μƒ‰, 1: μ΄ˆλ‘μƒ‰)
33
- class_0_indices = np.where(y == 0)[0]
34
- class_1_indices = np.where(y == 1)[0]
35
-
36
  # Class 0 (νŒŒλž€μƒ‰)
37
- if len(class_0_indices) > 0:
38
- fig.add_trace(go.Bar(
39
- x=X[class_0_indices],
40
- y=np.ones(len(class_0_indices)),
41
- marker=dict(color='rgb(70, 130, 180)', line=dict(width=0)),
 
 
42
  name='Class 0',
43
  showlegend=True,
44
  hovertemplate='Time: %{x}<br>Class: 0<extra></extra>'
45
  ))
46
 
47
  # Class 1 (μ΄ˆλ‘μƒ‰)
48
- if len(class_1_indices) > 0:
49
- fig.add_trace(go.Bar(
50
- x=X[class_1_indices],
51
- y=np.ones(len(class_1_indices)),
52
- marker=dict(color='rgb(60, 179, 113)', line=dict(width=0)),
 
 
53
  name='Class 1',
54
  showlegend=True,
55
  hovertemplate='Time: %{x}<br>Class: 1<extra></extra>'
@@ -68,7 +74,7 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
68
  text=title_map.get(drift_type, "Concept Drift"),
69
  x=0.5,
70
  xanchor='center',
71
- font=dict(size=20, weight='bold')
72
  ),
73
  xaxis_title="Time",
74
  yaxis_title="Data distribution",
@@ -84,15 +90,14 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
84
  ),
85
  xaxis=dict(
86
  showgrid=False,
87
- showticklabels=False
88
  ),
89
  yaxis=dict(
90
  showgrid=False,
91
  showticklabels=False,
92
- range=[0, 1.2]
93
  ),
94
- plot_bgcolor='white',
95
- bargap=0
96
  )
97
 
98
  return fig
@@ -120,41 +125,45 @@ def create_comparison_visualization(drift_data_dict: dict) -> go.Figure:
120
  # Incremental drift: 연속 색상 λ³€ν™”
121
  colors = []
122
  for val in y:
123
- blue = int(255 * (1 - val))
124
- green = int(255 * val)
125
- colors.append(f'rgb({blue}, {green}, 150)')
 
126
 
127
  fig.add_trace(
128
- go.Bar(
129
  x=X,
130
  y=np.ones(len(X)),
131
- marker=dict(color=colors, line=dict(width=0)),
 
132
  showlegend=False
133
  ),
134
  row=row, col=col
135
  )
136
  else:
137
  # 이진 λΆ„λ₯˜
138
- class_0_indices = np.where(y == 0)[0]
139
- class_1_indices = np.where(y == 1)[0]
140
 
141
- if len(class_0_indices) > 0:
142
  fig.add_trace(
143
- go.Bar(
144
- x=X[class_0_indices],
145
- y=np.ones(len(class_0_indices)),
146
- marker=dict(color='rgb(70, 130, 180)', line=dict(width=0)),
 
147
  showlegend=False
148
  ),
149
  row=row, col=col
150
  )
151
 
152
- if len(class_1_indices) > 0:
153
  fig.add_trace(
154
- go.Bar(
155
- x=X[class_1_indices],
156
- y=np.ones(len(class_1_indices)),
157
- marker=dict(color='rgb(60, 179, 113)', line=dict(width=0)),
 
158
  showlegend=False
159
  ),
160
  row=row, col=col
@@ -162,12 +171,11 @@ def create_comparison_visualization(drift_data_dict: dict) -> go.Figure:
162
 
163
  # λ ˆμ΄μ•„μ›ƒ μ„€μ •
164
  fig.update_xaxes(title_text="Time", showgrid=False, showticklabels=False)
165
- fig.update_yaxes(title_text="Data distribution", showgrid=False, showticklabels=False, range=[0, 1.2])
166
  fig.update_layout(
167
  height=800,
168
  title_text="Concept Drift Types Comparison",
169
  showlegend=False,
170
- bargap=0,
171
  template='plotly_white'
172
  )
173
 
 
9
 
10
  # incremental driftλŠ” 연속 κ°’μœΌλ‘œ 처리
11
  if drift_type == "incremental":
12
+ # 색상 리슀트 생성
13
  colors = []
14
  for val in y:
15
  # νŒŒλž€μƒ‰(0)μ—μ„œ μ΄ˆλ‘μƒ‰(1)둜 점진적 λ³€ν™”
16
+ r = int(50 + 50 * val)
17
+ g = int(100 + 79 * val)
18
+ b = int(200 - 87 * val)
19
+ colors.append(f'rgb({r}, {g}, {b})')
20
 
21
+ # Scatter둜 ν‘œν˜„ (색상 μ˜μ—­)
22
+ fig.add_trace(go.Scatter(
23
  x=X,
24
  y=np.ones(len(X)),
25
+ mode='markers',
26
  marker=dict(
27
  color=colors,
28
+ size=10,
29
+ symbol='square',
30
  line=dict(width=0)
31
  ),
32
  showlegend=False,
 
35
  ))
36
  else:
37
  # 이진 λΆ„λ₯˜ (0: νŒŒλž€μƒ‰, 1: μ΄ˆλ‘μƒ‰)
 
 
 
38
  # Class 0 (νŒŒλž€μƒ‰)
39
+ class_0_mask = y == 0
40
+ if np.any(class_0_mask):
41
+ fig.add_trace(go.Scatter(
42
+ x=X[class_0_mask],
43
+ y=np.ones(np.sum(class_0_mask)),
44
+ mode='markers',
45
+ marker=dict(color='rgb(65, 105, 225)', size=10, symbol='square', line=dict(width=0)),
46
  name='Class 0',
47
  showlegend=True,
48
  hovertemplate='Time: %{x}<br>Class: 0<extra></extra>'
49
  ))
50
 
51
  # Class 1 (μ΄ˆλ‘μƒ‰)
52
+ class_1_mask = y == 1
53
+ if np.any(class_1_mask):
54
+ fig.add_trace(go.Scatter(
55
+ x=X[class_1_mask],
56
+ y=np.ones(np.sum(class_1_mask)),
57
+ mode='markers',
58
+ marker=dict(color='rgb(50, 205, 50)', size=10, symbol='square', line=dict(width=0)),
59
  name='Class 1',
60
  showlegend=True,
61
  hovertemplate='Time: %{x}<br>Class: 1<extra></extra>'
 
74
  text=title_map.get(drift_type, "Concept Drift"),
75
  x=0.5,
76
  xanchor='center',
77
+ font=dict(size=20)
78
  ),
79
  xaxis_title="Time",
80
  yaxis_title="Data distribution",
 
90
  ),
91
  xaxis=dict(
92
  showgrid=False,
93
+ showticklabels=True
94
  ),
95
  yaxis=dict(
96
  showgrid=False,
97
  showticklabels=False,
98
+ range=[0.5, 1.5]
99
  ),
100
+ plot_bgcolor='white'
 
101
  )
102
 
103
  return fig
 
125
  # Incremental drift: 연속 색상 λ³€ν™”
126
  colors = []
127
  for val in y:
128
+ r = int(50 + 50 * val)
129
+ g = int(100 + 79 * val)
130
+ b = int(200 - 87 * val)
131
+ colors.append(f'rgb({r}, {g}, {b})')
132
 
133
  fig.add_trace(
134
+ go.Scatter(
135
  x=X,
136
  y=np.ones(len(X)),
137
+ mode='markers',
138
+ marker=dict(color=colors, size=5, symbol='square', line=dict(width=0)),
139
  showlegend=False
140
  ),
141
  row=row, col=col
142
  )
143
  else:
144
  # 이진 λΆ„λ₯˜
145
+ class_0_mask = y == 0
146
+ class_1_mask = y == 1
147
 
148
+ if np.any(class_0_mask):
149
  fig.add_trace(
150
+ go.Scatter(
151
+ x=X[class_0_mask],
152
+ y=np.ones(np.sum(class_0_mask)),
153
+ mode='markers',
154
+ marker=dict(color='rgb(65, 105, 225)', size=5, symbol='square', line=dict(width=0)),
155
  showlegend=False
156
  ),
157
  row=row, col=col
158
  )
159
 
160
+ if np.any(class_1_mask):
161
  fig.add_trace(
162
+ go.Scatter(
163
+ x=X[class_1_mask],
164
+ y=np.ones(np.sum(class_1_mask)),
165
+ mode='markers',
166
+ marker=dict(color='rgb(50, 205, 50)', size=5, symbol='square', line=dict(width=0)),
167
  showlegend=False
168
  ),
169
  row=row, col=col
 
171
 
172
  # λ ˆμ΄μ•„μ›ƒ μ„€μ •
173
  fig.update_xaxes(title_text="Time", showgrid=False, showticklabels=False)
174
+ fig.update_yaxes(title_text="Data distribution", showgrid=False, showticklabels=False, range=[0.5, 1.5])
175
  fig.update_layout(
176
  height=800,
177
  title_text="Concept Drift Types Comparison",
178
  showlegend=False,
 
179
  template='plotly_white'
180
  )
181