Spaces:
Sleeping
Sleeping
Yoon-gu Hwang Claude commited on
Commit Β·
93ec097
1
Parent(s): 7ab1194
Change visualization from bar chart to scatter plot
Browse files- Replace bar chart with scatter plot using square markers
- Simplify data distribution visualization
- Improve color scheme for better visibility
π€ Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- visualizer.py +50 -42
visualizer.py
CHANGED
|
@@ -9,19 +9,24 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
|
|
| 9 |
|
| 10 |
# incremental driftλ μ°μ κ°μΌλ‘ μ²λ¦¬
|
| 11 |
if drift_type == "incremental":
|
| 12 |
-
#
|
| 13 |
colors = []
|
| 14 |
for val in y:
|
| 15 |
# νλμ(0)μμ μ΄λ‘μ(1)λ‘ μ μ§μ λ³ν
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
|
|
|
|
| 21 |
x=X,
|
| 22 |
y=np.ones(len(X)),
|
|
|
|
| 23 |
marker=dict(
|
| 24 |
color=colors,
|
|
|
|
|
|
|
| 25 |
line=dict(width=0)
|
| 26 |
),
|
| 27 |
showlegend=False,
|
|
@@ -30,26 +35,27 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
|
|
| 30 |
))
|
| 31 |
else:
|
| 32 |
# μ΄μ§ λΆλ₯ (0: νλμ, 1: μ΄λ‘μ)
|
| 33 |
-
class_0_indices = np.where(y == 0)[0]
|
| 34 |
-
class_1_indices = np.where(y == 1)[0]
|
| 35 |
-
|
| 36 |
# Class 0 (νλμ)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
name='Class 0',
|
| 43 |
showlegend=True,
|
| 44 |
hovertemplate='Time: %{x}<br>Class: 0<extra></extra>'
|
| 45 |
))
|
| 46 |
|
| 47 |
# Class 1 (μ΄λ‘μ)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
name='Class 1',
|
| 54 |
showlegend=True,
|
| 55 |
hovertemplate='Time: %{x}<br>Class: 1<extra></extra>'
|
|
@@ -68,7 +74,7 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
|
|
| 68 |
text=title_map.get(drift_type, "Concept Drift"),
|
| 69 |
x=0.5,
|
| 70 |
xanchor='center',
|
| 71 |
-
font=dict(size=20
|
| 72 |
),
|
| 73 |
xaxis_title="Time",
|
| 74 |
yaxis_title="Data distribution",
|
|
@@ -84,15 +90,14 @@ def create_drift_visualization(X: np.ndarray, y: np.ndarray, drift_points: np.nd
|
|
| 84 |
),
|
| 85 |
xaxis=dict(
|
| 86 |
showgrid=False,
|
| 87 |
-
showticklabels=
|
| 88 |
),
|
| 89 |
yaxis=dict(
|
| 90 |
showgrid=False,
|
| 91 |
showticklabels=False,
|
| 92 |
-
range=[0, 1.
|
| 93 |
),
|
| 94 |
-
plot_bgcolor='white'
|
| 95 |
-
bargap=0
|
| 96 |
)
|
| 97 |
|
| 98 |
return fig
|
|
@@ -120,41 +125,45 @@ def create_comparison_visualization(drift_data_dict: dict) -> go.Figure:
|
|
| 120 |
# Incremental drift: μ°μ μμ λ³ν
|
| 121 |
colors = []
|
| 122 |
for val in y:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
fig.add_trace(
|
| 128 |
-
go.
|
| 129 |
x=X,
|
| 130 |
y=np.ones(len(X)),
|
| 131 |
-
|
|
|
|
| 132 |
showlegend=False
|
| 133 |
),
|
| 134 |
row=row, col=col
|
| 135 |
)
|
| 136 |
else:
|
| 137 |
# μ΄μ§ λΆλ₯
|
| 138 |
-
|
| 139 |
-
|
| 140 |
|
| 141 |
-
if
|
| 142 |
fig.add_trace(
|
| 143 |
-
go.
|
| 144 |
-
x=X[
|
| 145 |
-
y=np.ones(
|
| 146 |
-
|
|
|
|
| 147 |
showlegend=False
|
| 148 |
),
|
| 149 |
row=row, col=col
|
| 150 |
)
|
| 151 |
|
| 152 |
-
if
|
| 153 |
fig.add_trace(
|
| 154 |
-
go.
|
| 155 |
-
x=X[
|
| 156 |
-
y=np.ones(
|
| 157 |
-
|
|
|
|
| 158 |
showlegend=False
|
| 159 |
),
|
| 160 |
row=row, col=col
|
|
@@ -162,12 +171,11 @@ def create_comparison_visualization(drift_data_dict: dict) -> go.Figure:
|
|
| 162 |
|
| 163 |
# λ μ΄μμ μ€μ
|
| 164 |
fig.update_xaxes(title_text="Time", showgrid=False, showticklabels=False)
|
| 165 |
-
fig.update_yaxes(title_text="Data distribution", showgrid=False, showticklabels=False, range=[0, 1.
|
| 166 |
fig.update_layout(
|
| 167 |
height=800,
|
| 168 |
title_text="Concept Drift Types Comparison",
|
| 169 |
showlegend=False,
|
| 170 |
-
bargap=0,
|
| 171 |
template='plotly_white'
|
| 172 |
)
|
| 173 |
|
|
|
|
| 9 |
|
| 10 |
# incremental driftλ μ°μ κ°μΌλ‘ μ²λ¦¬
|
| 11 |
if drift_type == "incremental":
|
| 12 |
+
# μμ 리μ€νΈ μμ±
|
| 13 |
colors = []
|
| 14 |
for val in y:
|
| 15 |
# νλμ(0)μμ μ΄λ‘μ(1)λ‘ μ μ§μ λ³ν
|
| 16 |
+
r = int(50 + 50 * val)
|
| 17 |
+
g = int(100 + 79 * val)
|
| 18 |
+
b = int(200 - 87 * val)
|
| 19 |
+
colors.append(f'rgb({r}, {g}, {b})')
|
| 20 |
|
| 21 |
+
# Scatterλ‘ νν (μμ μμ)
|
| 22 |
+
fig.add_trace(go.Scatter(
|
| 23 |
x=X,
|
| 24 |
y=np.ones(len(X)),
|
| 25 |
+
mode='markers',
|
| 26 |
marker=dict(
|
| 27 |
color=colors,
|
| 28 |
+
size=10,
|
| 29 |
+
symbol='square',
|
| 30 |
line=dict(width=0)
|
| 31 |
),
|
| 32 |
showlegend=False,
|
|
|
|
| 35 |
))
|
| 36 |
else:
|
| 37 |
# μ΄μ§ λΆλ₯ (0: νλμ, 1: μ΄λ‘μ)
|
|
|
|
|
|
|
|
|
|
| 38 |
# Class 0 (νλμ)
|
| 39 |
+
class_0_mask = y == 0
|
| 40 |
+
if np.any(class_0_mask):
|
| 41 |
+
fig.add_trace(go.Scatter(
|
| 42 |
+
x=X[class_0_mask],
|
| 43 |
+
y=np.ones(np.sum(class_0_mask)),
|
| 44 |
+
mode='markers',
|
| 45 |
+
marker=dict(color='rgb(65, 105, 225)', size=10, symbol='square', line=dict(width=0)),
|
| 46 |
name='Class 0',
|
| 47 |
showlegend=True,
|
| 48 |
hovertemplate='Time: %{x}<br>Class: 0<extra></extra>'
|
| 49 |
))
|
| 50 |
|
| 51 |
# Class 1 (μ΄λ‘μ)
|
| 52 |
+
class_1_mask = y == 1
|
| 53 |
+
if np.any(class_1_mask):
|
| 54 |
+
fig.add_trace(go.Scatter(
|
| 55 |
+
x=X[class_1_mask],
|
| 56 |
+
y=np.ones(np.sum(class_1_mask)),
|
| 57 |
+
mode='markers',
|
| 58 |
+
marker=dict(color='rgb(50, 205, 50)', size=10, symbol='square', line=dict(width=0)),
|
| 59 |
name='Class 1',
|
| 60 |
showlegend=True,
|
| 61 |
hovertemplate='Time: %{x}<br>Class: 1<extra></extra>'
|
|
|
|
| 74 |
text=title_map.get(drift_type, "Concept Drift"),
|
| 75 |
x=0.5,
|
| 76 |
xanchor='center',
|
| 77 |
+
font=dict(size=20)
|
| 78 |
),
|
| 79 |
xaxis_title="Time",
|
| 80 |
yaxis_title="Data distribution",
|
|
|
|
| 90 |
),
|
| 91 |
xaxis=dict(
|
| 92 |
showgrid=False,
|
| 93 |
+
showticklabels=True
|
| 94 |
),
|
| 95 |
yaxis=dict(
|
| 96 |
showgrid=False,
|
| 97 |
showticklabels=False,
|
| 98 |
+
range=[0.5, 1.5]
|
| 99 |
),
|
| 100 |
+
plot_bgcolor='white'
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
return fig
|
|
|
|
| 125 |
# Incremental drift: μ°μ μμ λ³ν
|
| 126 |
colors = []
|
| 127 |
for val in y:
|
| 128 |
+
r = int(50 + 50 * val)
|
| 129 |
+
g = int(100 + 79 * val)
|
| 130 |
+
b = int(200 - 87 * val)
|
| 131 |
+
colors.append(f'rgb({r}, {g}, {b})')
|
| 132 |
|
| 133 |
fig.add_trace(
|
| 134 |
+
go.Scatter(
|
| 135 |
x=X,
|
| 136 |
y=np.ones(len(X)),
|
| 137 |
+
mode='markers',
|
| 138 |
+
marker=dict(color=colors, size=5, symbol='square', line=dict(width=0)),
|
| 139 |
showlegend=False
|
| 140 |
),
|
| 141 |
row=row, col=col
|
| 142 |
)
|
| 143 |
else:
|
| 144 |
# μ΄μ§ λΆλ₯
|
| 145 |
+
class_0_mask = y == 0
|
| 146 |
+
class_1_mask = y == 1
|
| 147 |
|
| 148 |
+
if np.any(class_0_mask):
|
| 149 |
fig.add_trace(
|
| 150 |
+
go.Scatter(
|
| 151 |
+
x=X[class_0_mask],
|
| 152 |
+
y=np.ones(np.sum(class_0_mask)),
|
| 153 |
+
mode='markers',
|
| 154 |
+
marker=dict(color='rgb(65, 105, 225)', size=5, symbol='square', line=dict(width=0)),
|
| 155 |
showlegend=False
|
| 156 |
),
|
| 157 |
row=row, col=col
|
| 158 |
)
|
| 159 |
|
| 160 |
+
if np.any(class_1_mask):
|
| 161 |
fig.add_trace(
|
| 162 |
+
go.Scatter(
|
| 163 |
+
x=X[class_1_mask],
|
| 164 |
+
y=np.ones(np.sum(class_1_mask)),
|
| 165 |
+
mode='markers',
|
| 166 |
+
marker=dict(color='rgb(50, 205, 50)', size=5, symbol='square', line=dict(width=0)),
|
| 167 |
showlegend=False
|
| 168 |
),
|
| 169 |
row=row, col=col
|
|
|
|
| 171 |
|
| 172 |
# λ μ΄μμ μ€μ
|
| 173 |
fig.update_xaxes(title_text="Time", showgrid=False, showticklabels=False)
|
| 174 |
+
fig.update_yaxes(title_text="Data distribution", showgrid=False, showticklabels=False, range=[0.5, 1.5])
|
| 175 |
fig.update_layout(
|
| 176 |
height=800,
|
| 177 |
title_text="Concept Drift Types Comparison",
|
| 178 |
showlegend=False,
|
|
|
|
| 179 |
template='plotly_white'
|
| 180 |
)
|
| 181 |
|